source: sasview/park_integration/AbstractFitEngine.py @ 3370922

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 3370922 was fff74cb, checked in by Gervaise Alina <gervyh@…>, 15 years ago

compute residual for data2D

  • Property mode set to 100644
File size: 19.6 KB
Line 
1
2import park,numpy,math
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146        self.qmin=mini
147        self.qmax=maxi
148       
149       
150    def getFitRange(self):
151        """
152            @return the range of data.x to fit
153        """
154        return self.qmin, self.qmax
155     
156     
157    def residuals(self, fn):
158        """ @param fn: function that return model value
159            @return residuals
160        """
161        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
162        if self.qmin==None and self.qmax==None: 
163            fx =numpy.asarray([fn(v) for v in x])
164            return (y - fx)/dy
165        else:
166            idx = (x>=self.qmin) & (x <= self.qmax)
167            fx = numpy.asarray([fn(item)for item in x[idx ]])
168            return (y[idx] - fx)/dy[idx]
169       
170    def residuals_deriv(self, model, pars=[]):
171        """
172            @return residuals derivatives .
173            @note: in this case just return empty array
174        """
175        return []
176   
177   
178class FitData1D(object):
179    """ Wrapper class  for SANS data """
180    def __init__(self,sans_data1d, smearer=None):
181        """
182            Data can be initital with a data (sans plottable)
183            or with vectors.
184           
185            self.smearer is an object of class QSmearer or SlitSmearer
186            that will smear the theory data (slit smearing or resolution
187            smearing) when set.
188           
189            The proper way to set the smearing object would be to
190            do the following:
191           
192            from DataLoader.qsmearing import smear_selection
193            fitdata1d = FitData1D(some_data)
194            fitdata1d.smearer = smear_selection(some_data)
195           
196            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
197           
198            Setting it back to None will turn smearing off.
199           
200        """
201       
202        self.smearer = smearer
203     
204        # Initialize from Data1D object
205        self.data=sans_data1d
206        self.x= sans_data1d.x
207        self.y= sans_data1d.y
208        self.dx= sans_data1d.dx
209        self.dy= sans_data1d.dy
210       
211        ## Min Q-value
212        self.qmin= min (self.data.x)
213        ## Max Q-value
214        self.qmax= max (self.data.x)
215       
216       
217    def setFitRange(self,qmin=None,qmax=None):
218        """ to set the fit range"""
219        if qmin!=None:
220            self.qmin = qmin
221        if qmax !=None:
222            self.qmax = qmax
223       
224       
225    def getFitRange(self):
226        """
227            @return the range of data.x to fit
228        """
229        return self.qmin, self.qmax
230     
231     
232    def residuals(self, fn):
233        """
234            Compute residuals.
235           
236            If self.smearer has been set, use if to smear
237            the data before computing chi squared.
238           
239            @param fn: function that return model value
240            @return residuals
241        """
242        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
243        if self.dy ==None or self.dy==[]:
244            dy= numpy.zeros(len(y)) 
245        else:
246            dy= numpy.asarray(self.dy)
247        dy[dy==0]=1
248        idx = (x>=self.qmin) & (x <= self.qmax)
249 
250        # Compute theory data f(x)
251        fx = numpy.zeros(len(x))
252        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
253       
254        # Smear theory data
255     
256        if self.smearer is not None:
257            fx = self.smearer(fx)
258       
259        # Sanity check
260        if numpy.size(dy) < numpy.size(x):
261            raise RuntimeError, "FitData1D: invalid error array"
262                           
263        return (y[idx] - fx[idx])/dy[idx]
264     
265 
266       
267    def residuals_deriv(self, model, pars=[]):
268        """
269            @return residuals derivatives .
270            @note: in this case just return empty array
271        """
272        return []
273   
274   
275class FitData2D(object):
276    """ Wrapper class  for SANS data """
277    def __init__(self,sans_data2d):
278        """
279            Data can be initital with a data (sans plottable)
280            or with vectors.
281        """
282        self.data=sans_data2d
283        self.image = sans_data2d.data
284        self.err_image = sans_data2d.err_data
285        self.x_bins= sans_data2d.x_bins
286        self.y_bins= sans_data2d.y_bins
287       
288        x = max(self.data.xmin, self.data.xmax)
289        y = max(self.data.ymin, self.data.ymax)
290       
291        ## fitting range
292        self.qmin = 0
293        self.qmax = math.sqrt(x*x +y*y)
294       
295       
296       
297    def setFitRange(self,qmin=None,qmax=None):
298        """ to set the fit range"""
299        if qmin!=None:
300            self.qmin= qmin
301        if qmax!=None:
302            self.qmax= qmax
303     
304       
305    def getFitRange(self):
306        """
307            @return the range of data.x to fit
308        """
309        return self.qmin, self.qmax
310     
311     
312    def residuals(self, fn):
313        """ @param fn: function that return model value
314            @return residuals
315        """
316        res=[]
317        if self.err_image== None or self.err_image ==[]:
318            self.err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
319        self.err_image[self.err_image==0]=1
320       
321        for i in range(len(self.x_bins)):
322            for j in range(len(self.y_bins)):
323                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
324                radius= math.sqrt(temp)
325                if self.qmin <= radius and radius <= self.qmax:
326                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
327                            /self.err_image[j][i] )
328       
329        return numpy.array(res)
330       
331         
332    def residuals_deriv(self, model, pars=[]):
333        """
334            @return residuals derivatives .
335            @note: in this case just return empty array
336        """
337        return []
338   
339class sansAssembly:
340    """
341         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
342    """
343    def __init__(self,paramlist,Model=None , Data=None):
344        """
345            @param Model: the model wrapper fro sans -model
346            @param Data: the data wrapper for sans data
347        """
348        self.model = Model
349        self.data  = Data
350        self.paramlist=paramlist
351        self.res=[]
352    def chisq(self, params):
353        """
354            Calculates chi^2
355            @param params: list of parameter values
356            @return: chi^2
357        """
358        sum = 0
359        for item in self.res:
360            sum += item*item
361       
362        return sum/ len(self.res)
363   
364    def __call__(self,params):
365        """
366            Compute residuals
367            @param params: value of parameters to fit
368        """
369        #import thread
370        self.model.setParams(self.paramlist,params)
371        self.res= self.data.residuals(self.model.eval)
372        #print "residuals",thread.get_ident()
373        return self.res
374   
375class FitEngine:
376    def __init__(self):
377        """
378            Base class for scipy and park fit engine
379        """
380        #List of parameter names to fit
381        self.paramList=[]
382        #Dictionnary of fitArrange element (fit problems)
383        self.fitArrangeDict={}
384       
385    def _concatenateData(self, listdata=[]):
386        """ 
387            _concatenateData method concatenates each fields of all data contains ins listdata.
388            @param listdata: list of data
389            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
390             of data concatenanted
391            @raise: if listdata is empty  will return None
392            @raise: if data in listdata don't contain dy field ,will create an error
393            during fitting
394        """
395        #TODO: we have to refactor the way we handle data.
396        # We should move away from plottables and move towards the Data1D objects
397        # defined in DataLoader. Data1D allows data manipulations, which should be
398        # used to concatenate.
399        # In the meantime we should switch off the concatenation.
400        #if len(listdata)>1:
401        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
402        #return listdata[0]
403       
404        if listdata==[]:
405            raise ValueError, " data list missing"
406        else:
407            xtemp=[]
408            ytemp=[]
409            dytemp=[]
410            self.mini=None
411            self.maxi=None
412               
413            for item in listdata:
414                data=item.data
415                mini,maxi=data.getFitRange()
416                if self.mini==None and self.maxi==None:
417                    self.mini=mini
418                    self.maxi=maxi
419                else:
420                    if mini < self.mini:
421                        self.mini=mini
422                    if self.maxi < maxi:
423                        self.maxi=maxi
424                       
425                   
426                for i in range(len(data.x)):
427                    xtemp.append(data.x[i])
428                    ytemp.append(data.y[i])
429                    if data.dy is not None and len(data.dy)==len(data.y):   
430                        dytemp.append(data.dy[i])
431                    else:
432                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
433            data= Data(x=xtemp,y=ytemp,dy=dytemp)
434            data.setFitRange(self.mini, self.maxi)
435            return data
436       
437       
438    def set_model(self,model,Uid,pars=[]):
439        """
440            set a model on a given uid in the fit engine.
441            @param model: the model to fit
442            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
443            @param pars: the list of parameters to fit
444            @note : pars must contains only name of existing model's paramaters
445        """
446        if len(pars) >0:
447            if model==None:
448                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
449            else:
450                temp=[]
451                for item in pars:
452                    if item in model.model.getParamList():
453                        temp.append(item)
454                        self.paramList.append(item)
455                    else:
456                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
457                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
458                        return
459            #A fitArrange is already created but contains dList only at Uid
460            if self.fitArrangeDict.has_key(Uid):
461                self.fitArrangeDict[Uid].set_model(model)
462                self.fitArrangeDict[Uid].pars= pars
463            else:
464            #no fitArrange object has been create with this Uid
465                fitproblem = FitArrange()
466                fitproblem.set_model(model)
467                fitproblem.pars= pars
468                self.fitArrangeDict[Uid] = fitproblem
469               
470        else:
471            raise ValueError, "park_integration:missing parameters"
472   
473    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
474        """ Receives plottable, creates a list of data to fit,set data
475            in a FitArrange object and adds that object in a dictionary
476            with key Uid.
477            @param data: data added
478            @param Uid: unique key corresponding to a fitArrange object with data
479        """
480        if data.__class__.__name__=='Data2D':
481            fitdata=FitData2D(data)
482        else:
483            fitdata=FitData1D(data, smearer)
484       
485        fitdata.setFitRange(qmin=qmin,qmax=qmax)
486        #A fitArrange is already created but contains model only at Uid
487        if self.fitArrangeDict.has_key(Uid):
488            self.fitArrangeDict[Uid].add_data(fitdata)
489        else:
490        #no fitArrange object has been create with this Uid
491            fitproblem= FitArrange()
492            fitproblem.add_data(fitdata)
493            self.fitArrangeDict[Uid]=fitproblem   
494   
495    def get_model(self,Uid):
496        """
497            @param Uid: Uid is key in the dictionary containing the model to return
498            @return  a model at this uid or None if no FitArrange element was created
499            with this Uid
500        """
501        if self.fitArrangeDict.has_key(Uid):
502            return self.fitArrangeDict[Uid].get_model()
503        else:
504            return None
505   
506    def remove_Fit_Problem(self,Uid):
507        """remove   fitarrange in Uid"""
508        if self.fitArrangeDict.has_key(Uid):
509            del self.fitArrangeDict[Uid]
510           
511    def select_problem_for_fit(self,Uid,value):
512        """
513            select a couple of model and data at the Uid position in dictionary
514            and set in self.selected value to value
515            @param value: the value to allow fitting. can only have the value one or zero
516        """
517        if self.fitArrangeDict.has_key(Uid):
518             self.fitArrangeDict[Uid].set_to_fit( value)
519             
520             
521    def get_problem_to_fit(self,Uid):
522        """
523            return the self.selected value of the fit problem of Uid
524           @param Uid: the Uid of the problem
525        """
526        if self.fitArrangeDict.has_key(Uid):
527             self.fitArrangeDict[Uid].get_to_fit()
528   
529class FitArrange:
530    def __init__(self):
531        """
532            Class FitArrange contains a set of data for a given model
533            to perform the Fit.FitArrange must contain exactly one model
534            and at least one data for the fit to be performed.
535            model: the model selected by the user
536            Ldata: a list of data what the user wants to fit
537           
538        """
539        self.model = None
540        self.dList =[]
541        self.pars=[]
542        #self.selected  is zero when this fit problem is not schedule to fit
543        #self.selected is 1 when schedule to fit
544        self.selected = 0
545       
546    def set_model(self,model):
547        """
548            set_model save a copy of the model
549            @param model: the model being set
550        """
551        self.model = model
552       
553    def add_data(self,data):
554        """
555            add_data fill a self.dList with data to fit
556            @param data: Data to add in the list 
557        """
558        if not data in self.dList:
559            self.dList.append(data)
560           
561    def get_model(self):
562        """ @return: saved model """
563        return self.model   
564     
565    def get_data(self):
566        """ @return:  list of data dList"""
567        #return self.dList
568        return self.dList[0] 
569     
570    def remove_data(self,data):
571        """
572            Remove one element from the list
573            @param data: Data to remove from dList
574        """
575        if data in self.dList:
576            self.dList.remove(data)
577    def set_to_fit (self, value=0):
578        """
579           set self.selected to 0 or 1  for other values raise an exception
580           @param value: integer between 0 or 1
581        """
582        self.selected= value
583       
584    def get_to_fit(self):
585        """
586            @return self.selected value
587        """
588        return self.selected
589   
590
591
592   
Note: See TracBrowser for help on using the repository browser.