source: sasview/park_integration/test/FittingModule.py @ b11d175

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b11d175 was b11d175, checked in by Gervaise Alina <gervyh@…>, 16 years ago

modification on fit function.

  • Property mode set to 100644
File size: 9.5 KB
Line 
1#class Fitting
2from sans.guitools.plottables import Data1D
3from Loader import Load
4from scipy import optimize
5
6
7class FitArrange:
8    def __init__(self):
9        """
10            Store a set of data for a given model to perform the Fit
11            @param model: the model selected by the user
12            @param Ldata: a list of data what the user want to fit
13        """
14        self.model = None
15        self.dList =[]
16       
17    def set_model(self,model):
18        """ set the model """
19        self.model = model
20       
21    def add_data(self,data):
22        """
23            @param data: Data to add in the list
24            fill a self.dataList with data to fit
25        """
26        if not data in self.dList:
27            self.dList.append(data)
28           
29    def get_model(self):
30        """ Return the model"""
31        return self.model   
32     
33    def get_data(self):
34        """ Return list of data"""
35        return self.dList
36     
37    def remove_data(self,data):
38        """
39            Remove one element from the list
40            @param data: Data to remove from the the list of data
41        """
42        if data in self.dList:
43            self.dList.remove(data)
44           
45class Fitting:
46    """
47        Performs the Fit.he user determine what kind of data
48    """
49    def __init__(self,data=[]):
50        #this is a dictionary of FitArrange elements
51        self.fitArrangeList={}
52        #the constraint of the Fit
53        self.constraint =None
54        #Specify the use of scipy or park fit
55        self.fitType =None
56       
57    def fit_engine(self,word):
58        """
59            Check the contraint value and specify what kind of fit to use
60        """
61        word=word.lower()
62        if word =="scipy" or word=="park":
63            self.fitType = word
64            return True
65        else:
66            #raise ValueError, "please enter the keyword scipy or park"
67            return False
68   
69    def fit(self,pars, qmin=None, qmax=None):
70        """
71             Do the fit
72        """
73        #for item in self.fitArrangeList.:
74        if not self.fitType ==None:
75            if self.fitType=="scipy":# sans fit
76                fitproblem = self.fitArrangeList.values()[0]
77                listdata=[]
78                model = fitproblem.get_model()
79                listdata = fitproblem.get_data()
80                parameters = self.set_param(model,pars)
81                if listdata==[]:
82                    raise ValueError, " data list missing"
83                else:
84                    # Do the fit with more than one data set and one model
85                    xtemp=[]
86                    ytemp=[]
87                    dytemp=[]
88                   
89                    for data in listdata:
90                        for i in range(len(data.x)):
91                            if not data.x[i] in xtemp:
92                                xtemp.append(data.x[i])
93                               
94                            if not data.y[i] in ytemp:
95                                ytemp.append(data.y[i])
96                               
97                            if not data.dy[i] in dytemp:
98                                dytemp.append(data.dy[i])
99                    if qmin==None:
100                        qmin= min(xtemp)
101                    if qmax==None:
102                        qmax= max(xtemp) 
103                    chisqr, out, cov = fitHelper(model,parameters, xtemp,ytemp, dytemp ,qmin,qmax)
104                    return chisqr, out, cov
105            else:#park fit
106                parkHelper()
107   
108    def set_model(self,model,Uid):
109        """ Set model """
110        if self.fitArrangeList.has_key(Uid):
111            self.fitArrangeList[Uid].set_model(model)
112        else:
113            fitproblem= FitArrange()
114            fitproblem.set_model(model)
115            self.fitArrangeList[Uid]=fitproblem
116       
117    def set_data(self,data,Uid):
118        """ Receive plottable and create a list of data to fit"""
119       
120        if self.fitArrangeList.has_key(Uid):
121            self.fitArrangeList[Uid].add_data(data)
122        else:
123            fitproblem= FitArrange()
124            fitproblem.add_data(data)
125            self.fitArrangeList[Uid]=fitproblem
126           
127    def get_model(self,Uid):
128        """ return list of data"""
129        return self.fitArrangeList[Uid]
130   
131    def set_param(self,model, pars):
132        """ Recieve a dictionary of parameter and save it """
133        parameters=[]
134        if model==None:
135            raise ValueError, "Cannot set parameters for empty model"
136        else:
137            #for key ,value in pars:
138            for key, value in pars.iteritems():
139                param = Parameter(model, key, value)
140                parameters.append(param)
141        return parameters
142   
143    def add_constraint(self, constraint):
144        """ User specify contraint to fit """
145        self.constraint = str(constraint)
146       
147    def get_constraint(self):
148        """ return the contraint value """
149        return self.constraint
150   
151    def set_constraint(self,constraint):
152        """
153            receive a string as a constraint
154            @param constraint: a string used to constraint some parameters to get a
155                specific value
156        """
157        self.constraint= constraint
158   
159   
160               
161
162class Parameter:
163    """
164        Class to handle model parameters
165    """
166    def __init__(self, model, name, value=None):
167            self.model = model
168            self.name = name
169            if not value==None:
170                self.model.setParam(self.name, value)
171           
172    def set(self, value):
173        """
174            Set the value of the parameter
175        """
176        self.model.setParam(self.name, value)
177
178    def __call__(self):
179        """
180            Return the current value of the parameter
181        """
182        return self.model.getParam(self.name)
183
184class Fitness:
185   
186    def __init__(self,model, pars, x, y, err_y ,qmin=None, qmax=None):
187        self.x = x
188        self.y = y
189        self.model = model
190        self.err_y = err_y
191        self.qmin = qmin
192        self.qmax= qmax
193        self.pars = pars
194       
195    def getParam(self):
196        return [param() for param in self.pars]
197   
198    def __call__(self, params):
199        i = 0
200        for p in self.pars:
201            p.set(params[i])
202            i += 1
203       
204        residuals = []
205        for j in range(len(self.x)):
206            if self.x[j]>self.qmin and self.x[j]<self.qmax:
207                residuals.append( ( self.y[j] - self.model.runXY(self.x[j]) ) / self.err_y[j] )
208       
209        return residuals
210   
211def parkHelper(): 
212    """ park code goes here"""
213
214def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
215    """
216        Fit function
217        @param model: sans model object
218        @param pars: list of parameters
219        @param x: vector of x data
220        @param y: vector of y data
221        @param err_y: vector of y errors
222    """
223   
224    f = Fitness(model, pars, x, y, err_y ,qmin, qmax)
225   
226    def ff(params):
227        """
228            Calculates the vector of residuals for each point
229            in y for a given set of input parameters.
230            @param params: list of parameter values
231            @return: vector of residuals
232        """
233        i = 0
234        for p in pars:
235            p.set(params[i])
236            i += 1
237       
238        residuals = []
239        for j in range(len(x)):
240            if x[j]>qmin and x[j]<qmax:
241                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
242       
243        return residuals
244       
245    def chi2(params):
246        """
247            Calculates chi^2
248            @param params: list of parameter values
249            @return: chi^2
250        """
251        sum = 0
252        res = f(params)
253        for item in res:
254            sum += item*item
255        return sum
256       
257    p = [param() for param in pars]
258    out, cov_x, info, mesg, success = optimize.leastsq(f,f.getParam(), full_output=1, warning=True)
259    #out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
260    print info, mesg, success
261    # Calculate chi squared
262    if len(pars)>1:
263        chisqr = chi2(out)
264    elif len(pars)==1:
265        chisqr = chi2([out])
266       
267    return chisqr, out, cov_x   
268
269     
270if __name__ == "__main__": 
271    load= Load()
272   
273    # test fit one data set one model
274    load.set_filename("testdata_line.txt")
275    load.set_values()
276    data1 = Data1D(x=[], y=[], dx=None,dy=None)
277    data1.name = "data1"
278    load.load_data(data1)
279    Fit =Fitting()
280   
281    from LineModel import LineModel
282    model  = LineModel()
283    Fit.set_model(model,1)
284    Fit.set_data(data1,1)
285    flag=Fit.fit_engine("Scipy")
286    chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None)
287    print"fit only one data",chisqr, out, cov
288   
289    # test fit with 2 data and one model
290    Fit =Fitting()
291    Fit.set_model(model,2 )
292    load.set_filename("testdata1.txt")
293    load.set_values()
294    data2 = Data1D(x=[], y=[], dx=None,dy=None)
295    data2.name = "data2"
296   
297    load.load_data(data2)
298    Fit.set_data(data2,2)
299   
300    load.set_filename("testdata2.txt")
301    load.set_values()
302    data3 = Data1D(x=[], y=[], dx=None,dy=None)
303    data3.name = "data2"
304    load.load_data(data3)
305    Fit.set_data(data3,2)
306    flag=Fit.fit_engine("scipy")
307    chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None)
308    print"fit two data",chisqr, out, cov
309   
Note: See TracBrowser for help on using the repository browser.