source: sasview/src/sas/qtgui/Plotting/Fittings.py @ 46ca1f4

Last change on this file since 46ca1f4 was cee5c78, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Converted more syntax not covered by 2to3

  • Property mode set to 100644
File size: 3.0 KB
Line 
1"""
2This module is used to fit a set of x,y data to a model passed to it. It is
3used to calculate the slope and intercepts for the linearized fits.  Two things
4should be noted:
5
6First, this fitting module uses the NLLSQ module of SciPy rather than a linear
7fit.  This along with a few other modules could probably be removed if we
8move to a linear regression approach.
9
10Second, this infrastructure does not allow for resolution smearing of the
11the models.  Hence the results are not that accurate even for pinhole
12collimation of SANS but may be good for SAXS.  It is completely wrong for
13slit smeared data.
14
15"""
16class Parameter(object):
17    """
18    Class to handle model parameters - sets the parameters and their
19    initial value from the model based to it.
20    """
21    def __init__(self, model, name, value=None):
22        self.model = model
23        self.name = name
24        if value is not None:
25            self.model.setParam(self.name, value)
26
27    def set(self, value):
28        """
29            Set the value of the parameter
30        """
31        self.model.setParam(self.name, value)
32
33    def __call__(self):
34        """
35            Return the current value of the parameter
36        """
37        return self.model.getParam(self.name)
38
39
40def sasfit(model, pars, x, y, err_y, qmin=None, qmax=None):
41    """
42    Fit function
43
44    :param model: sas model object
45    :param pars: list of parameters
46    :param x: vector of x data
47    :param y: vector of y data
48    :param err_y: vector of y errors
49    """
50    def f(params):
51        """
52        Calculates the vector of residuals for each point
53        in y for a given set of input parameters.
54
55        :param params: list of parameter values
56        :return: vector of residuals
57        """
58        i = 0
59        for p in pars:
60            p.set(params[i])
61            i += 1
62
63        residuals = []
64        for j in range(len(x)):
65            if x[j] >= qmin and x[j] <= qmax:
66                residuals.append((y[j] - model.runXY(x[j])) / err_y[j])
67        return residuals
68
69    def chi2(params):
70        """
71        Calculates chi^2
72
73        :param params: list of parameter values
74
75        :return: chi^2
76
77        """
78        sum = 0
79        res = f(params)
80        for item in res:
81            sum += item * item
82        return sum
83    # This import takes a long time, which is why it's here not in the top of file
84    from scipy.optimize import leastsq
85    p = [param() for param in pars]
86    out, cov_x, info, mesg, success = leastsq(f, p, full_output=1)
87    # Calculate chi squared
88    if len(pars) > 1:
89        chisqr = chi2(out)
90    elif len(pars) == 1:
91        chisqr = chi2([out])
92
93    return chisqr, out, cov_x
94
95
96def calcCommandline(event):
97    # Testing implementation
98    # Fit a Line model
99    from .LineModel import LineModel
100    line = LineModel()
101    cstA = Parameter(line, 'A', event.cstA)
102    cstB = Parameter(line, 'B', event.cstB)
103    y = line.run()
104    chisqr, out, cov = sasfit(line, [cstA, cstB], event.x, y, 0)
105    # print "Output parameters:", out
106    print("The right answer is [70.0, 1.0]")
107    print(chisqr, out, cov)
108
Note: See TracBrowser for help on using the repository browser.