source: sasview/guitools/fittings.py @ c9aa125

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since c9aa125 was 1c94a9f1, checked in by Mathieu Doucet <doucetm@…>, 16 years ago

Fixed all sorts of bugs: replotting problems, bad logic in rescaling function, removed buggy field in dialog box, improved usability for linear fit.

  • Property mode set to 100644
File size: 2.6 KB
Line 
1from scipy import optimize
2#from numpy import *
3
4
5
6class Parameter:
7    """
8        Class to handle model parameters
9    """
10    def __init__(self, model, name, value=None):
11            self.model = model
12            self.name = name
13            if not value==None:
14                self.model.setParam(self.name, value)
15           
16    def set(self, value):
17        """
18            Set the value of the parameter
19        """
20        self.model.setParam(self.name, value)
21
22    def __call__(self):
23        """
24            Return the current value of the parameter
25        """
26        return self.model.getParam(self.name)
27   
28def sansfit(model, pars, x, y, err_y ,qmin=None, qmax=None):
29    """
30        Fit function
31        @param model: sans model object
32        @param pars: list of parameters
33        @param x: vector of x data
34        @param y: vector of y data
35        @param err_y: vector of y errors
36    """
37    def f(params):
38        """
39            Calculates the vector of residuals for each point
40            in y for a given set of input parameters.
41            @param params: list of parameter values
42            @return: vector of residuals
43        """
44        i = 0
45        for p in pars:
46            p.set(params[i])
47            i += 1
48       
49        residuals = []
50        for j in range(len(x)):
51            if x[j]>qmin and x[j]<qmax:
52                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
53       
54        return residuals
55       
56    def chi2(params):
57        """
58            Calculates chi^2
59            @param params: list of parameter values
60            @return: chi^2
61        """
62        sum = 0
63        res = f(params)
64        for item in res:
65            sum += item*item
66        return sum
67       
68    p = [param() for param in pars]
69    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
70    # Calculate chi squared
71    if len(pars)>1:
72        chisqr = chi2(out)
73    elif len(pars)==1:
74        chisqr = chi2([out])
75       
76    return chisqr, out, cov_x   
77
78
79def calcCommandline(self,event):
80    """
81        Testing implementation
82    """
83 
84    # Fit a Line model
85    from LineModel import Line
86    line    = Line()
87    cstA = Parameter(line, 'A', event.cstA)
88    cstB  = Parameter(line, 'B', event.cstB)       
89    y = line.run()
90    chisqr, out, cov = sansfit(line, [cstA, cstB],  event.x, y, 0) 
91    # print "Output parameters:", out
92    print "The right answer is [70.0, 1.0]"
93    print chisqr, out, cov
94
95if __name__ == "__main__": 
96   
97    calcCommandline()
Note: See TracBrowser for help on using the repository browser.