source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 8222f171

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 8222f171 was 8222f171, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

SASVIEW-625: code review fixes. Corrected handling for cancelling the file open dialog, updated test cases.

  • Property mode set to 100644
File size: 14.7 KB
Line 
1from copy import deepcopy
2
3from PyQt4 import QtGui
4from PyQt4 import QtCore
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11def replaceShellName(param_name, value):
12    """
13    Updates parameter name from <param_name>[n_shell] to <param_name>value
14    """
15    assert '[' in param_name
16    return param_name[:param_name.index('[')]+str(value)
17
18def getIterParams(model):
19    """
20    Returns a list of all multi-shell parameters in 'model'
21    """
22    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
23
24def getMultiplicity(model):
25    """
26    Finds out if 'model' has multishell parameters.
27    If so, returns the name of the counter parameter and the number of shells
28    """
29    iter_params = getIterParams(model)
30    param_name = ""
31    param_length = 0
32    if iter_params:
33        param_length = iter_params[0].length
34        param_name = iter_params[0].length_control
35        if param_name is None and '[' in iter_params[0].name:
36            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
37    return (param_name, param_length)
38
39def addParametersToModel(parameters, kernel_module, is2D):
40    """
41    Update local ModelModel with sasmodel parameters
42    """
43    multishell_parameters = getIterParams(parameters)
44    multishell_param_name, _ = getMultiplicity(parameters)
45    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
46    item = []
47    for param in params:
48        # don't include shell parameters
49        if param.name == multishell_param_name:
50            continue
51        # Modify parameter name from <param>[n] to <param>1
52        item_name = param.name
53        if param in multishell_parameters:
54            continue
55        #    item_name = replaceShellName(param.name, 1)
56
57        item1 = QtGui.QStandardItem(item_name)
58        item1.setCheckable(True)
59        item1.setEditable(False)
60        # item_err = QtGui.QStandardItem()
61        # check for polydisp params
62        if param.polydisperse:
63            poly_item = QtGui.QStandardItem("Polydispersity")
64            poly_item.setEditable(False)
65            item1_1 = QtGui.QStandardItem("Distribution")
66            item1_1.setEditable(False)
67            # Find param in volume_params
68            for p in parameters.form_volume_parameters:
69                if p.name != param.name:
70                    continue
71                width = kernel_module.getParam(p.name+'.width')
72                type = kernel_module.getParam(p.name+'.type')
73
74                item1_2 = QtGui.QStandardItem(str(width))
75                item1_2.setEditable(False)
76                item1_3 = QtGui.QStandardItem()
77                item1_3.setEditable(False)
78                item1_4 = QtGui.QStandardItem()
79                item1_4.setEditable(False)
80                item1_5 = QtGui.QStandardItem(type)
81                item1_5.setEditable(False)
82                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
83                break
84            # Add the polydisp item as a child
85            item1.appendRow([poly_item])
86        # Param values
87        item2 = QtGui.QStandardItem(str(param.default))
88        # TODO: the error column.
89        # Either add a proxy model or a custom view delegate
90        #item_err = QtGui.QStandardItem()
91        item3 = QtGui.QStandardItem(str(param.limits[0]))
92        item4 = QtGui.QStandardItem(str(param.limits[1]))
93        item5 = QtGui.QStandardItem(param.units)
94        item5.setEditable(False)
95        item.append([item1, item2, item3, item4, item5])
96    return item
97
98def addSimpleParametersToModel(parameters, is2D):
99    """
100    Update local ModelModel with sasmodel parameters
101    """
102    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
103    item = []
104    for param in params:
105        # Create the top level, checkable item
106        item_name = param.name
107        item1 = QtGui.QStandardItem(item_name)
108        item1.setCheckable(True)
109        item1.setEditable(False)
110        # Param values
111        # TODO: add delegate for validation of cells
112        item2 = QtGui.QStandardItem(str(param.default))
113        item4 = QtGui.QStandardItem(str(param.limits[0]))
114        item5 = QtGui.QStandardItem(str(param.limits[1]))
115        item6 = QtGui.QStandardItem(param.units)
116        item6.setEditable(False)
117        item.append([item1, item2, item4, item5, item6])
118    return item
119
120def addCheckedListToModel(model, param_list):
121    """
122    Add a QItem to model. Makes the QItem checkable
123    """
124    assert isinstance(model, QtGui.QStandardItemModel)
125    item_list = [QtGui.QStandardItem(item) for item in param_list]
126    item_list[0].setCheckable(True)
127    model.appendRow(item_list)
128
129def addHeadersToModel(model):
130    """
131    Adds predefined headers to the model
132    """
133    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
134    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
135    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
136    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
137    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Units"))
138
139def addErrorHeadersToModel(model):
140    """
141    Adds predefined headers to the model
142    """
143    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
144    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
145    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
146    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
147    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
148    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Units"))
149
150def addPolyHeadersToModel(model):
151    """
152    Adds predefined headers to the model
153    """
154    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
155    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
156    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
157    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
158    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
159    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
160    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
161    model.setHeaderData(7, QtCore.Qt.Horizontal, QtCore.QVariant("Filename"))
162
163def addErrorPolyHeadersToModel(model):
164    """
165    Adds predefined headers to the model
166    """
167    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
168    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
169    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
170    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
171    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
172    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
173    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
174    model.setHeaderData(7, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
175    model.setHeaderData(8, QtCore.Qt.Horizontal, QtCore.QVariant("Filename"))
176
177def addShellsToModel(parameters, model, index):
178    """
179    Find out multishell parameters and update the model with the requested number of them
180    """
181    multishell_parameters = getIterParams(parameters)
182
183    for i in xrange(index):
184        for par in multishell_parameters:
185            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
186            param_name = replaceShellName(par.name, i+1)
187            item1 = QtGui.QStandardItem(param_name)
188            item1.setCheckable(True)
189            # check for polydisp params
190            if par.polydisperse:
191                poly_item = QtGui.QStandardItem("Polydispersity")
192                item1_1 = QtGui.QStandardItem("Distribution")
193                # Find param in volume_params
194                for p in parameters.form_volume_parameters:
195                    if p.name != par.name:
196                        continue
197                    item1_2 = QtGui.QStandardItem(str(p.default))
198                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
199                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
200                    item1_5 = QtGui.QStandardItem(p.units)
201                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
202                    break
203                item1.appendRow([poly_item])
204
205            item2 = QtGui.QStandardItem(str(par.default))
206            item3 = QtGui.QStandardItem(str(par.limits[0]))
207            item4 = QtGui.QStandardItem(str(par.limits[1]))
208            item5 = QtGui.QStandardItem(par.units)
209            model.appendRow([item1, item2, item3, item4, item5])
210
211def calculateChi2(reference_data, current_data):
212    """
213    Calculate Chi2 value between two sets of data
214    """
215
216    # WEIGHING INPUT
217    #from sas.sasgui.perspectives.fitting.utils import get_weight
218    #flag = self.get_weight_flag()
219    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
220    chisqr = None
221    if reference_data is None:
222        return chisqr
223
224    # temporary default values for index and weight
225    index = None
226    weight = None
227
228    # Get data: data I, theory I, and data dI in order
229    if isinstance(reference_data, Data2D):
230        if index is None:
231            index = numpy.ones(len(current_data.data), dtype=bool)
232        if weight is not None:
233            current_data.err_data = weight
234        # get rid of zero error points
235        index = index & (current_data.err_data != 0)
236        index = index & (numpy.isfinite(current_data.data))
237        fn = current_data.data[index]
238        gn = reference_data.data[index]
239        en = current_data.err_data[index]
240    else:
241        # 1 d theory from model_thread is only in the range of index
242        if index is None:
243            index = numpy.ones(len(current_data.y), dtype=bool)
244        if weight is not None:
245            current_data.dy = weight
246        if current_data.dy is None or current_data.dy == []:
247            dy = numpy.ones(len(current_data.y))
248        else:
249            ## Set consistently w/AbstractFitengine:
250            # But this should be corrected later.
251            dy = deepcopy(current_data.dy)
252            dy[dy == 0] = 1
253        fn = current_data.y[index]
254        gn = reference_data.y
255        en = dy[index]
256    # Calculate the residual
257    try:
258        res = (fn - gn) / en
259    except ValueError:
260        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
261        return None
262
263    residuals = res[numpy.isfinite(res)]
264    chisqr = numpy.average(residuals * residuals)
265
266    return chisqr
267
268def residualsData1D(reference_data, current_data):
269    """
270    Calculate the residuals for difference of two Data1D sets
271    """
272    # temporary default values for index and weight
273    index = None
274    weight = None
275
276    # 1d theory from model_thread is only in the range of index
277    if current_data.dy is None or current_data.dy == []:
278        dy = numpy.ones(len(current_data.y))
279    else:
280        dy = weight if weight is not None else numpy.ones(len(current_data.y))
281        dy[dy == 0] = 1
282    fn = current_data.y[index][0]
283    gn = reference_data.y
284    en = dy[index][0]
285    # build residuals
286    residuals = Data1D()
287    if len(fn) == len(gn):
288        y = (fn - gn)/en
289        residuals.y = -y
290    else:
291        # TODO: fix case where applying new data from file on top of existing model data
292        try:
293            y = (fn - gn[index][0]) / en
294            residuals.y = y
295        except ValueError:
296            # value errors may show up every once in a while for malformed columns,
297            # just reuse what's there already
298            pass
299
300    residuals.x = current_data.x[index][0]
301    residuals.dy = numpy.ones(len(residuals.y))
302    residuals.dx = None
303    residuals.dxl = None
304    residuals.dxw = None
305    residuals.ytransform = 'y'
306    # For latter scale changes
307    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
308    residuals.yaxis('\\rm{Residuals} ', 'normalized')
309
310    return residuals
311
312def residualsData2D(reference_data, current_data):
313    """
314    Calculate the residuals for difference of two Data2D sets
315    """
316    # temporary default values for index and weight
317    # index = None
318    weight = None
319
320    # build residuals
321    residuals = Data2D()
322    # Not for trunk the line below, instead use the line above
323    current_data.clone_without_data(len(current_data.data), residuals)
324    residuals.data = None
325    fn = current_data.data
326    gn = reference_data.data
327    en = current_data.err_data if weight is None else weight
328    residuals.data = (fn - gn) / en
329    residuals.qx_data = current_data.qx_data
330    residuals.qy_data = current_data.qy_data
331    residuals.q_data = current_data.q_data
332    residuals.err_data = numpy.ones(len(residuals.data))
333    residuals.xmin = min(residuals.qx_data)
334    residuals.xmax = max(residuals.qx_data)
335    residuals.ymin = min(residuals.qy_data)
336    residuals.ymax = max(residuals.qy_data)
337    residuals.q_data = current_data.q_data
338    residuals.mask = current_data.mask
339    residuals.scale = 'linear'
340    # check the lengths
341    if len(residuals.data) != len(residuals.q_data):
342        return None
343    return residuals
344
345def plotResiduals(reference_data, current_data):
346    """
347    Create Data1D/Data2D with residuals, ready for plotting
348    """
349    data_copy = deepcopy(current_data)
350    # Get data: data I, theory I, and data dI in order
351    method_name = current_data.__class__.__name__
352    residuals_dict = {"Data1D": residualsData1D,
353                      "Data2D": residualsData2D}
354
355    residuals = residuals_dict[method_name](reference_data, data_copy)
356
357    theory_name = str(current_data.name.split()[0])
358    residuals.name = "Residuals for " + str(theory_name) + "[" + \
359                    str(reference_data.filename) + "]"
360    residuals.title = residuals.name
361    residuals.ytransform = 'y'
362
363    # when 2 data have the same id override the 1 st plotted
364    # include the last part if keeping charts for separate models is required
365    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
366    # group_id specify on which panel to plot this data
367    group_id = reference_data.group_id
368    residuals.group_id = "res" + str(group_id)
369
370    # Symbol
371    residuals.symbol = 0
372    residuals.hide_error = False
373
374    return residuals
375
376def binary_encode(i, digits):
377    return [i >> d & 1 for d in xrange(digits)]
378
379def getWeight(data, is2d, flag=None):
380    """
381    Received flag and compute error on data.
382    :param flag: flag to transform error of data.
383    """
384    weight = None
385    if is2d:
386        dy_data = data.err_data
387        data = data.data
388    else:
389        dy_data = data.dy
390        data = data.y
391
392    if flag == 0:
393        weight = numpy.ones_like(data)
394    elif flag == 1:
395        weight = dy_data
396    elif flag == 2:
397        weight = numpy.sqrt(numpy.abs(data))
398    elif flag == 3:
399        weight = numpy.abs(data)
400    return weight
Note: See TracBrowser for help on using the repository browser.