source: sasview/park_integration/AbstractFitEngine.py @ f2817bb

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since f2817bb was f2817bb, checked in by Gervaise Alina <gervyh@…>, 16 years ago

remove metadata classes

  • Property mode set to 100644
File size: 17.2 KB
Line 
1
2import park,numpy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        lo,hi = self._model.details[self.name][1:]
41        if lo is None: lo = -numpy.inf
42        if hi is None: hi = numpy.inf
43        return lo,hi
44   
45    def _setrange(self,r):
46        """
47            override _setrange of park parameter
48            @param r: the value of the range to set
49        """
50        self._model.details[self.name][1:] = r
51    range = property(_getrange,_setrange)
52   
53class Model(park.Model):
54    """
55        PARK wrapper for SANS models.
56    """
57    def __init__(self, sans_model, **kw):
58        """
59            @param sans_model: the sans model to wrap using park interface
60        """
61        park.Model.__init__(self, **kw)
62        self.model = sans_model
63        self.name = sans_model.name
64        #list of parameters names
65        self.sansp = sans_model.getParamList()
66        #list of park parameter
67        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
68        #list of parameterset
69        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
70        self.pars=[]
71 
72 
73    def getParams(self,fitparams):
74        """
75            return a list of value of paramter to fit
76            @param fitparams: list of paramaters name to fit
77        """
78        list=[]
79        self.pars=[]
80        self.pars=fitparams
81        for item in fitparams:
82            for element in self.parkp:
83                 if element.name ==str(item):
84                     list.append(element.value)
85        return list
86   
87   
88    def setParams(self,paramlist, params):
89        """
90            Set value for parameters to fit
91            @param params: list of value for parameters to fit
92        """
93        try:
94            for i in range(len(self.parkp)):
95                for j in range(len(paramlist)):
96                    if self.parkp[i].name==paramlist[j]:
97                        self.parkp[i].value = params[j]
98                        self.model.setParam(self.parkp[i].name,params[j])
99        except:
100            raise
101 
102    def eval(self,x):
103        """
104            override eval method of park model.
105            @param x: the x value used to compute a function
106        """
107        return self.model.runXY(x)
108   
109   
110
111
112class Data(object):
113    """ Wrapper class  for SANS data """
114    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
115        """
116            Data can be initital with a data (sans plottable)
117            or with vectors.
118        """
119        if  sans_data !=None:
120            self.x= sans_data.x
121            self.y= sans_data.y
122            self.dx= sans_data.dx
123            self.dy= sans_data.dy
124           
125        elif (x!=None and y!=None and dy!=None):
126                self.x=x
127                self.y=y
128                self.dx=dx
129                self.dy=dy
130        else:
131            raise ValueError,\
132            "Data is missing x, y or dy, impossible to compute residuals later on"
133        self.qmin=None
134        self.qmax=None
135       
136       
137    def setFitRange(self,mini=None,maxi=None):
138        """ to set the fit range"""
139        self.qmin=mini
140        self.qmax=maxi
141       
142       
143    def getFitRange(self):
144        """
145            @return the range of data.x to fit
146        """
147        return self.qmin, self.qmax
148     
149     
150    def residuals(self, fn):
151        """ @param fn: function that return model value
152            @return residuals
153        """
154        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
155        if self.qmin==None and self.qmax==None: 
156            fx =numpy.asarray([fn(v) for v in x])
157            return (y - fx)/dy
158        else:
159            idx = (x>=self.qmin) & (x <= self.qmax)
160            fx = numpy.asarray([fn(item)for item in x[idx ]])
161            return (y[idx] - fx)/dy[idx]
162       
163    def residuals_deriv(self, model, pars=[]):
164        """
165            @return residuals derivatives .
166            @note: in this case just return empty array
167        """
168        return []
169class FitData1D(object):
170    """ Wrapper class  for SANS data """
171    def __init__(self,sans_data1d):
172        """
173            Data can be initital with a data (sans plottable)
174            or with vectors.
175        """
176        self.data=sans_data1d
177        self.x= sans_data1d.x
178        self.y= sans_data1d.y
179        self.dx= sans_data1d.dx
180        self.dy= sans_data1d.dy
181        self.qmin=None
182        self.qmax=None
183       
184       
185    def setFitRange(self,qmin=None,qmax=None,ymin=None,ymax=None,):
186        """ to set the fit range"""
187        self.qmin=qmin
188        self.qmax=qmax
189       
190       
191    def getFitRange(self):
192        """
193            @return the range of data.x to fit
194        """
195        return self.qmin, self.qmax
196     
197     
198    def residuals(self, fn):
199        """ @param fn: function that return model value
200            @return residuals
201        """
202        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
203        if self.qmin==None and self.qmax==None: 
204            fx =numpy.asarray([fn(v) for v in x])
205            return (y - fx)/dy
206        else:
207            idx = (x>=self.qmin) & (x <= self.qmax)
208            fx = numpy.asarray([fn(item)for item in x[idx ]])
209            return (y[idx] - fx)/dy[idx]
210       
211    def residuals_deriv(self, model, pars=[]):
212        """
213            @return residuals derivatives .
214            @note: in this case just return empty array
215        """
216        return []
217   
218   
219class FitData2D(object):
220    """ Wrapper class  for SANS data """
221    def __init__(self,sans_data2d):
222        """
223            Data can be initital with a data (sans plottable)
224            or with vectors.
225        """
226        self.data=sans_data2d
227        self.image = sans_data2d.image
228        self.err_image = sans_data2d.err_image
229        self.x_bins= sans_data2d.x_bins
230        self.y_bins= sans_data2d.y_bins
231       
232        self.xmin= self.data.xmin
233        self.xmax= self.data.xmax
234        self.ymin= self.data.ymin
235        self.ymax= self.data.ymax
236       
237       
238    def setFitRange(self,qmin=None,qmax=None,ymin=None,ymax=None):
239        """ to set the fit range"""
240        self.xmin= qmin
241        self.xmax= qmax
242        self.ymin= ymin
243        self.ymax= ymax
244       
245    def getFitRange(self):
246        """
247            @return the range of data.x to fit
248        """
249        return self.xmin, self.xmax,self.ymin, self.ymax
250     
251     
252    def residuals(self, fn):
253        """ @param fn: function that return model value
254            @return residuals
255        """
256        res=[]
257        if self.xmin==None:
258            self.xmin= self.data.xmin
259        if self.xmax==None:
260            self.xmax= self.data.xmax
261        if self.ymin==None:
262            self.ymin= self.data.ymin
263        if self.ymax==None:
264            self.ymax= self.data.ymax
265           
266        for i in range(len(self.y_bins)):
267            #if self.y_bins[i]>= self.ymin and self.y_bins[i]<= self.ymax:
268            for j in range(len(self.x_bins)):
269                #if self.x_bins[j]>= self.xmin and self.x_bins[j]<= self.xmax:
270                res.append( (self.image[j][i]- fn([self.x_bins[j],self.y_bins[i]]))\
271                            /self.err_image[j][i] )
272       
273        return numpy.array(res)
274       
275         
276    def residuals_deriv(self, model, pars=[]):
277        """
278            @return residuals derivatives .
279            @note: in this case just return empty array
280        """
281        return []
282   
283class sansAssembly:
284    """
285         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
286    """
287    def __init__(self,paramlist,Model=None , Data=None):
288        """
289            @param Model: the model wrapper fro sans -model
290            @param Data: the data wrapper for sans data
291        """
292        self.model = Model
293        self.data  = Data
294        self.paramlist=paramlist
295        self.res=[]
296    def chisq(self, params):
297        """
298            Calculates chi^2
299            @param params: list of parameter values
300            @return: chi^2
301        """
302        sum = 0
303        for item in self.res:
304            sum += item*item
305        return sum
306    def __call__(self,params):
307        """
308            Compute residuals
309            @param params: value of parameters to fit
310        """
311        self.model.setParams(self.paramlist,params)
312        self.res= self.data.residuals(self.model.eval)
313        return self.res
314   
315class FitEngine:
316    def __init__(self):
317        """
318            Base class for scipy and park fit engine
319        """
320        #List of parameter names to fit
321        self.paramList=[]
322        #Dictionnary of fitArrange element (fit problems)
323        self.fitArrangeDict={}
324       
325    def _concatenateData(self, listdata=[]):
326        """ 
327            _concatenateData method concatenates each fields of all data contains ins listdata.
328            @param listdata: list of data
329            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
330             of data concatenanted
331            @raise: if listdata is empty  will return None
332            @raise: if data in listdata don't contain dy field ,will create an error
333            during fitting
334        """
335        if listdata==[]:
336            raise ValueError, " data list missing"
337        else:
338            xtemp=[]
339            ytemp=[]
340            dytemp=[]
341            self.mini=None
342            self.maxi=None
343               
344            for item in listdata:
345                data=item.data
346                mini,maxi=data.getFitRange()
347                if self.mini==None and self.maxi==None:
348                    self.mini=mini
349                    self.maxi=maxi
350                else:
351                    if mini < self.mini:
352                        self.mini=mini
353                    if self.maxi < maxi:
354                        self.maxi=maxi
355                       
356                   
357                for i in range(len(data.x)):
358                    xtemp.append(data.x[i])
359                    ytemp.append(data.y[i])
360                    if data.dy is not None and len(data.dy)==len(data.y):   
361                        dytemp.append(data.dy[i])
362                    else:
363                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
364            data= Data(x=xtemp,y=ytemp,dy=dytemp)
365            data.setFitRange(self.mini, self.maxi)
366            return data
367       
368       
369    def set_model(self,model,Uid,pars=[]):
370        """
371            set a model on a given uid in the fit engine.
372            @param model: the model to fit
373            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
374            @param pars: the list of parameters to fit
375            @note : pars must contains only name of existing model's paramaters
376        """
377        if len(pars) >0:
378            if model==None:
379                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
380            else:
381                for item in pars:
382                    if item in model.model.getParamList():
383                        self.paramList.append(item)
384                    else:
385                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
386                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
387                        return
388            #A fitArrange is already created but contains dList only at Uid
389            if self.fitArrangeDict.has_key(Uid):
390                self.fitArrangeDict[Uid].set_model(model)
391            else:
392            #no fitArrange object has been create with this Uid
393                fitproblem = FitArrange()
394                fitproblem.set_model(model)
395                self.fitArrangeDict[Uid] = fitproblem
396        else:
397            raise ValueError, "park_integration:missing parameters"
398   
399    def set_data(self,data,Uid,qmin=None,qmax=None,ymin=None,ymax=None):
400        """ Receives plottable, creates a list of data to fit,set data
401            in a FitArrange object and adds that object in a dictionary
402            with key Uid.
403            @param data: data added
404            @param Uid: unique key corresponding to a fitArrange object with data
405        """
406        if data.__class__.__name__=='Data2D':
407            fitdata=FitData2D(data)
408        else:
409            fitdata=FitData1D(data)
410       
411        fitdata.setFitRange(qmin=qmin,qmax=qmax, ymin=ymin,ymax=ymax)
412        #A fitArrange is already created but contains model only at Uid
413        if self.fitArrangeDict.has_key(Uid):
414            self.fitArrangeDict[Uid].add_data(fitdata)
415        else:
416        #no fitArrange object has been create with this Uid
417            fitproblem= FitArrange()
418            fitproblem.add_data(fitdata)
419            self.fitArrangeDict[Uid]=fitproblem   
420   
421    def get_model(self,Uid):
422        """
423            @param Uid: Uid is key in the dictionary containing the model to return
424            @return  a model at this uid or None if no FitArrange element was created
425            with this Uid
426        """
427        if self.fitArrangeDict.has_key(Uid):
428            return self.fitArrangeDict[Uid].get_model()
429        else:
430            return None
431   
432    def remove_Fit_Problem(self,Uid):
433        """remove   fitarrange in Uid"""
434        if self.fitArrangeDict.has_key(Uid):
435            del self.fitArrangeDict[Uid]
436           
437    def select_problem_for_fit(self,Uid,value):
438        """
439            select a couple of model and data at the Uid position in dictionary
440            and set in self.selected value to value
441            @param value: the value to allow fitting. can only have the value one or zero
442        """
443        if self.fitArrangeDict.has_key(Uid):
444             self.fitArrangeDict[Uid].set_to_fit( value)
445    def get_problem_to_fit(self,Uid):
446        """
447            return the self.selected value of the fit problem of Uid
448           @param Uid: the Uid of the problem
449        """
450        if self.fitArrangeDict.has_key(Uid):
451             self.fitArrangeDict[Uid].get_to_fit()
452   
453class FitArrange:
454    def __init__(self):
455        """
456            Class FitArrange contains a set of data for a given model
457            to perform the Fit.FitArrange must contain exactly one model
458            and at least one data for the fit to be performed.
459            model: the model selected by the user
460            Ldata: a list of data what the user wants to fit
461           
462        """
463        self.model = None
464        self.dList =[]
465        #self.selected  is zero when this fit problem is not schedule to fit
466        #self.selected is 1 when schedule to fit
467        self.selected = 0
468       
469    def set_model(self,model):
470        """
471            set_model save a copy of the model
472            @param model: the model being set
473        """
474        self.model = model
475       
476    def add_data(self,data):
477        """
478            add_data fill a self.dList with data to fit
479            @param data: Data to add in the list 
480        """
481        if not data in self.dList:
482            self.dList.append(data)
483           
484    def get_model(self):
485        """ @return: saved model """
486        return self.model   
487     
488    def get_data(self):
489        """ @return:  list of data dList"""
490        #return self.dList
491        return self.dList[0] 
492     
493    def remove_data(self,data):
494        """
495            Remove one element from the list
496            @param data: Data to remove from dList
497        """
498        if data in self.dList:
499            self.dList.remove(data)
500    def set_to_fit (self, value=0):
501        """
502           set self.selected to 0 or 1  for other values raise an exception
503           @param value: integer between 0 or 1
504        """
505        self.selected= value
506       
507    def get_to_fit(self):
508        """
509            @return self.selected value
510        """
511        return self.selected
512   
513
514
515   
Note: See TracBrowser for help on using the repository browser.