source: sasview/src/sas/qtgui/Plotting/Fittings.py @ 9909967

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 9909967 was dc5ef15, checked in by Piotr Rozyczko <rozyczko@…>, 8 years ago

Removed qtgui dependency on sasgui and wx SASVIEW-590

  • Property mode set to 100755
File size: 3.1 KB
Line 
1"""
2This module is used to fit a set of x,y data to a model passed to it. It is
3used to calculate the slope and intercepts for the linearized fits.  Two things
4should be noted:
5
6First, this fitting module uses the NLLSQ module of SciPy rather than a linear
7fit.  This along with a few other modules could probably be removed if we
8move to a linear regression approach.
9
10Second, this infrastructure does not allow for resolution smearing of the
11the models.  Hence the results are not that accurate even for pinhole
12collimation of SANS but may be good for SAXS.  It is completely wrong for
13slit smeared data.
14
15"""
16from scipy import optimize
17
18
19class Parameter(object):
20    """
21    Class to handle model parameters - sets the parameters and their
22    initial value from the model based to it.
23    """
24    def __init__(self, model, name, value=None):
25        self.model = model
26        self.name = name
27        if not value == None:
28            self.model.setParam(self.name, value)
29
30    def set(self, value):
31        """
32            Set the value of the parameter
33        """
34        self.model.setParam(self.name, value)
35
36    def __call__(self):
37        """
38            Return the current value of the parameter
39        """
40        return self.model.getParam(self.name)
41
42
43def sasfit(model, pars, x, y, err_y, qmin=None, qmax=None):
44    """
45    Fit function
46
47    :param model: sas model object
48    :param pars: list of parameters
49    :param x: vector of x data
50    :param y: vector of y data
51    :param err_y: vector of y errors
52    """
53    def f(params):
54        """
55        Calculates the vector of residuals for each point
56        in y for a given set of input parameters.
57
58        :param params: list of parameter values
59        :return: vector of residuals
60        """
61        i = 0
62        for p in pars:
63            p.set(params[i])
64            i += 1
65
66        residuals = []
67        for j in range(len(x)):
68            if x[j] >= qmin and x[j] <= qmax:
69                residuals.append((y[j] - model.runXY(x[j])) / err_y[j])
70        return residuals
71
72    def chi2(params):
73        """
74        Calculates chi^2
75
76        :param params: list of parameter values
77
78        :return: chi^2
79
80        """
81        sum = 0
82        res = f(params)
83        for item in res:
84            sum += item * item
85        return sum
86
87    p = [param() for param in pars]
88    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1)
89    # Calculate chi squared
90    if len(pars) > 1:
91        chisqr = chi2(out)
92    elif len(pars) == 1:
93        chisqr = chi2([out])
94
95    return chisqr, out, cov_x
96
97
98def calcCommandline(event):
99    # Testing implementation
100    # Fit a Line model
101    from LineModel import LineModel
102    line = LineModel()
103    cstA = Parameter(line, 'A', event.cstA)
104    cstB = Parameter(line, 'B', event.cstB)
105    y = line.run()
106    chisqr, out, cov = sasfit(line, [cstA, cstB], event.x, y, 0)
107    # print "Output parameters:", out
108    print "The right answer is [70.0, 1.0]"
109    print chisqr, out, cov
110
Note: See TracBrowser for help on using the repository browser.