source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 0268aed

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 0268aed was 0268aed, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Plotting residuals in fitting.
PlotHelper? updates.
Minor refactoring.

  • Property mode set to 100755
File size: 11.9 KB
Line 
1from PyQt4 import QtGui
2from PyQt4 import QtCore
3
4import numpy
5from copy import deepcopy
6
7from sas.sasgui.guiframe.dataFitting import Data1D
8from sas.sasgui.guiframe.dataFitting import Data2D
9
10def replaceShellName(param_name, value):
11    """
12    Updates parameter name from <param_name>[n_shell] to <param_name>value
13    """
14    assert '[' in param_name
15    return param_name[:param_name.index('[')]+str(value)
16
17def getIterParams(model):
18    """
19    Returns a list of all multi-shell parameters in 'model'
20    """
21    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
22
23def getMultiplicity(model):
24    """
25    Finds out if 'model' has multishell parameters.
26    If so, returns the name of the counter parameter and the number of shells
27    """
28    iter_params = getIterParams(model)
29    param_name = ""
30    param_length = 0
31    if iter_params:
32        param_length = iter_params[0].length
33        param_name = iter_params[0].length_control
34        if param_name is None and '[' in iter_params[0].name:
35            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
36    return (param_name, param_length)
37
38def addParametersToModel(parameters, model):
39    """
40    Update local ModelModel with sasmodel parameters
41    """
42    multishell_parameters = getIterParams(parameters)
43    multishell_param_name, _ = getMultiplicity(parameters)
44
45    for param in parameters.iq_parameters:
46        # don't include shell parameters
47        if param.name == multishell_param_name:
48            continue
49        # Modify parameter name from <param>[n] to <param>1
50        item_name = param.name
51        if param in multishell_parameters:
52            continue
53        #    item_name = replaceShellName(param.name, 1)
54
55        item1 = QtGui.QStandardItem(item_name)
56        item1.setCheckable(True)
57        # check for polydisp params
58        if param.polydisperse:
59            poly_item = QtGui.QStandardItem("Polydispersity")
60            item1_1 = QtGui.QStandardItem("Distribution")
61            # Find param in volume_params
62            for p in parameters.form_volume_parameters:
63                if p.name != param.name:
64                    continue
65                item1_2 = QtGui.QStandardItem(str(p.default))
66                item1_3 = QtGui.QStandardItem(str(p.limits[0]))
67                item1_4 = QtGui.QStandardItem(str(p.limits[1]))
68                item1_5 = QtGui.QStandardItem(p.units)
69                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
70                break
71            # Add the polydisp item as a child
72            item1.appendRow([poly_item])
73        # Param values
74        item2 = QtGui.QStandardItem(str(param.default))
75        # TODO: the error column.
76        # Either add a proxy model or a custom view delegate
77        #item_err = QtGui.QStandardItem()
78        item3 = QtGui.QStandardItem(str(param.limits[0]))
79        item4 = QtGui.QStandardItem(str(param.limits[1]))
80        item5 = QtGui.QStandardItem(param.units)
81        model.appendRow([item1, item2, item3, item4, item5])
82
83def addSimpleParametersToModel(parameters, model):
84    """
85    Update local ModelModel with sasmodel parameters
86    """
87    for param in parameters.iq_parameters:
88        # Create the top level, checkable item
89        item_name = param.name
90        item1 = QtGui.QStandardItem(item_name)
91        item1.setCheckable(True)
92        # Param values
93        item2 = QtGui.QStandardItem(str(param.default))
94        # TODO: the error column.
95        # Either add a proxy model or a custom view delegate
96        #item_err = QtGui.QStandardItem()
97        item3 = QtGui.QStandardItem(str(param.limits[0]))
98        item4 = QtGui.QStandardItem(str(param.limits[1]))
99        item5 = QtGui.QStandardItem(param.units)
100        model.appendRow([item1, item2, item3, item4, item5])
101
102def addCheckedListToModel(model, param_list):
103    """
104    Add a QItem to model. Makes the QItem checkable
105    """
106    assert isinstance(model, QtGui.QStandardItemModel)
107    item_list = [QtGui.QStandardItem(item) for item in param_list]
108    item_list[0].setCheckable(True)
109    model.appendRow(item_list)
110
111def addHeadersToModel(model):
112    """
113    Adds predefined headers to the model
114    """
115    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
116    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
117    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
118    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
119    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
120
121def addPolyHeadersToModel(model):
122    """
123    Adds predefined headers to the model
124    """
125    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
126    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
127    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
128    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
129    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
130    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
131    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
132
133def addShellsToModel(parameters, model, index):
134    """
135    Find out multishell parameters and update the model with the requested number of them
136    """
137    multishell_parameters = getIterParams(parameters)
138
139    for i in xrange(index):
140        for par in multishell_parameters:
141            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
142            param_name = replaceShellName(par.name, i+1)
143            item1 = QtGui.QStandardItem(param_name)
144            item1.setCheckable(True)
145            # check for polydisp params
146            if par.polydisperse:
147                poly_item = QtGui.QStandardItem("Polydispersity")
148                item1_1 = QtGui.QStandardItem("Distribution")
149                # Find param in volume_params
150                for p in parameters.form_volume_parameters:
151                    if p.name != par.name:
152                        continue
153                    item1_2 = QtGui.QStandardItem(str(p.default))
154                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
155                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
156                    item1_5 = QtGui.QStandardItem(p.units)
157                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
158                    break
159                item1.appendRow([poly_item])
160
161            item2 = QtGui.QStandardItem(str(par.default))
162            item3 = QtGui.QStandardItem(str(par.limits[0]))
163            item4 = QtGui.QStandardItem(str(par.limits[1]))
164            item5 = QtGui.QStandardItem(par.units)
165            model.appendRow([item1, item2, item3, item4, item5])
166
167def calculateChi2(reference_data, current_data):
168    """
169    Calculate Chi2 value between two sets of data
170    """
171
172    # WEIGHING INPUT
173    #from sas.sasgui.perspectives.fitting.utils import get_weight
174    #flag = self.get_weight_flag()
175    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
176
177    if reference_data == None:
178       return chisqr
179
180    # temporary default values for index and weight
181    index = None
182    weight = None
183
184    # Get data: data I, theory I, and data dI in order
185    if isinstance(reference_data, Data2D):
186        if index == None:
187            index = numpy.ones(len(current_data.data), dtype=bool)
188        if weight != None:
189            current_data.err_data = weight
190        # get rid of zero error points
191        index = index & (current_data.err_data != 0)
192        index = index & (numpy.isfinite(current_data.data))
193        fn = current_data.data[index]
194        gn = reference_data.data[index]
195        en = current_data.err_data[index]
196    else:
197        # 1 d theory from model_thread is only in the range of index
198        if index == None:
199            index = numpy.ones(len(current_data.y), dtype=bool)
200        if weight != None:
201            current_data.dy = weight
202        if current_data.dy == None or current_data.dy == []:
203            dy = numpy.ones(len(current_data.y))
204        else:
205            ## Set consistently w/AbstractFitengine:
206            # But this should be corrected later.
207            dy = deepcopy(current_data.dy)
208            dy[dy == 0] = 1
209        fn = current_data.y[index]
210        gn = reference_data.y
211        en = dy[index]
212    # Calculate the residual
213    try:
214        res = (fn - gn) / en
215    except ValueError:
216        print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
217        return None
218
219    residuals = res[numpy.isfinite(res)]
220    chisqr = numpy.average(residuals * residuals)
221
222    return chisqr
223
224def residualsData1D(reference_data, current_data):
225    """
226    """
227    # temporary default values for index and weight
228    index = None
229    weight = None
230
231    # 1d theory from model_thread is only in the range of index
232    if current_data.dy == None or current_data.dy == []:
233        dy = numpy.ones(len(current_data.y))
234    else:
235        if weight == None:
236            dy = numpy.ones(len(current_data.y))
237        else:
238            dy = weight
239        dy[dy == 0] = 1
240    fn = current_data.y[index][0]
241    gn = reference_data.y
242    en = dy[index][0]
243    # build residuals
244    residuals = Data1D()
245    try:
246        y = (fn - gn)/en
247        residuals.y = -y
248    except:
249        msg = "ResidualPlot Error: different # of data points in theory"
250        print msg
251        y = (fn - gn[index][0]) / en
252        residuals.y = y
253    residuals.x = current_data.x[index][0]
254    residuals.dy = numpy.ones(len(residuals.y))
255    residuals.dx = None
256    residuals.dxl = None
257    residuals.dxw = None
258    residuals.ytransform = 'y'
259    # For latter scale changes
260    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
261    residuals.yaxis('\\rm{Residuals} ', 'normalized')
262
263    return residuals
264
265def residualsData2D(reference_data, current_data):
266    """
267    """
268    # temporary default values for index and weight
269    index = None
270    weight = None
271
272    # build residuals
273    residuals = Data2D()
274    # Not for trunk the line below, instead use the line above
275    current_data.clone_without_data(len(current_data.data), residuals)
276    residuals.data = None
277    fn = current_data.data
278    gn = reference_data.data
279    if weight == None:
280        en = current_data.err_data
281    else:
282        en = weight
283    residuals.data = (fn - gn) / en
284    residuals.qx_data = current_data.qx_data
285    residuals.qy_data = current_data.qy_data
286    residuals.q_data = current_data.q_data
287    residuals.err_data = numpy.ones(len(residuals.data))
288    residuals.xmin = min(residuals.qx_data)
289    residuals.xmax = max(residuals.qx_data)
290    residuals.ymin = min(residuals.qy_data)
291    residuals.ymax = max(residuals.qy_data)
292    residuals.q_data = current_data.q_data
293    residuals.mask = current_data.mask
294    residuals.scale = 'linear'
295    # check the lengths
296    if len(residuals.data) != len(residuals.q_data):
297        return None
298    return residuals
299
300def plotResiduals(reference_data, current_data):
301    """
302    Create Data1D/Data2D with residuals, ready for plotting
303    """
304    data_copy = deepcopy(current_data)
305    # Get data: data I, theory I, and data dI in order
306
307    method_name = current_data.__class__.__name__
308    residuals_dict = {"Data1D": residualsData1D,
309                      "Data2D": residualsData2D}
310
311    residuals = residuals_dict[method_name](reference_data, data_copy)
312
313    theory_name = str(current_data.name.split()[0])
314    residuals.name = "Residuals for " + str(theory_name) + "[" + \
315                    str(reference_data.filename) + "]"
316    residuals.title = residuals.name
317    # when 2 data have the same id override the 1 st plotted
318    # include the last part if keeping charts for separate models is required
319    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
320    # group_id specify on which panel to plot this data
321    group_id = reference_data.group_id
322    residuals.group_id = "res" + str(group_id)
323   
324    # Symbol
325    residuals.symbol = 0
326    residuals.hide_error = False
327
328    return residuals
329
330
331def binary_encode(i, digits):
332    return [i >> d & 1 for d in xrange(digits)]
333
Note: See TracBrowser for help on using the repository browser.