source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ d4dac80

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since d4dac80 was d4dac80, checked in by Piotr Rozyczko <rozyczko@…>, 4 years ago

Merge branch 'ESS_GUI' into ESS_GUI_better_batch

  • Property mode set to 100644
File size: 15.4 KB
Line 
1import copy
2
3from PyQt5 import QtCore
4from PyQt5 import QtGui
5from PyQt5 import QtWidgets
6
7import numpy
8
9from sas.qtgui.Plotting.PlotterData import Data1D
10from sas.qtgui.Plotting.PlotterData import Data2D
11
12model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
13
14model_header_tooltips = ['Select parameter for fitting',
15                         'Enter parameter value',
16                         'Enter minimum value for parameter',
17                         'Enter maximum value for parameter',
18                         'Unit of the parameter']
19
20poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
21                        'Function', 'Filename']
22
23poly_header_tooltips = ['Select parameter for fitting',
24                        'Enter polydispersity ratio (STD/mean). '
25                        'STD: standard deviation from the mean value',
26                        'Enter minimum value for parameter',
27                        'Enter maximum value for parameter',
28                        'Enter number of points for parameter',
29                        'Enter number of sigmas parameter',
30                        'Select distribution function',
31                        'Select filename with user-definable distribution']
32
33error_tooltip = 'Error value for fitted parameter'
34header_error_caption = 'Error'
35
36def replaceShellName(param_name, value):
37    """
38    Updates parameter name from <param_name>[n_shell] to <param_name>value
39    """
40    assert '[' in param_name
41    return param_name[:param_name.index('[')]+str(value)
42
43def getIterParams(model):
44    """
45    Returns a list of all multi-shell parameters in 'model'
46    """
47    return list([par for par in model.iq_parameters if "[" in par.name])
48
49def getMultiplicity(model):
50    """
51    Finds out if 'model' has multishell parameters.
52    If so, returns the name of the counter parameter and the number of shells
53    """
54    iter_params = getIterParams(model)
55    param_name = ""
56    param_length = 0
57    if iter_params:
58        param_length = iter_params[0].length
59        param_name = iter_params[0].length_control
60        if param_name is None and '[' in iter_params[0].name:
61            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
62    return (param_name, param_length)
63
64def addParametersToModel(parameters, kernel_module, is2D):
65    """
66    Update local ModelModel with sasmodel parameters
67    """
68    multishell_parameters = getIterParams(parameters)
69    multishell_param_name, _ = getMultiplicity(parameters)
70    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
71    item = []
72    for param in params:
73        # don't include shell parameters
74        if param.name == multishell_param_name:
75            continue
76        # Modify parameter name from <param>[n] to <param>1
77        item_name = param.name
78        if param in multishell_parameters:
79            continue
80        #    item_name = replaceShellName(param.name, 1)
81
82        item1 = QtGui.QStandardItem(item_name)
83        item1.setCheckable(True)
84        item1.setEditable(False)
85        # item_err = QtGui.QStandardItem()
86        # check for polydisp params
87        if param.polydisperse:
88            poly_item = QtGui.QStandardItem("Polydispersity")
89            poly_item.setEditable(False)
90            item1_1 = QtGui.QStandardItem("Distribution")
91            item1_1.setEditable(False)
92            # Find param in volume_params
93            for p in parameters.form_volume_parameters:
94                if p.name != param.name:
95                    continue
96                width = kernel_module.getParam(p.name+'.width')
97                type = kernel_module.getParam(p.name+'.type')
98
99                item1_2 = QtGui.QStandardItem(str(width))
100                item1_2.setEditable(False)
101                item1_3 = QtGui.QStandardItem()
102                item1_3.setEditable(False)
103                item1_4 = QtGui.QStandardItem()
104                item1_4.setEditable(False)
105                item1_5 = QtGui.QStandardItem(type)
106                item1_5.setEditable(False)
107                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
108                break
109            # Add the polydisp item as a child
110            item1.appendRow([poly_item])
111        # Param values
112        item2 = QtGui.QStandardItem(str(param.default))
113        # TODO: the error column.
114        # Either add a proxy model or a custom view delegate
115        #item_err = QtGui.QStandardItem()
116        item3 = QtGui.QStandardItem(str(param.limits[0]))
117        item4 = QtGui.QStandardItem(str(param.limits[1]))
118        item5 = QtGui.QStandardItem(param.units)
119        item5.setEditable(False)
120        item.append([item1, item2, item3, item4, item5])
121    return item
122
123def addSimpleParametersToModel(parameters, is2D):
124    """
125    Update local ModelModel with sasmodel parameters
126    """
127    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
128    item = []
129    for param in params:
130        # Create the top level, checkable item
131        item_name = param.name
132        item1 = QtGui.QStandardItem(item_name)
133        item1.setCheckable(True)
134        item1.setEditable(False)
135        # Param values
136        # TODO: add delegate for validation of cells
137        item2 = QtGui.QStandardItem(str(param.default))
138        item4 = QtGui.QStandardItem(str(param.limits[0]))
139        item5 = QtGui.QStandardItem(str(param.limits[1]))
140        item6 = QtGui.QStandardItem(param.units)
141        item6.setEditable(False)
142        item.append([item1, item2, item4, item5, item6])
143    return item
144
145def addCheckedListToModel(model, param_list):
146    """
147    Add a QItem to model. Makes the QItem checkable
148    """
149    assert isinstance(model, QtGui.QStandardItemModel)
150    item_list = [QtGui.QStandardItem(item) for item in param_list]
151    item_list[0].setCheckable(True)
152    model.appendRow(item_list)
153
154def addHeadersToModel(model):
155    """
156    Adds predefined headers to the model
157    """
158    for i, item in enumerate(model_header_captions):
159        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
160
161    model.header_tooltips = copy.copy(model_header_tooltips)
162
163def addErrorHeadersToModel(model):
164    """
165    Adds predefined headers to the model
166    """
167    model_header_error_captions = copy.copy(model_header_captions)
168    model_header_error_captions.insert(2, header_error_caption)
169    for i, item in enumerate(model_header_error_captions):
170        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
171
172    model_header_error_tooltips = copy.copy(model_header_tooltips)
173    model_header_error_tooltips.insert(2, error_tooltip)
174    model.header_tooltips = copy.copy(model_header_error_tooltips)
175
176def addPolyHeadersToModel(model):
177    """
178    Adds predefined headers to the model
179    """
180    for i, item in enumerate(poly_header_captions):
181        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
182
183    model.header_tooltips = copy.copy(poly_header_tooltips)
184
185
186def addErrorPolyHeadersToModel(model):
187    """
188    Adds predefined headers to the model
189    """
190    poly_header_error_captions = copy.copy(poly_header_captions)
191    poly_header_error_captions.insert(2, header_error_caption)
192    for i, item in enumerate(poly_header_error_captions):
193        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
194
195    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
196    poly_header_error_tooltips.insert(2, error_tooltip)
197    model.header_tooltips = copy.copy(poly_header_error_tooltips)
198
199def addShellsToModel(parameters, model, index):
200    """
201    Find out multishell parameters and update the model with the requested number of them
202    """
203    multishell_parameters = getIterParams(parameters)
204
205    for i in range(index):
206        for par in multishell_parameters:
207            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
208            param_name = replaceShellName(par.name, i+1)
209            item1 = QtGui.QStandardItem(param_name)
210            item1.setCheckable(True)
211            # check for polydisp params
212            if par.polydisperse:
213                poly_item = QtGui.QStandardItem("Polydispersity")
214                item1_1 = QtGui.QStandardItem("Distribution")
215                # Find param in volume_params
216                for p in parameters.form_volume_parameters:
217                    if p.name != par.name:
218                        continue
219                    item1_2 = QtGui.QStandardItem(str(p.default))
220                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
221                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
222                    item1_5 = QtGui.QStandardItem(p.units)
223                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
224                    break
225                item1.appendRow([poly_item])
226
227            item2 = QtGui.QStandardItem(str(par.default))
228            item3 = QtGui.QStandardItem(str(par.limits[0]))
229            item4 = QtGui.QStandardItem(str(par.limits[1]))
230            item5 = QtGui.QStandardItem(par.units)
231            model.appendRow([item1, item2, item3, item4, item5])
232
233def calculateChi2(reference_data, current_data):
234    """
235    Calculate Chi2 value between two sets of data
236    """
237
238    # WEIGHING INPUT
239    #from sas.sasgui.perspectives.fitting.utils import get_weight
240    #flag = self.get_weight_flag()
241    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
242    chisqr = None
243    if reference_data is None:
244        return chisqr
245
246    # temporary default values for index and weight
247    index = None
248    weight = None
249
250    # Get data: data I, theory I, and data dI in order
251    if isinstance(reference_data, Data2D):
252        if index is None:
253            index = numpy.ones(len(current_data.data), dtype=bool)
254        if weight is not None:
255            current_data.err_data = weight
256        # get rid of zero error points
257        index = index & (current_data.err_data != 0)
258        index = index & (numpy.isfinite(current_data.data))
259        fn = current_data.data[index]
260        gn = reference_data.data[index]
261        en = current_data.err_data[index]
262    else:
263        # 1 d theory from model_thread is only in the range of index
264        if index is None:
265            index = numpy.ones(len(current_data.y), dtype=bool)
266        if weight is not None:
267            current_data.dy = weight
268        if current_data.dy is None or current_data.dy == []:
269            dy = numpy.ones(len(current_data.y))
270        else:
271            ## Set consistently w/AbstractFitengine:
272            # But this should be corrected later.
273            dy = copy.deepcopy(current_data.dy)
274            dy[dy == 0] = 1
275        fn = current_data.y[index]
276        gn = reference_data.y
277        en = dy[index]
278    # Calculate the residual
279    try:
280        res = (fn - gn) / en
281    except ValueError:
282        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
283        return None
284
285    residuals = res[numpy.isfinite(res)]
286    chisqr = numpy.average(residuals * residuals)
287
288    return chisqr
289
290def residualsData1D(reference_data, current_data):
291    """
292    Calculate the residuals for difference of two Data1D sets
293    """
294    # temporary default values for index and weight
295    index = None
296    weight = None
297
298    # 1d theory from model_thread is only in the range of index
299    if current_data.dy is None or current_data.dy == []:
300        dy = numpy.ones(len(current_data.y))
301    else:
302        dy = weight if weight is not None else numpy.ones(len(current_data.y))
303        dy[dy == 0] = 1
304    fn = current_data.y[index][0]
305    gn = reference_data.y
306    en = dy[index][0]
307    # build residuals
308    residuals = Data1D()
309    if len(fn) == len(gn):
310        y = (fn - gn)/en
311        residuals.y = -y
312    else:
313        # TODO: fix case where applying new data from file on top of existing model data
314        try:
315            y = (fn - gn[index][0]) / en
316            residuals.y = y
317        except ValueError:
318            # value errors may show up every once in a while for malformed columns,
319            # just reuse what's there already
320            pass
321
322    residuals.x = current_data.x[index][0]
323    residuals.dy = numpy.ones(len(residuals.y))
324    residuals.dx = None
325    residuals.dxl = None
326    residuals.dxw = None
327    residuals.ytransform = 'y'
328    # For latter scale changes
329    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
330    residuals.yaxis('\\rm{Residuals} ', 'normalized')
331
332    return residuals
333
334def residualsData2D(reference_data, current_data):
335    """
336    Calculate the residuals for difference of two Data2D sets
337    """
338    # temporary default values for index and weight
339    # index = None
340    weight = None
341
342    # build residuals
343    residuals = Data2D()
344    # Not for trunk the line below, instead use the line above
345    current_data.clone_without_data(len(current_data.data), residuals)
346    residuals.data = None
347    fn = current_data.data
348    gn = reference_data.data
349    en = current_data.err_data if weight is None else weight
350    residuals.data = (fn - gn) / en
351    residuals.qx_data = current_data.qx_data
352    residuals.qy_data = current_data.qy_data
353    residuals.q_data = current_data.q_data
354    residuals.err_data = numpy.ones(len(residuals.data))
355    residuals.xmin = min(residuals.qx_data)
356    residuals.xmax = max(residuals.qx_data)
357    residuals.ymin = min(residuals.qy_data)
358    residuals.ymax = max(residuals.qy_data)
359    residuals.q_data = current_data.q_data
360    residuals.mask = current_data.mask
361    residuals.scale = 'linear'
362    # check the lengths
363    if len(residuals.data) != len(residuals.q_data):
364        return None
365    return residuals
366
367def plotResiduals(reference_data, current_data):
368    """
369    Create Data1D/Data2D with residuals, ready for plotting
370    """
371    data_copy = copy.deepcopy(current_data)
372    # Get data: data I, theory I, and data dI in order
373    method_name = current_data.__class__.__name__
374    residuals_dict = {"Data1D": residualsData1D,
375                      "Data2D": residualsData2D}
376
377    residuals = residuals_dict[method_name](reference_data, data_copy)
378
379    theory_name = str(current_data.name.split()[0])
380    residuals.name = "Residuals for " + str(theory_name) + "[" + \
381                    str(reference_data.filename) + "]"
382    residuals.title = residuals.name
383    residuals.ytransform = 'y'
384
385    # when 2 data have the same id override the 1 st plotted
386    # include the last part if keeping charts for separate models is required
387    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
388    # group_id specify on which panel to plot this data
389    group_id = reference_data.group_id
390    residuals.group_id = "res" + str(group_id)
391
392    # Symbol
393    residuals.symbol = 0
394    residuals.hide_error = False
395
396    return residuals
397
398def binary_encode(i, digits):
399    return [i >> d & 1 for d in range(digits)]
400
401def getWeight(data, is2d, flag=None):
402    """
403    Received flag and compute error on data.
404    :param flag: flag to transform error of data.
405    """
406    weight = None
407    if is2d:
408        dy_data = data.err_data
409        data = data.data
410    else:
411        dy_data = data.dy
412        data = data.y
413
414    if flag == 0:
415        weight = numpy.ones_like(data)
416    elif flag == 1:
417        weight = dy_data
418    elif flag == 2:
419        weight = numpy.sqrt(numpy.abs(data))
420    elif flag == 3:
421        weight = numpy.abs(data)
422    return weight
423
424def updateKernelWithResults(kernel, results):
425    """
426    Takes model kernel and applies results dict to its parameters,
427    returning the modified (deep) copy of the kernel.
428    """
429    assert(isinstance(results, dict))
430    local_kernel = copy.deepcopy(kernel)
431
432    for parameter in results.keys():
433        # Update the parameter value - note: this supports +/-inf as well
434        local_kernel.setParam(parameter, results[parameter][0])
435
436    return local_kernel
437
438
Note: See TracBrowser for help on using the repository browser.