source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 689222c

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 689222c was 689222c, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Allow more flexibility for malformed files SASVIEW-597

  • Property mode set to 100755
File size: 14.6 KB
Line 
1from copy import deepcopy
2
3from PyQt4 import QtGui
4from PyQt4 import QtCore
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11def replaceShellName(param_name, value):
12    """
13    Updates parameter name from <param_name>[n_shell] to <param_name>value
14    """
15    assert '[' in param_name
16    return param_name[:param_name.index('[')]+str(value)
17
18def getIterParams(model):
19    """
20    Returns a list of all multi-shell parameters in 'model'
21    """
22    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
23
24def getMultiplicity(model):
25    """
26    Finds out if 'model' has multishell parameters.
27    If so, returns the name of the counter parameter and the number of shells
28    """
29    iter_params = getIterParams(model)
30    param_name = ""
31    param_length = 0
32    if iter_params:
33        param_length = iter_params[0].length
34        param_name = iter_params[0].length_control
35        if param_name is None and '[' in iter_params[0].name:
36            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
37    return (param_name, param_length)
38
39def addParametersToModel(parameters, kernel_module, is2D):
40    """
41    Update local ModelModel with sasmodel parameters
42    """
43    multishell_parameters = getIterParams(parameters)
44    multishell_param_name, _ = getMultiplicity(parameters)
45    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
46    item = []
47    for param in params:
48        # don't include shell parameters
49        if param.name == multishell_param_name:
50            continue
51        # Modify parameter name from <param>[n] to <param>1
52        item_name = param.name
53        if param in multishell_parameters:
54            continue
55        #    item_name = replaceShellName(param.name, 1)
56
57        item1 = QtGui.QStandardItem(item_name)
58        item1.setCheckable(True)
59        item1.setEditable(False)
60        # item_err = QtGui.QStandardItem()
61        # check for polydisp params
62        if param.polydisperse:
63            poly_item = QtGui.QStandardItem("Polydispersity")
64            poly_item.setEditable(False)
65            item1_1 = QtGui.QStandardItem("Distribution")
66            item1_1.setEditable(False)
67            # Find param in volume_params
68            for p in parameters.form_volume_parameters:
69                if p.name != param.name:
70                    continue
71                width = kernel_module.getParam(p.name+'.width')
72                type = kernel_module.getParam(p.name+'.type')
73
74                item1_2 = QtGui.QStandardItem(str(width))
75                item1_2.setEditable(False)
76                item1_3 = QtGui.QStandardItem()
77                item1_3.setEditable(False)
78                item1_4 = QtGui.QStandardItem()
79                item1_4.setEditable(False)
80                item1_5 = QtGui.QStandardItem(type)
81                item1_5.setEditable(False)
82                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
83                break
84            # Add the polydisp item as a child
85            item1.appendRow([poly_item])
86        # Param values
87        item2 = QtGui.QStandardItem(str(param.default))
88        # TODO: the error column.
89        # Either add a proxy model or a custom view delegate
90        #item_err = QtGui.QStandardItem()
91        item3 = QtGui.QStandardItem(str(param.limits[0]))
92        item4 = QtGui.QStandardItem(str(param.limits[1]))
93        item5 = QtGui.QStandardItem(param.units)
94        item5.setEditable(False)
95        item.append([item1, item2, item3, item4, item5])
96    return item
97
98def addSimpleParametersToModel(parameters, is2D):
99    """
100    Update local ModelModel with sasmodel parameters
101    """
102    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
103    item = []
104    for param in params:
105        # Create the top level, checkable item
106        item_name = param.name
107        item1 = QtGui.QStandardItem(item_name)
108        item1.setCheckable(True)
109        item1.setEditable(False)
110        # Param values
111        # TODO: add delegate for validation of cells
112        item2 = QtGui.QStandardItem(str(param.default))
113        item4 = QtGui.QStandardItem(str(param.limits[0]))
114        item5 = QtGui.QStandardItem(str(param.limits[1]))
115        item6 = QtGui.QStandardItem(param.units)
116        item6.setEditable(False)
117        item.append([item1, item2, item4, item5, item6])
118    return item
119
120def addCheckedListToModel(model, param_list):
121    """
122    Add a QItem to model. Makes the QItem checkable
123    """
124    assert isinstance(model, QtGui.QStandardItemModel)
125    item_list = [QtGui.QStandardItem(item) for item in param_list]
126    item_list[0].setCheckable(True)
127    model.appendRow(item_list)
128
129def addHeadersToModel(model):
130    """
131    Adds predefined headers to the model
132    """
133    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
134    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
135    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
136    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
137    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Units"))
138
139def addErrorHeadersToModel(model):
140    """
141    Adds predefined headers to the model
142    """
143    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
144    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
145    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
146    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
147    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
148    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Units"))
149
150def addPolyHeadersToModel(model):
151    """
152    Adds predefined headers to the model
153    """
154    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
155    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
156    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
157    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
158    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
159    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
160    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
161
162def addErrorPolyHeadersToModel(model):
163    """
164    Adds predefined headers to the model
165    """
166    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
167    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
168    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
169    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
170    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
171    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
172    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
173    model.setHeaderData(7, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
174
175def addShellsToModel(parameters, model, index):
176    """
177    Find out multishell parameters and update the model with the requested number of them
178    """
179    multishell_parameters = getIterParams(parameters)
180
181    for i in xrange(index):
182        for par in multishell_parameters:
183            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
184            param_name = replaceShellName(par.name, i+1)
185            item1 = QtGui.QStandardItem(param_name)
186            item1.setCheckable(True)
187            # check for polydisp params
188            if par.polydisperse:
189                poly_item = QtGui.QStandardItem("Polydispersity")
190                item1_1 = QtGui.QStandardItem("Distribution")
191                # Find param in volume_params
192                for p in parameters.form_volume_parameters:
193                    if p.name != par.name:
194                        continue
195                    item1_2 = QtGui.QStandardItem(str(p.default))
196                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
197                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
198                    item1_5 = QtGui.QStandardItem(p.units)
199                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
200                    break
201                item1.appendRow([poly_item])
202
203            item2 = QtGui.QStandardItem(str(par.default))
204            item3 = QtGui.QStandardItem(str(par.limits[0]))
205            item4 = QtGui.QStandardItem(str(par.limits[1]))
206            item5 = QtGui.QStandardItem(par.units)
207            model.appendRow([item1, item2, item3, item4, item5])
208
209def calculateChi2(reference_data, current_data):
210    """
211    Calculate Chi2 value between two sets of data
212    """
213
214    # WEIGHING INPUT
215    #from sas.sasgui.perspectives.fitting.utils import get_weight
216    #flag = self.get_weight_flag()
217    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
218    chisqr = None
219    if reference_data is None:
220        return chisqr
221
222    # temporary default values for index and weight
223    index = None
224    weight = None
225
226    # Get data: data I, theory I, and data dI in order
227    if isinstance(reference_data, Data2D):
228        if index is None:
229            index = numpy.ones(len(current_data.data), dtype=bool)
230        if weight is not None:
231            current_data.err_data = weight
232        # get rid of zero error points
233        index = index & (current_data.err_data != 0)
234        index = index & (numpy.isfinite(current_data.data))
235        fn = current_data.data[index]
236        gn = reference_data.data[index]
237        en = current_data.err_data[index]
238    else:
239        # 1 d theory from model_thread is only in the range of index
240        if index is None:
241            index = numpy.ones(len(current_data.y), dtype=bool)
242        if weight is not None:
243            current_data.dy = weight
244        if current_data.dy is None or current_data.dy == []:
245            dy = numpy.ones(len(current_data.y))
246        else:
247            ## Set consistently w/AbstractFitengine:
248            # But this should be corrected later.
249            dy = deepcopy(current_data.dy)
250            dy[dy == 0] = 1
251        fn = current_data.y[index]
252        gn = reference_data.y
253        en = dy[index]
254    # Calculate the residual
255    try:
256        res = (fn - gn) / en
257    except ValueError:
258        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
259        return None
260
261    residuals = res[numpy.isfinite(res)]
262    chisqr = numpy.average(residuals * residuals)
263
264    return chisqr
265
266def residualsData1D(reference_data, current_data):
267    """
268    Calculate the residuals for difference of two Data1D sets
269    """
270    # temporary default values for index and weight
271    index = None
272    weight = None
273
274    # 1d theory from model_thread is only in the range of index
275    if current_data.dy is None or current_data.dy == []:
276        dy = numpy.ones(len(current_data.y))
277    else:
278        dy = weight if weight is not None else numpy.ones(len(current_data.y))
279        dy[dy == 0] = 1
280    fn = current_data.y[index][0]
281    gn = reference_data.y
282    en = dy[index][0]
283    # build residuals
284    residuals = Data1D()
285    if len(fn) == len(gn):
286        y = (fn - gn)/en
287        residuals.y = -y
288    else:
289        # TODO: fix case where applying new data from file on top of existing model data
290        try:
291            y = (fn - gn[index][0]) / en
292            residuals.y = y
293        except ValueError:
294            # value errors may show up every once in a while for malformed columns,
295            # just reuse what's there already
296            pass
297
298    residuals.x = current_data.x[index][0]
299    residuals.dy = numpy.ones(len(residuals.y))
300    residuals.dx = None
301    residuals.dxl = None
302    residuals.dxw = None
303    residuals.ytransform = 'y'
304    # For latter scale changes
305    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
306    residuals.yaxis('\\rm{Residuals} ', 'normalized')
307
308    return residuals
309
310def residualsData2D(reference_data, current_data):
311    """
312    Calculate the residuals for difference of two Data2D sets
313    """
314    # temporary default values for index and weight
315    # index = None
316    weight = None
317
318    # build residuals
319    residuals = Data2D()
320    # Not for trunk the line below, instead use the line above
321    current_data.clone_without_data(len(current_data.data), residuals)
322    residuals.data = None
323    fn = current_data.data
324    gn = reference_data.data
325    en = current_data.err_data if weight is None else weight
326    residuals.data = (fn - gn) / en
327    residuals.qx_data = current_data.qx_data
328    residuals.qy_data = current_data.qy_data
329    residuals.q_data = current_data.q_data
330    residuals.err_data = numpy.ones(len(residuals.data))
331    residuals.xmin = min(residuals.qx_data)
332    residuals.xmax = max(residuals.qx_data)
333    residuals.ymin = min(residuals.qy_data)
334    residuals.ymax = max(residuals.qy_data)
335    residuals.q_data = current_data.q_data
336    residuals.mask = current_data.mask
337    residuals.scale = 'linear'
338    # check the lengths
339    if len(residuals.data) != len(residuals.q_data):
340        return None
341    return residuals
342
343def plotResiduals(reference_data, current_data):
344    """
345    Create Data1D/Data2D with residuals, ready for plotting
346    """
347    data_copy = deepcopy(current_data)
348    # Get data: data I, theory I, and data dI in order
349    method_name = current_data.__class__.__name__
350    residuals_dict = {"Data1D": residualsData1D,
351                      "Data2D": residualsData2D}
352
353    residuals = residuals_dict[method_name](reference_data, data_copy)
354
355    theory_name = str(current_data.name.split()[0])
356    residuals.name = "Residuals for " + str(theory_name) + "[" + \
357                    str(reference_data.filename) + "]"
358    residuals.title = residuals.name
359    residuals.ytransform = 'y'
360
361    # when 2 data have the same id override the 1 st plotted
362    # include the last part if keeping charts for separate models is required
363    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
364    # group_id specify on which panel to plot this data
365    group_id = reference_data.group_id
366    residuals.group_id = "res" + str(group_id)
367
368    # Symbol
369    residuals.symbol = 0
370    residuals.hide_error = False
371
372    return residuals
373
374def binary_encode(i, digits):
375    return [i >> d & 1 for d in xrange(digits)]
376
377def getWeight(data, is2d, flag=None):
378    """
379    Received flag and compute error on data.
380    :param flag: flag to transform error of data.
381    """
382    weight = None
383    if is2d:
384        dy_data = data.err_data
385        data = data.data
386    else:
387        dy_data = data.dy
388        data = data.y
389
390    if flag == 0:
391        weight = numpy.ones_like(data)
392    elif flag == 1:
393        weight = dy_data
394    elif flag == 2:
395        weight = numpy.sqrt(numpy.abs(data))
396    elif flag == 3:
397        weight = numpy.abs(data)
398    return weight
Note: See TracBrowser for help on using the repository browser.