source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ c1e380e

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since c1e380e was 6964d44, checked in by Piotr Rozyczko <rozyczko@…>, 8 years ago

Minor fixes in fitpage

  • Property mode set to 100755
File size: 13.6 KB
Line 
1from copy import deepcopy
2
3from PyQt4 import QtGui
4from PyQt4 import QtCore
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11def replaceShellName(param_name, value):
12    """
13    Updates parameter name from <param_name>[n_shell] to <param_name>value
14    """
15    assert '[' in param_name
16    return param_name[:param_name.index('[')]+str(value)
17
18def getIterParams(model):
19    """
20    Returns a list of all multi-shell parameters in 'model'
21    """
22    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
23
24def getMultiplicity(model):
25    """
26    Finds out if 'model' has multishell parameters.
27    If so, returns the name of the counter parameter and the number of shells
28    """
29    iter_params = getIterParams(model)
30    param_name = ""
31    param_length = 0
32    if iter_params:
33        param_length = iter_params[0].length
34        param_name = iter_params[0].length_control
35        if param_name is None and '[' in iter_params[0].name:
36            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
37    return (param_name, param_length)
38
39def addParametersToModel(parameters, is2D):
40    """
41    Update local ModelModel with sasmodel parameters
42    """
43    multishell_parameters = getIterParams(parameters)
44    multishell_param_name, _ = getMultiplicity(parameters)
45    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
46    item = []
47    for param in params:
48        # don't include shell parameters
49        if param.name == multishell_param_name:
50            continue
51        # Modify parameter name from <param>[n] to <param>1
52        item_name = param.name
53        if param in multishell_parameters:
54            continue
55        #    item_name = replaceShellName(param.name, 1)
56
57        item1 = QtGui.QStandardItem(item_name)
58        item1.setCheckable(True)
59        item1.setEditable(False)
60        # item_err = QtGui.QStandardItem()
61        # check for polydisp params
62        if param.polydisperse:
63            poly_item = QtGui.QStandardItem("Polydispersity")
64            poly_item.setEditable(False)
65            item1_1 = QtGui.QStandardItem("Distribution")
66            item1_1.setEditable(False)
67            # Find param in volume_params
68            for p in parameters.form_volume_parameters:
69                if p.name != param.name:
70                    continue
71                item1_2 = QtGui.QStandardItem(str(p.default))
72                item1_2.setEditable(False)
73                item1_3 = QtGui.QStandardItem(str(p.limits[0]))
74                item1_3.setEditable(False)
75                item1_4 = QtGui.QStandardItem(str(p.limits[1]))
76                item1_4.setEditable(False)
77                item1_5 = QtGui.QStandardItem(p.units)
78                item1_5.setEditable(False)
79                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
80                break
81            # Add the polydisp item as a child
82            item1.appendRow([poly_item])
83        # Param values
84        item2 = QtGui.QStandardItem(str(param.default))
85        # TODO: the error column.
86        # Either add a proxy model or a custom view delegate
87        #item_err = QtGui.QStandardItem()
88        item3 = QtGui.QStandardItem(str(param.limits[0]))
89        item4 = QtGui.QStandardItem(str(param.limits[1]))
90        item5 = QtGui.QStandardItem(param.units)
91        item5.setEditable(False)
92        item.append([item1, item2, item3, item4, item5])
93    return item
94
95def addSimpleParametersToModel(parameters, is2D):
96    """
97    Update local ModelModel with sasmodel parameters
98    """
99    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
100    item = []
101    for param in params:
102        # Create the top level, checkable item
103        item_name = param.name
104        item1 = QtGui.QStandardItem(item_name)
105        item1.setCheckable(True)
106        item1.setEditable(False)
107        # Param values
108        # TODO: add delegate for validation of cells
109        item2 = QtGui.QStandardItem(str(param.default))
110        item4 = QtGui.QStandardItem(str(param.limits[0]))
111        item5 = QtGui.QStandardItem(str(param.limits[1]))
112        item6 = QtGui.QStandardItem(param.units)
113        item6.setEditable(False)
114        item.append([item1, item2, item4, item5, item6])
115    return item
116
117def addCheckedListToModel(model, param_list):
118    """
119    Add a QItem to model. Makes the QItem checkable
120    """
121    assert isinstance(model, QtGui.QStandardItemModel)
122    item_list = [QtGui.QStandardItem(item) for item in param_list]
123    item_list[0].setCheckable(True)
124    model.appendRow(item_list)
125
126def addHeadersToModel(model):
127    """
128    Adds predefined headers to the model
129    """
130    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
131    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
132    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
133    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
134    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Units"))
135
136def addErrorHeadersToModel(model):
137    """
138    Adds predefined headers to the model
139    """
140    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
141    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
142    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
143    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
144    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
145    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Units"))
146
147def addPolyHeadersToModel(model):
148    """
149    Adds predefined headers to the model
150    """
151    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
152    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
153    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
154    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
155    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
156    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
157    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
158
159def addShellsToModel(parameters, model, index):
160    """
161    Find out multishell parameters and update the model with the requested number of them
162    """
163    multishell_parameters = getIterParams(parameters)
164
165    for i in xrange(index):
166        for par in multishell_parameters:
167            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
168            param_name = replaceShellName(par.name, i+1)
169            item1 = QtGui.QStandardItem(param_name)
170            item1.setCheckable(True)
171            # check for polydisp params
172            if par.polydisperse:
173                poly_item = QtGui.QStandardItem("Polydispersity")
174                item1_1 = QtGui.QStandardItem("Distribution")
175                # Find param in volume_params
176                for p in parameters.form_volume_parameters:
177                    if p.name != par.name:
178                        continue
179                    item1_2 = QtGui.QStandardItem(str(p.default))
180                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
181                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
182                    item1_5 = QtGui.QStandardItem(p.units)
183                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
184                    break
185                item1.appendRow([poly_item])
186
187            item2 = QtGui.QStandardItem(str(par.default))
188            item3 = QtGui.QStandardItem(str(par.limits[0]))
189            item4 = QtGui.QStandardItem(str(par.limits[1]))
190            item5 = QtGui.QStandardItem(par.units)
191            model.appendRow([item1, item2, item3, item4, item5])
192
193def calculateChi2(reference_data, current_data):
194    """
195    Calculate Chi2 value between two sets of data
196    """
197
198    # WEIGHING INPUT
199    #from sas.sasgui.perspectives.fitting.utils import get_weight
200    #flag = self.get_weight_flag()
201    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
202    chisqr = None
203    if reference_data is None:
204        return chisqr
205
206    # temporary default values for index and weight
207    index = None
208    weight = None
209
210    # Get data: data I, theory I, and data dI in order
211    if isinstance(reference_data, Data2D):
212        if index is None:
213            index = numpy.ones(len(current_data.data), dtype=bool)
214        if weight is not None:
215            current_data.err_data = weight
216        # get rid of zero error points
217        index = index & (current_data.err_data != 0)
218        index = index & (numpy.isfinite(current_data.data))
219        fn = current_data.data[index]
220        gn = reference_data.data[index]
221        en = current_data.err_data[index]
222    else:
223        # 1 d theory from model_thread is only in the range of index
224        if index is None:
225            index = numpy.ones(len(current_data.y), dtype=bool)
226        if weight is not None:
227            current_data.dy = weight
228        if current_data.dy is None or current_data.dy == []:
229            dy = numpy.ones(len(current_data.y))
230        else:
231            ## Set consistently w/AbstractFitengine:
232            # But this should be corrected later.
233            dy = deepcopy(current_data.dy)
234            dy[dy == 0] = 1
235        fn = current_data.y[index]
236        gn = reference_data.y
237        en = dy[index]
238    # Calculate the residual
239    try:
240        res = (fn - gn) / en
241    except ValueError:
242        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
243        return None
244
245    residuals = res[numpy.isfinite(res)]
246    chisqr = numpy.average(residuals * residuals)
247
248    return chisqr
249
250def residualsData1D(reference_data, current_data):
251    """
252    Calculate the residuals for difference of two Data1D sets
253    """
254    # temporary default values for index and weight
255    index = None
256    weight = None
257
258    # 1d theory from model_thread is only in the range of index
259    if current_data.dy is None or current_data.dy == []:
260        dy = numpy.ones(len(current_data.y))
261    else:
262        dy = weight if weight is not None else numpy.ones(len(current_data.y))
263        dy[dy == 0] = 1
264    fn = current_data.y[index][0]
265    gn = reference_data.y
266    en = dy[index][0]
267    # build residuals
268    residuals = Data1D()
269    if len(fn) == len(gn):
270        y = (fn - gn)/en
271        residuals.y = -y
272    else:
273        # TODO: fix case where applying new data from file on top of existing model data
274        y = (fn - gn[index][0]) / en
275        residuals.y = y
276
277    residuals.x = current_data.x[index][0]
278    residuals.dy = numpy.ones(len(residuals.y))
279    residuals.dx = None
280    residuals.dxl = None
281    residuals.dxw = None
282    residuals.ytransform = 'y'
283    # For latter scale changes
284    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
285    residuals.yaxis('\\rm{Residuals} ', 'normalized')
286
287    return residuals
288
289def residualsData2D(reference_data, current_data):
290    """
291    Calculate the residuals for difference of two Data2D sets
292    """
293    # temporary default values for index and weight
294    # index = None
295    weight = None
296
297    # build residuals
298    residuals = Data2D()
299    # Not for trunk the line below, instead use the line above
300    current_data.clone_without_data(len(current_data.data), residuals)
301    residuals.data = None
302    fn = current_data.data
303    gn = reference_data.data
304    en = current_data.err_data if weight is None else weight
305    residuals.data = (fn - gn) / en
306    residuals.qx_data = current_data.qx_data
307    residuals.qy_data = current_data.qy_data
308    residuals.q_data = current_data.q_data
309    residuals.err_data = numpy.ones(len(residuals.data))
310    residuals.xmin = min(residuals.qx_data)
311    residuals.xmax = max(residuals.qx_data)
312    residuals.ymin = min(residuals.qy_data)
313    residuals.ymax = max(residuals.qy_data)
314    residuals.q_data = current_data.q_data
315    residuals.mask = current_data.mask
316    residuals.scale = 'linear'
317    # check the lengths
318    if len(residuals.data) != len(residuals.q_data):
319        return None
320    return residuals
321
322def plotResiduals(reference_data, current_data):
323    """
324    Create Data1D/Data2D with residuals, ready for plotting
325    """
326    data_copy = deepcopy(current_data)
327    # Get data: data I, theory I, and data dI in order
328    method_name = current_data.__class__.__name__
329    residuals_dict = {"Data1D": residualsData1D,
330                      "Data2D": residualsData2D}
331
332    residuals = residuals_dict[method_name](reference_data, data_copy)
333
334    theory_name = str(current_data.name.split()[0])
335    residuals.name = "Residuals for " + str(theory_name) + "[" + \
336                    str(reference_data.filename) + "]"
337    residuals.title = residuals.name
338    residuals.ytransform = 'y'
339
340    # when 2 data have the same id override the 1 st plotted
341    # include the last part if keeping charts for separate models is required
342    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
343    # group_id specify on which panel to plot this data
344    group_id = reference_data.group_id
345    residuals.group_id = "res" + str(group_id)
346
347    # Symbol
348    residuals.symbol = 0
349    residuals.hide_error = False
350
351    return residuals
352
353def binary_encode(i, digits):
354    return [i >> d & 1 for d in xrange(digits)]
355
356def getWeight(data, is2d, flag=None):
357    """
358    Received flag and compute error on data.
359    :param flag: flag to transform error of data.
360    """
361    weight = None
362    if is2d:
363        dy_data = data.err_data
364        data = data.data
365    else:
366        dy_data = data.dy
367        data = data.y
368
369    if flag == 0:
370        weight = numpy.ones_like(data)
371    elif flag == 1:
372        weight = dy_data
373    elif flag == 2:
374        weight = numpy.sqrt(numpy.abs(data))
375    elif flag == 3:
376        weight = numpy.abs(data)
377    return weight
Note: See TracBrowser for help on using the repository browser.