source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 0215e0a

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 0215e0a was 180bd54, checked in by Piotr Rozyczko <rozyczko@…>, 8 years ago

Refactored fitting options tab

  • Property mode set to 100755
File size: 12.9 KB
Line 
1from PyQt4 import QtGui
2from PyQt4 import QtCore
3
4import numpy
5from copy import deepcopy
6
7from sas.sasgui.guiframe.dataFitting import Data1D
8from sas.sasgui.guiframe.dataFitting import Data2D
9
10def replaceShellName(param_name, value):
11    """
12    Updates parameter name from <param_name>[n_shell] to <param_name>value
13    """
14    assert '[' in param_name
15    return param_name[:param_name.index('[')]+str(value)
16
17def getIterParams(model):
18    """
19    Returns a list of all multi-shell parameters in 'model'
20    """
21    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
22
23def getMultiplicity(model):
24    """
25    Finds out if 'model' has multishell parameters.
26    If so, returns the name of the counter parameter and the number of shells
27    """
28    iter_params = getIterParams(model)
29    param_name = ""
30    param_length = 0
31    if iter_params:
32        param_length = iter_params[0].length
33        param_name = iter_params[0].length_control
34        if param_name is None and '[' in iter_params[0].name:
35            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
36    return (param_name, param_length)
37
38def addParametersToModel(parameters, is2D):
39    """
40    Update local ModelModel with sasmodel parameters
41    """
42    multishell_parameters = getIterParams(parameters)
43    multishell_param_name, _ = getMultiplicity(parameters)
44    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
45    item = []
46    for param in params:
47        # don't include shell parameters
48        if param.name == multishell_param_name:
49            continue
50        # Modify parameter name from <param>[n] to <param>1
51        item_name = param.name
52        if param in multishell_parameters:
53            continue
54        #    item_name = replaceShellName(param.name, 1)
55
56        item1 = QtGui.QStandardItem(item_name)
57        item1.setCheckable(True)
58        item_err = QtGui.QStandardItem()
59        # check for polydisp params
60        if param.polydisperse:
61            poly_item = QtGui.QStandardItem("Polydispersity")
62            item1_1 = QtGui.QStandardItem("Distribution")
63            # Find param in volume_params
64            for p in parameters.form_volume_parameters:
65                if p.name != param.name:
66                    continue
67                item1_2 = QtGui.QStandardItem(str(p.default))
68                item1_3 = QtGui.QStandardItem(str(p.limits[0]))
69                item1_4 = QtGui.QStandardItem(str(p.limits[1]))
70                item1_5 = QtGui.QStandardItem(p.units)
71                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
72                break
73            # Add the polydisp item as a child
74            item1.appendRow([poly_item])
75        # Param values
76        item2 = QtGui.QStandardItem(str(param.default))
77        # TODO: the error column.
78        # Either add a proxy model or a custom view delegate
79        #item_err = QtGui.QStandardItem()
80        item3 = QtGui.QStandardItem(str(param.limits[0]))
81        item4 = QtGui.QStandardItem(str(param.limits[1]))
82        item5 = QtGui.QStandardItem(param.units)
83        item.append([item1, item2, item3, item4, item5])
84    return item
85
86def addSimpleParametersToModel(parameters, is2D):
87    """
88    Update local ModelModel with sasmodel parameters
89    """
90    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
91    item = []
92    for param in params:
93        # Create the top level, checkable item
94        item_name = param.name
95        item1 = QtGui.QStandardItem(item_name)
96        item1.setCheckable(True)
97        # Param values
98        item2 = QtGui.QStandardItem(str(param.default))
99        # TODO: the error column.
100        # Either add a proxy model or a custom view delegate
101        item_err = QtGui.QStandardItem()
102        item4 = QtGui.QStandardItem(str(param.limits[0]))
103        item5 = QtGui.QStandardItem(str(param.limits[1]))
104        item6 = QtGui.QStandardItem(param.units)
105        item.append([item1, item2, item4, item5, item6])
106    return item
107
108def addCheckedListToModel(model, param_list):
109    """
110    Add a QItem to model. Makes the QItem checkable
111    """
112    assert isinstance(model, QtGui.QStandardItemModel)
113    item_list = [QtGui.QStandardItem(item) for item in param_list]
114    item_list[0].setCheckable(True)
115    model.appendRow(item_list)
116
117def addHeadersToModel(model):
118    """
119    Adds predefined headers to the model
120    """
121    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
122    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
123    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
124    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
125    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
126
127def addErrorHeadersToModel(model):
128    """
129    Adds predefined headers to the model
130    """
131    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
132    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
133    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
134    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
135    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
136    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
137
138def addPolyHeadersToModel(model):
139    """
140    Adds predefined headers to the model
141    """
142    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
143    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
144    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
145    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
146    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
147    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
148    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
149
150def addShellsToModel(parameters, model, index):
151    """
152    Find out multishell parameters and update the model with the requested number of them
153    """
154    multishell_parameters = getIterParams(parameters)
155
156    for i in xrange(index):
157        for par in multishell_parameters:
158            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
159            param_name = replaceShellName(par.name, i+1)
160            item1 = QtGui.QStandardItem(param_name)
161            item1.setCheckable(True)
162            # check for polydisp params
163            if par.polydisperse:
164                poly_item = QtGui.QStandardItem("Polydispersity")
165                item1_1 = QtGui.QStandardItem("Distribution")
166                # Find param in volume_params
167                for p in parameters.form_volume_parameters:
168                    if p.name != par.name:
169                        continue
170                    item1_2 = QtGui.QStandardItem(str(p.default))
171                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
172                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
173                    item1_5 = QtGui.QStandardItem(p.units)
174                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
175                    break
176                item1.appendRow([poly_item])
177
178            item2 = QtGui.QStandardItem(str(par.default))
179            item3 = QtGui.QStandardItem(str(par.limits[0]))
180            item4 = QtGui.QStandardItem(str(par.limits[1]))
181            item5 = QtGui.QStandardItem(par.units)
182            model.appendRow([item1, item2, item3, item4, item5])
183
184def calculateChi2(reference_data, current_data):
185    """
186    Calculate Chi2 value between two sets of data
187    """
188
189    # WEIGHING INPUT
190    #from sas.sasgui.perspectives.fitting.utils import get_weight
191    #flag = self.get_weight_flag()
192    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
193
194    if reference_data == None:
195       return chisqr
196
197    # temporary default values for index and weight
198    index = None
199    weight = None
200
201    # Get data: data I, theory I, and data dI in order
202    if isinstance(reference_data, Data2D):
203        if index == None:
204            index = numpy.ones(len(current_data.data), dtype=bool)
205        if weight != None:
206            current_data.err_data = weight
207        # get rid of zero error points
208        index = index & (current_data.err_data != 0)
209        index = index & (numpy.isfinite(current_data.data))
210        fn = current_data.data[index]
211        gn = reference_data.data[index]
212        en = current_data.err_data[index]
213    else:
214        # 1 d theory from model_thread is only in the range of index
215        if index == None:
216            index = numpy.ones(len(current_data.y), dtype=bool)
217        if weight != None:
218            current_data.dy = weight
219        if current_data.dy == None or current_data.dy == []:
220            dy = numpy.ones(len(current_data.y))
221        else:
222            ## Set consistently w/AbstractFitengine:
223            # But this should be corrected later.
224            dy = deepcopy(current_data.dy)
225            dy[dy == 0] = 1
226        fn = current_data.y[index]
227        gn = reference_data.y
228        en = dy[index]
229    # Calculate the residual
230    try:
231        res = (fn - gn) / en
232    except ValueError:
233        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
234        return None
235
236    residuals = res[numpy.isfinite(res)]
237    chisqr = numpy.average(residuals * residuals)
238
239    return chisqr
240
241def residualsData1D(reference_data, current_data):
242    """
243    Calculate the residuals for difference of two Data1D sets
244    """
245    # temporary default values for index and weight
246    index = None
247    weight = None
248
249    # 1d theory from model_thread is only in the range of index
250    if current_data.dy is None or current_data.dy == []:
251        dy = numpy.ones(len(current_data.y))
252    else:
253        dy = weight if weight is not None else numpy.ones(len(current_data.y))
254        dy[dy == 0] = 1
255    fn = current_data.y[index][0]
256    gn = reference_data.y
257    en = dy[index][0]
258    # build residuals
259    residuals = Data1D()
260    if len(fn) == len(gn):
261        y = (fn - gn)/en
262        residuals.y = -y
263    else:
264        y = (fn - gn[index][0]) / en
265        residuals.y = y
266
267    #try:
268    #    y = (fn - gn)/en
269    #    residuals.y = -y
270    #except ValueError:
271    #    msg = "ResidualPlot Error: different number of data points in theory"
272    #    print msg
273    #    y = (fn - gn[index][0]) / en
274    #    residuals.y = y
275    residuals.x = current_data.x[index][0]
276    residuals.dy = numpy.ones(len(residuals.y))
277    residuals.dx = None
278    residuals.dxl = None
279    residuals.dxw = None
280    residuals.ytransform = 'y'
281    # For latter scale changes
282    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
283    residuals.yaxis('\\rm{Residuals} ', 'normalized')
284
285    return residuals
286
287def residualsData2D(reference_data, current_data):
288    """
289    Calculate the residuals for difference of two Data2D sets
290    """
291    # temporary default values for index and weight
292    index = None
293    weight = None
294
295    # build residuals
296    residuals = Data2D()
297    # Not for trunk the line below, instead use the line above
298    current_data.clone_without_data(len(current_data.data), residuals)
299    residuals.data = None
300    fn = current_data.data
301    gn = reference_data.data
302    en = current_data.err_data if weight is None else weight
303    residuals.data = (fn - gn) / en
304    residuals.qx_data = current_data.qx_data
305    residuals.qy_data = current_data.qy_data
306    residuals.q_data = current_data.q_data
307    residuals.err_data = numpy.ones(len(residuals.data))
308    residuals.xmin = min(residuals.qx_data)
309    residuals.xmax = max(residuals.qx_data)
310    residuals.ymin = min(residuals.qy_data)
311    residuals.ymax = max(residuals.qy_data)
312    residuals.q_data = current_data.q_data
313    residuals.mask = current_data.mask
314    residuals.scale = 'linear'
315    # check the lengths
316    if len(residuals.data) != len(residuals.q_data):
317        return None
318    return residuals
319
320def plotResiduals(reference_data, current_data):
321    """
322    Create Data1D/Data2D with residuals, ready for plotting
323    """
324    data_copy = deepcopy(current_data)
325    # Get data: data I, theory I, and data dI in order
326    method_name = current_data.__class__.__name__
327    residuals_dict = {"Data1D": residualsData1D,
328                      "Data2D": residualsData2D}
329
330    residuals = residuals_dict[method_name](reference_data, data_copy)
331
332    theory_name = str(current_data.name.split()[0])
333    residuals.name = "Residuals for " + str(theory_name) + "[" + \
334                    str(reference_data.filename) + "]"
335    residuals.title = residuals.name
336    residuals.ytransform = 'y'
337
338    # when 2 data have the same id override the 1 st plotted
339    # include the last part if keeping charts for separate models is required
340    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
341    # group_id specify on which panel to plot this data
342    group_id = reference_data.group_id
343    residuals.group_id = "res" + str(group_id)
344   
345    # Symbol
346    residuals.symbol = 0
347    residuals.hide_error = False
348
349    return residuals
350
351
352def binary_encode(i, digits):
353    return [i >> d & 1 for d in xrange(digits)]
354
Note: See TracBrowser for help on using the repository browser.