source: sasview/park_integration/docs/FittingModule.py @ f24b8f4

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since f24b8f4 was f24b8f4, checked in by Gervaise Alina <gervyh@…>, 16 years ago

made some chagnges on the unit test

  • Property mode set to 100644
File size: 7.9 KB
Line 
1#class Fitting
2from sans.guitools.plottables import Data1D
3from Loader import Load
4from scipy import optimize
5
6
7class FitArrange:
8    def __init__(self):
9        """
10            Store a set of data for a given model to perform the Fit
11            @param model: the model selected by the user
12            @param Ldata: a list of data what the user want to fit
13        """
14        self.model = None
15        self.dList =[]
16       
17    def set_model(self,model):
18        """ set the model """
19        self.model = model
20       
21    def add_data(self,data):
22        """
23            @param data: Data to add in the list
24            fill a self.dataList with data to fit
25        """
26        if not data in self.dList:
27            self.dList.append(data)
28           
29    def get_model(self):
30        """ Return the model"""
31        return self.model   
32     
33    def get_data(self):
34        """ Return list of data"""
35        return self.dList
36     
37    def remove_data(self,data):
38        """
39            Remove one element from the list
40            @param data: Data to remove from the the lsit of data
41        """
42        if data in self.dList:
43            self.dList.remove(data)
44           
45class Fitting:
46    """
47        Performs the Fit.he user determine what kind of data
48    """
49    def __init__(self,data=[]):
50        #this is a dictionary of FitArrange elements
51        self.fitArrangeList={}
52        #the constraint of the Fit
53        self.constraint =None
54        #Specify the use of scipy or park fit
55        self.fitType =None
56       
57    def fit_engine(self,word):
58        """
59            Check the contraint value and specify what kind of fit to use
60        """
61        self.fitType = word
62        return True
63   
64    def fit(self,pars, qmin=None, qmax=None):
65        """
66             Do the fit
67        """
68        #for item in self.fitArrangeList.:
69       
70        fitproblem=self.fitArrangeList.values()[0]
71        listdata=[]
72        model = fitproblem.get_model()
73        listdata = fitproblem.get_data()
74       
75        parameters = self.set_param(model,pars)
76        if listdata==[]:
77            raise ValueError, " data list missing"
78        else:
79            # Do the fit with more than one data set and one model
80            xtemp=[]
81            ytemp=[]
82            dytemp=[]
83           
84            for data in listdata:
85                for i in range(len(data.x)):
86                    if not data.x[i] in xtemp:
87                        xtemp.append(data.x[i])
88                       
89                    if not data.y[i] in ytemp:
90                        ytemp.append(data.y[i])
91                       
92                    if not data.dy[i] in dytemp:
93                        dytemp.append(data.dy[i])
94            if qmin==None:
95                qmin= min(xtemp)
96            if qmax==None:
97                qmax= max(xtemp) 
98            chisqr, out, cov = fitHelper(model,parameters, xtemp,ytemp, dytemp ,qmin,qmax)
99            return chisqr, out, cov
100   
101    def set_model(self,model,Uid):
102        """ Set model """
103        if self.fitArrangeList.has_key(Uid):
104            self.fitArrangeList[Uid].set_model(model)
105        else:
106            fitproblem= FitArrange()
107            fitproblem.set_model(model)
108            self.fitArrangeList[Uid]=fitproblem
109       
110    def set_data(self,data,Uid):
111        """ Receive plottable and create a list of data to fit"""
112       
113        if self.fitArrangeList.has_key(Uid):
114            self.fitArrangeList[Uid].add_data(data)
115        else:
116            fitproblem= FitArrange()
117            fitproblem.add_data(data)
118            self.fitArrangeList[Uid]=fitproblem
119           
120    def get_model(self,Uid):
121        """ return list of data"""
122        return self.fitArrangeList[Uid]
123   
124    def set_param(self,model, pars):
125        """ Recieve a dictionary of parameter and save it """
126        parameters=[]
127        if model==None:
128            raise ValueError, "Cannot set parameters for empty model"
129        else:
130            #for key ,value in pars:
131            for key, value in pars.iteritems():
132                param = Parameter(model, key, value)
133                parameters.append(param)
134        return parameters
135   
136    def add_constraint(self, constraint):
137        """ User specify contraint to fit """
138        self.constraint = str(constraint)
139       
140    def get_constraint(self):
141        """ return the contraint value """
142        return self.constraint
143   
144    def set_constraint(self,constraint):
145        """
146            receive a string as a constraint
147            @param constraint: a string used to constraint some parameters to get a
148                specific value
149        """
150        self.constraint= constraint
151   
152   
153               
154
155class Parameter:
156    """
157        Class to handle model parameters
158    """
159    def __init__(self, model, name, value=None):
160            self.model = model
161            self.name = name
162            if not value==None:
163                self.model.setParam(self.name, value)
164           
165    def set(self, value):
166        """
167            Set the value of the parameter
168        """
169        self.model.setParam(self.name, value)
170
171    def __call__(self):
172        """
173            Return the current value of the parameter
174        """
175        return self.model.getParam(self.name)
176   
177def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
178    """
179        Fit function
180        @param model: sans model object
181        @param pars: list of parameters
182        @param x: vector of x data
183        @param y: vector of y data
184        @param err_y: vector of y errors
185    """
186    def f(params):
187        """
188            Calculates the vector of residuals for each point
189            in y for a given set of input parameters.
190            @param params: list of parameter values
191            @return: vector of residuals
192        """
193        i = 0
194        for p in pars:
195            p.set(params[i])
196            i += 1
197       
198        residuals = []
199        for j in range(len(x)):
200            if x[j]>qmin and x[j]<qmax:
201                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
202       
203        return residuals
204       
205    def chi2(params):
206        """
207            Calculates chi^2
208            @param params: list of parameter values
209            @return: chi^2
210        """
211        sum = 0
212        res = f(params)
213        for item in res:
214            sum += item*item
215        return sum
216       
217    p = [param() for param in pars]
218    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
219    print info, mesg, success
220    # Calculate chi squared
221    if len(pars)>1:
222        chisqr = chi2(out)
223    elif len(pars)==1:
224        chisqr = chi2([out])
225       
226    return chisqr, out, cov_x   
227
228     
229if __name__ == "__main__": 
230    load= Load()
231   
232    # test fit one data set one model
233    load.set_filename("testdata_line.txt")
234    load.set_values()
235    data1 = Data1D(x=[], y=[], dx=None,dy=None)
236    data1.name = "data1"
237    load.load_data(data1)
238    Fit =Fitting()
239   
240    from sans.guitools.LineModel import LineModel
241    model  = LineModel()
242    Fit.set_model(model,1)
243    Fit.set_data(data1,1)
244   
245    chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None)
246    print"fit only one data",chisqr, out, cov
247   
248    # test fit with 2 data and one model
249    Fit =Fitting()
250    Fit.set_model(model,2 )
251    load.set_filename("testdata1.txt")
252    load.set_values()
253    data2 = Data1D(x=[], y=[], dx=None,dy=None)
254    data2.name = "data2"
255   
256    load.load_data(data2)
257    Fit.set_data(data2,2)
258   
259    load.set_filename("testdata2.txt")
260    load.set_values()
261    data3 = Data1D(x=[], y=[], dx=None,dy=None)
262    data3.name = "data2"
263    load.load_data(data3)
264    Fit.set_data(data3,2)
265    chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None)
266    print"fit two data",chisqr, out, cov
267   
Note: See TracBrowser for help on using the repository browser.