source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ f7d39c9

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since f7d39c9 was f7d39c9, checked in by Piotr Rozyczko <rozyczko@…>, 6 years ago

Several fixes for runtime errors. SASVIEW-988.
Failures in Algorithm Option dialog
Q Ranges not showing properly with charts

  • Property mode set to 100644
File size: 22.1 KB
Line 
1import copy
2
3from PyQt5 import QtCore
4from PyQt5 import QtGui
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
12
13model_header_tooltips = ['Select parameter for fitting',
14                         'Enter parameter value',
15                         'Enter minimum value for parameter',
16                         'Enter maximum value for parameter',
17                         'Unit of the parameter']
18
19poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
20                        'Function', 'Filename']
21
22poly_header_tooltips = ['Select parameter for fitting',
23                        'Enter polydispersity ratio (STD/mean). '
24                        'STD: standard deviation from the mean value',
25                        'Enter minimum value for parameter',
26                        'Enter maximum value for parameter',
27                        'Enter number of points for parameter',
28                        'Enter number of sigmas parameter',
29                        'Select distribution function',
30                        'Select filename with user-definable distribution']
31
32error_tooltip = 'Error value for fitted parameter'
33header_error_caption = 'Error'
34
35def replaceShellName(param_name, value):
36    """
37    Updates parameter name from <param_name>[n_shell] to <param_name>value
38    """
39    assert '[' in param_name
40    return param_name[:param_name.index('[')]+str(value)
41
42def getIterParams(model):
43    """
44    Returns a list of all multi-shell parameters in 'model'
45    """
46    return list([par for par in model.iq_parameters if "[" in par.name])
47
48def getMultiplicity(model):
49    """
50    Finds out if 'model' has multishell parameters.
51    If so, returns the name of the counter parameter and the number of shells
52    """
53    iter_params = getIterParams(model)
54    param_name = ""
55    param_length = 0
56    if iter_params:
57        param_length = iter_params[0].length
58        param_name = iter_params[0].length_control
59        if param_name is None and '[' in iter_params[0].name:
60            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
61    return (param_name, param_length)
62
63def addParametersToModel(parameters, kernel_module, is2D):
64    """
65    Update local ModelModel with sasmodel parameters
66    """
67    multishell_parameters = getIterParams(parameters)
68    multishell_param_name, _ = getMultiplicity(parameters)
69
70    if is2D:
71        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
72    else:
73        params = parameters.iq_parameters
74    item = []
75    for param in params:
76        # don't include shell parameters
77        if param.name == multishell_param_name:
78            continue
79        # Modify parameter name from <param>[n] to <param>1
80        item_name = param.name
81        if param in multishell_parameters:
82            continue
83        #    item_name = replaceShellName(param.name, 1)
84
85        item1 = QtGui.QStandardItem(item_name)
86        item1.setCheckable(True)
87        item1.setEditable(False)
88        # item_err = QtGui.QStandardItem()
89        # check for polydisp params
90        if param.polydisperse:
91            poly_item = QtGui.QStandardItem("Polydispersity")
92            poly_item.setEditable(False)
93            item1_1 = QtGui.QStandardItem("Distribution")
94            item1_1.setEditable(False)
95            # Find param in volume_params
96            for p in parameters.form_volume_parameters:
97                if p.name != param.name:
98                    continue
99                width = kernel_module.getParam(p.name+'.width')
100                ptype = kernel_module.getParam(p.name+'.type')
101
102                item1_2 = QtGui.QStandardItem(str(width))
103                item1_2.setEditable(False)
104                item1_3 = QtGui.QStandardItem()
105                item1_3.setEditable(False)
106                item1_4 = QtGui.QStandardItem()
107                item1_4.setEditable(False)
108                item1_5 = QtGui.QStandardItem(ptype)
109                item1_5.setEditable(False)
110                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
111                break
112            # Add the polydisp item as a child
113            item1.appendRow([poly_item])
114        # Param values
115        item2 = QtGui.QStandardItem(str(param.default))
116        # TODO: the error column.
117        # Either add a proxy model or a custom view delegate
118        #item_err = QtGui.QStandardItem()
119        item3 = QtGui.QStandardItem(str(param.limits[0]))
120        item4 = QtGui.QStandardItem(str(param.limits[1]))
121        item5 = QtGui.QStandardItem(param.units)
122        item5.setEditable(False)
123        item.append([item1, item2, item3, item4, item5])
124    return item
125
126def addSimpleParametersToModel(parameters, is2D):
127    """
128    Update local ModelModel with sasmodel parameters
129    """
130    if is2D:
131        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
132    else:
133        params = parameters.iq_parameters
134    item = []
135    for param in params:
136        # Create the top level, checkable item
137        item_name = param.name
138        item1 = QtGui.QStandardItem(item_name)
139        item1.setCheckable(True)
140        item1.setEditable(False)
141        # Param values
142        # TODO: add delegate for validation of cells
143        item2 = QtGui.QStandardItem(str(param.default))
144        item4 = QtGui.QStandardItem(str(param.limits[0]))
145        item5 = QtGui.QStandardItem(str(param.limits[1]))
146        item6 = QtGui.QStandardItem(param.units)
147        item6.setEditable(False)
148        item.append([item1, item2, item4, item5, item6])
149    return item
150
151def markParameterDisabled(model, row):
152    """Given the QModel row number, format to show it is not available for fitting"""
153
154    # If an error column is present, there are a total of 6 columns.
155    items = [model.item(row, c) for c in range(6)]
156
157    model.blockSignals(True)
158
159    for item in items:
160        if item is None:
161            continue
162        item.setEditable(False)
163        item.setCheckable(False)
164
165    item = items[0]
166
167    font = QtGui.QFont()
168    font.setItalic(True)
169    item.setFont(font)
170    item.setForeground(QtGui.QBrush(QtGui.QColor(100, 100, 100)))
171    item.setToolTip("This parameter cannot be fitted.")
172
173    model.blockSignals(False)
174
175def addCheckedListToModel(model, param_list):
176    """
177    Add a QItem to model. Makes the QItem checkable
178    """
179    assert isinstance(model, QtGui.QStandardItemModel)
180    item_list = [QtGui.QStandardItem(item) for item in param_list]
181    item_list[0].setCheckable(True)
182    model.appendRow(item_list)
183
184def addHeadersToModel(model):
185    """
186    Adds predefined headers to the model
187    """
188    for i, item in enumerate(model_header_captions):
189        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
190
191    model.header_tooltips = copy.copy(model_header_tooltips)
192
193def addErrorHeadersToModel(model):
194    """
195    Adds predefined headers to the model
196    """
197    model_header_error_captions = copy.copy(model_header_captions)
198    model_header_error_captions.insert(2, header_error_caption)
199    for i, item in enumerate(model_header_error_captions):
200        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
201
202    model_header_error_tooltips = copy.copy(model_header_tooltips)
203    model_header_error_tooltips.insert(2, error_tooltip)
204    model.header_tooltips = copy.copy(model_header_error_tooltips)
205
206def addPolyHeadersToModel(model):
207    """
208    Adds predefined headers to the model
209    """
210    for i, item in enumerate(poly_header_captions):
211        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
212
213    model.header_tooltips = copy.copy(poly_header_tooltips)
214
215
216def addErrorPolyHeadersToModel(model):
217    """
218    Adds predefined headers to the model
219    """
220    poly_header_error_captions = copy.copy(poly_header_captions)
221    poly_header_error_captions.insert(2, header_error_caption)
222    for i, item in enumerate(poly_header_error_captions):
223        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
224
225    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
226    poly_header_error_tooltips.insert(2, error_tooltip)
227    model.header_tooltips = copy.copy(poly_header_error_tooltips)
228
229def addShellsToModel(parameters, model, index):
230    """
231    Find out multishell parameters and update the model with the requested number of them
232    """
233    multishell_parameters = getIterParams(parameters)
234
235    for i in range(index):
236        for par in multishell_parameters:
237            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
238            param_name = replaceShellName(par.name, i+1)
239            item1 = QtGui.QStandardItem(param_name)
240            item1.setCheckable(True)
241            # check for polydisp params
242            if par.polydisperse:
243                poly_item = QtGui.QStandardItem("Polydispersity")
244                item1_1 = QtGui.QStandardItem("Distribution")
245                # Find param in volume_params
246                for p in parameters.form_volume_parameters:
247                    if p.name != par.name:
248                        continue
249                    item1_2 = QtGui.QStandardItem(str(p.default))
250                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
251                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
252                    item1_5 = QtGui.QStandardItem(p.units)
253                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
254                    break
255                item1.appendRow([poly_item])
256
257            item2 = QtGui.QStandardItem(str(par.default))
258            item3 = QtGui.QStandardItem(str(par.limits[0]))
259            item4 = QtGui.QStandardItem(str(par.limits[1]))
260            item5 = QtGui.QStandardItem(par.units)
261            model.appendRow([item1, item2, item3, item4, item5])
262
263def calculateChi2(reference_data, current_data):
264    """
265    Calculate Chi2 value between two sets of data
266    """
267
268    # WEIGHING INPUT
269    #from sas.sasgui.perspectives.fitting.utils import get_weight
270    #flag = self.get_weight_flag()
271    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
272    chisqr = None
273    if reference_data is None:
274        return chisqr
275
276    # temporary default values for index and weight
277    index = None
278    weight = None
279
280    # Get data: data I, theory I, and data dI in order
281    if isinstance(reference_data, Data2D):
282        if index is None:
283            index = numpy.ones(len(current_data.data), dtype=bool)
284        if weight is not None:
285            current_data.err_data = weight
286        # get rid of zero error points
287        index = index & (current_data.err_data != 0)
288        index = index & (numpy.isfinite(current_data.data))
289        fn = current_data.data[index]
290        gn = reference_data.data[index]
291        en = current_data.err_data[index]
292    else:
293        # 1 d theory from model_thread is only in the range of index
294        if index is None:
295            index = numpy.ones(len(current_data.y), dtype=bool)
296        if weight is not None:
297            current_data.dy = weight
298        if current_data.dy is None or current_data.dy == []:
299            dy = numpy.ones(len(current_data.y))
300        else:
301            ## Set consistently w/AbstractFitengine:
302            # But this should be corrected later.
303            dy = copy.deepcopy(current_data.dy)
304            dy[dy == 0] = 1
305        fn = current_data.y[index]
306        gn = reference_data.y
307        en = dy[index]
308    # Calculate the residual
309    try:
310        res = (fn - gn) / en
311    except ValueError:
312        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
313        return None
314
315    residuals = res[numpy.isfinite(res)]
316    chisqr = numpy.average(residuals * residuals)
317
318    return chisqr
319
320def residualsData1D(reference_data, current_data):
321    """
322    Calculate the residuals for difference of two Data1D sets
323    """
324    # temporary default values for index and weight
325    index = None
326    weight = None
327
328    # 1d theory from model_thread is only in the range of index
329    if current_data.dy is None or current_data.dy == []:
330        dy = numpy.ones(len(current_data.y))
331    else:
332        dy = weight if weight is not None else numpy.ones(len(current_data.y))
333        dy[dy == 0] = 1
334    fn = current_data.y[index][0]
335    gn = reference_data.y
336    en = dy[index][0]
337
338    # x values
339    x_current = current_data.x
340    x_reference = reference_data.x
341
342    # build residuals
343    residuals = Data1D()
344    if len(fn) == len(gn):
345        y = (fn - gn)/en
346        residuals.y = -y
347    elif len(fn) > len(gn):
348        residuals.y = (fn - gn[1:len(fn)])/en
349    else:
350        try:
351            y = numpy.zeros(len(current_data.y))
352            begin = 0
353            for i, x_value in enumerate(x_reference):
354                if x_value in x_current:
355                    begin = i
356                    break
357            end = len(x_reference)
358            endl = 0
359            for i, x_value in enumerate(list(x_reference)[::-1]):
360                if x_value in x_current:
361                    endl = i
362                    break
363            # make sure we have correct lengths
364            assert len(x_current) == len(x_reference[begin:end-endl])
365
366            y = (fn - gn[begin:end-endl])/en
367            residuals.y = y
368        except ValueError:
369            # value errors may show up every once in a while for malformed columns,
370            # just reuse what's there already
371            pass
372
373    residuals.x = current_data.x[index][0]
374    residuals.dy = numpy.ones(len(residuals.y))
375    residuals.dx = None
376    residuals.dxl = None
377    residuals.dxw = None
378    residuals.ytransform = 'y'
379    # For latter scale changes
380    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
381    residuals.yaxis('\\rm{Residuals} ', 'normalized')
382
383    return residuals
384
385def residualsData2D(reference_data, current_data):
386    """
387    Calculate the residuals for difference of two Data2D sets
388    """
389    # temporary default values for index and weight
390    # index = None
391    weight = None
392
393    # build residuals
394    residuals = Data2D()
395    # Not for trunk the line below, instead use the line above
396    current_data.clone_without_data(len(current_data.data), residuals)
397    residuals.data = None
398    fn = current_data.data
399    gn = reference_data.data
400    en = current_data.err_data if weight is None else weight
401    residuals.data = (fn - gn) / en
402    residuals.qx_data = current_data.qx_data
403    residuals.qy_data = current_data.qy_data
404    residuals.q_data = current_data.q_data
405    residuals.err_data = numpy.ones(len(residuals.data))
406    residuals.xmin = min(residuals.qx_data)
407    residuals.xmax = max(residuals.qx_data)
408    residuals.ymin = min(residuals.qy_data)
409    residuals.ymax = max(residuals.qy_data)
410    residuals.q_data = current_data.q_data
411    residuals.mask = current_data.mask
412    residuals.scale = 'linear'
413    # check the lengths
414    if len(residuals.data) != len(residuals.q_data):
415        return None
416    return residuals
417
418def plotResiduals(reference_data, current_data):
419    """
420    Create Data1D/Data2D with residuals, ready for plotting
421    """
422    data_copy = copy.deepcopy(current_data)
423    # Get data: data I, theory I, and data dI in order
424    method_name = current_data.__class__.__name__
425    residuals_dict = {"Data1D": residualsData1D,
426                      "Data2D": residualsData2D}
427
428    residuals = residuals_dict[method_name](reference_data, data_copy)
429
430    theory_name = str(current_data.name.split()[0])
431    residuals.name = "Residuals for " + str(theory_name) + "[" + \
432                    str(reference_data.filename) + "]"
433    residuals.title = residuals.name
434    residuals.ytransform = 'y'
435
436    # when 2 data have the same id override the 1 st plotted
437    # include the last part if keeping charts for separate models is required
438    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
439    # group_id specify on which panel to plot this data
440    group_id = reference_data.group_id
441    residuals.group_id = "res" + str(group_id)
442
443    # Symbol
444    residuals.symbol = 0
445    residuals.hide_error = False
446
447    return residuals
448
449def binary_encode(i, digits):
450    return [i >> d & 1 for d in range(digits)]
451
452def getWeight(data, is2d, flag=None):
453    """
454    Received flag and compute error on data.
455    :param flag: flag to transform error of data.
456    """
457    weight = None
458    if is2d:
459        dy_data = data.err_data
460        data = data.data
461    else:
462        dy_data = data.dy
463        data = data.y
464
465    if flag == 0:
466        weight = numpy.ones_like(data)
467    elif flag == 1:
468        weight = dy_data
469    elif flag == 2:
470        weight = numpy.sqrt(numpy.abs(data))
471    elif flag == 3:
472        weight = numpy.abs(data)
473    return weight
474
475def updateKernelWithResults(kernel, results):
476    """
477    Takes model kernel and applies results dict to its parameters,
478    returning the modified (deep) copy of the kernel.
479    """
480    assert isinstance(results, dict)
481    local_kernel = copy.deepcopy(kernel)
482
483    for parameter in results.keys():
484        # Update the parameter value - note: this supports +/-inf as well
485        local_kernel.setParam(parameter, results[parameter][0])
486
487    return local_kernel
488
489
490def getStandardParam(model=None):
491    """
492    Returns a list with standard parameters for the current model
493    """
494    param = []
495    num_rows = model.rowCount()
496    if num_rows < 1:
497        return None
498
499    for row in range(num_rows):
500        param_name = model.item(row, 0).text()
501        checkbox_state = model.item(row, 0).checkState() == QtCore.Qt.Checked
502        value = model.item(row, 1).text()
503        column_shift = 0
504        if model.columnCount() == 5: # no error column
505            error_state = False
506            error_value = 0.0
507        else:
508            error_state = True
509            error_value = model.item(row, 2).text()
510            column_shift = 1
511        min_state = True
512        max_state = True
513        min_value = model.item(row, 2+column_shift).text()
514        max_value = model.item(row, 3+column_shift).text()
515        unit = ""
516        if model.item(row, 4+column_shift) is not None:
517            unit = model.item(row, 4+column_shift).text()
518
519        param.append([checkbox_state, param_name, value, "",
520                        [error_state, error_value],
521                        [min_state, min_value],
522                        [max_state, max_value], unit])
523
524    return param
525
526def getOrientationParam(kernel_module=None):
527    """
528    Get the dictionary with orientation parameters
529    """
530    param = []
531    if kernel_module is None:
532        return None
533    for param_name in list(kernel_module.params.keys()):
534        name = param_name
535        value = kernel_module.params[param_name]
536        min_state = True
537        max_state = True
538        error_state = False
539        error_value = 0.0
540        checkbox_state = True #??
541        details = kernel_module.details[param_name] #[unit, mix, max]
542        param.append([checkbox_state, name, value, "",
543                     [error_state, error_value],
544                     [min_state, details[1]],
545                     [max_state, details[2]], details[0]])
546
547    return param
548
549def formatParameters(parameters):
550    """
551    Prepare the parameter string in the standard SasView layout
552    """
553    assert parameters is not None
554    assert isinstance(parameters, list)
555    output_string = "sasview_parameter_values:"
556    for parameter in parameters:
557        output_string += ",".join([p for p in parameter if p is not None])
558        output_string += ":"
559    return output_string
560
561def formatParametersExcel(parameters):
562    """
563    Prepare the parameter string in the Excel format (tab delimited)
564    """
565    assert parameters is not None
566    assert isinstance(parameters, list)
567    crlf = chr(13) + chr(10)
568    tab = chr(9)
569
570    output_string = ""
571    # names
572    names = ""
573    values = ""
574    for parameter in parameters:
575        names += parameter[0]+tab
576        # Add the error column if fitted
577        if parameter[1] == "True" and parameter[3] is not None:
578            names += parameter[0]+"_err"+tab
579
580        values += parameter[2]+tab
581        if parameter[1] == "True" and parameter[3] is not None:
582            values += parameter[3]+tab
583        # add .npts and .nsigmas when necessary
584        if parameter[0][-6:] == ".width":
585            names += parameter[0].replace('.width', '.nsigmas') + tab
586            names += parameter[0].replace('.width', '.npts') + tab
587            values += parameter[5] + tab + parameter[4] + tab
588
589    output_string = names + crlf + values
590    return output_string
591
592def formatParametersLatex(parameters):
593    """
594    Prepare the parameter string in latex
595    """
596    assert parameters is not None
597    assert isinstance(parameters, list)
598    output_string = r'\begin{table}'
599    output_string += r'\begin{tabular}[h]'
600
601    crlf = chr(13) + chr(10)
602    output_string += '{|'
603    output_string += 'l|l|'*len(parameters)
604    output_string += r'}\hline'
605    output_string += crlf
606
607    for index, parameter in enumerate(parameters):
608        name = parameter[0] # Parameter name
609        output_string += name.replace('_', r'\_')  # Escape underscores
610        # Add the error column if fitted
611        if parameter[1] == "True" and parameter[3] is not None:
612            output_string += ' & '
613            output_string += parameter[0]+r'\_err'
614
615        if index < len(parameters) - 1:
616            output_string += ' & '
617
618        # add .npts and .nsigmas when necessary
619        if parameter[0][-6:] == ".width":
620            output_string += parameter[0].replace('.width', '.nsigmas') + ' & '
621            output_string += parameter[0].replace('.width', '.npts')
622
623            if index < len(parameters) - 1:
624                output_string += ' & '
625
626    output_string += r'\\ \hline'
627    output_string += crlf
628
629    # Construct row of values and errors
630    for index, parameter in enumerate(parameters):
631        output_string += parameter[2]
632        if parameter[1] == "True" and parameter[3] is not None:
633            output_string += ' & '
634            output_string += parameter[3]
635
636        if index < len(parameters) - 1:
637            output_string += ' & '
638
639        # add .npts and .nsigmas when necessary
640        if parameter[0][-6:] == ".width":
641            output_string += parameter[5] + ' & '
642            output_string += parameter[4]
643
644            if index < len(parameters) - 1:
645                output_string += ' & '
646
647    output_string += r'\\ \hline'
648    output_string += crlf
649    output_string += r'\end{tabular}'
650    output_string += r'\end{table}'
651
652    return output_string
Note: See TracBrowser for help on using the repository browser.