source: sasview/src/danse/common/plottools/fittings.py @ f468791

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since f468791 was 5777106, checked in by Mathieu Doucet <doucetm@…>, 11 years ago

Moving things around. Will definitely not build.

  • Property mode set to 100644
File size: 2.5 KB
RevLine 
[82a54b8]1"""
2"""
3from scipy import optimize
[10bfeb3]4
[82a54b8]5
6class Parameter:
7    """
8    Class to handle model parameters
9    """
10    def __init__(self, model, name, value=None):
11            self.model = model
12            self.name = name
13            if not value == None:
14                self.model.setParam(self.name, value)
15           
16    def set(self, value):
17        """
18            Set the value of the parameter
19        """
20        self.model.setParam(self.name, value)
21
22    def __call__(self):
[10bfeb3]23        """
[82a54b8]24            Return the current value of the parameter
25        """
26        return self.model.getParam(self.name)
27   
[10bfeb3]28   
[82a54b8]29def sansfit(model, pars, x, y, err_y , qmin=None, qmax=None):
30    """
31    Fit function
32   
33    :param model: sans model object
34    :param pars: list of parameters
35    :param x: vector of x data
36    :param y: vector of y data
[10bfeb3]37    :param err_y: vector of y errors
[82a54b8]38    """
39    def f(params):
40        """
[10bfeb3]41        Calculates the vector of residuals for each point
[82a54b8]42        in y for a given set of input parameters.
43       
44        :param params: list of parameter values
45        :return: vector of residuals
46        """
47        i = 0
48        for p in pars:
49            p.set(params[i])
50            i += 1
51       
52        residuals = []
53        for j in range(len(x)):
54            if x[j] >= qmin and x[j] <= qmax:
55                residuals.append((y[j] - model.runXY(x[j])) / err_y[j])
56        return residuals
57       
58    def chi2(params):
59        """
60        Calculates chi^2
61       
62        :param params: list of parameter values
63       
64        :return: chi^2
65       
66        """
67        sum = 0
68        res = f(params)
69        for item in res:
70            sum += item * item
71        return sum
72       
73    p = [param() for param in pars]
[cfe1feb]74    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1)
[82a54b8]75    # Calculate chi squared
76    if len(pars) > 1:
77        chisqr = chi2(out)
78    elif len(pars) == 1:
79        chisqr = chi2([out])
80       
[10bfeb3]81    return chisqr, out, cov_x
82
[82a54b8]83
84def calcCommandline(event):
85    #Testing implementation
86    # Fit a Line model
87    from LineModel import LineModel
[10bfeb3]88    line = LineModel()
[82a54b8]89    cstA = Parameter(line, 'A', event.cstA)
[10bfeb3]90    cstB = Parameter(line, 'B', event.cstB)
[82a54b8]91    y = line.run()
[10bfeb3]92    chisqr, out, cov = sansfit(line, [cstA, cstB], event.x, y, 0)
[82a54b8]93    # print "Output parameters:", out
94    print "The right answer is [70.0, 1.0]"
95    print chisqr, out, cov
Note: See TracBrowser for help on using the repository browser.