source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 085e3c9d

ESS_GUI_iss959
Last change on this file since 085e3c9d was 57be490, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Merged ESS_GUI_reporting

  • Property mode set to 100644
File size: 17.5 KB
Line 
1import copy
2
3from PyQt5 import QtCore
4from PyQt5 import QtGui
5from PyQt5 import QtWidgets
6
7import numpy
8
9from sas.qtgui.Plotting.PlotterData import Data1D
10from sas.qtgui.Plotting.PlotterData import Data2D
11
12model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
13
14model_header_tooltips = ['Select parameter for fitting',
15                         'Enter parameter value',
16                         'Enter minimum value for parameter',
17                         'Enter maximum value for parameter',
18                         'Unit of the parameter']
19
20poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
21                        'Function', 'Filename']
22
23poly_header_tooltips = ['Select parameter for fitting',
24                        'Enter polydispersity ratio (STD/mean). '
25                        'STD: standard deviation from the mean value',
26                        'Enter minimum value for parameter',
27                        'Enter maximum value for parameter',
28                        'Enter number of points for parameter',
29                        'Enter number of sigmas parameter',
30                        'Select distribution function',
31                        'Select filename with user-definable distribution']
32
33error_tooltip = 'Error value for fitted parameter'
34header_error_caption = 'Error'
35
36def replaceShellName(param_name, value):
37    """
38    Updates parameter name from <param_name>[n_shell] to <param_name>value
39    """
40    assert '[' in param_name
41    return param_name[:param_name.index('[')]+str(value)
42
43def getIterParams(model):
44    """
45    Returns a list of all multi-shell parameters in 'model'
46    """
47    return list([par for par in model.iq_parameters if "[" in par.name])
48
49def getMultiplicity(model):
50    """
51    Finds out if 'model' has multishell parameters.
52    If so, returns the name of the counter parameter and the number of shells
53    """
54    iter_params = getIterParams(model)
55    param_name = ""
56    param_length = 0
57    if iter_params:
58        param_length = iter_params[0].length
59        param_name = iter_params[0].length_control
60        if param_name is None and '[' in iter_params[0].name:
61            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
62    return (param_name, param_length)
63
64def addParametersToModel(parameters, kernel_module, is2D):
65    """
66    Update local ModelModel with sasmodel parameters
67    """
68    multishell_parameters = getIterParams(parameters)
69    multishell_param_name, _ = getMultiplicity(parameters)
70
71    if is2D:
72        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
73    else:
74        params = parameters.iq_parameters
75    item = []
76    for param in params:
77        # don't include shell parameters
78        if param.name == multishell_param_name:
79            continue
80        # Modify parameter name from <param>[n] to <param>1
81        item_name = param.name
82        if param in multishell_parameters:
83            continue
84        #    item_name = replaceShellName(param.name, 1)
85
86        item1 = QtGui.QStandardItem(item_name)
87        item1.setCheckable(True)
88        item1.setEditable(False)
89        # item_err = QtGui.QStandardItem()
90        # check for polydisp params
91        if param.polydisperse:
92            poly_item = QtGui.QStandardItem("Polydispersity")
93            poly_item.setEditable(False)
94            item1_1 = QtGui.QStandardItem("Distribution")
95            item1_1.setEditable(False)
96            # Find param in volume_params
97            for p in parameters.form_volume_parameters:
98                if p.name != param.name:
99                    continue
100                width = kernel_module.getParam(p.name+'.width')
101                type = kernel_module.getParam(p.name+'.type')
102
103                item1_2 = QtGui.QStandardItem(str(width))
104                item1_2.setEditable(False)
105                item1_3 = QtGui.QStandardItem()
106                item1_3.setEditable(False)
107                item1_4 = QtGui.QStandardItem()
108                item1_4.setEditable(False)
109                item1_5 = QtGui.QStandardItem(type)
110                item1_5.setEditable(False)
111                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
112                break
113            # Add the polydisp item as a child
114            item1.appendRow([poly_item])
115        # Param values
116        item2 = QtGui.QStandardItem(str(param.default))
117        # TODO: the error column.
118        # Either add a proxy model or a custom view delegate
119        #item_err = QtGui.QStandardItem()
120        item3 = QtGui.QStandardItem(str(param.limits[0]))
121        item4 = QtGui.QStandardItem(str(param.limits[1]))
122        item5 = QtGui.QStandardItem(param.units)
123        item5.setEditable(False)
124        item.append([item1, item2, item3, item4, item5])
125    return item
126
127def addSimpleParametersToModel(parameters, is2D):
128    """
129    Update local ModelModel with sasmodel parameters
130    """
131    if is2D:
132        params = [p for p in parameters.kernel_parameters if p.type != 'magnetic']
133    else:
134        params = parameters.iq_parameters
135    item = []
136    for param in params:
137        # Create the top level, checkable item
138        item_name = param.name
139        item1 = QtGui.QStandardItem(item_name)
140        item1.setCheckable(True)
141        item1.setEditable(False)
142        # Param values
143        # TODO: add delegate for validation of cells
144        item2 = QtGui.QStandardItem(str(param.default))
145        item4 = QtGui.QStandardItem(str(param.limits[0]))
146        item5 = QtGui.QStandardItem(str(param.limits[1]))
147        item6 = QtGui.QStandardItem(param.units)
148        item6.setEditable(False)
149        item.append([item1, item2, item4, item5, item6])
150    return item
151
152def addCheckedListToModel(model, param_list):
153    """
154    Add a QItem to model. Makes the QItem checkable
155    """
156    assert isinstance(model, QtGui.QStandardItemModel)
157    item_list = [QtGui.QStandardItem(item) for item in param_list]
158    item_list[0].setCheckable(True)
159    model.appendRow(item_list)
160
161def addHeadersToModel(model):
162    """
163    Adds predefined headers to the model
164    """
165    for i, item in enumerate(model_header_captions):
166        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
167
168    model.header_tooltips = copy.copy(model_header_tooltips)
169
170def addErrorHeadersToModel(model):
171    """
172    Adds predefined headers to the model
173    """
174    model_header_error_captions = copy.copy(model_header_captions)
175    model_header_error_captions.insert(2, header_error_caption)
176    for i, item in enumerate(model_header_error_captions):
177        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
178
179    model_header_error_tooltips = copy.copy(model_header_tooltips)
180    model_header_error_tooltips.insert(2, error_tooltip)
181    model.header_tooltips = copy.copy(model_header_error_tooltips)
182
183def addPolyHeadersToModel(model):
184    """
185    Adds predefined headers to the model
186    """
187    for i, item in enumerate(poly_header_captions):
188        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
189
190    model.header_tooltips = copy.copy(poly_header_tooltips)
191
192
193def addErrorPolyHeadersToModel(model):
194    """
195    Adds predefined headers to the model
196    """
197    poly_header_error_captions = copy.copy(poly_header_captions)
198    poly_header_error_captions.insert(2, header_error_caption)
199    for i, item in enumerate(poly_header_error_captions):
200        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
201
202    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
203    poly_header_error_tooltips.insert(2, error_tooltip)
204    model.header_tooltips = copy.copy(poly_header_error_tooltips)
205
206def addShellsToModel(parameters, model, index):
207    """
208    Find out multishell parameters and update the model with the requested number of them
209    """
210    multishell_parameters = getIterParams(parameters)
211
212    for i in range(index):
213        for par in multishell_parameters:
214            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
215            param_name = replaceShellName(par.name, i+1)
216            item1 = QtGui.QStandardItem(param_name)
217            item1.setCheckable(True)
218            # check for polydisp params
219            if par.polydisperse:
220                poly_item = QtGui.QStandardItem("Polydispersity")
221                item1_1 = QtGui.QStandardItem("Distribution")
222                # Find param in volume_params
223                for p in parameters.form_volume_parameters:
224                    if p.name != par.name:
225                        continue
226                    item1_2 = QtGui.QStandardItem(str(p.default))
227                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
228                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
229                    item1_5 = QtGui.QStandardItem(p.units)
230                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
231                    break
232                item1.appendRow([poly_item])
233
234            item2 = QtGui.QStandardItem(str(par.default))
235            item3 = QtGui.QStandardItem(str(par.limits[0]))
236            item4 = QtGui.QStandardItem(str(par.limits[1]))
237            item5 = QtGui.QStandardItem(par.units)
238            model.appendRow([item1, item2, item3, item4, item5])
239
240def calculateChi2(reference_data, current_data):
241    """
242    Calculate Chi2 value between two sets of data
243    """
244
245    # WEIGHING INPUT
246    #from sas.sasgui.perspectives.fitting.utils import get_weight
247    #flag = self.get_weight_flag()
248    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
249    chisqr = None
250    if reference_data is None:
251        return chisqr
252
253    # temporary default values for index and weight
254    index = None
255    weight = None
256
257    # Get data: data I, theory I, and data dI in order
258    if isinstance(reference_data, Data2D):
259        if index is None:
260            index = numpy.ones(len(current_data.data), dtype=bool)
261        if weight is not None:
262            current_data.err_data = weight
263        # get rid of zero error points
264        index = index & (current_data.err_data != 0)
265        index = index & (numpy.isfinite(current_data.data))
266        fn = current_data.data[index]
267        gn = reference_data.data[index]
268        en = current_data.err_data[index]
269    else:
270        # 1 d theory from model_thread is only in the range of index
271        if index is None:
272            index = numpy.ones(len(current_data.y), dtype=bool)
273        if weight is not None:
274            current_data.dy = weight
275        if current_data.dy is None or current_data.dy == []:
276            dy = numpy.ones(len(current_data.y))
277        else:
278            ## Set consistently w/AbstractFitengine:
279            # But this should be corrected later.
280            dy = copy.deepcopy(current_data.dy)
281            dy[dy == 0] = 1
282        fn = current_data.y[index]
283        gn = reference_data.y
284        en = dy[index]
285    # Calculate the residual
286    try:
287        res = (fn - gn) / en
288    except ValueError:
289        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
290        return None
291
292    residuals = res[numpy.isfinite(res)]
293    chisqr = numpy.average(residuals * residuals)
294
295    return chisqr
296
297def residualsData1D(reference_data, current_data):
298    """
299    Calculate the residuals for difference of two Data1D sets
300    """
301    # temporary default values for index and weight
302    index = None
303    weight = None
304
305    # 1d theory from model_thread is only in the range of index
306    if current_data.dy is None or current_data.dy == []:
307        dy = numpy.ones(len(current_data.y))
308    else:
309        dy = weight if weight is not None else numpy.ones(len(current_data.y))
310        dy[dy == 0] = 1
311    fn = current_data.y[index][0]
312    gn = reference_data.y
313    en = dy[index][0]
314    # build residuals
315    residuals = Data1D()
316    if len(fn) == len(gn):
317        y = (fn - gn)/en
318        residuals.y = -y
319    else:
320        # TODO: fix case where applying new data from file on top of existing model data
321        try:
322            y = (fn - gn[index][0]) / en
323            residuals.y = y
324        except ValueError:
325            # value errors may show up every once in a while for malformed columns,
326            # just reuse what's there already
327            pass
328
329    residuals.x = current_data.x[index][0]
330    residuals.dy = numpy.ones(len(residuals.y))
331    residuals.dx = None
332    residuals.dxl = None
333    residuals.dxw = None
334    residuals.ytransform = 'y'
335    # For latter scale changes
336    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
337    residuals.yaxis('\\rm{Residuals} ', 'normalized')
338
339    return residuals
340
341def residualsData2D(reference_data, current_data):
342    """
343    Calculate the residuals for difference of two Data2D sets
344    """
345    # temporary default values for index and weight
346    # index = None
347    weight = None
348
349    # build residuals
350    residuals = Data2D()
351    # Not for trunk the line below, instead use the line above
352    current_data.clone_without_data(len(current_data.data), residuals)
353    residuals.data = None
354    fn = current_data.data
355    gn = reference_data.data
356    en = current_data.err_data if weight is None else weight
357    residuals.data = (fn - gn) / en
358    residuals.qx_data = current_data.qx_data
359    residuals.qy_data = current_data.qy_data
360    residuals.q_data = current_data.q_data
361    residuals.err_data = numpy.ones(len(residuals.data))
362    residuals.xmin = min(residuals.qx_data)
363    residuals.xmax = max(residuals.qx_data)
364    residuals.ymin = min(residuals.qy_data)
365    residuals.ymax = max(residuals.qy_data)
366    residuals.q_data = current_data.q_data
367    residuals.mask = current_data.mask
368    residuals.scale = 'linear'
369    # check the lengths
370    if len(residuals.data) != len(residuals.q_data):
371        return None
372    return residuals
373
374def plotResiduals(reference_data, current_data):
375    """
376    Create Data1D/Data2D with residuals, ready for plotting
377    """
378    data_copy = copy.deepcopy(current_data)
379    # Get data: data I, theory I, and data dI in order
380    method_name = current_data.__class__.__name__
381    residuals_dict = {"Data1D": residualsData1D,
382                      "Data2D": residualsData2D}
383
384    residuals = residuals_dict[method_name](reference_data, data_copy)
385
386    theory_name = str(current_data.name.split()[0])
387    residuals.name = "Residuals for " + str(theory_name) + "[" + \
388                    str(reference_data.filename) + "]"
389    residuals.title = residuals.name
390    residuals.ytransform = 'y'
391
392    # when 2 data have the same id override the 1 st plotted
393    # include the last part if keeping charts for separate models is required
394    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
395    # group_id specify on which panel to plot this data
396    group_id = reference_data.group_id
397    residuals.group_id = "res" + str(group_id)
398
399    # Symbol
400    residuals.symbol = 0
401    residuals.hide_error = False
402
403    return residuals
404
405def binary_encode(i, digits):
406    return [i >> d & 1 for d in range(digits)]
407
408def getWeight(data, is2d, flag=None):
409    """
410    Received flag and compute error on data.
411    :param flag: flag to transform error of data.
412    """
413    weight = None
414    if is2d:
415        dy_data = data.err_data
416        data = data.data
417    else:
418        dy_data = data.dy
419        data = data.y
420
421    if flag == 0:
422        weight = numpy.ones_like(data)
423    elif flag == 1:
424        weight = dy_data
425    elif flag == 2:
426        weight = numpy.sqrt(numpy.abs(data))
427    elif flag == 3:
428        weight = numpy.abs(data)
429    return weight
430
431def updateKernelWithResults(kernel, results):
432    """
433    Takes model kernel and applies results dict to its parameters,
434    returning the modified (deep) copy of the kernel.
435    """
436    assert(isinstance(results, dict))
437    local_kernel = copy.deepcopy(kernel)
438
439    for parameter in results.keys():
440        # Update the parameter value - note: this supports +/-inf as well
441        local_kernel.setParam(parameter, results[parameter][0])
442
443    return local_kernel
444
445
446def getStandardParam(model=None):
447    """
448    Returns a list with standard parameters for the current model
449    """
450    param = []
451    num_rows = model.rowCount()
452    if num_rows < 1:
453        return None
454
455    for row in range(num_rows):
456        param_name = model.item(row, 0).text()
457        checkbox_state = model.item(row,0).checkState() == QtCore.Qt.Checked
458        value= model.item(row, 1).text()
459        column_shift = 0
460        if model.columnCount() == 5: # no error column
461            error_state = False
462            error_value = 0.0
463        else:
464            error_state = True
465            error_value = model.item(row, 2).text()
466            column_shift = 1
467        min_state = True
468        max_state = True
469        min_value = model.item(row, 2+column_shift).text()
470        max_value = model.item(row, 3+column_shift).text()
471        unit = ""
472        if model.item(row, 4+column_shift) is not None:
473            unit = model.item(row, 4+column_shift).text()
474
475        param.append([checkbox_state, param_name, value, "",
476                        [error_state, error_value],
477                        [min_state, min_value],
478                        [max_state, max_value], unit])
479
480    return param
481
482def getOrientationParam(kernel_module=None):
483    """
484    Get the dictionary with orientation parameters
485    """
486    param = []
487    if kernel_module is None: 
488        return None
489    for param_name in list(kernel_module.params.keys()):
490        name = param_name
491        value = kernel_module.params[param_name]
492        min_state = True
493        max_state = True
494        error_state = False
495        error_value = 0.0
496        checkbox_state = True #??
497        details = kernel_module.details[param_name] #[unit, mix, max]
498        param.append([checkbox_state, name, value, "",
499                     [error_state, error_value],
500                     [min_state, details[1]],
501                     [max_state, details[2]], details[0]])
502
503    return param
Note: See TracBrowser for help on using the repository browser.