source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 516ee4b

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 516ee4b was b87dc1a, checked in by Torin Cooper-Bennun <torin.cooper-bennun@…>, 6 years ago

keep non-fittable rows selectable, to not break layout

  • Property mode set to 100644
File size: 18.1 KB
Line 
1import copy
2
3from PyQt5 import QtCore
4from PyQt5 import QtGui
5from PyQt5 import QtWidgets
6
7import numpy
8
9from sas.qtgui.Plotting.PlotterData import Data1D
10from sas.qtgui.Plotting.PlotterData import Data2D
11
12model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
13
14model_header_tooltips = ['Select parameter for fitting',
15                         'Enter parameter value',
16                         'Enter minimum value for parameter',
17                         'Enter maximum value for parameter',
18                         'Unit of the parameter']
19
20poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
21                        'Function', 'Filename']
22
23poly_header_tooltips = ['Select parameter for fitting',
24                        'Enter polydispersity ratio (STD/mean). '
25                        'STD: standard deviation from the mean value',
26                        'Enter minimum value for parameter',
27                        'Enter maximum value for parameter',
28                        'Enter number of points for parameter',
29                        'Enter number of sigmas parameter',
30                        'Select distribution function',
31                        'Select filename with user-definable distribution']
32
33error_tooltip = 'Error value for fitted parameter'
34header_error_caption = 'Error'
35
36def replaceShellName(param_name, value):
37    """
38    Updates parameter name from <param_name>[n_shell] to <param_name>value
39    """
40    assert '[' in param_name
41    return param_name[:param_name.index('[')]+str(value)
42
43def getIterParams(model):
44    """
45    Returns a list of all multi-shell parameters in 'model'
46    """
47    return list([par for par in model.iq_parameters if "[" in par.name])
48
49def getMultiplicity(model):
50    """
51    Finds out if 'model' has multishell parameters.
52    If so, returns the name of the counter parameter and the number of shells
53    """
54    iter_params = getIterParams(model)
55    param_name = ""
56    param_length = 0
57    if iter_params:
58        param_length = iter_params[0].length
59        param_name = iter_params[0].length_control
60        if param_name is None and '[' in iter_params[0].name:
61            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
62    return (param_name, param_length)
63
64def addParametersToModel(parameters, kernel_module, is2D):
65    """
66    Update local ModelModel with sasmodel parameters
67    """
68    multishell_parameters = getIterParams(parameters)
69    multishell_param_name, _ = getMultiplicity(parameters)
70
71    if is2D:
72        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
73    else:
74        params = parameters.iq_parameters
75    item = []
76    for param in params:
77        # don't include shell parameters
78        if param.name == multishell_param_name:
79            continue
80        # Modify parameter name from <param>[n] to <param>1
81        item_name = param.name
82        if param in multishell_parameters:
83            continue
84        #    item_name = replaceShellName(param.name, 1)
85
86        item1 = QtGui.QStandardItem(item_name)
87        item1.setCheckable(True)
88        item1.setEditable(False)
89        # item_err = QtGui.QStandardItem()
90        # check for polydisp params
91        if param.polydisperse:
92            poly_item = QtGui.QStandardItem("Polydispersity")
93            poly_item.setEditable(False)
94            item1_1 = QtGui.QStandardItem("Distribution")
95            item1_1.setEditable(False)
96            # Find param in volume_params
97            for p in parameters.form_volume_parameters:
98                if p.name != param.name:
99                    continue
100                width = kernel_module.getParam(p.name+'.width')
101                type = kernel_module.getParam(p.name+'.type')
102
103                item1_2 = QtGui.QStandardItem(str(width))
104                item1_2.setEditable(False)
105                item1_3 = QtGui.QStandardItem()
106                item1_3.setEditable(False)
107                item1_4 = QtGui.QStandardItem()
108                item1_4.setEditable(False)
109                item1_5 = QtGui.QStandardItem(type)
110                item1_5.setEditable(False)
111                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
112                break
113            # Add the polydisp item as a child
114            item1.appendRow([poly_item])
115        # Param values
116        item2 = QtGui.QStandardItem(str(param.default))
117        # TODO: the error column.
118        # Either add a proxy model or a custom view delegate
119        #item_err = QtGui.QStandardItem()
120        item3 = QtGui.QStandardItem(str(param.limits[0]))
121        item4 = QtGui.QStandardItem(str(param.limits[1]))
122        item5 = QtGui.QStandardItem(param.units)
123        item5.setEditable(False)
124        item.append([item1, item2, item3, item4, item5])
125    return item
126
127def addSimpleParametersToModel(parameters, is2D):
128    """
129    Update local ModelModel with sasmodel parameters
130    """
131    if is2D:
132        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
133    else:
134        params = parameters.iq_parameters
135    item = []
136    for param in params:
137        # Create the top level, checkable item
138        item_name = param.name
139        item1 = QtGui.QStandardItem(item_name)
140        item1.setCheckable(True)
141        item1.setEditable(False)
142        # Param values
143        # TODO: add delegate for validation of cells
144        item2 = QtGui.QStandardItem(str(param.default))
145        item4 = QtGui.QStandardItem(str(param.limits[0]))
146        item5 = QtGui.QStandardItem(str(param.limits[1]))
147        item6 = QtGui.QStandardItem(param.units)
148        item6.setEditable(False)
149        item.append([item1, item2, item4, item5, item6])
150    return item
151
152def markParameterDisabled(model, row):
153    """Given the QModel row number, format to show it is not available for fitting"""
154
155    # If an error column is present, there are a total of 6 columns.
156    items = [model.item(row, c) for c in range(6)]
157
158    model.blockSignals(True)
159
160    for item in items:
161        if item is None:
162            continue
163        item.setEditable(False)
164        item.setCheckable(False)
165
166    item = items[0]
167
168    font = QtGui.QFont()
169    font.setItalic(True)
170    item.setFont(font)
171    item.setForeground(QtGui.QBrush(QtGui.QColor(100, 100, 100)))
172    item.setToolTip("This parameter cannot be fitted.")
173
174    model.blockSignals(False)
175
176def addCheckedListToModel(model, param_list):
177    """
178    Add a QItem to model. Makes the QItem checkable
179    """
180    assert isinstance(model, QtGui.QStandardItemModel)
181    item_list = [QtGui.QStandardItem(item) for item in param_list]
182    item_list[0].setCheckable(True)
183    model.appendRow(item_list)
184
185def addHeadersToModel(model):
186    """
187    Adds predefined headers to the model
188    """
189    for i, item in enumerate(model_header_captions):
190        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
191
192    model.header_tooltips = copy.copy(model_header_tooltips)
193
194def addErrorHeadersToModel(model):
195    """
196    Adds predefined headers to the model
197    """
198    model_header_error_captions = copy.copy(model_header_captions)
199    model_header_error_captions.insert(2, header_error_caption)
200    for i, item in enumerate(model_header_error_captions):
201        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
202
203    model_header_error_tooltips = copy.copy(model_header_tooltips)
204    model_header_error_tooltips.insert(2, error_tooltip)
205    model.header_tooltips = copy.copy(model_header_error_tooltips)
206
207def addPolyHeadersToModel(model):
208    """
209    Adds predefined headers to the model
210    """
211    for i, item in enumerate(poly_header_captions):
212        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
213
214    model.header_tooltips = copy.copy(poly_header_tooltips)
215
216
217def addErrorPolyHeadersToModel(model):
218    """
219    Adds predefined headers to the model
220    """
221    poly_header_error_captions = copy.copy(poly_header_captions)
222    poly_header_error_captions.insert(2, header_error_caption)
223    for i, item in enumerate(poly_header_error_captions):
224        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
225
226    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
227    poly_header_error_tooltips.insert(2, error_tooltip)
228    model.header_tooltips = copy.copy(poly_header_error_tooltips)
229
230def addShellsToModel(parameters, model, index):
231    """
232    Find out multishell parameters and update the model with the requested number of them
233    """
234    multishell_parameters = getIterParams(parameters)
235
236    for i in range(index):
237        for par in multishell_parameters:
238            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
239            param_name = replaceShellName(par.name, i+1)
240            item1 = QtGui.QStandardItem(param_name)
241            item1.setCheckable(True)
242            # check for polydisp params
243            if par.polydisperse:
244                poly_item = QtGui.QStandardItem("Polydispersity")
245                item1_1 = QtGui.QStandardItem("Distribution")
246                # Find param in volume_params
247                for p in parameters.form_volume_parameters:
248                    if p.name != par.name:
249                        continue
250                    item1_2 = QtGui.QStandardItem(str(p.default))
251                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
252                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
253                    item1_5 = QtGui.QStandardItem(p.units)
254                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
255                    break
256                item1.appendRow([poly_item])
257
258            item2 = QtGui.QStandardItem(str(par.default))
259            item3 = QtGui.QStandardItem(str(par.limits[0]))
260            item4 = QtGui.QStandardItem(str(par.limits[1]))
261            item5 = QtGui.QStandardItem(par.units)
262            model.appendRow([item1, item2, item3, item4, item5])
263
264def calculateChi2(reference_data, current_data):
265    """
266    Calculate Chi2 value between two sets of data
267    """
268
269    # WEIGHING INPUT
270    #from sas.sasgui.perspectives.fitting.utils import get_weight
271    #flag = self.get_weight_flag()
272    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
273    chisqr = None
274    if reference_data is None:
275        return chisqr
276
277    # temporary default values for index and weight
278    index = None
279    weight = None
280
281    # Get data: data I, theory I, and data dI in order
282    if isinstance(reference_data, Data2D):
283        if index is None:
284            index = numpy.ones(len(current_data.data), dtype=bool)
285        if weight is not None:
286            current_data.err_data = weight
287        # get rid of zero error points
288        index = index & (current_data.err_data != 0)
289        index = index & (numpy.isfinite(current_data.data))
290        fn = current_data.data[index]
291        gn = reference_data.data[index]
292        en = current_data.err_data[index]
293    else:
294        # 1 d theory from model_thread is only in the range of index
295        if index is None:
296            index = numpy.ones(len(current_data.y), dtype=bool)
297        if weight is not None:
298            current_data.dy = weight
299        if current_data.dy is None or current_data.dy == []:
300            dy = numpy.ones(len(current_data.y))
301        else:
302            ## Set consistently w/AbstractFitengine:
303            # But this should be corrected later.
304            dy = copy.deepcopy(current_data.dy)
305            dy[dy == 0] = 1
306        fn = current_data.y[index]
307        gn = reference_data.y
308        en = dy[index]
309    # Calculate the residual
310    try:
311        res = (fn - gn) / en
312    except ValueError:
313        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
314        return None
315
316    residuals = res[numpy.isfinite(res)]
317    chisqr = numpy.average(residuals * residuals)
318
319    return chisqr
320
321def residualsData1D(reference_data, current_data):
322    """
323    Calculate the residuals for difference of two Data1D sets
324    """
325    # temporary default values for index and weight
326    index = None
327    weight = None
328
329    # 1d theory from model_thread is only in the range of index
330    if current_data.dy is None or current_data.dy == []:
331        dy = numpy.ones(len(current_data.y))
332    else:
333        dy = weight if weight is not None else numpy.ones(len(current_data.y))
334        dy[dy == 0] = 1
335    fn = current_data.y[index][0]
336    gn = reference_data.y
337    en = dy[index][0]
338    # build residuals
339    residuals = Data1D()
340    if len(fn) == len(gn):
341        y = (fn - gn)/en
342        residuals.y = -y
343    else:
344        # TODO: fix case where applying new data from file on top of existing model data
345        try:
346            y = (fn - gn[index][0]) / en
347            residuals.y = y
348        except ValueError:
349            # value errors may show up every once in a while for malformed columns,
350            # just reuse what's there already
351            pass
352
353    residuals.x = current_data.x[index][0]
354    residuals.dy = numpy.ones(len(residuals.y))
355    residuals.dx = None
356    residuals.dxl = None
357    residuals.dxw = None
358    residuals.ytransform = 'y'
359    # For latter scale changes
360    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
361    residuals.yaxis('\\rm{Residuals} ', 'normalized')
362
363    return residuals
364
365def residualsData2D(reference_data, current_data):
366    """
367    Calculate the residuals for difference of two Data2D sets
368    """
369    # temporary default values for index and weight
370    # index = None
371    weight = None
372
373    # build residuals
374    residuals = Data2D()
375    # Not for trunk the line below, instead use the line above
376    current_data.clone_without_data(len(current_data.data), residuals)
377    residuals.data = None
378    fn = current_data.data
379    gn = reference_data.data
380    en = current_data.err_data if weight is None else weight
381    residuals.data = (fn - gn) / en
382    residuals.qx_data = current_data.qx_data
383    residuals.qy_data = current_data.qy_data
384    residuals.q_data = current_data.q_data
385    residuals.err_data = numpy.ones(len(residuals.data))
386    residuals.xmin = min(residuals.qx_data)
387    residuals.xmax = max(residuals.qx_data)
388    residuals.ymin = min(residuals.qy_data)
389    residuals.ymax = max(residuals.qy_data)
390    residuals.q_data = current_data.q_data
391    residuals.mask = current_data.mask
392    residuals.scale = 'linear'
393    # check the lengths
394    if len(residuals.data) != len(residuals.q_data):
395        return None
396    return residuals
397
398def plotResiduals(reference_data, current_data):
399    """
400    Create Data1D/Data2D with residuals, ready for plotting
401    """
402    data_copy = copy.deepcopy(current_data)
403    # Get data: data I, theory I, and data dI in order
404    method_name = current_data.__class__.__name__
405    residuals_dict = {"Data1D": residualsData1D,
406                      "Data2D": residualsData2D}
407
408    residuals = residuals_dict[method_name](reference_data, data_copy)
409
410    theory_name = str(current_data.name.split()[0])
411    residuals.name = "Residuals for " + str(theory_name) + "[" + \
412                    str(reference_data.filename) + "]"
413    residuals.title = residuals.name
414    residuals.ytransform = 'y'
415
416    # when 2 data have the same id override the 1 st plotted
417    # include the last part if keeping charts for separate models is required
418    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
419    # group_id specify on which panel to plot this data
420    group_id = reference_data.group_id
421    residuals.group_id = "res" + str(group_id)
422
423    # Symbol
424    residuals.symbol = 0
425    residuals.hide_error = False
426
427    return residuals
428
429def binary_encode(i, digits):
430    return [i >> d & 1 for d in range(digits)]
431
432def getWeight(data, is2d, flag=None):
433    """
434    Received flag and compute error on data.
435    :param flag: flag to transform error of data.
436    """
437    weight = None
438    if is2d:
439        dy_data = data.err_data
440        data = data.data
441    else:
442        dy_data = data.dy
443        data = data.y
444
445    if flag == 0:
446        weight = numpy.ones_like(data)
447    elif flag == 1:
448        weight = dy_data
449    elif flag == 2:
450        weight = numpy.sqrt(numpy.abs(data))
451    elif flag == 3:
452        weight = numpy.abs(data)
453    return weight
454
455def updateKernelWithResults(kernel, results):
456    """
457    Takes model kernel and applies results dict to its parameters,
458    returning the modified (deep) copy of the kernel.
459    """
460    assert(isinstance(results, dict))
461    local_kernel = copy.deepcopy(kernel)
462
463    for parameter in results.keys():
464        # Update the parameter value - note: this supports +/-inf as well
465        local_kernel.setParam(parameter, results[parameter][0])
466
467    return local_kernel
468
469
470def getStandardParam(model=None):
471    """
472    Returns a list with standard parameters for the current model
473    """
474    param = []
475    num_rows = model.rowCount()
476    if num_rows < 1:
477        return None
478
479    for row in range(num_rows):
480        param_name = model.item(row, 0).text()
481        checkbox_state = model.item(row,0).checkState() == QtCore.Qt.Checked
482        value= model.item(row, 1).text()
483        column_shift = 0
484        if model.columnCount() == 5: # no error column
485            error_state = False
486            error_value = 0.0
487        else:
488            error_state = True
489            error_value = model.item(row, 2).text()
490            column_shift = 1
491        min_state = True
492        max_state = True
493        min_value = model.item(row, 2+column_shift).text()
494        max_value = model.item(row, 3+column_shift).text()
495        unit = ""
496        if model.item(row, 4+column_shift) is not None:
497            unit = model.item(row, 4+column_shift).text()
498
499        param.append([checkbox_state, param_name, value, "",
500                        [error_state, error_value],
501                        [min_state, min_value],
502                        [max_state, max_value], unit])
503
504    return param
505
506def getOrientationParam(kernel_module=None):
507    """
508    Get the dictionary with orientation parameters
509    """
510    param = []
511    if kernel_module is None: 
512        return None
513    for param_name in list(kernel_module.params.keys()):
514        name = param_name
515        value = kernel_module.params[param_name]
516        min_state = True
517        max_state = True
518        error_state = False
519        error_value = 0.0
520        checkbox_state = True #??
521        details = kernel_module.details[param_name] #[unit, mix, max]
522        param.append([checkbox_state, name, value, "",
523                     [error_state, error_value],
524                     [min_state, details[1]],
525                     [max_state, details[2]], details[0]])
526
527    return param
Note: See TracBrowser for help on using the repository browser.