source: sasview/park_integration/AbstractFitEngine.py @ 479eced

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 479eced was 36bc34e, checked in by Jae Cho <jhjcho@…>, 15 years ago

fixed 2D fit problem

  • Property mode set to 100644
File size: 23.2 KB
Line 
1import logging, sys
2import park,numpy,math, copy
3from DataLoader.data_info import Data1D
4from DataLoader.data_info import Data2D
5class SansParameter(park.Parameter):
6    """
7        SANS model parameters for use in the PARK fitting service.
8        The parameter attribute value is redirected to the underlying
9        parameter value in the SANS model.
10    """
11    def __init__(self, name, model):
12        """
13            @param name: the name of the model parameter
14            @param model: the sans model to wrap as a park model
15        """
16        self._model, self._name = model,name
17        #set the value for the parameter of the given name
18        self.set(model.getParam(name))
19         
20    def _getvalue(self):
21        """
22            override the _getvalue of park parameter
23            @return value the parameter associates with self.name
24        """
25        return self._model.getParam(self.name)
26   
27    def _setvalue(self,value):
28        """
29            override the _setvalue pf park parameter
30            @param value: the value to set on a given parameter
31        """
32        self._model.setParam(self.name, value)
33       
34    value = property(_getvalue,_setvalue)
35   
36    def _getrange(self):
37        """
38            Override _getrange of park parameter
39            return the range of parameter
40        """
41        #if not  self.name in self._model.getDispParamList():
42        lo,hi = self._model.details[self.name][1:3]
43        if lo is None: lo = -numpy.inf
44        if hi is None: hi = numpy.inf
45        #else:
46            #lo,hi = self._model.details[self.name][1:]
47            #if lo is None: lo = -numpy.inf
48            #if hi is None: hi = numpy.inf
49        if lo >= hi:
50            raise ValueError,"wrong fit range for parameters"
51       
52        return lo,hi
53   
54    def _setrange(self,r):
55        """
56            override _setrange of park parameter
57            @param r: the value of the range to set
58        """
59        self._model.details[self.name][1:3] = r
60    range = property(_getrange,_setrange)
61   
62class Model(park.Model):
63    """
64        PARK wrapper for SANS models.
65    """
66    def __init__(self, sans_model, **kw):
67        """
68            @param sans_model: the sans model to wrap using park interface
69        """
70        park.Model.__init__(self, **kw)
71        self.model = sans_model
72        self.name = sans_model.name
73        #list of parameters names
74        self.sansp = sans_model.getParamList()
75        #list of park parameter
76        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
77        #list of parameterset
78        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
79        self.pars=[]
80 
81 
82    def getParams(self,fitparams):
83        """
84            return a list of value of paramter to fit
85            @param fitparams: list of paramaters name to fit
86        """
87        list=[]
88        self.pars=[]
89        self.pars=fitparams
90        for item in fitparams:
91            for element in self.parkp:
92                 if element.name ==str(item):
93                     list.append(element.value)
94        return list
95   
96   
97    def setParams(self,paramlist, params):
98        """
99            Set value for parameters to fit
100            @param params: list of value for parameters to fit
101        """
102        try:
103            for i in range(len(self.parkp)):
104                for j in range(len(paramlist)):
105                    if self.parkp[i].name==paramlist[j]:
106                        self.parkp[i].value = params[j]
107                        self.model.setParam(self.parkp[i].name,params[j])
108        except:
109            raise
110 
111    def eval(self,x):
112        """
113            override eval method of park model.
114            @param x: the x value used to compute a function
115        """
116        try:
117            return self.model.evalDistribution(x)
118        except:
119            raise
120
121   
122class FitData1D(Data1D):
123    """
124        Wrapper class  for SANS data
125        FitData1D inherits from DataLoader.data_info.Data1D. Implements
126        a way to get residuals from data.
127    """
128    def __init__(self,x, y,dx= None, dy=None, smearer=None):
129        Data1D.__init__(self, x=numpy.array(x), y=numpy.array(y), dx=dx, dy=dy)
130        """
131            @param smearer: is an object of class QSmearer or SlitSmearer
132            that will smear the theory data (slit smearing or resolution
133            smearing) when set.
134           
135            The proper way to set the smearing object would be to
136            do the following:
137           
138            from DataLoader.qsmearing import smear_selection
139            smearer = smear_selection(some_data)
140            fitdata1d = FitData1D( x= [1,3,..,],
141                                    y= [3,4,..,8],
142                                    dx=None,
143                                    dy=[1,2...], smearer= smearer)
144           
145            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
146           
147            Setting it back to None will turn smearing off.
148           
149        """
150       
151        self.smearer = smearer
152        if dy ==None or dy==[]:
153            self.dy= numpy.zeros(len(self.y)) 
154        else:
155            self.dy= numpy.asarray(dy)
156     
157        # For fitting purposes, replace zero errors by 1
158        #TODO: check validity for the rare case where only
159        # a few points have zero errors
160        self.dy[self.dy==0]=1
161       
162        ## Min Q-value
163        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
164        if min (self.x) ==0.0 and self.x[0]==0 and not numpy.isfinite(self.y[0]):
165            self.qmin = min(self.x[self.x!=0])
166        else:                             
167            self.qmin= min (self.x)
168        ## Max Q-value
169        self.qmax = max (self.x)
170       
171        # Range used for input to smearing
172        self._qmin_unsmeared = self.qmin
173        self._qmax_unsmeared = self.qmax
174        # Identify the bin range for the unsmeared and smeared spaces
175        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
176        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
177 
178       
179       
180    def setFitRange(self,qmin=None,qmax=None):
181        """ to set the fit range"""
182        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
183        #ToDo: Fix this.
184        if qmin==0.0 and not numpy.isfinite(self.y[qmin]):
185            self.qmin = min(self.x[self.x!=0])
186        elif qmin!=None:                       
187            self.qmin = qmin           
188
189        if qmax !=None:
190            self.qmax = qmax
191           
192        # Determine the range needed in unsmeared-Q to cover
193        # the smeared Q range
194        self._qmin_unsmeared = self.qmin
195        self._qmax_unsmeared = self.qmax   
196       
197        self._first_unsmeared_bin = 0
198        self._last_unsmeared_bin  = len(self.x)-1
199       
200        if self.smearer!=None:
201            self._first_unsmeared_bin, self._last_unsmeared_bin = self.smearer.get_bin_range(self.qmin, self.qmax)
202            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
203            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
204           
205        # Identify the bin range for the unsmeared and smeared spaces
206        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
207        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
208 
209       
210    def getFitRange(self):
211        """
212            @return the range of data.x to fit
213        """
214        return self.qmin, self.qmax
215       
216    def residuals(self, fn):
217        """
218            Compute residuals.
219           
220            If self.smearer has been set, use if to smear
221            the data before computing chi squared.
222           
223            @param fn: function that return model value
224            @return residuals
225        """
226        # Compute theory data f(x)
227        fx= numpy.zeros(len(self.x))
228        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
229       
230        ## Smear theory data
231        if self.smearer is not None:
232            fx = self.smearer(fx, self._first_unsmeared_bin, self._last_unsmeared_bin)
233       
234        ## Sanity check
235        if numpy.size(self.dy)!= numpy.size(fx):
236            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
237                                                                              numpy.size(fx))
238                                                                             
239        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
240     
241 
242       
243    def residuals_deriv(self, model, pars=[]):
244        """
245            @return residuals derivatives .
246            @note: in this case just return empty array
247        """
248        return []
249   
250   
251class FitData2D(Data2D):
252    """ Wrapper class  for SANS data """
253    def __init__(self,sans_data2d ,data=None, err_data=None):
254        Data2D.__init__(self, data= data, err_data= err_data)
255        """
256            Data can be initital with a data (sans plottable)
257            or with vectors.
258        """
259        self.res_err_image=[]
260        self.index_model=[]
261        self.qmin= None
262        self.qmax= None
263        self.set_data(sans_data2d )
264       
265    def set_data(self, sans_data2d, qmin=None, qmax=None ):
266        """
267            Determine the correct qx_data and qy_data within range to fit
268        """
269        self.data     = sans_data2d.data
270        self.err_data = sans_data2d.err_data
271        self.qx_data = sans_data2d.qx_data
272        self.qy_data = sans_data2d.qy_data
273        self.mask       = sans_data2d.mask
274
275        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
276        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
277       
278        ## fitting range
279        if qmin == None:
280            self.qmin = 1e-16
281        if qmax == None:
282            self.qmax = math.sqrt(x_max*x_max +y_max*y_max)
283        ## new error image for fitting purpose
284        if self.err_data== None or self.err_data ==[]:
285            self.res_err_data= numpy.ones(len(self.data))
286        else:
287            self.res_err_data = copy.deepcopy(self.err_data)
288        self.res_err_data[self.res_err_data==0]=1
289       
290        self.radius= numpy.sqrt(self.qx_data**2 + self.qy_data**2)
291       
292        # Note: mask = True: for MASK while mask = False for NOT to mask
293        self.index_model = ((self.qmin <= self.radius)&(self.radius<= self.qmax))
294        self.index_model = (self.index_model) & (self.mask)
295        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
296           
297    def setFitRange(self,qmin=None,qmax=None):
298        """ to set the fit range"""
299        if qmin==0.0:
300            self.qmin = 1e-16
301        elif qmin!=None:                       
302            self.qmin = qmin           
303        if qmax!=None:
304            self.qmax= qmax       
305        self.radius= numpy.sqrt(self.qx_data**2 + self.qy_data**2)
306        self.index_model = ((self.qmin <= self.radius)&(self.radius<= self.qmax))
307        self.index_model = (self.index_model) &(self.mask)
308        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
309       
310    def getFitRange(self):
311        """
312            @return the range of data.x to fit
313        """
314        return self.qmin, self.qmax
315     
316    def residuals(self, fn): 
317        """
318            @return the residuals
319        """       
320        # use only the data point within ROI range
321        res=(self.data[self.index_model] - fn([self.qx_data[self.index_model],
322                             self.qy_data[self.index_model]]))/self.res_err_data[self.index_model]
323        return res
324       
325 
326    def residuals_deriv(self, model, pars=[]):
327        """
328            @return residuals derivatives .
329            @note: in this case just return empty array
330        """
331        return []
332   
333class FitAbort(Exception):
334    """
335        Exception raise to stop the fit
336    """
337    print"Creating fit abort Exception"
338
339
340class SansAssembly:
341    """
342         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
343    """
344    def __init__(self, paramlist, model=None , data=None, fitresult=None,
345                 handler=None, curr_thread=None):
346        """
347            @param Model: the model wrapper fro sans -model
348            @param Data: the data wrapper for sans data
349        """
350        self.model = model
351        self.data  = data
352        self.paramlist = paramlist
353        self.curr_thread = curr_thread
354        self.handler = handler
355        self.fitresult = fitresult
356        self.res = []
357        self.func_name = "Functor"
358       
359    def chisq(self, params):
360        """
361            Calculates chi^2
362            @param params: list of parameter values
363            @return: chi^2
364        """
365        sum = 0
366        for item in self.res:
367            sum += item*item
368        if len(self.res)==0:
369            return None
370        return sum/ len(self.res)
371   
372    def __call__(self,params):
373        """
374            Compute residuals
375            @param params: value of parameters to fit
376        """
377        #import thread
378        self.model.setParams(self.paramlist,params)
379        self.res= self.data.residuals(self.model.eval)
380        if self.fitresult is not None and  self.handler is not None:
381            self.fitresult.set_model(model=self.model)
382            fitness = self.chisq(params=params)
383            self.fitresult.set_fitness(fitness=fitness)
384            self.handler.set_result(result=self.fitresult)
385            self.handler.update_fit()
386       
387        #if self.curr_thread != None :
388        #    try:
389        #        self.curr_thread.isquit()
390        #    except:
391        #        raise FitAbort,"stop leastsqr optimizer"   
392        return self.res
393   
394class FitEngine:
395    def __init__(self):
396        """
397            Base class for scipy and park fit engine
398        """
399        #List of parameter names to fit
400        self.paramList=[]
401        #Dictionnary of fitArrange element (fit problems)
402        self.fitArrangeDict={}
403       
404    def _concatenateData(self, listdata=[]):
405        """ 
406            _concatenateData method concatenates each fields of all data contains ins listdata.
407            @param listdata: list of data
408            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
409             of data concatenanted
410            @raise: if listdata is empty  will return None
411            @raise: if data in listdata don't contain dy field ,will create an error
412            during fitting
413        """
414        #TODO: we have to refactor the way we handle data.
415        # We should move away from plottables and move towards the Data1D objects
416        # defined in DataLoader. Data1D allows data manipulations, which should be
417        # used to concatenate.
418        # In the meantime we should switch off the concatenation.
419        #if len(listdata)>1:
420        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
421        #return listdata[0]
422       
423        if listdata==[]:
424            raise ValueError, " data list missing"
425        else:
426            xtemp=[]
427            ytemp=[]
428            dytemp=[]
429            self.mini=None
430            self.maxi=None
431               
432            for item in listdata:
433                data=item.data
434                mini,maxi=data.getFitRange()
435                if self.mini==None and self.maxi==None:
436                    self.mini=mini
437                    self.maxi=maxi
438                else:
439                    if mini < self.mini:
440                        self.mini=mini
441                    if self.maxi < maxi:
442                        self.maxi=maxi
443                       
444                   
445                for i in range(len(data.x)):
446                    xtemp.append(data.x[i])
447                    ytemp.append(data.y[i])
448                    if data.dy is not None and len(data.dy)==len(data.y):   
449                        dytemp.append(data.dy[i])
450                    else:
451                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
452            data= Data(x=xtemp,y=ytemp,dy=dytemp)
453            data.setFitRange(self.mini, self.maxi)
454            return data
455       
456       
457    def set_model(self,model,Uid,pars=[], constraints=[]):
458        """
459            set a model on a given uid in the fit engine.
460            @param model: sans.models type
461            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
462            @param pars: the list of parameters to fit
463            @param constraints: list of
464                tuple (name of parameter, value of parameters)
465                the value of parameter must be a string to constraint 2 different
466                parameters.
467                Example:
468                we want to fit 2 model M1 and M2 both have parameters A and B.
469                constraints can be:
470                 constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
471            @note : pars must contains only name of existing model's paramaters
472        """
473        if model == None:
474            raise ValueError, "AbstractFitEngine: Need to set model to fit"
475       
476        new_model= model
477        if not issubclass(model.__class__, Model):
478            new_model= Model(model)
479       
480        if len(constraints)>0:
481            for constraint in constraints:
482                name, value = constraint
483                try:
484                    new_model.parameterset[ str(name)].set( str(value) )
485                except:
486                    msg= "Fit Engine: Error occurs when setting the constraint"
487                    msg += " %s for parameter %s "%(value, name)
488                    raise ValueError, msg
489               
490        if len(pars) >0:
491            temp=[]
492            for item in pars:
493                if item in new_model.model.getParamList():
494                    temp.append(item)
495                    self.paramList.append(item)
496                else:
497                   
498                    msg = "wrong parameter %s used"%str(item)
499                    msg += "to set model %s. Choose"%str(new_model.model.name)
500                    msg += "parameter name within %s"%str(new_model.model.getParamList())
501                    raise ValueError,msg
502             
503            #A fitArrange is already created but contains dList only at Uid
504            if self.fitArrangeDict.has_key(Uid):
505                self.fitArrangeDict[Uid].set_model(new_model)
506                self.fitArrangeDict[Uid].pars= pars
507            else:
508            #no fitArrange object has been create with this Uid
509                fitproblem = FitArrange()
510                fitproblem.set_model(new_model)
511                fitproblem.pars= pars
512                self.fitArrangeDict[Uid] = fitproblem
513               
514        else:
515            raise ValueError, "park_integration:missing parameters"
516   
517    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
518        """ Receives plottable, creates a list of data to fit,set data
519            in a FitArrange object and adds that object in a dictionary
520            with key Uid.
521            @param data: data added
522            @param Uid: unique key corresponding to a fitArrange object with data
523        """
524        if data.__class__.__name__=='Data2D':
525            fitdata=FitData2D(sans_data2d=data, data=data.data, err_data= data.err_data)
526        else:
527            fitdata=FitData1D(x=data.x, y=data.y , dx= data.dx,dy=data.dy,smearer=smearer)
528       
529        fitdata.setFitRange(qmin=qmin,qmax=qmax)
530        #A fitArrange is already created but contains model only at Uid
531        if self.fitArrangeDict.has_key(Uid):
532            self.fitArrangeDict[Uid].add_data(fitdata)
533        else:
534        #no fitArrange object has been create with this Uid
535            fitproblem= FitArrange()
536            fitproblem.add_data(fitdata)
537            self.fitArrangeDict[Uid]=fitproblem   
538   
539    def get_model(self,Uid):
540        """
541            @param Uid: Uid is key in the dictionary containing the model to return
542            @return  a model at this uid or None if no FitArrange element was created
543            with this Uid
544        """
545        if self.fitArrangeDict.has_key(Uid):
546            return self.fitArrangeDict[Uid].get_model()
547        else:
548            return None
549   
550    def remove_Fit_Problem(self,Uid):
551        """remove   fitarrange in Uid"""
552        if self.fitArrangeDict.has_key(Uid):
553            del self.fitArrangeDict[Uid]
554           
555    def select_problem_for_fit(self,Uid,value):
556        """
557            select a couple of model and data at the Uid position in dictionary
558            and set in self.selected value to value
559            @param value: the value to allow fitting. can only have the value one or zero
560        """
561        if self.fitArrangeDict.has_key(Uid):
562             self.fitArrangeDict[Uid].set_to_fit( value)
563             
564             
565    def get_problem_to_fit(self,Uid):
566        """
567            return the self.selected value of the fit problem of Uid
568           @param Uid: the Uid of the problem
569        """
570        if self.fitArrangeDict.has_key(Uid):
571             self.fitArrangeDict[Uid].get_to_fit()
572   
573class FitArrange:
574    def __init__(self):
575        """
576            Class FitArrange contains a set of data for a given model
577            to perform the Fit.FitArrange must contain exactly one model
578            and at least one data for the fit to be performed.
579            model: the model selected by the user
580            Ldata: a list of data what the user wants to fit
581           
582        """
583        self.model = None
584        self.dList =[]
585        self.pars=[]
586        #self.selected  is zero when this fit problem is not schedule to fit
587        #self.selected is 1 when schedule to fit
588        self.selected = 0
589       
590    def set_model(self,model):
591        """
592            set_model save a copy of the model
593            @param model: the model being set
594        """
595        self.model = model
596       
597    def add_data(self,data):
598        """
599            add_data fill a self.dList with data to fit
600            @param data: Data to add in the list 
601        """
602        if not data in self.dList:
603            self.dList.append(data)
604           
605    def get_model(self):
606        """ @return: saved model """
607        return self.model   
608     
609    def get_data(self):
610        """ @return:  list of data dList"""
611        #return self.dList
612        return self.dList[0] 
613     
614    def remove_data(self,data):
615        """
616            Remove one element from the list
617            @param data: Data to remove from dList
618        """
619        if data in self.dList:
620            self.dList.remove(data)
621    def set_to_fit (self, value=0):
622        """
623           set self.selected to 0 or 1  for other values raise an exception
624           @param value: integer between 0 or 1
625        """
626        self.selected= value
627       
628    def get_to_fit(self):
629        """
630            @return self.selected value
631        """
632        return self.selected
Note: See TracBrowser for help on using the repository browser.