source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 1bc27f1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 1bc27f1 was 1bc27f1, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Code review fixes SASVIEW-588
Pylint related fixes in Perspectives/Fitting?

  • Property mode set to 100755
File size: 13.3 KB
Line 
1from copy import deepcopy
2
3from PyQt4 import QtGui
4from PyQt4 import QtCore
5
6import numpy
7
8from sas.sasgui.guiframe.dataFitting import Data1D
9from sas.sasgui.guiframe.dataFitting import Data2D
10
11def replaceShellName(param_name, value):
12    """
13    Updates parameter name from <param_name>[n_shell] to <param_name>value
14    """
15    assert '[' in param_name
16    return param_name[:param_name.index('[')]+str(value)
17
18def getIterParams(model):
19    """
20    Returns a list of all multi-shell parameters in 'model'
21    """
22    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
23
24def getMultiplicity(model):
25    """
26    Finds out if 'model' has multishell parameters.
27    If so, returns the name of the counter parameter and the number of shells
28    """
29    iter_params = getIterParams(model)
30    param_name = ""
31    param_length = 0
32    if iter_params:
33        param_length = iter_params[0].length
34        param_name = iter_params[0].length_control
35        if param_name is None and '[' in iter_params[0].name:
36            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
37    return (param_name, param_length)
38
39def addParametersToModel(parameters, is2D):
40    """
41    Update local ModelModel with sasmodel parameters
42    """
43    multishell_parameters = getIterParams(parameters)
44    multishell_param_name, _ = getMultiplicity(parameters)
45    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
46    item = []
47    for param in params:
48        # don't include shell parameters
49        if param.name == multishell_param_name:
50            continue
51        # Modify parameter name from <param>[n] to <param>1
52        item_name = param.name
53        if param in multishell_parameters:
54            continue
55        #    item_name = replaceShellName(param.name, 1)
56
57        item1 = QtGui.QStandardItem(item_name)
58        item1.setCheckable(True)
59        item1.setEditable(False)
60        # item_err = QtGui.QStandardItem()
61        # check for polydisp params
62        if param.polydisperse:
63            poly_item = QtGui.QStandardItem("Polydispersity")
64            poly_item.setEditable(False)
65            item1_1 = QtGui.QStandardItem("Distribution")
66            item1_1.setEditable(False)
67            # Find param in volume_params
68            for p in parameters.form_volume_parameters:
69                if p.name != param.name:
70                    continue
71                item1_2 = QtGui.QStandardItem(str(p.default))
72                item1_2.setEditable(False)
73                item1_3 = QtGui.QStandardItem(str(p.limits[0]))
74                item1_3.setEditable(False)
75                item1_4 = QtGui.QStandardItem(str(p.limits[1]))
76                item1_4.setEditable(False)
77                item1_5 = QtGui.QStandardItem(p.units)
78                item1_5.setEditable(False)
79                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
80                break
81            # Add the polydisp item as a child
82            item1.appendRow([poly_item])
83        # Param values
84        item2 = QtGui.QStandardItem(str(param.default))
85        # TODO: the error column.
86        # Either add a proxy model or a custom view delegate
87        #item_err = QtGui.QStandardItem()
88        item3 = QtGui.QStandardItem(str(param.limits[0]))
89        item4 = QtGui.QStandardItem(str(param.limits[1]))
90        item5 = QtGui.QStandardItem(param.units)
91        item5.setEditable(False)
92        item.append([item1, item2, item3, item4, item5])
93    return item
94
95def addSimpleParametersToModel(parameters, is2D):
96    """
97    Update local ModelModel with sasmodel parameters
98    """
99    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
100    item = []
101    for param in params:
102        # Create the top level, checkable item
103        item_name = param.name
104        item1 = QtGui.QStandardItem(item_name)
105        item1.setCheckable(True)
106        item1.setEditable(False)
107        # Param values
108        # TODO: add delegate for validation of cells
109        item2 = QtGui.QStandardItem(str(param.default))
110        item4 = QtGui.QStandardItem(str(param.limits[0]))
111        item5 = QtGui.QStandardItem(str(param.limits[1]))
112        item6 = QtGui.QStandardItem(param.units)
113        item6.setEditable(False)
114        item.append([item1, item2, item4, item5, item6])
115    return item
116
117def addCheckedListToModel(model, param_list):
118    """
119    Add a QItem to model. Makes the QItem checkable
120    """
121    assert isinstance(model, QtGui.QStandardItemModel)
122    item_list = [QtGui.QStandardItem(item) for item in param_list]
123    item_list[0].setCheckable(True)
124    model.appendRow(item_list)
125
126def addHeadersToModel(model):
127    """
128    Adds predefined headers to the model
129    """
130    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
131    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
132    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
133    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
134    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
135
136def addErrorHeadersToModel(model):
137    """
138    Adds predefined headers to the model
139    """
140    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
141    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
142    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
143    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
144    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
145    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
146
147def addPolyHeadersToModel(model):
148    """
149    Adds predefined headers to the model
150    """
151    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
152    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
153    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
154    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
155    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
156    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
157    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
158
159def addShellsToModel(parameters, model, index):
160    """
161    Find out multishell parameters and update the model with the requested number of them
162    """
163    multishell_parameters = getIterParams(parameters)
164
165    for i in xrange(index):
166        for par in multishell_parameters:
167            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
168            param_name = replaceShellName(par.name, i+1)
169            item1 = QtGui.QStandardItem(param_name)
170            item1.setCheckable(True)
171            # check for polydisp params
172            if par.polydisperse:
173                poly_item = QtGui.QStandardItem("Polydispersity")
174                item1_1 = QtGui.QStandardItem("Distribution")
175                # Find param in volume_params
176                for p in parameters.form_volume_parameters:
177                    if p.name != par.name:
178                        continue
179                    item1_2 = QtGui.QStandardItem(str(p.default))
180                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
181                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
182                    item1_5 = QtGui.QStandardItem(p.units)
183                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
184                    break
185                item1.appendRow([poly_item])
186
187            item2 = QtGui.QStandardItem(str(par.default))
188            item3 = QtGui.QStandardItem(str(par.limits[0]))
189            item4 = QtGui.QStandardItem(str(par.limits[1]))
190            item5 = QtGui.QStandardItem(par.units)
191            model.appendRow([item1, item2, item3, item4, item5])
192
193def calculateChi2(reference_data, current_data):
194    """
195    Calculate Chi2 value between two sets of data
196    """
197
198    # WEIGHING INPUT
199    #from sas.sasgui.perspectives.fitting.utils import get_weight
200    #flag = self.get_weight_flag()
201    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
202    chisqr = None
203    if reference_data is None:
204        return chisqr
205
206    # temporary default values for index and weight
207    index = None
208    weight = None
209
210    # Get data: data I, theory I, and data dI in order
211    if isinstance(reference_data, Data2D):
212        if index is None:
213            index = numpy.ones(len(current_data.data), dtype=bool)
214        if weight is not None:
215            current_data.err_data = weight
216        # get rid of zero error points
217        index = index & (current_data.err_data != 0)
218        index = index & (numpy.isfinite(current_data.data))
219        fn = current_data.data[index]
220        gn = reference_data.data[index]
221        en = current_data.err_data[index]
222    else:
223        # 1 d theory from model_thread is only in the range of index
224        if index is None:
225            index = numpy.ones(len(current_data.y), dtype=bool)
226        if weight is not None:
227            current_data.dy = weight
228        if current_data.dy is None or current_data.dy == []:
229            dy = numpy.ones(len(current_data.y))
230        else:
231            ## Set consistently w/AbstractFitengine:
232            # But this should be corrected later.
233            dy = deepcopy(current_data.dy)
234            dy[dy == 0] = 1
235        fn = current_data.y[index]
236        gn = reference_data.y
237        en = dy[index]
238    # Calculate the residual
239    try:
240        res = (fn - gn) / en
241    except ValueError:
242        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
243        return None
244
245    residuals = res[numpy.isfinite(res)]
246    chisqr = numpy.average(residuals * residuals)
247
248    return chisqr
249
250def residualsData1D(reference_data, current_data):
251    """
252    Calculate the residuals for difference of two Data1D sets
253    """
254    # temporary default values for index and weight
255    index = None
256    weight = None
257
258    # 1d theory from model_thread is only in the range of index
259    if current_data.dy is None or current_data.dy == []:
260        dy = numpy.ones(len(current_data.y))
261    else:
262        dy = weight if weight is not None else numpy.ones(len(current_data.y))
263        dy[dy == 0] = 1
264    fn = current_data.y[index][0]
265    gn = reference_data.y
266    en = dy[index][0]
267    # build residuals
268    residuals = Data1D()
269    if len(fn) == len(gn):
270        y = (fn - gn)/en
271        residuals.y = -y
272    else:
273        # TODO: fix case where applying new data from file on top of existing model data
274        y = (fn - gn[index][0]) / en
275        residuals.y = y
276
277    #try:
278    #    y = (fn - gn)/en
279    #    residuals.y = -y
280    #except ValueError:
281    #    msg = "ResidualPlot Error: different number of data points in theory"
282    #    print msg
283    #    y = (fn - gn[index][0]) / en
284    #    residuals.y = y
285    residuals.x = current_data.x[index][0]
286    residuals.dy = numpy.ones(len(residuals.y))
287    residuals.dx = None
288    residuals.dxl = None
289    residuals.dxw = None
290    residuals.ytransform = 'y'
291    # For latter scale changes
292    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
293    residuals.yaxis('\\rm{Residuals} ', 'normalized')
294
295    return residuals
296
297def residualsData2D(reference_data, current_data):
298    """
299    Calculate the residuals for difference of two Data2D sets
300    """
301    # temporary default values for index and weight
302    # index = None
303    weight = None
304
305    # build residuals
306    residuals = Data2D()
307    # Not for trunk the line below, instead use the line above
308    current_data.clone_without_data(len(current_data.data), residuals)
309    residuals.data = None
310    fn = current_data.data
311    gn = reference_data.data
312    en = current_data.err_data if weight is None else weight
313    residuals.data = (fn - gn) / en
314    residuals.qx_data = current_data.qx_data
315    residuals.qy_data = current_data.qy_data
316    residuals.q_data = current_data.q_data
317    residuals.err_data = numpy.ones(len(residuals.data))
318    residuals.xmin = min(residuals.qx_data)
319    residuals.xmax = max(residuals.qx_data)
320    residuals.ymin = min(residuals.qy_data)
321    residuals.ymax = max(residuals.qy_data)
322    residuals.q_data = current_data.q_data
323    residuals.mask = current_data.mask
324    residuals.scale = 'linear'
325    # check the lengths
326    if len(residuals.data) != len(residuals.q_data):
327        return None
328    return residuals
329
330def plotResiduals(reference_data, current_data):
331    """
332    Create Data1D/Data2D with residuals, ready for plotting
333    """
334    data_copy = deepcopy(current_data)
335    # Get data: data I, theory I, and data dI in order
336    method_name = current_data.__class__.__name__
337    residuals_dict = {"Data1D": residualsData1D,
338                      "Data2D": residualsData2D}
339
340    residuals = residuals_dict[method_name](reference_data, data_copy)
341
342    theory_name = str(current_data.name.split()[0])
343    residuals.name = "Residuals for " + str(theory_name) + "[" + \
344                    str(reference_data.filename) + "]"
345    residuals.title = residuals.name
346    residuals.ytransform = 'y'
347
348    # when 2 data have the same id override the 1 st plotted
349    # include the last part if keeping charts for separate models is required
350    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
351    # group_id specify on which panel to plot this data
352    group_id = reference_data.group_id
353    residuals.group_id = "res" + str(group_id)
354
355    # Symbol
356    residuals.symbol = 0
357    residuals.hide_error = False
358
359    return residuals
360
361
362def binary_encode(i, digits):
363    return [i >> d & 1 for d in xrange(digits)]
364
Note: See TracBrowser for help on using the repository browser.