source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 4ea8020

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 4ea8020 was 4ea8020, checked in by GitHub <noreply@…>, 3 years ago

Merge branch 'ESS_GUI' into ESS_GUI_iss966

  • Property mode set to 100644
File size: 24.0 KB
Line 
1import copy
2
3from PyQt5 import QtCore
4from PyQt5 import QtGui
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
12
13model_header_tooltips = ['Select parameter for fitting',
14                         'Enter parameter value',
15                         'Enter minimum value for parameter',
16                         'Enter maximum value for parameter',
17                         'Unit of the parameter']
18
19poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
20                        'Function', 'Filename']
21
22poly_header_tooltips = ['Select parameter for fitting',
23                        'Enter polydispersity ratio (STD/mean). '
24                        'STD: standard deviation from the mean value',
25                        'Enter minimum value for parameter',
26                        'Enter maximum value for parameter',
27                        'Enter number of points for parameter',
28                        'Enter number of sigmas parameter',
29                        'Select distribution function',
30                        'Select filename with user-definable distribution']
31
32error_tooltip = 'Error value for fitted parameter'
33header_error_caption = 'Error'
34
35def replaceShellName(param_name, value):
36    """
37    Updates parameter name from <param_name>[n_shell] to <param_name>value
38    """
39    assert '[' in param_name
40    return param_name[:param_name.index('[')]+str(value)
41
42def getIterParams(model):
43    """
44    Returns a list of all multi-shell parameters in 'model'
45    """
46    return list([par for par in model.iq_parameters if "[" in par.name])
47
48def getMultiplicity(model):
49    """
50    Finds out if 'model' has multishell parameters.
51    If so, returns the name of the counter parameter and the number of shells
52    """
53    iter_params = getIterParams(model)
54    param_name = ""
55    param_length = 0
56    if iter_params:
57        param_length = iter_params[0].length
58        param_name = iter_params[0].length_control
59        if param_name is None and '[' in iter_params[0].name:
60            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
61    return (param_name, param_length)
62
63def addParametersToModel(parameters, kernel_module, is2D):
64    """
65    Update local ModelModel with sasmodel parameters
66    """
67    multishell_parameters = getIterParams(parameters)
68    multishell_param_name, _ = getMultiplicity(parameters)
69
70    if is2D:
71        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
72    else:
73        params = parameters.iq_parameters
74    item = []
75    for param in params:
76        # don't include shell parameters
77        if param.name == multishell_param_name:
78            continue
79        # Modify parameter name from <param>[n] to <param>1
80        item_name = param.name
81        if param in multishell_parameters:
82            continue
83        #    item_name = replaceShellName(param.name, 1)
84
85        item1 = QtGui.QStandardItem(item_name)
86        item1.setCheckable(True)
87        item1.setEditable(False)
88        # item_err = QtGui.QStandardItem()
89        # check for polydisp params
90        if param.polydisperse:
91            poly_item = QtGui.QStandardItem("Polydispersity")
92            poly_item.setEditable(False)
93            item1_1 = QtGui.QStandardItem("Distribution")
94            item1_1.setEditable(False)
95            # Find param in volume_params
96            for p in parameters.form_volume_parameters:
97                if p.name != param.name:
98                    continue
99                width = kernel_module.getParam(p.name+'.width')
100                ptype = kernel_module.getParam(p.name+'.type')
101
102                item1_2 = QtGui.QStandardItem(str(width))
103                item1_2.setEditable(False)
104                item1_3 = QtGui.QStandardItem()
105                item1_3.setEditable(False)
106                item1_4 = QtGui.QStandardItem()
107                item1_4.setEditable(False)
108                item1_5 = QtGui.QStandardItem(ptype)
109                item1_5.setEditable(False)
110                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
111                break
112            # Add the polydisp item as a child
113            item1.appendRow([poly_item])
114        # Param values
115        item2 = QtGui.QStandardItem(str(param.default))
116        # TODO: the error column.
117        # Either add a proxy model or a custom view delegate
118        #item_err = QtGui.QStandardItem()
119        item3 = QtGui.QStandardItem(str(param.limits[0]))
120        item4 = QtGui.QStandardItem(str(param.limits[1]))
121        item5 = QtGui.QStandardItem(param.units)
122        item5.setEditable(False)
123        item.append([item1, item2, item3, item4, item5])
124    return item
125
126def addSimpleParametersToModel(parameters, is2D, parameters_original=None):
127    """
128    Update local ModelModel with sasmodel parameters
129    parameters_original: list of parameters before any tagging on their IDs, e.g. for product model
130    (so that those are the display names; see below)
131    """
132    if is2D:
133        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
134    else:
135        params = parameters.iq_parameters
136
137    if parameters_original:
138        # 'parameters_original' contains the parameters as they are to be DISPLAYED, while 'parameters'
139        # contains the parameters as they were renamed; this is for handling name collisions in product model.
140        # The 'real name' of the parameter will be stored in the item's user data.
141        if is2D:
142            params_orig = [p for p in parameters_original.kernel_parameters if p.type != 'magnetic']
143        else:
144            params_orig = parameters_original.iq_parameters
145    else:
146        # no difference in names anyway
147        params_orig = params
148
149    item = []
150    for param, param_orig in zip(params, params_orig):
151        # Create the top level, checkable item
152        item_name = param_orig.name
153        item1 = QtGui.QStandardItem(item_name)
154        item1.setData(param.name, QtCore.Qt.UserRole)
155        item1.setCheckable(True)
156        item1.setEditable(False)
157        # Param values
158        # TODO: add delegate for validation of cells
159        item2 = QtGui.QStandardItem(str(param.default))
160        item4 = QtGui.QStandardItem(str(param.limits[0]))
161        item5 = QtGui.QStandardItem(str(param.limits[1]))
162        item6 = QtGui.QStandardItem(str(param.units))
163        item6.setEditable(False)
164        item.append([item1, item2, item4, item5, item6])
165    return item
166
167def markParameterDisabled(model, row):
168    """Given the QModel row number, format to show it is not available for fitting"""
169
170    # If an error column is present, there are a total of 6 columns.
171    items = [model.item(row, c) for c in range(6)]
172
173    model.blockSignals(True)
174
175    for item in items:
176        if item is None:
177            continue
178        item.setEditable(False)
179        item.setCheckable(False)
180
181    item = items[0]
182
183    font = QtGui.QFont()
184    font.setItalic(True)
185    item.setFont(font)
186    item.setForeground(QtGui.QBrush(QtGui.QColor(100, 100, 100)))
187    item.setToolTip("This parameter cannot be fitted.")
188
189    model.blockSignals(False)
190
191def addCheckedListToModel(model, param_list):
192    """
193    Add a QItem to model. Makes the QItem checkable
194    """
195    assert isinstance(model, QtGui.QStandardItemModel)
196    item_list = [QtGui.QStandardItem(item) for item in param_list]
197    item_list[0].setCheckable(True)
198    model.appendRow(item_list)
199
200def addHeadingRowToModel(model, name):
201    """adds a non-interactive top-level row to the model"""
202    header_row = [QtGui.QStandardItem() for i in range(5)]
203    header_row[0].setText(name)
204
205    font = header_row[0].font()
206    font.setBold(True)
207    header_row[0].setFont(font)
208
209    for item in header_row:
210        item.setEditable(False)
211        item.setCheckable(False)
212        item.setSelectable(False)
213
214    model.appendRow(header_row)
215
216def addHeadersToModel(model):
217    """
218    Adds predefined headers to the model
219    """
220    for i, item in enumerate(model_header_captions):
221        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
222
223    model.header_tooltips = copy.copy(model_header_tooltips)
224
225def addErrorHeadersToModel(model):
226    """
227    Adds predefined headers to the model
228    """
229    model_header_error_captions = copy.copy(model_header_captions)
230    model_header_error_captions.insert(2, header_error_caption)
231    for i, item in enumerate(model_header_error_captions):
232        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
233
234    model_header_error_tooltips = copy.copy(model_header_tooltips)
235    model_header_error_tooltips.insert(2, error_tooltip)
236    model.header_tooltips = copy.copy(model_header_error_tooltips)
237
238def addPolyHeadersToModel(model):
239    """
240    Adds predefined headers to the model
241    """
242    for i, item in enumerate(poly_header_captions):
243        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
244
245    model.header_tooltips = copy.copy(poly_header_tooltips)
246
247
248def addErrorPolyHeadersToModel(model):
249    """
250    Adds predefined headers to the model
251    """
252    poly_header_error_captions = copy.copy(poly_header_captions)
253    poly_header_error_captions.insert(2, header_error_caption)
254    for i, item in enumerate(poly_header_error_captions):
255        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
256
257    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
258    poly_header_error_tooltips.insert(2, error_tooltip)
259    model.header_tooltips = copy.copy(poly_header_error_tooltips)
260
261def addShellsToModel(parameters, model, index, row_num=None):
262    """
263    Find out multishell parameters and update the model with the requested number of them.
264    Inserts them after the row at row_num, if not None; otherwise, appends to end.
265    Returns a list of lists of QStandardItem objects.
266    """
267    multishell_parameters = getIterParams(parameters)
268
269    rows = []
270    for i in range(index):
271        for par in multishell_parameters:
272            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
273            param_name = replaceShellName(par.name, i+1)
274            item1 = QtGui.QStandardItem(param_name)
275            item1.setCheckable(True)
276            # check for polydisp params
277            if par.polydisperse:
278                poly_item = QtGui.QStandardItem("Polydispersity")
279                item1_1 = QtGui.QStandardItem("Distribution")
280                # Find param in volume_params
281                for p in parameters.form_volume_parameters:
282                    if p.name != par.name:
283                        continue
284                    item1_2 = QtGui.QStandardItem(str(p.default))
285                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
286                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
287                    item1_5 = QtGui.QStandardItem(p.units)
288                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
289                    break
290                item1.appendRow([poly_item])
291
292            item2 = QtGui.QStandardItem(str(par.default))
293            item3 = QtGui.QStandardItem(str(par.limits[0]))
294            item4 = QtGui.QStandardItem(str(par.limits[1]))
295            item5 = QtGui.QStandardItem(par.units)
296            row = [item1, item2, item3, item4, item5]
297            rows.append(row)
298
299            if row_num is None:
300                model.appendRow(row)
301            else:
302                model.insertRow(row_num, row)
303                row_num += 1
304
305    return rows
306
307def calculateChi2(reference_data, current_data):
308    """
309    Calculate Chi2 value between two sets of data
310    """
311    if reference_data is None or current_data is None:
312        return None
313    # WEIGHING INPUT
314    #from sas.sasgui.perspectives.fitting.utils import get_weight
315    #flag = self.get_weight_flag()
316    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
317    chisqr = None
318    if reference_data is None:
319        return chisqr
320
321    # temporary default values for index and weight
322    index = None
323    weight = None
324
325    # Get data: data I, theory I, and data dI in order
326    if isinstance(reference_data, Data2D):
327        if index is None:
328            index = numpy.ones(len(current_data.data), dtype=bool)
329        if weight is not None:
330            current_data.err_data = weight
331        # get rid of zero error points
332        index = index & (current_data.err_data != 0)
333        index = index & (numpy.isfinite(current_data.data))
334        fn = current_data.data[index]
335        gn = reference_data.data[index]
336        en = current_data.err_data[index]
337    else:
338        # 1 d theory from model_thread is only in the range of index
339        if index is None:
340            index = numpy.ones(len(current_data.y), dtype=bool)
341        if weight is not None:
342            current_data.dy = weight
343        if current_data.dy is None or current_data.dy == []:
344            dy = numpy.ones(len(current_data.y))
345        else:
346            ## Set consistently w/AbstractFitengine:
347            # But this should be corrected later.
348            dy = copy.deepcopy(current_data.dy)
349            dy[dy == 0] = 1
350        fn = current_data.y[index]
351        gn = reference_data.y
352        en = dy[index]
353    # Calculate the residual
354    try:
355        res = (fn - gn) / en
356    except ValueError:
357        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
358        return None
359
360    residuals = res[numpy.isfinite(res)]
361    chisqr = numpy.average(residuals * residuals)
362
363    return chisqr
364
365def residualsData1D(reference_data, current_data):
366    """
367    Calculate the residuals for difference of two Data1D sets
368    """
369    # temporary default values for index and weight
370    index = None
371    weight = None
372
373    # 1d theory from model_thread is only in the range of index
374    if current_data.dy is None or current_data.dy == []:
375        dy = numpy.ones(len(current_data.y))
376    else:
377        dy = weight if weight is not None else numpy.ones(len(current_data.y))
378        dy[dy == 0] = 1
379    fn = current_data.y[index][0]
380    gn = reference_data.y
381    en = dy[index][0]
382
383    # x values
384    x_current = current_data.x
385    x_reference = reference_data.x
386
387    # build residuals
388    residuals = Data1D()
389    if len(fn) == len(gn):
390        y = (fn - gn)/en
391        residuals.y = -y
392    elif len(fn) > len(gn):
393        residuals.y = (fn - gn[1:len(fn)])/en
394    else:
395        try:
396            y = numpy.zeros(len(current_data.y))
397            begin = 0
398            for i, x_value in enumerate(x_reference):
399                if x_value in x_current:
400                    begin = i
401                    break
402            end = len(x_reference)
403            endl = 0
404            for i, x_value in enumerate(list(x_reference)[::-1]):
405                if x_value in x_current:
406                    endl = i
407                    break
408            # make sure we have correct lengths
409            assert len(x_current) == len(x_reference[begin:end-endl])
410
411            y = (fn - gn[begin:end-endl])/en
412            residuals.y = y
413        except ValueError:
414            # value errors may show up every once in a while for malformed columns,
415            # just reuse what's there already
416            pass
417
418    residuals.x = current_data.x[index][0]
419    residuals.dy = numpy.ones(len(residuals.y))
420    residuals.dx = None
421    residuals.dxl = None
422    residuals.dxw = None
423    residuals.ytransform = 'y'
424    # For latter scale changes
425    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
426    residuals.yaxis('\\rm{Residuals} ', 'normalized')
427
428    return residuals
429
430def residualsData2D(reference_data, current_data):
431    """
432    Calculate the residuals for difference of two Data2D sets
433    """
434    # temporary default values for index and weight
435    # index = None
436    weight = None
437
438    # build residuals
439    residuals = Data2D()
440    # Not for trunk the line below, instead use the line above
441    current_data.clone_without_data(len(current_data.data), residuals)
442    residuals.data = None
443    fn = current_data.data
444    gn = reference_data.data
445    en = current_data.err_data if weight is None else weight
446    residuals.data = (fn - gn) / en
447    residuals.qx_data = current_data.qx_data
448    residuals.qy_data = current_data.qy_data
449    residuals.q_data = current_data.q_data
450    residuals.err_data = numpy.ones(len(residuals.data))
451    residuals.xmin = min(residuals.qx_data)
452    residuals.xmax = max(residuals.qx_data)
453    residuals.ymin = min(residuals.qy_data)
454    residuals.ymax = max(residuals.qy_data)
455    residuals.q_data = current_data.q_data
456    residuals.mask = current_data.mask
457    residuals.scale = 'linear'
458    # check the lengths
459    if len(residuals.data) != len(residuals.q_data):
460        return None
461    return residuals
462
463def plotResiduals(reference_data, current_data):
464    """
465    Create Data1D/Data2D with residuals, ready for plotting
466    """
467    data_copy = copy.deepcopy(current_data)
468    # Get data: data I, theory I, and data dI in order
469    method_name = current_data.__class__.__name__
470    residuals_dict = {"Data1D": residualsData1D,
471                      "Data2D": residualsData2D}
472
473    residuals = residuals_dict[method_name](reference_data, data_copy)
474
475    theory_name = str(current_data.name.split()[0])
476    res_name = reference_data.filename if reference_data.filename else reference_data.name
477    residuals.name = "Residuals for " + str(theory_name) + "[" + res_name + "]"
478    residuals.title = residuals.name
479    residuals.ytransform = 'y'
480
481    # when 2 data have the same id override the 1 st plotted
482    # include the last part if keeping charts for separate models is required
483    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
484    # group_id specify on which panel to plot this data
485    group_id = reference_data.group_id
486    residuals.group_id = "res" + str(group_id)
487
488    # Symbol
489    residuals.symbol = 0
490    residuals.hide_error = False
491
492    return residuals
493
494def binary_encode(i, digits):
495    return [i >> d & 1 for d in range(digits)]
496
497def getWeight(data, is2d, flag=None):
498    """
499    Received flag and compute error on data.
500    :param flag: flag to transform error of data.
501    """
502    weight = None
503    if data is None:
504        return []
505    if is2d:
506        if not hasattr(data, 'err_data'):
507            return []
508        dy_data = data.err_data
509        data = data.data
510    else:
511        if not hasattr(data, 'dy'):
512            return []
513        dy_data = data.dy
514        data = data.y
515
516    if flag == 0:
517        weight = numpy.ones_like(data)
518    elif flag == 1:
519        weight = dy_data
520    elif flag == 2:
521        weight = numpy.sqrt(numpy.abs(data))
522    elif flag == 3:
523        weight = numpy.abs(data)
524    return weight
525
526def updateKernelWithResults(kernel, results):
527    """
528    Takes model kernel and applies results dict to its parameters,
529    returning the modified (deep) copy of the kernel.
530    """
531    assert isinstance(results, dict)
532    local_kernel = copy.deepcopy(kernel)
533
534    for parameter in results.keys():
535        # Update the parameter value - note: this supports +/-inf as well
536        local_kernel.setParam(parameter, results[parameter][0])
537
538    return local_kernel
539
540
541def getStandardParam(model=None):
542    """
543    Returns a list with standard parameters for the current model
544    """
545    param = []
546    num_rows = model.rowCount()
547    if num_rows < 1:
548        return None
549
550    for row in range(num_rows):
551        param_name = model.item(row, 0).text()
552        checkbox_state = model.item(row, 0).checkState() == QtCore.Qt.Checked
553        value = model.item(row, 1).text()
554        column_shift = 0
555        if model.columnCount() == 5: # no error column
556            error_state = False
557            error_value = 0.0
558        else:
559            error_state = True
560            error_value = model.item(row, 2).text()
561            column_shift = 1
562        min_state = True
563        max_state = True
564        min_value = model.item(row, 2+column_shift).text()
565        max_value = model.item(row, 3+column_shift).text()
566        unit = ""
567        if model.item(row, 4+column_shift) is not None:
568            unit = model.item(row, 4+column_shift).text()
569
570        param.append([checkbox_state, param_name, value, "",
571                        [error_state, error_value],
572                        [min_state, min_value],
573                        [max_state, max_value], unit])
574
575    return param
576
577def getOrientationParam(kernel_module=None):
578    """
579    Get the dictionary with orientation parameters
580    """
581    param = []
582    if kernel_module is None:
583        return None
584    for param_name in list(kernel_module.params.keys()):
585        name = param_name
586        value = kernel_module.params[param_name]
587        min_state = True
588        max_state = True
589        error_state = False
590        error_value = 0.0
591        checkbox_state = True #??
592        details = kernel_module.details[param_name] #[unit, mix, max]
593        param.append([checkbox_state, name, value, "",
594                     [error_state, error_value],
595                     [min_state, details[1]],
596                     [max_state, details[2]], details[0]])
597
598    return param
599
600def formatParameters(parameters):
601    """
602    Prepare the parameter string in the standard SasView layout
603    """
604    assert parameters is not None
605    assert isinstance(parameters, list)
606    output_string = "sasview_parameter_values:"
607    for parameter in parameters:
608        output_string += ",".join([p for p in parameter if p is not None])
609        output_string += ":"
610    return output_string
611
612def formatParametersExcel(parameters):
613    """
614    Prepare the parameter string in the Excel format (tab delimited)
615    """
616    assert parameters is not None
617    assert isinstance(parameters, list)
618    crlf = chr(13) + chr(10)
619    tab = chr(9)
620
621    output_string = ""
622    # names
623    names = ""
624    values = ""
625    for parameter in parameters:
626        names += parameter[0]+tab
627        # Add the error column if fitted
628        if parameter[1] == "True" and parameter[3] is not None:
629            names += parameter[0]+"_err"+tab
630
631        values += parameter[2]+tab
632        if parameter[1] == "True" and parameter[3] is not None:
633            values += parameter[3]+tab
634        # add .npts and .nsigmas when necessary
635        if parameter[0][-6:] == ".width":
636            names += parameter[0].replace('.width', '.nsigmas') + tab
637            names += parameter[0].replace('.width', '.npts') + tab
638            values += parameter[5] + tab + parameter[4] + tab
639
640    output_string = names + crlf + values
641    return output_string
642
643def formatParametersLatex(parameters):
644    """
645    Prepare the parameter string in latex
646    """
647    assert parameters is not None
648    assert isinstance(parameters, list)
649    output_string = r'\begin{table}'
650    output_string += r'\begin{tabular}[h]'
651
652    crlf = chr(13) + chr(10)
653    output_string += '{|'
654    output_string += 'l|l|'*len(parameters)
655    output_string += r'}\hline'
656    output_string += crlf
657
658    for index, parameter in enumerate(parameters):
659        name = parameter[0] # Parameter name
660        output_string += name.replace('_', r'\_')  # Escape underscores
661        # Add the error column if fitted
662        if parameter[1] == "True" and parameter[3] is not None:
663            output_string += ' & '
664            output_string += parameter[0]+r'\_err'
665
666        if index < len(parameters) - 1:
667            output_string += ' & '
668
669        # add .npts and .nsigmas when necessary
670        if parameter[0][-6:] == ".width":
671            output_string += parameter[0].replace('.width', '.nsigmas') + ' & '
672            output_string += parameter[0].replace('.width', '.npts')
673
674            if index < len(parameters) - 1:
675                output_string += ' & '
676
677    output_string += r'\\ \hline'
678    output_string += crlf
679
680    # Construct row of values and errors
681    for index, parameter in enumerate(parameters):
682        output_string += parameter[2]
683        if parameter[1] == "True" and parameter[3] is not None:
684            output_string += ' & '
685            output_string += parameter[3]
686
687        if index < len(parameters) - 1:
688            output_string += ' & '
689
690        # add .npts and .nsigmas when necessary
691        if parameter[0][-6:] == ".width":
692            output_string += parameter[5] + ' & '
693            output_string += parameter[4]
694
695            if index < len(parameters) - 1:
696                output_string += ' & '
697
698    output_string += r'\\ \hline'
699    output_string += crlf
700    output_string += r'\end{tabular}'
701    output_string += r'\end{table}'
702
703    return output_string
Note: See TracBrowser for help on using the repository browser.