source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 04f775d

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 04f775d was 04f775d, checked in by Torin Cooper-Bennun <torin.cooper-bennun@…>, 6 years ago

cherry-pick fixed-choice param support, made more generic and cleaner

  • Property mode set to 100644
File size: 24.0 KB
Line 
1import copy
2
3from PyQt5 import QtCore
4from PyQt5 import QtGui
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11from sas.qtgui.Perspectives.Fitting.AssociatedComboBox import AssociatedComboBox
12
13model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
14
15model_header_tooltips = ['Select parameter for fitting',
16                         'Enter parameter value',
17                         'Enter minimum value for parameter',
18                         'Enter maximum value for parameter',
19                         'Unit of the parameter']
20
21poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
22                        'Function', 'Filename']
23
24poly_header_tooltips = ['Select parameter for fitting',
25                        'Enter polydispersity ratio (STD/mean). '
26                        'STD: standard deviation from the mean value',
27                        'Enter minimum value for parameter',
28                        'Enter maximum value for parameter',
29                        'Enter number of points for parameter',
30                        'Enter number of sigmas parameter',
31                        'Select distribution function',
32                        'Select filename with user-definable distribution']
33
34error_tooltip = 'Error value for fitted parameter'
35header_error_caption = 'Error'
36
37def replaceShellName(param_name, value):
38    """
39    Updates parameter name from <param_name>[n_shell] to <param_name>value
40    """
41    assert '[' in param_name
42    return param_name[:param_name.index('[')]+str(value)
43
44def getIterParams(model):
45    """
46    Returns a list of all multi-shell parameters in 'model'
47    """
48    return list([par for par in model.iq_parameters if "[" in par.name])
49
50def getMultiplicity(model):
51    """
52    Finds out if 'model' has multishell parameters.
53    If so, returns the name of the counter parameter and the number of shells
54    """
55    iter_params = getIterParams(model)
56    param_name = ""
57    param_length = 0
58    if iter_params:
59        param_length = iter_params[0].length
60        param_name = iter_params[0].length_control
61        if param_name is None and '[' in iter_params[0].name:
62            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
63    return (param_name, param_length)
64
65def createFixedChoiceComboBox(param, item_row):
66    """
67    Determines whether param is a fixed-choice parameter, modifies items in item_row appropriately and returns a combo
68    box containing the fixed choices. Returns None if param is not fixed-choice.
69   
70    item_row is a list of QStandardItem objects for insertion into the parameter table.
71    """
72
73    # Determine whether this is a fixed-choice parameter. There are lots of conditionals, simply because the
74    # implementation is not yet concrete; there are several possible indicators that the parameter is fixed-choice.
75    # TODO: (when the sasmodels implementation is concrete, clean this up)
76    choices = None
77    if type(param.choices) is list and len(param.choices) > 0:
78        # The choices property is concrete in sasmodels, probably will use this
79        choices = param.choices
80    elif type(param.units) is list:
81        choices = param.units
82
83    cbox = None
84    if choices is not None:
85        # Use combo box for input, if it is fixed-choice
86        cbox = AssociatedComboBox(item_row[1], idx_as_value=True)
87        cbox.addItems(choices)
88        item_row[2].setEditable(False)
89        item_row[3].setEditable(False)
90
91    return cbox
92
93def addParametersToModel(model, view, parameters, kernel_module, is2D):
94    """
95    Update local ModelModel with sasmodel parameters
96    """
97    multishell_parameters = getIterParams(parameters)
98    multishell_param_name, _ = getMultiplicity(parameters)
99
100    if is2D:
101        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
102    else:
103        params = parameters.iq_parameters
104
105    for param in params:
106        # don't include shell parameters
107        if param.name == multishell_param_name:
108            continue
109
110        # Modify parameter name from <param>[n] to <param>1
111        item_name = param.name
112        if param in multishell_parameters:
113            continue
114
115        item1 = QtGui.QStandardItem(item_name)
116        item1.setCheckable(True)
117        item1.setEditable(False)
118
119        # check for polydisp params
120        if param.polydisperse:
121            poly_item = QtGui.QStandardItem("Polydispersity")
122            poly_item.setEditable(False)
123            item1_1 = QtGui.QStandardItem("Distribution")
124            item1_1.setEditable(False)
125
126            # Find param in volume_params
127            for p in parameters.form_volume_parameters:
128                if p.name != param.name:
129                    continue
130                width = kernel_module.getParam(p.name+'.width')
131                ptype = kernel_module.getParam(p.name+'.type')
132                item1_2 = QtGui.QStandardItem(str(width))
133                item1_2.setEditable(False)
134                item1_3 = QtGui.QStandardItem()
135                item1_3.setEditable(False)
136                item1_4 = QtGui.QStandardItem()
137                item1_4.setEditable(False)
138                item1_5 = QtGui.QStandardItem(ptype)
139                item1_5.setEditable(False)
140                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
141                break
142
143            # Add the polydisp item as a child
144            item1.appendRow([poly_item])
145
146        # Param values
147        item2 = QtGui.QStandardItem(str(param.default))
148        item3 = QtGui.QStandardItem(str(param.limits[0]))
149        item4 = QtGui.QStandardItem(str(param.limits[1]))
150        item5 = QtGui.QStandardItem(str(param.units))
151        item5.setEditable(False)
152
153        # Check if fixed-choice (returns combobox, if so, also makes some items uneditable)
154        row = [item1, item2, item3, item4, item5]
155        cbox = createFixedChoiceComboBox(param, row)
156
157        # Append to the model and use the combobox, if required
158        model.appendRow(row)
159        if cbox is not None:
160            view.setIndexWidget(item2.index(), cbox)
161
162def addSimpleParametersToModel(model, view, parameters, is2D):
163    """
164    Update local ModelModel with sasmodel parameters (non-dispersed, non-magnetic)
165    """
166    if is2D:
167        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
168    else:
169        params = parameters.iq_parameters
170
171    for param in params:
172        # Create the top level, checkable item
173        item_name = param.name
174        item1 = QtGui.QStandardItem(item_name)
175        item1.setCheckable(True)
176        item1.setEditable(False)
177
178        # Param values
179        # TODO: add delegate for validation of cells
180        item2 = QtGui.QStandardItem(str(param.default))
181        item3 = QtGui.QStandardItem(str(param.limits[0]))
182        item4 = QtGui.QStandardItem(str(param.limits[1]))
183        item5 = QtGui.QStandardItem(str(param.units))
184        item5.setEditable(False)
185
186        # Check if fixed-choice (returns combobox, if so, also makes some items uneditable)
187        row = [item1, item2, item3, item4, item5]
188        cbox = createFixedChoiceComboBox(param, row)
189
190        # Append to the model and use the combobox, if required
191        model.appendRow(row)
192        if cbox is not None:
193            view.setIndexWidget(item2.index(), cbox)
194
195def markParameterDisabled(model, row):
196    """Given the QModel row number, format to show it is not available for fitting"""
197
198    # If an error column is present, there are a total of 6 columns.
199    items = [model.item(row, c) for c in range(6)]
200
201    model.blockSignals(True)
202
203    for item in items:
204        if item is None:
205            continue
206        item.setEditable(False)
207        item.setCheckable(False)
208
209    item = items[0]
210
211    font = QtGui.QFont()
212    font.setItalic(True)
213    item.setFont(font)
214    item.setForeground(QtGui.QBrush(QtGui.QColor(100, 100, 100)))
215    item.setToolTip("This parameter cannot be fitted.")
216
217    model.blockSignals(False)
218
219def addCheckedListToModel(model, param_list):
220    """
221    Add a QItem to model. Makes the QItem checkable
222    """
223    assert isinstance(model, QtGui.QStandardItemModel)
224    item_list = [QtGui.QStandardItem(item) for item in param_list]
225    item_list[0].setCheckable(True)
226    model.appendRow(item_list)
227
228def addHeadersToModel(model):
229    """
230    Adds predefined headers to the model
231    """
232    for i, item in enumerate(model_header_captions):
233        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
234
235    model.header_tooltips = copy.copy(model_header_tooltips)
236
237def addErrorHeadersToModel(model):
238    """
239    Adds predefined headers to the model
240    """
241    model_header_error_captions = copy.copy(model_header_captions)
242    model_header_error_captions.insert(2, header_error_caption)
243    for i, item in enumerate(model_header_error_captions):
244        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
245
246    model_header_error_tooltips = copy.copy(model_header_tooltips)
247    model_header_error_tooltips.insert(2, error_tooltip)
248    model.header_tooltips = copy.copy(model_header_error_tooltips)
249
250def addPolyHeadersToModel(model):
251    """
252    Adds predefined headers to the model
253    """
254    for i, item in enumerate(poly_header_captions):
255        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
256
257    model.header_tooltips = copy.copy(poly_header_tooltips)
258
259
260def addErrorPolyHeadersToModel(model):
261    """
262    Adds predefined headers to the model
263    """
264    poly_header_error_captions = copy.copy(poly_header_captions)
265    poly_header_error_captions.insert(2, header_error_caption)
266    for i, item in enumerate(poly_header_error_captions):
267        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
268
269    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
270    poly_header_error_tooltips.insert(2, error_tooltip)
271    model.header_tooltips = copy.copy(poly_header_error_tooltips)
272
273def addShellsToModel(parameters, model, index):
274    """
275    Find out multishell parameters and update the model with the requested number of them
276    """
277    multishell_parameters = getIterParams(parameters)
278
279    for i in range(index):
280        for par in multishell_parameters:
281            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
282            param_name = replaceShellName(par.name, i+1)
283            item1 = QtGui.QStandardItem(param_name)
284            item1.setCheckable(True)
285            # check for polydisp params
286            if par.polydisperse:
287                poly_item = QtGui.QStandardItem("Polydispersity")
288                item1_1 = QtGui.QStandardItem("Distribution")
289                # Find param in volume_params
290                for p in parameters.form_volume_parameters:
291                    if p.name != par.name:
292                        continue
293                    item1_2 = QtGui.QStandardItem(str(p.default))
294                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
295                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
296                    item1_5 = QtGui.QStandardItem(p.units)
297                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
298                    break
299                item1.appendRow([poly_item])
300
301            item2 = QtGui.QStandardItem(str(par.default))
302            item3 = QtGui.QStandardItem(str(par.limits[0]))
303            item4 = QtGui.QStandardItem(str(par.limits[1]))
304            item5 = QtGui.QStandardItem(par.units)
305            model.appendRow([item1, item2, item3, item4, item5])
306
307def calculateChi2(reference_data, current_data):
308    """
309    Calculate Chi2 value between two sets of data
310    """
311    if reference_data is None or current_data is None:
312        return None
313    # WEIGHING INPUT
314    #from sas.sasgui.perspectives.fitting.utils import get_weight
315    #flag = self.get_weight_flag()
316    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
317    chisqr = None
318    if reference_data is None:
319        return chisqr
320
321    # temporary default values for index and weight
322    index = None
323    weight = None
324
325    # Get data: data I, theory I, and data dI in order
326    if isinstance(reference_data, Data2D):
327        if index is None:
328            index = numpy.ones(len(current_data.data), dtype=bool)
329        if weight is not None:
330            current_data.err_data = weight
331        # get rid of zero error points
332        index = index & (current_data.err_data != 0)
333        index = index & (numpy.isfinite(current_data.data))
334        fn = current_data.data[index]
335        gn = reference_data.data[index]
336        en = current_data.err_data[index]
337    else:
338        # 1 d theory from model_thread is only in the range of index
339        if index is None:
340            index = numpy.ones(len(current_data.y), dtype=bool)
341        if weight is not None:
342            current_data.dy = weight
343        if current_data.dy is None or current_data.dy == []:
344            dy = numpy.ones(len(current_data.y))
345        else:
346            ## Set consistently w/AbstractFitengine:
347            # But this should be corrected later.
348            dy = copy.deepcopy(current_data.dy)
349            dy[dy == 0] = 1
350        fn = current_data.y[index]
351        gn = reference_data.y
352        en = dy[index]
353    # Calculate the residual
354    try:
355        res = (fn - gn) / en
356    except ValueError:
357        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
358        return None
359
360    residuals = res[numpy.isfinite(res)]
361    chisqr = numpy.average(residuals * residuals)
362
363    return chisqr
364
365def residualsData1D(reference_data, current_data):
366    """
367    Calculate the residuals for difference of two Data1D sets
368    """
369    # temporary default values for index and weight
370    index = None
371    weight = None
372
373    # 1d theory from model_thread is only in the range of index
374    if current_data.dy is None or current_data.dy == []:
375        dy = numpy.ones(len(current_data.y))
376    else:
377        dy = weight if weight is not None else numpy.ones(len(current_data.y))
378        dy[dy == 0] = 1
379    fn = current_data.y[index][0]
380    gn = reference_data.y
381    en = dy[index][0]
382
383    # x values
384    x_current = current_data.x
385    x_reference = reference_data.x
386
387    # build residuals
388    residuals = Data1D()
389    if len(fn) == len(gn):
390        y = (fn - gn)/en
391        residuals.y = -y
392    elif len(fn) > len(gn):
393        residuals.y = (fn - gn[1:len(fn)])/en
394    else:
395        try:
396            y = numpy.zeros(len(current_data.y))
397            begin = 0
398            for i, x_value in enumerate(x_reference):
399                if x_value in x_current:
400                    begin = i
401                    break
402            end = len(x_reference)
403            endl = 0
404            for i, x_value in enumerate(list(x_reference)[::-1]):
405                if x_value in x_current:
406                    endl = i
407                    break
408            # make sure we have correct lengths
409            assert len(x_current) == len(x_reference[begin:end-endl])
410
411            y = (fn - gn[begin:end-endl])/en
412            residuals.y = y
413        except ValueError:
414            # value errors may show up every once in a while for malformed columns,
415            # just reuse what's there already
416            pass
417
418    residuals.x = current_data.x[index][0]
419    residuals.dy = numpy.ones(len(residuals.y))
420    residuals.dx = None
421    residuals.dxl = None
422    residuals.dxw = None
423    residuals.ytransform = 'y'
424    # For latter scale changes
425    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
426    residuals.yaxis('\\rm{Residuals} ', 'normalized')
427
428    return residuals
429
430def residualsData2D(reference_data, current_data):
431    """
432    Calculate the residuals for difference of two Data2D sets
433    """
434    # temporary default values for index and weight
435    # index = None
436    weight = None
437
438    # build residuals
439    residuals = Data2D()
440    # Not for trunk the line below, instead use the line above
441    current_data.clone_without_data(len(current_data.data), residuals)
442    residuals.data = None
443    fn = current_data.data
444    gn = reference_data.data
445    en = current_data.err_data if weight is None else weight
446    residuals.data = (fn - gn) / en
447    residuals.qx_data = current_data.qx_data
448    residuals.qy_data = current_data.qy_data
449    residuals.q_data = current_data.q_data
450    residuals.err_data = numpy.ones(len(residuals.data))
451    residuals.xmin = min(residuals.qx_data)
452    residuals.xmax = max(residuals.qx_data)
453    residuals.ymin = min(residuals.qy_data)
454    residuals.ymax = max(residuals.qy_data)
455    residuals.q_data = current_data.q_data
456    residuals.mask = current_data.mask
457    residuals.scale = 'linear'
458    # check the lengths
459    if len(residuals.data) != len(residuals.q_data):
460        return None
461    return residuals
462
463def plotResiduals(reference_data, current_data):
464    """
465    Create Data1D/Data2D with residuals, ready for plotting
466    """
467    data_copy = copy.deepcopy(current_data)
468    # Get data: data I, theory I, and data dI in order
469    method_name = current_data.__class__.__name__
470    residuals_dict = {"Data1D": residualsData1D,
471                      "Data2D": residualsData2D}
472
473    residuals = residuals_dict[method_name](reference_data, data_copy)
474
475    theory_name = str(current_data.name.split()[0])
476    res_name = reference_data.filename if reference_data.filename else reference_data.name
477    residuals.name = "Residuals for " + str(theory_name) + "[" + res_name + "]"
478    residuals.title = residuals.name
479    residuals.ytransform = 'y'
480
481    # when 2 data have the same id override the 1 st plotted
482    # include the last part if keeping charts for separate models is required
483    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
484    # group_id specify on which panel to plot this data
485    group_id = reference_data.group_id
486    residuals.group_id = "res" + str(group_id)
487
488    # Symbol
489    residuals.symbol = 0
490    residuals.hide_error = False
491
492    return residuals
493
494def binary_encode(i, digits):
495    return [i >> d & 1 for d in range(digits)]
496
497def getWeight(data, is2d, flag=None):
498    """
499    Received flag and compute error on data.
500    :param flag: flag to transform error of data.
501    """
502    weight = None
503    if data is None:
504        return []
505    if is2d:
506        if not hasattr(data, 'err_data'):
507            return []
508        dy_data = data.err_data
509        data = data.data
510    else:
511        if not hasattr(data, 'dy'):
512            return []
513        dy_data = data.dy
514        data = data.y
515
516    if flag == 0:
517        weight = numpy.ones_like(data)
518    elif flag == 1:
519        weight = dy_data
520    elif flag == 2:
521        weight = numpy.sqrt(numpy.abs(data))
522    elif flag == 3:
523        weight = numpy.abs(data)
524    return weight
525
526def updateKernelWithResults(kernel, results):
527    """
528    Takes model kernel and applies results dict to its parameters,
529    returning the modified (deep) copy of the kernel.
530    """
531    assert isinstance(results, dict)
532    local_kernel = copy.deepcopy(kernel)
533
534    for parameter in results.keys():
535        # Update the parameter value - note: this supports +/-inf as well
536        local_kernel.setParam(parameter, results[parameter][0])
537
538    return local_kernel
539
540
541def getStandardParam(model=None):
542    """
543    Returns a list with standard parameters for the current model
544    """
545    param = []
546    num_rows = model.rowCount()
547    if num_rows < 1:
548        return None
549
550    for row in range(num_rows):
551        param_name = model.item(row, 0).text()
552        checkbox_state = model.item(row, 0).checkState() == QtCore.Qt.Checked
553        value = model.item(row, 1).text()
554        column_shift = 0
555        if model.columnCount() == 5: # no error column
556            error_state = False
557            error_value = 0.0
558        else:
559            error_state = True
560            error_value = model.item(row, 2).text()
561            column_shift = 1
562        min_state = True
563        max_state = True
564        min_value = model.item(row, 2+column_shift).text()
565        max_value = model.item(row, 3+column_shift).text()
566        unit = ""
567        if model.item(row, 4+column_shift) is not None:
568            unit = model.item(row, 4+column_shift).text()
569
570        param.append([checkbox_state, param_name, value, "",
571                        [error_state, error_value],
572                        [min_state, min_value],
573                        [max_state, max_value], unit])
574
575    return param
576
577def getOrientationParam(kernel_module=None):
578    """
579    Get the dictionary with orientation parameters
580    """
581    param = []
582    if kernel_module is None:
583        return None
584    for param_name in list(kernel_module.params.keys()):
585        name = param_name
586        value = kernel_module.params[param_name]
587        min_state = True
588        max_state = True
589        error_state = False
590        error_value = 0.0
591        checkbox_state = True #??
592        details = kernel_module.details[param_name] #[unit, mix, max]
593        param.append([checkbox_state, name, value, "",
594                     [error_state, error_value],
595                     [min_state, details[1]],
596                     [max_state, details[2]], details[0]])
597
598    return param
599
600def formatParameters(parameters):
601    """
602    Prepare the parameter string in the standard SasView layout
603    """
604    assert parameters is not None
605    assert isinstance(parameters, list)
606    output_string = "sasview_parameter_values:"
607    for parameter in parameters:
608        output_string += ",".join([p for p in parameter if p is not None])
609        output_string += ":"
610    return output_string
611
612def formatParametersExcel(parameters):
613    """
614    Prepare the parameter string in the Excel format (tab delimited)
615    """
616    assert parameters is not None
617    assert isinstance(parameters, list)
618    crlf = chr(13) + chr(10)
619    tab = chr(9)
620
621    output_string = ""
622    # names
623    names = ""
624    values = ""
625    for parameter in parameters:
626        names += parameter[0]+tab
627        # Add the error column if fitted
628        if parameter[1] == "True" and parameter[3] is not None:
629            names += parameter[0]+"_err"+tab
630
631        values += parameter[2]+tab
632        if parameter[1] == "True" and parameter[3] is not None:
633            values += parameter[3]+tab
634        # add .npts and .nsigmas when necessary
635        if parameter[0][-6:] == ".width":
636            names += parameter[0].replace('.width', '.nsigmas') + tab
637            names += parameter[0].replace('.width', '.npts') + tab
638            values += parameter[5] + tab + parameter[4] + tab
639
640    output_string = names + crlf + values
641    return output_string
642
643def formatParametersLatex(parameters):
644    """
645    Prepare the parameter string in latex
646    """
647    assert parameters is not None
648    assert isinstance(parameters, list)
649    output_string = r'\begin{table}'
650    output_string += r'\begin{tabular}[h]'
651
652    crlf = chr(13) + chr(10)
653    output_string += '{|'
654    output_string += 'l|l|'*len(parameters)
655    output_string += r'}\hline'
656    output_string += crlf
657
658    for index, parameter in enumerate(parameters):
659        name = parameter[0] # Parameter name
660        output_string += name.replace('_', r'\_')  # Escape underscores
661        # Add the error column if fitted
662        if parameter[1] == "True" and parameter[3] is not None:
663            output_string += ' & '
664            output_string += parameter[0]+r'\_err'
665
666        if index < len(parameters) - 1:
667            output_string += ' & '
668
669        # add .npts and .nsigmas when necessary
670        if parameter[0][-6:] == ".width":
671            output_string += parameter[0].replace('.width', '.nsigmas') + ' & '
672            output_string += parameter[0].replace('.width', '.npts')
673
674            if index < len(parameters) - 1:
675                output_string += ' & '
676
677    output_string += r'\\ \hline'
678    output_string += crlf
679
680    # Construct row of values and errors
681    for index, parameter in enumerate(parameters):
682        output_string += parameter[2]
683        if parameter[1] == "True" and parameter[3] is not None:
684            output_string += ' & '
685            output_string += parameter[3]
686
687        if index < len(parameters) - 1:
688            output_string += ' & '
689
690        # add .npts and .nsigmas when necessary
691        if parameter[0][-6:] == ".width":
692            output_string += parameter[5] + ' & '
693            output_string += parameter[4]
694
695            if index < len(parameters) - 1:
696                output_string += ' & '
697
698    output_string += r'\\ \hline'
699    output_string += crlf
700    output_string += r'\end{tabular}'
701    output_string += r'\end{table}'
702
703    return output_string
Note: See TracBrowser for help on using the repository browser.