source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 2add354

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 2add354 was 2add354, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Code review fixes for SASVIEW-273

  • Property mode set to 100755
File size: 13.3 KB
Line 
1from PyQt4 import QtGui
2from PyQt4 import QtCore
3
4import numpy
5from copy import deepcopy
6
7from sas.sasgui.guiframe.dataFitting import Data1D
8from sas.sasgui.guiframe.dataFitting import Data2D
9
10def replaceShellName(param_name, value):
11    """
12    Updates parameter name from <param_name>[n_shell] to <param_name>value
13    """
14    assert '[' in param_name
15    return param_name[:param_name.index('[')]+str(value)
16
17def getIterParams(model):
18    """
19    Returns a list of all multi-shell parameters in 'model'
20    """
21    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
22
23def getMultiplicity(model):
24    """
25    Finds out if 'model' has multishell parameters.
26    If so, returns the name of the counter parameter and the number of shells
27    """
28    iter_params = getIterParams(model)
29    param_name = ""
30    param_length = 0
31    if iter_params:
32        param_length = iter_params[0].length
33        param_name = iter_params[0].length_control
34        if param_name is None and '[' in iter_params[0].name:
35            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
36    return (param_name, param_length)
37
38def addParametersToModel(parameters, is2D):
39    """
40    Update local ModelModel with sasmodel parameters
41    """
42    multishell_parameters = getIterParams(parameters)
43    multishell_param_name, _ = getMultiplicity(parameters)
44    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
45    item = []
46    for param in params:
47        # don't include shell parameters
48        if param.name == multishell_param_name:
49            continue
50        # Modify parameter name from <param>[n] to <param>1
51        item_name = param.name
52        if param in multishell_parameters:
53            continue
54        #    item_name = replaceShellName(param.name, 1)
55
56        item1 = QtGui.QStandardItem(item_name)
57        item1.setCheckable(True)
58        item1.setEditable(False)
59        item_err = QtGui.QStandardItem()
60        # check for polydisp params
61        if param.polydisperse:
62            poly_item = QtGui.QStandardItem("Polydispersity")
63            poly_item.setEditable(False)
64            item1_1 = QtGui.QStandardItem("Distribution")
65            item1_1.setEditable(False)
66            # Find param in volume_params
67            for p in parameters.form_volume_parameters:
68                if p.name != param.name:
69                    continue
70                item1_2 = QtGui.QStandardItem(str(p.default))
71                item1_2.setEditable(False)
72                item1_3 = QtGui.QStandardItem(str(p.limits[0]))
73                item1_3.setEditable(False)
74                item1_4 = QtGui.QStandardItem(str(p.limits[1]))
75                item1_4.setEditable(False)
76                item1_5 = QtGui.QStandardItem(p.units)
77                item1_5.setEditable(False)
78                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
79                break
80            # Add the polydisp item as a child
81            item1.appendRow([poly_item])
82        # Param values
83        item2 = QtGui.QStandardItem(str(param.default))
84        # TODO: the error column.
85        # Either add a proxy model or a custom view delegate
86        #item_err = QtGui.QStandardItem()
87        item3 = QtGui.QStandardItem(str(param.limits[0]))
88        item4 = QtGui.QStandardItem(str(param.limits[1]))
89        item5 = QtGui.QStandardItem(param.units)
90        item5.setEditable(False)
91        item.append([item1, item2, item3, item4, item5])
92    return item
93
94def addSimpleParametersToModel(parameters, is2D):
95    """
96    Update local ModelModel with sasmodel parameters
97    """
98    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
99    item = []
100    for param in params:
101        # Create the top level, checkable item
102        item_name = param.name
103        item1 = QtGui.QStandardItem(item_name)
104        item1.setCheckable(True)
105        item1.setEditable(False)
106        # Param values
107        # TODO: add delegate for validation of cells
108        item2 = QtGui.QStandardItem(str(param.default))
109        item4 = QtGui.QStandardItem(str(param.limits[0]))
110        item5 = QtGui.QStandardItem(str(param.limits[1]))
111        item6 = QtGui.QStandardItem(param.units)
112        item6.setEditable(False)
113        item.append([item1, item2, item4, item5, item6])
114    return item
115
116def addCheckedListToModel(model, param_list):
117    """
118    Add a QItem to model. Makes the QItem checkable
119    """
120    assert isinstance(model, QtGui.QStandardItemModel)
121    item_list = [QtGui.QStandardItem(item) for item in param_list]
122    item_list[0].setCheckable(True)
123    model.appendRow(item_list)
124
125def addHeadersToModel(model):
126    """
127    Adds predefined headers to the model
128    """
129    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
130    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
131    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
132    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
133    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
134
135def addErrorHeadersToModel(model):
136    """
137    Adds predefined headers to the model
138    """
139    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
140    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
141    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
142    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
143    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
144    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
145
146def addPolyHeadersToModel(model):
147    """
148    Adds predefined headers to the model
149    """
150    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
151    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
152    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
153    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
154    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
155    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
156    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
157
158def addShellsToModel(parameters, model, index):
159    """
160    Find out multishell parameters and update the model with the requested number of them
161    """
162    multishell_parameters = getIterParams(parameters)
163
164    for i in xrange(index):
165        for par in multishell_parameters:
166            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
167            param_name = replaceShellName(par.name, i+1)
168            item1 = QtGui.QStandardItem(param_name)
169            item1.setCheckable(True)
170            # check for polydisp params
171            if par.polydisperse:
172                poly_item = QtGui.QStandardItem("Polydispersity")
173                item1_1 = QtGui.QStandardItem("Distribution")
174                # Find param in volume_params
175                for p in parameters.form_volume_parameters:
176                    if p.name != par.name:
177                        continue
178                    item1_2 = QtGui.QStandardItem(str(p.default))
179                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
180                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
181                    item1_5 = QtGui.QStandardItem(p.units)
182                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
183                    break
184                item1.appendRow([poly_item])
185
186            item2 = QtGui.QStandardItem(str(par.default))
187            item3 = QtGui.QStandardItem(str(par.limits[0]))
188            item4 = QtGui.QStandardItem(str(par.limits[1]))
189            item5 = QtGui.QStandardItem(par.units)
190            model.appendRow([item1, item2, item3, item4, item5])
191
192def calculateChi2(reference_data, current_data):
193    """
194    Calculate Chi2 value between two sets of data
195    """
196
197    # WEIGHING INPUT
198    #from sas.sasgui.perspectives.fitting.utils import get_weight
199    #flag = self.get_weight_flag()
200    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
201
202    if reference_data == None:
203       return chisqr
204
205    # temporary default values for index and weight
206    index = None
207    weight = None
208
209    # Get data: data I, theory I, and data dI in order
210    if isinstance(reference_data, Data2D):
211        if index == None:
212            index = numpy.ones(len(current_data.data), dtype=bool)
213        if weight != None:
214            current_data.err_data = weight
215        # get rid of zero error points
216        index = index & (current_data.err_data != 0)
217        index = index & (numpy.isfinite(current_data.data))
218        fn = current_data.data[index]
219        gn = reference_data.data[index]
220        en = current_data.err_data[index]
221    else:
222        # 1 d theory from model_thread is only in the range of index
223        if index == None:
224            index = numpy.ones(len(current_data.y), dtype=bool)
225        if weight != None:
226            current_data.dy = weight
227        if current_data.dy == None or current_data.dy == []:
228            dy = numpy.ones(len(current_data.y))
229        else:
230            ## Set consistently w/AbstractFitengine:
231            # But this should be corrected later.
232            dy = deepcopy(current_data.dy)
233            dy[dy == 0] = 1
234        fn = current_data.y[index]
235        gn = reference_data.y
236        en = dy[index]
237    # Calculate the residual
238    try:
239        res = (fn - gn) / en
240    except ValueError:
241        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
242        return None
243
244    residuals = res[numpy.isfinite(res)]
245    chisqr = numpy.average(residuals * residuals)
246
247    return chisqr
248
249def residualsData1D(reference_data, current_data):
250    """
251    Calculate the residuals for difference of two Data1D sets
252    """
253    # temporary default values for index and weight
254    index = None
255    weight = None
256
257    # 1d theory from model_thread is only in the range of index
258    if current_data.dy is None or current_data.dy == []:
259        dy = numpy.ones(len(current_data.y))
260    else:
261        dy = weight if weight is not None else numpy.ones(len(current_data.y))
262        dy[dy == 0] = 1
263    fn = current_data.y[index][0]
264    gn = reference_data.y
265    en = dy[index][0]
266    # build residuals
267    residuals = Data1D()
268    if len(fn) == len(gn):
269        y = (fn - gn)/en
270        residuals.y = -y
271    else:
272        # TODO: fix case where applying new data from file on top of existing model data
273        y = (fn - gn[index][0]) / en
274        residuals.y = y
275
276    #try:
277    #    y = (fn - gn)/en
278    #    residuals.y = -y
279    #except ValueError:
280    #    msg = "ResidualPlot Error: different number of data points in theory"
281    #    print msg
282    #    y = (fn - gn[index][0]) / en
283    #    residuals.y = y
284    residuals.x = current_data.x[index][0]
285    residuals.dy = numpy.ones(len(residuals.y))
286    residuals.dx = None
287    residuals.dxl = None
288    residuals.dxw = None
289    residuals.ytransform = 'y'
290    # For latter scale changes
291    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
292    residuals.yaxis('\\rm{Residuals} ', 'normalized')
293
294    return residuals
295
296def residualsData2D(reference_data, current_data):
297    """
298    Calculate the residuals for difference of two Data2D sets
299    """
300    # temporary default values for index and weight
301    index = None
302    weight = None
303
304    # build residuals
305    residuals = Data2D()
306    # Not for trunk the line below, instead use the line above
307    current_data.clone_without_data(len(current_data.data), residuals)
308    residuals.data = None
309    fn = current_data.data
310    gn = reference_data.data
311    en = current_data.err_data if weight is None else weight
312    residuals.data = (fn - gn) / en
313    residuals.qx_data = current_data.qx_data
314    residuals.qy_data = current_data.qy_data
315    residuals.q_data = current_data.q_data
316    residuals.err_data = numpy.ones(len(residuals.data))
317    residuals.xmin = min(residuals.qx_data)
318    residuals.xmax = max(residuals.qx_data)
319    residuals.ymin = min(residuals.qy_data)
320    residuals.ymax = max(residuals.qy_data)
321    residuals.q_data = current_data.q_data
322    residuals.mask = current_data.mask
323    residuals.scale = 'linear'
324    # check the lengths
325    if len(residuals.data) != len(residuals.q_data):
326        return None
327    return residuals
328
329def plotResiduals(reference_data, current_data):
330    """
331    Create Data1D/Data2D with residuals, ready for plotting
332    """
333    data_copy = deepcopy(current_data)
334    # Get data: data I, theory I, and data dI in order
335    method_name = current_data.__class__.__name__
336    residuals_dict = {"Data1D": residualsData1D,
337                      "Data2D": residualsData2D}
338
339    residuals = residuals_dict[method_name](reference_data, data_copy)
340
341    theory_name = str(current_data.name.split()[0])
342    residuals.name = "Residuals for " + str(theory_name) + "[" + \
343                    str(reference_data.filename) + "]"
344    residuals.title = residuals.name
345    residuals.ytransform = 'y'
346
347    # when 2 data have the same id override the 1 st plotted
348    # include the last part if keeping charts for separate models is required
349    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
350    # group_id specify on which panel to plot this data
351    group_id = reference_data.group_id
352    residuals.group_id = "res" + str(group_id)
353   
354    # Symbol
355    residuals.symbol = 0
356    residuals.hide_error = False
357
358    return residuals
359
360
361def binary_encode(i, digits):
362    return [i >> d & 1 for d in xrange(digits)]
363
Note: See TracBrowser for help on using the repository browser.