source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ d48cc19

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since d48cc19 was d48cc19, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Compute/Show? Plot button logic: SASVIEW-271
Unit tests for plotting in fitting: SASVIEW-501

  • Property mode set to 100755
File size: 13.0 KB
Line 
1from PyQt4 import QtGui
2from PyQt4 import QtCore
3
4import numpy
5from copy import deepcopy
6
7from sas.sasgui.guiframe.dataFitting import Data1D
8from sas.sasgui.guiframe.dataFitting import Data2D
9
10def replaceShellName(param_name, value):
11    """
12    Updates parameter name from <param_name>[n_shell] to <param_name>value
13    """
14    assert '[' in param_name
15    return param_name[:param_name.index('[')]+str(value)
16
17def getIterParams(model):
18    """
19    Returns a list of all multi-shell parameters in 'model'
20    """
21    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
22
23def getMultiplicity(model):
24    """
25    Finds out if 'model' has multishell parameters.
26    If so, returns the name of the counter parameter and the number of shells
27    """
28    iter_params = getIterParams(model)
29    param_name = ""
30    param_length = 0
31    if iter_params:
32        param_length = iter_params[0].length
33        param_name = iter_params[0].length_control
34        if param_name is None and '[' in iter_params[0].name:
35            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
36    return (param_name, param_length)
37
38def addParametersToModel(parameters, is2D):
39    """
40    Update local ModelModel with sasmodel parameters
41    """
42    multishell_parameters = getIterParams(parameters)
43    multishell_param_name, _ = getMultiplicity(parameters)
44    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
45    item = []
46    for param in params:
47        # don't include shell parameters
48        if param.name == multishell_param_name:
49            continue
50        # Modify parameter name from <param>[n] to <param>1
51        item_name = param.name
52        if param in multishell_parameters:
53            continue
54        #    item_name = replaceShellName(param.name, 1)
55
56        item1 = QtGui.QStandardItem(item_name)
57        item1.setCheckable(True)
58        item_err = QtGui.QStandardItem()
59        # check for polydisp params
60        if param.polydisperse:
61            poly_item = QtGui.QStandardItem("Polydispersity")
62            item1_1 = QtGui.QStandardItem("Distribution")
63            # Find param in volume_params
64            for p in parameters.form_volume_parameters:
65                if p.name != param.name:
66                    continue
67                item1_2 = QtGui.QStandardItem(str(p.default))
68                item1_3 = QtGui.QStandardItem(str(p.limits[0]))
69                item1_4 = QtGui.QStandardItem(str(p.limits[1]))
70                item1_5 = QtGui.QStandardItem(p.units)
71                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
72                break
73            # Add the polydisp item as a child
74            item1.appendRow([poly_item])
75        # Param values
76        item2 = QtGui.QStandardItem(str(param.default))
77        # TODO: the error column.
78        # Either add a proxy model or a custom view delegate
79        #item_err = QtGui.QStandardItem()
80        item3 = QtGui.QStandardItem(str(param.limits[0]))
81        item4 = QtGui.QStandardItem(str(param.limits[1]))
82        item5 = QtGui.QStandardItem(param.units)
83        item.append([item1, item2, item3, item4, item5])
84    return item
85
86def addSimpleParametersToModel(parameters, is2D):
87    """
88    Update local ModelModel with sasmodel parameters
89    """
90    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
91    item = []
92    for param in params:
93        # Create the top level, checkable item
94        item_name = param.name
95        item1 = QtGui.QStandardItem(item_name)
96        item1.setCheckable(True)
97        # Param values
98        item2 = QtGui.QStandardItem(str(param.default))
99        # TODO: the error column.
100        # Either add a proxy model or a custom view delegate
101        item_err = QtGui.QStandardItem()
102        item4 = QtGui.QStandardItem(str(param.limits[0]))
103        item5 = QtGui.QStandardItem(str(param.limits[1]))
104        item6 = QtGui.QStandardItem(param.units)
105        item.append([item1, item2, item4, item5, item6])
106    return item
107
108def addCheckedListToModel(model, param_list):
109    """
110    Add a QItem to model. Makes the QItem checkable
111    """
112    assert isinstance(model, QtGui.QStandardItemModel)
113    item_list = [QtGui.QStandardItem(item) for item in param_list]
114    item_list[0].setCheckable(True)
115    model.appendRow(item_list)
116
117def addHeadersToModel(model):
118    """
119    Adds predefined headers to the model
120    """
121    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
122    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
123    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
124    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
125    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
126
127def addErrorHeadersToModel(model):
128    """
129    Adds predefined headers to the model
130    """
131    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
132    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
133    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
134    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
135    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
136    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
137
138def addPolyHeadersToModel(model):
139    """
140    Adds predefined headers to the model
141    """
142    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
143    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
144    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
145    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
146    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
147    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
148    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
149
150def addShellsToModel(parameters, model, index):
151    """
152    Find out multishell parameters and update the model with the requested number of them
153    """
154    multishell_parameters = getIterParams(parameters)
155
156    for i in xrange(index):
157        for par in multishell_parameters:
158            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
159            param_name = replaceShellName(par.name, i+1)
160            item1 = QtGui.QStandardItem(param_name)
161            item1.setCheckable(True)
162            # check for polydisp params
163            if par.polydisperse:
164                poly_item = QtGui.QStandardItem("Polydispersity")
165                item1_1 = QtGui.QStandardItem("Distribution")
166                # Find param in volume_params
167                for p in parameters.form_volume_parameters:
168                    if p.name != par.name:
169                        continue
170                    item1_2 = QtGui.QStandardItem(str(p.default))
171                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
172                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
173                    item1_5 = QtGui.QStandardItem(p.units)
174                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
175                    break
176                item1.appendRow([poly_item])
177
178            item2 = QtGui.QStandardItem(str(par.default))
179            item3 = QtGui.QStandardItem(str(par.limits[0]))
180            item4 = QtGui.QStandardItem(str(par.limits[1]))
181            item5 = QtGui.QStandardItem(par.units)
182            model.appendRow([item1, item2, item3, item4, item5])
183
184def calculateChi2(reference_data, current_data):
185    """
186    Calculate Chi2 value between two sets of data
187    """
188
189    # WEIGHING INPUT
190    #from sas.sasgui.perspectives.fitting.utils import get_weight
191    #flag = self.get_weight_flag()
192    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
193
194    if reference_data == None:
195       return chisqr
196
197    # temporary default values for index and weight
198    index = None
199    weight = None
200
201    # Get data: data I, theory I, and data dI in order
202    if isinstance(reference_data, Data2D):
203        if index == None:
204            index = numpy.ones(len(current_data.data), dtype=bool)
205        if weight != None:
206            current_data.err_data = weight
207        # get rid of zero error points
208        index = index & (current_data.err_data != 0)
209        index = index & (numpy.isfinite(current_data.data))
210        fn = current_data.data[index]
211        gn = reference_data.data[index]
212        en = current_data.err_data[index]
213    else:
214        # 1 d theory from model_thread is only in the range of index
215        if index == None:
216            index = numpy.ones(len(current_data.y), dtype=bool)
217        if weight != None:
218            current_data.dy = weight
219        if current_data.dy == None or current_data.dy == []:
220            dy = numpy.ones(len(current_data.y))
221        else:
222            ## Set consistently w/AbstractFitengine:
223            # But this should be corrected later.
224            dy = deepcopy(current_data.dy)
225            dy[dy == 0] = 1
226        fn = current_data.y[index]
227        gn = reference_data.y
228        en = dy[index]
229    # Calculate the residual
230    try:
231        res = (fn - gn) / en
232    except ValueError:
233        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
234        return None
235
236    residuals = res[numpy.isfinite(res)]
237    chisqr = numpy.average(residuals * residuals)
238
239    return chisqr
240
241def residualsData1D(reference_data, current_data):
242    """
243    Calculate the residuals for difference of two Data1D sets
244    """
245    # temporary default values for index and weight
246    index = None
247    weight = None
248
249    # 1d theory from model_thread is only in the range of index
250    if current_data.dy is None or current_data.dy == []:
251        dy = numpy.ones(len(current_data.y))
252    else:
253        dy = weight if weight is not None else numpy.ones(len(current_data.y))
254        dy[dy == 0] = 1
255    fn = current_data.y[index][0]
256    gn = reference_data.y
257    en = dy[index][0]
258    # build residuals
259    residuals = Data1D()
260    if len(fn) == len(gn):
261        y = (fn - gn)/en
262        residuals.y = -y
263    else:
264        # TODO: fix case where applying new data from file on top of existing model data
265        y = (fn - gn[index][0]) / en
266        residuals.y = y
267
268    #try:
269    #    y = (fn - gn)/en
270    #    residuals.y = -y
271    #except ValueError:
272    #    msg = "ResidualPlot Error: different number of data points in theory"
273    #    print msg
274    #    y = (fn - gn[index][0]) / en
275    #    residuals.y = y
276    residuals.x = current_data.x[index][0]
277    residuals.dy = numpy.ones(len(residuals.y))
278    residuals.dx = None
279    residuals.dxl = None
280    residuals.dxw = None
281    residuals.ytransform = 'y'
282    # For latter scale changes
283    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
284    residuals.yaxis('\\rm{Residuals} ', 'normalized')
285
286    return residuals
287
288def residualsData2D(reference_data, current_data):
289    """
290    Calculate the residuals for difference of two Data2D sets
291    """
292    # temporary default values for index and weight
293    index = None
294    weight = None
295
296    # build residuals
297    residuals = Data2D()
298    # Not for trunk the line below, instead use the line above
299    current_data.clone_without_data(len(current_data.data), residuals)
300    residuals.data = None
301    fn = current_data.data
302    gn = reference_data.data
303    en = current_data.err_data if weight is None else weight
304    residuals.data = (fn - gn) / en
305    residuals.qx_data = current_data.qx_data
306    residuals.qy_data = current_data.qy_data
307    residuals.q_data = current_data.q_data
308    residuals.err_data = numpy.ones(len(residuals.data))
309    residuals.xmin = min(residuals.qx_data)
310    residuals.xmax = max(residuals.qx_data)
311    residuals.ymin = min(residuals.qy_data)
312    residuals.ymax = max(residuals.qy_data)
313    residuals.q_data = current_data.q_data
314    residuals.mask = current_data.mask
315    residuals.scale = 'linear'
316    # check the lengths
317    if len(residuals.data) != len(residuals.q_data):
318        return None
319    return residuals
320
321def plotResiduals(reference_data, current_data):
322    """
323    Create Data1D/Data2D with residuals, ready for plotting
324    """
325    data_copy = deepcopy(current_data)
326    # Get data: data I, theory I, and data dI in order
327    method_name = current_data.__class__.__name__
328    residuals_dict = {"Data1D": residualsData1D,
329                      "Data2D": residualsData2D}
330
331    residuals = residuals_dict[method_name](reference_data, data_copy)
332
333    theory_name = str(current_data.name.split()[0])
334    residuals.name = "Residuals for " + str(theory_name) + "[" + \
335                    str(reference_data.filename) + "]"
336    residuals.title = residuals.name
337    residuals.ytransform = 'y'
338
339    # when 2 data have the same id override the 1 st plotted
340    # include the last part if keeping charts for separate models is required
341    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
342    # group_id specify on which panel to plot this data
343    group_id = reference_data.group_id
344    residuals.group_id = "res" + str(group_id)
345   
346    # Symbol
347    residuals.symbol = 0
348    residuals.hide_error = False
349
350    return residuals
351
352
353def binary_encode(i, digits):
354    return [i >> d & 1 for d in xrange(digits)]
355
Note: See TracBrowser for help on using the repository browser.