source: sasview/park_integration/AbstractFitEngine.py @ b293683

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b293683 was 7f81665, checked in by Mathieu Doucet <doucetm@…>, 15 years ago

park_integration: remove try-except block that lets evalDistribution silently fail.

  • Property mode set to 100644
File size: 23.6 KB
Line 
1import logging, sys
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        try:
115            return self.model.evalDistribution(x)
116        except:
117            print "AbstractFitEngine.Model.eval [FIXME]:", sys.exc_value
118            raise
119
120class Data(object):
121    """ Wrapper class  for SANS data """
122    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
123        """
124            Data can be initital with a data (sans plottable)
125            or with vectors.
126        """
127        if  sans_data !=None:
128            self.x= sans_data.x
129            self.y= sans_data.y
130            self.dx= sans_data.dx
131            self.dy= sans_data.dy
132           
133        elif (x!=None and y!=None and dy!=None):
134                self.x=x
135                self.y=y
136                self.dx=dx
137                self.dy=dy
138        else:
139            raise ValueError,\
140            "Data is missing x, y or dy, impossible to compute residuals later on"
141        self.qmin=None
142        self.qmax=None
143       
144       
145    def setFitRange(self,mini=None,maxi=None):
146        """ to set the fit range"""
147       
148        self.qmin=mini           
149        self.qmax=maxi
150       
151       
152    def getFitRange(self):
153        """
154            @return the range of data.x to fit
155        """
156        return self.qmin, self.qmax
157     
158     
159    def residuals(self, fn):
160        """ @param fn: function that return model value
161            @return residuals
162        """
163        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
164        if self.qmin==None and self.qmax==None: 
165            fx =numpy.asarray([fn(v) for v in x])
166            return (y - fx)/dy
167        else:
168            idx = (x>=self.qmin) & (x <= self.qmax)
169            fx = numpy.asarray([fn(item)for item in x[idx ]])
170            return (y[idx] - fx)/dy[idx]
171       
172    def residuals_deriv(self, model, pars=[]):
173        """
174            @return residuals derivatives .
175            @note: in this case just return empty array
176        """
177        return []
178   
179   
180class FitData1D(object):
181    """ Wrapper class  for SANS data """
182    def __init__(self,sans_data1d, smearer=None):
183        """
184            Data can be initital with a data (sans plottable)
185            or with vectors.
186           
187            self.smearer is an object of class QSmearer or SlitSmearer
188            that will smear the theory data (slit smearing or resolution
189            smearing) when set.
190           
191            The proper way to set the smearing object would be to
192            do the following:
193           
194            from DataLoader.qsmearing import smear_selection
195            fitdata1d = FitData1D(some_data)
196            fitdata1d.smearer = smear_selection(some_data)
197           
198            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
199           
200            Setting it back to None will turn smearing off.
201           
202        """
203       
204        self.smearer = smearer
205     
206        # Initialize from Data1D object
207        self.data=sans_data1d
208        self.x= sans_data1d.x
209        self.y= sans_data1d.y
210        self.dx= sans_data1d.dx
211        self.dy= sans_data1d.dy
212       
213        ## Min Q-value
214        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
215        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
216            self.qmin = min(self.data.x[self.data.x!=0])
217        else:                             
218            self.qmin= min (self.data.x)
219        ## Max Q-value
220        self.qmax= max (self.data.x)
221       
222        # Range used for input to smearing
223        self._qmin_unsmeared = self.qmin
224        self._qmax_unsmeared = self.qmax
225       
226       
227    def setFitRange(self,qmin=None,qmax=None):
228        """ to set the fit range"""
229        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
230        #ToDo: Fix this.
231        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
232            self.qmin = min(self.data.x[self.data.x!=0])
233        elif qmin!=None:                       
234            self.qmin = qmin           
235
236        if qmax !=None:
237            self.qmax = qmax
238           
239        # Range used for input to smearing
240        self._qmin_unsmeared = self.qmin
241        self._qmax_unsmeared = self.qmax   
242       
243        # Determine the range needed in unsmeared-Q to cover
244        # the smeared Q range
245        #TODO: use the smearing matrix to determine which
246        # bin range to use
247        if self.smearer.__class__.__name__ == 'SlitSmearer':
248            self._qmin_unsmeared = min(self.data.x)
249            self._qmax_unsmeared = max(self.data.x)
250        elif self.smearer.__class__.__name__ == 'QSmearer':
251            # Take 3 sigmas as the offset between smeared and unsmeared space
252            try:
253                offset = 3.0*max(self.smearer.width)
254                self._qmin_unsmeared = max([min(self.data.x), self.qmin-offset])
255                self._qmax_unsmeared = min([max(self.data.x), self.qmax+offset])
256            except:
257                logging.error("FitData1D.setFitRange: %s" % sys.exc_value)
258       
259       
260    def getFitRange(self):
261        """
262            @return the range of data.x to fit
263        """
264        return self.qmin, self.qmax
265       
266    def residuals(self, fn):
267        """
268            Compute residuals.
269           
270            If self.smearer has been set, use if to smear
271            the data before computing chi squared.
272           
273            @param fn: function that return model value
274            @return residuals
275        """
276        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
277        if self.dy ==None or self.dy==[]:
278            dy= numpy.zeros(len(y)) 
279        else:
280            dy= numpy.asarray(self.dy)
281     
282        # For fitting purposes, replace zero errors by 1
283        #TODO: check validity for the rare case where only
284        # a few points have zero errors
285        dy[dy==0]=1
286       
287        # Identify the bin range for the unsmeared and smeared spaces
288        idx = (x>=self.qmin) & (x <= self.qmax)
289        idx_unsmeared = (x>=self._qmin_unsmeared) & (x <= self._qmax_unsmeared)
290 
291        # Compute theory data f(x)
292        fx= numpy.zeros(len(x))
293   
294        _first_bin = None
295        _last_bin  = None
296        for i_x in range(len(x)):
297            try:
298                if idx_unsmeared[i_x]==True:
299                    # Identify first and last bin
300                    #TODO: refactor this to pass q-values to the smearer
301                    # and let it figure out which bin range to use
302                    if _first_bin is None:
303                        _first_bin = i_x
304                    else:
305                        _last_bin  = i_x
306                   
307                    value = fn(x[i_x])
308                    fx[i_x] = value
309            except:
310                ## skip error for model.run(x)
311                pass
312                 
313        ## Smear theory data
314        if self.smearer is not None:
315            fx = self.smearer(fx, _first_bin, _last_bin)
316       
317        ## Sanity check
318        if numpy.size(dy)!= numpy.size(fx):
319            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.size(dy), numpy.size(fx))
320
321        return (y[idx]-fx[idx])/dy[idx]
322     
323 
324       
325    def residuals_deriv(self, model, pars=[]):
326        """
327            @return residuals derivatives .
328            @note: in this case just return empty array
329        """
330        return []
331   
332   
333class FitData2D(object):
334    """ Wrapper class  for SANS data """
335    def __init__(self,sans_data2d):
336        """
337            Data can be initital with a data (sans plottable)
338            or with vectors.
339        """
340        self.data=sans_data2d
341        self.image = sans_data2d.data
342        self.err_image = sans_data2d.err_data
343        self.x_bins_array= numpy.reshape(sans_data2d.x_bins,
344                                         [1,len(sans_data2d.x_bins)])
345        self.y_bins_array = numpy.reshape(sans_data2d.y_bins,
346                                          [len(sans_data2d.y_bins),1])
347       
348       
349        self.x_bins = sans_data2d.x_bins
350        self.y_bins = sans_data2d.y_bins
351       
352        x = max(self.data.xmin, self.data.xmax)
353        y = max(self.data.ymin, self.data.ymax)
354       
355        ## fitting range
356        self.qmin = 1e-16
357        self.qmax = math.sqrt(x*x +y*y)
358        ## new error image for fitting purpose
359        if self.err_image== None or self.err_image ==[]:
360            self.res_err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
361        else:
362            self.res_err_image = copy.deepcopy(self.err_image)
363        self.res_err_image[self.err_image==0]=1
364       
365        self.radius= numpy.sqrt(self.x_bins_array**2 + self.y_bins_array**2)
366        self.index_model = (self.qmin <= self.radius)&(self.radius<= self.qmax)
367       
368       
369    def setFitRange(self,qmin=None,qmax=None):
370        """ to set the fit range"""
371        if qmin==0.0:
372            self.qmin = 1e-16
373        elif qmin!=None:                       
374            self.qmin = qmin           
375        if qmax!=None:
376            self.qmax= qmax
377     
378       
379    def getFitRange(self):
380        """
381            @return the range of data.x to fit
382        """
383        return self.qmin, self.qmax
384     
385    def residuals(self, fn): 
386        res=self.index_model*(self.image - fn([self.y_bins_array,
387                         self.x_bins_array]))/self.res_err_image
388        return res.ravel() 
389   
390   
391    def old_residuals(self, fn):
392        """ @param fn: function that return model value
393            @return residuals
394        """
395        res=[]
396       
397        for i in range(len(self.x_bins)):
398            for j in range(len(self.y_bins)):
399                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
400                radius= math.sqrt(temp)
401                if self.qmin <= radius and radius <= self.qmax:
402                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
403                            /self.res_err_image[j][i] )
404       
405        return numpy.array(res)
406       
407         
408    def residuals_deriv(self, model, pars=[]):
409        """
410            @return residuals derivatives .
411            @note: in this case just return empty array
412        """
413        return []
414   
415class FitAbort(Exception):
416    """
417        Exception raise to stop the fit
418    """
419    print"Creating fit abort Exception"
420
421
422class SansAssembly:
423    """
424         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
425    """
426    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
427        """
428            @param Model: the model wrapper fro sans -model
429            @param Data: the data wrapper for sans data
430        """
431        self.model = Model
432        self.data  = Data
433        self.paramlist=paramlist
434        self.curr_thread= curr_thread
435        self.res=[]
436        self.func_name="Functor"
437    def chisq(self, params):
438        """
439            Calculates chi^2
440            @param params: list of parameter values
441            @return: chi^2
442        """
443        sum = 0
444        for item in self.res:
445            sum += item*item
446        if len(self.res)==0:
447            return None
448        return sum/ len(self.res)
449   
450    def __call__(self,params):
451        """
452            Compute residuals
453            @param params: value of parameters to fit
454        """
455        #import thread
456        self.model.setParams(self.paramlist,params)
457        self.res= self.data.residuals(self.model.eval)
458        #if self.curr_thread != None :
459        #    try:
460        #        self.curr_thread.isquit()
461        #    except:
462        #        raise FitAbort,"stop leastsqr optimizer"   
463        return self.res
464   
465class FitEngine:
466    def __init__(self):
467        """
468            Base class for scipy and park fit engine
469        """
470        #List of parameter names to fit
471        self.paramList=[]
472        #Dictionnary of fitArrange element (fit problems)
473        self.fitArrangeDict={}
474       
475    def _concatenateData(self, listdata=[]):
476        """ 
477            _concatenateData method concatenates each fields of all data contains ins listdata.
478            @param listdata: list of data
479            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
480             of data concatenanted
481            @raise: if listdata is empty  will return None
482            @raise: if data in listdata don't contain dy field ,will create an error
483            during fitting
484        """
485        #TODO: we have to refactor the way we handle data.
486        # We should move away from plottables and move towards the Data1D objects
487        # defined in DataLoader. Data1D allows data manipulations, which should be
488        # used to concatenate.
489        # In the meantime we should switch off the concatenation.
490        #if len(listdata)>1:
491        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
492        #return listdata[0]
493       
494        if listdata==[]:
495            raise ValueError, " data list missing"
496        else:
497            xtemp=[]
498            ytemp=[]
499            dytemp=[]
500            self.mini=None
501            self.maxi=None
502               
503            for item in listdata:
504                data=item.data
505                mini,maxi=data.getFitRange()
506                if self.mini==None and self.maxi==None:
507                    self.mini=mini
508                    self.maxi=maxi
509                else:
510                    if mini < self.mini:
511                        self.mini=mini
512                    if self.maxi < maxi:
513                        self.maxi=maxi
514                       
515                   
516                for i in range(len(data.x)):
517                    xtemp.append(data.x[i])
518                    ytemp.append(data.y[i])
519                    if data.dy is not None and len(data.dy)==len(data.y):   
520                        dytemp.append(data.dy[i])
521                    else:
522                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
523            data= Data(x=xtemp,y=ytemp,dy=dytemp)
524            data.setFitRange(self.mini, self.maxi)
525            return data
526       
527       
528    def set_model(self,model,Uid,pars=[]):
529        """
530            set a model on a given uid in the fit engine.
531            @param model: the model to fit
532            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
533            @param pars: the list of parameters to fit
534            @note : pars must contains only name of existing model's paramaters
535        """
536        if len(pars) >0:
537            if model==None:
538                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
539            else:
540                temp=[]
541                for item in pars:
542                    if item in model.model.getParamList():
543                        temp.append(item)
544                        self.paramList.append(item)
545                    else:
546                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
547                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
548                        return
549            #A fitArrange is already created but contains dList only at Uid
550            if self.fitArrangeDict.has_key(Uid):
551                self.fitArrangeDict[Uid].set_model(model)
552                self.fitArrangeDict[Uid].pars= pars
553            else:
554            #no fitArrange object has been create with this Uid
555                fitproblem = FitArrange()
556                fitproblem.set_model(model)
557                fitproblem.pars= pars
558                self.fitArrangeDict[Uid] = fitproblem
559               
560        else:
561            raise ValueError, "park_integration:missing parameters"
562   
563    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
564        """ Receives plottable, creates a list of data to fit,set data
565            in a FitArrange object and adds that object in a dictionary
566            with key Uid.
567            @param data: data added
568            @param Uid: unique key corresponding to a fitArrange object with data
569        """
570        if data.__class__.__name__=='Data2D':
571            fitdata=FitData2D(data)
572        else:
573            fitdata=FitData1D(data, smearer)
574       
575        fitdata.setFitRange(qmin=qmin,qmax=qmax)
576        #A fitArrange is already created but contains model only at Uid
577        if self.fitArrangeDict.has_key(Uid):
578            self.fitArrangeDict[Uid].add_data(fitdata)
579        else:
580        #no fitArrange object has been create with this Uid
581            fitproblem= FitArrange()
582            fitproblem.add_data(fitdata)
583            self.fitArrangeDict[Uid]=fitproblem   
584   
585    def get_model(self,Uid):
586        """
587            @param Uid: Uid is key in the dictionary containing the model to return
588            @return  a model at this uid or None if no FitArrange element was created
589            with this Uid
590        """
591        if self.fitArrangeDict.has_key(Uid):
592            return self.fitArrangeDict[Uid].get_model()
593        else:
594            return None
595   
596    def remove_Fit_Problem(self,Uid):
597        """remove   fitarrange in Uid"""
598        if self.fitArrangeDict.has_key(Uid):
599            del self.fitArrangeDict[Uid]
600           
601    def select_problem_for_fit(self,Uid,value):
602        """
603            select a couple of model and data at the Uid position in dictionary
604            and set in self.selected value to value
605            @param value: the value to allow fitting. can only have the value one or zero
606        """
607        if self.fitArrangeDict.has_key(Uid):
608             self.fitArrangeDict[Uid].set_to_fit( value)
609             
610             
611    def get_problem_to_fit(self,Uid):
612        """
613            return the self.selected value of the fit problem of Uid
614           @param Uid: the Uid of the problem
615        """
616        if self.fitArrangeDict.has_key(Uid):
617             self.fitArrangeDict[Uid].get_to_fit()
618   
619class FitArrange:
620    def __init__(self):
621        """
622            Class FitArrange contains a set of data for a given model
623            to perform the Fit.FitArrange must contain exactly one model
624            and at least one data for the fit to be performed.
625            model: the model selected by the user
626            Ldata: a list of data what the user wants to fit
627           
628        """
629        self.model = None
630        self.dList =[]
631        self.pars=[]
632        #self.selected  is zero when this fit problem is not schedule to fit
633        #self.selected is 1 when schedule to fit
634        self.selected = 0
635       
636    def set_model(self,model):
637        """
638            set_model save a copy of the model
639            @param model: the model being set
640        """
641        self.model = model
642       
643    def add_data(self,data):
644        """
645            add_data fill a self.dList with data to fit
646            @param data: Data to add in the list 
647        """
648        if not data in self.dList:
649            self.dList.append(data)
650           
651    def get_model(self):
652        """ @return: saved model """
653        return self.model   
654     
655    def get_data(self):
656        """ @return:  list of data dList"""
657        #return self.dList
658        return self.dList[0] 
659     
660    def remove_data(self,data):
661        """
662            Remove one element from the list
663            @param data: Data to remove from dList
664        """
665        if data in self.dList:
666            self.dList.remove(data)
667    def set_to_fit (self, value=0):
668        """
669           set self.selected to 0 or 1  for other values raise an exception
670           @param value: integer between 0 or 1
671        """
672        self.selected= value
673       
674    def get_to_fit(self):
675        """
676            @return self.selected value
677        """
678        return self.selected
Note: See TracBrowser for help on using the repository browser.