source: sasview/park_integration/AbstractFitEngine.py @ c9aa125

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since c9aa125 was d8a2e31, checked in by Gervaise Alina <gervyh@…>, 15 years ago

use eval distribution as well as run

  • Property mode set to 100644
File size: 23.7 KB
Line 
1import logging, sys
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        try:
115            return self.model.evalDistribution(x)
116        except:
117            return self.model.runXY(x)
118   
119   
120
121
122class Data(object):
123    """ Wrapper class  for SANS data """
124    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
125        """
126            Data can be initital with a data (sans plottable)
127            or with vectors.
128        """
129        if  sans_data !=None:
130            self.x= sans_data.x
131            self.y= sans_data.y
132            self.dx= sans_data.dx
133            self.dy= sans_data.dy
134           
135        elif (x!=None and y!=None and dy!=None):
136                self.x=x
137                self.y=y
138                self.dx=dx
139                self.dy=dy
140        else:
141            raise ValueError,\
142            "Data is missing x, y or dy, impossible to compute residuals later on"
143        self.qmin=None
144        self.qmax=None
145       
146       
147    def setFitRange(self,mini=None,maxi=None):
148        """ to set the fit range"""
149       
150        self.qmin=mini           
151        self.qmax=maxi
152       
153       
154    def getFitRange(self):
155        """
156            @return the range of data.x to fit
157        """
158        return self.qmin, self.qmax
159     
160     
161    def residuals(self, fn):
162        """ @param fn: function that return model value
163            @return residuals
164        """
165        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
166        if self.qmin==None and self.qmax==None: 
167            fx =numpy.asarray([fn(v) for v in x])
168            return (y - fx)/dy
169        else:
170            idx = (x>=self.qmin) & (x <= self.qmax)
171            fx = numpy.asarray([fn(item)for item in x[idx ]])
172            return (y[idx] - fx)/dy[idx]
173       
174    def residuals_deriv(self, model, pars=[]):
175        """
176            @return residuals derivatives .
177            @note: in this case just return empty array
178        """
179        return []
180   
181   
182class FitData1D(object):
183    """ Wrapper class  for SANS data """
184    def __init__(self,sans_data1d, smearer=None):
185        """
186            Data can be initital with a data (sans plottable)
187            or with vectors.
188           
189            self.smearer is an object of class QSmearer or SlitSmearer
190            that will smear the theory data (slit smearing or resolution
191            smearing) when set.
192           
193            The proper way to set the smearing object would be to
194            do the following:
195           
196            from DataLoader.qsmearing import smear_selection
197            fitdata1d = FitData1D(some_data)
198            fitdata1d.smearer = smear_selection(some_data)
199           
200            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
201           
202            Setting it back to None will turn smearing off.
203           
204        """
205       
206        self.smearer = smearer
207     
208        # Initialize from Data1D object
209        self.data=sans_data1d
210        self.x= sans_data1d.x
211        self.y= sans_data1d.y
212        self.dx= sans_data1d.dx
213        self.dy= sans_data1d.dy
214       
215        ## Min Q-value
216        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
217        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
218            self.qmin = min(self.data.x[self.data.x!=0])
219        else:                             
220            self.qmin= min (self.data.x)
221        ## Max Q-value
222        self.qmax= max (self.data.x)
223       
224        # Range used for input to smearing
225        self._qmin_unsmeared = self.qmin
226        self._qmax_unsmeared = self.qmax
227       
228       
229    def setFitRange(self,qmin=None,qmax=None):
230        """ to set the fit range"""
231        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
232        #ToDo: Fix this.
233        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
234            self.qmin = min(self.data.x[self.data.x!=0])
235        elif qmin!=None:                       
236            self.qmin = qmin           
237
238        if qmax !=None:
239            self.qmax = qmax
240           
241        # Range used for input to smearing
242        self._qmin_unsmeared = self.qmin
243        self._qmax_unsmeared = self.qmax   
244       
245        # Determine the range needed in unsmeared-Q to cover
246        # the smeared Q range
247        #TODO: use the smearing matrix to determine which
248        # bin range to use
249        if self.smearer.__class__.__name__ == 'SlitSmearer':
250            self._qmin_unsmeared = min(self.data.x)
251            self._qmax_unsmeared = max(self.data.x)
252        elif self.smearer.__class__.__name__ == 'QSmearer':
253            # Take 3 sigmas as the offset between smeared and unsmeared space
254            try:
255                offset = 3.0*max(self.smearer.width)
256                self._qmin_unsmeared = max([min(self.data.x), self.qmin-offset])
257                self._qmax_unsmeared = min([max(self.data.x), self.qmax+offset])
258            except:
259                logging.error("FitData1D.setFitRange: %s" % sys.exc_value)
260       
261       
262    def getFitRange(self):
263        """
264            @return the range of data.x to fit
265        """
266        return self.qmin, self.qmax
267       
268    def residuals(self, fn):
269        """
270            Compute residuals.
271           
272            If self.smearer has been set, use if to smear
273            the data before computing chi squared.
274           
275            @param fn: function that return model value
276            @return residuals
277        """
278        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
279        if self.dy ==None or self.dy==[]:
280            dy= numpy.zeros(len(y)) 
281        else:
282            dy= numpy.asarray(self.dy)
283     
284        # For fitting purposes, replace zero errors by 1
285        #TODO: check validity for the rare case where only
286        # a few points have zero errors
287        dy[dy==0]=1
288       
289        # Identify the bin range for the unsmeared and smeared spaces
290        idx = (x>=self.qmin) & (x <= self.qmax)
291        idx_unsmeared = (x>=self._qmin_unsmeared) & (x <= self._qmax_unsmeared)
292 
293        # Compute theory data f(x)
294        fx= numpy.zeros(len(x))
295   
296        _first_bin = None
297        _last_bin  = None
298        for i_x in range(len(x)):
299            try:
300                if idx_unsmeared[i_x]==True:
301                    # Identify first and last bin
302                    #TODO: refactor this to pass q-values to the smearer
303                    # and let it figure out which bin range to use
304                    if _first_bin is None:
305                        _first_bin = i_x
306                    else:
307                        _last_bin  = i_x
308                   
309                    value = fn(x[i_x])
310                    fx[i_x] = value
311            except:
312                ## skip error for model.run(x)
313                pass
314                 
315        ## Smear theory data
316        if self.smearer is not None:
317            fx = self.smearer(fx, _first_bin, _last_bin)
318       
319        ## Sanity check
320        if numpy.size(dy)!= numpy.size(fx):
321            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.size(dy), numpy.size(fx))
322
323        return (y[idx]-fx[idx])/dy[idx]
324     
325 
326       
327    def residuals_deriv(self, model, pars=[]):
328        """
329            @return residuals derivatives .
330            @note: in this case just return empty array
331        """
332        return []
333   
334   
335class FitData2D(object):
336    """ Wrapper class  for SANS data """
337    def __init__(self,sans_data2d):
338        """
339            Data can be initital with a data (sans plottable)
340            or with vectors.
341        """
342        self.data=sans_data2d
343        self.image = sans_data2d.data
344        self.err_image = sans_data2d.err_data
345        self.x_bins_array= numpy.reshape(sans_data2d.x_bins,
346                                         [1,len(sans_data2d.x_bins)])
347        self.y_bins_array = numpy.reshape(sans_data2d.y_bins,
348                                          [len(sans_data2d.y_bins),1])
349       
350       
351        self.x_bins = sans_data2d.x_bins
352        self.y_bins = sans_data2d.y_bins
353       
354        x = max(self.data.xmin, self.data.xmax)
355        y = max(self.data.ymin, self.data.ymax)
356       
357        ## fitting range
358        self.qmin = 1e-16
359        self.qmax = math.sqrt(x*x +y*y)
360        ## new error image for fitting purpose
361        if self.err_image== None or self.err_image ==[]:
362            self.res_err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
363        else:
364            self.res_err_image = copy.deepcopy(self.err_image)
365        self.res_err_image[self.err_image==0]=1
366       
367        self.radius= numpy.sqrt(self.x_bins_array**2 + self.y_bins_array**2)
368        self.index_model = (self.qmin <= self.radius)&(self.radius<= self.qmax)
369       
370       
371    def setFitRange(self,qmin=None,qmax=None):
372        """ to set the fit range"""
373        if qmin==0.0:
374            self.qmin = 1e-16
375        elif qmin!=None:                       
376            self.qmin = qmin           
377        if qmax!=None:
378            self.qmax= qmax
379     
380       
381    def getFitRange(self):
382        """
383            @return the range of data.x to fit
384        """
385        return self.qmin, self.qmax
386     
387    def residuals(self, fn): 
388        try:
389            res=self.index_model*(self.image - fn([self.y_bins_array,
390                             self.x_bins_array]))/self.res_err_image
391            return res.ravel() 
392        except:
393            print "Using old residual method"
394            return self.old_residuals( fn)
395   
396   
397    def old_residuals(self, fn):
398        """ @param fn: function that return model value
399            @return residuals
400        """
401        res=[]
402       
403        for i in range(len(self.x_bins)):
404            for j in range(len(self.y_bins)):
405                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
406                radius= math.sqrt(temp)
407                if self.qmin <= radius and radius <= self.qmax:
408                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
409                            /self.res_err_image[j][i] )
410       
411        return numpy.array(res)
412       
413         
414    def residuals_deriv(self, model, pars=[]):
415        """
416            @return residuals derivatives .
417            @note: in this case just return empty array
418        """
419        return []
420   
421class FitAbort(Exception):
422    """
423        Exception raise to stop the fit
424    """
425    print"Creating fit abort Exception"
426
427
428class SansAssembly:
429    """
430         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
431    """
432    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
433        """
434            @param Model: the model wrapper fro sans -model
435            @param Data: the data wrapper for sans data
436        """
437        self.model = Model
438        self.data  = Data
439        self.paramlist=paramlist
440        self.curr_thread= curr_thread
441        self.res=[]
442        self.func_name="Functor"
443    def chisq(self, params):
444        """
445            Calculates chi^2
446            @param params: list of parameter values
447            @return: chi^2
448        """
449        sum = 0
450        for item in self.res:
451            sum += item*item
452        if len(self.res)==0:
453            return None
454        return sum/ len(self.res)
455   
456    def __call__(self,params):
457        """
458            Compute residuals
459            @param params: value of parameters to fit
460        """
461        #import thread
462        self.model.setParams(self.paramlist,params)
463        self.res= self.data.residuals(self.model.eval)
464        #if self.curr_thread != None :
465        #    try:
466        #        self.curr_thread.isquit()
467        #    except:
468        #        raise FitAbort,"stop leastsqr optimizer"   
469        return self.res
470   
471class FitEngine:
472    def __init__(self):
473        """
474            Base class for scipy and park fit engine
475        """
476        #List of parameter names to fit
477        self.paramList=[]
478        #Dictionnary of fitArrange element (fit problems)
479        self.fitArrangeDict={}
480       
481    def _concatenateData(self, listdata=[]):
482        """ 
483            _concatenateData method concatenates each fields of all data contains ins listdata.
484            @param listdata: list of data
485            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
486             of data concatenanted
487            @raise: if listdata is empty  will return None
488            @raise: if data in listdata don't contain dy field ,will create an error
489            during fitting
490        """
491        #TODO: we have to refactor the way we handle data.
492        # We should move away from plottables and move towards the Data1D objects
493        # defined in DataLoader. Data1D allows data manipulations, which should be
494        # used to concatenate.
495        # In the meantime we should switch off the concatenation.
496        #if len(listdata)>1:
497        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
498        #return listdata[0]
499       
500        if listdata==[]:
501            raise ValueError, " data list missing"
502        else:
503            xtemp=[]
504            ytemp=[]
505            dytemp=[]
506            self.mini=None
507            self.maxi=None
508               
509            for item in listdata:
510                data=item.data
511                mini,maxi=data.getFitRange()
512                if self.mini==None and self.maxi==None:
513                    self.mini=mini
514                    self.maxi=maxi
515                else:
516                    if mini < self.mini:
517                        self.mini=mini
518                    if self.maxi < maxi:
519                        self.maxi=maxi
520                       
521                   
522                for i in range(len(data.x)):
523                    xtemp.append(data.x[i])
524                    ytemp.append(data.y[i])
525                    if data.dy is not None and len(data.dy)==len(data.y):   
526                        dytemp.append(data.dy[i])
527                    else:
528                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
529            data= Data(x=xtemp,y=ytemp,dy=dytemp)
530            data.setFitRange(self.mini, self.maxi)
531            return data
532       
533       
534    def set_model(self,model,Uid,pars=[]):
535        """
536            set a model on a given uid in the fit engine.
537            @param model: the model to fit
538            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
539            @param pars: the list of parameters to fit
540            @note : pars must contains only name of existing model's paramaters
541        """
542        if len(pars) >0:
543            if model==None:
544                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
545            else:
546                temp=[]
547                for item in pars:
548                    if item in model.model.getParamList():
549                        temp.append(item)
550                        self.paramList.append(item)
551                    else:
552                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
553                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
554                        return
555            #A fitArrange is already created but contains dList only at Uid
556            if self.fitArrangeDict.has_key(Uid):
557                self.fitArrangeDict[Uid].set_model(model)
558                self.fitArrangeDict[Uid].pars= pars
559            else:
560            #no fitArrange object has been create with this Uid
561                fitproblem = FitArrange()
562                fitproblem.set_model(model)
563                fitproblem.pars= pars
564                self.fitArrangeDict[Uid] = fitproblem
565               
566        else:
567            raise ValueError, "park_integration:missing parameters"
568   
569    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
570        """ Receives plottable, creates a list of data to fit,set data
571            in a FitArrange object and adds that object in a dictionary
572            with key Uid.
573            @param data: data added
574            @param Uid: unique key corresponding to a fitArrange object with data
575        """
576        if data.__class__.__name__=='Data2D':
577            fitdata=FitData2D(data)
578        else:
579            fitdata=FitData1D(data, smearer)
580       
581        fitdata.setFitRange(qmin=qmin,qmax=qmax)
582        #A fitArrange is already created but contains model only at Uid
583        if self.fitArrangeDict.has_key(Uid):
584            self.fitArrangeDict[Uid].add_data(fitdata)
585        else:
586        #no fitArrange object has been create with this Uid
587            fitproblem= FitArrange()
588            fitproblem.add_data(fitdata)
589            self.fitArrangeDict[Uid]=fitproblem   
590   
591    def get_model(self,Uid):
592        """
593            @param Uid: Uid is key in the dictionary containing the model to return
594            @return  a model at this uid or None if no FitArrange element was created
595            with this Uid
596        """
597        if self.fitArrangeDict.has_key(Uid):
598            return self.fitArrangeDict[Uid].get_model()
599        else:
600            return None
601   
602    def remove_Fit_Problem(self,Uid):
603        """remove   fitarrange in Uid"""
604        if self.fitArrangeDict.has_key(Uid):
605            del self.fitArrangeDict[Uid]
606           
607    def select_problem_for_fit(self,Uid,value):
608        """
609            select a couple of model and data at the Uid position in dictionary
610            and set in self.selected value to value
611            @param value: the value to allow fitting. can only have the value one or zero
612        """
613        if self.fitArrangeDict.has_key(Uid):
614             self.fitArrangeDict[Uid].set_to_fit( value)
615             
616             
617    def get_problem_to_fit(self,Uid):
618        """
619            return the self.selected value of the fit problem of Uid
620           @param Uid: the Uid of the problem
621        """
622        if self.fitArrangeDict.has_key(Uid):
623             self.fitArrangeDict[Uid].get_to_fit()
624   
625class FitArrange:
626    def __init__(self):
627        """
628            Class FitArrange contains a set of data for a given model
629            to perform the Fit.FitArrange must contain exactly one model
630            and at least one data for the fit to be performed.
631            model: the model selected by the user
632            Ldata: a list of data what the user wants to fit
633           
634        """
635        self.model = None
636        self.dList =[]
637        self.pars=[]
638        #self.selected  is zero when this fit problem is not schedule to fit
639        #self.selected is 1 when schedule to fit
640        self.selected = 0
641       
642    def set_model(self,model):
643        """
644            set_model save a copy of the model
645            @param model: the model being set
646        """
647        self.model = model
648       
649    def add_data(self,data):
650        """
651            add_data fill a self.dList with data to fit
652            @param data: Data to add in the list 
653        """
654        if not data in self.dList:
655            self.dList.append(data)
656           
657    def get_model(self):
658        """ @return: saved model """
659        return self.model   
660     
661    def get_data(self):
662        """ @return:  list of data dList"""
663        #return self.dList
664        return self.dList[0] 
665     
666    def remove_data(self,data):
667        """
668            Remove one element from the list
669            @param data: Data to remove from dList
670        """
671        if data in self.dList:
672            self.dList.remove(data)
673    def set_to_fit (self, value=0):
674        """
675           set self.selected to 0 or 1  for other values raise an exception
676           @param value: integer between 0 or 1
677        """
678        self.selected= value
679       
680    def get_to_fit(self):
681        """
682            @return self.selected value
683        """
684        return self.selected
Note: See TracBrowser for help on using the repository browser.