source: sasview/park_integration/AbstractFitEngine.py @ 077809c

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 077809c was 70bf68c, checked in by Gervaise Alina <gervyh@…>, 15 years ago

working on residual function

  • Property mode set to 100644
File size: 21.4 KB
Line 
1
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146       
147        self.qmin=mini           
148        self.qmax=maxi
149       
150       
151    def getFitRange(self):
152        """
153            @return the range of data.x to fit
154        """
155        return self.qmin, self.qmax
156     
157     
158    def residuals(self, fn):
159        """ @param fn: function that return model value
160            @return residuals
161        """
162        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
163        if self.qmin==None and self.qmax==None: 
164            fx =numpy.asarray([fn(v) for v in x])
165            return (y - fx)/dy
166        else:
167            idx = (x>=self.qmin) & (x <= self.qmax)
168            fx = numpy.asarray([fn(item)for item in x[idx ]])
169            return (y[idx] - fx)/dy[idx]
170       
171    def residuals_deriv(self, model, pars=[]):
172        """
173            @return residuals derivatives .
174            @note: in this case just return empty array
175        """
176        return []
177   
178   
179class FitData1D(object):
180    """ Wrapper class  for SANS data """
181    def __init__(self,sans_data1d, smearer=None):
182        """
183            Data can be initital with a data (sans plottable)
184            or with vectors.
185           
186            self.smearer is an object of class QSmearer or SlitSmearer
187            that will smear the theory data (slit smearing or resolution
188            smearing) when set.
189           
190            The proper way to set the smearing object would be to
191            do the following:
192           
193            from DataLoader.qsmearing import smear_selection
194            fitdata1d = FitData1D(some_data)
195            fitdata1d.smearer = smear_selection(some_data)
196           
197            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
198           
199            Setting it back to None will turn smearing off.
200           
201        """
202       
203        self.smearer = smearer
204     
205        # Initialize from Data1D object
206        self.data=sans_data1d
207        self.x= sans_data1d.x
208        self.y= sans_data1d.y
209        self.dx= sans_data1d.dx
210        self.dy= sans_data1d.dy
211       
212        ## Min Q-value
213        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
214        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
215            self.qmin = min(self.data.x[self.data.x!=0])
216        else:                             
217            self.qmin= min (self.data.x)
218        ## Max Q-value
219        self.qmax= max (self.data.x)
220       
221       
222    def setFitRange(self,qmin=None,qmax=None):
223        """ to set the fit range"""
224       
225        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
226        #ToDo: Fix this.
227        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
228            self.qmin = min(self.data.x[self.data.x!=0])
229        elif qmin!=None:                       
230            self.qmin = qmin           
231
232        if qmax !=None:
233            self.qmax = qmax
234       
235       
236    def getFitRange(self):
237        """
238            @return the range of data.x to fit
239        """
240        return self.qmin, self.qmax
241     
242     
243    def residuals(self, fn):
244        """
245            Compute residuals.
246           
247            If self.smearer has been set, use if to smear
248            the data before computing chi squared.
249           
250            @param fn: function that return model value
251            @return residuals
252        """
253        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
254        if self.dy ==None or self.dy==[]:
255            dy= numpy.zeros(len(y)) 
256        else:
257            dy= copy.deepcopy(self.dy)
258            dy= numpy.asarray(dy)
259     
260        dy[dy==0]=1
261       
262        # Compute theory data f(x)
263        tempy=[]
264        fx=numpy.zeros(len(y))
265        tempdy=[]
266        index=[]
267        tempfx=[]
268        for i_x in  range(len(x)):
269            try:
270                if self.qmin <=x[i_x] and x[i_x]<=self.qmax:
271                    value= fn(x[i_x])
272                    fx[i_x] =value
273                    tempy.append(y[i_x])
274                    tempdy.append(dy[i_x])
275                    index.append(i_x)
276            except:
277                ## skip error for model.run(x)
278                pass
279                 
280        ## Smear theory data
281        if self.smearer is not None:
282            fx = self.smearer(fx)
283           
284        for i in index:
285            tempfx.append(fx[i])
286     
287        newy= numpy.asarray(tempy)
288        newfx= numpy.asarray(tempfx)
289        newdy= numpy.asarray(tempdy)
290       
291        ## Sanity check
292        if numpy.size(newdy)!= numpy.size(newfx):
293            raise RuntimeError, "FitData1D: invalid error array"
294       
295        return (newy- newfx)/newdy
296     
297 
298       
299    def residuals_deriv(self, model, pars=[]):
300        """
301            @return residuals derivatives .
302            @note: in this case just return empty array
303        """
304        return []
305   
306   
307class FitData2D(object):
308    """ Wrapper class  for SANS data """
309    def __init__(self,sans_data2d):
310        """
311            Data can be initital with a data (sans plottable)
312            or with vectors.
313        """
314        self.data=sans_data2d
315        self.image = sans_data2d.data
316        self.err_image = sans_data2d.err_data
317        self.x_bins= sans_data2d.x_bins
318        self.y_bins= sans_data2d.y_bins
319       
320        x = max(self.data.xmin, self.data.xmax)
321        y = max(self.data.ymin, self.data.ymax)
322       
323        ## fitting range
324        self.qmin = 1e-16
325        self.qmax = math.sqrt(x*x +y*y)
326        ## new error image for fitting purpose
327        if self.err_image== None or self.err_image ==[]:
328            self.res_err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
329        else:
330            self.res_err_image = copy.deepcopy(self.err_image)
331        self.res_err_image[self.err_image==0]=1
332       
333       
334    def setFitRange(self,qmin=None,qmax=None):
335        """ to set the fit range"""
336        if qmin==0.0:
337            self.qmin = 1e-16
338        elif qmin!=None:                       
339            self.qmin = qmin           
340        if qmax!=None:
341            self.qmax= qmax
342     
343       
344    def getFitRange(self):
345        """
346            @return the range of data.x to fit
347        """
348        return self.qmin, self.qmax
349     
350     
351    def residuals(self, fn):
352        """ @param fn: function that return model value
353            @return residuals
354        """
355        res=[]
356       
357        for i in range(len(self.x_bins)):
358            for j in range(len(self.y_bins)):
359                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
360                radius= math.sqrt(temp)
361                if self.qmin <= radius and radius <= self.qmax:
362                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
363                            /self.res_err_image[j][i] )
364       
365        return numpy.array(res)
366       
367         
368    def residuals_deriv(self, model, pars=[]):
369        """
370            @return residuals derivatives .
371            @note: in this case just return empty array
372        """
373        return []
374   
375class FitAbort(Exception):
376    """
377        Exception raise to stop the fit
378    """
379    print"Creating fit abort Exception"
380
381
382class SansAssembly:
383    """
384         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
385    """
386    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
387        """
388            @param Model: the model wrapper fro sans -model
389            @param Data: the data wrapper for sans data
390        """
391        self.model = Model
392        self.data  = Data
393        self.paramlist=paramlist
394        self.curr_thread= curr_thread
395        self.res=[]
396        self.func_name="Functor"
397    def chisq(self, params):
398        """
399            Calculates chi^2
400            @param params: list of parameter values
401            @return: chi^2
402        """
403        sum = 0
404        for item in self.res:
405            sum += item*item
406        if len(self.res)==0:
407            return None
408        return sum/ len(self.res)
409   
410    def __call__(self,params):
411        """
412            Compute residuals
413            @param params: value of parameters to fit
414        """
415        #import thread
416        self.model.setParams(self.paramlist,params)
417        self.res= self.data.residuals(self.model.eval)
418        #if self.curr_thread != None :
419        #    try:
420        #        self.curr_thread.isquit()
421        #    except:
422        #        raise FitAbort,"stop leastsqr optimizer"   
423        return self.res
424   
425class FitEngine:
426    def __init__(self):
427        """
428            Base class for scipy and park fit engine
429        """
430        #List of parameter names to fit
431        self.paramList=[]
432        #Dictionnary of fitArrange element (fit problems)
433        self.fitArrangeDict={}
434       
435    def _concatenateData(self, listdata=[]):
436        """ 
437            _concatenateData method concatenates each fields of all data contains ins listdata.
438            @param listdata: list of data
439            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
440             of data concatenanted
441            @raise: if listdata is empty  will return None
442            @raise: if data in listdata don't contain dy field ,will create an error
443            during fitting
444        """
445        #TODO: we have to refactor the way we handle data.
446        # We should move away from plottables and move towards the Data1D objects
447        # defined in DataLoader. Data1D allows data manipulations, which should be
448        # used to concatenate.
449        # In the meantime we should switch off the concatenation.
450        #if len(listdata)>1:
451        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
452        #return listdata[0]
453       
454        if listdata==[]:
455            raise ValueError, " data list missing"
456        else:
457            xtemp=[]
458            ytemp=[]
459            dytemp=[]
460            self.mini=None
461            self.maxi=None
462               
463            for item in listdata:
464                data=item.data
465                mini,maxi=data.getFitRange()
466                if self.mini==None and self.maxi==None:
467                    self.mini=mini
468                    self.maxi=maxi
469                else:
470                    if mini < self.mini:
471                        self.mini=mini
472                    if self.maxi < maxi:
473                        self.maxi=maxi
474                       
475                   
476                for i in range(len(data.x)):
477                    xtemp.append(data.x[i])
478                    ytemp.append(data.y[i])
479                    if data.dy is not None and len(data.dy)==len(data.y):   
480                        dytemp.append(data.dy[i])
481                    else:
482                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
483            data= Data(x=xtemp,y=ytemp,dy=dytemp)
484            data.setFitRange(self.mini, self.maxi)
485            return data
486       
487       
488    def set_model(self,model,Uid,pars=[]):
489        """
490            set a model on a given uid in the fit engine.
491            @param model: the model to fit
492            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
493            @param pars: the list of parameters to fit
494            @note : pars must contains only name of existing model's paramaters
495        """
496        if len(pars) >0:
497            if model==None:
498                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
499            else:
500                temp=[]
501                for item in pars:
502                    if item in model.model.getParamList():
503                        temp.append(item)
504                        self.paramList.append(item)
505                    else:
506                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
507                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
508                        return
509            #A fitArrange is already created but contains dList only at Uid
510            if self.fitArrangeDict.has_key(Uid):
511                self.fitArrangeDict[Uid].set_model(model)
512                self.fitArrangeDict[Uid].pars= pars
513            else:
514            #no fitArrange object has been create with this Uid
515                fitproblem = FitArrange()
516                fitproblem.set_model(model)
517                fitproblem.pars= pars
518                self.fitArrangeDict[Uid] = fitproblem
519               
520        else:
521            raise ValueError, "park_integration:missing parameters"
522   
523    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
524        """ Receives plottable, creates a list of data to fit,set data
525            in a FitArrange object and adds that object in a dictionary
526            with key Uid.
527            @param data: data added
528            @param Uid: unique key corresponding to a fitArrange object with data
529        """
530        if data.__class__.__name__=='Data2D':
531            fitdata=FitData2D(data)
532        else:
533            fitdata=FitData1D(data, smearer)
534       
535        fitdata.setFitRange(qmin=qmin,qmax=qmax)
536        #A fitArrange is already created but contains model only at Uid
537        if self.fitArrangeDict.has_key(Uid):
538            self.fitArrangeDict[Uid].add_data(fitdata)
539        else:
540        #no fitArrange object has been create with this Uid
541            fitproblem= FitArrange()
542            fitproblem.add_data(fitdata)
543            self.fitArrangeDict[Uid]=fitproblem   
544   
545    def get_model(self,Uid):
546        """
547            @param Uid: Uid is key in the dictionary containing the model to return
548            @return  a model at this uid or None if no FitArrange element was created
549            with this Uid
550        """
551        if self.fitArrangeDict.has_key(Uid):
552            return self.fitArrangeDict[Uid].get_model()
553        else:
554            return None
555   
556    def remove_Fit_Problem(self,Uid):
557        """remove   fitarrange in Uid"""
558        if self.fitArrangeDict.has_key(Uid):
559            del self.fitArrangeDict[Uid]
560           
561    def select_problem_for_fit(self,Uid,value):
562        """
563            select a couple of model and data at the Uid position in dictionary
564            and set in self.selected value to value
565            @param value: the value to allow fitting. can only have the value one or zero
566        """
567        if self.fitArrangeDict.has_key(Uid):
568             self.fitArrangeDict[Uid].set_to_fit( value)
569             
570             
571    def get_problem_to_fit(self,Uid):
572        """
573            return the self.selected value of the fit problem of Uid
574           @param Uid: the Uid of the problem
575        """
576        if self.fitArrangeDict.has_key(Uid):
577             self.fitArrangeDict[Uid].get_to_fit()
578   
579class FitArrange:
580    def __init__(self):
581        """
582            Class FitArrange contains a set of data for a given model
583            to perform the Fit.FitArrange must contain exactly one model
584            and at least one data for the fit to be performed.
585            model: the model selected by the user
586            Ldata: a list of data what the user wants to fit
587           
588        """
589        self.model = None
590        self.dList =[]
591        self.pars=[]
592        #self.selected  is zero when this fit problem is not schedule to fit
593        #self.selected is 1 when schedule to fit
594        self.selected = 0
595       
596    def set_model(self,model):
597        """
598            set_model save a copy of the model
599            @param model: the model being set
600        """
601        self.model = model
602       
603    def add_data(self,data):
604        """
605            add_data fill a self.dList with data to fit
606            @param data: Data to add in the list 
607        """
608        if not data in self.dList:
609            self.dList.append(data)
610           
611    def get_model(self):
612        """ @return: saved model """
613        return self.model   
614     
615    def get_data(self):
616        """ @return:  list of data dList"""
617        #return self.dList
618        return self.dList[0] 
619     
620    def remove_data(self,data):
621        """
622            Remove one element from the list
623            @param data: Data to remove from dList
624        """
625        if data in self.dList:
626            self.dList.remove(data)
627    def set_to_fit (self, value=0):
628        """
629           set self.selected to 0 or 1  for other values raise an exception
630           @param value: integer between 0 or 1
631        """
632        self.selected= value
633       
634    def get_to_fit(self):
635        """
636            @return self.selected value
637        """
638        return self.selected
639   
640
641
642   
Note: See TracBrowser for help on using the repository browser.