source: sasview/park_integration/AbstractFitEngine.py @ de5c813

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since de5c813 was eef2e0ed, checked in by Gervaise Alina <gervyh@…>, 16 years ago

remove comments

  • Property mode set to 100644
File size: 19.2 KB
RevLine 
[4c718654]1
[39c3263]2import park,numpy,math
[48882d1]3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
[ca6d914]11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
[48882d1]18         
[ca6d914]19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
[48882d1]25   
[ca6d914]26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
[48882d1]31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
[ca6d914]36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
[c79ee796]40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
[48882d1]47        return lo,hi
48   
49    def _setrange(self,r):
[ca6d914]50        """
51            override _setrange of park parameter
52            @param r: the value of the range to set
53        """
[48882d1]54        self._model.details[self.name][1:] = r
55    range = property(_getrange,_setrange)
[a9e04aa]56   
57class Model(park.Model):
[48882d1]58    """
59        PARK wrapper for SANS models.
60    """
[388309d]61    def __init__(self, sans_model, **kw):
[ca6d914]62        """
63            @param sans_model: the sans model to wrap using park interface
64        """
[a9e04aa]65        park.Model.__init__(self, **kw)
[48882d1]66        self.model = sans_model
[ca6d914]67        self.name = sans_model.name
68        #list of parameters names
[48882d1]69        self.sansp = sans_model.getParamList()
[ca6d914]70        #list of park parameter
[48882d1]71        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
[ca6d914]72        #list of parameterset
[48882d1]73        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
74        self.pars=[]
[ca6d914]75 
76 
[48882d1]77    def getParams(self,fitparams):
[ca6d914]78        """
79            return a list of value of paramter to fit
80            @param fitparams: list of paramaters name to fit
81        """
[48882d1]82        list=[]
83        self.pars=[]
84        self.pars=fitparams
85        for item in fitparams:
86            for element in self.parkp:
87                 if element.name ==str(item):
88                     list.append(element.value)
89        return list
90   
[ca6d914]91   
[e71440c]92    def setParams(self,paramlist, params):
[ca6d914]93        """
94            Set value for parameters to fit
95            @param params: list of value for parameters to fit
96        """
[e71440c]97        try:
98            for i in range(len(self.parkp)):
99                for j in range(len(paramlist)):
100                    if self.parkp[i].name==paramlist[j]:
101                        self.parkp[i].value = params[j]
102                        self.model.setParam(self.parkp[i].name,params[j])
103        except:
104            raise
[ca6d914]105 
[48882d1]106    def eval(self,x):
[ca6d914]107        """
108            override eval method of park model.
109            @param x: the x value used to compute a function
110        """
[48882d1]111        return self.model.runXY(x)
[388309d]112   
113   
[a9e04aa]114
115
[48882d1]116class Data(object):
117    """ Wrapper class  for SANS data """
118    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
[ca6d914]119        """
120            Data can be initital with a data (sans plottable)
121            or with vectors.
122        """
[48882d1]123        if  sans_data !=None:
124            self.x= sans_data.x
125            self.y= sans_data.y
126            self.dx= sans_data.dx
127            self.dy= sans_data.dy
128           
129        elif (x!=None and y!=None and dy!=None):
130                self.x=x
131                self.y=y
132                self.dx=dx
133                self.dy=dy
134        else:
135            raise ValueError,\
136            "Data is missing x, y or dy, impossible to compute residuals later on"
137        self.qmin=None
138        self.qmax=None
139       
[ca6d914]140       
[48882d1]141    def setFitRange(self,mini=None,maxi=None):
142        """ to set the fit range"""
143        self.qmin=mini
144        self.qmax=maxi
[ca6d914]145       
146       
[48882d1]147    def getFitRange(self):
[ca6d914]148        """
149            @return the range of data.x to fit
150        """
151        return self.qmin, self.qmax
152     
153     
[48882d1]154    def residuals(self, fn):
155        """ @param fn: function that return model value
156            @return residuals
157        """
158        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
159        if self.qmin==None and self.qmax==None: 
[ca6d914]160            fx =numpy.asarray([fn(v) for v in x])
[48882d1]161            return (y - fx)/dy
162        else:
163            idx = (x>=self.qmin) & (x <= self.qmax)
[ca6d914]164            fx = numpy.asarray([fn(item)for item in x[idx ]])
[48882d1]165            return (y[idx] - fx)/dy[idx]
[e71440c]166       
[48882d1]167    def residuals_deriv(self, model, pars=[]):
168        """
169            @return residuals derivatives .
170            @note: in this case just return empty array
171        """
172        return []
[7d0c1a8]173class FitData1D(object):
174    """ Wrapper class  for SANS data """
[b461b6d7]175    def __init__(self,sans_data1d, smearer=None):
[7d0c1a8]176        """
177            Data can be initital with a data (sans plottable)
178            or with vectors.
[109e60ab]179           
180            self.smearer is an object of class QSmearer or SlitSmearer
181            that will smear the theory data (slit smearing or resolution
182            smearing) when set.
183           
184            The proper way to set the smearing object would be to
185            do the following:
186           
187            from DataLoader.qsmearing import smear_selection
188            fitdata1d = FitData1D(some_data)
189            fitdata1d.smearer = smear_selection(some_data)
190           
191            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
192           
193            Setting it back to None will turn smearing off.
194           
[7d0c1a8]195        """
[b461b6d7]196       
197        self.smearer = smearer
198     
[109e60ab]199        # Initialize from Data1D object
[7d0c1a8]200        self.data=sans_data1d
201        self.x= sans_data1d.x
202        self.y= sans_data1d.y
203        self.dx= sans_data1d.dx
204        self.dy= sans_data1d.dy
[109e60ab]205       
206        ## Min Q-value
[20d30e9]207        self.qmin= min (self.data.x)
[109e60ab]208        ## Max Q-value
[20d30e9]209        self.qmax= max (self.data.x)
[7d0c1a8]210       
211       
[20d30e9]212    def setFitRange(self,qmin=None,qmax=None):
[7d0c1a8]213        """ to set the fit range"""
[eef2e0ed]214        if qmin!=None:
215            self.qmin = qmin
216        if qmax !=None:
217            self.qmax = qmax
[7d0c1a8]218       
219       
220    def getFitRange(self):
221        """
222            @return the range of data.x to fit
223        """
224        return self.qmin, self.qmax
225     
226     
227    def residuals(self, fn):
[109e60ab]228        """
229            Compute residuals.
230           
231            If self.smearer has been set, use if to smear
232            the data before computing chi squared.
233           
234            @param fn: function that return model value
235            @return residuals
236        """
237        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
238           
[20d30e9]239        idx = (x>=self.qmin) & (x <= self.qmax)
240 
[109e60ab]241        # Compute theory data f(x)
242        fx = numpy.zeros(len(x))
243        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
244       
245        # Smear theory data
[aed7c57]246     
[109e60ab]247        if self.smearer is not None:
248            fx = self.smearer(fx)
[20d30e9]249       
[109e60ab]250        # Sanity check
251        if numpy.size(dy) < numpy.size(x):
252            raise RuntimeError, "FitData1D: invalid error array"
253                           
254        return (y[idx] - fx[idx])/dy[idx]
255     
[20d30e9]256 
[7d0c1a8]257       
258    def residuals_deriv(self, model, pars=[]):
259        """
260            @return residuals derivatives .
261            @note: in this case just return empty array
262        """
263        return []
264   
265   
266class FitData2D(object):
267    """ Wrapper class  for SANS data """
268    def __init__(self,sans_data2d):
269        """
270            Data can be initital with a data (sans plottable)
271            or with vectors.
272        """
273        self.data=sans_data2d
[415bc97]274        self.image = sans_data2d.data
275        self.err_image = sans_data2d.err_data
[7d0c1a8]276        self.x_bins= sans_data2d.x_bins
277        self.y_bins= sans_data2d.y_bins
278       
[20d30e9]279        x = max(self.data.xmin, self.data.xmax)
280        y = max(self.data.ymin, self.data.ymax)
281       
282        ## fitting range
283        self.qmin = 0
284        self.qmax = math.sqrt(x*x +y*y)
[7d0c1a8]285       
286       
[20d30e9]287       
288    def setFitRange(self,qmin=None,qmax=None):
[7d0c1a8]289        """ to set the fit range"""
[eef2e0ed]290        if qmin!=None:
291            self.qmin= qmin
292        if qmax!=None:
293            self.qmax= qmax
[20d30e9]294     
[7d0c1a8]295       
296    def getFitRange(self):
297        """
298            @return the range of data.x to fit
299        """
[20d30e9]300        return self.qmin, self.qmax
[7d0c1a8]301     
302     
303    def residuals(self, fn):
304        """ @param fn: function that return model value
305            @return residuals
306        """
307        res=[]
[20d30e9]308       
[0e51519]309        for i in range(len(self.y_bins)):
310            for j in range(len(self.x_bins)):
[20d30e9]311                radius = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
312                if self.qmin <= radius and radius <= self.qmax:
313                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
314                            /self.err_image[j][i] )
[0e51519]315       
316        return numpy.array(res)
317       
318         
[7d0c1a8]319    def residuals_deriv(self, model, pars=[]):
320        """
321            @return residuals derivatives .
322            @note: in this case just return empty array
323        """
324        return []
[48882d1]325   
326class sansAssembly:
[ca6d914]327    """
328         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
329    """
[e71440c]330    def __init__(self,paramlist,Model=None , Data=None):
[ca6d914]331        """
332            @param Model: the model wrapper fro sans -model
333            @param Data: the data wrapper for sans data
334        """
335        self.model = Model
336        self.data  = Data
[e71440c]337        self.paramlist=paramlist
[ca6d914]338        self.res=[]
[48882d1]339    def chisq(self, params):
340        """
341            Calculates chi^2
342            @param params: list of parameter values
343            @return: chi^2
344        """
345        sum = 0
346        for item in self.res:
347            sum += item*item
[20d30e9]348       
[26cb768]349        return sum/ len(self.res)
[20d30e9]350   
[48882d1]351    def __call__(self,params):
[ca6d914]352        """
353            Compute residuals
354            @param params: value of parameters to fit
355        """
[681f0dc]356        #import thread
[e71440c]357        self.model.setParams(self.paramlist,params)
[48882d1]358        self.res= self.data.residuals(self.model.eval)
[681f0dc]359        #print "residuals",thread.get_ident()
[48882d1]360        return self.res
361   
[4c718654]362class FitEngine:
[ee5b04c]363    def __init__(self):
[ca6d914]364        """
365            Base class for scipy and park fit engine
366        """
367        #List of parameter names to fit
[ee5b04c]368        self.paramList=[]
[ca6d914]369        #Dictionnary of fitArrange element (fit problems)
370        self.fitArrangeDict={}
371       
[4c718654]372    def _concatenateData(self, listdata=[]):
373        """ 
374            _concatenateData method concatenates each fields of all data contains ins listdata.
375            @param listdata: list of data
[ca6d914]376            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
377             of data concatenanted
[4c718654]378            @raise: if listdata is empty  will return None
379            @raise: if data in listdata don't contain dy field ,will create an error
380            during fitting
381        """
[109e60ab]382        #TODO: we have to refactor the way we handle data.
383        # We should move away from plottables and move towards the Data1D objects
384        # defined in DataLoader. Data1D allows data manipulations, which should be
385        # used to concatenate.
386        # In the meantime we should switch off the concatenation.
387        #if len(listdata)>1:
388        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
389        #return listdata[0]
390       
[4c718654]391        if listdata==[]:
392            raise ValueError, " data list missing"
393        else:
394            xtemp=[]
395            ytemp=[]
396            dytemp=[]
[48882d1]397            self.mini=None
398            self.maxi=None
[4c718654]399               
[7d0c1a8]400            for item in listdata:
401                data=item.data
[48882d1]402                mini,maxi=data.getFitRange()
403                if self.mini==None and self.maxi==None:
404                    self.mini=mini
405                    self.maxi=maxi
406                else:
407                    if mini < self.mini:
408                        self.mini=mini
409                    if self.maxi < maxi:
410                        self.maxi=maxi
411                       
412                   
[4c718654]413                for i in range(len(data.x)):
414                    xtemp.append(data.x[i])
415                    ytemp.append(data.y[i])
416                    if data.dy is not None and len(data.dy)==len(data.y):   
417                        dytemp.append(data.dy[i])
418                    else:
[ee5b04c]419                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
[20d30e9]420            data= Data(x=xtemp,y=ytemp,dy=dytemp)
[48882d1]421            data.setFitRange(self.mini, self.maxi)
422            return data
[ca6d914]423       
424       
425    def set_model(self,model,Uid,pars=[]):
426        """
427            set a model on a given uid in the fit engine.
428            @param model: the model to fit
429            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
430            @param pars: the list of parameters to fit
431            @note : pars must contains only name of existing model's paramaters
432        """
[f44dbc7]433        if len(pars) >0:
[6831a99]434            if model==None:
[f44dbc7]435                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
[6831a99]436            else:
[aed7c57]437                temp=[]
[ca6d914]438                for item in pars:
439                    if item in model.model.getParamList():
[aed7c57]440                        temp.append(item)
[ca6d914]441                        self.paramList.append(item)
442                    else:
443                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
444                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
445                        return
[6831a99]446            #A fitArrange is already created but contains dList only at Uid
[ca6d914]447            if self.fitArrangeDict.has_key(Uid):
448                self.fitArrangeDict[Uid].set_model(model)
[aed7c57]449                self.fitArrangeDict[Uid].pars= pars
[6831a99]450            else:
451            #no fitArrange object has been create with this Uid
[48882d1]452                fitproblem = FitArrange()
[6831a99]453                fitproblem.set_model(model)
[aed7c57]454                fitproblem.pars= pars
[ca6d914]455                self.fitArrangeDict[Uid] = fitproblem
[aed7c57]456               
[d4b0687]457        else:
[6831a99]458            raise ValueError, "park_integration:missing parameters"
[48882d1]459   
[20d30e9]460    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
[d4b0687]461        """ Receives plottable, creates a list of data to fit,set data
462            in a FitArrange object and adds that object in a dictionary
463            with key Uid.
464            @param data: data added
465            @param Uid: unique key corresponding to a fitArrange object with data
[ca6d914]466        """
[f2817bb]467        if data.__class__.__name__=='Data2D':
[f8ce013]468            fitdata=FitData2D(data)
469        else:
[b461b6d7]470            fitdata=FitData1D(data, smearer)
[20d30e9]471       
472        fitdata.setFitRange(qmin=qmin,qmax=qmax)
[d4b0687]473        #A fitArrange is already created but contains model only at Uid
[ca6d914]474        if self.fitArrangeDict.has_key(Uid):
[f8ce013]475            self.fitArrangeDict[Uid].add_data(fitdata)
[d4b0687]476        else:
477        #no fitArrange object has been create with this Uid
478            fitproblem= FitArrange()
[f8ce013]479            fitproblem.add_data(fitdata)
[ca6d914]480            self.fitArrangeDict[Uid]=fitproblem   
[20d30e9]481   
[d4b0687]482    def get_model(self,Uid):
483        """
484            @param Uid: Uid is key in the dictionary containing the model to return
485            @return  a model at this uid or None if no FitArrange element was created
486            with this Uid
487        """
[ca6d914]488        if self.fitArrangeDict.has_key(Uid):
489            return self.fitArrangeDict[Uid].get_model()
[d4b0687]490        else:
491            return None
492   
493    def remove_Fit_Problem(self,Uid):
494        """remove   fitarrange in Uid"""
[ca6d914]495        if self.fitArrangeDict.has_key(Uid):
496            del self.fitArrangeDict[Uid]
[a9e04aa]497           
498    def select_problem_for_fit(self,Uid,value):
499        """
500            select a couple of model and data at the Uid position in dictionary
501            and set in self.selected value to value
502            @param value: the value to allow fitting. can only have the value one or zero
503        """
504        if self.fitArrangeDict.has_key(Uid):
505             self.fitArrangeDict[Uid].set_to_fit( value)
[eef2e0ed]506             
507             
[a9e04aa]508    def get_problem_to_fit(self,Uid):
509        """
510            return the self.selected value of the fit problem of Uid
511           @param Uid: the Uid of the problem
512        """
513        if self.fitArrangeDict.has_key(Uid):
514             self.fitArrangeDict[Uid].get_to_fit()
[4c718654]515   
[d4b0687]516class FitArrange:
517    def __init__(self):
518        """
519            Class FitArrange contains a set of data for a given model
520            to perform the Fit.FitArrange must contain exactly one model
521            and at least one data for the fit to be performed.
522            model: the model selected by the user
523            Ldata: a list of data what the user wants to fit
524           
525        """
526        self.model = None
527        self.dList =[]
[aed7c57]528        self.pars=[]
[a9e04aa]529        #self.selected  is zero when this fit problem is not schedule to fit
530        #self.selected is 1 when schedule to fit
531        self.selected = 0
[d4b0687]532       
533    def set_model(self,model):
534        """
535            set_model save a copy of the model
536            @param model: the model being set
537        """
538        self.model = model
539       
540    def add_data(self,data):
541        """
542            add_data fill a self.dList with data to fit
543            @param data: Data to add in the list 
544        """
545        if not data in self.dList:
546            self.dList.append(data)
547           
548    def get_model(self):
549        """ @return: saved model """
550        return self.model   
551     
552    def get_data(self):
553        """ @return:  list of data dList"""
[7d0c1a8]554        #return self.dList
555        return self.dList[0] 
[d4b0687]556     
557    def remove_data(self,data):
558        """
559            Remove one element from the list
560            @param data: Data to remove from dList
561        """
562        if data in self.dList:
563            self.dList.remove(data)
[a9e04aa]564    def set_to_fit (self, value=0):
565        """
566           set self.selected to 0 or 1  for other values raise an exception
567           @param value: integer between 0 or 1
568        """
569        self.selected= value
570       
571    def get_to_fit(self):
572        """
573            @return self.selected value
574        """
575        return self.selected
[94b44293]576   
[4c718654]577
578
579   
Note: See TracBrowser for help on using the repository browser.