source: sasview/park_integration/AbstractFitEngine.py @ 2140e68

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 2140e68 was 20d30e9, checked in by Gervaise Alina <gervyh@…>, 16 years ago

fit with qmin qmax value only

  • Property mode set to 100644
File size: 19.3 KB
RevLine 
[4c718654]1
[39c3263]2import park,numpy,math
[48882d1]3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
[ca6d914]11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
[48882d1]18         
[ca6d914]19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
[48882d1]25   
[ca6d914]26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
[48882d1]31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
[ca6d914]36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
[c79ee796]40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
[48882d1]47        return lo,hi
48   
49    def _setrange(self,r):
[ca6d914]50        """
51            override _setrange of park parameter
52            @param r: the value of the range to set
53        """
[48882d1]54        self._model.details[self.name][1:] = r
55    range = property(_getrange,_setrange)
[a9e04aa]56   
57class Model(park.Model):
[48882d1]58    """
59        PARK wrapper for SANS models.
60    """
[388309d]61    def __init__(self, sans_model, **kw):
[ca6d914]62        """
63            @param sans_model: the sans model to wrap using park interface
64        """
[a9e04aa]65        park.Model.__init__(self, **kw)
[48882d1]66        self.model = sans_model
[ca6d914]67        self.name = sans_model.name
68        #list of parameters names
[48882d1]69        self.sansp = sans_model.getParamList()
[ca6d914]70        #list of park parameter
[48882d1]71        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
[ca6d914]72        #list of parameterset
[48882d1]73        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
74        self.pars=[]
[ca6d914]75 
76 
[48882d1]77    def getParams(self,fitparams):
[ca6d914]78        """
79            return a list of value of paramter to fit
80            @param fitparams: list of paramaters name to fit
81        """
[48882d1]82        list=[]
83        self.pars=[]
84        self.pars=fitparams
85        for item in fitparams:
86            for element in self.parkp:
87                 if element.name ==str(item):
88                     list.append(element.value)
89        return list
90   
[ca6d914]91   
[e71440c]92    def setParams(self,paramlist, params):
[ca6d914]93        """
94            Set value for parameters to fit
95            @param params: list of value for parameters to fit
96        """
[e71440c]97        try:
98            for i in range(len(self.parkp)):
99                for j in range(len(paramlist)):
100                    if self.parkp[i].name==paramlist[j]:
101                        self.parkp[i].value = params[j]
102                        self.model.setParam(self.parkp[i].name,params[j])
103        except:
104            raise
[ca6d914]105 
[48882d1]106    def eval(self,x):
[ca6d914]107        """
108            override eval method of park model.
109            @param x: the x value used to compute a function
110        """
[48882d1]111        return self.model.runXY(x)
[388309d]112   
113   
[a9e04aa]114
115
[48882d1]116class Data(object):
117    """ Wrapper class  for SANS data """
118    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
[ca6d914]119        """
120            Data can be initital with a data (sans plottable)
121            or with vectors.
122        """
[48882d1]123        if  sans_data !=None:
124            self.x= sans_data.x
125            self.y= sans_data.y
126            self.dx= sans_data.dx
127            self.dy= sans_data.dy
128           
129        elif (x!=None and y!=None and dy!=None):
130                self.x=x
131                self.y=y
132                self.dx=dx
133                self.dy=dy
134        else:
135            raise ValueError,\
136            "Data is missing x, y or dy, impossible to compute residuals later on"
137        self.qmin=None
138        self.qmax=None
139       
[ca6d914]140       
[48882d1]141    def setFitRange(self,mini=None,maxi=None):
142        """ to set the fit range"""
143        self.qmin=mini
144        self.qmax=maxi
[ca6d914]145       
146       
[48882d1]147    def getFitRange(self):
[ca6d914]148        """
149            @return the range of data.x to fit
150        """
151        return self.qmin, self.qmax
152     
153     
[48882d1]154    def residuals(self, fn):
155        """ @param fn: function that return model value
156            @return residuals
157        """
158        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
159        if self.qmin==None and self.qmax==None: 
[ca6d914]160            fx =numpy.asarray([fn(v) for v in x])
[48882d1]161            return (y - fx)/dy
162        else:
163            idx = (x>=self.qmin) & (x <= self.qmax)
[ca6d914]164            fx = numpy.asarray([fn(item)for item in x[idx ]])
[48882d1]165            return (y[idx] - fx)/dy[idx]
[e71440c]166       
[48882d1]167    def residuals_deriv(self, model, pars=[]):
168        """
169            @return residuals derivatives .
170            @note: in this case just return empty array
171        """
172        return []
[7d0c1a8]173class FitData1D(object):
174    """ Wrapper class  for SANS data """
[b461b6d7]175    def __init__(self,sans_data1d, smearer=None):
[7d0c1a8]176        """
177            Data can be initital with a data (sans plottable)
178            or with vectors.
[109e60ab]179           
180            self.smearer is an object of class QSmearer or SlitSmearer
181            that will smear the theory data (slit smearing or resolution
182            smearing) when set.
183           
184            The proper way to set the smearing object would be to
185            do the following:
186           
187            from DataLoader.qsmearing import smear_selection
188            fitdata1d = FitData1D(some_data)
189            fitdata1d.smearer = smear_selection(some_data)
190           
191            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
192           
193            Setting it back to None will turn smearing off.
194           
[7d0c1a8]195        """
[b461b6d7]196       
197        self.smearer = smearer
198     
[109e60ab]199        # Initialize from Data1D object
[7d0c1a8]200        self.data=sans_data1d
201        self.x= sans_data1d.x
202        self.y= sans_data1d.y
203        self.dx= sans_data1d.dx
204        self.dy= sans_data1d.dy
[109e60ab]205       
206        ## Min Q-value
[20d30e9]207        self.qmin= min (self.data.x)
[109e60ab]208        ## Max Q-value
[20d30e9]209        self.qmax= max (self.data.x)
[7d0c1a8]210       
211       
[20d30e9]212    def setFitRange(self,qmin=None,qmax=None):
[7d0c1a8]213        """ to set the fit range"""
[20d30e9]214        self.qmin = qmin
215        self.qmax = qmax
[7d0c1a8]216       
217       
218    def getFitRange(self):
219        """
220            @return the range of data.x to fit
221        """
222        return self.qmin, self.qmax
223     
224     
225    def residuals(self, fn):
[109e60ab]226        """
227            Compute residuals.
228           
229            If self.smearer has been set, use if to smear
230            the data before computing chi squared.
231           
232            @param fn: function that return model value
233            @return residuals
234        """
235        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
236           
237        # Find entries to consider
[20d30e9]238        if self.qmin==None:
239            self.qmin= min(self.data.x)
240        if  self.qmax==None:
241            self.qmin= max(self.data.x)
242       
243        idx = (x>=self.qmin) & (x <= self.qmax)
244 
[109e60ab]245        # Compute theory data f(x)
246        fx = numpy.zeros(len(x))
247        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
248       
249        # Smear theory data
250        if self.smearer is not None:
251            fx = self.smearer(fx)
[20d30e9]252       
[109e60ab]253        # Sanity check
254        if numpy.size(dy) < numpy.size(x):
255            raise RuntimeError, "FitData1D: invalid error array"
256                           
257        return (y[idx] - fx[idx])/dy[idx]
258     
[20d30e9]259 
[7d0c1a8]260       
261    def residuals_deriv(self, model, pars=[]):
262        """
263            @return residuals derivatives .
264            @note: in this case just return empty array
265        """
266        return []
267   
268   
269class FitData2D(object):
270    """ Wrapper class  for SANS data """
271    def __init__(self,sans_data2d):
272        """
273            Data can be initital with a data (sans plottable)
274            or with vectors.
275        """
276        self.data=sans_data2d
[415bc97]277        self.image = sans_data2d.data
278        self.err_image = sans_data2d.err_data
[7d0c1a8]279        self.x_bins= sans_data2d.x_bins
280        self.y_bins= sans_data2d.y_bins
281       
[20d30e9]282        x = max(self.data.xmin, self.data.xmax)
283        y = max(self.data.ymin, self.data.ymax)
284       
285        ## fitting range
286        self.qmin = 0
287        self.qmax = math.sqrt(x*x +y*y)
[7d0c1a8]288       
289       
[20d30e9]290       
291    def setFitRange(self,qmin=None,qmax=None):
[7d0c1a8]292        """ to set the fit range"""
[20d30e9]293        self.qmin= qmin
294        self.qmax= qmax
295     
[7d0c1a8]296       
297    def getFitRange(self):
298        """
299            @return the range of data.x to fit
300        """
[20d30e9]301        return self.qmin, self.qmax
[7d0c1a8]302     
303     
304    def residuals(self, fn):
305        """ @param fn: function that return model value
306            @return residuals
307        """
308        res=[]
[20d30e9]309        #Here we define that  qmin >=0 and qmax>=qmain
310        x = max(self.data.xmin, self.data.xmax)
311        y = max(self.data.ymin, self.data.ymax)
312       
313        ## fitting range
314        if self.qmin == None:
315            self.qmin = 0
316        if self.qmax == None:
317            self.qmax = math.sqrt(x*x +y*y)
318       
319       
[0e51519]320        for i in range(len(self.y_bins)):
321            for j in range(len(self.x_bins)):
[20d30e9]322                radius = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
323                if self.qmin <= radius and radius <= self.qmax:
324                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
325                            /self.err_image[j][i] )
[0e51519]326       
327        return numpy.array(res)
328       
329         
[7d0c1a8]330    def residuals_deriv(self, model, pars=[]):
331        """
332            @return residuals derivatives .
333            @note: in this case just return empty array
334        """
335        return []
[48882d1]336   
337class sansAssembly:
[ca6d914]338    """
339         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
340    """
[e71440c]341    def __init__(self,paramlist,Model=None , Data=None):
[ca6d914]342        """
343            @param Model: the model wrapper fro sans -model
344            @param Data: the data wrapper for sans data
345        """
346        self.model = Model
347        self.data  = Data
[e71440c]348        self.paramlist=paramlist
[ca6d914]349        self.res=[]
[48882d1]350    def chisq(self, params):
351        """
352            Calculates chi^2
353            @param params: list of parameter values
354            @return: chi^2
355        """
356        sum = 0
357        for item in self.res:
358            sum += item*item
[20d30e9]359       
[26cb768]360        return sum/ len(self.res)
[20d30e9]361   
[48882d1]362    def __call__(self,params):
[ca6d914]363        """
364            Compute residuals
365            @param params: value of parameters to fit
366        """
[681f0dc]367        #import thread
[e71440c]368        self.model.setParams(self.paramlist,params)
[48882d1]369        self.res= self.data.residuals(self.model.eval)
[681f0dc]370        #print "residuals",thread.get_ident()
[48882d1]371        return self.res
372   
[4c718654]373class FitEngine:
[ee5b04c]374    def __init__(self):
[ca6d914]375        """
376            Base class for scipy and park fit engine
377        """
378        #List of parameter names to fit
[ee5b04c]379        self.paramList=[]
[ca6d914]380        #Dictionnary of fitArrange element (fit problems)
381        self.fitArrangeDict={}
382       
[4c718654]383    def _concatenateData(self, listdata=[]):
384        """ 
385            _concatenateData method concatenates each fields of all data contains ins listdata.
386            @param listdata: list of data
[ca6d914]387            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
388             of data concatenanted
[4c718654]389            @raise: if listdata is empty  will return None
390            @raise: if data in listdata don't contain dy field ,will create an error
391            during fitting
392        """
[109e60ab]393        #TODO: we have to refactor the way we handle data.
394        # We should move away from plottables and move towards the Data1D objects
395        # defined in DataLoader. Data1D allows data manipulations, which should be
396        # used to concatenate.
397        # In the meantime we should switch off the concatenation.
398        #if len(listdata)>1:
399        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
400        #return listdata[0]
401       
[4c718654]402        if listdata==[]:
403            raise ValueError, " data list missing"
404        else:
405            xtemp=[]
406            ytemp=[]
407            dytemp=[]
[48882d1]408            self.mini=None
409            self.maxi=None
[4c718654]410               
[7d0c1a8]411            for item in listdata:
412                data=item.data
[48882d1]413                mini,maxi=data.getFitRange()
414                if self.mini==None and self.maxi==None:
415                    self.mini=mini
416                    self.maxi=maxi
417                else:
418                    if mini < self.mini:
419                        self.mini=mini
420                    if self.maxi < maxi:
421                        self.maxi=maxi
422                       
423                   
[4c718654]424                for i in range(len(data.x)):
425                    xtemp.append(data.x[i])
426                    ytemp.append(data.y[i])
427                    if data.dy is not None and len(data.dy)==len(data.y):   
428                        dytemp.append(data.dy[i])
429                    else:
[ee5b04c]430                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
[20d30e9]431            data= Data(x=xtemp,y=ytemp,dy=dytemp)
[48882d1]432            data.setFitRange(self.mini, self.maxi)
433            return data
[ca6d914]434       
435       
436    def set_model(self,model,Uid,pars=[]):
437        """
438            set a model on a given uid in the fit engine.
439            @param model: the model to fit
440            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
441            @param pars: the list of parameters to fit
442            @note : pars must contains only name of existing model's paramaters
443        """
[f44dbc7]444        if len(pars) >0:
[6831a99]445            if model==None:
[f44dbc7]446                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
[6831a99]447            else:
[ca6d914]448                for item in pars:
449                    if item in model.model.getParamList():
450                        self.paramList.append(item)
451                    else:
452                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
453                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
454                        return
[6831a99]455            #A fitArrange is already created but contains dList only at Uid
[ca6d914]456            if self.fitArrangeDict.has_key(Uid):
457                self.fitArrangeDict[Uid].set_model(model)
[6831a99]458            else:
459            #no fitArrange object has been create with this Uid
[48882d1]460                fitproblem = FitArrange()
[6831a99]461                fitproblem.set_model(model)
[ca6d914]462                self.fitArrangeDict[Uid] = fitproblem
[d4b0687]463        else:
[6831a99]464            raise ValueError, "park_integration:missing parameters"
[48882d1]465   
[20d30e9]466    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
[d4b0687]467        """ Receives plottable, creates a list of data to fit,set data
468            in a FitArrange object and adds that object in a dictionary
469            with key Uid.
470            @param data: data added
471            @param Uid: unique key corresponding to a fitArrange object with data
[ca6d914]472        """
[f2817bb]473        if data.__class__.__name__=='Data2D':
[f8ce013]474            fitdata=FitData2D(data)
475        else:
[b461b6d7]476            fitdata=FitData1D(data, smearer)
[20d30e9]477       
478        fitdata.setFitRange(qmin=qmin,qmax=qmax)
[d4b0687]479        #A fitArrange is already created but contains model only at Uid
[ca6d914]480        if self.fitArrangeDict.has_key(Uid):
[f8ce013]481            self.fitArrangeDict[Uid].add_data(fitdata)
[d4b0687]482        else:
483        #no fitArrange object has been create with this Uid
484            fitproblem= FitArrange()
[f8ce013]485            fitproblem.add_data(fitdata)
[ca6d914]486            self.fitArrangeDict[Uid]=fitproblem   
[20d30e9]487   
[d4b0687]488    def get_model(self,Uid):
489        """
490            @param Uid: Uid is key in the dictionary containing the model to return
491            @return  a model at this uid or None if no FitArrange element was created
492            with this Uid
493        """
[ca6d914]494        if self.fitArrangeDict.has_key(Uid):
495            return self.fitArrangeDict[Uid].get_model()
[d4b0687]496        else:
497            return None
498   
499    def remove_Fit_Problem(self,Uid):
500        """remove   fitarrange in Uid"""
[ca6d914]501        if self.fitArrangeDict.has_key(Uid):
502            del self.fitArrangeDict[Uid]
[a9e04aa]503           
504    def select_problem_for_fit(self,Uid,value):
505        """
506            select a couple of model and data at the Uid position in dictionary
507            and set in self.selected value to value
508            @param value: the value to allow fitting. can only have the value one or zero
509        """
510        if self.fitArrangeDict.has_key(Uid):
511             self.fitArrangeDict[Uid].set_to_fit( value)
512    def get_problem_to_fit(self,Uid):
513        """
514            return the self.selected value of the fit problem of Uid
515           @param Uid: the Uid of the problem
516        """
517        if self.fitArrangeDict.has_key(Uid):
518             self.fitArrangeDict[Uid].get_to_fit()
[4c718654]519   
[d4b0687]520class FitArrange:
521    def __init__(self):
522        """
523            Class FitArrange contains a set of data for a given model
524            to perform the Fit.FitArrange must contain exactly one model
525            and at least one data for the fit to be performed.
526            model: the model selected by the user
527            Ldata: a list of data what the user wants to fit
528           
529        """
530        self.model = None
531        self.dList =[]
[a9e04aa]532        #self.selected  is zero when this fit problem is not schedule to fit
533        #self.selected is 1 when schedule to fit
534        self.selected = 0
[d4b0687]535       
536    def set_model(self,model):
537        """
538            set_model save a copy of the model
539            @param model: the model being set
540        """
541        self.model = model
542       
543    def add_data(self,data):
544        """
545            add_data fill a self.dList with data to fit
546            @param data: Data to add in the list 
547        """
548        if not data in self.dList:
549            self.dList.append(data)
550           
551    def get_model(self):
552        """ @return: saved model """
553        return self.model   
554     
555    def get_data(self):
556        """ @return:  list of data dList"""
[7d0c1a8]557        #return self.dList
558        return self.dList[0] 
[d4b0687]559     
560    def remove_data(self,data):
561        """
562            Remove one element from the list
563            @param data: Data to remove from dList
564        """
565        if data in self.dList:
566            self.dList.remove(data)
[a9e04aa]567    def set_to_fit (self, value=0):
568        """
569           set self.selected to 0 or 1  for other values raise an exception
570           @param value: integer between 0 or 1
571        """
572        self.selected= value
573       
574    def get_to_fit(self):
575        """
576            @return self.selected value
577        """
578        return self.selected
[94b44293]579   
[4c718654]580
581
582   
Note: See TracBrowser for help on using the repository browser.