source: sasview/park_integration/AbstractFitEngine.py @ 0e5e586

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 0e5e586 was 058b2d7, checked in by Gervaise Alina <gervyh@…>, 15 years ago

modify fitdata1d and fitdata2d residuals function

  • Property mode set to 100644
File size: 21.4 KB
Line 
1
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146       
147        self.qmin=mini           
148        self.qmax=maxi
149       
150       
151    def getFitRange(self):
152        """
153            @return the range of data.x to fit
154        """
155        return self.qmin, self.qmax
156     
157     
158    def residuals(self, fn):
159        """ @param fn: function that return model value
160            @return residuals
161        """
162        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
163        if self.qmin==None and self.qmax==None: 
164            fx =numpy.asarray([fn(v) for v in x])
165            return (y - fx)/dy
166        else:
167            idx = (x>=self.qmin) & (x <= self.qmax)
168            fx = numpy.asarray([fn(item)for item in x[idx ]])
169            return (y[idx] - fx)/dy[idx]
170       
171    def residuals_deriv(self, model, pars=[]):
172        """
173            @return residuals derivatives .
174            @note: in this case just return empty array
175        """
176        return []
177   
178   
179class FitData1D(object):
180    """ Wrapper class  for SANS data """
181    def __init__(self,sans_data1d, smearer=None):
182        """
183            Data can be initital with a data (sans plottable)
184            or with vectors.
185           
186            self.smearer is an object of class QSmearer or SlitSmearer
187            that will smear the theory data (slit smearing or resolution
188            smearing) when set.
189           
190            The proper way to set the smearing object would be to
191            do the following:
192           
193            from DataLoader.qsmearing import smear_selection
194            fitdata1d = FitData1D(some_data)
195            fitdata1d.smearer = smear_selection(some_data)
196           
197            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
198           
199            Setting it back to None will turn smearing off.
200           
201        """
202       
203        self.smearer = smearer
204     
205        # Initialize from Data1D object
206        self.data=sans_data1d
207        self.x= numpy.array(sans_data1d.x)
208        self.y= numpy.array(sans_data1d.y)
209        self.dx= numpy.array(sans_data1d.dx)
210        self.dy= numpy.array(sans_data1d.dy)
211         
212       
213        if self.dy ==None or len(self.dy)==0:
214            self.res_dy= numpy.zeros(len(self.y))
215        else:
216            self.res_dy= copy.deepcopy(self.dy)
217        self.res_dy= numpy.asarray(self.res_dy)
218       
219        self.res_dy[self.res_dy==0]=1
220        ## Min Q-value
221        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
222        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
223            self.qmin = min(self.data.x[self.data.x!=0])
224        else:                             
225            self.qmin= min (self.data.x)
226        ## Max Q-value
227        self.qmax= max (self.data.x)
228       
229       
230    def setFitRange(self,qmin=None,qmax=None):
231        """ to set the fit range"""
232        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
233        #ToDo: Fix this.
234        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
235            self.qmin = min(self.data.x[self.data.x!=0])
236        elif qmin!=None:                       
237            self.qmin = qmin           
238
239        if qmax !=None:
240            self.qmax = qmax
241       
242       
243    def getFitRange(self):
244        """
245            @return the range of data.x to fit
246        """
247        return self.qmin, self.qmax
248     
249     
250    def residuals(self, fn):
251        """
252        Compute residuals.
253       
254        If self.smearer has been set, use if to smear the data before computing chi squared.
255       
256        @param fn: function that return model value @return residuals """
257        # Get the indices of the selected range
258        idx = (self.x>=self.qmin) & (self.x <= self.qmax)
259       
260        # Compute theory data f(x)
261        newy = numpy.zeros(len(self.x))
262        newfx= numpy.zeros(len(self.x))
263        newdy= numpy.zeros(len(self.x))
264       
265        for i_x in range(len(self.x)):
266            try:
267                # Skip the selection here since we want all the contribution from the theory bins #if self.qmin <=x[i_x] and x[i_x]<=self.qmax:
268                value       = fn(self.x[i_x])
269                newfx[i_x]  = value
270                newy[i_x]   = self.y[i_x]
271                newdy[i_x]  = self.res_dy[i_x]
272               
273            except:
274                ## skip error for model.run(x)
275                pass
276       
277        ## Smear theory data
278        if self.smearer is not None:
279           
280            newfx = self.smearer(newfx)
281     
282        ## Sanity check
283        if numpy.size(newdy)!= numpy.size(newfx):
284            raise RuntimeError, "FitData1D: invalid error array"
285       
286        # Sum over the selected range
287        return (newy[idx]- newfx[idx])/newdy[idx]
288       
289
290    def residuals_deriv(self, model, pars=[]):
291        """
292            @return residuals derivatives .
293            @note: in this case just return empty array
294        """
295        return []
296   
297   
298class FitData2D(object):
299    """ Wrapper class  for SANS data """
300    def __init__(self,sans_data2d):
301        """
302            Data can be initital with a data (sans plottable)
303            or with vectors.
304        """
305        self.data=sans_data2d
306        self.image = sans_data2d.data
307        self.err_image = sans_data2d.err_data
308        self.x_bins= sans_data2d.x_bins
309        self.y_bins= sans_data2d.y_bins
310       
311        x = max(self.data.xmin, self.data.xmax)
312        y = max(self.data.ymin, self.data.ymax)
313       
314        ## fitting range
315        self.qmin = 1e-16
316        self.qmax = math.sqrt(x*x +y*y)
317        ## new error image for fitting purpose
318        if self.err_image== None or self.err_image ==[]:
319            self.res_err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
320        else:
321            self.res_err_image = copy.deepcopy(self.err_image)
322        self.res_err_image[self.err_image==0]=1
323       
324       
325    def setFitRange(self,qmin=None,qmax=None):
326        """ to set the fit range"""
327        if qmin==0.0:
328            self.qmin = 1e-16
329        elif qmin!=None:                       
330            self.qmin = qmin           
331        if qmax!=None:
332            self.qmax= qmax
333     
334       
335    def getFitRange(self):
336        """
337            @return the range of data.x to fit
338        """
339        return self.qmin, self.qmax
340     
341     
342    def residuals(self, fn):
343        """ @param fn: function that return model value
344            @return residuals
345        """
346        res=[]
347       
348        for i in range(len(self.x_bins)):
349            for j in range(len(self.y_bins)):
350                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
351                radius= math.sqrt(temp)
352                if self.qmin <= radius and radius <= self.qmax:
353                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
354                            /self.res_err_image[j][i] )
355       
356        return numpy.array(res)
357       
358         
359    def residuals_deriv(self, model, pars=[]):
360        """
361            @return residuals derivatives .
362            @note: in this case just return empty array
363        """
364        return []
365   
366class FitAbort(Exception):
367    """
368        Exception raise to stop the fit
369    """
370    print"Creating fit abort Exception"
371
372
373class SansAssembly:
374    """
375         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
376    """
377    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
378        """
379            @param Model: the model wrapper fro sans -model
380            @param Data: the data wrapper for sans data
381        """
382        self.model = Model
383        self.data  = Data
384        self.paramlist=paramlist
385        self.curr_thread= curr_thread
386        self.res=[]
387        self.func_name="Functor"
388    def chisq(self, params):
389        """
390            Calculates chi^2
391            @param params: list of parameter values
392            @return: chi^2
393        """
394        sum = 0
395        for item in self.res:
396            sum += item*item
397        if len(self.res)==0:
398            return None
399        return sum/ len(self.res)
400   
401    def __call__(self,params):
402        """
403            Compute residuals
404            @param params: value of parameters to fit
405        """
406        #import thread
407        self.model.setParams(self.paramlist,params)
408        self.res= self.data.residuals(self.model.eval)
409        #if self.curr_thread != None :
410        #    try:
411        #        self.curr_thread.isquit()
412        #    except:
413        #        raise FitAbort,"stop leastsqr optimizer"   
414        return self.res
415   
416class FitEngine:
417    def __init__(self):
418        """
419            Base class for scipy and park fit engine
420        """
421        #List of parameter names to fit
422        self.paramList=[]
423        #Dictionnary of fitArrange element (fit problems)
424        self.fitArrangeDict={}
425       
426    def _concatenateData(self, listdata=[]):
427        """ 
428            _concatenateData method concatenates each fields of all data contains ins listdata.
429            @param listdata: list of data
430            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
431             of data concatenanted
432            @raise: if listdata is empty  will return None
433            @raise: if data in listdata don't contain dy field ,will create an error
434            during fitting
435        """
436        #TODO: we have to refactor the way we handle data.
437        # We should move away from plottables and move towards the Data1D objects
438        # defined in DataLoader. Data1D allows data manipulations, which should be
439        # used to concatenate.
440        # In the meantime we should switch off the concatenation.
441        #if len(listdata)>1:
442        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
443        #return listdata[0]
444       
445        if listdata==[]:
446            raise ValueError, " data list missing"
447        else:
448            xtemp=[]
449            ytemp=[]
450            dytemp=[]
451            self.mini=None
452            self.maxi=None
453               
454            for item in listdata:
455                data=item.data
456                mini,maxi=data.getFitRange()
457                if self.mini==None and self.maxi==None:
458                    self.mini=mini
459                    self.maxi=maxi
460                else:
461                    if mini < self.mini:
462                        self.mini=mini
463                    if self.maxi < maxi:
464                        self.maxi=maxi
465                       
466                   
467                for i in range(len(data.x)):
468                    xtemp.append(data.x[i])
469                    ytemp.append(data.y[i])
470                    if data.dy is not None and len(data.dy)==len(data.y):   
471                        dytemp.append(data.dy[i])
472                    else:
473                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
474            data= Data(x=xtemp,y=ytemp,dy=dytemp)
475            data.setFitRange(self.mini, self.maxi)
476            return data
477       
478       
479    def set_model(self,model,Uid,pars=[]):
480        """
481            set a model on a given uid in the fit engine.
482            @param model: the model to fit
483            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
484            @param pars: the list of parameters to fit
485            @note : pars must contains only name of existing model's paramaters
486        """
487        if len(pars) >0:
488            if model==None:
489                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
490            else:
491                temp=[]
492                for item in pars:
493                    if item in model.model.getParamList():
494                        temp.append(item)
495                        self.paramList.append(item)
496                    else:
497                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
498                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
499                        return
500            #A fitArrange is already created but contains dList only at Uid
501            if self.fitArrangeDict.has_key(Uid):
502                self.fitArrangeDict[Uid].set_model(model)
503                self.fitArrangeDict[Uid].pars= pars
504            else:
505            #no fitArrange object has been create with this Uid
506                fitproblem = FitArrange()
507                fitproblem.set_model(model)
508                fitproblem.pars= pars
509                self.fitArrangeDict[Uid] = fitproblem
510               
511        else:
512            raise ValueError, "park_integration:missing parameters"
513   
514    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
515        """ Receives plottable, creates a list of data to fit,set data
516            in a FitArrange object and adds that object in a dictionary
517            with key Uid.
518            @param data: data added
519            @param Uid: unique key corresponding to a fitArrange object with data
520        """
521        if data.__class__.__name__=='Data2D':
522            fitdata=FitData2D(data)
523        else:
524            fitdata=FitData1D(data, smearer)
525       
526        fitdata.setFitRange(qmin=qmin,qmax=qmax)
527        #A fitArrange is already created but contains model only at Uid
528        if self.fitArrangeDict.has_key(Uid):
529            self.fitArrangeDict[Uid].add_data(fitdata)
530        else:
531        #no fitArrange object has been create with this Uid
532            fitproblem= FitArrange()
533            fitproblem.add_data(fitdata)
534            self.fitArrangeDict[Uid]=fitproblem   
535   
536    def get_model(self,Uid):
537        """
538            @param Uid: Uid is key in the dictionary containing the model to return
539            @return  a model at this uid or None if no FitArrange element was created
540            with this Uid
541        """
542        if self.fitArrangeDict.has_key(Uid):
543            return self.fitArrangeDict[Uid].get_model()
544        else:
545            return None
546   
547    def remove_Fit_Problem(self,Uid):
548        """remove   fitarrange in Uid"""
549        if self.fitArrangeDict.has_key(Uid):
550            del self.fitArrangeDict[Uid]
551           
552    def select_problem_for_fit(self,Uid,value):
553        """
554            select a couple of model and data at the Uid position in dictionary
555            and set in self.selected value to value
556            @param value: the value to allow fitting. can only have the value one or zero
557        """
558        if self.fitArrangeDict.has_key(Uid):
559             self.fitArrangeDict[Uid].set_to_fit( value)
560             
561             
562    def get_problem_to_fit(self,Uid):
563        """
564            return the self.selected value of the fit problem of Uid
565           @param Uid: the Uid of the problem
566        """
567        if self.fitArrangeDict.has_key(Uid):
568             self.fitArrangeDict[Uid].get_to_fit()
569   
570class FitArrange:
571    def __init__(self):
572        """
573            Class FitArrange contains a set of data for a given model
574            to perform the Fit.FitArrange must contain exactly one model
575            and at least one data for the fit to be performed.
576            model: the model selected by the user
577            Ldata: a list of data what the user wants to fit
578           
579        """
580        self.model = None
581        self.dList =[]
582        self.pars=[]
583        #self.selected  is zero when this fit problem is not schedule to fit
584        #self.selected is 1 when schedule to fit
585        self.selected = 0
586       
587    def set_model(self,model):
588        """
589            set_model save a copy of the model
590            @param model: the model being set
591        """
592        self.model = model
593       
594    def add_data(self,data):
595        """
596            add_data fill a self.dList with data to fit
597            @param data: Data to add in the list 
598        """
599        if not data in self.dList:
600            self.dList.append(data)
601           
602    def get_model(self):
603        """ @return: saved model """
604        return self.model   
605     
606    def get_data(self):
607        """ @return:  list of data dList"""
608        #return self.dList
609        return self.dList[0] 
610     
611    def remove_data(self,data):
612        """
613            Remove one element from the list
614            @param data: Data to remove from dList
615        """
616        if data in self.dList:
617            self.dList.remove(data)
618    def set_to_fit (self, value=0):
619        """
620           set self.selected to 0 or 1  for other values raise an exception
621           @param value: integer between 0 or 1
622        """
623        self.selected= value
624       
625    def get_to_fit(self):
626        """
627            @return self.selected value
628        """
629        return self.selected
Note: See TracBrowser for help on using the repository browser.