source: sasview/src/sas/sasgui/plottools/fittings.py @ 3388337

ticket-1094-headless
Last change on this file since 3388337 was a1b8fee, checked in by andyfaff, 8 years ago

MAINT: from future import print_function

  • Property mode set to 100644
File size: 3.0 KB
RevLine 
[959eb01]1"""
2This module is used to fit a set of x,y data to a model passed to it. It is
3used to calculate the slope and intercepts for the linearized fits.  Two things
4should be noted:
5
6First, this fitting module uses the NLLSQ module of SciPy rather than a linear
7fit.  This along with a few other modules could probably be removed if we
8move to a linear regression approach.
9
10Second, this infrastructure does not allow for resolution smearing of the
11the models.  Hence the results are not that accurate even for pinhole
12collimation of SANS but may be good for SAXS.  It is completely wrong for
13slit smeared data.
14
15"""
[a1b8fee]16from __future__ import print_function
17
[959eb01]18from scipy import optimize
19
20
21class Parameter(object):
22    """
23    Class to handle model parameters - sets the parameters and their
24    initial value from the model based to it.
25    """
26    def __init__(self, model, name, value=None):
27        self.model = model
28        self.name = name
[ac07a3a]29        if value is not None:
[959eb01]30            self.model.setParam(self.name, value)
31
32    def set(self, value):
33        """
34            Set the value of the parameter
35        """
36        self.model.setParam(self.name, value)
37
38    def __call__(self):
39        """
40            Return the current value of the parameter
41        """
42        return self.model.getParam(self.name)
43
44
45def sasfit(model, pars, x, y, err_y, qmin=None, qmax=None):
46    """
47    Fit function
48
49    :param model: sas model object
50    :param pars: list of parameters
51    :param x: vector of x data
52    :param y: vector of y data
53    :param err_y: vector of y errors
54    """
55    def f(params):
56        """
57        Calculates the vector of residuals for each point
58        in y for a given set of input parameters.
59
60        :param params: list of parameter values
61        :return: vector of residuals
62        """
63        i = 0
64        for p in pars:
65            p.set(params[i])
66            i += 1
67
68        residuals = []
69        for j in range(len(x)):
70            if x[j] >= qmin and x[j] <= qmax:
71                residuals.append((y[j] - model.runXY(x[j])) / err_y[j])
72        return residuals
73
74    def chi2(params):
75        """
76        Calculates chi^2
77
78        :param params: list of parameter values
79
80        :return: chi^2
81
82        """
83        sum = 0
84        res = f(params)
85        for item in res:
86            sum += item * item
87        return sum
88
89    p = [param() for param in pars]
90    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1)
91    # Calculate chi squared
92    if len(pars) > 1:
93        chisqr = chi2(out)
94    elif len(pars) == 1:
95        chisqr = chi2([out])
96
97    return chisqr, out, cov_x
98
99
100def calcCommandline(event):
101    # Testing implementation
102    # Fit a Line model
103    from LineModel import LineModel
104    line = LineModel()
105    cstA = Parameter(line, 'A', event.cstA)
106    cstB = Parameter(line, 'B', event.cstB)
107    y = line.run()
108    chisqr, out, cov = sasfit(line, [cstA, cstB], event.x, y, 0)
109    # print "Output parameters:", out
[9c3d784]110    print("The right answer is [70.0, 1.0]")
111    print(chisqr, out, cov)
Note: See TracBrowser for help on using the repository browser.