source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @ d4dac80

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since d4dac80 was d4dac80, checked in by Piotr Rozyczko <rozyczko@…>, 6 years ago

Merge branch 'ESS_GUI' into ESS_GUI_better_batch

  • Property mode set to 100644
File size: 15.4 KB
RevLine 
[fde5bcd]1import copy
[1bc27f1]2
[4992ff2]3from PyQt5 import QtCore
4from PyQt5 import QtGui
5from PyQt5 import QtWidgets
[4d457df]6
[6fd4e36]7import numpy
8
[dc5ef15]9from sas.qtgui.Plotting.PlotterData import Data1D
10from sas.qtgui.Plotting.PlotterData import Data2D
[6fd4e36]11
[f54ce30]12model_header_captions = ['Parameter', 'Value', 'Min', 'Max', 'Units']
13
14model_header_tooltips = ['Select parameter for fitting',
[d0dfcb2]15                         'Enter parameter value',
16                         'Enter minimum value for parameter',
17                         'Enter maximum value for parameter',
18                         'Unit of the parameter']
[f54ce30]19
20poly_header_captions = ['Parameter', 'PD[ratio]', 'Min', 'Max', 'Npts', 'Nsigs',
[d0dfcb2]21                        'Function', 'Filename']
[f54ce30]22
23poly_header_tooltips = ['Select parameter for fitting',
[d0dfcb2]24                        'Enter polydispersity ratio (STD/mean). '
25                        'STD: standard deviation from the mean value',
26                        'Enter minimum value for parameter',
27                        'Enter maximum value for parameter',
28                        'Enter number of points for parameter',
29                        'Enter number of sigmas parameter',
30                        'Select distribution function',
31                        'Select filename with user-definable distribution']
[f54ce30]32
33error_tooltip = 'Error value for fitted parameter'
34header_error_caption = 'Error'
35
[4d457df]36def replaceShellName(param_name, value):
37    """
38    Updates parameter name from <param_name>[n_shell] to <param_name>value
39    """
40    assert '[' in param_name
41    return param_name[:param_name.index('[')]+str(value)
42
43def getIterParams(model):
44    """
45    Returns a list of all multi-shell parameters in 'model'
46    """
[b3e8629]47    return list([par for par in model.iq_parameters if "[" in par.name])
[4d457df]48
49def getMultiplicity(model):
50    """
51    Finds out if 'model' has multishell parameters.
52    If so, returns the name of the counter parameter and the number of shells
53    """
54    iter_params = getIterParams(model)
[a9b568c]55    param_name = ""
56    param_length = 0
57    if iter_params:
58        param_length = iter_params[0].length
59        param_name = iter_params[0].length_control
60        if param_name is None and '[' in iter_params[0].name:
61            param_name = iter_params[0].name[:iter_params[0].name.index('[')]
62    return (param_name, param_length)
[4d457df]63
[aca8418]64def addParametersToModel(parameters, kernel_module, is2D):
[4d457df]65    """
66    Update local ModelModel with sasmodel parameters
67    """
68    multishell_parameters = getIterParams(parameters)
69    multishell_param_name, _ = getMultiplicity(parameters)
[1970780]70    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
71    item = []
72    for param in params:
[4d457df]73        # don't include shell parameters
74        if param.name == multishell_param_name:
75            continue
76        # Modify parameter name from <param>[n] to <param>1
77        item_name = param.name
78        if param in multishell_parameters:
[b1e36a3]79            continue
80        #    item_name = replaceShellName(param.name, 1)
[4d457df]81
82        item1 = QtGui.QStandardItem(item_name)
83        item1.setCheckable(True)
[2add354]84        item1.setEditable(False)
[1bc27f1]85        # item_err = QtGui.QStandardItem()
[4d457df]86        # check for polydisp params
87        if param.polydisperse:
88            poly_item = QtGui.QStandardItem("Polydispersity")
[2add354]89            poly_item.setEditable(False)
[4d457df]90            item1_1 = QtGui.QStandardItem("Distribution")
[2add354]91            item1_1.setEditable(False)
[4d457df]92            # Find param in volume_params
93            for p in parameters.form_volume_parameters:
94                if p.name != param.name:
95                    continue
[aca8418]96                width = kernel_module.getParam(p.name+'.width')
97                type = kernel_module.getParam(p.name+'.type')
98
99                item1_2 = QtGui.QStandardItem(str(width))
[2add354]100                item1_2.setEditable(False)
[aca8418]101                item1_3 = QtGui.QStandardItem()
[2add354]102                item1_3.setEditable(False)
[aca8418]103                item1_4 = QtGui.QStandardItem()
[2add354]104                item1_4.setEditable(False)
[aca8418]105                item1_5 = QtGui.QStandardItem(type)
[2add354]106                item1_5.setEditable(False)
[4d457df]107                poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
108                break
109            # Add the polydisp item as a child
110            item1.appendRow([poly_item])
111        # Param values
112        item2 = QtGui.QStandardItem(str(param.default))
113        # TODO: the error column.
114        # Either add a proxy model or a custom view delegate
115        #item_err = QtGui.QStandardItem()
116        item3 = QtGui.QStandardItem(str(param.limits[0]))
117        item4 = QtGui.QStandardItem(str(param.limits[1]))
118        item5 = QtGui.QStandardItem(param.units)
[2add354]119        item5.setEditable(False)
[1970780]120        item.append([item1, item2, item3, item4, item5])
121    return item
[4d457df]122
[1970780]123def addSimpleParametersToModel(parameters, is2D):
[4d457df]124    """
125    Update local ModelModel with sasmodel parameters
126    """
[1970780]127    params = parameters.iqxy_parameters if is2D else parameters.iq_parameters
128    item = []
129    for param in params:
[7248d75d]130        # Create the top level, checkable item
[4d457df]131        item_name = param.name
132        item1 = QtGui.QStandardItem(item_name)
133        item1.setCheckable(True)
[2add354]134        item1.setEditable(False)
[4d457df]135        # Param values
[2add354]136        # TODO: add delegate for validation of cells
[4d457df]137        item2 = QtGui.QStandardItem(str(param.default))
[f182f93]138        item4 = QtGui.QStandardItem(str(param.limits[0]))
139        item5 = QtGui.QStandardItem(str(param.limits[1]))
140        item6 = QtGui.QStandardItem(param.units)
[2add354]141        item6.setEditable(False)
[1970780]142        item.append([item1, item2, item4, item5, item6])
143    return item
[4d457df]144
145def addCheckedListToModel(model, param_list):
146    """
147    Add a QItem to model. Makes the QItem checkable
148    """
149    assert isinstance(model, QtGui.QStandardItemModel)
150    item_list = [QtGui.QStandardItem(item) for item in param_list]
151    item_list[0].setCheckable(True)
152    model.appendRow(item_list)
153
154def addHeadersToModel(model):
155    """
156    Adds predefined headers to the model
157    """
[f54ce30]158    for i, item in enumerate(model_header_captions):
[b3e8629]159        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
[f54ce30]160
[fde5bcd]161    model.header_tooltips = copy.copy(model_header_tooltips)
[4d457df]162
[f182f93]163def addErrorHeadersToModel(model):
164    """
165    Adds predefined headers to the model
166    """
[fde5bcd]167    model_header_error_captions = copy.copy(model_header_captions)
[f54ce30]168    model_header_error_captions.insert(2, header_error_caption)
169    for i, item in enumerate(model_header_error_captions):
[b3e8629]170        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
[f182f93]171
[fde5bcd]172    model_header_error_tooltips = copy.copy(model_header_tooltips)
[f54ce30]173    model_header_error_tooltips.insert(2, error_tooltip)
[fde5bcd]174    model.header_tooltips = copy.copy(model_header_error_tooltips)
[a95c44b]175
[4d457df]176def addPolyHeadersToModel(model):
177    """
178    Adds predefined headers to the model
179    """
[f54ce30]180    for i, item in enumerate(poly_header_captions):
[b3e8629]181        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
[f54ce30]182
[fde5bcd]183    model.header_tooltips = copy.copy(poly_header_tooltips)
[4d457df]184
[a95c44b]185
[aca8418]186def addErrorPolyHeadersToModel(model):
187    """
188    Adds predefined headers to the model
189    """
[fde5bcd]190    poly_header_error_captions = copy.copy(poly_header_captions)
[f54ce30]191    poly_header_error_captions.insert(2, header_error_caption)
192    for i, item in enumerate(poly_header_error_captions):
[b3e8629]193        model.setHeaderData(i, QtCore.Qt.Horizontal, item)
[f54ce30]194
[fde5bcd]195    poly_header_error_tooltips = copy.copy(poly_header_tooltips)
[f54ce30]196    poly_header_error_tooltips.insert(2, error_tooltip)
[fde5bcd]197    model.header_tooltips = copy.copy(poly_header_error_tooltips)
[a95c44b]198
[4d457df]199def addShellsToModel(parameters, model, index):
200    """
201    Find out multishell parameters and update the model with the requested number of them
202    """
203    multishell_parameters = getIterParams(parameters)
204
[b3e8629]205    for i in range(index):
[4d457df]206        for par in multishell_parameters:
[b1e36a3]207            # Create the name: <param>[<i>], e.g. "sld1" for parameter "sld[n]"
208            param_name = replaceShellName(par.name, i+1)
[4d457df]209            item1 = QtGui.QStandardItem(param_name)
210            item1.setCheckable(True)
211            # check for polydisp params
212            if par.polydisperse:
213                poly_item = QtGui.QStandardItem("Polydispersity")
214                item1_1 = QtGui.QStandardItem("Distribution")
215                # Find param in volume_params
216                for p in parameters.form_volume_parameters:
217                    if p.name != par.name:
218                        continue
219                    item1_2 = QtGui.QStandardItem(str(p.default))
220                    item1_3 = QtGui.QStandardItem(str(p.limits[0]))
221                    item1_4 = QtGui.QStandardItem(str(p.limits[1]))
222                    item1_5 = QtGui.QStandardItem(p.units)
223                    poly_item.appendRow([item1_1, item1_2, item1_3, item1_4, item1_5])
224                    break
225                item1.appendRow([poly_item])
226
227            item2 = QtGui.QStandardItem(str(par.default))
228            item3 = QtGui.QStandardItem(str(par.limits[0]))
229            item4 = QtGui.QStandardItem(str(par.limits[1]))
230            item5 = QtGui.QStandardItem(par.units)
231            model.appendRow([item1, item2, item3, item4, item5])
232
[6fd4e36]233def calculateChi2(reference_data, current_data):
234    """
235    Calculate Chi2 value between two sets of data
236    """
237
238    # WEIGHING INPUT
239    #from sas.sasgui.perspectives.fitting.utils import get_weight
240    #flag = self.get_weight_flag()
241    #weight = get_weight(data=self.data, is2d=self._is_2D(), flag=flag)
[1bc27f1]242    chisqr = None
243    if reference_data is None:
244        return chisqr
[6fd4e36]245
246    # temporary default values for index and weight
247    index = None
248    weight = None
249
250    # Get data: data I, theory I, and data dI in order
251    if isinstance(reference_data, Data2D):
[1bc27f1]252        if index is None:
[6fd4e36]253            index = numpy.ones(len(current_data.data), dtype=bool)
[1bc27f1]254        if weight is not None:
[6fd4e36]255            current_data.err_data = weight
256        # get rid of zero error points
257        index = index & (current_data.err_data != 0)
258        index = index & (numpy.isfinite(current_data.data))
259        fn = current_data.data[index]
260        gn = reference_data.data[index]
261        en = current_data.err_data[index]
262    else:
263        # 1 d theory from model_thread is only in the range of index
[1bc27f1]264        if index is None:
[6fd4e36]265            index = numpy.ones(len(current_data.y), dtype=bool)
[1bc27f1]266        if weight is not None:
[6fd4e36]267            current_data.dy = weight
[1bc27f1]268        if current_data.dy is None or current_data.dy == []:
[6fd4e36]269            dy = numpy.ones(len(current_data.y))
270        else:
271            ## Set consistently w/AbstractFitengine:
272            # But this should be corrected later.
[fde5bcd]273            dy = copy.deepcopy(current_data.dy)
[6fd4e36]274            dy[dy == 0] = 1
275        fn = current_data.y[index]
276        gn = reference_data.y
277        en = dy[index]
278    # Calculate the residual
279    try:
280        res = (fn - gn) / en
281    except ValueError:
[180bd54]282        #print "Chi2 calculations: Unmatched lengths %s, %s, %s" % (len(fn), len(gn), len(en))
[0268aed]283        return None
[6fd4e36]284
285    residuals = res[numpy.isfinite(res)]
286    chisqr = numpy.average(residuals * residuals)
287
288    return chisqr
289
[0268aed]290def residualsData1D(reference_data, current_data):
291    """
[7d077d1]292    Calculate the residuals for difference of two Data1D sets
[0268aed]293    """
294    # temporary default values for index and weight
295    index = None
296    weight = None
297
298    # 1d theory from model_thread is only in the range of index
[180bd54]299    if current_data.dy is None or current_data.dy == []:
[0268aed]300        dy = numpy.ones(len(current_data.y))
301    else:
[180bd54]302        dy = weight if weight is not None else numpy.ones(len(current_data.y))
[0268aed]303        dy[dy == 0] = 1
304    fn = current_data.y[index][0]
305    gn = reference_data.y
306    en = dy[index][0]
307    # build residuals
308    residuals = Data1D()
[180bd54]309    if len(fn) == len(gn):
[0268aed]310        y = (fn - gn)/en
311        residuals.y = -y
[180bd54]312    else:
[d48cc19]313        # TODO: fix case where applying new data from file on top of existing model data
[689222c]314        try:
315            y = (fn - gn[index][0]) / en
316            residuals.y = y
317        except ValueError:
318            # value errors may show up every once in a while for malformed columns,
319            # just reuse what's there already
320            pass
[180bd54]321
[0268aed]322    residuals.x = current_data.x[index][0]
323    residuals.dy = numpy.ones(len(residuals.y))
324    residuals.dx = None
325    residuals.dxl = None
326    residuals.dxw = None
327    residuals.ytransform = 'y'
[1bc27f1]328    # For latter scale changes
[0268aed]329    residuals.xaxis('\\rm{Q} ', 'A^{-1}')
330    residuals.yaxis('\\rm{Residuals} ', 'normalized')
331
332    return residuals
333
334def residualsData2D(reference_data, current_data):
335    """
[7d077d1]336    Calculate the residuals for difference of two Data2D sets
[0268aed]337    """
338    # temporary default values for index and weight
[1bc27f1]339    # index = None
[0268aed]340    weight = None
341
342    # build residuals
343    residuals = Data2D()
344    # Not for trunk the line below, instead use the line above
345    current_data.clone_without_data(len(current_data.data), residuals)
346    residuals.data = None
347    fn = current_data.data
348    gn = reference_data.data
[180bd54]349    en = current_data.err_data if weight is None else weight
[0268aed]350    residuals.data = (fn - gn) / en
351    residuals.qx_data = current_data.qx_data
352    residuals.qy_data = current_data.qy_data
353    residuals.q_data = current_data.q_data
354    residuals.err_data = numpy.ones(len(residuals.data))
355    residuals.xmin = min(residuals.qx_data)
356    residuals.xmax = max(residuals.qx_data)
357    residuals.ymin = min(residuals.qy_data)
358    residuals.ymax = max(residuals.qy_data)
359    residuals.q_data = current_data.q_data
360    residuals.mask = current_data.mask
361    residuals.scale = 'linear'
362    # check the lengths
363    if len(residuals.data) != len(residuals.q_data):
364        return None
365    return residuals
366
367def plotResiduals(reference_data, current_data):
368    """
369    Create Data1D/Data2D with residuals, ready for plotting
370    """
[fde5bcd]371    data_copy = copy.deepcopy(current_data)
[0268aed]372    # Get data: data I, theory I, and data dI in order
373    method_name = current_data.__class__.__name__
374    residuals_dict = {"Data1D": residualsData1D,
375                      "Data2D": residualsData2D}
376
377    residuals = residuals_dict[method_name](reference_data, data_copy)
378
379    theory_name = str(current_data.name.split()[0])
380    residuals.name = "Residuals for " + str(theory_name) + "[" + \
381                    str(reference_data.filename) + "]"
382    residuals.title = residuals.name
[f182f93]383    residuals.ytransform = 'y'
384
[0268aed]385    # when 2 data have the same id override the 1 st plotted
386    # include the last part if keeping charts for separate models is required
387    residuals.id = "res" + str(reference_data.id) # + str(theory_name)
388    # group_id specify on which panel to plot this data
389    group_id = reference_data.group_id
390    residuals.group_id = "res" + str(group_id)
[1bc27f1]391
[0268aed]392    # Symbol
393    residuals.symbol = 0
394    residuals.hide_error = False
395
396    return residuals
397
[6fd4e36]398def binary_encode(i, digits):
[b3e8629]399    return [i >> d & 1 for d in range(digits)]
[6fd4e36]400
[fd1ae6d1]401def getWeight(data, is2d, flag=None):
402    """
403    Received flag and compute error on data.
404    :param flag: flag to transform error of data.
405    """
406    weight = None
407    if is2d:
408        dy_data = data.err_data
409        data = data.data
410    else:
411        dy_data = data.dy
412        data = data.y
413
414    if flag == 0:
415        weight = numpy.ones_like(data)
416    elif flag == 1:
417        weight = dy_data
418    elif flag == 2:
419        weight = numpy.sqrt(numpy.abs(data))
420    elif flag == 3:
421        weight = numpy.abs(data)
422    return weight
[d4dac80]423
424def updateKernelWithResults(kernel, results):
425    """
426    Takes model kernel and applies results dict to its parameters,
427    returning the modified (deep) copy of the kernel.
428    """
429    assert(isinstance(results, dict))
430    local_kernel = copy.deepcopy(kernel)
431
432    for parameter in results.keys():
433        # Update the parameter value - note: this supports +/-inf as well
434        local_kernel.setParam(parameter, results[parameter][0])
435
436    return local_kernel
437
438
Note: See TracBrowser for help on using the repository browser.