source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ ca7c6bd

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since ca7c6bd was ca7c6bd, checked in by wojciech, 6 years ago

Attempt to pass parameters

  • Property mode set to 100644
File size: 14.8 KB
Line 
1from copy import deepcopy
2
3from PyQt4 import QtGui
4from PyQt4 import QtCore
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11def replaceShellName(param_name, value):
12    """
13    Updates parameter name from <param_name>[n_shell] to <param_name>value
14    """
15    assert '[' in param_name
16    return param_name[:param_name.index('[')]+str(value)
17
18def getIterParams(model):
19    """
20    Returns a list of all multi-shell parameters in 'model'
21    """
22    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
23
24def getMultiplicity(model):
25    """
26    Finds out if 'model' has multishell parameters.
27    If so, returns the name of the counter parameter and the number of shells
28    """
29    iter_params = getIterParams(model)
30    param_name = ""
31    param_length = 0
32    if iter_params:
33        param_length = iter_params[0].length
34        param_name = iter_params[0].length_control
35        if param_name is None and '[' in iter_params[0].name:
36            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
37    return (param_name, param_length)
38
39def addParametersToModel(parameters, kernel_module, is2D):
40    """
41    Update local ModelModel with sasmodel parameters
42    """
43    multishell_parameters = getIterParams(parameters)
44    multishell_param_name, _ = getMultiplicity(parameters)
45    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
46    item = []
47    for param in params:
48        # don't include shell parameters
49        if param.name == multishell_param_name:
50            continue
51        # Modify parameter name from <param>[n] to <param>1
52        item_name = param.name
53        if param in multishell_parameters:
54            continue
55        #    item_name = replaceShellName(param.name, 1)
56
57        item1 = QtGui.QStandardItem(item_name)
58        item1.setCheckable(True)
59        item1.setEditable(False)
60        # item_err = QtGui.QStandardItem()
61        # check for polydisp params
62        if param.polydisperse:
63            poly_item = QtGui.QStandardItem("Polydispersity")
64            poly_item.setEditable(False)
65            item1_1 = QtGui.QStandardItem("Distribution")
66            item1_1.setEditable(False)
67            # Find param in volume_params
68            for p in parameters.form_volume_parameters:
69                if p.name != param.name:
70                    continue
71                width = kernel_module.getParam(p.name+'.width')
72                type = kernel_module.getParam(p.name+'.type')
73
74                item1_2 = QtGui.QStandardItem(str(width))
75                item1_2.setEditable(False)
76                item1_3 = QtGui.QStandardItem()
77                item1_3.setEditable(False)
78                item1_4 = QtGui.QStandardItem()
79                item1_4.setEditable(False)
80                item1_5 = QtGui.QStandardItem(type)
81                item1_5.setEditable(False)
82                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
83                break
84            # Add the polydisp item as a child
85            item1.appendRow([poly_item])
86        # Param values
87        item2 = QtGui.QStandardItem(str(param.default))
88        # TODO: the error column.
89        # Either add a proxy model or a custom view delegate
90        #item_err = QtGui.QStandardItem()
91        item3 = QtGui.QStandardItem(str(param.limits[0]))
92        item4 = QtGui.QStandardItem(str(param.limits[1]))
93        item5 = QtGui.QStandardItem(param.units)
94        item5.setEditable(False)
95        item.append([item1, item2, item3, item4, item5])
96    return item
97
98def addSimpleParametersToModel(parameters, is2D):
99    """
100    Update local ModelModel with sasmodel parameters
101    """
102    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
103    item = []
104    for param in params:
105        # Create the top level, checkable item
106        item_name = param.name
107        item1 = QtGui.QStandardItem(item_name)
108        item1.setCheckable(True)
109        item1.setEditable(False)
110        # Param values
111        # TODO: add delegate for validation of cells
112        item2 = QtGui.QStandardItem(str(param.default))
113        item4 = QtGui.QStandardItem(str(param.limits[0]))
114        item5 = QtGui.QStandardItem(str(param.limits[1]))
115        item6 = QtGui.QStandardItem(param.units)
116        item6.setEditable(False)
117        item.append([item1, item2, item4, item5, item6])
118    return item
119
120def addCheckedListToModel(model, param_list):
121    """
122    Add a QItem to model. Makes the QItem checkable
123    """
124    assert isinstance(model, QtGui.QStandardItemModel)
125    item_list = [QtGui.QStandardItem(item) for item in param_list]
126    item_list[0].setCheckable(True)
127    model.appendRow(item_list)
128
129def addHeadersToModel(model):
130    """
131    Adds predefined headers to the model
132    """
133    model.parameter_headers = ['Parameter', 'Value', 'Min', 'Max', 'Units']
134    #model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
135    #model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
136    #model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
137    #model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
138    #model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Units"))
139
140def addErrorHeadersToModel(model):
141    """
142    Adds predefined headers to the model
143    """
144    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
145    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
146    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
147    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
148    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
149    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Units"))
150
151def addPolyHeadersToModel(model):
152    """
153    Adds predefined headers to the model
154    """
155    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
156    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
157    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
158    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
159    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
160    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
161    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
162    model.setHeaderData(7, QtCore.Qt.Horizontal, QtCore.QVariant("Filename"))
163
164def addErrorPolyHeadersToModel(model):
165    """
166    Adds predefined headers to the model
167    """
168    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
169    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
170    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
171    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
172    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
173    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
174    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
175    model.setHeaderData(7, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
176    model.setHeaderData(8, QtCore.Qt.Horizontal, QtCore.QVariant("Filename"))
177
178def addShellsToModel(parameters, model, index):
179    """
180    Find out multishell parameters and update the model with the requested number of them
181    """
182    multishell_parameters = getIterParams(parameters)
183
184    for i in xrange(index):
185        for par in multishell_parameters:
186            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
187            param_name = replaceShellName(par.name, i+1)
188            item1 = QtGui.QStandardItem(param_name)
189            item1.setCheckable(True)
190            # check for polydisp params
191            if par.polydisperse:
192                poly_item = QtGui.QStandardItem("Polydispersity")
193                item1_1 = QtGui.QStandardItem("Distribution")
194                # Find param in volume_params
195                for p in parameters.form_volume_parameters:
196                    if p.name != par.name:
197                        continue
198                    item1_2 = QtGui.QStandardItem(str(p.default))
199                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
200                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
201                    item1_5 = QtGui.QStandardItem(p.units)
202                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
203                    break
204                item1.appendRow([poly_item])
205
206            item2 = QtGui.QStandardItem(str(par.default))
207            item3 = QtGui.QStandardItem(str(par.limits[0]))
208            item4 = QtGui.QStandardItem(str(par.limits[1]))
209            item5 = QtGui.QStandardItem(par.units)
210            model.appendRow([item1, item2, item3, item4, item5])
211
212def calculateChi2(reference_data, current_data):
213    """
214    Calculate Chi2 value between two sets of data
215    """
216
217    # WEIGHING INPUT
218    #from sas.sasgui.perspectives.fitting.utils import get_weight
219    #flag = self.get_weight_flag()
220    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
221    chisqr = None
222    if reference_data is None:
223        return chisqr
224
225    # temporary default values for index and weight
226    index = None
227    weight = None
228
229    # Get data: data I, theory I, and data dI in order
230    if isinstance(reference_data, Data2D):
231        if index is None:
232            index = numpy.ones(len(current_data.data), dtype=bool)
233        if weight is not None:
234            current_data.err_data = weight
235        # get rid of zero error points
236        index = index & (current_data.err_data != 0)
237        index = index & (numpy.isfinite(current_data.data))
238        fn = current_data.data[index]
239        gn = reference_data.data[index]
240        en = current_data.err_data[index]
241    else:
242        # 1 d theory from model_thread is only in the range of index
243        if index is None:
244            index = numpy.ones(len(current_data.y), dtype=bool)
245        if weight is not None:
246            current_data.dy = weight
247        if current_data.dy is None or current_data.dy == []:
248            dy = numpy.ones(len(current_data.y))
249        else:
250            ## Set consistently w/AbstractFitengine:
251            # But this should be corrected later.
252            dy = deepcopy(current_data.dy)
253            dy[dy == 0] = 1
254        fn = current_data.y[index]
255        gn = reference_data.y
256        en = dy[index]
257    # Calculate the residual
258    try:
259        res = (fn - gn) / en
260    except ValueError:
261        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
262        return None
263
264    residuals = res[numpy.isfinite(res)]
265    chisqr = numpy.average(residuals * residuals)
266
267    return chisqr
268
269def residualsData1D(reference_data, current_data):
270    """
271    Calculate the residuals for difference of two Data1D sets
272    """
273    # temporary default values for index and weight
274    index = None
275    weight = None
276
277    # 1d theory from model_thread is only in the range of index
278    if current_data.dy is None or current_data.dy == []:
279        dy = numpy.ones(len(current_data.y))
280    else:
281        dy = weight if weight is not None else numpy.ones(len(current_data.y))
282        dy[dy == 0] = 1
283    fn = current_data.y[index][0]
284    gn = reference_data.y
285    en = dy[index][0]
286    # build residuals
287    residuals = Data1D()
288    if len(fn) == len(gn):
289        y = (fn - gn)/en
290        residuals.y = -y
291    else:
292        # TODO: fix case where applying new data from file on top of existing model data
293        try:
294            y = (fn - gn[index][0]) / en
295            residuals.y = y
296        except ValueError:
297            # value errors may show up every once in a while for malformed columns,
298            # just reuse what's there already
299            pass
300
301    residuals.x = current_data.x[index][0]
302    residuals.dy = numpy.ones(len(residuals.y))
303    residuals.dx = None
304    residuals.dxl = None
305    residuals.dxw = None
306    residuals.ytransform = 'y'
307    # For latter scale changes
308    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
309    residuals.yaxis('\\rm{Residuals} ', 'normalized')
310
311    return residuals
312
313def residualsData2D(reference_data, current_data):
314    """
315    Calculate the residuals for difference of two Data2D sets
316    """
317    # temporary default values for index and weight
318    # index = None
319    weight = None
320
321    # build residuals
322    residuals = Data2D()
323    # Not for trunk the line below, instead use the line above
324    current_data.clone_without_data(len(current_data.data), residuals)
325    residuals.data = None
326    fn = current_data.data
327    gn = reference_data.data
328    en = current_data.err_data if weight is None else weight
329    residuals.data = (fn - gn) / en
330    residuals.qx_data = current_data.qx_data
331    residuals.qy_data = current_data.qy_data
332    residuals.q_data = current_data.q_data
333    residuals.err_data = numpy.ones(len(residuals.data))
334    residuals.xmin = min(residuals.qx_data)
335    residuals.xmax = max(residuals.qx_data)
336    residuals.ymin = min(residuals.qy_data)
337    residuals.ymax = max(residuals.qy_data)
338    residuals.q_data = current_data.q_data
339    residuals.mask = current_data.mask
340    residuals.scale = 'linear'
341    # check the lengths
342    if len(residuals.data) != len(residuals.q_data):
343        return None
344    return residuals
345
346def plotResiduals(reference_data, current_data):
347    """
348    Create Data1D/Data2D with residuals, ready for plotting
349    """
350    data_copy = deepcopy(current_data)
351    # Get data: data I, theory I, and data dI in order
352    method_name = current_data.__class__.__name__
353    residuals_dict = {"Data1D": residualsData1D,
354                      "Data2D": residualsData2D}
355
356    residuals = residuals_dict[method_name](reference_data, data_copy)
357
358    theory_name = str(current_data.name.split()[0])
359    residuals.name = "Residuals for " + str(theory_name) + "[" + \
360                    str(reference_data.filename) + "]"
361    residuals.title = residuals.name
362    residuals.ytransform = 'y'
363
364    # when 2 data have the same id override the 1 st plotted
365    # include the last part if keeping charts for separate models is required
366    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
367    # group_id specify on which panel to plot this data
368    group_id = reference_data.group_id
369    residuals.group_id = "res" + str(group_id)
370
371    # Symbol
372    residuals.symbol = 0
373    residuals.hide_error = False
374
375    return residuals
376
377def binary_encode(i, digits):
378    return [i >> d & 1 for d in xrange(digits)]
379
380def getWeight(data, is2d, flag=None):
381    """
382    Received flag and compute error on data.
383    :param flag: flag to transform error of data.
384    """
385    weight = None
386    if is2d:
387        dy_data = data.err_data
388        data = data.data
389    else:
390        dy_data = data.dy
391        data = data.y
392
393    if flag == 0:
394        weight = numpy.ones_like(data)
395    elif flag == 1:
396        weight = dy_data
397    elif flag == 2:
398        weight = numpy.sqrt(numpy.abs(data))
399    elif flag == 3:
400        weight = numpy.abs(data)
401    return weight
Note: See TracBrowser for help on using the repository browser.