source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 7d077d1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 7d077d1 was 7d077d1, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

When chart is shown - react and show changes to Qmodel

  • Property mode set to 100755
File size: 12.1 KB
Line 
1from PyQt4 import QtGui
2from PyQt4 import QtCore
3
4import numpy
5from copy import deepcopy
6
7from sas.sasgui.guiframe.dataFitting import Data1D
8from sas.sasgui.guiframe.dataFitting import Data2D
9
10def replaceShellName(param_name, value):
11    """
12    Updates parameter name from <param_name>[n_shell] to <param_name>value
13    """
14    assert '[' in param_name
15    return param_name[:param_name.index('[')]+str(value)
16
17def getIterParams(model):
18    """
19    Returns a list of all multi-shell parameters in 'model'
20    """
21    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
22
23def getMultiplicity(model):
24    """
25    Finds out if 'model' has multishell parameters.
26    If so, returns the name of the counter parameter and the number of shells
27    """
28    iter_params = getIterParams(model)
29    param_name = ""
30    param_length = 0
31    if iter_params:
32        param_length = iter_params[0].length
33        param_name = iter_params[0].length_control
34        if param_name is None and '[' in iter_params[0].name:
35            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
36    return (param_name, param_length)
37
38def addParametersToModel(parameters, model):
39    """
40    Update local ModelModel with sasmodel parameters
41    """
42    multishell_parameters = getIterParams(parameters)
43    multishell_param_name, _ = getMultiplicity(parameters)
44
45    for param in parameters.iq_parameters:
46        # don't include shell parameters
47        if param.name == multishell_param_name:
48            continue
49        # Modify parameter name from <param>[n] to <param>1
50        item_name = param.name
51        if param in multishell_parameters:
52            continue
53        #    item_name = replaceShellName(param.name, 1)
54
55        item1 = QtGui.QStandardItem(item_name)
56        item1.setCheckable(True)
57        # check for polydisp params
58        if param.polydisperse:
59            poly_item = QtGui.QStandardItem("Polydispersity")
60            item1_1 = QtGui.QStandardItem("Distribution")
61            # Find param in volume_params
62            for p in parameters.form_volume_parameters:
63                if p.name != param.name:
64                    continue
65                item1_2 = QtGui.QStandardItem(str(p.default))
66                item1_3 = QtGui.QStandardItem(str(p.limits[0]))
67                item1_4 = QtGui.QStandardItem(str(p.limits[1]))
68                item1_5 = QtGui.QStandardItem(p.units)
69                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
70                break
71            # Add the polydisp item as a child
72            item1.appendRow([poly_item])
73        # Param values
74        item2 = QtGui.QStandardItem(str(param.default))
75        # TODO: the error column.
76        # Either add a proxy model or a custom view delegate
77        #item_err = QtGui.QStandardItem()
78        item3 = QtGui.QStandardItem(str(param.limits[0]))
79        item4 = QtGui.QStandardItem(str(param.limits[1]))
80        item5 = QtGui.QStandardItem(param.units)
81        model.appendRow([item1, item2, item3, item4, item5])
82
83def addSimpleParametersToModel(parameters, model):
84    """
85    Update local ModelModel with sasmodel parameters
86    """
87    for param in parameters.iq_parameters:
88        # Create the top level, checkable item
89        item_name = param.name
90        item1 = QtGui.QStandardItem(item_name)
91        item1.setCheckable(True)
92        # Param values
93        item2 = QtGui.QStandardItem(str(param.default))
94        # TODO: the error column.
95        # Either add a proxy model or a custom view delegate
96        #item_err = QtGui.QStandardItem()
97        item3 = QtGui.QStandardItem(str(param.limits[0]))
98        item4 = QtGui.QStandardItem(str(param.limits[1]))
99        item5 = QtGui.QStandardItem(param.units)
100        model.appendRow([item1, item2, item3, item4, item5])
101
102def addCheckedListToModel(model, param_list):
103    """
104    Add a QItem to model. Makes the QItem checkable
105    """
106    assert isinstance(model, QtGui.QStandardItemModel)
107    item_list = [QtGui.QStandardItem(item) for item in param_list]
108    item_list[0].setCheckable(True)
109    model.appendRow(item_list)
110
111def addHeadersToModel(model):
112    """
113    Adds predefined headers to the model
114    """
115    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
116    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
117    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
118    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
119    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
120
121def addPolyHeadersToModel(model):
122    """
123    Adds predefined headers to the model
124    """
125    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
126    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
127    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
128    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
129    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
130    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
131    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
132
133def addShellsToModel(parameters, model, index):
134    """
135    Find out multishell parameters and update the model with the requested number of them
136    """
137    multishell_parameters = getIterParams(parameters)
138
139    for i in xrange(index):
140        for par in multishell_parameters:
141            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
142            param_name = replaceShellName(par.name, i+1)
143            item1 = QtGui.QStandardItem(param_name)
144            item1.setCheckable(True)
145            # check for polydisp params
146            if par.polydisperse:
147                poly_item = QtGui.QStandardItem("Polydispersity")
148                item1_1 = QtGui.QStandardItem("Distribution")
149                # Find param in volume_params
150                for p in parameters.form_volume_parameters:
151                    if p.name != par.name:
152                        continue
153                    item1_2 = QtGui.QStandardItem(str(p.default))
154                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
155                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
156                    item1_5 = QtGui.QStandardItem(p.units)
157                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
158                    break
159                item1.appendRow([poly_item])
160
161            item2 = QtGui.QStandardItem(str(par.default))
162            item3 = QtGui.QStandardItem(str(par.limits[0]))
163            item4 = QtGui.QStandardItem(str(par.limits[1]))
164            item5 = QtGui.QStandardItem(par.units)
165            model.appendRow([item1, item2, item3, item4, item5])
166
167def calculateChi2(reference_data, current_data):
168    """
169    Calculate Chi2 value between two sets of data
170    """
171
172    # WEIGHING INPUT
173    #from sas.sasgui.perspectives.fitting.utils import get_weight
174    #flag = self.get_weight_flag()
175    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
176
177    if reference_data == None:
178       return chisqr
179
180    # temporary default values for index and weight
181    index = None
182    weight = None
183
184    # Get data: data I, theory I, and data dI in order
185    if isinstance(reference_data, Data2D):
186        if index == None:
187            index = numpy.ones(len(current_data.data), dtype=bool)
188        if weight != None:
189            current_data.err_data = weight
190        # get rid of zero error points
191        index = index & (current_data.err_data != 0)
192        index = index & (numpy.isfinite(current_data.data))
193        fn = current_data.data[index]
194        gn = reference_data.data[index]
195        en = current_data.err_data[index]
196    else:
197        # 1 d theory from model_thread is only in the range of index
198        if index == None:
199            index = numpy.ones(len(current_data.y), dtype=bool)
200        if weight != None:
201            current_data.dy = weight
202        if current_data.dy == None or current_data.dy == []:
203            dy = numpy.ones(len(current_data.y))
204        else:
205            ## Set consistently w/AbstractFitengine:
206            # But this should be corrected later.
207            dy = deepcopy(current_data.dy)
208            dy[dy == 0] = 1
209        fn = current_data.y[index]
210        gn = reference_data.y
211        en = dy[index]
212    # Calculate the residual
213    try:
214        res = (fn - gn) / en
215    except ValueError:
216        print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
217        return None
218
219    residuals = res[numpy.isfinite(res)]
220    chisqr = numpy.average(residuals * residuals)
221
222    return chisqr
223
224def residualsData1D(reference_data, current_data):
225    """
226    Calculate the residuals for difference of two Data1D sets
227    """
228    # temporary default values for index and weight
229    index = None
230    weight = None
231
232    # 1d theory from model_thread is only in the range of index
233    if current_data.dy == None or current_data.dy == []:
234        dy = numpy.ones(len(current_data.y))
235    else:
236        if weight == None:
237            dy = numpy.ones(len(current_data.y))
238        else:
239            dy = weight
240        dy[dy == 0] = 1
241    fn = current_data.y[index][0]
242    gn = reference_data.y
243    en = dy[index][0]
244    # build residuals
245    residuals = Data1D()
246    try:
247        y = (fn - gn)/en
248        residuals.y = -y
249    except:
250        msg = "ResidualPlot Error: different # of data points in theory"
251        print msg
252        y = (fn - gn[index][0]) / en
253        residuals.y = y
254    residuals.x = current_data.x[index][0]
255    residuals.dy = numpy.ones(len(residuals.y))
256    residuals.dx = None
257    residuals.dxl = None
258    residuals.dxw = None
259    residuals.ytransform = 'y'
260    # For latter scale changes
261    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
262    residuals.yaxis('\\rm{Residuals} ', 'normalized')
263
264    return residuals
265
266def residualsData2D(reference_data, current_data):
267    """
268    Calculate the residuals for difference of two Data2D sets
269    """
270    # temporary default values for index and weight
271    index = None
272    weight = None
273
274    # build residuals
275    residuals = Data2D()
276    # Not for trunk the line below, instead use the line above
277    current_data.clone_without_data(len(current_data.data), residuals)
278    residuals.data = None
279    fn = current_data.data
280    gn = reference_data.data
281    if weight == None:
282        en = current_data.err_data
283    else:
284        en = weight
285    residuals.data = (fn - gn) / en
286    residuals.qx_data = current_data.qx_data
287    residuals.qy_data = current_data.qy_data
288    residuals.q_data = current_data.q_data
289    residuals.err_data = numpy.ones(len(residuals.data))
290    residuals.xmin = min(residuals.qx_data)
291    residuals.xmax = max(residuals.qx_data)
292    residuals.ymin = min(residuals.qy_data)
293    residuals.ymax = max(residuals.qy_data)
294    residuals.q_data = current_data.q_data
295    residuals.mask = current_data.mask
296    residuals.scale = 'linear'
297    # check the lengths
298    if len(residuals.data) != len(residuals.q_data):
299        return None
300    return residuals
301
302def plotResiduals(reference_data, current_data):
303    """
304    Create Data1D/Data2D with residuals, ready for plotting
305    """
306    data_copy = deepcopy(current_data)
307    # Get data: data I, theory I, and data dI in order
308    method_name = current_data.__class__.__name__
309    residuals_dict = {"Data1D": residualsData1D,
310                      "Data2D": residualsData2D}
311
312    residuals = residuals_dict[method_name](reference_data, data_copy)
313
314    theory_name = str(current_data.name.split()[0])
315    residuals.name = "Residuals for " + str(theory_name) + "[" + \
316                    str(reference_data.filename) + "]"
317    residuals.title = residuals.name
318    # when 2 data have the same id override the 1 st plotted
319    # include the last part if keeping charts for separate models is required
320    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
321    # group_id specify on which panel to plot this data
322    group_id = reference_data.group_id
323    residuals.group_id = "res" + str(group_id)
324   
325    # Symbol
326    residuals.symbol = 0
327    residuals.hide_error = False
328
329    return residuals
330
331
332def binary_encode(i, digits):
333    return [i >> d & 1 for d in xrange(digits)]
334
Note: See TracBrowser for help on using the repository browser.