source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 70f4458

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 70f4458 was 70f4458, checked in by Torin Cooper-Bennun <torin.cooper-bennun@…>, 6 years ago

shell parameters appear in P(Q) section correctly

  • Property mode set to 100644
File size: 23.8 KB
Line 
1import copy
2
3from PyQt5 import QtCore
4from PyQt5 import QtGui
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
12
13model_header_tooltips = ['Select parameter for fitting',
14                         'Enter parameter value',
15                         'Enter minimum value for parameter',
16                         'Enter maximum value for parameter',
17                         'Unit of the parameter']
18
19poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
20                        'Function', 'Filename']
21
22poly_header_tooltips = ['Select parameter for fitting',
23                        'Enter polydispersity ratio (STD/mean). '
24                        'STD: standard deviation from the mean value',
25                        'Enter minimum value for parameter',
26                        'Enter maximum value for parameter',
27                        'Enter number of points for parameter',
28                        'Enter number of sigmas parameter',
29                        'Select distribution function',
30                        'Select filename with user-definable distribution']
31
32error_tooltip = 'Error value for fitted parameter'
33header_error_caption = 'Error'
34
35def replaceShellName(param_name, value):
36    """
37    Updates parameter name from <param_name>[n_shell] to <param_name>value
38    """
39    assert '[' in param_name
40    return param_name[:param_name.index('[')]+str(value)
41
42def getIterParams(model):
43    """
44    Returns a list of all multi-shell parameters in 'model'
45    """
46    return list([par for par in model.iq_parameters if "[" in par.name])
47
48def getMultiplicity(model):
49    """
50    Finds out if 'model' has multishell parameters.
51    If so, returns the name of the counter parameter and the number of shells
52    """
53    iter_params = getIterParams(model)
54    param_name = ""
55    param_length = 0
56    if iter_params:
57        param_length = iter_params[0].length
58        param_name = iter_params[0].length_control
59        if param_name is None and '[' in iter_params[0].name:
60            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
61    return (param_name, param_length)
62
63def addParametersToModel(parameters, kernel_module, is2D):
64    """
65    Update local ModelModel with sasmodel parameters
66    """
67    multishell_parameters = getIterParams(parameters)
68    multishell_param_name, _ = getMultiplicity(parameters)
69
70    if is2D:
71        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
72    else:
73        params = parameters.iq_parameters
74    item = []
75    for param in params:
76        # don't include shell parameters
77        if param.name == multishell_param_name:
78            continue
79        # Modify parameter name from <param>[n] to <param>1
80        item_name = param.name
81        if param in multishell_parameters:
82            continue
83        #    item_name = replaceShellName(param.name, 1)
84
85        item1 = QtGui.QStandardItem(item_name)
86        item1.setCheckable(True)
87        item1.setEditable(False)
88        # item_err = QtGui.QStandardItem()
89        # check for polydisp params
90        if param.polydisperse:
91            poly_item = QtGui.QStandardItem("Polydispersity")
92            poly_item.setEditable(False)
93            item1_1 = QtGui.QStandardItem("Distribution")
94            item1_1.setEditable(False)
95            # Find param in volume_params
96            for p in parameters.form_volume_parameters:
97                if p.name != param.name:
98                    continue
99                width = kernel_module.getParam(p.name+'.width')
100                ptype = kernel_module.getParam(p.name+'.type')
101
102                item1_2 = QtGui.QStandardItem(str(width))
103                item1_2.setEditable(False)
104                item1_3 = QtGui.QStandardItem()
105                item1_3.setEditable(False)
106                item1_4 = QtGui.QStandardItem()
107                item1_4.setEditable(False)
108                item1_5 = QtGui.QStandardItem(ptype)
109                item1_5.setEditable(False)
110                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
111                break
112            # Add the polydisp item as a child
113            item1.appendRow([poly_item])
114        # Param values
115        item2 = QtGui.QStandardItem(str(param.default))
116        # TODO: the error column.
117        # Either add a proxy model or a custom view delegate
118        #item_err = QtGui.QStandardItem()
119        item3 = QtGui.QStandardItem(str(param.limits[0]))
120        item4 = QtGui.QStandardItem(str(param.limits[1]))
121        item5 = QtGui.QStandardItem(param.units)
122        item5.setEditable(False)
123        item.append([item1, item2, item3, item4, item5])
124    return item
125
126def addSimpleParametersToModel(parameters, is2D, parameters_original=None):
127    """
128    Update local ModelModel with sasmodel parameters
129    parameters_original: list of parameters before any tagging on their IDs, e.g. for product model
130    (so that those are the display names; see below)
131    """
132    if is2D:
133        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
134    else:
135        params = parameters.iq_parameters
136
137    if parameters_original:
138        # 'parameters_original' contains the parameters as they are to be DISPLAYED, while 'parameters'
139        # contains the parameters as they were renamed; this is for handling name collisions in product model.
140        # The 'real name' of the parameter will be stored in the item's user data.
141        if is2D:
142            params_orig = [p for p in parameters_original.kernel_parameters if p.type != 'magnetic']
143        else:
144            params_orig = parameters_original.iq_parameters
145    else:
146        # no difference in names anyway
147        params_orig = params
148
149    item = []
150    for param, param_orig in zip(params, params_orig):
151        # Create the top level, checkable item
152        item_name = param_orig.name
153        item1 = QtGui.QStandardItem(item_name)
154        item1.setData(param.name, QtCore.Qt.UserRole)
155        item1.setCheckable(True)
156        item1.setEditable(False)
157        # Param values
158        # TODO: add delegate for validation of cells
159        item2 = QtGui.QStandardItem(str(param.default))
160        item4 = QtGui.QStandardItem(str(param.limits[0]))
161        item5 = QtGui.QStandardItem(str(param.limits[1]))
162        item6 = QtGui.QStandardItem(str(param.units))
163        item6.setEditable(False)
164        item.append([item1, item2, item4, item5, item6])
165    return item
166
167def markParameterDisabled(model, row):
168    """Given the QModel row number, format to show it is not available for fitting"""
169
170    # If an error column is present, there are a total of 6 columns.
171    items = [model.item(row, c) for c in range(6)]
172
173    model.blockSignals(True)
174
175    for item in items:
176        if item is None:
177            continue
178        item.setEditable(False)
179        item.setCheckable(False)
180
181    item = items[0]
182
183    font = QtGui.QFont()
184    font.setItalic(True)
185    item.setFont(font)
186    item.setForeground(QtGui.QBrush(QtGui.QColor(100, 100, 100)))
187    item.setToolTip("This parameter cannot be fitted.")
188
189    model.blockSignals(False)
190
191def addCheckedListToModel(model, param_list):
192    """
193    Add a QItem to model. Makes the QItem checkable
194    """
195    assert isinstance(model, QtGui.QStandardItemModel)
196    item_list = [QtGui.QStandardItem(item) for item in param_list]
197    item_list[0].setCheckable(True)
198    model.appendRow(item_list)
199
200def addHeadingRowToModel(model, name):
201    """adds a non-interactive top-level row to the model"""
202    header_row = [QtGui.QStandardItem() for i in range(5)]
203    header_row[0].setText(name)
204
205    font = header_row[0].font()
206    font.setBold(True)
207    header_row[0].setFont(font)
208
209    for item in header_row:
210        item.setEditable(False)
211        item.setCheckable(False)
212        item.setSelectable(False)
213
214    model.appendRow(header_row)
215
216def addHeadersToModel(model):
217    """
218    Adds predefined headers to the model
219    """
220    for i, item in enumerate(model_header_captions):
221        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
222
223    model.header_tooltips = copy.copy(model_header_tooltips)
224
225def addErrorHeadersToModel(model):
226    """
227    Adds predefined headers to the model
228    """
229    model_header_error_captions = copy.copy(model_header_captions)
230    model_header_error_captions.insert(2, header_error_caption)
231    for i, item in enumerate(model_header_error_captions):
232        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
233
234    model_header_error_tooltips = copy.copy(model_header_tooltips)
235    model_header_error_tooltips.insert(2, error_tooltip)
236    model.header_tooltips = copy.copy(model_header_error_tooltips)
237
238def addPolyHeadersToModel(model):
239    """
240    Adds predefined headers to the model
241    """
242    for i, item in enumerate(poly_header_captions):
243        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
244
245    model.header_tooltips = copy.copy(poly_header_tooltips)
246
247
248def addErrorPolyHeadersToModel(model):
249    """
250    Adds predefined headers to the model
251    """
252    poly_header_error_captions = copy.copy(poly_header_captions)
253    poly_header_error_captions.insert(2, header_error_caption)
254    for i, item in enumerate(poly_header_error_captions):
255        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
256
257    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
258    poly_header_error_tooltips.insert(2, error_tooltip)
259    model.header_tooltips = copy.copy(poly_header_error_tooltips)
260
261def addShellsToModel(parameters, model, index, row_num=None):
262    """
263    Find out multishell parameters and update the model with the requested number of them.
264    Inserts them after the row at row_num, if not None; otherwise, appends to end.
265    Returns a list of lists of QStandardItem objects.
266    """
267    multishell_parameters = getIterParams(parameters)
268
269    rows = []
270    for i in range(index):
271        for par in multishell_parameters:
272            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
273            param_name = replaceShellName(par.name, i+1)
274            item1 = QtGui.QStandardItem(param_name)
275            item1.setCheckable(True)
276            # check for polydisp params
277            if par.polydisperse:
278                poly_item = QtGui.QStandardItem("Polydispersity")
279                item1_1 = QtGui.QStandardItem("Distribution")
280                # Find param in volume_params
281                for p in parameters.form_volume_parameters:
282                    if p.name != par.name:
283                        continue
284                    item1_2 = QtGui.QStandardItem(str(p.default))
285                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
286                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
287                    item1_5 = QtGui.QStandardItem(p.units)
288                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
289                    break
290                item1.appendRow([poly_item])
291
292            item2 = QtGui.QStandardItem(str(par.default))
293            item3 = QtGui.QStandardItem(str(par.limits[0]))
294            item4 = QtGui.QStandardItem(str(par.limits[1]))
295            item5 = QtGui.QStandardItem(par.units)
296            row = [item1, item2, item3, item4, item5]
297            rows.append(row)
298
299            if row_num is None:
300                model.appendRow(row)
301            else:
302                model.insertRow(row_num, row)
303                row_num += 1
304
305    return rows
306
307def calculateChi2(reference_data, current_data):
308    """
309    Calculate Chi2 value between two sets of data
310    """
311    if reference_data is None or current_data is None:
312        return None
313    # WEIGHING INPUT
314    #from sas.sasgui.perspectives.fitting.utils import get_weight
315    #flag = self.get_weight_flag()
316    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
317    chisqr = None
318    if reference_data is None:
319        return chisqr
320
321    # temporary default values for index and weight
322    index = None
323    weight = None
324
325    # Get data: data I, theory I, and data dI in order
326    if isinstance(reference_data, Data2D):
327        if index is None:
328            index = numpy.ones(len(current_data.data), dtype=bool)
329        if weight is not None:
330            current_data.err_data = weight
331        # get rid of zero error points
332        index = index & (current_data.err_data != 0)
333        index = index & (numpy.isfinite(current_data.data))
334        fn = current_data.data[index]
335        gn = reference_data.data[index]
336        en = current_data.err_data[index]
337    else:
338        # 1 d theory from model_thread is only in the range of index
339        if index is None:
340            index = numpy.ones(len(current_data.y), dtype=bool)
341        if weight is not None:
342            current_data.dy = weight
343        if current_data.dy is None or current_data.dy == []:
344            dy = numpy.ones(len(current_data.y))
345        else:
346            ## Set consistently w/AbstractFitengine:
347            # But this should be corrected later.
348            dy = copy.deepcopy(current_data.dy)
349            dy[dy == 0] = 1
350        fn = current_data.y[index]
351        gn = reference_data.y
352        en = dy[index]
353    # Calculate the residual
354    try:
355        res = (fn - gn) / en
356    except ValueError:
357        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
358        return None
359
360    residuals = res[numpy.isfinite(res)]
361    chisqr = numpy.average(residuals * residuals)
362
363    return chisqr
364
365def residualsData1D(reference_data, current_data):
366    """
367    Calculate the residuals for difference of two Data1D sets
368    """
369    # temporary default values for index and weight
370    index = None
371    weight = None
372
373    # 1d theory from model_thread is only in the range of index
374    if current_data.dy is None or current_data.dy == []:
375        dy = numpy.ones(len(current_data.y))
376    else:
377        dy = weight if weight is not None else numpy.ones(len(current_data.y))
378        dy[dy == 0] = 1
379    fn = current_data.y[index][0]
380    gn = reference_data.y
381    en = dy[index][0]
382
383    # x values
384    x_current = current_data.x
385    x_reference = reference_data.x
386
387    # build residuals
388    residuals = Data1D()
389    if len(fn) == len(gn):
390        y = (fn - gn)/en
391        residuals.y = -y
392    elif len(fn) > len(gn):
393        residuals.y = (fn - gn[1:len(fn)])/en
394    else:
395        try:
396            y = numpy.zeros(len(current_data.y))
397            begin = 0
398            for i, x_value in enumerate(x_reference):
399                if x_value in x_current:
400                    begin = i
401                    break
402            end = len(x_reference)
403            endl = 0
404            for i, x_value in enumerate(list(x_reference)[::-1]):
405                if x_value in x_current:
406                    endl = i
407                    break
408            # make sure we have correct lengths
409            assert len(x_current) == len(x_reference[begin:end-endl])
410
411            y = (fn - gn[begin:end-endl])/en
412            residuals.y = y
413        except ValueError:
414            # value errors may show up every once in a while for malformed columns,
415            # just reuse what's there already
416            pass
417
418    residuals.x = current_data.x[index][0]
419    residuals.dy = numpy.ones(len(residuals.y))
420    residuals.dx = None
421    residuals.dxl = None
422    residuals.dxw = None
423    residuals.ytransform = 'y'
424    # For latter scale changes
425    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
426    residuals.yaxis('\\rm{Residuals} ', 'normalized')
427
428    return residuals
429
430def residualsData2D(reference_data, current_data):
431    """
432    Calculate the residuals for difference of two Data2D sets
433    """
434    # temporary default values for index and weight
435    # index = None
436    weight = None
437
438    # build residuals
439    residuals = Data2D()
440    # Not for trunk the line below, instead use the line above
441    current_data.clone_without_data(len(current_data.data), residuals)
442    residuals.data = None
443    fn = current_data.data
444    gn = reference_data.data
445    en = current_data.err_data if weight is None else weight
446    residuals.data = (fn - gn) / en
447    residuals.qx_data = current_data.qx_data
448    residuals.qy_data = current_data.qy_data
449    residuals.q_data = current_data.q_data
450    residuals.err_data = numpy.ones(len(residuals.data))
451    residuals.xmin = min(residuals.qx_data)
452    residuals.xmax = max(residuals.qx_data)
453    residuals.ymin = min(residuals.qy_data)
454    residuals.ymax = max(residuals.qy_data)
455    residuals.q_data = current_data.q_data
456    residuals.mask = current_data.mask
457    residuals.scale = 'linear'
458    # check the lengths
459    if len(residuals.data) != len(residuals.q_data):
460        return None
461    return residuals
462
463def plotResiduals(reference_data, current_data):
464    """
465    Create Data1D/Data2D with residuals, ready for plotting
466    """
467    data_copy = copy.deepcopy(current_data)
468    # Get data: data I, theory I, and data dI in order
469    method_name = current_data.__class__.__name__
470    residuals_dict = {"Data1D": residualsData1D,
471                      "Data2D": residualsData2D}
472
473    residuals = residuals_dict[method_name](reference_data, data_copy)
474
475    theory_name = str(current_data.name.split()[0])
476    residuals.name = "Residuals for " + str(theory_name) + "[" + \
477                    str(reference_data.filename) + "]"
478    residuals.title = residuals.name
479    residuals.ytransform = 'y'
480
481    # when 2 data have the same id override the 1 st plotted
482    # include the last part if keeping charts for separate models is required
483    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
484    # group_id specify on which panel to plot this data
485    group_id = reference_data.group_id
486    residuals.group_id = "res" + str(group_id)
487
488    # Symbol
489    residuals.symbol = 0
490    residuals.hide_error = False
491
492    return residuals
493
494def binary_encode(i, digits):
495    return [i >> d & 1 for d in range(digits)]
496
497def getWeight(data, is2d, flag=None):
498    """
499    Received flag and compute error on data.
500    :param flag: flag to transform error of data.
501    """
502    weight = None
503    if is2d:
504        dy_data = data.err_data
505        data = data.data
506    else:
507        dy_data = data.dy
508        data = data.y
509
510    if flag == 0:
511        weight = numpy.ones_like(data)
512    elif flag == 1:
513        weight = dy_data
514    elif flag == 2:
515        weight = numpy.sqrt(numpy.abs(data))
516    elif flag == 3:
517        weight = numpy.abs(data)
518    return weight
519
520def updateKernelWithResults(kernel, results):
521    """
522    Takes model kernel and applies results dict to its parameters,
523    returning the modified (deep) copy of the kernel.
524    """
525    assert isinstance(results, dict)
526    local_kernel = copy.deepcopy(kernel)
527
528    for parameter in results.keys():
529        # Update the parameter value - note: this supports +/-inf as well
530        local_kernel.setParam(parameter, results[parameter][0])
531
532    return local_kernel
533
534
535def getStandardParam(model=None):
536    """
537    Returns a list with standard parameters for the current model
538    """
539    param = []
540    num_rows = model.rowCount()
541    if num_rows < 1:
542        return None
543
544    for row in range(num_rows):
545        param_name = model.item(row, 0).text()
546        checkbox_state = model.item(row, 0).checkState() == QtCore.Qt.Checked
547        value = model.item(row, 1).text()
548        column_shift = 0
549        if model.columnCount() == 5: # no error column
550            error_state = False
551            error_value = 0.0
552        else:
553            error_state = True
554            error_value = model.item(row, 2).text()
555            column_shift = 1
556        min_state = True
557        max_state = True
558        min_value = model.item(row, 2+column_shift).text()
559        max_value = model.item(row, 3+column_shift).text()
560        unit = ""
561        if model.item(row, 4+column_shift) is not None:
562            unit = model.item(row, 4+column_shift).text()
563
564        param.append([checkbox_state, param_name, value, "",
565                        [error_state, error_value],
566                        [min_state, min_value],
567                        [max_state, max_value], unit])
568
569    return param
570
571def getOrientationParam(kernel_module=None):
572    """
573    Get the dictionary with orientation parameters
574    """
575    param = []
576    if kernel_module is None:
577        return None
578    for param_name in list(kernel_module.params.keys()):
579        name = param_name
580        value = kernel_module.params[param_name]
581        min_state = True
582        max_state = True
583        error_state = False
584        error_value = 0.0
585        checkbox_state = True #??
586        details = kernel_module.details[param_name] #[unit, mix, max]
587        param.append([checkbox_state, name, value, "",
588                     [error_state, error_value],
589                     [min_state, details[1]],
590                     [max_state, details[2]], details[0]])
591
592    return param
593
594def formatParameters(parameters):
595    """
596    Prepare the parameter string in the standard SasView layout
597    """
598    assert parameters is not None
599    assert isinstance(parameters, list)
600    output_string = "sasview_parameter_values:"
601    for parameter in parameters:
602        output_string += ",".join([p for p in parameter if p is not None])
603        output_string += ":"
604    return output_string
605
606def formatParametersExcel(parameters):
607    """
608    Prepare the parameter string in the Excel format (tab delimited)
609    """
610    assert parameters is not None
611    assert isinstance(parameters, list)
612    crlf = chr(13) + chr(10)
613    tab = chr(9)
614
615    output_string = ""
616    # names
617    names = ""
618    values = ""
619    for parameter in parameters:
620        names += parameter[0]+tab
621        # Add the error column if fitted
622        if parameter[1] == "True" and parameter[3] is not None:
623            names += parameter[0]+"_err"+tab
624
625        values += parameter[2]+tab
626        if parameter[1] == "True" and parameter[3] is not None:
627            values += parameter[3]+tab
628        # add .npts and .nsigmas when necessary
629        if parameter[0][-6:] == ".width":
630            names += parameter[0].replace('.width', '.nsigmas') + tab
631            names += parameter[0].replace('.width', '.npts') + tab
632            values += parameter[5] + tab + parameter[4] + tab
633
634    output_string = names + crlf + values
635    return output_string
636
637def formatParametersLatex(parameters):
638    """
639    Prepare the parameter string in latex
640    """
641    assert parameters is not None
642    assert isinstance(parameters, list)
643    output_string = r'\begin{table}'
644    output_string += r'\begin{tabular}[h]'
645
646    crlf = chr(13) + chr(10)
647    output_string += '{|'
648    output_string += 'l|l|'*len(parameters)
649    output_string += r'}\hline'
650    output_string += crlf
651
652    for index, parameter in enumerate(parameters):
653        name = parameter[0] # Parameter name
654        output_string += name.replace('_', r'\_')  # Escape underscores
655        # Add the error column if fitted
656        if parameter[1] == "True" and parameter[3] is not None:
657            output_string += ' & '
658            output_string += parameter[0]+r'\_err'
659
660        if index < len(parameters) - 1:
661            output_string += ' & '
662
663        # add .npts and .nsigmas when necessary
664        if parameter[0][-6:] == ".width":
665            output_string += parameter[0].replace('.width', '.nsigmas') + ' & '
666            output_string += parameter[0].replace('.width', '.npts')
667
668            if index < len(parameters) - 1:
669                output_string += ' & '
670
671    output_string += r'\\ \hline'
672    output_string += crlf
673
674    # Construct row of values and errors
675    for index, parameter in enumerate(parameters):
676        output_string += parameter[2]
677        if parameter[1] == "True" and parameter[3] is not None:
678            output_string += ' & '
679            output_string += parameter[3]
680
681        if index < len(parameters) - 1:
682            output_string += ' & '
683
684        # add .npts and .nsigmas when necessary
685        if parameter[0][-6:] == ".width":
686            output_string += parameter[5] + ' & '
687            output_string += parameter[4]
688
689            if index < len(parameters) - 1:
690                output_string += ' & '
691
692    output_string += r'\\ \hline'
693    output_string += crlf
694    output_string += r'\end{tabular}'
695    output_string += r'\end{table}'
696
697    return output_string
Note: See TracBrowser for help on using the repository browser.