source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ cd2cc745

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since cd2cc745 was 1970780, checked in by Piotr Rozyczko <rozyczko@…>, 8 years ago

Add display of orientational parameters + minor refactoring

  • Property mode set to 100755
File size: 12.8 KB
Line 
1from PyQt4 import QtGui
2from PyQt4 import QtCore
3
4import numpy
5from copy import deepcopy
6
7from sas.sasgui.guiframe.dataFitting import Data1D
8from sas.sasgui.guiframe.dataFitting import Data2D
9
10def replaceShellName(param_name, value):
11    """
12    Updates parameter name from <param_name>[n_shell] to <param_name>value
13    """
14    assert '[' in param_name
15    return param_name[:param_name.index('[')]+str(value)
16
17def getIterParams(model):
18    """
19    Returns a list of all multi-shell parameters in 'model'
20    """
21    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
22
23def getMultiplicity(model):
24    """
25    Finds out if 'model' has multishell parameters.
26    If so, returns the name of the counter parameter and the number of shells
27    """
28    iter_params = getIterParams(model)
29    param_name = ""
30    param_length = 0
31    if iter_params:
32        param_length = iter_params[0].length
33        param_name = iter_params[0].length_control
34        if param_name is None and '[' in iter_params[0].name:
35            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
36    return (param_name, param_length)
37
38def addParametersToModel(parameters, is2D):
39    """
40    Update local ModelModel with sasmodel parameters
41    """
42    multishell_parameters = getIterParams(parameters)
43    multishell_param_name, _ = getMultiplicity(parameters)
44    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
45    item = []
46    for param in params:
47        # don't include shell parameters
48        if param.name == multishell_param_name:
49            continue
50        # Modify parameter name from <param>[n] to <param>1
51        item_name = param.name
52        if param in multishell_parameters:
53            continue
54        #    item_name = replaceShellName(param.name, 1)
55
56        item1 = QtGui.QStandardItem(item_name)
57        item1.setCheckable(True)
58        item_err = QtGui.QStandardItem()
59        # check for polydisp params
60        if param.polydisperse:
61            poly_item = QtGui.QStandardItem("Polydispersity")
62            item1_1 = QtGui.QStandardItem("Distribution")
63            # Find param in volume_params
64            for p in parameters.form_volume_parameters:
65                if p.name != param.name:
66                    continue
67                item1_2 = QtGui.QStandardItem(str(p.default))
68                item1_3 = QtGui.QStandardItem(str(p.limits[0]))
69                item1_4 = QtGui.QStandardItem(str(p.limits[1]))
70                item1_5 = QtGui.QStandardItem(p.units)
71                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
72                break
73            # Add the polydisp item as a child
74            item1.appendRow([poly_item])
75        # Param values
76        item2 = QtGui.QStandardItem(str(param.default))
77        # TODO: the error column.
78        # Either add a proxy model or a custom view delegate
79        #item_err = QtGui.QStandardItem()
80        item3 = QtGui.QStandardItem(str(param.limits[0]))
81        item4 = QtGui.QStandardItem(str(param.limits[1]))
82        item5 = QtGui.QStandardItem(param.units)
83        item.append([item1, item2, item3, item4, item5])
84    return item
85
86def addSimpleParametersToModel(parameters, is2D):
87    """
88    Update local ModelModel with sasmodel parameters
89    """
90    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
91    item = []
92    for param in params:
93        # Create the top level, checkable item
94        item_name = param.name
95        item1 = QtGui.QStandardItem(item_name)
96        item1.setCheckable(True)
97        # Param values
98        item2 = QtGui.QStandardItem(str(param.default))
99        # TODO: the error column.
100        # Either add a proxy model or a custom view delegate
101        item_err = QtGui.QStandardItem()
102        item4 = QtGui.QStandardItem(str(param.limits[0]))
103        item5 = QtGui.QStandardItem(str(param.limits[1]))
104        item6 = QtGui.QStandardItem(param.units)
105        item.append([item1, item2, item4, item5, item6])
106    return item
107
108def addCheckedListToModel(model, param_list):
109    """
110    Add a QItem to model. Makes the QItem checkable
111    """
112    assert isinstance(model, QtGui.QStandardItemModel)
113    item_list = [QtGui.QStandardItem(item) for item in param_list]
114    item_list[0].setCheckable(True)
115    model.appendRow(item_list)
116
117def addHeadersToModel(model):
118    """
119    Adds predefined headers to the model
120    """
121    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
122    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
123    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
124    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
125    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
126
127def addErrorHeadersToModel(model):
128    """
129    Adds predefined headers to the model
130    """
131    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
132    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
133    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
134    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
135    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
136    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
137
138def addPolyHeadersToModel(model):
139    """
140    Adds predefined headers to the model
141    """
142    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
143    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
144    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
145    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
146    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
147    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
148    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
149
150def addShellsToModel(parameters, model, index):
151    """
152    Find out multishell parameters and update the model with the requested number of them
153    """
154    multishell_parameters = getIterParams(parameters)
155
156    for i in xrange(index):
157        for par in multishell_parameters:
158            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
159            param_name = replaceShellName(par.name, i+1)
160            item1 = QtGui.QStandardItem(param_name)
161            item1.setCheckable(True)
162            # check for polydisp params
163            if par.polydisperse:
164                poly_item = QtGui.QStandardItem("Polydispersity")
165                item1_1 = QtGui.QStandardItem("Distribution")
166                # Find param in volume_params
167                for p in parameters.form_volume_parameters:
168                    if p.name != par.name:
169                        continue
170                    item1_2 = QtGui.QStandardItem(str(p.default))
171                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
172                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
173                    item1_5 = QtGui.QStandardItem(p.units)
174                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
175                    break
176                item1.appendRow([poly_item])
177
178            item2 = QtGui.QStandardItem(str(par.default))
179            item3 = QtGui.QStandardItem(str(par.limits[0]))
180            item4 = QtGui.QStandardItem(str(par.limits[1]))
181            item5 = QtGui.QStandardItem(par.units)
182            model.appendRow([item1, item2, item3, item4, item5])
183
184def calculateChi2(reference_data, current_data):
185    """
186    Calculate Chi2 value between two sets of data
187    """
188
189    # WEIGHING INPUT
190    #from sas.sasgui.perspectives.fitting.utils import get_weight
191    #flag = self.get_weight_flag()
192    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
193
194    if reference_data == None:
195       return chisqr
196
197    # temporary default values for index and weight
198    index = None
199    weight = None
200
201    # Get data: data I, theory I, and data dI in order
202    if isinstance(reference_data, Data2D):
203        if index == None:
204            index = numpy.ones(len(current_data.data), dtype=bool)
205        if weight != None:
206            current_data.err_data = weight
207        # get rid of zero error points
208        index = index & (current_data.err_data != 0)
209        index = index & (numpy.isfinite(current_data.data))
210        fn = current_data.data[index]
211        gn = reference_data.data[index]
212        en = current_data.err_data[index]
213    else:
214        # 1 d theory from model_thread is only in the range of index
215        if index == None:
216            index = numpy.ones(len(current_data.y), dtype=bool)
217        if weight != None:
218            current_data.dy = weight
219        if current_data.dy == None or current_data.dy == []:
220            dy = numpy.ones(len(current_data.y))
221        else:
222            ## Set consistently w/AbstractFitengine:
223            # But this should be corrected later.
224            dy = deepcopy(current_data.dy)
225            dy[dy == 0] = 1
226        fn = current_data.y[index]
227        gn = reference_data.y
228        en = dy[index]
229    # Calculate the residual
230    try:
231        res = (fn - gn) / en
232    except ValueError:
233        print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
234        return None
235
236    residuals = res[numpy.isfinite(res)]
237    chisqr = numpy.average(residuals * residuals)
238
239    return chisqr
240
241def residualsData1D(reference_data, current_data):
242    """
243    Calculate the residuals for difference of two Data1D sets
244    """
245    # temporary default values for index and weight
246    index = None
247    weight = None
248
249    # 1d theory from model_thread is only in the range of index
250    if current_data.dy == None or current_data.dy == []:
251        dy = numpy.ones(len(current_data.y))
252    else:
253        if weight == None:
254            dy = numpy.ones(len(current_data.y))
255        else:
256            dy = weight
257        dy[dy == 0] = 1
258    fn = current_data.y[index][0]
259    gn = reference_data.y
260    en = dy[index][0]
261    # build residuals
262    residuals = Data1D()
263    try:
264        y = (fn - gn)/en
265        residuals.y = -y
266    except:
267        msg = "ResidualPlot Error: different # of data points in theory"
268        print msg
269        y = (fn - gn[index][0]) / en
270        residuals.y = y
271    residuals.x = current_data.x[index][0]
272    residuals.dy = numpy.ones(len(residuals.y))
273    residuals.dx = None
274    residuals.dxl = None
275    residuals.dxw = None
276    residuals.ytransform = 'y'
277    # For latter scale changes
278    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
279    residuals.yaxis('\\rm{Residuals} ', 'normalized')
280
281    return residuals
282
283def residualsData2D(reference_data, current_data):
284    """
285    Calculate the residuals for difference of two Data2D sets
286    """
287    # temporary default values for index and weight
288    index = None
289    weight = None
290
291    # build residuals
292    residuals = Data2D()
293    # Not for trunk the line below, instead use the line above
294    current_data.clone_without_data(len(current_data.data), residuals)
295    residuals.data = None
296    fn = current_data.data
297    gn = reference_data.data
298    if weight == None:
299        en = current_data.err_data
300    else:
301        en = weight
302    residuals.data = (fn - gn) / en
303    residuals.qx_data = current_data.qx_data
304    residuals.qy_data = current_data.qy_data
305    residuals.q_data = current_data.q_data
306    residuals.err_data = numpy.ones(len(residuals.data))
307    residuals.xmin = min(residuals.qx_data)
308    residuals.xmax = max(residuals.qx_data)
309    residuals.ymin = min(residuals.qy_data)
310    residuals.ymax = max(residuals.qy_data)
311    residuals.q_data = current_data.q_data
312    residuals.mask = current_data.mask
313    residuals.scale = 'linear'
314    # check the lengths
315    if len(residuals.data) != len(residuals.q_data):
316        return None
317    return residuals
318
319def plotResiduals(reference_data, current_data):
320    """
321    Create Data1D/Data2D with residuals, ready for plotting
322    """
323    data_copy = deepcopy(current_data)
324    # Get data: data I, theory I, and data dI in order
325    method_name = current_data.__class__.__name__
326    residuals_dict = {"Data1D": residualsData1D,
327                      "Data2D": residualsData2D}
328
329    residuals = residuals_dict[method_name](reference_data, data_copy)
330
331    theory_name = str(current_data.name.split()[0])
332    residuals.name = "Residuals for " + str(theory_name) + "[" + \
333                    str(reference_data.filename) + "]"
334    residuals.title = residuals.name
335    residuals.ytransform = 'y'
336
337    # when 2 data have the same id override the 1 st plotted
338    # include the last part if keeping charts for separate models is required
339    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
340    # group_id specify on which panel to plot this data
341    group_id = reference_data.group_id
342    residuals.group_id = "res" + str(group_id)
343   
344    # Symbol
345    residuals.symbol = 0
346    residuals.hide_error = False
347
348    return residuals
349
350
351def binary_encode(i, digits):
352    return [i >> d & 1 for d in xrange(digits)]
353
Note: See TracBrowser for help on using the repository browser.