source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ b3e8629

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since b3e8629 was b3e8629, checked in by Piotr Rozyczko <rozyczko@…>, 6 years ago

Initial changes to make SasView? run with python3

  • Property mode set to 100755
File size: 15.0 KB
Line 
1from copy import deepcopy
2
3from PyQt4 import QtGui
4from PyQt4 import QtCore
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
12
13model_header_tooltips = ['Select parameter for fitting',
14                         'Enter parameter value',
15                         'Enter minimum value for parameter',
16                         'Enter maximum value for parameter',
17                         'Unit of the parameter']
18
19poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
20                        'Function', 'Filename']
21
22poly_header_tooltips = ['Select parameter for fitting',
23                        'Enter polydispersity ratio (STD/mean). '
24                        'STD: standard deviation from the mean value',
25                        'Enter minimum value for parameter',
26                        'Enter maximum value for parameter',
27                        'Enter number of points for parameter',
28                        'Enter number of sigmas parameter',
29                        'Select distribution function',
30                        'Select filename with user-definable distribution']
31
32error_tooltip = 'Error value for fitted parameter'
33header_error_caption = 'Error'
34
35def replaceShellName(param_name, value):
36    """
37    Updates parameter name from <param_name>[n_shell] to <param_name>value
38    """
39    assert '[' in param_name
40    return param_name[:param_name.index('[')]+str(value)
41
42def getIterParams(model):
43    """
44    Returns a list of all multi-shell parameters in 'model'
45    """
46    return list([par for par in model.iq_parameters if "[" in par.name])
47
48def getMultiplicity(model):
49    """
50    Finds out if 'model' has multishell parameters.
51    If so, returns the name of the counter parameter and the number of shells
52    """
53    iter_params = getIterParams(model)
54    param_name = ""
55    param_length = 0
56    if iter_params:
57        param_length = iter_params[0].length
58        param_name = iter_params[0].length_control
59        if param_name is None and '[' in iter_params[0].name:
60            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
61    return (param_name, param_length)
62
63def addParametersToModel(parameters, kernel_module, is2D):
64    """
65    Update local ModelModel with sasmodel parameters
66    """
67    multishell_parameters = getIterParams(parameters)
68    multishell_param_name, _ = getMultiplicity(parameters)
69    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
70    item = []
71    for param in params:
72        # don't include shell parameters
73        if param.name == multishell_param_name:
74            continue
75        # Modify parameter name from <param>[n] to <param>1
76        item_name = param.name
77        if param in multishell_parameters:
78            continue
79        #    item_name = replaceShellName(param.name, 1)
80
81        item1 = QtGui.QStandardItem(item_name)
82        item1.setCheckable(True)
83        item1.setEditable(False)
84        # item_err = QtGui.QStandardItem()
85        # check for polydisp params
86        if param.polydisperse:
87            poly_item = QtGui.QStandardItem("Polydispersity")
88            poly_item.setEditable(False)
89            item1_1 = QtGui.QStandardItem("Distribution")
90            item1_1.setEditable(False)
91            # Find param in volume_params
92            for p in parameters.form_volume_parameters:
93                if p.name != param.name:
94                    continue
95                width = kernel_module.getParam(p.name+'.width')
96                type = kernel_module.getParam(p.name+'.type')
97
98                item1_2 = QtGui.QStandardItem(str(width))
99                item1_2.setEditable(False)
100                item1_3 = QtGui.QStandardItem()
101                item1_3.setEditable(False)
102                item1_4 = QtGui.QStandardItem()
103                item1_4.setEditable(False)
104                item1_5 = QtGui.QStandardItem(type)
105                item1_5.setEditable(False)
106                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
107                break
108            # Add the polydisp item as a child
109            item1.appendRow([poly_item])
110        # Param values
111        item2 = QtGui.QStandardItem(str(param.default))
112        # TODO: the error column.
113        # Either add a proxy model or a custom view delegate
114        #item_err = QtGui.QStandardItem()
115        item3 = QtGui.QStandardItem(str(param.limits[0]))
116        item4 = QtGui.QStandardItem(str(param.limits[1]))
117        item5 = QtGui.QStandardItem(param.units)
118        item5.setEditable(False)
119        item.append([item1, item2, item3, item4, item5])
120    return item
121
122def addSimpleParametersToModel(parameters, is2D):
123    """
124    Update local ModelModel with sasmodel parameters
125    """
126    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
127    item = []
128    for param in params:
129        # Create the top level, checkable item
130        item_name = param.name
131        item1 = QtGui.QStandardItem(item_name)
132        item1.setCheckable(True)
133        item1.setEditable(False)
134        # Param values
135        # TODO: add delegate for validation of cells
136        item2 = QtGui.QStandardItem(str(param.default))
137        item4 = QtGui.QStandardItem(str(param.limits[0]))
138        item5 = QtGui.QStandardItem(str(param.limits[1]))
139        item6 = QtGui.QStandardItem(param.units)
140        item6.setEditable(False)
141        item.append([item1, item2, item4, item5, item6])
142    return item
143
144def addCheckedListToModel(model, param_list):
145    """
146    Add a QItem to model. Makes the QItem checkable
147    """
148    assert isinstance(model, QtGui.QStandardItemModel)
149    item_list = [QtGui.QStandardItem(item) for item in param_list]
150    item_list[0].setCheckable(True)
151    model.appendRow(item_list)
152
153def addHeadersToModel(model):
154    """
155    Adds predefined headers to the model
156    """
157    for i, item in enumerate(model_header_captions):
158        #model.setHeaderData(i, QtCore.Qt.Horizontal, QtCore.QVariant(item))
159        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
160
161    model.header_tooltips = model_header_tooltips
162
163def addErrorHeadersToModel(model):
164    """
165    Adds predefined headers to the model
166    """
167    model_header_error_captions = model_header_captions
168    model_header_error_captions.insert(2, header_error_caption)
169    for i, item in enumerate(model_header_error_captions):
170        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
171
172    model_header_error_tooltips = model_header_tooltips
173    model_header_error_tooltips.insert(2, error_tooltip)
174    model.header_tooltips = model_header_error_tooltips
175
176def addPolyHeadersToModel(model):
177    """
178    Adds predefined headers to the model
179    """
180    for i, item in enumerate(poly_header_captions):
181        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
182
183    model.header_tooltips = poly_header_tooltips
184
185
186def addErrorPolyHeadersToModel(model):
187    """
188    Adds predefined headers to the model
189    """
190    poly_header_error_captions = poly_header_captions
191    poly_header_error_captions.insert(2, header_error_caption)
192    for i, item in enumerate(poly_header_error_captions):
193        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
194
195    poly_header_error_tooltips = poly_header_tooltips
196    poly_header_error_tooltips.insert(2, error_tooltip)
197    model.header_tooltips = poly_header_error_tooltips
198
199def addShellsToModel(parameters, model, index):
200    """
201    Find out multishell parameters and update the model with the requested number of them
202    """
203    multishell_parameters = getIterParams(parameters)
204
205    for i in range(index):
206        for par in multishell_parameters:
207            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
208            param_name = replaceShellName(par.name, i+1)
209            item1 = QtGui.QStandardItem(param_name)
210            item1.setCheckable(True)
211            # check for polydisp params
212            if par.polydisperse:
213                poly_item = QtGui.QStandardItem("Polydispersity")
214                item1_1 = QtGui.QStandardItem("Distribution")
215                # Find param in volume_params
216                for p in parameters.form_volume_parameters:
217                    if p.name != par.name:
218                        continue
219                    item1_2 = QtGui.QStandardItem(str(p.default))
220                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
221                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
222                    item1_5 = QtGui.QStandardItem(p.units)
223                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
224                    break
225                item1.appendRow([poly_item])
226
227            item2 = QtGui.QStandardItem(str(par.default))
228            item3 = QtGui.QStandardItem(str(par.limits[0]))
229            item4 = QtGui.QStandardItem(str(par.limits[1]))
230            item5 = QtGui.QStandardItem(par.units)
231            model.appendRow([item1, item2, item3, item4, item5])
232
233def calculateChi2(reference_data, current_data):
234    """
235    Calculate Chi2 value between two sets of data
236    """
237
238    # WEIGHING INPUT
239    #from sas.sasgui.perspectives.fitting.utils import get_weight
240    #flag = self.get_weight_flag()
241    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
242    chisqr = None
243    if reference_data is None:
244        return chisqr
245
246    # temporary default values for index and weight
247    index = None
248    weight = None
249
250    # Get data: data I, theory I, and data dI in order
251    if isinstance(reference_data, Data2D):
252        if index is None:
253            index = numpy.ones(len(current_data.data), dtype=bool)
254        if weight is not None:
255            current_data.err_data = weight
256        # get rid of zero error points
257        index = index & (current_data.err_data != 0)
258        index = index & (numpy.isfinite(current_data.data))
259        fn = current_data.data[index]
260        gn = reference_data.data[index]
261        en = current_data.err_data[index]
262    else:
263        # 1 d theory from model_thread is only in the range of index
264        if index is None:
265            index = numpy.ones(len(current_data.y), dtype=bool)
266        if weight is not None:
267            current_data.dy = weight
268        if current_data.dy is None or current_data.dy == []:
269            dy = numpy.ones(len(current_data.y))
270        else:
271            ## Set consistently w/AbstractFitengine:
272            # But this should be corrected later.
273            dy = deepcopy(current_data.dy)
274            dy[dy == 0] = 1
275        fn = current_data.y[index]
276        gn = reference_data.y
277        en = dy[index]
278    # Calculate the residual
279    try:
280        res = (fn - gn) / en
281    except ValueError:
282        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
283        return None
284
285    residuals = res[numpy.isfinite(res)]
286    chisqr = numpy.average(residuals * residuals)
287
288    return chisqr
289
290def residualsData1D(reference_data, current_data):
291    """
292    Calculate the residuals for difference of two Data1D sets
293    """
294    # temporary default values for index and weight
295    index = None
296    weight = None
297
298    # 1d theory from model_thread is only in the range of index
299    if current_data.dy is None or current_data.dy == []:
300        dy = numpy.ones(len(current_data.y))
301    else:
302        dy = weight if weight is not None else numpy.ones(len(current_data.y))
303        dy[dy == 0] = 1
304    fn = current_data.y[index][0]
305    gn = reference_data.y
306    en = dy[index][0]
307    # build residuals
308    residuals = Data1D()
309    if len(fn) == len(gn):
310        y = (fn - gn)/en
311        residuals.y = -y
312    else:
313        # TODO: fix case where applying new data from file on top of existing model data
314        try:
315            y = (fn - gn[index][0]) / en
316            residuals.y = y
317        except ValueError:
318            # value errors may show up every once in a while for malformed columns,
319            # just reuse what's there already
320            pass
321
322    residuals.x = current_data.x[index][0]
323    residuals.dy = numpy.ones(len(residuals.y))
324    residuals.dx = None
325    residuals.dxl = None
326    residuals.dxw = None
327    residuals.ytransform = 'y'
328    # For latter scale changes
329    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
330    residuals.yaxis('\\rm{Residuals} ', 'normalized')
331
332    return residuals
333
334def residualsData2D(reference_data, current_data):
335    """
336    Calculate the residuals for difference of two Data2D sets
337    """
338    # temporary default values for index and weight
339    # index = None
340    weight = None
341
342    # build residuals
343    residuals = Data2D()
344    # Not for trunk the line below, instead use the line above
345    current_data.clone_without_data(len(current_data.data), residuals)
346    residuals.data = None
347    fn = current_data.data
348    gn = reference_data.data
349    en = current_data.err_data if weight is None else weight
350    residuals.data = (fn - gn) / en
351    residuals.qx_data = current_data.qx_data
352    residuals.qy_data = current_data.qy_data
353    residuals.q_data = current_data.q_data
354    residuals.err_data = numpy.ones(len(residuals.data))
355    residuals.xmin = min(residuals.qx_data)
356    residuals.xmax = max(residuals.qx_data)
357    residuals.ymin = min(residuals.qy_data)
358    residuals.ymax = max(residuals.qy_data)
359    residuals.q_data = current_data.q_data
360    residuals.mask = current_data.mask
361    residuals.scale = 'linear'
362    # check the lengths
363    if len(residuals.data) != len(residuals.q_data):
364        return None
365    return residuals
366
367def plotResiduals(reference_data, current_data):
368    """
369    Create Data1D/Data2D with residuals, ready for plotting
370    """
371    data_copy = deepcopy(current_data)
372    # Get data: data I, theory I, and data dI in order
373    method_name = current_data.__class__.__name__
374    residuals_dict = {"Data1D": residualsData1D,
375                      "Data2D": residualsData2D}
376
377    residuals = residuals_dict[method_name](reference_data, data_copy)
378
379    theory_name = str(current_data.name.split()[0])
380    residuals.name = "Residuals for " + str(theory_name) + "[" + \
381                    str(reference_data.filename) + "]"
382    residuals.title = residuals.name
383    residuals.ytransform = 'y'
384
385    # when 2 data have the same id override the 1 st plotted
386    # include the last part if keeping charts for separate models is required
387    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
388    # group_id specify on which panel to plot this data
389    group_id = reference_data.group_id
390    residuals.group_id = "res" + str(group_id)
391
392    # Symbol
393    residuals.symbol = 0
394    residuals.hide_error = False
395
396    return residuals
397
398def binary_encode(i, digits):
399    return [i >> d & 1 for d in range(digits)]
400
401def getWeight(data, is2d, flag=None):
402    """
403    Received flag and compute error on data.
404    :param flag: flag to transform error of data.
405    """
406    weight = None
407    if is2d:
408        dy_data = data.err_data
409        data = data.data
410    else:
411        dy_data = data.dy
412        data = data.y
413
414    if flag == 0:
415        weight = numpy.ones_like(data)
416    elif flag == 1:
417        weight = dy_data
418    elif flag == 2:
419        weight = numpy.sqrt(numpy.abs(data))
420    elif flag == 3:
421        weight = numpy.abs(data)
422    return weight
Note: See TracBrowser for help on using the repository browser.