source: sasview/guitools/fittings.py @ 5789654

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 5789654 was 5789654, checked in by Gervaise Alina <gervyh@…>, 17 years ago

modied fitdialog property dialog added fittings , modified plottables

  • Property mode set to 100644
File size: 2.6 KB
Line 
1from scipy import optimize
2#from numpy import *
3
4
5
6class Parameter:
7    """
8        Class to handle model parameters
9    """
10    def __init__(self, model, name, value=None):
11            self.model = model
12            self.name = name
13            if not value==None:
14                self.model.setParam(self.name, value)
15           
16    def set(self, value):
17        """
18            Set the value of the parameter
19        """
20        self.model.setParam(self.name, value)
21
22    def __call__(self):
23        """
24            Return the current value of the parameter
25        """
26        return self.model.getParam(self.name)
27   
28def sansfit(model, pars, x, y, err_y ,qmin=None, qmax=None):
29    """
30        Fit function
31        @param model: sans model object
32        @param pars: list of parameters
33        @param x: vector of x data
34        @param y: vector of y data
35        @param err_y: vector of y errors
36    """
37    def f(params):
38        """
39            Calculates the vector of residuals for each point
40            in y for a given set of input parameters.
41            @param params: list of parameter values
42            @return: vector of residuals
43        """
44        i = 0
45        for p in pars:
46            p.set(params[i])
47            i += 1
48       
49        residuals = []
50        for j in range(len(x)):
51            if x[j]>qmin and x[j]<qmax:
52                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
53       
54        return residuals
55       
56    def chi2(params):
57        """
58            Calculates chi^2
59            @param params: list of parameter values
60            @return: chi^2
61        """
62        sum = 0
63        res = f(params)
64        for item in res:
65            sum += item*item
66        return sum
67       
68    p = [param() for param in pars]
69    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
70    print info, mesg, success
71    # Calculate chi squared
72    if len(pars)>1:
73        chisqr = chi2(out)
74    elif len(pars)==1:
75        chisqr = chi2([out])
76       
77    return chisqr, out, cov_x   
78
79
80def calcCommandline(self,event):
81    """
82        Testing implementation
83    """
84 
85    # Fit a Line model
86    from LineModel import Line
87    line    = Line()
88    cstA = Parameter(line, 'A', event.cstA)
89    cstB  = Parameter(line, 'B', event.cstB)       
90    y = line.run()
91    chisqr, out, cov = sansfit(line, [cstA, cstB],  event.x, y, 0) 
92    # print "Output parameters:", out
93    print "The right answer is [70.0, 1.0]"
94    print chisqr, out, cov
95
96if __name__ == "__main__": 
97   
98    calcCommandline()
Note: See TracBrowser for help on using the repository browser.