source: sasview/src/sas/plottools/fittings.py @ ff50c51

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since ff50c51 was 2df0b74, checked in by Mathieu Doucet <doucetm@…>, 10 years ago

pylint fixes

  • Property mode set to 100644
File size: 2.4 KB
Line 
1"""
2"""
3from scipy import optimize
4
5
6class Parameter(object):
7    """
8    Class to handle model parameters
9    """
10    def __init__(self, model, name, value=None):
11        self.model = model
12        self.name = name
13        if not value == None:
14            self.model.setParam(self.name, value)
15
16    def set(self, value):
17        """
18            Set the value of the parameter
19        """
20        self.model.setParam(self.name, value)
21
22    def __call__(self):
23        """
24            Return the current value of the parameter
25        """
26        return self.model.getParam(self.name)
27
28
29def sasfit(model, pars, x, y, err_y, qmin=None, qmax=None):
30    """
31    Fit function
32
33    :param model: sas model object
34    :param pars: list of parameters
35    :param x: vector of x data
36    :param y: vector of y data
37    :param err_y: vector of y errors
38    """
39    def f(params):
40        """
41        Calculates the vector of residuals for each point
42        in y for a given set of input parameters.
43
44        :param params: list of parameter values
45        :return: vector of residuals
46        """
47        i = 0
48        for p in pars:
49            p.set(params[i])
50            i += 1
51
52        residuals = []
53        for j in range(len(x)):
54            if x[j] >= qmin and x[j] <= qmax:
55                residuals.append((y[j] - model.runXY(x[j])) / err_y[j])
56        return residuals
57
58    def chi2(params):
59        """
60        Calculates chi^2
61
62        :param params: list of parameter values
63
64        :return: chi^2
65
66        """
67        sum = 0
68        res = f(params)
69        for item in res:
70            sum += item * item
71        return sum
72
73    p = [param() for param in pars]
74    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1)
75    # Calculate chi squared
76    if len(pars) > 1:
77        chisqr = chi2(out)
78    elif len(pars) == 1:
79        chisqr = chi2([out])
80
81    return chisqr, out, cov_x
82
83
84def calcCommandline(event):
85    # Testing implementation
86    # Fit a Line model
87    from LineModel import LineModel
88    line = LineModel()
89    cstA = Parameter(line, 'A', event.cstA)
90    cstB = Parameter(line, 'B', event.cstB)
91    y = line.run()
92    chisqr, out, cov = sasfit(line, [cstA, cstB], event.x, y, 0)
93    # print "Output parameters:", out
94    print "The right answer is [70.0, 1.0]"
95    print chisqr, out, cov
Note: See TracBrowser for help on using the repository browser.