source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ 4a9786d8

Last change on this file since 4a9786d8 was d0dfcb2, checked in by wojciech, 7 years ago

Pylint fixes

  • Property mode set to 100644
File size: 15.0 KB
Line 
1from copy import deepcopy
2
3from PyQt4 import QtGui
4from PyQt4 import QtCore
5
6import numpy
7
8from sas.qtgui.Plotting.PlotterData import Data1D
9from sas.qtgui.Plotting.PlotterData import Data2D
10
11model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
12
13model_header_tooltips = ['Select parameter for fitting',
14                         'Enter parameter value',
15                         'Enter minimum value for parameter',
16                         'Enter maximum value for parameter',
17                         'Unit of the parameter']
18
19poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
20                        'Function', 'Filename']
21
22poly_header_tooltips = ['Select parameter for fitting',
23                        'Enter polydispersity ratio (STD/mean). '
24                        'STD: standard deviation from the mean value',
25                        'Enter minimum value for parameter',
26                        'Enter maximum value for parameter',
27                        'Enter number of points for parameter',
28                        'Enter number of sigmas parameter',
29                        'Select distribution function',
30                        'Select filename with user-definable distribution']
31
32error_tooltip = 'Error value for fitted parameter'
33header_error_caption = 'Error'
34
35def replaceShellName(param_name, value):
36    """
37    Updates parameter name from <param_name>[n_shell] to <param_name>value
38    """
39    assert '[' in param_name
40    return param_name[:param_name.index('[')]+str(value)
41
42def getIterParams(model):
43    """
44    Returns a list of all multi-shell parameters in 'model'
45    """
46    return list(filter(lambda par: "[" in par.name, model.iq_parameters))
47
48def getMultiplicity(model):
49    """
50    Finds out if 'model' has multishell parameters.
51    If so, returns the name of the counter parameter and the number of shells
52    """
53    iter_params = getIterParams(model)
54    param_name = ""
55    param_length = 0
56    if iter_params:
57        param_length = iter_params[0].length
58        param_name = iter_params[0].length_control
59        if param_name is None and '[' in iter_params[0].name:
60            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
61    return (param_name, param_length)
62
63def addParametersToModel(parameters, kernel_module, is2D):
64    """
65    Update local ModelModel with sasmodel parameters
66    """
67    multishell_parameters = getIterParams(parameters)
68    multishell_param_name, _ = getMultiplicity(parameters)
69    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
70    item = []
71    for param in params:
72        # don't include shell parameters
73        if param.name == multishell_param_name:
74            continue
75        # Modify parameter name from <param>[n] to <param>1
76        item_name = param.name
77        if param in multishell_parameters:
78            continue
79        #    item_name = replaceShellName(param.name, 1)
80
81        item1 = QtGui.QStandardItem(item_name)
82        item1.setCheckable(True)
83        item1.setEditable(False)
84        # item_err = QtGui.QStandardItem()
85        # check for polydisp params
86        if param.polydisperse:
87            poly_item = QtGui.QStandardItem("Polydispersity")
88            poly_item.setEditable(False)
89            item1_1 = QtGui.QStandardItem("Distribution")
90            item1_1.setEditable(False)
91            # Find param in volume_params
92            for p in parameters.form_volume_parameters:
93                if p.name != param.name:
94                    continue
95                width = kernel_module.getParam(p.name+'.width')
96                type = kernel_module.getParam(p.name+'.type')
97
98                item1_2 = QtGui.QStandardItem(str(width))
99                item1_2.setEditable(False)
100                item1_3 = QtGui.QStandardItem()
101                item1_3.setEditable(False)
102                item1_4 = QtGui.QStandardItem()
103                item1_4.setEditable(False)
104                item1_5 = QtGui.QStandardItem(type)
105                item1_5.setEditable(False)
106                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
107                break
108            # Add the polydisp item as a child
109            item1.appendRow([poly_item])
110        # Param values
111        item2 = QtGui.QStandardItem(str(param.default))
112        # TODO: the error column.
113        # Either add a proxy model or a custom view delegate
114        #item_err = QtGui.QStandardItem()
115        item3 = QtGui.QStandardItem(str(param.limits[0]))
116        item4 = QtGui.QStandardItem(str(param.limits[1]))
117        item5 = QtGui.QStandardItem(param.units)
118        item5.setEditable(False)
119        item.append([item1, item2, item3, item4, item5])
120    return item
121
122def addSimpleParametersToModel(parameters, is2D):
123    """
124    Update local ModelModel with sasmodel parameters
125    """
126    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
127    item = []
128    for param in params:
129        # Create the top level, checkable item
130        item_name = param.name
131        item1 = QtGui.QStandardItem(item_name)
132        item1.setCheckable(True)
133        item1.setEditable(False)
134        # Param values
135        # TODO: add delegate for validation of cells
136        item2 = QtGui.QStandardItem(str(param.default))
137        item4 = QtGui.QStandardItem(str(param.limits[0]))
138        item5 = QtGui.QStandardItem(str(param.limits[1]))
139        item6 = QtGui.QStandardItem(param.units)
140        item6.setEditable(False)
141        item.append([item1, item2, item4, item5, item6])
142    return item
143
144def addCheckedListToModel(model, param_list):
145    """
146    Add a QItem to model. Makes the QItem checkable
147    """
148    assert isinstance(model, QtGui.QStandardItemModel)
149    item_list = [QtGui.QStandardItem(item) for item in param_list]
150    item_list[0].setCheckable(True)
151    model.appendRow(item_list)
152
153def addHeadersToModel(model):
154    """
155    Adds predefined headers to the model
156    """
157    for i, item in enumerate(model_header_captions):
158        model.setHeaderData(i, QtCore.Qt.Horizontal, QtCore.QVariant(item))
159
160    model.header_tooltips = model_header_tooltips
161
162def addErrorHeadersToModel(model):
163    """
164    Adds predefined headers to the model
165    """
166    model_header_error_captions = model_header_captions
167    model_header_error_captions.insert(2, header_error_caption)
168    for i, item in enumerate(model_header_error_captions):
169        model.setHeaderData(i, QtCore.Qt.Horizontal, QtCore.QVariant(item))
170
171    model_header_error_tooltips = model_header_tooltips
172    model_header_error_tooltips.insert(2, error_tooltip)
173    model.header_tooltips = model_header_error_tooltips
174
175def addPolyHeadersToModel(model):
176    """
177    Adds predefined headers to the model
178    """
179    for i, item in enumerate(poly_header_captions):
180        model.setHeaderData(i, QtCore.Qt.Horizontal, QtCore.QVariant(item))
181
182    model.header_tooltips = poly_header_tooltips
183
184
185def addErrorPolyHeadersToModel(model):
186    """
187    Adds predefined headers to the model
188    """
189    poly_header_error_captions = poly_header_captions
190    poly_header_error_captions.insert(2, header_error_caption)
191    for i, item in enumerate(poly_header_error_captions):
192        model.setHeaderData(i, QtCore.Qt.Horizontal, QtCore.QVariant(item))
193
194    poly_header_error_tooltips = poly_header_tooltips
195    poly_header_error_tooltips.insert(2, error_tooltip)
196    model.header_tooltips = poly_header_error_tooltips
197
198def addShellsToModel(parameters, model, index):
199    """
200    Find out multishell parameters and update the model with the requested number of them
201    """
202    multishell_parameters = getIterParams(parameters)
203
204    for i in xrange(index):
205        for par in multishell_parameters:
206            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
207            param_name = replaceShellName(par.name, i+1)
208            item1 = QtGui.QStandardItem(param_name)
209            item1.setCheckable(True)
210            # check for polydisp params
211            if par.polydisperse:
212                poly_item = QtGui.QStandardItem("Polydispersity")
213                item1_1 = QtGui.QStandardItem("Distribution")
214                # Find param in volume_params
215                for p in parameters.form_volume_parameters:
216                    if p.name != par.name:
217                        continue
218                    item1_2 = QtGui.QStandardItem(str(p.default))
219                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
220                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
221                    item1_5 = QtGui.QStandardItem(p.units)
222                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
223                    break
224                item1.appendRow([poly_item])
225
226            item2 = QtGui.QStandardItem(str(par.default))
227            item3 = QtGui.QStandardItem(str(par.limits[0]))
228            item4 = QtGui.QStandardItem(str(par.limits[1]))
229            item5 = QtGui.QStandardItem(par.units)
230            model.appendRow([item1, item2, item3, item4, item5])
231
232def calculateChi2(reference_data, current_data):
233    """
234    Calculate Chi2 value between two sets of data
235    """
236
237    # WEIGHING INPUT
238    #from sas.sasgui.perspectives.fitting.utils import get_weight
239    #flag = self.get_weight_flag()
240    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
241    chisqr = None
242    if reference_data is None:
243        return chisqr
244
245    # temporary default values for index and weight
246    index = None
247    weight = None
248
249    # Get data: data I, theory I, and data dI in order
250    if isinstance(reference_data, Data2D):
251        if index is None:
252            index = numpy.ones(len(current_data.data), dtype=bool)
253        if weight is not None:
254            current_data.err_data = weight
255        # get rid of zero error points
256        index = index & (current_data.err_data != 0)
257        index = index & (numpy.isfinite(current_data.data))
258        fn = current_data.data[index]
259        gn = reference_data.data[index]
260        en = current_data.err_data[index]
261    else:
262        # 1 d theory from model_thread is only in the range of index
263        if index is None:
264            index = numpy.ones(len(current_data.y), dtype=bool)
265        if weight is not None:
266            current_data.dy = weight
267        if current_data.dy is None or current_data.dy == []:
268            dy = numpy.ones(len(current_data.y))
269        else:
270            ## Set consistently w/AbstractFitengine:
271            # But this should be corrected later.
272            dy = deepcopy(current_data.dy)
273            dy[dy == 0] = 1
274        fn = current_data.y[index]
275        gn = reference_data.y
276        en = dy[index]
277    # Calculate the residual
278    try:
279        res = (fn - gn) / en
280    except ValueError:
281        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
282        return None
283
284    residuals = res[numpy.isfinite(res)]
285    chisqr = numpy.average(residuals * residuals)
286
287    return chisqr
288
289def residualsData1D(reference_data, current_data):
290    """
291    Calculate the residuals for difference of two Data1D sets
292    """
293    # temporary default values for index and weight
294    index = None
295    weight = None
296
297    # 1d theory from model_thread is only in the range of index
298    if current_data.dy is None or current_data.dy == []:
299        dy = numpy.ones(len(current_data.y))
300    else:
301        dy = weight if weight is not None else numpy.ones(len(current_data.y))
302        dy[dy == 0] = 1
303    fn = current_data.y[index][0]
304    gn = reference_data.y
305    en = dy[index][0]
306    # build residuals
307    residuals = Data1D()
308    if len(fn) == len(gn):
309        y = (fn - gn)/en
310        residuals.y = -y
311    else:
312        # TODO: fix case where applying new data from file on top of existing model data
313        try:
314            y = (fn - gn[index][0]) / en
315            residuals.y = y
316        except ValueError:
317            # value errors may show up every once in a while for malformed columns,
318            # just reuse what's there already
319            pass
320
321    residuals.x = current_data.x[index][0]
322    residuals.dy = numpy.ones(len(residuals.y))
323    residuals.dx = None
324    residuals.dxl = None
325    residuals.dxw = None
326    residuals.ytransform = 'y'
327    # For latter scale changes
328    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
329    residuals.yaxis('\\rm{Residuals} ', 'normalized')
330
331    return residuals
332
333def residualsData2D(reference_data, current_data):
334    """
335    Calculate the residuals for difference of two Data2D sets
336    """
337    # temporary default values for index and weight
338    # index = None
339    weight = None
340
341    # build residuals
342    residuals = Data2D()
343    # Not for trunk the line below, instead use the line above
344    current_data.clone_without_data(len(current_data.data), residuals)
345    residuals.data = None
346    fn = current_data.data
347    gn = reference_data.data
348    en = current_data.err_data if weight is None else weight
349    residuals.data = (fn - gn) / en
350    residuals.qx_data = current_data.qx_data
351    residuals.qy_data = current_data.qy_data
352    residuals.q_data = current_data.q_data
353    residuals.err_data = numpy.ones(len(residuals.data))
354    residuals.xmin = min(residuals.qx_data)
355    residuals.xmax = max(residuals.qx_data)
356    residuals.ymin = min(residuals.qy_data)
357    residuals.ymax = max(residuals.qy_data)
358    residuals.q_data = current_data.q_data
359    residuals.mask = current_data.mask
360    residuals.scale = 'linear'
361    # check the lengths
362    if len(residuals.data) != len(residuals.q_data):
363        return None
364    return residuals
365
366def plotResiduals(reference_data, current_data):
367    """
368    Create Data1D/Data2D with residuals, ready for plotting
369    """
370    data_copy = deepcopy(current_data)
371    # Get data: data I, theory I, and data dI in order
372    method_name = current_data.__class__.__name__
373    residuals_dict = {"Data1D": residualsData1D,
374                      "Data2D": residualsData2D}
375
376    residuals = residuals_dict[method_name](reference_data, data_copy)
377
378    theory_name = str(current_data.name.split()[0])
379    residuals.name = "Residuals for " + str(theory_name) + "[" + \
380                    str(reference_data.filename) + "]"
381    residuals.title = residuals.name
382    residuals.ytransform = 'y'
383
384    # when 2 data have the same id override the 1 st plotted
385    # include the last part if keeping charts for separate models is required
386    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
387    # group_id specify on which panel to plot this data
388    group_id = reference_data.group_id
389    residuals.group_id = "res" + str(group_id)
390
391    # Symbol
392    residuals.symbol = 0
393    residuals.hide_error = False
394
395    return residuals
396
397def binary_encode(i, digits):
398    return [i >> d & 1 for d in xrange(digits)]
399
400def getWeight(data, is2d, flag=None):
401    """
402    Received flag and compute error on data.
403    :param flag: flag to transform error of data.
404    """
405    weight = None
406    if is2d:
407        dy_data = data.err_data
408        data = data.data
409    else:
410        dy_data = data.dy
411        data = data.y
412
413    if flag == 0:
414        weight = numpy.ones_like(data)
415    elif flag == 1:
416        weight = dy_data
417    elif flag == 2:
418        weight = numpy.sqrt(numpy.abs(data))
419    elif flag == 3:
420        weight = numpy.abs(data)
421    return weight
Note: See TracBrowser for help on using the repository browser.