source: sasview/park_integration/AbstractFitEngine.py @ cc1ead1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since cc1ead1 was f72333f, checked in by Jae Cho <jhjcho@…>, 15 years ago

Plugged in 2D smear: traditional over-sampling method

  • Property mode set to 100644
File size: 23.7 KB
Line 
1import logging, sys
2import park,numpy,math, copy
3from DataLoader.data_info import Data1D
4from DataLoader.data_info import Data2D
5class SansParameter(park.Parameter):
6    """
7        SANS model parameters for use in the PARK fitting service.
8        The parameter attribute value is redirected to the underlying
9        parameter value in the SANS model.
10    """
11    def __init__(self, name, model):
12        """
13            @param name: the name of the model parameter
14            @param model: the sans model to wrap as a park model
15        """
16        self._model, self._name = model,name
17        #set the value for the parameter of the given name
18        self.set(model.getParam(name))
19         
20    def _getvalue(self):
21        """
22            override the _getvalue of park parameter
23            @return value the parameter associates with self.name
24        """
25        return self._model.getParam(self.name)
26   
27    def _setvalue(self,value):
28        """
29            override the _setvalue pf park parameter
30            @param value: the value to set on a given parameter
31        """
32        self._model.setParam(self.name, value)
33       
34    value = property(_getvalue,_setvalue)
35   
36    def _getrange(self):
37        """
38            Override _getrange of park parameter
39            return the range of parameter
40        """
41        #if not  self.name in self._model.getDispParamList():
42        lo,hi = self._model.details[self.name][1:3]
43        if lo is None: lo = -numpy.inf
44        if hi is None: hi = numpy.inf
45        #else:
46            #lo,hi = self._model.details[self.name][1:]
47            #if lo is None: lo = -numpy.inf
48            #if hi is None: hi = numpy.inf
49        if lo >= hi:
50            raise ValueError,"wrong fit range for parameters"
51       
52        return lo,hi
53   
54    def _setrange(self,r):
55        """
56            override _setrange of park parameter
57            @param r: the value of the range to set
58        """
59        self._model.details[self.name][1:3] = r
60    range = property(_getrange,_setrange)
61   
62class Model(park.Model):
63    """
64        PARK wrapper for SANS models.
65    """
66    def __init__(self, sans_model, **kw):
67        """
68            @param sans_model: the sans model to wrap using park interface
69        """
70        park.Model.__init__(self, **kw)
71        self.model = sans_model
72        self.name = sans_model.name
73        #list of parameters names
74        self.sansp = sans_model.getParamList()
75        #list of park parameter
76        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
77        #list of parameterset
78        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
79        self.pars=[]
80 
81 
82    def getParams(self,fitparams):
83        """
84            return a list of value of paramter to fit
85            @param fitparams: list of paramaters name to fit
86        """
87        list=[]
88        self.pars=[]
89        self.pars=fitparams
90        for item in fitparams:
91            for element in self.parkp:
92                 if element.name ==str(item):
93                     list.append(element.value)
94        return list
95   
96   
97    def setParams(self,paramlist, params):
98        """
99            Set value for parameters to fit
100            @param params: list of value for parameters to fit
101        """
102        try:
103            for i in range(len(self.parkp)):
104                for j in range(len(paramlist)):
105                    if self.parkp[i].name==paramlist[j]:
106                        self.parkp[i].value = params[j]
107                        self.model.setParam(self.parkp[i].name,params[j])
108        except:
109            raise
110 
111    def eval(self,x):
112        """
113            override eval method of park model.
114            @param x: the x value used to compute a function
115        """
116        try:
117            return self.model.evalDistribution(x)
118        except:
119            raise
120
121   
122class FitData1D(Data1D):
123    """
124        Wrapper class  for SANS data
125        FitData1D inherits from DataLoader.data_info.Data1D. Implements
126        a way to get residuals from data.
127    """
128    def __init__(self,x, y,dx= None, dy=None, smearer=None):
129        Data1D.__init__(self, x=numpy.array(x), y=numpy.array(y), dx=dx, dy=dy)
130        """
131            @param smearer: is an object of class QSmearer or SlitSmearer
132            that will smear the theory data (slit smearing or resolution
133            smearing) when set.
134           
135            The proper way to set the smearing object would be to
136            do the following:
137           
138            from DataLoader.qsmearing import smear_selection
139            smearer = smear_selection(some_data)
140            fitdata1d = FitData1D( x= [1,3,..,],
141                                    y= [3,4,..,8],
142                                    dx=None,
143                                    dy=[1,2...], smearer= smearer)
144           
145            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
146           
147            Setting it back to None will turn smearing off.
148           
149        """
150       
151        self.smearer = smearer
152        if dy ==None or dy==[]:
153            self.dy= numpy.zeros(len(self.y)) 
154        else:
155            self.dy= numpy.asarray(dy)
156     
157        # For fitting purposes, replace zero errors by 1
158        #TODO: check validity for the rare case where only
159        # a few points have zero errors
160        self.dy[self.dy==0]=1
161       
162        ## Min Q-value
163        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
164        if min (self.x) ==0.0 and self.x[0]==0 and not numpy.isfinite(self.y[0]):
165            self.qmin = min(self.x[self.x!=0])
166        else:                             
167            self.qmin= min (self.x)
168        ## Max Q-value
169        self.qmax = max (self.x)
170       
171        # Range used for input to smearing
172        self._qmin_unsmeared = self.qmin
173        self._qmax_unsmeared = self.qmax
174        # Identify the bin range for the unsmeared and smeared spaces
175        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
176        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
177 
178       
179       
180    def setFitRange(self,qmin=None,qmax=None):
181        """ to set the fit range"""
182        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
183        #ToDo: Fix this.
184        if qmin==0.0 and not numpy.isfinite(self.y[qmin]):
185            self.qmin = min(self.x[self.x!=0])
186        elif qmin!=None:                       
187            self.qmin = qmin           
188
189        if qmax !=None:
190            self.qmax = qmax
191           
192        # Determine the range needed in unsmeared-Q to cover
193        # the smeared Q range
194        self._qmin_unsmeared = self.qmin
195        self._qmax_unsmeared = self.qmax   
196       
197        self._first_unsmeared_bin = 0
198        self._last_unsmeared_bin  = len(self.x)-1
199       
200        if self.smearer!=None:
201            self._first_unsmeared_bin, self._last_unsmeared_bin = self.smearer.get_bin_range(self.qmin, self.qmax)
202            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
203            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
204           
205        # Identify the bin range for the unsmeared and smeared spaces
206        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
207        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
208 
209       
210    def getFitRange(self):
211        """
212            @return the range of data.x to fit
213        """
214        return self.qmin, self.qmax
215       
216    def residuals(self, fn):
217        """
218            Compute residuals.
219           
220            If self.smearer has been set, use if to smear
221            the data before computing chi squared.
222           
223            @param fn: function that return model value
224            @return residuals
225        """
226        # Compute theory data f(x)
227        fx= numpy.zeros(len(self.x))
228        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
229       
230        ## Smear theory data
231        if self.smearer is not None:
232            fx = self.smearer(fx, self._first_unsmeared_bin, self._last_unsmeared_bin)
233       
234        ## Sanity check
235        if numpy.size(self.dy)!= numpy.size(fx):
236            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
237                                                                              numpy.size(fx))
238                                                                             
239        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
240     
241 
242       
243    def residuals_deriv(self, model, pars=[]):
244        """
245            @return residuals derivatives .
246            @note: in this case just return empty array
247        """
248        return []
249   
250   
251class FitData2D(Data2D):
252    """ Wrapper class  for SANS data """
253    def __init__(self,sans_data2d ,data=None, err_data=None):
254        Data2D.__init__(self, data= data, err_data= err_data)
255        """
256            Data can be initital with a data (sans plottable)
257            or with vectors.
258        """
259        self.res_err_image=[]
260        self.index_model=[]
261        self.qmin= None
262        self.qmax= None
263        self.smearer = None
264        self.set_data(sans_data2d )
265
266       
267    def set_data(self, sans_data2d, qmin=None, qmax=None ):
268        """
269            Determine the correct qx_data and qy_data within range to fit
270        """
271        self.data     = sans_data2d.data
272        self.err_data = sans_data2d.err_data
273        self.qx_data = sans_data2d.qx_data
274        self.qy_data = sans_data2d.qy_data
275        self.mask       = sans_data2d.mask
276
277        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
278        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
279       
280        ## fitting range
281        if qmin == None:
282            self.qmin = 1e-16
283        if qmax == None:
284            self.qmax = math.sqrt(x_max*x_max +y_max*y_max)
285        ## new error image for fitting purpose
286        if self.err_data== None or self.err_data ==[]:
287            self.res_err_data= numpy.ones(len(self.data))
288        else:
289            self.res_err_data = copy.deepcopy(self.err_data)
290        self.res_err_data[self.res_err_data==0]=1
291       
292        self.radius= numpy.sqrt(self.qx_data**2 + self.qy_data**2)
293       
294        # Note: mask = True: for MASK while mask = False for NOT to mask
295        self.index_model = ((self.qmin <= self.radius)&(self.radius<= self.qmax))
296        self.index_model = (self.index_model) & (self.mask)
297        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
298       
299    def set_smearer(self,smearer): 
300        """
301            Set smearer
302        """
303        if smearer == None:
304            return
305        self.smearer = smearer
306        self.smearer.set_index(self.index_model)
307        self.smearer.get_data()
308
309    def setFitRange(self,qmin=None,qmax=None):
310        """ to set the fit range"""
311        if qmin==0.0:
312            self.qmin = 1e-16
313        elif qmin!=None:                       
314            self.qmin = qmin           
315        if qmax!=None:
316            self.qmax= qmax       
317        self.radius= numpy.sqrt(self.qx_data**2 + self.qy_data**2)
318        self.index_model = ((self.qmin <= self.radius)&(self.radius<= self.qmax))
319        self.index_model = (self.index_model) &(self.mask)
320        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
321       
322    def getFitRange(self):
323        """
324            @return the range of data.x to fit
325        """
326        return self.qmin, self.qmax
327     
328    def residuals(self, fn): 
329        """
330            @return the residuals
331        """ 
332        if self.smearer != None:
333            fn.set_index(self.index_model)
334            # Get necessary data from self.data and set the data for smearing
335            fn.get_data()
336
337            gn = fn.get_value() 
338        else:
339            gn = fn([self.qx_data[self.index_model],self.qy_data[self.index_model]])
340        # use only the data point within ROI range
341        res=(self.data[self.index_model] - gn)/self.res_err_data[self.index_model]
342        return res
343       
344 
345    def residuals_deriv(self, model, pars=[]):
346        """
347            @return residuals derivatives .
348            @note: in this case just return empty array
349        """
350        return []
351   
352class FitAbort(Exception):
353    """
354        Exception raise to stop the fit
355    """
356    print"Creating fit abort Exception"
357
358
359class SansAssembly:
360    """
361         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
362    """
363    def __init__(self, paramlist, model=None , data=None, fitresult=None,
364                 handler=None, curr_thread=None):
365        """
366            @param Model: the model wrapper fro sans -model
367            @param Data: the data wrapper for sans data
368        """
369        self.model = model
370        self.data  = data
371        self.paramlist = paramlist
372        self.curr_thread = curr_thread
373        self.handler = handler
374        self.fitresult = fitresult
375        self.res = []
376        self.func_name = "Functor"
377       
378    def chisq(self, params):
379        """
380            Calculates chi^2
381            @param params: list of parameter values
382            @return: chi^2
383        """
384        sum = 0
385        for item in self.res:
386            sum += item*item
387        if len(self.res)==0:
388            return None
389        return sum/ len(self.res)
390   
391    def __call__(self,params):
392        """
393            Compute residuals
394            @param params: value of parameters to fit
395        """
396        #import thread
397        self.model.setParams(self.paramlist,params)
398        self.res= self.data.residuals(self.model.eval)
399        if self.fitresult is not None and  self.handler is not None:
400            self.fitresult.set_model(model=self.model)
401            fitness = self.chisq(params=params)
402            self.fitresult.set_fitness(fitness=fitness)
403            self.handler.set_result(result=self.fitresult)
404            self.handler.update_fit()
405       
406        #if self.curr_thread != None :
407        #    try:
408        #        self.curr_thread.isquit()
409        #    except:
410        #        raise FitAbort,"stop leastsqr optimizer"   
411        return self.res
412   
413class FitEngine:
414    def __init__(self):
415        """
416            Base class for scipy and park fit engine
417        """
418        #List of parameter names to fit
419        self.paramList=[]
420        #Dictionnary of fitArrange element (fit problems)
421        self.fitArrangeDict={}
422       
423    def _concatenateData(self, listdata=[]):
424        """ 
425            _concatenateData method concatenates each fields of all data contains ins listdata.
426            @param listdata: list of data
427            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
428             of data concatenanted
429            @raise: if listdata is empty  will return None
430            @raise: if data in listdata don't contain dy field ,will create an error
431            during fitting
432        """
433        #TODO: we have to refactor the way we handle data.
434        # We should move away from plottables and move towards the Data1D objects
435        # defined in DataLoader. Data1D allows data manipulations, which should be
436        # used to concatenate.
437        # In the meantime we should switch off the concatenation.
438        #if len(listdata)>1:
439        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
440        #return listdata[0]
441       
442        if listdata==[]:
443            raise ValueError, " data list missing"
444        else:
445            xtemp=[]
446            ytemp=[]
447            dytemp=[]
448            self.mini=None
449            self.maxi=None
450               
451            for item in listdata:
452                data=item.data
453                mini,maxi=data.getFitRange()
454                if self.mini==None and self.maxi==None:
455                    self.mini=mini
456                    self.maxi=maxi
457                else:
458                    if mini < self.mini:
459                        self.mini=mini
460                    if self.maxi < maxi:
461                        self.maxi=maxi
462                       
463                   
464                for i in range(len(data.x)):
465                    xtemp.append(data.x[i])
466                    ytemp.append(data.y[i])
467                    if data.dy is not None and len(data.dy)==len(data.y):   
468                        dytemp.append(data.dy[i])
469                    else:
470                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
471            data= Data(x=xtemp,y=ytemp,dy=dytemp)
472            data.setFitRange(self.mini, self.maxi)
473            return data
474       
475       
476    def set_model(self,model,Uid,pars=[], constraints=[]):
477        """
478            set a model on a given uid in the fit engine.
479            @param model: sans.models type
480            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
481            @param pars: the list of parameters to fit
482            @param constraints: list of
483                tuple (name of parameter, value of parameters)
484                the value of parameter must be a string to constraint 2 different
485                parameters.
486                Example:
487                we want to fit 2 model M1 and M2 both have parameters A and B.
488                constraints can be:
489                 constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
490            @note : pars must contains only name of existing model's paramaters
491        """
492        if model == None:
493            raise ValueError, "AbstractFitEngine: Need to set model to fit"
494       
495        new_model= model
496        if not issubclass(model.__class__, Model):
497            new_model= Model(model)
498       
499        if len(constraints)>0:
500            for constraint in constraints:
501                name, value = constraint
502                try:
503                    new_model.parameterset[ str(name)].set( str(value) )
504                except:
505                    msg= "Fit Engine: Error occurs when setting the constraint"
506                    msg += " %s for parameter %s "%(value, name)
507                    raise ValueError, msg
508               
509        if len(pars) >0:
510            temp=[]
511            for item in pars:
512                if item in new_model.model.getParamList():
513                    temp.append(item)
514                    self.paramList.append(item)
515                else:
516                   
517                    msg = "wrong parameter %s used"%str(item)
518                    msg += "to set model %s. Choose"%str(new_model.model.name)
519                    msg += "parameter name within %s"%str(new_model.model.getParamList())
520                    raise ValueError,msg
521             
522            #A fitArrange is already created but contains dList only at Uid
523            if self.fitArrangeDict.has_key(Uid):
524                self.fitArrangeDict[Uid].set_model(new_model)
525                self.fitArrangeDict[Uid].pars= pars
526            else:
527            #no fitArrange object has been create with this Uid
528                fitproblem = FitArrange()
529                fitproblem.set_model(new_model)
530                fitproblem.pars= pars
531                self.fitArrangeDict[Uid] = fitproblem
532               
533        else:
534            raise ValueError, "park_integration:missing parameters"
535   
536    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
537        """ Receives plottable, creates a list of data to fit,set data
538            in a FitArrange object and adds that object in a dictionary
539            with key Uid.
540            @param data: data added
541            @param Uid: unique key corresponding to a fitArrange object with data
542        """
543        if data.__class__.__name__=='Data2D':
544            fitdata=FitData2D(sans_data2d=data, data=data.data, err_data= data.err_data)
545        else:
546            fitdata=FitData1D(x=data.x, y=data.y , dx= data.dx,dy=data.dy,smearer=smearer)
547       
548        fitdata.setFitRange(qmin=qmin,qmax=qmax)
549        #A fitArrange is already created but contains model only at Uid
550        if self.fitArrangeDict.has_key(Uid):
551            self.fitArrangeDict[Uid].add_data(fitdata)
552        else:
553        #no fitArrange object has been create with this Uid
554            fitproblem= FitArrange()
555            fitproblem.add_data(fitdata)
556            self.fitArrangeDict[Uid]=fitproblem   
557   
558    def get_model(self,Uid):
559        """
560            @param Uid: Uid is key in the dictionary containing the model to return
561            @return  a model at this uid or None if no FitArrange element was created
562            with this Uid
563        """
564        if self.fitArrangeDict.has_key(Uid):
565            return self.fitArrangeDict[Uid].get_model()
566        else:
567            return None
568   
569    def remove_Fit_Problem(self,Uid):
570        """remove   fitarrange in Uid"""
571        if self.fitArrangeDict.has_key(Uid):
572            del self.fitArrangeDict[Uid]
573           
574    def select_problem_for_fit(self,Uid,value):
575        """
576            select a couple of model and data at the Uid position in dictionary
577            and set in self.selected value to value
578            @param value: the value to allow fitting. can only have the value one or zero
579        """
580        if self.fitArrangeDict.has_key(Uid):
581             self.fitArrangeDict[Uid].set_to_fit( value)
582             
583             
584    def get_problem_to_fit(self,Uid):
585        """
586            return the self.selected value of the fit problem of Uid
587           @param Uid: the Uid of the problem
588        """
589        if self.fitArrangeDict.has_key(Uid):
590             self.fitArrangeDict[Uid].get_to_fit()
591   
592class FitArrange:
593    def __init__(self):
594        """
595            Class FitArrange contains a set of data for a given model
596            to perform the Fit.FitArrange must contain exactly one model
597            and at least one data for the fit to be performed.
598            model: the model selected by the user
599            Ldata: a list of data what the user wants to fit
600           
601        """
602        self.model = None
603        self.dList =[]
604        self.pars=[]
605        #self.selected  is zero when this fit problem is not schedule to fit
606        #self.selected is 1 when schedule to fit
607        self.selected = 0
608       
609    def set_model(self,model):
610        """
611            set_model save a copy of the model
612            @param model: the model being set
613        """
614        self.model = model
615       
616    def add_data(self,data):
617        """
618            add_data fill a self.dList with data to fit
619            @param data: Data to add in the list 
620        """
621        if not data in self.dList:
622            self.dList.append(data)
623           
624    def get_model(self):
625        """ @return: saved model """
626        return self.model   
627     
628    def get_data(self):
629        """ @return:  list of data dList"""
630        #return self.dList
631        return self.dList[0] 
632     
633    def remove_data(self,data):
634        """
635            Remove one element from the list
636            @param data: Data to remove from dList
637        """
638        if data in self.dList:
639            self.dList.remove(data)
640    def set_to_fit (self, value=0):
641        """
642           set self.selected to 0 or 1  for other values raise an exception
643           @param value: integer between 0 or 1
644        """
645        self.selected= value
646       
647    def get_to_fit(self):
648        """
649            @return self.selected value
650        """
651        return self.selected
Note: See TracBrowser for help on using the repository browser.