source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 8e2cd79

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 8e2cd79 was 8e2cd79, checked in by Piotr Rozyczko <rozyczko@…>, 6 years ago

Copy/paste fitting parameters SASVIEW-933

  • Property mode set to 100644
File size: 20.8 KB
RevLine 
[fde5bcd]1import copy
[1bc27f1]2
[4992ff2]3from PyQt5 import QtCore
4from PyQt5 import QtGui
[4d457df]5
[6fd4e36]6import numpy
7
[dc5ef15]8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
[6fd4e36]10
[f54ce30]11model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
12
13model_header_tooltips = ['Select parameter for fitting',
[d0dfcb2]14                         'Enter parameter value',
15                         'Enter minimum value for parameter',
16                         'Enter maximum value for parameter',
17                         'Unit of the parameter']
[f54ce30]18
19poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
[d0dfcb2]20                        'Function', 'Filename']
[f54ce30]21
22poly_header_tooltips = ['Select parameter for fitting',
[d0dfcb2]23                        'Enter polydispersity ratio (STD/mean). '
24                        'STD: standard deviation from the mean value',
25                        'Enter minimum value for parameter',
26                        'Enter maximum value for parameter',
27                        'Enter number of points for parameter',
28                        'Enter number of sigmas parameter',
29                        'Select distribution function',
30                        'Select filename with user-definable distribution']
[f54ce30]31
32error_tooltip = 'Error value for fitted parameter'
33header_error_caption = 'Error'
34
[4d457df]35def replaceShellName(param_name, value):
36    """
37    Updates parameter name from <param_name>[n_shell] to <param_name>value
38    """
39    assert '[' in param_name
40    return param_name[:param_name.index('[')]+str(value)
41
42def getIterParams(model):
43    """
44    Returns a list of all multi-shell parameters in 'model'
45    """
[b3e8629]46    return list([par for par in model.iq_parameters if "[" in par.name])
[4d457df]47
48def getMultiplicity(model):
49    """
50    Finds out if 'model' has multishell parameters.
51    If so, returns the name of the counter parameter and the number of shells
52    """
53    iter_params = getIterParams(model)
[a9b568c]54    param_name = ""
55    param_length = 0
56    if iter_params:
57        param_length = iter_params[0].length
58        param_name = iter_params[0].length_control
59        if param_name is None and '[' in iter_params[0].name:
60            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
61    return (param_name, param_length)
[4d457df]62
[aca8418]63def addParametersToModel(parameters, kernel_module, is2D):
[4d457df]64    """
65    Update local ModelModel with sasmodel parameters
66    """
67    multishell_parameters = getIterParams(parameters)
68    multishell_param_name, _ = getMultiplicity(parameters)
[57be490]69
[b5cc06e]70    if is2D:
71        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
72    else:
73        params = parameters.iq_parameters
[1970780]74    item = []
75    for param in params:
[4d457df]76        # don't include shell parameters
77        if param.name == multishell_param_name:
78            continue
79        # Modify parameter name from <param>[n] to <param>1
80        item_name = param.name
81        if param in multishell_parameters:
[b1e36a3]82            continue
83        #    item_name = replaceShellName(param.name, 1)
[4d457df]84
85        item1 = QtGui.QStandardItem(item_name)
86        item1.setCheckable(True)
[2add354]87        item1.setEditable(False)
[1bc27f1]88        # item_err = QtGui.QStandardItem()
[4d457df]89        # check for polydisp params
90        if param.polydisperse:
91            poly_item = QtGui.QStandardItem("Polydispersity")
[2add354]92            poly_item.setEditable(False)
[4d457df]93            item1_1 = QtGui.QStandardItem("Distribution")
[2add354]94            item1_1.setEditable(False)
[4d457df]95            # Find param in volume_params
96            for p in parameters.form_volume_parameters:
97                if p.name != param.name:
98                    continue
[aca8418]99                width = kernel_module.getParam(p.name+'.width')
[8e2cd79]100                ptype = kernel_module.getParam(p.name+'.type')
[aca8418]101
102                item1_2 = QtGui.QStandardItem(str(width))
[2add354]103                item1_2.setEditable(False)
[aca8418]104                item1_3 = QtGui.QStandardItem()
[2add354]105                item1_3.setEditable(False)
[aca8418]106                item1_4 = QtGui.QStandardItem()
[2add354]107                item1_4.setEditable(False)
[8e2cd79]108                item1_5 = QtGui.QStandardItem(ptype)
[2add354]109                item1_5.setEditable(False)
[4d457df]110                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
111                break
112            # Add the polydisp item as a child
113            item1.appendRow([poly_item])
114        # Param values
115        item2 = QtGui.QStandardItem(str(param.default))
116        # TODO: the error column.
117        # Either add a proxy model or a custom view delegate
118        #item_err = QtGui.QStandardItem()
119        item3 = QtGui.QStandardItem(str(param.limits[0]))
120        item4 = QtGui.QStandardItem(str(param.limits[1]))
121        item5 = QtGui.QStandardItem(param.units)
[2add354]122        item5.setEditable(False)
[1970780]123        item.append([item1, item2, item3, item4, item5])
124    return item
[4d457df]125
[1970780]126def addSimpleParametersToModel(parameters, is2D):
[4d457df]127    """
128    Update local ModelModel with sasmodel parameters
129    """
[b5cc06e]130    if is2D:
131        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
132    else:
133        params = parameters.iq_parameters
[1970780]134    item = []
135    for param in params:
[7248d75d]136        # Create the top level, checkable item
[4d457df]137        item_name = param.name
138        item1 = QtGui.QStandardItem(item_name)
139        item1.setCheckable(True)
[2add354]140        item1.setEditable(False)
[4d457df]141        # Param values
[2add354]142        # TODO: add delegate for validation of cells
[4d457df]143        item2 = QtGui.QStandardItem(str(param.default))
[f182f93]144        item4 = QtGui.QStandardItem(str(param.limits[0]))
145        item5 = QtGui.QStandardItem(str(param.limits[1]))
146        item6 = QtGui.QStandardItem(param.units)
[2add354]147        item6.setEditable(False)
[1970780]148        item.append([item1, item2, item4, item5, item6])
149    return item
[4d457df]150
151def addCheckedListToModel(model, param_list):
152    """
153    Add a QItem to model. Makes the QItem checkable
154    """
155    assert isinstance(model, QtGui.QStandardItemModel)
156    item_list = [QtGui.QStandardItem(item) for item in param_list]
157    item_list[0].setCheckable(True)
158    model.appendRow(item_list)
159
160def addHeadersToModel(model):
161    """
162    Adds predefined headers to the model
163    """
[f54ce30]164    for i, item in enumerate(model_header_captions):
[b3e8629]165        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
[f54ce30]166
[fde5bcd]167    model.header_tooltips = copy.copy(model_header_tooltips)
[4d457df]168
[f182f93]169def addErrorHeadersToModel(model):
170    """
171    Adds predefined headers to the model
172    """
[fde5bcd]173    model_header_error_captions = copy.copy(model_header_captions)
[f54ce30]174    model_header_error_captions.insert(2, header_error_caption)
175    for i, item in enumerate(model_header_error_captions):
[b3e8629]176        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
[f182f93]177
[fde5bcd]178    model_header_error_tooltips = copy.copy(model_header_tooltips)
[f54ce30]179    model_header_error_tooltips.insert(2, error_tooltip)
[fde5bcd]180    model.header_tooltips = copy.copy(model_header_error_tooltips)
[a95c44b]181
[4d457df]182def addPolyHeadersToModel(model):
183    """
184    Adds predefined headers to the model
185    """
[f54ce30]186    for i, item in enumerate(poly_header_captions):
[b3e8629]187        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
[f54ce30]188
[fde5bcd]189    model.header_tooltips = copy.copy(poly_header_tooltips)
[4d457df]190
[a95c44b]191
[aca8418]192def addErrorPolyHeadersToModel(model):
193    """
194    Adds predefined headers to the model
195    """
[fde5bcd]196    poly_header_error_captions = copy.copy(poly_header_captions)
[f54ce30]197    poly_header_error_captions.insert(2, header_error_caption)
198    for i, item in enumerate(poly_header_error_captions):
[b3e8629]199        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
[f54ce30]200
[fde5bcd]201    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
[f54ce30]202    poly_header_error_tooltips.insert(2, error_tooltip)
[fde5bcd]203    model.header_tooltips = copy.copy(poly_header_error_tooltips)
[a95c44b]204
[4d457df]205def addShellsToModel(parameters, model, index):
206    """
207    Find out multishell parameters and update the model with the requested number of them
208    """
209    multishell_parameters = getIterParams(parameters)
210
[b3e8629]211    for i in range(index):
[4d457df]212        for par in multishell_parameters:
[b1e36a3]213            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
214            param_name = replaceShellName(par.name, i+1)
[4d457df]215            item1 = QtGui.QStandardItem(param_name)
216            item1.setCheckable(True)
217            # check for polydisp params
218            if par.polydisperse:
219                poly_item = QtGui.QStandardItem("Polydispersity")
220                item1_1 = QtGui.QStandardItem("Distribution")
221                # Find param in volume_params
222                for p in parameters.form_volume_parameters:
223                    if p.name != par.name:
224                        continue
225                    item1_2 = QtGui.QStandardItem(str(p.default))
226                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
227                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
228                    item1_5 = QtGui.QStandardItem(p.units)
229                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
230                    break
231                item1.appendRow([poly_item])
232
233            item2 = QtGui.QStandardItem(str(par.default))
234            item3 = QtGui.QStandardItem(str(par.limits[0]))
235            item4 = QtGui.QStandardItem(str(par.limits[1]))
236            item5 = QtGui.QStandardItem(par.units)
237            model.appendRow([item1, item2, item3, item4, item5])
238
[6fd4e36]239def calculateChi2(reference_data, current_data):
240    """
241    Calculate Chi2 value between two sets of data
242    """
243
244    # WEIGHING INPUT
245    #from sas.sasgui.perspectives.fitting.utils import get_weight
246    #flag = self.get_weight_flag()
247    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
[1bc27f1]248    chisqr = None
249    if reference_data is None:
250        return chisqr
[6fd4e36]251
252    # temporary default values for index and weight
253    index = None
254    weight = None
255
256    # Get data: data I, theory I, and data dI in order
257    if isinstance(reference_data, Data2D):
[1bc27f1]258        if index is None:
[6fd4e36]259            index = numpy.ones(len(current_data.data), dtype=bool)
[1bc27f1]260        if weight is not None:
[6fd4e36]261            current_data.err_data = weight
262        # get rid of zero error points
263        index = index & (current_data.err_data != 0)
264        index = index & (numpy.isfinite(current_data.data))
265        fn = current_data.data[index]
266        gn = reference_data.data[index]
267        en = current_data.err_data[index]
268    else:
269        # 1 d theory from model_thread is only in the range of index
[1bc27f1]270        if index is None:
[6fd4e36]271            index = numpy.ones(len(current_data.y), dtype=bool)
[1bc27f1]272        if weight is not None:
[6fd4e36]273            current_data.dy = weight
[1bc27f1]274        if current_data.dy is None or current_data.dy == []:
[6fd4e36]275            dy = numpy.ones(len(current_data.y))
276        else:
277            ## Set consistently w/AbstractFitengine:
278            # But this should be corrected later.
[fde5bcd]279            dy = copy.deepcopy(current_data.dy)
[6fd4e36]280            dy[dy == 0] = 1
281        fn = current_data.y[index]
282        gn = reference_data.y
283        en = dy[index]
284    # Calculate the residual
285    try:
286        res = (fn - gn) / en
287    except ValueError:
[180bd54]288        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
[0268aed]289        return None
[6fd4e36]290
291    residuals = res[numpy.isfinite(res)]
292    chisqr = numpy.average(residuals * residuals)
293
294    return chisqr
295
[0268aed]296def residualsData1D(reference_data, current_data):
297    """
[7d077d1]298    Calculate the residuals for difference of two Data1D sets
[0268aed]299    """
300    # temporary default values for index and weight
301    index = None
302    weight = None
303
304    # 1d theory from model_thread is only in the range of index
[180bd54]305    if current_data.dy is None or current_data.dy == []:
[0268aed]306        dy = numpy.ones(len(current_data.y))
307    else:
[180bd54]308        dy = weight if weight is not None else numpy.ones(len(current_data.y))
[0268aed]309        dy[dy == 0] = 1
310    fn = current_data.y[index][0]
311    gn = reference_data.y
312    en = dy[index][0]
313    # build residuals
314    residuals = Data1D()
[180bd54]315    if len(fn) == len(gn):
[0268aed]316        y = (fn - gn)/en
317        residuals.y = -y
[180bd54]318    else:
[d48cc19]319        # TODO: fix case where applying new data from file on top of existing model data
[689222c]320        try:
321            y = (fn - gn[index][0]) / en
322            residuals.y = y
323        except ValueError:
324            # value errors may show up every once in a while for malformed columns,
325            # just reuse what's there already
326            pass
[180bd54]327
[0268aed]328    residuals.x = current_data.x[index][0]
329    residuals.dy = numpy.ones(len(residuals.y))
330    residuals.dx = None
331    residuals.dxl = None
332    residuals.dxw = None
333    residuals.ytransform = 'y'
[1bc27f1]334    # For latter scale changes
[0268aed]335    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
336    residuals.yaxis('\\rm{Residuals} ', 'normalized')
337
338    return residuals
339
340def residualsData2D(reference_data, current_data):
341    """
[7d077d1]342    Calculate the residuals for difference of two Data2D sets
[0268aed]343    """
344    # temporary default values for index and weight
[1bc27f1]345    # index = None
[0268aed]346    weight = None
347
348    # build residuals
349    residuals = Data2D()
350    # Not for trunk the line below, instead use the line above
351    current_data.clone_without_data(len(current_data.data), residuals)
352    residuals.data = None
353    fn = current_data.data
354    gn = reference_data.data
[180bd54]355    en = current_data.err_data if weight is None else weight
[0268aed]356    residuals.data = (fn - gn) / en
357    residuals.qx_data = current_data.qx_data
358    residuals.qy_data = current_data.qy_data
359    residuals.q_data = current_data.q_data
360    residuals.err_data = numpy.ones(len(residuals.data))
361    residuals.xmin = min(residuals.qx_data)
362    residuals.xmax = max(residuals.qx_data)
363    residuals.ymin = min(residuals.qy_data)
364    residuals.ymax = max(residuals.qy_data)
365    residuals.q_data = current_data.q_data
366    residuals.mask = current_data.mask
367    residuals.scale = 'linear'
368    # check the lengths
369    if len(residuals.data) != len(residuals.q_data):
370        return None
371    return residuals
372
373def plotResiduals(reference_data, current_data):
374    """
375    Create Data1D/Data2D with residuals, ready for plotting
376    """
[fde5bcd]377    data_copy = copy.deepcopy(current_data)
[0268aed]378    # Get data: data I, theory I, and data dI in order
379    method_name = current_data.__class__.__name__
380    residuals_dict = {"Data1D": residualsData1D,
381                      "Data2D": residualsData2D}
382
383    residuals = residuals_dict[method_name](reference_data, data_copy)
384
385    theory_name = str(current_data.name.split()[0])
386    residuals.name = "Residuals for " + str(theory_name) + "[" + \
387                    str(reference_data.filename) + "]"
388    residuals.title = residuals.name
[f182f93]389    residuals.ytransform = 'y'
390
[0268aed]391    # when 2 data have the same id override the 1 st plotted
392    # include the last part if keeping charts for separate models is required
393    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
394    # group_id specify on which panel to plot this data
395    group_id = reference_data.group_id
396    residuals.group_id = "res" + str(group_id)
[1bc27f1]397
[0268aed]398    # Symbol
399    residuals.symbol = 0
400    residuals.hide_error = False
401
402    return residuals
403
[6fd4e36]404def binary_encode(i, digits):
[b3e8629]405    return [i >> d & 1 for d in range(digits)]
[6fd4e36]406
[fd1ae6d1]407def getWeight(data, is2d, flag=None):
408    """
409    Received flag and compute error on data.
410    :param flag: flag to transform error of data.
411    """
412    weight = None
413    if is2d:
414        dy_data = data.err_data
415        data = data.data
416    else:
417        dy_data = data.dy
418        data = data.y
419
420    if flag == 0:
421        weight = numpy.ones_like(data)
422    elif flag == 1:
423        weight = dy_data
424    elif flag == 2:
425        weight = numpy.sqrt(numpy.abs(data))
426    elif flag == 3:
427        weight = numpy.abs(data)
428    return weight
[d4dac80]429
430def updateKernelWithResults(kernel, results):
431    """
432    Takes model kernel and applies results dict to its parameters,
433    returning the modified (deep) copy of the kernel.
434    """
[8e2cd79]435    assert isinstance(results, dict)
[d4dac80]436    local_kernel = copy.deepcopy(kernel)
437
438    for parameter in results.keys():
439        # Update the parameter value - note: this supports +/-inf as well
440        local_kernel.setParam(parameter, results[parameter][0])
441
442    return local_kernel
443
444
[57be490]445def getStandardParam(model=None):
446    """
447    Returns a list with standard parameters for the current model
448    """
449    param = []
450    num_rows = model.rowCount()
451    if num_rows < 1:
452        return None
453
454    for row in range(num_rows):
455        param_name = model.item(row, 0).text()
[8e2cd79]456        checkbox_state = model.item(row, 0).checkState() == QtCore.Qt.Checked
457        value = model.item(row, 1).text()
[57be490]458        column_shift = 0
459        if model.columnCount() == 5: # no error column
460            error_state = False
461            error_value = 0.0
462        else:
463            error_state = True
464            error_value = model.item(row, 2).text()
465            column_shift = 1
466        min_state = True
467        max_state = True
468        min_value = model.item(row, 2+column_shift).text()
469        max_value = model.item(row, 3+column_shift).text()
470        unit = ""
471        if model.item(row, 4+column_shift) is not None:
472            unit = model.item(row, 4+column_shift).text()
473
474        param.append([checkbox_state, param_name, value, "",
475                        [error_state, error_value],
476                        [min_state, min_value],
477                        [max_state, max_value], unit])
478
479    return param
480
481def getOrientationParam(kernel_module=None):
482    """
483    Get the dictionary with orientation parameters
484    """
485    param = []
[8e2cd79]486    if kernel_module is None:
[57be490]487        return None
488    for param_name in list(kernel_module.params.keys()):
489        name = param_name
490        value = kernel_module.params[param_name]
491        min_state = True
492        max_state = True
493        error_state = False
494        error_value = 0.0
495        checkbox_state = True #??
496        details = kernel_module.details[param_name] #[unit, mix, max]
497        param.append([checkbox_state, name, value, "",
498                     [error_state, error_value],
499                     [min_state, details[1]],
500                     [max_state, details[2]], details[0]])
501
502    return param
[8e2cd79]503
504def formatParameters(parameters):
505    """
506    Prepare the parameter string in the standard SasView layout
507    """
508    assert parameters is not None
509    assert isinstance(parameters, list)
510    output_string = "sasview_parameter_values:"
511    for parameter in parameters:
512        output_string += ",".join([p for p in parameter if p is not None])
513        output_string += ":"
514    return output_string
515
516def formatParametersExcel(parameters):
517    """
518    Prepare the parameter string in the Excel format (tab delimited)
519    """
520    assert parameters is not None
521    assert isinstance(parameters, list)
522    crlf = chr(13) + chr(10)
523    tab = chr(9)
524
525    output_string = ""
526    # names
527    names = ""
528    values = ""
529    for parameter in parameters:
530        names += parameter[0]+tab
531        # Add the error column if fitted
532        if parameter[1] == "True" and parameter[3] is not None:
533            names += parameter[0]+"_err"+tab
534
535        values += parameter[2]+tab
536        if parameter[1] == "True" and parameter[3] is not None:
537            values += parameter[3]+tab
538        # add .npts and .nsigmas when necessary
539        if parameter[0][-6:] == ".width":
540            names += parameter[0].replace('.width', '.nsigmas') + tab
541            names += parameter[0].replace('.width', '.npts') + tab
542            values += parameter[5] + tab + parameter[4] + tab
543
544    output_string = names + crlf + values
545    return output_string
546
547def formatParametersLatex(parameters):
548    """
549    Prepare the parameter string in latex
550    """
551    assert parameters is not None
552    assert isinstance(parameters, list)
553    output_string = r'\begin{table}'
554    output_string += r'\begin{tabular}[h]'
555
556    crlf = chr(13) + chr(10)
557    output_string += '{|'
558    output_string += 'l|l|'*len(parameters)
559    output_string += r'}\hline'
560    output_string += crlf
561
562    for index, parameter in enumerate(parameters):
563        name = parameter[0] # Parameter name
564        output_string += name.replace('_', r'\_')  # Escape underscores
565        # Add the error column if fitted
566        if parameter[1] == "True" and parameter[3] is not None:
567            output_string += ' & '
568            output_string += parameter[0]+r'\_err'
569
570        if index < len(parameters) - 1:
571            output_string += ' & '
572
573        # add .npts and .nsigmas when necessary
574        if parameter[0][-6:] == ".width":
575            output_string += parameter[0].replace('.width', '.nsigmas') + ' & '
576            output_string += parameter[0].replace('.width', '.npts')
577
578            if index < len(parameters) - 1:
579                output_string += ' & '
580
581    output_string += r'\\ \hline'
582    output_string += crlf
583
584    # Construct row of values and errors
585    for index, parameter in enumerate(parameters):
586        output_string += parameter[2]
587        if parameter[1] == "True" and parameter[3] is not None:
588            output_string += ' & '
589            output_string += parameter[3]
590
591        if index < len(parameters) - 1:
592            output_string += ' & '
593
594        # add .npts and .nsigmas when necessary
595        if parameter[0][-6:] == ".width":
596            output_string += parameter[5] + ' & '
597            output_string += parameter[4]
598
599            if index < len(parameters) - 1:
600                output_string += ' & '
601
602    output_string += r'\\ \hline'
603    output_string += crlf
604    output_string += r'\end{tabular}'
605    output_string += r'\end{table}'
606
607    return output_string
Note: See TracBrowser for help on using the repository browser.