source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ f3a19ad

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since f3a19ad was c7809a9, checked in by Torin Cooper-Bennun <torin.cooper-bennun@…>, 6 years ago

[CHERRY-PICK FROM 099369c03] fix mishaps when still using beta_approx / beta_approx_lazy_results branches of sasmodels

  • Property mode set to 100644
File size: 23.4 KB
Line 
1import copy
2
3from PyQt5 import QtCore
4from PyQt5 import QtGui
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
12
13model_header_tooltips = ['Select parameter for fitting',
14                         'Enter parameter value',
15                         'Enter minimum value for parameter',
16                         'Enter maximum value for parameter',
17                         'Unit of the parameter']
18
19poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
20                        'Function', 'Filename']
21
22poly_header_tooltips = ['Select parameter for fitting',
23                        'Enter polydispersity ratio (STD/mean). '
24                        'STD: standard deviation from the mean value',
25                        'Enter minimum value for parameter',
26                        'Enter maximum value for parameter',
27                        'Enter number of points for parameter',
28                        'Enter number of sigmas parameter',
29                        'Select distribution function',
30                        'Select filename with user-definable distribution']
31
32error_tooltip = 'Error value for fitted parameter'
33header_error_caption = 'Error'
34
35def replaceShellName(param_name, value):
36    """
37    Updates parameter name from <param_name>[n_shell] to <param_name>value
38    """
39    assert '[' in param_name
40    return param_name[:param_name.index('[')]+str(value)
41
42def getIterParams(model):
43    """
44    Returns a list of all multi-shell parameters in 'model'
45    """
46    return list([par for par in model.iq_parameters if "[" in par.name])
47
48def getMultiplicity(model):
49    """
50    Finds out if 'model' has multishell parameters.
51    If so, returns the name of the counter parameter and the number of shells
52    """
53    iter_params = getIterParams(model)
54    param_name = ""
55    param_length = 0
56    if iter_params:
57        param_length = iter_params[0].length
58        param_name = iter_params[0].length_control
59        if param_name is None and '[' in iter_params[0].name:
60            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
61    return (param_name, param_length)
62
63def addParametersToModel(parameters, kernel_module, is2D):
64    """
65    Update local ModelModel with sasmodel parameters
66    """
67    multishell_parameters = getIterParams(parameters)
68    multishell_param_name, _ = getMultiplicity(parameters)
69
70    if is2D:
71        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
72    else:
73        params = parameters.iq_parameters
74    item = []
75    for param in params:
76        # don't include shell parameters
77        if param.name == multishell_param_name:
78            continue
79        # Modify parameter name from <param>[n] to <param>1
80        item_name = param.name
81        if param in multishell_parameters:
82            continue
83        #    item_name = replaceShellName(param.name, 1)
84
85        item1 = QtGui.QStandardItem(item_name)
86        item1.setCheckable(True)
87        item1.setEditable(False)
88        # item_err = QtGui.QStandardItem()
89        # check for polydisp params
90        if param.polydisperse:
91            poly_item = QtGui.QStandardItem("Polydispersity")
92            poly_item.setEditable(False)
93            item1_1 = QtGui.QStandardItem("Distribution")
94            item1_1.setEditable(False)
95            # Find param in volume_params
96            for p in parameters.form_volume_parameters:
97                if p.name != param.name:
98                    continue
99                width = kernel_module.getParam(p.name+'.width')
100                ptype = kernel_module.getParam(p.name+'.type')
101
102                item1_2 = QtGui.QStandardItem(str(width))
103                item1_2.setEditable(False)
104                item1_3 = QtGui.QStandardItem()
105                item1_3.setEditable(False)
106                item1_4 = QtGui.QStandardItem()
107                item1_4.setEditable(False)
108                item1_5 = QtGui.QStandardItem(ptype)
109                item1_5.setEditable(False)
110                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
111                break
112            # Add the polydisp item as a child
113            item1.appendRow([poly_item])
114        # Param values
115        item2 = QtGui.QStandardItem(str(param.default))
116        # TODO: the error column.
117        # Either add a proxy model or a custom view delegate
118        #item_err = QtGui.QStandardItem()
119        item3 = QtGui.QStandardItem(str(param.limits[0]))
120        item4 = QtGui.QStandardItem(str(param.limits[1]))
121        item5 = QtGui.QStandardItem(param.units)
122        item5.setEditable(False)
123        item.append([item1, item2, item3, item4, item5])
124    return item
125
126def addSimpleParametersToModel(parameters, is2D, parameters_original=None):
127    """
128    Update local ModelModel with sasmodel parameters
129    parameters_original: list of parameters before any tagging on their IDs, e.g. for product model
130    (so that those are the display names; see below)
131    """
132    if is2D:
133        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
134    else:
135        params = parameters.iq_parameters
136
137    if parameters_original:
138        # 'parameters_original' contains the parameters as they are to be DISPLAYED, while 'parameters'
139        # contains the parameters as they were renamed; this is for handling name collisions in product model.
140        # The 'real name' of the parameter will be stored in the item's user data.
141        if is2D:
142            params_orig = [p for p in parameters_original.kernel_parameters if p.type != 'magnetic']
143        else:
144            params_orig = parameters_original.iq_parameters
145    else:
146        # no difference in names anyway
147        params_orig = params
148
149    item = []
150    for param, param_orig in zip(params, params_orig):
151        # Create the top level, checkable item
152        item_name = param_orig.name
153        item1 = QtGui.QStandardItem(item_name)
154        item1.setData(param.name, QtCore.Qt.UserRole)
155        item1.setCheckable(True)
156        item1.setEditable(False)
157        # Param values
158        # TODO: add delegate for validation of cells
159        item2 = QtGui.QStandardItem(str(param.default))
160        item4 = QtGui.QStandardItem(str(param.limits[0]))
161        item5 = QtGui.QStandardItem(str(param.limits[1]))
162        item6 = QtGui.QStandardItem(str(param.units))
163        item6.setEditable(False)
164        item.append([item1, item2, item4, item5, item6])
165    return item
166
167def markParameterDisabled(model, row):
168    """Given the QModel row number, format to show it is not available for fitting"""
169
170    # If an error column is present, there are a total of 6 columns.
171    items = [model.item(row, c) for c in range(6)]
172
173    model.blockSignals(True)
174
175    for item in items:
176        if item is None:
177            continue
178        item.setEditable(False)
179        item.setCheckable(False)
180
181    item = items[0]
182
183    font = QtGui.QFont()
184    font.setItalic(True)
185    item.setFont(font)
186    item.setForeground(QtGui.QBrush(QtGui.QColor(100, 100, 100)))
187    item.setToolTip("This parameter cannot be fitted.")
188
189    model.blockSignals(False)
190
191def addCheckedListToModel(model, param_list):
192    """
193    Add a QItem to model. Makes the QItem checkable
194    """
195    assert isinstance(model, QtGui.QStandardItemModel)
196    item_list = [QtGui.QStandardItem(item) for item in param_list]
197    item_list[0].setCheckable(True)
198    model.appendRow(item_list)
199
200def addHeadingRowToModel(model, name):
201    """adds a non-interactive top-level row to the model"""
202    header_row = [QtGui.QStandardItem() for i in range(5)]
203    header_row[0].setText(name)
204
205    font = header_row[0].font()
206    font.setBold(True)
207    header_row[0].setFont(font)
208
209    for item in header_row:
210        item.setEditable(False)
211        item.setCheckable(False)
212        item.setSelectable(False)
213
214    model.appendRow(header_row)
215
216def addHeadersToModel(model):
217    """
218    Adds predefined headers to the model
219    """
220    for i, item in enumerate(model_header_captions):
221        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
222
223    model.header_tooltips = copy.copy(model_header_tooltips)
224
225def addErrorHeadersToModel(model):
226    """
227    Adds predefined headers to the model
228    """
229    model_header_error_captions = copy.copy(model_header_captions)
230    model_header_error_captions.insert(2, header_error_caption)
231    for i, item in enumerate(model_header_error_captions):
232        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
233
234    model_header_error_tooltips = copy.copy(model_header_tooltips)
235    model_header_error_tooltips.insert(2, error_tooltip)
236    model.header_tooltips = copy.copy(model_header_error_tooltips)
237
238def addPolyHeadersToModel(model):
239    """
240    Adds predefined headers to the model
241    """
242    for i, item in enumerate(poly_header_captions):
243        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
244
245    model.header_tooltips = copy.copy(poly_header_tooltips)
246
247
248def addErrorPolyHeadersToModel(model):
249    """
250    Adds predefined headers to the model
251    """
252    poly_header_error_captions = copy.copy(poly_header_captions)
253    poly_header_error_captions.insert(2, header_error_caption)
254    for i, item in enumerate(poly_header_error_captions):
255        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
256
257    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
258    poly_header_error_tooltips.insert(2, error_tooltip)
259    model.header_tooltips = copy.copy(poly_header_error_tooltips)
260
261def addShellsToModel(parameters, model, index):
262    """
263    Find out multishell parameters and update the model with the requested number of them
264    """
265    multishell_parameters = getIterParams(parameters)
266
267    for i in range(index):
268        for par in multishell_parameters:
269            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
270            param_name = replaceShellName(par.name, i+1)
271            item1 = QtGui.QStandardItem(param_name)
272            item1.setCheckable(True)
273            # check for polydisp params
274            if par.polydisperse:
275                poly_item = QtGui.QStandardItem("Polydispersity")
276                item1_1 = QtGui.QStandardItem("Distribution")
277                # Find param in volume_params
278                for p in parameters.form_volume_parameters:
279                    if p.name != par.name:
280                        continue
281                    item1_2 = QtGui.QStandardItem(str(p.default))
282                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
283                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
284                    item1_5 = QtGui.QStandardItem(p.units)
285                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
286                    break
287                item1.appendRow([poly_item])
288
289            item2 = QtGui.QStandardItem(str(par.default))
290            item3 = QtGui.QStandardItem(str(par.limits[0]))
291            item4 = QtGui.QStandardItem(str(par.limits[1]))
292            item5 = QtGui.QStandardItem(par.units)
293            model.appendRow([item1, item2, item3, item4, item5])
294
295def calculateChi2(reference_data, current_data):
296    """
297    Calculate Chi2 value between two sets of data
298    """
299    if reference_data is None or current_data is None:
300        return None
301    # WEIGHING INPUT
302    #from sas.sasgui.perspectives.fitting.utils import get_weight
303    #flag = self.get_weight_flag()
304    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
305    chisqr = None
306    if reference_data is None:
307        return chisqr
308
309    # temporary default values for index and weight
310    index = None
311    weight = None
312
313    # Get data: data I, theory I, and data dI in order
314    if isinstance(reference_data, Data2D):
315        if index is None:
316            index = numpy.ones(len(current_data.data), dtype=bool)
317        if weight is not None:
318            current_data.err_data = weight
319        # get rid of zero error points
320        index = index & (current_data.err_data != 0)
321        index = index & (numpy.isfinite(current_data.data))
322        fn = current_data.data[index]
323        gn = reference_data.data[index]
324        en = current_data.err_data[index]
325    else:
326        # 1 d theory from model_thread is only in the range of index
327        if index is None:
328            index = numpy.ones(len(current_data.y), dtype=bool)
329        if weight is not None:
330            current_data.dy = weight
331        if current_data.dy is None or current_data.dy == []:
332            dy = numpy.ones(len(current_data.y))
333        else:
334            ## Set consistently w/AbstractFitengine:
335            # But this should be corrected later.
336            dy = copy.deepcopy(current_data.dy)
337            dy[dy == 0] = 1
338        fn = current_data.y[index]
339        gn = reference_data.y
340        en = dy[index]
341    # Calculate the residual
342    try:
343        res = (fn - gn) / en
344    except ValueError:
345        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
346        return None
347
348    residuals = res[numpy.isfinite(res)]
349    chisqr = numpy.average(residuals * residuals)
350
351    return chisqr
352
353def residualsData1D(reference_data, current_data):
354    """
355    Calculate the residuals for difference of two Data1D sets
356    """
357    # temporary default values for index and weight
358    index = None
359    weight = None
360
361    # 1d theory from model_thread is only in the range of index
362    if current_data.dy is None or current_data.dy == []:
363        dy = numpy.ones(len(current_data.y))
364    else:
365        dy = weight if weight is not None else numpy.ones(len(current_data.y))
366        dy[dy == 0] = 1
367    fn = current_data.y[index][0]
368    gn = reference_data.y
369    en = dy[index][0]
370
371    # x values
372    x_current = current_data.x
373    x_reference = reference_data.x
374
375    # build residuals
376    residuals = Data1D()
377    if len(fn) == len(gn):
378        y = (fn - gn)/en
379        residuals.y = -y
380    elif len(fn) > len(gn):
381        residuals.y = (fn - gn[1:len(fn)])/en
382    else:
383        try:
384            y = numpy.zeros(len(current_data.y))
385            begin = 0
386            for i, x_value in enumerate(x_reference):
387                if x_value in x_current:
388                    begin = i
389                    break
390            end = len(x_reference)
391            endl = 0
392            for i, x_value in enumerate(list(x_reference)[::-1]):
393                if x_value in x_current:
394                    endl = i
395                    break
396            # make sure we have correct lengths
397            assert len(x_current) == len(x_reference[begin:end-endl])
398
399            y = (fn - gn[begin:end-endl])/en
400            residuals.y = y
401        except ValueError:
402            # value errors may show up every once in a while for malformed columns,
403            # just reuse what's there already
404            pass
405
406    residuals.x = current_data.x[index][0]
407    residuals.dy = numpy.ones(len(residuals.y))
408    residuals.dx = None
409    residuals.dxl = None
410    residuals.dxw = None
411    residuals.ytransform = 'y'
412    # For latter scale changes
413    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
414    residuals.yaxis('\\rm{Residuals} ', 'normalized')
415
416    return residuals
417
418def residualsData2D(reference_data, current_data):
419    """
420    Calculate the residuals for difference of two Data2D sets
421    """
422    # temporary default values for index and weight
423    # index = None
424    weight = None
425
426    # build residuals
427    residuals = Data2D()
428    # Not for trunk the line below, instead use the line above
429    current_data.clone_without_data(len(current_data.data), residuals)
430    residuals.data = None
431    fn = current_data.data
432    gn = reference_data.data
433    en = current_data.err_data if weight is None else weight
434    residuals.data = (fn - gn) / en
435    residuals.qx_data = current_data.qx_data
436    residuals.qy_data = current_data.qy_data
437    residuals.q_data = current_data.q_data
438    residuals.err_data = numpy.ones(len(residuals.data))
439    residuals.xmin = min(residuals.qx_data)
440    residuals.xmax = max(residuals.qx_data)
441    residuals.ymin = min(residuals.qy_data)
442    residuals.ymax = max(residuals.qy_data)
443    residuals.q_data = current_data.q_data
444    residuals.mask = current_data.mask
445    residuals.scale = 'linear'
446    # check the lengths
447    if len(residuals.data) != len(residuals.q_data):
448        return None
449    return residuals
450
451def plotResiduals(reference_data, current_data):
452    """
453    Create Data1D/Data2D with residuals, ready for plotting
454    """
455    data_copy = copy.deepcopy(current_data)
456    # Get data: data I, theory I, and data dI in order
457    method_name = current_data.__class__.__name__
458    residuals_dict = {"Data1D": residualsData1D,
459                      "Data2D": residualsData2D}
460
461    residuals = residuals_dict[method_name](reference_data, data_copy)
462
463    theory_name = str(current_data.name.split()[0])
464    residuals.name = "Residuals for " + str(theory_name) + "[" + \
465                    str(reference_data.filename) + "]"
466    residuals.title = residuals.name
467    residuals.ytransform = 'y'
468
469    # when 2 data have the same id override the 1 st plotted
470    # include the last part if keeping charts for separate models is required
471    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
472    # group_id specify on which panel to plot this data
473    group_id = reference_data.group_id
474    residuals.group_id = "res" + str(group_id)
475
476    # Symbol
477    residuals.symbol = 0
478    residuals.hide_error = False
479
480    return residuals
481
482def binary_encode(i, digits):
483    return [i >> d & 1 for d in range(digits)]
484
485def getWeight(data, is2d, flag=None):
486    """
487    Received flag and compute error on data.
488    :param flag: flag to transform error of data.
489    """
490    weight = None
491    if is2d:
492        dy_data = data.err_data
493        data = data.data
494    else:
495        dy_data = data.dy
496        data = data.y
497
498    if flag == 0:
499        weight = numpy.ones_like(data)
500    elif flag == 1:
501        weight = dy_data
502    elif flag == 2:
503        weight = numpy.sqrt(numpy.abs(data))
504    elif flag == 3:
505        weight = numpy.abs(data)
506    return weight
507
508def updateKernelWithResults(kernel, results):
509    """
510    Takes model kernel and applies results dict to its parameters,
511    returning the modified (deep) copy of the kernel.
512    """
513    assert isinstance(results, dict)
514    local_kernel = copy.deepcopy(kernel)
515
516    for parameter in results.keys():
517        # Update the parameter value - note: this supports +/-inf as well
518        local_kernel.setParam(parameter, results[parameter][0])
519
520    return local_kernel
521
522
523def getStandardParam(model=None):
524    """
525    Returns a list with standard parameters for the current model
526    """
527    param = []
528    num_rows = model.rowCount()
529    if num_rows < 1:
530        return None
531
532    for row in range(num_rows):
533        param_name = model.item(row, 0).text()
534        checkbox_state = model.item(row, 0).checkState() == QtCore.Qt.Checked
535        value = model.item(row, 1).text()
536        column_shift = 0
537        if model.columnCount() == 5: # no error column
538            error_state = False
539            error_value = 0.0
540        else:
541            error_state = True
542            error_value = model.item(row, 2).text()
543            column_shift = 1
544        min_state = True
545        max_state = True
546        min_value = model.item(row, 2+column_shift).text()
547        max_value = model.item(row, 3+column_shift).text()
548        unit = ""
549        if model.item(row, 4+column_shift) is not None:
550            unit = model.item(row, 4+column_shift).text()
551
552        param.append([checkbox_state, param_name, value, "",
553                        [error_state, error_value],
554                        [min_state, min_value],
555                        [max_state, max_value], unit])
556
557    return param
558
559def getOrientationParam(kernel_module=None):
560    """
561    Get the dictionary with orientation parameters
562    """
563    param = []
564    if kernel_module is None:
565        return None
566    for param_name in list(kernel_module.params.keys()):
567        name = param_name
568        value = kernel_module.params[param_name]
569        min_state = True
570        max_state = True
571        error_state = False
572        error_value = 0.0
573        checkbox_state = True #??
574        details = kernel_module.details[param_name] #[unit, mix, max]
575        param.append([checkbox_state, name, value, "",
576                     [error_state, error_value],
577                     [min_state, details[1]],
578                     [max_state, details[2]], details[0]])
579
580    return param
581
582def formatParameters(parameters):
583    """
584    Prepare the parameter string in the standard SasView layout
585    """
586    assert parameters is not None
587    assert isinstance(parameters, list)
588    output_string = "sasview_parameter_values:"
589    for parameter in parameters:
590        output_string += ",".join([p for p in parameter if p is not None])
591        output_string += ":"
592    return output_string
593
594def formatParametersExcel(parameters):
595    """
596    Prepare the parameter string in the Excel format (tab delimited)
597    """
598    assert parameters is not None
599    assert isinstance(parameters, list)
600    crlf = chr(13) + chr(10)
601    tab = chr(9)
602
603    output_string = ""
604    # names
605    names = ""
606    values = ""
607    for parameter in parameters:
608        names += parameter[0]+tab
609        # Add the error column if fitted
610        if parameter[1] == "True" and parameter[3] is not None:
611            names += parameter[0]+"_err"+tab
612
613        values += parameter[2]+tab
614        if parameter[1] == "True" and parameter[3] is not None:
615            values += parameter[3]+tab
616        # add .npts and .nsigmas when necessary
617        if parameter[0][-6:] == ".width":
618            names += parameter[0].replace('.width', '.nsigmas') + tab
619            names += parameter[0].replace('.width', '.npts') + tab
620            values += parameter[5] + tab + parameter[4] + tab
621
622    output_string = names + crlf + values
623    return output_string
624
625def formatParametersLatex(parameters):
626    """
627    Prepare the parameter string in latex
628    """
629    assert parameters is not None
630    assert isinstance(parameters, list)
631    output_string = r'\begin{table}'
632    output_string += r'\begin{tabular}[h]'
633
634    crlf = chr(13) + chr(10)
635    output_string += '{|'
636    output_string += 'l|l|'*len(parameters)
637    output_string += r'}\hline'
638    output_string += crlf
639
640    for index, parameter in enumerate(parameters):
641        name = parameter[0] # Parameter name
642        output_string += name.replace('_', r'\_')  # Escape underscores
643        # Add the error column if fitted
644        if parameter[1] == "True" and parameter[3] is not None:
645            output_string += ' & '
646            output_string += parameter[0]+r'\_err'
647
648        if index < len(parameters) - 1:
649            output_string += ' & '
650
651        # add .npts and .nsigmas when necessary
652        if parameter[0][-6:] == ".width":
653            output_string += parameter[0].replace('.width', '.nsigmas') + ' & '
654            output_string += parameter[0].replace('.width', '.npts')
655
656            if index < len(parameters) - 1:
657                output_string += ' & '
658
659    output_string += r'\\ \hline'
660    output_string += crlf
661
662    # Construct row of values and errors
663    for index, parameter in enumerate(parameters):
664        output_string += parameter[2]
665        if parameter[1] == "True" and parameter[3] is not None:
666            output_string += ' & '
667            output_string += parameter[3]
668
669        if index < len(parameters) - 1:
670            output_string += ' & '
671
672        # add .npts and .nsigmas when necessary
673        if parameter[0][-6:] == ".width":
674            output_string += parameter[5] + ' & '
675            output_string += parameter[4]
676
677            if index < len(parameters) - 1:
678                output_string += ' & '
679
680    output_string += r'\\ \hline'
681    output_string += crlf
682    output_string += r'\end{tabular}'
683    output_string += r'\end{table}'
684
685    return output_string
Note: See TracBrowser for help on using the repository browser.