source: sasview/park_integration/AbstractFitEngine.py @ 882a912

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 882a912 was 54c21f50, checked in by Gervaise Alina <gervyh@…>, 15 years ago

deecopy of error data before computing residual

  • Property mode set to 100644
File size: 19.7 KB
Line 
1
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146        self.qmin=mini
147        self.qmax=maxi
148       
149       
150    def getFitRange(self):
151        """
152            @return the range of data.x to fit
153        """
154        return self.qmin, self.qmax
155     
156     
157    def residuals(self, fn):
158        """ @param fn: function that return model value
159            @return residuals
160        """
161        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
162        if self.qmin==None and self.qmax==None: 
163            fx =numpy.asarray([fn(v) for v in x])
164            return (y - fx)/dy
165        else:
166            idx = (x>=self.qmin) & (x <= self.qmax)
167            fx = numpy.asarray([fn(item)for item in x[idx ]])
168            return (y[idx] - fx)/dy[idx]
169       
170    def residuals_deriv(self, model, pars=[]):
171        """
172            @return residuals derivatives .
173            @note: in this case just return empty array
174        """
175        return []
176   
177   
178class FitData1D(object):
179    """ Wrapper class  for SANS data """
180    def __init__(self,sans_data1d, smearer=None):
181        """
182            Data can be initital with a data (sans plottable)
183            or with vectors.
184           
185            self.smearer is an object of class QSmearer or SlitSmearer
186            that will smear the theory data (slit smearing or resolution
187            smearing) when set.
188           
189            The proper way to set the smearing object would be to
190            do the following:
191           
192            from DataLoader.qsmearing import smear_selection
193            fitdata1d = FitData1D(some_data)
194            fitdata1d.smearer = smear_selection(some_data)
195           
196            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
197           
198            Setting it back to None will turn smearing off.
199           
200        """
201       
202        self.smearer = smearer
203     
204        # Initialize from Data1D object
205        self.data=sans_data1d
206        self.x= sans_data1d.x
207        self.y= sans_data1d.y
208        self.dx= sans_data1d.dx
209        self.dy= sans_data1d.dy
210       
211        ## Min Q-value
212        self.qmin= min (self.data.x)
213        ## Max Q-value
214        self.qmax= max (self.data.x)
215       
216       
217    def setFitRange(self,qmin=None,qmax=None):
218        """ to set the fit range"""
219        if qmin!=None:
220            self.qmin = qmin
221        if qmax !=None:
222            self.qmax = qmax
223       
224       
225    def getFitRange(self):
226        """
227            @return the range of data.x to fit
228        """
229        return self.qmin, self.qmax
230     
231     
232    def residuals(self, fn):
233        """
234            Compute residuals.
235           
236            If self.smearer has been set, use if to smear
237            the data before computing chi squared.
238           
239            @param fn: function that return model value
240            @return residuals
241        """
242        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
243        if self.dy ==None or self.dy==[]:
244            dy= numpy.zeros(len(y)) 
245        else:
246            dy= copy.deepcopy(self.dy)
247            dy= numpy.asarray(dy)
248        dy[dy==0]=1
249        idx = (x>=self.qmin) & (x <= self.qmax)
250 
251        # Compute theory data f(x)
252        fx = numpy.zeros(len(x))
253        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
254       
255        # Smear theory data
256     
257        if self.smearer is not None:
258            fx = self.smearer(fx)
259       
260        # Sanity check
261        if numpy.size(dy) < numpy.size(x):
262            raise RuntimeError, "FitData1D: invalid error array"
263                           
264        return (y[idx] - fx[idx])/dy[idx]
265     
266 
267       
268    def residuals_deriv(self, model, pars=[]):
269        """
270            @return residuals derivatives .
271            @note: in this case just return empty array
272        """
273        return []
274   
275   
276class FitData2D(object):
277    """ Wrapper class  for SANS data """
278    def __init__(self,sans_data2d):
279        """
280            Data can be initital with a data (sans plottable)
281            or with vectors.
282        """
283        self.data=sans_data2d
284        self.image = sans_data2d.data
285        self.err_image = sans_data2d.err_data
286        self.x_bins= sans_data2d.x_bins
287        self.y_bins= sans_data2d.y_bins
288       
289        x = max(self.data.xmin, self.data.xmax)
290        y = max(self.data.ymin, self.data.ymax)
291       
292        ## fitting range
293        self.qmin = 0
294        self.qmax = math.sqrt(x*x +y*y)
295       
296       
297       
298    def setFitRange(self,qmin=None,qmax=None):
299        """ to set the fit range"""
300        if qmin!=None:
301            self.qmin= qmin
302        if qmax!=None:
303            self.qmax= qmax
304     
305       
306    def getFitRange(self):
307        """
308            @return the range of data.x to fit
309        """
310        return self.qmin, self.qmax
311     
312     
313    def residuals(self, fn):
314        """ @param fn: function that return model value
315            @return residuals
316        """
317        res=[]
318        if self.err_image== None or self.err_image ==[]:
319            err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
320        else:
321            err_image = copy.deepcopy(self.err_image)
322           
323        err_image[err_image==0]=1
324        for i in range(len(self.x_bins)):
325            for j in range(len(self.y_bins)):
326                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
327                radius= math.sqrt(temp)
328                if self.qmin <= radius and radius <= self.qmax:
329                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
330                            /err_image[j][i] )
331       
332        return numpy.array(res)
333       
334         
335    def residuals_deriv(self, model, pars=[]):
336        """
337            @return residuals derivatives .
338            @note: in this case just return empty array
339        """
340        return []
341   
342class sansAssembly:
343    """
344         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
345    """
346    def __init__(self,paramlist,Model=None , Data=None):
347        """
348            @param Model: the model wrapper fro sans -model
349            @param Data: the data wrapper for sans data
350        """
351        self.model = Model
352        self.data  = Data
353        self.paramlist=paramlist
354        self.res=[]
355    def chisq(self, params):
356        """
357            Calculates chi^2
358            @param params: list of parameter values
359            @return: chi^2
360        """
361        sum = 0
362        for item in self.res:
363            sum += item*item
364       
365        return sum/ len(self.res)
366   
367    def __call__(self,params):
368        """
369            Compute residuals
370            @param params: value of parameters to fit
371        """
372        #import thread
373        self.model.setParams(self.paramlist,params)
374        self.res= self.data.residuals(self.model.eval)
375        #print "residuals",thread.get_ident()
376        return self.res
377   
378class FitEngine:
379    def __init__(self):
380        """
381            Base class for scipy and park fit engine
382        """
383        #List of parameter names to fit
384        self.paramList=[]
385        #Dictionnary of fitArrange element (fit problems)
386        self.fitArrangeDict={}
387       
388    def _concatenateData(self, listdata=[]):
389        """ 
390            _concatenateData method concatenates each fields of all data contains ins listdata.
391            @param listdata: list of data
392            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
393             of data concatenanted
394            @raise: if listdata is empty  will return None
395            @raise: if data in listdata don't contain dy field ,will create an error
396            during fitting
397        """
398        #TODO: we have to refactor the way we handle data.
399        # We should move away from plottables and move towards the Data1D objects
400        # defined in DataLoader. Data1D allows data manipulations, which should be
401        # used to concatenate.
402        # In the meantime we should switch off the concatenation.
403        #if len(listdata)>1:
404        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
405        #return listdata[0]
406       
407        if listdata==[]:
408            raise ValueError, " data list missing"
409        else:
410            xtemp=[]
411            ytemp=[]
412            dytemp=[]
413            self.mini=None
414            self.maxi=None
415               
416            for item in listdata:
417                data=item.data
418                mini,maxi=data.getFitRange()
419                if self.mini==None and self.maxi==None:
420                    self.mini=mini
421                    self.maxi=maxi
422                else:
423                    if mini < self.mini:
424                        self.mini=mini
425                    if self.maxi < maxi:
426                        self.maxi=maxi
427                       
428                   
429                for i in range(len(data.x)):
430                    xtemp.append(data.x[i])
431                    ytemp.append(data.y[i])
432                    if data.dy is not None and len(data.dy)==len(data.y):   
433                        dytemp.append(data.dy[i])
434                    else:
435                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
436            data= Data(x=xtemp,y=ytemp,dy=dytemp)
437            data.setFitRange(self.mini, self.maxi)
438            return data
439       
440       
441    def set_model(self,model,Uid,pars=[]):
442        """
443            set a model on a given uid in the fit engine.
444            @param model: the model to fit
445            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
446            @param pars: the list of parameters to fit
447            @note : pars must contains only name of existing model's paramaters
448        """
449        if len(pars) >0:
450            if model==None:
451                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
452            else:
453                temp=[]
454                for item in pars:
455                    if item in model.model.getParamList():
456                        temp.append(item)
457                        self.paramList.append(item)
458                    else:
459                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
460                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
461                        return
462            #A fitArrange is already created but contains dList only at Uid
463            if self.fitArrangeDict.has_key(Uid):
464                self.fitArrangeDict[Uid].set_model(model)
465                self.fitArrangeDict[Uid].pars= pars
466            else:
467            #no fitArrange object has been create with this Uid
468                fitproblem = FitArrange()
469                fitproblem.set_model(model)
470                fitproblem.pars= pars
471                self.fitArrangeDict[Uid] = fitproblem
472               
473        else:
474            raise ValueError, "park_integration:missing parameters"
475   
476    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
477        """ Receives plottable, creates a list of data to fit,set data
478            in a FitArrange object and adds that object in a dictionary
479            with key Uid.
480            @param data: data added
481            @param Uid: unique key corresponding to a fitArrange object with data
482        """
483        if data.__class__.__name__=='Data2D':
484            fitdata=FitData2D(data)
485        else:
486            fitdata=FitData1D(data, smearer)
487       
488        fitdata.setFitRange(qmin=qmin,qmax=qmax)
489        #A fitArrange is already created but contains model only at Uid
490        if self.fitArrangeDict.has_key(Uid):
491            self.fitArrangeDict[Uid].add_data(fitdata)
492        else:
493        #no fitArrange object has been create with this Uid
494            fitproblem= FitArrange()
495            fitproblem.add_data(fitdata)
496            self.fitArrangeDict[Uid]=fitproblem   
497   
498    def get_model(self,Uid):
499        """
500            @param Uid: Uid is key in the dictionary containing the model to return
501            @return  a model at this uid or None if no FitArrange element was created
502            with this Uid
503        """
504        if self.fitArrangeDict.has_key(Uid):
505            return self.fitArrangeDict[Uid].get_model()
506        else:
507            return None
508   
509    def remove_Fit_Problem(self,Uid):
510        """remove   fitarrange in Uid"""
511        if self.fitArrangeDict.has_key(Uid):
512            del self.fitArrangeDict[Uid]
513           
514    def select_problem_for_fit(self,Uid,value):
515        """
516            select a couple of model and data at the Uid position in dictionary
517            and set in self.selected value to value
518            @param value: the value to allow fitting. can only have the value one or zero
519        """
520        if self.fitArrangeDict.has_key(Uid):
521             self.fitArrangeDict[Uid].set_to_fit( value)
522             
523             
524    def get_problem_to_fit(self,Uid):
525        """
526            return the self.selected value of the fit problem of Uid
527           @param Uid: the Uid of the problem
528        """
529        if self.fitArrangeDict.has_key(Uid):
530             self.fitArrangeDict[Uid].get_to_fit()
531   
532class FitArrange:
533    def __init__(self):
534        """
535            Class FitArrange contains a set of data for a given model
536            to perform the Fit.FitArrange must contain exactly one model
537            and at least one data for the fit to be performed.
538            model: the model selected by the user
539            Ldata: a list of data what the user wants to fit
540           
541        """
542        self.model = None
543        self.dList =[]
544        self.pars=[]
545        #self.selected  is zero when this fit problem is not schedule to fit
546        #self.selected is 1 when schedule to fit
547        self.selected = 0
548       
549    def set_model(self,model):
550        """
551            set_model save a copy of the model
552            @param model: the model being set
553        """
554        self.model = model
555       
556    def add_data(self,data):
557        """
558            add_data fill a self.dList with data to fit
559            @param data: Data to add in the list 
560        """
561        if not data in self.dList:
562            self.dList.append(data)
563           
564    def get_model(self):
565        """ @return: saved model """
566        return self.model   
567     
568    def get_data(self):
569        """ @return:  list of data dList"""
570        #return self.dList
571        return self.dList[0] 
572     
573    def remove_data(self,data):
574        """
575            Remove one element from the list
576            @param data: Data to remove from dList
577        """
578        if data in self.dList:
579            self.dList.remove(data)
580    def set_to_fit (self, value=0):
581        """
582           set self.selected to 0 or 1  for other values raise an exception
583           @param value: integer between 0 or 1
584        """
585        self.selected= value
586       
587    def get_to_fit(self):
588        """
589            @return self.selected value
590        """
591        return self.selected
592   
593
594
595   
Note: See TracBrowser for help on using the repository browser.