source: sasview/park_integration/AbstractFitEngine.py @ d5b488b

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since d5b488b was d5b488b, checked in by Gervaise Alina <gervyh@…>, 15 years ago

compute residual rejecting singular point for 1D fitting

  • Property mode set to 100644
File size: 21.3 KB
Line 
1
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146       
147        self.qmin=mini           
148        self.qmax=maxi
149       
150       
151    def getFitRange(self):
152        """
153            @return the range of data.x to fit
154        """
155        return self.qmin, self.qmax
156     
157     
158    def residuals(self, fn):
159        """ @param fn: function that return model value
160            @return residuals
161        """
162        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
163        if self.qmin==None and self.qmax==None: 
164            fx =numpy.asarray([fn(v) for v in x])
165            return (y - fx)/dy
166        else:
167            idx = (x>=self.qmin) & (x <= self.qmax)
168            fx = numpy.asarray([fn(item)for item in x[idx ]])
169            return (y[idx] - fx)/dy[idx]
170       
171    def residuals_deriv(self, model, pars=[]):
172        """
173            @return residuals derivatives .
174            @note: in this case just return empty array
175        """
176        return []
177   
178   
179class FitData1D(object):
180    """ Wrapper class  for SANS data """
181    def __init__(self,sans_data1d, smearer=None):
182        """
183            Data can be initital with a data (sans plottable)
184            or with vectors.
185           
186            self.smearer is an object of class QSmearer or SlitSmearer
187            that will smear the theory data (slit smearing or resolution
188            smearing) when set.
189           
190            The proper way to set the smearing object would be to
191            do the following:
192           
193            from DataLoader.qsmearing import smear_selection
194            fitdata1d = FitData1D(some_data)
195            fitdata1d.smearer = smear_selection(some_data)
196           
197            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
198           
199            Setting it back to None will turn smearing off.
200           
201        """
202       
203        self.smearer = smearer
204     
205        # Initialize from Data1D object
206        self.data=sans_data1d
207        self.x= sans_data1d.x
208        self.y= sans_data1d.y
209        self.dx= sans_data1d.dx
210        self.dy= sans_data1d.dy
211       
212        ## Min Q-value
213        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
214        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
215            self.qmin = min(self.data.x[self.data.x!=0])
216        else:                             
217            self.qmin= min (self.data.x)
218        ## Max Q-value
219        self.qmax= max (self.data.x)
220       
221       
222    def setFitRange(self,qmin=None,qmax=None):
223        """ to set the fit range"""
224       
225        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
226        #ToDo: Fix this.
227        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
228            self.qmin = min(self.data.x[self.data.x!=0])
229        elif qmin!=None:                       
230            self.qmin = qmin           
231
232        if qmax !=None:
233            self.qmax = qmax
234       
235       
236    def getFitRange(self):
237        """
238            @return the range of data.x to fit
239        """
240        return self.qmin, self.qmax
241     
242     
243    def residuals(self, fn):
244        """
245            Compute residuals.
246           
247            If self.smearer has been set, use if to smear
248            the data before computing chi squared.
249           
250            @param fn: function that return model value
251            @return residuals
252        """
253        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
254        if self.dy ==None or self.dy==[]:
255            dy= numpy.zeros(len(y)) 
256        else:
257            dy= copy.deepcopy(self.dy)
258            dy= numpy.asarray(dy)
259     
260        dy[dy==0]=1
261        #idx = (x>=self.qmin) & (x <= self.qmax)
262 
263        # Compute theory data f(x)
264        #fx = numpy.zeros(len(x))
265        tempy=[]
266        tempfx=[]
267        tempdy=[]
268        #fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
269        for i_x in  range(len(x)):
270            try:
271                if self.qmin <=x[i_x] and x[i_x]<=self.qmax:
272                    value= fn(x[i_x])
273                    tempfx.append( value)
274                    tempy.append(y[i_x])
275                    tempdy.append(dy[i_x])
276            except:
277                ## skip error for model.run(x)
278                pass
279                 
280        ## Smear theory data
281        if self.smearer is not None:
282            tempfx = self.smearer(tempfx)
283        newy= numpy.asarray(tempy)
284        newfx= numpy.asarray(tempfx)
285        newdy= numpy.asarray(tempdy)
286       
287       
288        ## Sanity check
289        if numpy.size(newdy)!= numpy.size(newfx):
290            raise RuntimeError, "FitData1D: invalid error array"
291        #return (y[idx] - fx[idx])/dy[idx]
292       
293        return (newy- newfx)/newdy
294     
295 
296       
297    def residuals_deriv(self, model, pars=[]):
298        """
299            @return residuals derivatives .
300            @note: in this case just return empty array
301        """
302        return []
303   
304   
305class FitData2D(object):
306    """ Wrapper class  for SANS data """
307    def __init__(self,sans_data2d):
308        """
309            Data can be initital with a data (sans plottable)
310            or with vectors.
311        """
312        self.data=sans_data2d
313        self.image = sans_data2d.data
314        self.err_image = sans_data2d.err_data
315        self.x_bins= sans_data2d.x_bins
316        self.y_bins= sans_data2d.y_bins
317       
318        x = max(self.data.xmin, self.data.xmax)
319        y = max(self.data.ymin, self.data.ymax)
320       
321        ## fitting range
322        self.qmin = 1e-16
323        self.qmax = math.sqrt(x*x +y*y)
324       
325       
326       
327    def setFitRange(self,qmin=None,qmax=None):
328        """ to set the fit range"""
329        if qmin==0.0:
330            self.qmin = 1e-16
331        elif qmin!=None:                       
332            self.qmin = qmin           
333        if qmax!=None:
334            self.qmax= qmax
335     
336       
337    def getFitRange(self):
338        """
339            @return the range of data.x to fit
340        """
341        return self.qmin, self.qmax
342     
343     
344    def residuals(self, fn):
345        """ @param fn: function that return model value
346            @return residuals
347        """
348        res=[]
349        if self.err_image== None or self.err_image ==[]:
350            err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
351        else:
352            err_image = copy.deepcopy(self.err_image)
353           
354        err_image[err_image==0]=1
355        for i in range(len(self.x_bins)):
356            for j in range(len(self.y_bins)):
357                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
358                radius= math.sqrt(temp)
359                if self.qmin <= radius and radius <= self.qmax:
360                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
361                            /err_image[j][i] )
362       
363        return numpy.array(res)
364       
365         
366    def residuals_deriv(self, model, pars=[]):
367        """
368            @return residuals derivatives .
369            @note: in this case just return empty array
370        """
371        return []
372   
373class FitAbort(Exception):
374    """
375        Exception raise to stop the fit
376    """
377    print"Creating fit abort Exception"
378
379
380class sansAssembly:
381    """
382         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
383    """
384    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
385        """
386            @param Model: the model wrapper fro sans -model
387            @param Data: the data wrapper for sans data
388        """
389        self.model = Model
390        self.data  = Data
391        self.paramlist=paramlist
392        self.curr_thread= curr_thread
393        self.res=[]
394        self.func_name="Functor"
395    def chisq(self, params):
396        """
397            Calculates chi^2
398            @param params: list of parameter values
399            @return: chi^2
400        """
401        sum = 0
402        for item in self.res:
403            sum += item*item
404        if len(self.res)==0:
405            return None
406        return sum/ len(self.res)
407   
408    def __call__(self,params):
409        """
410            Compute residuals
411            @param params: value of parameters to fit
412        """
413        #import thread
414        self.model.setParams(self.paramlist,params)
415        self.res= self.data.residuals(self.model.eval)
416        if self.curr_thread != None :
417            try:
418                self.curr_thread.isquit()
419            except:
420                raise FitAbort,"stop leastsqr optimizer" 
421               
422        return self.res
423   
424class FitEngine:
425    def __init__(self):
426        """
427            Base class for scipy and park fit engine
428        """
429        #List of parameter names to fit
430        self.paramList=[]
431        #Dictionnary of fitArrange element (fit problems)
432        self.fitArrangeDict={}
433       
434    def _concatenateData(self, listdata=[]):
435        """ 
436            _concatenateData method concatenates each fields of all data contains ins listdata.
437            @param listdata: list of data
438            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
439             of data concatenanted
440            @raise: if listdata is empty  will return None
441            @raise: if data in listdata don't contain dy field ,will create an error
442            during fitting
443        """
444        #TODO: we have to refactor the way we handle data.
445        # We should move away from plottables and move towards the Data1D objects
446        # defined in DataLoader. Data1D allows data manipulations, which should be
447        # used to concatenate.
448        # In the meantime we should switch off the concatenation.
449        #if len(listdata)>1:
450        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
451        #return listdata[0]
452       
453        if listdata==[]:
454            raise ValueError, " data list missing"
455        else:
456            xtemp=[]
457            ytemp=[]
458            dytemp=[]
459            self.mini=None
460            self.maxi=None
461               
462            for item in listdata:
463                data=item.data
464                mini,maxi=data.getFitRange()
465                if self.mini==None and self.maxi==None:
466                    self.mini=mini
467                    self.maxi=maxi
468                else:
469                    if mini < self.mini:
470                        self.mini=mini
471                    if self.maxi < maxi:
472                        self.maxi=maxi
473                       
474                   
475                for i in range(len(data.x)):
476                    xtemp.append(data.x[i])
477                    ytemp.append(data.y[i])
478                    if data.dy is not None and len(data.dy)==len(data.y):   
479                        dytemp.append(data.dy[i])
480                    else:
481                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
482            data= Data(x=xtemp,y=ytemp,dy=dytemp)
483            data.setFitRange(self.mini, self.maxi)
484            return data
485       
486       
487    def set_model(self,model,Uid,pars=[]):
488        """
489            set a model on a given uid in the fit engine.
490            @param model: the model to fit
491            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
492            @param pars: the list of parameters to fit
493            @note : pars must contains only name of existing model's paramaters
494        """
495        if len(pars) >0:
496            if model==None:
497                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
498            else:
499                temp=[]
500                for item in pars:
501                    if item in model.model.getParamList():
502                        temp.append(item)
503                        self.paramList.append(item)
504                    else:
505                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
506                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
507                        return
508            #A fitArrange is already created but contains dList only at Uid
509            if self.fitArrangeDict.has_key(Uid):
510                self.fitArrangeDict[Uid].set_model(model)
511                self.fitArrangeDict[Uid].pars= pars
512            else:
513            #no fitArrange object has been create with this Uid
514                fitproblem = FitArrange()
515                fitproblem.set_model(model)
516                fitproblem.pars= pars
517                self.fitArrangeDict[Uid] = fitproblem
518               
519        else:
520            raise ValueError, "park_integration:missing parameters"
521   
522    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
523        """ Receives plottable, creates a list of data to fit,set data
524            in a FitArrange object and adds that object in a dictionary
525            with key Uid.
526            @param data: data added
527            @param Uid: unique key corresponding to a fitArrange object with data
528        """
529        if data.__class__.__name__=='Data2D':
530            fitdata=FitData2D(data)
531        else:
532            fitdata=FitData1D(data, smearer)
533       
534        fitdata.setFitRange(qmin=qmin,qmax=qmax)
535        #A fitArrange is already created but contains model only at Uid
536        if self.fitArrangeDict.has_key(Uid):
537            self.fitArrangeDict[Uid].add_data(fitdata)
538        else:
539        #no fitArrange object has been create with this Uid
540            fitproblem= FitArrange()
541            fitproblem.add_data(fitdata)
542            self.fitArrangeDict[Uid]=fitproblem   
543   
544    def get_model(self,Uid):
545        """
546            @param Uid: Uid is key in the dictionary containing the model to return
547            @return  a model at this uid or None if no FitArrange element was created
548            with this Uid
549        """
550        if self.fitArrangeDict.has_key(Uid):
551            return self.fitArrangeDict[Uid].get_model()
552        else:
553            return None
554   
555    def remove_Fit_Problem(self,Uid):
556        """remove   fitarrange in Uid"""
557        if self.fitArrangeDict.has_key(Uid):
558            del self.fitArrangeDict[Uid]
559           
560    def select_problem_for_fit(self,Uid,value):
561        """
562            select a couple of model and data at the Uid position in dictionary
563            and set in self.selected value to value
564            @param value: the value to allow fitting. can only have the value one or zero
565        """
566        if self.fitArrangeDict.has_key(Uid):
567             self.fitArrangeDict[Uid].set_to_fit( value)
568             
569             
570    def get_problem_to_fit(self,Uid):
571        """
572            return the self.selected value of the fit problem of Uid
573           @param Uid: the Uid of the problem
574        """
575        if self.fitArrangeDict.has_key(Uid):
576             self.fitArrangeDict[Uid].get_to_fit()
577   
578class FitArrange:
579    def __init__(self):
580        """
581            Class FitArrange contains a set of data for a given model
582            to perform the Fit.FitArrange must contain exactly one model
583            and at least one data for the fit to be performed.
584            model: the model selected by the user
585            Ldata: a list of data what the user wants to fit
586           
587        """
588        self.model = None
589        self.dList =[]
590        self.pars=[]
591        #self.selected  is zero when this fit problem is not schedule to fit
592        #self.selected is 1 when schedule to fit
593        self.selected = 0
594       
595    def set_model(self,model):
596        """
597            set_model save a copy of the model
598            @param model: the model being set
599        """
600        self.model = model
601       
602    def add_data(self,data):
603        """
604            add_data fill a self.dList with data to fit
605            @param data: Data to add in the list 
606        """
607        if not data in self.dList:
608            self.dList.append(data)
609           
610    def get_model(self):
611        """ @return: saved model """
612        return self.model   
613     
614    def get_data(self):
615        """ @return:  list of data dList"""
616        #return self.dList
617        return self.dList[0] 
618     
619    def remove_data(self,data):
620        """
621            Remove one element from the list
622            @param data: Data to remove from dList
623        """
624        if data in self.dList:
625            self.dList.remove(data)
626    def set_to_fit (self, value=0):
627        """
628           set self.selected to 0 or 1  for other values raise an exception
629           @param value: integer between 0 or 1
630        """
631        self.selected= value
632       
633    def get_to_fit(self):
634        """
635            @return self.selected value
636        """
637        return self.selected
638   
639
640
641   
Note: See TracBrowser for help on using the repository browser.