source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ f182f93

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since f182f93 was f182f93, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Fitting connected. Initial prototype

  • Property mode set to 100755
File size: 12.7 KB
Line 
1from PyQt4 import QtGui
2from PyQt4 import QtCore
3
4import numpy
5from copy import deepcopy
6
7from sas.sasgui.guiframe.dataFitting import Data1D
8from sas.sasgui.guiframe.dataFitting import Data2D
9
10def replaceShellName(param_name, value):
11    """
12    Updates parameter name from <param_name>[n_shell] to <param_name>value
13    """
14    assert '[' in param_name
15    return param_name[:param_name.index('[')]+str(value)
16
17def getIterParams(model):
18    """
19    Returns a list of all multi-shell parameters in 'model'
20    """
21    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
22
23def getMultiplicity(model):
24    """
25    Finds out if 'model' has multishell parameters.
26    If so, returns the name of the counter parameter and the number of shells
27    """
28    iter_params = getIterParams(model)
29    param_name = ""
30    param_length = 0
31    if iter_params:
32        param_length = iter_params[0].length
33        param_name = iter_params[0].length_control
34        if param_name is None and '[' in iter_params[0].name:
35            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
36    return (param_name, param_length)
37
38def addParametersToModel(parameters, model):
39    """
40    Update local ModelModel with sasmodel parameters
41    """
42    multishell_parameters = getIterParams(parameters)
43    multishell_param_name, _ = getMultiplicity(parameters)
44
45    for param in parameters.iq_parameters:
46        # don't include shell parameters
47        if param.name == multishell_param_name:
48            continue
49        # Modify parameter name from <param>[n] to <param>1
50        item_name = param.name
51        if param in multishell_parameters:
52            continue
53        #    item_name = replaceShellName(param.name, 1)
54
55        item1 = QtGui.QStandardItem(item_name)
56        item1.setCheckable(True)
57        item_err = QtGui.QStandardItem()
58        # check for polydisp params
59        if param.polydisperse:
60            poly_item = QtGui.QStandardItem("Polydispersity")
61            item1_1 = QtGui.QStandardItem("Distribution")
62            # Find param in volume_params
63            for p in parameters.form_volume_parameters:
64                if p.name != param.name:
65                    continue
66                item1_2 = QtGui.QStandardItem(str(p.default))
67                item1_3 = QtGui.QStandardItem(str(p.limits[0]))
68                item1_4 = QtGui.QStandardItem(str(p.limits[1]))
69                item1_5 = QtGui.QStandardItem(p.units)
70                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
71                break
72            # Add the polydisp item as a child
73            item1.appendRow([poly_item])
74        # Param values
75        item2 = QtGui.QStandardItem(str(param.default))
76        # TODO: the error column.
77        # Either add a proxy model or a custom view delegate
78        #item_err = QtGui.QStandardItem()
79        item3 = QtGui.QStandardItem(str(param.limits[0]))
80        item4 = QtGui.QStandardItem(str(param.limits[1]))
81        item5 = QtGui.QStandardItem(param.units)
82        model.appendRow([item1, item2, item3, item4, item5])
83
84def addSimpleParametersToModel(parameters, model):
85    """
86    Update local ModelModel with sasmodel parameters
87    """
88    for param in parameters.iq_parameters:
89        # Create the top level, checkable item
90        item_name = param.name
91        item1 = QtGui.QStandardItem(item_name)
92        item1.setCheckable(True)
93        # Param values
94        item2 = QtGui.QStandardItem(str(param.default))
95        # TODO: the error column.
96        # Either add a proxy model or a custom view delegate
97        item_err = QtGui.QStandardItem()
98        item4 = QtGui.QStandardItem(str(param.limits[0]))
99        item5 = QtGui.QStandardItem(str(param.limits[1]))
100        item6 = QtGui.QStandardItem(param.units)
101        model.appendRow([item1, item2, item4, item5, item6])
102
103def addCheckedListToModel(model, param_list):
104    """
105    Add a QItem to model. Makes the QItem checkable
106    """
107    assert isinstance(model, QtGui.QStandardItemModel)
108    item_list = [QtGui.QStandardItem(item) for item in param_list]
109    item_list[0].setCheckable(True)
110    model.appendRow(item_list)
111
112def addHeadersToModel(model):
113    """
114    Adds predefined headers to the model
115    """
116    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
117    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
118    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
119    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
120    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
121
122def addErrorHeadersToModel(model):
123    """
124    Adds predefined headers to the model
125    """
126    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
127    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
128    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
129    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
130    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
131    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("[Units]"))
132
133def addPolyHeadersToModel(model):
134    """
135    Adds predefined headers to the model
136    """
137    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
138    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
139    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
140    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
141    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
142    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
143    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
144
145def addShellsToModel(parameters, model, index):
146    """
147    Find out multishell parameters and update the model with the requested number of them
148    """
149    multishell_parameters = getIterParams(parameters)
150
151    for i in xrange(index):
152        for par in multishell_parameters:
153            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
154            param_name = replaceShellName(par.name, i+1)
155            item1 = QtGui.QStandardItem(param_name)
156            item1.setCheckable(True)
157            # check for polydisp params
158            if par.polydisperse:
159                poly_item = QtGui.QStandardItem("Polydispersity")
160                item1_1 = QtGui.QStandardItem("Distribution")
161                # Find param in volume_params
162                for p in parameters.form_volume_parameters:
163                    if p.name != par.name:
164                        continue
165                    item1_2 = QtGui.QStandardItem(str(p.default))
166                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
167                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
168                    item1_5 = QtGui.QStandardItem(p.units)
169                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
170                    break
171                item1.appendRow([poly_item])
172
173            item2 = QtGui.QStandardItem(str(par.default))
174            item3 = QtGui.QStandardItem(str(par.limits[0]))
175            item4 = QtGui.QStandardItem(str(par.limits[1]))
176            item5 = QtGui.QStandardItem(par.units)
177            model.appendRow([item1, item2, item3, item4, item5])
178
179def calculateChi2(reference_data, current_data):
180    """
181    Calculate Chi2 value between two sets of data
182    """
183
184    # WEIGHING INPUT
185    #from sas.sasgui.perspectives.fitting.utils import get_weight
186    #flag = self.get_weight_flag()
187    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
188
189    if reference_data == None:
190       return chisqr
191
192    # temporary default values for index and weight
193    index = None
194    weight = None
195
196    # Get data: data I, theory I, and data dI in order
197    if isinstance(reference_data, Data2D):
198        if index == None:
199            index = numpy.ones(len(current_data.data), dtype=bool)
200        if weight != None:
201            current_data.err_data = weight
202        # get rid of zero error points
203        index = index & (current_data.err_data != 0)
204        index = index & (numpy.isfinite(current_data.data))
205        fn = current_data.data[index]
206        gn = reference_data.data[index]
207        en = current_data.err_data[index]
208    else:
209        # 1 d theory from model_thread is only in the range of index
210        if index == None:
211            index = numpy.ones(len(current_data.y), dtype=bool)
212        if weight != None:
213            current_data.dy = weight
214        if current_data.dy == None or current_data.dy == []:
215            dy = numpy.ones(len(current_data.y))
216        else:
217            ## Set consistently w/AbstractFitengine:
218            # But this should be corrected later.
219            dy = deepcopy(current_data.dy)
220            dy[dy == 0] = 1
221        fn = current_data.y[index]
222        gn = reference_data.y
223        en = dy[index]
224    # Calculate the residual
225    try:
226        res = (fn - gn) / en
227    except ValueError:
228        print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
229        return None
230
231    residuals = res[numpy.isfinite(res)]
232    chisqr = numpy.average(residuals * residuals)
233
234    return chisqr
235
236def residualsData1D(reference_data, current_data):
237    """
238    Calculate the residuals for difference of two Data1D sets
239    """
240    # temporary default values for index and weight
241    index = None
242    weight = None
243
244    # 1d theory from model_thread is only in the range of index
245    if current_data.dy == None or current_data.dy == []:
246        dy = numpy.ones(len(current_data.y))
247    else:
248        if weight == None:
249            dy = numpy.ones(len(current_data.y))
250        else:
251            dy = weight
252        dy[dy == 0] = 1
253    fn = current_data.y[index][0]
254    gn = reference_data.y
255    en = dy[index][0]
256    # build residuals
257    residuals = Data1D()
258    try:
259        y = (fn - gn)/en
260        residuals.y = -y
261    except:
262        msg = "ResidualPlot Error: different # of data points in theory"
263        print msg
264        y = (fn - gn[index][0]) / en
265        residuals.y = y
266    residuals.x = current_data.x[index][0]
267    residuals.dy = numpy.ones(len(residuals.y))
268    residuals.dx = None
269    residuals.dxl = None
270    residuals.dxw = None
271    residuals.ytransform = 'y'
272    # For latter scale changes
273    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
274    residuals.yaxis('\\rm{Residuals} ', 'normalized')
275
276    return residuals
277
278def residualsData2D(reference_data, current_data):
279    """
280    Calculate the residuals for difference of two Data2D sets
281    """
282    # temporary default values for index and weight
283    index = None
284    weight = None
285
286    # build residuals
287    residuals = Data2D()
288    # Not for trunk the line below, instead use the line above
289    current_data.clone_without_data(len(current_data.data), residuals)
290    residuals.data = None
291    fn = current_data.data
292    gn = reference_data.data
293    if weight == None:
294        en = current_data.err_data
295    else:
296        en = weight
297    residuals.data = (fn - gn) / en
298    residuals.qx_data = current_data.qx_data
299    residuals.qy_data = current_data.qy_data
300    residuals.q_data = current_data.q_data
301    residuals.err_data = numpy.ones(len(residuals.data))
302    residuals.xmin = min(residuals.qx_data)
303    residuals.xmax = max(residuals.qx_data)
304    residuals.ymin = min(residuals.qy_data)
305    residuals.ymax = max(residuals.qy_data)
306    residuals.q_data = current_data.q_data
307    residuals.mask = current_data.mask
308    residuals.scale = 'linear'
309    # check the lengths
310    if len(residuals.data) != len(residuals.q_data):
311        return None
312    return residuals
313
314def plotResiduals(reference_data, current_data):
315    """
316    Create Data1D/Data2D with residuals, ready for plotting
317    """
318    data_copy = deepcopy(current_data)
319    # Get data: data I, theory I, and data dI in order
320    method_name = current_data.__class__.__name__
321    residuals_dict = {"Data1D": residualsData1D,
322                      "Data2D": residualsData2D}
323
324    residuals = residuals_dict[method_name](reference_data, data_copy)
325
326    theory_name = str(current_data.name.split()[0])
327    residuals.name = "Residuals for " + str(theory_name) + "[" + \
328                    str(reference_data.filename) + "]"
329    residuals.title = residuals.name
330    residuals.ytransform = 'y'
331
332    # when 2 data have the same id override the 1 st plotted
333    # include the last part if keeping charts for separate models is required
334    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
335    # group_id specify on which panel to plot this data
336    group_id = reference_data.group_id
337    residuals.group_id = "res" + str(group_id)
338   
339    # Symbol
340    residuals.symbol = 0
341    residuals.hide_error = False
342
343    return residuals
344
345
346def binary_encode(i, digits):
347    return [i >> d & 1 for d in xrange(digits)]
348
Note: See TracBrowser for help on using the repository browser.