source: sasview/src/sas/sasgui/plottools/fittings.py @ 934ce649

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 934ce649 was d7bb526, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 9 years ago

Refactored plottools into sasgui

  • Property mode set to 100644
File size: 2.4 KB
RevLine 
[a9d5684]1"""
2"""
3from scipy import optimize
4
5
[2df0b74]6class Parameter(object):
[a9d5684]7    """
8    Class to handle model parameters
9    """
10    def __init__(self, model, name, value=None):
[2df0b74]11        self.model = model
12        self.name = name
13        if not value == None:
14            self.model.setParam(self.name, value)
15
[a9d5684]16    def set(self, value):
17        """
18            Set the value of the parameter
19        """
20        self.model.setParam(self.name, value)
21
22    def __call__(self):
23        """
24            Return the current value of the parameter
25        """
26        return self.model.getParam(self.name)
[2df0b74]27
28
29def sasfit(model, pars, x, y, err_y, qmin=None, qmax=None):
[a9d5684]30    """
31    Fit function
[2df0b74]32
[79492222]33    :param model: sas model object
[a9d5684]34    :param pars: list of parameters
35    :param x: vector of x data
36    :param y: vector of y data
37    :param err_y: vector of y errors
38    """
39    def f(params):
40        """
41        Calculates the vector of residuals for each point
42        in y for a given set of input parameters.
[2df0b74]43
[a9d5684]44        :param params: list of parameter values
45        :return: vector of residuals
46        """
47        i = 0
48        for p in pars:
49            p.set(params[i])
50            i += 1
[2df0b74]51
[a9d5684]52        residuals = []
53        for j in range(len(x)):
54            if x[j] >= qmin and x[j] <= qmax:
55                residuals.append((y[j] - model.runXY(x[j])) / err_y[j])
56        return residuals
[2df0b74]57
[a9d5684]58    def chi2(params):
59        """
60        Calculates chi^2
[2df0b74]61
[a9d5684]62        :param params: list of parameter values
[2df0b74]63
[a9d5684]64        :return: chi^2
[2df0b74]65
[a9d5684]66        """
67        sum = 0
68        res = f(params)
69        for item in res:
70            sum += item * item
71        return sum
[2df0b74]72
[a9d5684]73    p = [param() for param in pars]
74    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1)
75    # Calculate chi squared
76    if len(pars) > 1:
77        chisqr = chi2(out)
78    elif len(pars) == 1:
79        chisqr = chi2([out])
[2df0b74]80
[a9d5684]81    return chisqr, out, cov_x
82
83
84def calcCommandline(event):
[2df0b74]85    # Testing implementation
[a9d5684]86    # Fit a Line model
87    from LineModel import LineModel
88    line = LineModel()
89    cstA = Parameter(line, 'A', event.cstA)
90    cstB = Parameter(line, 'B', event.cstB)
91    y = line.run()
[b9a5f0e]92    chisqr, out, cov = sasfit(line, [cstA, cstB], event.x, y, 0)
[a9d5684]93    # print "Output parameters:", out
94    print "The right answer is [70.0, 1.0]"
95    print chisqr, out, cov
Note: See TracBrowser for help on using the repository browser.