source: sasview/src/sas/sasgui/plottools/fittings.py @ 25dd9c9

Last change on this file since 25dd9c9 was ac07a3a, checked in by andyfaff, 8 years ago

MAINT: replace 'not x is None' by 'x is not None'

  • Property mode set to 100644
File size: 2.9 KB
Line 
1"""
2This module is used to fit a set of x,y data to a model passed to it. It is
3used to calculate the slope and intercepts for the linearized fits.  Two things
4should be noted:
5
6First, this fitting module uses the NLLSQ module of SciPy rather than a linear
7fit.  This along with a few other modules could probably be removed if we
8move to a linear regression approach.
9
10Second, this infrastructure does not allow for resolution smearing of the
11the models.  Hence the results are not that accurate even for pinhole
12collimation of SANS but may be good for SAXS.  It is completely wrong for
13slit smeared data.
14
15"""
16from scipy import optimize
17
18
19class Parameter(object):
20    """
21    Class to handle model parameters - sets the parameters and their
22    initial value from the model based to it.
23    """
24    def __init__(self, model, name, value=None):
25        self.model = model
26        self.name = name
27        if value is not None:
28            self.model.setParam(self.name, value)
29
30    def set(self, value):
31        """
32            Set the value of the parameter
33        """
34        self.model.setParam(self.name, value)
35
36    def __call__(self):
37        """
38            Return the current value of the parameter
39        """
40        return self.model.getParam(self.name)
41
42
43def sasfit(model, pars, x, y, err_y, qmin=None, qmax=None):
44    """
45    Fit function
46
47    :param model: sas model object
48    :param pars: list of parameters
49    :param x: vector of x data
50    :param y: vector of y data
51    :param err_y: vector of y errors
52    """
53    def f(params):
54        """
55        Calculates the vector of residuals for each point
56        in y for a given set of input parameters.
57
58        :param params: list of parameter values
59        :return: vector of residuals
60        """
61        i = 0
62        for p in pars:
63            p.set(params[i])
64            i += 1
65
66        residuals = []
67        for j in range(len(x)):
68            if x[j] >= qmin and x[j] <= qmax:
69                residuals.append((y[j] - model.runXY(x[j])) / err_y[j])
70        return residuals
71
72    def chi2(params):
73        """
74        Calculates chi^2
75
76        :param params: list of parameter values
77
78        :return: chi^2
79
80        """
81        sum = 0
82        res = f(params)
83        for item in res:
84            sum += item * item
85        return sum
86
87    p = [param() for param in pars]
88    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1)
89    # Calculate chi squared
90    if len(pars) > 1:
91        chisqr = chi2(out)
92    elif len(pars) == 1:
93        chisqr = chi2([out])
94
95    return chisqr, out, cov_x
96
97
98def calcCommandline(event):
99    # Testing implementation
100    # Fit a Line model
101    from LineModel import LineModel
102    line = LineModel()
103    cstA = Parameter(line, 'A', event.cstA)
104    cstB = Parameter(line, 'B', event.cstB)
105    y = line.run()
106    chisqr, out, cov = sasfit(line, [cstA, cstB], event.x, y, 0)
107    # print "Output parameters:", out
108    print "The right answer is [70.0, 1.0]"
109    print chisqr, out, cov
Note: See TracBrowser for help on using the repository browser.