source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ cadd595a

Last change on this file since cadd595a was a95c44b, checked in by wojciech, 7 years ago

Added descriptive tooltips to column headers

  • Property mode set to 100644
File size: 16.7 KB
Line 
1from copy import deepcopy
2
3from PyQt4 import QtGui
4from PyQt4 import QtCore
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11def replaceShellName(param_name, value):
12    """
13    Updates parameter name from <param_name>[n_shell] to <param_name>value
14    """
15    assert '[' in param_name
16    return param_name[:param_name.index('[')]+str(value)
17
18def getIterParams(model):
19    """
20    Returns a list of all multi-shell parameters in 'model'
21    """
22    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
23
24def getMultiplicity(model):
25    """
26    Finds out if 'model' has multishell parameters.
27    If so, returns the name of the counter parameter and the number of shells
28    """
29    iter_params = getIterParams(model)
30    param_name = ""
31    param_length = 0
32    if iter_params:
33        param_length = iter_params[0].length
34        param_name = iter_params[0].length_control
35        if param_name is None and '[' in iter_params[0].name:
36            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
37    return (param_name, param_length)
38
39def addParametersToModel(parameters, kernel_module, is2D):
40    """
41    Update local ModelModel with sasmodel parameters
42    """
43    multishell_parameters = getIterParams(parameters)
44    multishell_param_name, _ = getMultiplicity(parameters)
45    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
46    item = []
47    for param in params:
48        # don't include shell parameters
49        if param.name == multishell_param_name:
50            continue
51        # Modify parameter name from <param>[n] to <param>1
52        item_name = param.name
53        if param in multishell_parameters:
54            continue
55        #    item_name = replaceShellName(param.name, 1)
56
57        item1 = QtGui.QStandardItem(item_name)
58        item1.setCheckable(True)
59        item1.setEditable(False)
60        # item_err = QtGui.QStandardItem()
61        # check for polydisp params
62        if param.polydisperse:
63            poly_item = QtGui.QStandardItem("Polydispersity")
64            poly_item.setEditable(False)
65            item1_1 = QtGui.QStandardItem("Distribution")
66            item1_1.setEditable(False)
67            # Find param in volume_params
68            for p in parameters.form_volume_parameters:
69                if p.name != param.name:
70                    continue
71                width = kernel_module.getParam(p.name+'.width')
72                type = kernel_module.getParam(p.name+'.type')
73
74                item1_2 = QtGui.QStandardItem(str(width))
75                item1_2.setEditable(False)
76                item1_3 = QtGui.QStandardItem()
77                item1_3.setEditable(False)
78                item1_4 = QtGui.QStandardItem()
79                item1_4.setEditable(False)
80                item1_5 = QtGui.QStandardItem(type)
81                item1_5.setEditable(False)
82                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
83                break
84            # Add the polydisp item as a child
85            item1.appendRow([poly_item])
86        # Param values
87        item2 = QtGui.QStandardItem(str(param.default))
88        # TODO: the error column.
89        # Either add a proxy model or a custom view delegate
90        #item_err = QtGui.QStandardItem()
91        item3 = QtGui.QStandardItem(str(param.limits[0]))
92        item4 = QtGui.QStandardItem(str(param.limits[1]))
93        item5 = QtGui.QStandardItem(param.units)
94        item5.setEditable(False)
95        item.append([item1, item2, item3, item4, item5])
96    return item
97
98def addSimpleParametersToModel(parameters, is2D):
99    """
100    Update local ModelModel with sasmodel parameters
101    """
102    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
103    item = []
104    for param in params:
105        # Create the top level, checkable item
106        item_name = param.name
107        item1 = QtGui.QStandardItem(item_name)
108        item1.setCheckable(True)
109        item1.setEditable(False)
110        # Param values
111        # TODO: add delegate for validation of cells
112        item2 = QtGui.QStandardItem(str(param.default))
113        item4 = QtGui.QStandardItem(str(param.limits[0]))
114        item5 = QtGui.QStandardItem(str(param.limits[1]))
115        item6 = QtGui.QStandardItem(param.units)
116        item6.setEditable(False)
117        item.append([item1, item2, item4, item5, item6])
118    return item
119
120def addCheckedListToModel(model, param_list):
121    """
122    Add a QItem to model. Makes the QItem checkable
123    """
124    assert isinstance(model, QtGui.QStandardItemModel)
125    item_list = [QtGui.QStandardItem(item) for item in param_list]
126    item_list[0].setCheckable(True)
127    model.appendRow(item_list)
128
129def addHeadersToModel(model):
130    """
131    Adds predefined headers to the model
132    """
133    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
134    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
135    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
136    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
137    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Units"))
138
139    model.header_tooltips = ['Select parameter for fitting',
140                             'Enter parameter value',
141                             'Enter minimum value for parameter',
142                             'Enter maximum value for parameter',
143                             'Unit of the parameter']
144def addErrorHeadersToModel(model):
145    """
146    Adds predefined headers to the model
147    """
148    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
149    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("Value"))
150    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
151    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
152    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
153    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Units"))
154
155    model.header_tooltips = ['Select parameter for fitting',
156                             'Enter parameter value',
157                             'Error value for fitted parameter',
158                             'Enter minimum value for parameter',
159                             'Enter maximum value for parameter',
160                             'Unit of the parameter']
161
162def addPolyHeadersToModel(model):
163    """
164    Adds predefined headers to the model
165    """
166    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
167    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
168    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
169    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
170    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
171    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
172    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
173    model.setHeaderData(7, QtCore.Qt.Horizontal, QtCore.QVariant("Filename"))
174
175    model.header_tooltips = ['Select parameter for fitting',
176                             'Enter polydispersity ratio (STD/mean). '
177                             'STD: standard deviation from the mean value',
178                             'Enter minimum value for parameter',
179                             'Enter maximum value for parameter',
180                             'Enter number of points for parameter',
181                             'Enter number of sigmas parameter',
182                             'Select distribution function',
183                             'Select filename with user-definable distribution']
184
185def addErrorPolyHeadersToModel(model):
186    """
187    Adds predefined headers to the model
188    """
189    model.setHeaderData(0, QtCore.Qt.Horizontal, QtCore.QVariant("Parameter"))
190    model.setHeaderData(1, QtCore.Qt.Horizontal, QtCore.QVariant("PD[ratio]"))
191    model.setHeaderData(2, QtCore.Qt.Horizontal, QtCore.QVariant("Error"))
192    model.setHeaderData(3, QtCore.Qt.Horizontal, QtCore.QVariant("Min"))
193    model.setHeaderData(4, QtCore.Qt.Horizontal, QtCore.QVariant("Max"))
194    model.setHeaderData(5, QtCore.Qt.Horizontal, QtCore.QVariant("Npts"))
195    model.setHeaderData(6, QtCore.Qt.Horizontal, QtCore.QVariant("Nsigs"))
196    model.setHeaderData(7, QtCore.Qt.Horizontal, QtCore.QVariant("Function"))
197    model.setHeaderData(8, QtCore.Qt.Horizontal, QtCore.QVariant("Filename"))
198
199    model.header_tooltips = ['Select parameter for fitting',
200                             'Enter polydispersity ratio (STD/mean). '
201                             'STD: standard deviation from the mean value',
202                             'Error value for fitted parameter',
203                             'Enter minimum value for parameter',
204                             'Enter maximum value for parameter',
205                             'Enter number of points for parameter',
206                             'Enter number of sigmas parameter',
207                             'Select distribution function',
208                             'Select filename with user-definable distribution']
209
210def addShellsToModel(parameters, model, index):
211    """
212    Find out multishell parameters and update the model with the requested number of them
213    """
214    multishell_parameters = getIterParams(parameters)
215
216    for i in xrange(index):
217        for par in multishell_parameters:
218            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
219            param_name = replaceShellName(par.name, i+1)
220            item1 = QtGui.QStandardItem(param_name)
221            item1.setCheckable(True)
222            # check for polydisp params
223            if par.polydisperse:
224                poly_item = QtGui.QStandardItem("Polydispersity")
225                item1_1 = QtGui.QStandardItem("Distribution")
226                # Find param in volume_params
227                for p in parameters.form_volume_parameters:
228                    if p.name != par.name:
229                        continue
230                    item1_2 = QtGui.QStandardItem(str(p.default))
231                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
232                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
233                    item1_5 = QtGui.QStandardItem(p.units)
234                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
235                    break
236                item1.appendRow([poly_item])
237
238            item2 = QtGui.QStandardItem(str(par.default))
239            item3 = QtGui.QStandardItem(str(par.limits[0]))
240            item4 = QtGui.QStandardItem(str(par.limits[1]))
241            item5 = QtGui.QStandardItem(par.units)
242            model.appendRow([item1, item2, item3, item4, item5])
243
244def calculateChi2(reference_data, current_data):
245    """
246    Calculate Chi2 value between two sets of data
247    """
248
249    # WEIGHING INPUT
250    #from sas.sasgui.perspectives.fitting.utils import get_weight
251    #flag = self.get_weight_flag()
252    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
253    chisqr = None
254    if reference_data is None:
255        return chisqr
256
257    # temporary default values for index and weight
258    index = None
259    weight = None
260
261    # Get data: data I, theory I, and data dI in order
262    if isinstance(reference_data, Data2D):
263        if index is None:
264            index = numpy.ones(len(current_data.data), dtype=bool)
265        if weight is not None:
266            current_data.err_data = weight
267        # get rid of zero error points
268        index = index & (current_data.err_data != 0)
269        index = index & (numpy.isfinite(current_data.data))
270        fn = current_data.data[index]
271        gn = reference_data.data[index]
272        en = current_data.err_data[index]
273    else:
274        # 1 d theory from model_thread is only in the range of index
275        if index is None:
276            index = numpy.ones(len(current_data.y), dtype=bool)
277        if weight is not None:
278            current_data.dy = weight
279        if current_data.dy is None or current_data.dy == []:
280            dy = numpy.ones(len(current_data.y))
281        else:
282            ## Set consistently w/AbstractFitengine:
283            # But this should be corrected later.
284            dy = deepcopy(current_data.dy)
285            dy[dy == 0] = 1
286        fn = current_data.y[index]
287        gn = reference_data.y
288        en = dy[index]
289    # Calculate the residual
290    try:
291        res = (fn - gn) / en
292    except ValueError:
293        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
294        return None
295
296    residuals = res[numpy.isfinite(res)]
297    chisqr = numpy.average(residuals * residuals)
298
299    return chisqr
300
301def residualsData1D(reference_data, current_data):
302    """
303    Calculate the residuals for difference of two Data1D sets
304    """
305    # temporary default values for index and weight
306    index = None
307    weight = None
308
309    # 1d theory from model_thread is only in the range of index
310    if current_data.dy is None or current_data.dy == []:
311        dy = numpy.ones(len(current_data.y))
312    else:
313        dy = weight if weight is not None else numpy.ones(len(current_data.y))
314        dy[dy == 0] = 1
315    fn = current_data.y[index][0]
316    gn = reference_data.y
317    en = dy[index][0]
318    # build residuals
319    residuals = Data1D()
320    if len(fn) == len(gn):
321        y = (fn - gn)/en
322        residuals.y = -y
323    else:
324        # TODO: fix case where applying new data from file on top of existing model data
325        try:
326            y = (fn - gn[index][0]) / en
327            residuals.y = y
328        except ValueError:
329            # value errors may show up every once in a while for malformed columns,
330            # just reuse what's there already
331            pass
332
333    residuals.x = current_data.x[index][0]
334    residuals.dy = numpy.ones(len(residuals.y))
335    residuals.dx = None
336    residuals.dxl = None
337    residuals.dxw = None
338    residuals.ytransform = 'y'
339    # For latter scale changes
340    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
341    residuals.yaxis('\\rm{Residuals} ', 'normalized')
342
343    return residuals
344
345def residualsData2D(reference_data, current_data):
346    """
347    Calculate the residuals for difference of two Data2D sets
348    """
349    # temporary default values for index and weight
350    # index = None
351    weight = None
352
353    # build residuals
354    residuals = Data2D()
355    # Not for trunk the line below, instead use the line above
356    current_data.clone_without_data(len(current_data.data), residuals)
357    residuals.data = None
358    fn = current_data.data
359    gn = reference_data.data
360    en = current_data.err_data if weight is None else weight
361    residuals.data = (fn - gn) / en
362    residuals.qx_data = current_data.qx_data
363    residuals.qy_data = current_data.qy_data
364    residuals.q_data = current_data.q_data
365    residuals.err_data = numpy.ones(len(residuals.data))
366    residuals.xmin = min(residuals.qx_data)
367    residuals.xmax = max(residuals.qx_data)
368    residuals.ymin = min(residuals.qy_data)
369    residuals.ymax = max(residuals.qy_data)
370    residuals.q_data = current_data.q_data
371    residuals.mask = current_data.mask
372    residuals.scale = 'linear'
373    # check the lengths
374    if len(residuals.data) != len(residuals.q_data):
375        return None
376    return residuals
377
378def plotResiduals(reference_data, current_data):
379    """
380    Create Data1D/Data2D with residuals, ready for plotting
381    """
382    data_copy = deepcopy(current_data)
383    # Get data: data I, theory I, and data dI in order
384    method_name = current_data.__class__.__name__
385    residuals_dict = {"Data1D": residualsData1D,
386                      "Data2D": residualsData2D}
387
388    residuals = residuals_dict[method_name](reference_data, data_copy)
389
390    theory_name = str(current_data.name.split()[0])
391    residuals.name = "Residuals for " + str(theory_name) + "[" + \
392                    str(reference_data.filename) + "]"
393    residuals.title = residuals.name
394    residuals.ytransform = 'y'
395
396    # when 2 data have the same id override the 1 st plotted
397    # include the last part if keeping charts for separate models is required
398    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
399    # group_id specify on which panel to plot this data
400    group_id = reference_data.group_id
401    residuals.group_id = "res" + str(group_id)
402
403    # Symbol
404    residuals.symbol = 0
405    residuals.hide_error = False
406
407    return residuals
408
409def binary_encode(i, digits):
410    return [i >> d & 1 for d in xrange(digits)]
411
412def getWeight(data, is2d, flag=None):
413    """
414    Received flag and compute error on data.
415    :param flag: flag to transform error of data.
416    """
417    weight = None
418    if is2d:
419        dy_data = data.err_data
420        data = data.data
421    else:
422        dy_data = data.dy
423        data = data.y
424
425    if flag == 0:
426        weight = numpy.ones_like(data)
427    elif flag == 1:
428        weight = dy_data
429    elif flag == 2:
430        weight = numpy.sqrt(numpy.abs(data))
431    elif flag == 3:
432        weight = numpy.abs(data)
433    return weight
Note: See TracBrowser for help on using the repository browser.