source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ dcabba7

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since dcabba7 was b764ae5, checked in by Piotr Rozyczko <rozyczko@…>, 6 years ago

processEvents() helps with proper chart generation. - SASVIEW-890
Fixed weighing in fitting - SASVIEW-1017
Fixed error bars after fitting - SASVIEW-1004

  • Property mode set to 100644
File size: 22.4 KB
Line 
1import copy
2
3from PyQt5 import QtCore
4from PyQt5 import QtGui
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
12
13model_header_tooltips = ['Select parameter for fitting',
14                         'Enter parameter value',
15                         'Enter minimum value for parameter',
16                         'Enter maximum value for parameter',
17                         'Unit of the parameter']
18
19poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
20                        'Function', 'Filename']
21
22poly_header_tooltips = ['Select parameter for fitting',
23                        'Enter polydispersity ratio (STD/mean). '
24                        'STD: standard deviation from the mean value',
25                        'Enter minimum value for parameter',
26                        'Enter maximum value for parameter',
27                        'Enter number of points for parameter',
28                        'Enter number of sigmas parameter',
29                        'Select distribution function',
30                        'Select filename with user-definable distribution']
31
32error_tooltip = 'Error value for fitted parameter'
33header_error_caption = 'Error'
34
35def replaceShellName(param_name, value):
36    """
37    Updates parameter name from <param_name>[n_shell] to <param_name>value
38    """
39    assert '[' in param_name
40    return param_name[:param_name.index('[')]+str(value)
41
42def getIterParams(model):
43    """
44    Returns a list of all multi-shell parameters in 'model'
45    """
46    return list([par for par in model.iq_parameters if "[" in par.name])
47
48def getMultiplicity(model):
49    """
50    Finds out if 'model' has multishell parameters.
51    If so, returns the name of the counter parameter and the number of shells
52    """
53    iter_params = getIterParams(model)
54    param_name = ""
55    param_length = 0
56    if iter_params:
57        param_length = iter_params[0].length
58        param_name = iter_params[0].length_control
59        if param_name is None and '[' in iter_params[0].name:
60            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
61    return (param_name, param_length)
62
63def addParametersToModel(parameters, kernel_module, is2D):
64    """
65    Update local ModelModel with sasmodel parameters
66    """
67    multishell_parameters = getIterParams(parameters)
68    multishell_param_name, _ = getMultiplicity(parameters)
69
70    if is2D:
71        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
72    else:
73        params = parameters.iq_parameters
74    item = []
75    for param in params:
76        # don't include shell parameters
77        if param.name == multishell_param_name:
78            continue
79        # Modify parameter name from <param>[n] to <param>1
80        item_name = param.name
81        if param in multishell_parameters:
82            continue
83        #    item_name = replaceShellName(param.name, 1)
84
85        item1 = QtGui.QStandardItem(item_name)
86        item1.setCheckable(True)
87        item1.setEditable(False)
88        # item_err = QtGui.QStandardItem()
89        # check for polydisp params
90        if param.polydisperse:
91            poly_item = QtGui.QStandardItem("Polydispersity")
92            poly_item.setEditable(False)
93            item1_1 = QtGui.QStandardItem("Distribution")
94            item1_1.setEditable(False)
95            # Find param in volume_params
96            for p in parameters.form_volume_parameters:
97                if p.name != param.name:
98                    continue
99                width = kernel_module.getParam(p.name+'.width')
100                ptype = kernel_module.getParam(p.name+'.type')
101
102                item1_2 = QtGui.QStandardItem(str(width))
103                item1_2.setEditable(False)
104                item1_3 = QtGui.QStandardItem()
105                item1_3.setEditable(False)
106                item1_4 = QtGui.QStandardItem()
107                item1_4.setEditable(False)
108                item1_5 = QtGui.QStandardItem(ptype)
109                item1_5.setEditable(False)
110                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
111                break
112            # Add the polydisp item as a child
113            item1.appendRow([poly_item])
114        # Param values
115        item2 = QtGui.QStandardItem(str(param.default))
116        # TODO: the error column.
117        # Either add a proxy model or a custom view delegate
118        #item_err = QtGui.QStandardItem()
119        item3 = QtGui.QStandardItem(str(param.limits[0]))
120        item4 = QtGui.QStandardItem(str(param.limits[1]))
121        item5 = QtGui.QStandardItem(param.units)
122        item5.setEditable(False)
123        item.append([item1, item2, item3, item4, item5])
124    return item
125
126def addSimpleParametersToModel(parameters, is2D):
127    """
128    Update local ModelModel with sasmodel parameters
129    """
130    if is2D:
131        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
132    else:
133        params = parameters.iq_parameters
134    item = []
135    for param in params:
136        # Create the top level, checkable item
137        item_name = param.name
138        item1 = QtGui.QStandardItem(item_name)
139        item1.setCheckable(True)
140        item1.setEditable(False)
141        # Param values
142        # TODO: add delegate for validation of cells
143        item2 = QtGui.QStandardItem(str(param.default))
144        item4 = QtGui.QStandardItem(str(param.limits[0]))
145        item5 = QtGui.QStandardItem(str(param.limits[1]))
146        item6 = QtGui.QStandardItem(param.units)
147        item6.setEditable(False)
148        item.append([item1, item2, item4, item5, item6])
149    return item
150
151def markParameterDisabled(model, row):
152    """Given the QModel row number, format to show it is not available for fitting"""
153
154    # If an error column is present, there are a total of 6 columns.
155    items = [model.item(row, c) for c in range(6)]
156
157    model.blockSignals(True)
158
159    for item in items:
160        if item is None:
161            continue
162        item.setEditable(False)
163        item.setCheckable(False)
164
165    item = items[0]
166
167    font = QtGui.QFont()
168    font.setItalic(True)
169    item.setFont(font)
170    item.setForeground(QtGui.QBrush(QtGui.QColor(100, 100, 100)))
171    item.setToolTip("This parameter cannot be fitted.")
172
173    model.blockSignals(False)
174
175def addCheckedListToModel(model, param_list):
176    """
177    Add a QItem to model. Makes the QItem checkable
178    """
179    assert isinstance(model, QtGui.QStandardItemModel)
180    item_list = [QtGui.QStandardItem(item) for item in param_list]
181    item_list[0].setCheckable(True)
182    model.appendRow(item_list)
183
184def addHeadersToModel(model):
185    """
186    Adds predefined headers to the model
187    """
188    for i, item in enumerate(model_header_captions):
189        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
190
191    model.header_tooltips = copy.copy(model_header_tooltips)
192
193def addErrorHeadersToModel(model):
194    """
195    Adds predefined headers to the model
196    """
197    model_header_error_captions = copy.copy(model_header_captions)
198    model_header_error_captions.insert(2, header_error_caption)
199    for i, item in enumerate(model_header_error_captions):
200        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
201
202    model_header_error_tooltips = copy.copy(model_header_tooltips)
203    model_header_error_tooltips.insert(2, error_tooltip)
204    model.header_tooltips = copy.copy(model_header_error_tooltips)
205
206def addPolyHeadersToModel(model):
207    """
208    Adds predefined headers to the model
209    """
210    for i, item in enumerate(poly_header_captions):
211        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
212
213    model.header_tooltips = copy.copy(poly_header_tooltips)
214
215
216def addErrorPolyHeadersToModel(model):
217    """
218    Adds predefined headers to the model
219    """
220    poly_header_error_captions = copy.copy(poly_header_captions)
221    poly_header_error_captions.insert(2, header_error_caption)
222    for i, item in enumerate(poly_header_error_captions):
223        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
224
225    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
226    poly_header_error_tooltips.insert(2, error_tooltip)
227    model.header_tooltips = copy.copy(poly_header_error_tooltips)
228
229def addShellsToModel(parameters, model, index):
230    """
231    Find out multishell parameters and update the model with the requested number of them
232    """
233    multishell_parameters = getIterParams(parameters)
234
235    for i in range(index):
236        for par in multishell_parameters:
237            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
238            param_name = replaceShellName(par.name, i+1)
239            item1 = QtGui.QStandardItem(param_name)
240            item1.setCheckable(True)
241            # check for polydisp params
242            if par.polydisperse:
243                poly_item = QtGui.QStandardItem("Polydispersity")
244                item1_1 = QtGui.QStandardItem("Distribution")
245                # Find param in volume_params
246                for p in parameters.form_volume_parameters:
247                    if p.name != par.name:
248                        continue
249                    item1_2 = QtGui.QStandardItem(str(p.default))
250                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
251                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
252                    item1_5 = QtGui.QStandardItem(p.units)
253                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
254                    break
255                item1.appendRow([poly_item])
256
257            item2 = QtGui.QStandardItem(str(par.default))
258            item3 = QtGui.QStandardItem(str(par.limits[0]))
259            item4 = QtGui.QStandardItem(str(par.limits[1]))
260            item5 = QtGui.QStandardItem(par.units)
261            model.appendRow([item1, item2, item3, item4, item5])
262
263def calculateChi2(reference_data, current_data):
264    """
265    Calculate Chi2 value between two sets of data
266    """
267    if reference_data is None or current_data is None:
268        return None
269    # WEIGHING INPUT
270    #from sas.sasgui.perspectives.fitting.utils import get_weight
271    #flag = self.get_weight_flag()
272    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
273    chisqr = None
274    if reference_data is None:
275        return chisqr
276
277    # temporary default values for index and weight
278    index = None
279    weight = None
280
281    # Get data: data I, theory I, and data dI in order
282    if isinstance(reference_data, Data2D):
283        if index is None:
284            index = numpy.ones(len(current_data.data), dtype=bool)
285        if weight is not None:
286            current_data.err_data = weight
287        # get rid of zero error points
288        index = index & (current_data.err_data != 0)
289        index = index & (numpy.isfinite(current_data.data))
290        fn = current_data.data[index]
291        gn = reference_data.data[index]
292        en = current_data.err_data[index]
293    else:
294        # 1 d theory from model_thread is only in the range of index
295        if index is None:
296            index = numpy.ones(len(current_data.y), dtype=bool)
297        if weight is not None:
298            current_data.dy = weight
299        if current_data.dy is None or current_data.dy == []:
300            dy = numpy.ones(len(current_data.y))
301        else:
302            ## Set consistently w/AbstractFitengine:
303            # But this should be corrected later.
304            dy = copy.deepcopy(current_data.dy)
305            dy[dy == 0] = 1
306        fn = current_data.y[index]
307        gn = reference_data.y
308        en = dy[index]
309    # Calculate the residual
310    try:
311        res = (fn - gn) / en
312    except ValueError:
313        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
314        return None
315
316    residuals = res[numpy.isfinite(res)]
317    chisqr = numpy.average(residuals * residuals)
318
319    return chisqr
320
321def residualsData1D(reference_data, current_data):
322    """
323    Calculate the residuals for difference of two Data1D sets
324    """
325    # temporary default values for index and weight
326    index = None
327    weight = None
328
329    # 1d theory from model_thread is only in the range of index
330    if current_data.dy is None or current_data.dy == []:
331        dy = numpy.ones(len(current_data.y))
332    else:
333        dy = weight if weight is not None else numpy.ones(len(current_data.y))
334        dy[dy == 0] = 1
335    fn = current_data.y[index][0]
336    gn = reference_data.y
337    en = dy[index][0]
338
339    # x values
340    x_current = current_data.x
341    x_reference = reference_data.x
342
343    # build residuals
344    residuals = Data1D()
345    if len(fn) == len(gn):
346        y = (fn - gn)/en
347        residuals.y = -y
348    elif len(fn) > len(gn):
349        residuals.y = (fn - gn[1:len(fn)])/en
350    else:
351        try:
352            y = numpy.zeros(len(current_data.y))
353            begin = 0
354            for i, x_value in enumerate(x_reference):
355                if x_value in x_current:
356                    begin = i
357                    break
358            end = len(x_reference)
359            endl = 0
360            for i, x_value in enumerate(list(x_reference)[::-1]):
361                if x_value in x_current:
362                    endl = i
363                    break
364            # make sure we have correct lengths
365            assert len(x_current) == len(x_reference[begin:end-endl])
366
367            y = (fn - gn[begin:end-endl])/en
368            residuals.y = y
369        except ValueError:
370            # value errors may show up every once in a while for malformed columns,
371            # just reuse what's there already
372            pass
373
374    residuals.x = current_data.x[index][0]
375    residuals.dy = numpy.ones(len(residuals.y))
376    residuals.dx = None
377    residuals.dxl = None
378    residuals.dxw = None
379    residuals.ytransform = 'y'
380    # For latter scale changes
381    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
382    residuals.yaxis('\\rm{Residuals} ', 'normalized')
383
384    return residuals
385
386def residualsData2D(reference_data, current_data):
387    """
388    Calculate the residuals for difference of two Data2D sets
389    """
390    # temporary default values for index and weight
391    # index = None
392    weight = None
393
394    # build residuals
395    residuals = Data2D()
396    # Not for trunk the line below, instead use the line above
397    current_data.clone_without_data(len(current_data.data), residuals)
398    residuals.data = None
399    fn = current_data.data
400    gn = reference_data.data
401    en = current_data.err_data if weight is None else weight
402    residuals.data = (fn - gn) / en
403    residuals.qx_data = current_data.qx_data
404    residuals.qy_data = current_data.qy_data
405    residuals.q_data = current_data.q_data
406    residuals.err_data = numpy.ones(len(residuals.data))
407    residuals.xmin = min(residuals.qx_data)
408    residuals.xmax = max(residuals.qx_data)
409    residuals.ymin = min(residuals.qy_data)
410    residuals.ymax = max(residuals.qy_data)
411    residuals.q_data = current_data.q_data
412    residuals.mask = current_data.mask
413    residuals.scale = 'linear'
414    # check the lengths
415    if len(residuals.data) != len(residuals.q_data):
416        return None
417    return residuals
418
419def plotResiduals(reference_data, current_data):
420    """
421    Create Data1D/Data2D with residuals, ready for plotting
422    """
423    data_copy = copy.deepcopy(current_data)
424    # Get data: data I, theory I, and data dI in order
425    method_name = current_data.__class__.__name__
426    residuals_dict = {"Data1D": residualsData1D,
427                      "Data2D": residualsData2D}
428
429    residuals = residuals_dict[method_name](reference_data, data_copy)
430
431    theory_name = str(current_data.name.split()[0])
432    res_name = reference_data.filename if reference_data.filename else reference_data.name
433    residuals.name = "Residuals for " + str(theory_name) + "[" + res_name + "]"
434    residuals.title = residuals.name
435    residuals.ytransform = 'y'
436
437    # when 2 data have the same id override the 1 st plotted
438    # include the last part if keeping charts for separate models is required
439    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
440    # group_id specify on which panel to plot this data
441    group_id = reference_data.group_id
442    residuals.group_id = "res" + str(group_id)
443
444    # Symbol
445    residuals.symbol = 0
446    residuals.hide_error = False
447
448    return residuals
449
450def binary_encode(i, digits):
451    return [i >> d & 1 for d in range(digits)]
452
453def getWeight(data, is2d, flag=None):
454    """
455    Received flag and compute error on data.
456    :param flag: flag to transform error of data.
457    """
458    weight = None
459    if data is None:
460        return []
461    if is2d:
462        if not hasattr(data, 'err_data'):
463            return []
464        dy_data = data.err_data
465        data = data.data
466    else:
467        if not hasattr(data, 'dy'):
468            return []
469        dy_data = data.dy
470        data = data.y
471
472    if flag == 0:
473        weight = numpy.ones_like(data)
474    elif flag == 1:
475        weight = dy_data
476    elif flag == 2:
477        weight = numpy.sqrt(numpy.abs(data))
478    elif flag == 3:
479        weight = numpy.abs(data)
480    return weight
481
482def updateKernelWithResults(kernel, results):
483    """
484    Takes model kernel and applies results dict to its parameters,
485    returning the modified (deep) copy of the kernel.
486    """
487    assert isinstance(results, dict)
488    local_kernel = copy.deepcopy(kernel)
489
490    for parameter in results.keys():
491        # Update the parameter value - note: this supports +/-inf as well
492        local_kernel.setParam(parameter, results[parameter][0])
493
494    return local_kernel
495
496
497def getStandardParam(model=None):
498    """
499    Returns a list with standard parameters for the current model
500    """
501    param = []
502    num_rows = model.rowCount()
503    if num_rows < 1:
504        return None
505
506    for row in range(num_rows):
507        param_name = model.item(row, 0).text()
508        checkbox_state = model.item(row, 0).checkState() == QtCore.Qt.Checked
509        value = model.item(row, 1).text()
510        column_shift = 0
511        if model.columnCount() == 5: # no error column
512            error_state = False
513            error_value = 0.0
514        else:
515            error_state = True
516            error_value = model.item(row, 2).text()
517            column_shift = 1
518        min_state = True
519        max_state = True
520        min_value = model.item(row, 2+column_shift).text()
521        max_value = model.item(row, 3+column_shift).text()
522        unit = ""
523        if model.item(row, 4+column_shift) is not None:
524            unit = model.item(row, 4+column_shift).text()
525
526        param.append([checkbox_state, param_name, value, "",
527                        [error_state, error_value],
528                        [min_state, min_value],
529                        [max_state, max_value], unit])
530
531    return param
532
533def getOrientationParam(kernel_module=None):
534    """
535    Get the dictionary with orientation parameters
536    """
537    param = []
538    if kernel_module is None:
539        return None
540    for param_name in list(kernel_module.params.keys()):
541        name = param_name
542        value = kernel_module.params[param_name]
543        min_state = True
544        max_state = True
545        error_state = False
546        error_value = 0.0
547        checkbox_state = True #??
548        details = kernel_module.details[param_name] #[unit, mix, max]
549        param.append([checkbox_state, name, value, "",
550                     [error_state, error_value],
551                     [min_state, details[1]],
552                     [max_state, details[2]], details[0]])
553
554    return param
555
556def formatParameters(parameters):
557    """
558    Prepare the parameter string in the standard SasView layout
559    """
560    assert parameters is not None
561    assert isinstance(parameters, list)
562    output_string = "sasview_parameter_values:"
563    for parameter in parameters:
564        output_string += ",".join([p for p in parameter if p is not None])
565        output_string += ":"
566    return output_string
567
568def formatParametersExcel(parameters):
569    """
570    Prepare the parameter string in the Excel format (tab delimited)
571    """
572    assert parameters is not None
573    assert isinstance(parameters, list)
574    crlf = chr(13) + chr(10)
575    tab = chr(9)
576
577    output_string = ""
578    # names
579    names = ""
580    values = ""
581    for parameter in parameters:
582        names += parameter[0]+tab
583        # Add the error column if fitted
584        if parameter[1] == "True" and parameter[3] is not None:
585            names += parameter[0]+"_err"+tab
586
587        values += parameter[2]+tab
588        if parameter[1] == "True" and parameter[3] is not None:
589            values += parameter[3]+tab
590        # add .npts and .nsigmas when necessary
591        if parameter[0][-6:] == ".width":
592            names += parameter[0].replace('.width', '.nsigmas') + tab
593            names += parameter[0].replace('.width', '.npts') + tab
594            values += parameter[5] + tab + parameter[4] + tab
595
596    output_string = names + crlf + values
597    return output_string
598
599def formatParametersLatex(parameters):
600    """
601    Prepare the parameter string in latex
602    """
603    assert parameters is not None
604    assert isinstance(parameters, list)
605    output_string = r'\begin{table}'
606    output_string += r'\begin{tabular}[h]'
607
608    crlf = chr(13) + chr(10)
609    output_string += '{|'
610    output_string += 'l|l|'*len(parameters)
611    output_string += r'}\hline'
612    output_string += crlf
613
614    for index, parameter in enumerate(parameters):
615        name = parameter[0] # Parameter name
616        output_string += name.replace('_', r'\_')  # Escape underscores
617        # Add the error column if fitted
618        if parameter[1] == "True" and parameter[3] is not None:
619            output_string += ' & '
620            output_string += parameter[0]+r'\_err'
621
622        if index < len(parameters) - 1:
623            output_string += ' & '
624
625        # add .npts and .nsigmas when necessary
626        if parameter[0][-6:] == ".width":
627            output_string += parameter[0].replace('.width', '.nsigmas') + ' & '
628            output_string += parameter[0].replace('.width', '.npts')
629
630            if index < len(parameters) - 1:
631                output_string += ' & '
632
633    output_string += r'\\ \hline'
634    output_string += crlf
635
636    # Construct row of values and errors
637    for index, parameter in enumerate(parameters):
638        output_string += parameter[2]
639        if parameter[1] == "True" and parameter[3] is not None:
640            output_string += ' & '
641            output_string += parameter[3]
642
643        if index < len(parameters) - 1:
644            output_string += ' & '
645
646        # add .npts and .nsigmas when necessary
647        if parameter[0][-6:] == ".width":
648            output_string += parameter[5] + ' & '
649            output_string += parameter[4]
650
651            if index < len(parameters) - 1:
652                output_string += ' & '
653
654    output_string += r'\\ \hline'
655    output_string += crlf
656    output_string += r'\end{tabular}'
657    output_string += r'\end{table}'
658
659    return output_string
Note: See TracBrowser for help on using the repository browser.