source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 00b7ddf0

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 00b7ddf0 was 00b7ddf0, checked in by Torin Cooper-Bennun <torin.cooper-bennun@…>, 6 years ago

[CHERRY-PICK FROM 4f80a834f] put sub-heading creation into FittingUtilities?; fix a couple of caused issues; fix affected tests

  • Property mode set to 100644
File size: 22.6 KB
Line 
1import copy
2
3from PyQt5 import QtCore
4from PyQt5 import QtGui
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
12
13model_header_tooltips = ['Select parameter for fitting',
14                         'Enter parameter value',
15                         'Enter minimum value for parameter',
16                         'Enter maximum value for parameter',
17                         'Unit of the parameter']
18
19poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
20                        'Function', 'Filename']
21
22poly_header_tooltips = ['Select parameter for fitting',
23                        'Enter polydispersity ratio (STD/mean). '
24                        'STD: standard deviation from the mean value',
25                        'Enter minimum value for parameter',
26                        'Enter maximum value for parameter',
27                        'Enter number of points for parameter',
28                        'Enter number of sigmas parameter',
29                        'Select distribution function',
30                        'Select filename with user-definable distribution']
31
32error_tooltip = 'Error value for fitted parameter'
33header_error_caption = 'Error'
34
35def replaceShellName(param_name, value):
36    """
37    Updates parameter name from <param_name>[n_shell] to <param_name>value
38    """
39    assert '[' in param_name
40    return param_name[:param_name.index('[')]+str(value)
41
42def getIterParams(model):
43    """
44    Returns a list of all multi-shell parameters in 'model'
45    """
46    return list([par for par in model.iq_parameters if "[" in par.name])
47
48def getMultiplicity(model):
49    """
50    Finds out if 'model' has multishell parameters.
51    If so, returns the name of the counter parameter and the number of shells
52    """
53    iter_params = getIterParams(model)
54    param_name = ""
55    param_length = 0
56    if iter_params:
57        param_length = iter_params[0].length
58        param_name = iter_params[0].length_control
59        if param_name is None and '[' in iter_params[0].name:
60            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
61    return (param_name, param_length)
62
63def addParametersToModel(parameters, kernel_module, is2D):
64    """
65    Update local ModelModel with sasmodel parameters
66    """
67    multishell_parameters = getIterParams(parameters)
68    multishell_param_name, _ = getMultiplicity(parameters)
69
70    if is2D:
71        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
72    else:
73        params = parameters.iq_parameters
74    item = []
75    for param in params:
76        # don't include shell parameters
77        if param.name == multishell_param_name:
78            continue
79        # Modify parameter name from <param>[n] to <param>1
80        item_name = param.name
81        if param in multishell_parameters:
82            continue
83        #    item_name = replaceShellName(param.name, 1)
84
85        item1 = QtGui.QStandardItem(item_name)
86        item1.setCheckable(True)
87        item1.setEditable(False)
88        # item_err = QtGui.QStandardItem()
89        # check for polydisp params
90        if param.polydisperse:
91            poly_item = QtGui.QStandardItem("Polydispersity")
92            poly_item.setEditable(False)
93            item1_1 = QtGui.QStandardItem("Distribution")
94            item1_1.setEditable(False)
95            # Find param in volume_params
96            for p in parameters.form_volume_parameters:
97                if p.name != param.name:
98                    continue
99                width = kernel_module.getParam(p.name+'.width')
100                ptype = kernel_module.getParam(p.name+'.type')
101
102                item1_2 = QtGui.QStandardItem(str(width))
103                item1_2.setEditable(False)
104                item1_3 = QtGui.QStandardItem()
105                item1_3.setEditable(False)
106                item1_4 = QtGui.QStandardItem()
107                item1_4.setEditable(False)
108                item1_5 = QtGui.QStandardItem(ptype)
109                item1_5.setEditable(False)
110                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
111                break
112            # Add the polydisp item as a child
113            item1.appendRow([poly_item])
114        # Param values
115        item2 = QtGui.QStandardItem(str(param.default))
116        # TODO: the error column.
117        # Either add a proxy model or a custom view delegate
118        #item_err = QtGui.QStandardItem()
119        item3 = QtGui.QStandardItem(str(param.limits[0]))
120        item4 = QtGui.QStandardItem(str(param.limits[1]))
121        item5 = QtGui.QStandardItem(param.units)
122        item5.setEditable(False)
123        item.append([item1, item2, item3, item4, item5])
124    return item
125
126def addSimpleParametersToModel(parameters, is2D):
127    """
128    Update local ModelModel with sasmodel parameters
129    """
130    if is2D:
131        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
132    else:
133        params = parameters.iq_parameters
134    item = []
135    for param in params:
136        # Create the top level, checkable item
137        item_name = param.name
138        item1 = QtGui.QStandardItem(item_name)
139        item1.setCheckable(True)
140        item1.setEditable(False)
141        # Param values
142        # TODO: add delegate for validation of cells
143        item2 = QtGui.QStandardItem(str(param.default))
144        item4 = QtGui.QStandardItem(str(param.limits[0]))
145        item5 = QtGui.QStandardItem(str(param.limits[1]))
146        item6 = QtGui.QStandardItem(param.units)
147        item6.setEditable(False)
148        item.append([item1, item2, item4, item5, item6])
149    return item
150
151def markParameterDisabled(model, row):
152    """Given the QModel row number, format to show it is not available for fitting"""
153
154    # If an error column is present, there are a total of 6 columns.
155    items = [model.item(row, c) for c in range(6)]
156
157    model.blockSignals(True)
158
159    for item in items:
160        if item is None:
161            continue
162        item.setEditable(False)
163        item.setCheckable(False)
164
165    item = items[0]
166
167    font = QtGui.QFont()
168    font.setItalic(True)
169    item.setFont(font)
170    item.setForeground(QtGui.QBrush(QtGui.QColor(100, 100, 100)))
171    item.setToolTip("This parameter cannot be fitted.")
172
173    model.blockSignals(False)
174
175def addCheckedListToModel(model, param_list):
176    """
177    Add a QItem to model. Makes the QItem checkable
178    """
179    assert isinstance(model, QtGui.QStandardItemModel)
180    item_list = [QtGui.QStandardItem(item) for item in param_list]
181    item_list[0].setCheckable(True)
182    model.appendRow(item_list)
183
184def addHeadingRowToModel(model, name):
185    """adds a non-interactive top-level row to the model"""
186    header_row = [QtGui.QStandardItem() for i in range(5)]
187    header_row[0].setText(name)
188
189    font = header_row[0].font()
190    font.setBold(True)
191    header_row[0].setFont(font)
192
193    for item in header_row:
194        item.setEditable(False)
195        item.setCheckable(False)
196        item.setSelectable(False)
197
198    model.appendRow(header_row)
199
200def addHeadersToModel(model):
201    """
202    Adds predefined headers to the model
203    """
204    for i, item in enumerate(model_header_captions):
205        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
206
207    model.header_tooltips = copy.copy(model_header_tooltips)
208
209def addErrorHeadersToModel(model):
210    """
211    Adds predefined headers to the model
212    """
213    model_header_error_captions = copy.copy(model_header_captions)
214    model_header_error_captions.insert(2, header_error_caption)
215    for i, item in enumerate(model_header_error_captions):
216        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
217
218    model_header_error_tooltips = copy.copy(model_header_tooltips)
219    model_header_error_tooltips.insert(2, error_tooltip)
220    model.header_tooltips = copy.copy(model_header_error_tooltips)
221
222def addPolyHeadersToModel(model):
223    """
224    Adds predefined headers to the model
225    """
226    for i, item in enumerate(poly_header_captions):
227        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
228
229    model.header_tooltips = copy.copy(poly_header_tooltips)
230
231
232def addErrorPolyHeadersToModel(model):
233    """
234    Adds predefined headers to the model
235    """
236    poly_header_error_captions = copy.copy(poly_header_captions)
237    poly_header_error_captions.insert(2, header_error_caption)
238    for i, item in enumerate(poly_header_error_captions):
239        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
240
241    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
242    poly_header_error_tooltips.insert(2, error_tooltip)
243    model.header_tooltips = copy.copy(poly_header_error_tooltips)
244
245def addShellsToModel(parameters, model, index):
246    """
247    Find out multishell parameters and update the model with the requested number of them
248    """
249    multishell_parameters = getIterParams(parameters)
250
251    for i in range(index):
252        for par in multishell_parameters:
253            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
254            param_name = replaceShellName(par.name, i+1)
255            item1 = QtGui.QStandardItem(param_name)
256            item1.setCheckable(True)
257            # check for polydisp params
258            if par.polydisperse:
259                poly_item = QtGui.QStandardItem("Polydispersity")
260                item1_1 = QtGui.QStandardItem("Distribution")
261                # Find param in volume_params
262                for p in parameters.form_volume_parameters:
263                    if p.name != par.name:
264                        continue
265                    item1_2 = QtGui.QStandardItem(str(p.default))
266                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
267                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
268                    item1_5 = QtGui.QStandardItem(p.units)
269                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
270                    break
271                item1.appendRow([poly_item])
272
273            item2 = QtGui.QStandardItem(str(par.default))
274            item3 = QtGui.QStandardItem(str(par.limits[0]))
275            item4 = QtGui.QStandardItem(str(par.limits[1]))
276            item5 = QtGui.QStandardItem(par.units)
277            model.appendRow([item1, item2, item3, item4, item5])
278
279def calculateChi2(reference_data, current_data):
280    """
281    Calculate Chi2 value between two sets of data
282    """
283    if reference_data is None or current_data is None:
284        return None
285    # WEIGHING INPUT
286    #from sas.sasgui.perspectives.fitting.utils import get_weight
287    #flag = self.get_weight_flag()
288    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
289    chisqr = None
290    if reference_data is None:
291        return chisqr
292
293    # temporary default values for index and weight
294    index = None
295    weight = None
296
297    # Get data: data I, theory I, and data dI in order
298    if isinstance(reference_data, Data2D):
299        if index is None:
300            index = numpy.ones(len(current_data.data), dtype=bool)
301        if weight is not None:
302            current_data.err_data = weight
303        # get rid of zero error points
304        index = index & (current_data.err_data != 0)
305        index = index & (numpy.isfinite(current_data.data))
306        fn = current_data.data[index]
307        gn = reference_data.data[index]
308        en = current_data.err_data[index]
309    else:
310        # 1 d theory from model_thread is only in the range of index
311        if index is None:
312            index = numpy.ones(len(current_data.y), dtype=bool)
313        if weight is not None:
314            current_data.dy = weight
315        if current_data.dy is None or current_data.dy == []:
316            dy = numpy.ones(len(current_data.y))
317        else:
318            ## Set consistently w/AbstractFitengine:
319            # But this should be corrected later.
320            dy = copy.deepcopy(current_data.dy)
321            dy[dy == 0] = 1
322        fn = current_data.y[index]
323        gn = reference_data.y
324        en = dy[index]
325    # Calculate the residual
326    try:
327        res = (fn - gn) / en
328    except ValueError:
329        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
330        return None
331
332    residuals = res[numpy.isfinite(res)]
333    chisqr = numpy.average(residuals * residuals)
334
335    return chisqr
336
337def residualsData1D(reference_data, current_data):
338    """
339    Calculate the residuals for difference of two Data1D sets
340    """
341    # temporary default values for index and weight
342    index = None
343    weight = None
344
345    # 1d theory from model_thread is only in the range of index
346    if current_data.dy is None or current_data.dy == []:
347        dy = numpy.ones(len(current_data.y))
348    else:
349        dy = weight if weight is not None else numpy.ones(len(current_data.y))
350        dy[dy == 0] = 1
351    fn = current_data.y[index][0]
352    gn = reference_data.y
353    en = dy[index][0]
354
355    # x values
356    x_current = current_data.x
357    x_reference = reference_data.x
358
359    # build residuals
360    residuals = Data1D()
361    if len(fn) == len(gn):
362        y = (fn - gn)/en
363        residuals.y = -y
364    elif len(fn) > len(gn):
365        residuals.y = (fn - gn[1:len(fn)])/en
366    else:
367        try:
368            y = numpy.zeros(len(current_data.y))
369            begin = 0
370            for i, x_value in enumerate(x_reference):
371                if x_value in x_current:
372                    begin = i
373                    break
374            end = len(x_reference)
375            endl = 0
376            for i, x_value in enumerate(list(x_reference)[::-1]):
377                if x_value in x_current:
378                    endl = i
379                    break
380            # make sure we have correct lengths
381            assert len(x_current) == len(x_reference[begin:end-endl])
382
383            y = (fn - gn[begin:end-endl])/en
384            residuals.y = y
385        except ValueError:
386            # value errors may show up every once in a while for malformed columns,
387            # just reuse what's there already
388            pass
389
390    residuals.x = current_data.x[index][0]
391    residuals.dy = numpy.ones(len(residuals.y))
392    residuals.dx = None
393    residuals.dxl = None
394    residuals.dxw = None
395    residuals.ytransform = 'y'
396    # For latter scale changes
397    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
398    residuals.yaxis('\\rm{Residuals} ', 'normalized')
399
400    return residuals
401
402def residualsData2D(reference_data, current_data):
403    """
404    Calculate the residuals for difference of two Data2D sets
405    """
406    # temporary default values for index and weight
407    # index = None
408    weight = None
409
410    # build residuals
411    residuals = Data2D()
412    # Not for trunk the line below, instead use the line above
413    current_data.clone_without_data(len(current_data.data), residuals)
414    residuals.data = None
415    fn = current_data.data
416    gn = reference_data.data
417    en = current_data.err_data if weight is None else weight
418    residuals.data = (fn - gn) / en
419    residuals.qx_data = current_data.qx_data
420    residuals.qy_data = current_data.qy_data
421    residuals.q_data = current_data.q_data
422    residuals.err_data = numpy.ones(len(residuals.data))
423    residuals.xmin = min(residuals.qx_data)
424    residuals.xmax = max(residuals.qx_data)
425    residuals.ymin = min(residuals.qy_data)
426    residuals.ymax = max(residuals.qy_data)
427    residuals.q_data = current_data.q_data
428    residuals.mask = current_data.mask
429    residuals.scale = 'linear'
430    # check the lengths
431    if len(residuals.data) != len(residuals.q_data):
432        return None
433    return residuals
434
435def plotResiduals(reference_data, current_data):
436    """
437    Create Data1D/Data2D with residuals, ready for plotting
438    """
439    data_copy = copy.deepcopy(current_data)
440    # Get data: data I, theory I, and data dI in order
441    method_name = current_data.__class__.__name__
442    residuals_dict = {"Data1D": residualsData1D,
443                      "Data2D": residualsData2D}
444
445    residuals = residuals_dict[method_name](reference_data, data_copy)
446
447    theory_name = str(current_data.name.split()[0])
448    residuals.name = "Residuals for " + str(theory_name) + "[" + \
449                    str(reference_data.filename) + "]"
450    residuals.title = residuals.name
451    residuals.ytransform = 'y'
452
453    # when 2 data have the same id override the 1 st plotted
454    # include the last part if keeping charts for separate models is required
455    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
456    # group_id specify on which panel to plot this data
457    group_id = reference_data.group_id
458    residuals.group_id = "res" + str(group_id)
459
460    # Symbol
461    residuals.symbol = 0
462    residuals.hide_error = False
463
464    return residuals
465
466def binary_encode(i, digits):
467    return [i >> d & 1 for d in range(digits)]
468
469def getWeight(data, is2d, flag=None):
470    """
471    Received flag and compute error on data.
472    :param flag: flag to transform error of data.
473    """
474    weight = None
475    if is2d:
476        dy_data = data.err_data
477        data = data.data
478    else:
479        dy_data = data.dy
480        data = data.y
481
482    if flag == 0:
483        weight = numpy.ones_like(data)
484    elif flag == 1:
485        weight = dy_data
486    elif flag == 2:
487        weight = numpy.sqrt(numpy.abs(data))
488    elif flag == 3:
489        weight = numpy.abs(data)
490    return weight
491
492def updateKernelWithResults(kernel, results):
493    """
494    Takes model kernel and applies results dict to its parameters,
495    returning the modified (deep) copy of the kernel.
496    """
497    assert isinstance(results, dict)
498    local_kernel = copy.deepcopy(kernel)
499
500    for parameter in results.keys():
501        # Update the parameter value - note: this supports +/-inf as well
502        local_kernel.setParam(parameter, results[parameter][0])
503
504    return local_kernel
505
506
507def getStandardParam(model=None):
508    """
509    Returns a list with standard parameters for the current model
510    """
511    param = []
512    num_rows = model.rowCount()
513    if num_rows < 1:
514        return None
515
516    for row in range(num_rows):
517        param_name = model.item(row, 0).text()
518        checkbox_state = model.item(row, 0).checkState() == QtCore.Qt.Checked
519        value = model.item(row, 1).text()
520        column_shift = 0
521        if model.columnCount() == 5: # no error column
522            error_state = False
523            error_value = 0.0
524        else:
525            error_state = True
526            error_value = model.item(row, 2).text()
527            column_shift = 1
528        min_state = True
529        max_state = True
530        min_value = model.item(row, 2+column_shift).text()
531        max_value = model.item(row, 3+column_shift).text()
532        unit = ""
533        if model.item(row, 4+column_shift) is not None:
534            unit = model.item(row, 4+column_shift).text()
535
536        param.append([checkbox_state, param_name, value, "",
537                        [error_state, error_value],
538                        [min_state, min_value],
539                        [max_state, max_value], unit])
540
541    return param
542
543def getOrientationParam(kernel_module=None):
544    """
545    Get the dictionary with orientation parameters
546    """
547    param = []
548    if kernel_module is None:
549        return None
550    for param_name in list(kernel_module.params.keys()):
551        name = param_name
552        value = kernel_module.params[param_name]
553        min_state = True
554        max_state = True
555        error_state = False
556        error_value = 0.0
557        checkbox_state = True #??
558        details = kernel_module.details[param_name] #[unit, mix, max]
559        param.append([checkbox_state, name, value, "",
560                     [error_state, error_value],
561                     [min_state, details[1]],
562                     [max_state, details[2]], details[0]])
563
564    return param
565
566def formatParameters(parameters):
567    """
568    Prepare the parameter string in the standard SasView layout
569    """
570    assert parameters is not None
571    assert isinstance(parameters, list)
572    output_string = "sasview_parameter_values:"
573    for parameter in parameters:
574        output_string += ",".join([p for p in parameter if p is not None])
575        output_string += ":"
576    return output_string
577
578def formatParametersExcel(parameters):
579    """
580    Prepare the parameter string in the Excel format (tab delimited)
581    """
582    assert parameters is not None
583    assert isinstance(parameters, list)
584    crlf = chr(13) + chr(10)
585    tab = chr(9)
586
587    output_string = ""
588    # names
589    names = ""
590    values = ""
591    for parameter in parameters:
592        names += parameter[0]+tab
593        # Add the error column if fitted
594        if parameter[1] == "True" and parameter[3] is not None:
595            names += parameter[0]+"_err"+tab
596
597        values += parameter[2]+tab
598        if parameter[1] == "True" and parameter[3] is not None:
599            values += parameter[3]+tab
600        # add .npts and .nsigmas when necessary
601        if parameter[0][-6:] == ".width":
602            names += parameter[0].replace('.width', '.nsigmas') + tab
603            names += parameter[0].replace('.width', '.npts') + tab
604            values += parameter[5] + tab + parameter[4] + tab
605
606    output_string = names + crlf + values
607    return output_string
608
609def formatParametersLatex(parameters):
610    """
611    Prepare the parameter string in latex
612    """
613    assert parameters is not None
614    assert isinstance(parameters, list)
615    output_string = r'\begin{table}'
616    output_string += r'\begin{tabular}[h]'
617
618    crlf = chr(13) + chr(10)
619    output_string += '{|'
620    output_string += 'l|l|'*len(parameters)
621    output_string += r'}\hline'
622    output_string += crlf
623
624    for index, parameter in enumerate(parameters):
625        name = parameter[0] # Parameter name
626        output_string += name.replace('_', r'\_')  # Escape underscores
627        # Add the error column if fitted
628        if parameter[1] == "True" and parameter[3] is not None:
629            output_string += ' & '
630            output_string += parameter[0]+r'\_err'
631
632        if index < len(parameters) - 1:
633            output_string += ' & '
634
635        # add .npts and .nsigmas when necessary
636        if parameter[0][-6:] == ".width":
637            output_string += parameter[0].replace('.width', '.nsigmas') + ' & '
638            output_string += parameter[0].replace('.width', '.npts')
639
640            if index < len(parameters) - 1:
641                output_string += ' & '
642
643    output_string += r'\\ \hline'
644    output_string += crlf
645
646    # Construct row of values and errors
647    for index, parameter in enumerate(parameters):
648        output_string += parameter[2]
649        if parameter[1] == "True" and parameter[3] is not None:
650            output_string += ' & '
651            output_string += parameter[3]
652
653        if index < len(parameters) - 1:
654            output_string += ' & '
655
656        # add .npts and .nsigmas when necessary
657        if parameter[0][-6:] == ".width":
658            output_string += parameter[5] + ' & '
659            output_string += parameter[4]
660
661            if index < len(parameters) - 1:
662                output_string += ' & '
663
664    output_string += r'\\ \hline'
665    output_string += crlf
666    output_string += r'\end{tabular}'
667    output_string += r'\end{table}'
668
669    return output_string
Note: See TracBrowser for help on using the repository browser.