source: sasview/park_integration/AbstractFitEngine.py @ 3617aa2

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 3617aa2 was 05f14dd, checked in by Gervaise Alina <gervyh@…>, 16 years ago

checking parameters range for park model

  • Property mode set to 100644
File size: 19.3 KB
Line 
1
2import park,numpy,math
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146        self.qmin=mini
147        self.qmax=maxi
148       
149       
150    def getFitRange(self):
151        """
152            @return the range of data.x to fit
153        """
154        return self.qmin, self.qmax
155     
156     
157    def residuals(self, fn):
158        """ @param fn: function that return model value
159            @return residuals
160        """
161        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
162        if self.qmin==None and self.qmax==None: 
163            fx =numpy.asarray([fn(v) for v in x])
164            return (y - fx)/dy
165        else:
166            idx = (x>=self.qmin) & (x <= self.qmax)
167            fx = numpy.asarray([fn(item)for item in x[idx ]])
168            return (y[idx] - fx)/dy[idx]
169       
170    def residuals_deriv(self, model, pars=[]):
171        """
172            @return residuals derivatives .
173            @note: in this case just return empty array
174        """
175        return []
176class FitData1D(object):
177    """ Wrapper class  for SANS data """
178    def __init__(self,sans_data1d, smearer=None):
179        """
180            Data can be initital with a data (sans plottable)
181            or with vectors.
182           
183            self.smearer is an object of class QSmearer or SlitSmearer
184            that will smear the theory data (slit smearing or resolution
185            smearing) when set.
186           
187            The proper way to set the smearing object would be to
188            do the following:
189           
190            from DataLoader.qsmearing import smear_selection
191            fitdata1d = FitData1D(some_data)
192            fitdata1d.smearer = smear_selection(some_data)
193           
194            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
195           
196            Setting it back to None will turn smearing off.
197           
198        """
199       
200        self.smearer = smearer
201     
202        # Initialize from Data1D object
203        self.data=sans_data1d
204        self.x= sans_data1d.x
205        self.y= sans_data1d.y
206        self.dx= sans_data1d.dx
207        self.dy= sans_data1d.dy
208       
209        ## Min Q-value
210        self.qmin= min (self.data.x)
211        ## Max Q-value
212        self.qmax= max (self.data.x)
213       
214       
215    def setFitRange(self,qmin=None,qmax=None):
216        """ to set the fit range"""
217        if qmin!=None:
218            self.qmin = qmin
219        if qmax !=None:
220            self.qmax = qmax
221       
222       
223    def getFitRange(self):
224        """
225            @return the range of data.x to fit
226        """
227        return self.qmin, self.qmax
228     
229     
230    def residuals(self, fn):
231        """
232            Compute residuals.
233           
234            If self.smearer has been set, use if to smear
235            the data before computing chi squared.
236           
237            @param fn: function that return model value
238            @return residuals
239        """
240        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
241           
242        idx = (x>=self.qmin) & (x <= self.qmax)
243 
244        # Compute theory data f(x)
245        fx = numpy.zeros(len(x))
246        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
247       
248        # Smear theory data
249     
250        if self.smearer is not None:
251            fx = self.smearer(fx)
252       
253        # Sanity check
254        if numpy.size(dy) < numpy.size(x):
255            raise RuntimeError, "FitData1D: invalid error array"
256                           
257        return (y[idx] - fx[idx])/dy[idx]
258     
259 
260       
261    def residuals_deriv(self, model, pars=[]):
262        """
263            @return residuals derivatives .
264            @note: in this case just return empty array
265        """
266        return []
267   
268   
269class FitData2D(object):
270    """ Wrapper class  for SANS data """
271    def __init__(self,sans_data2d):
272        """
273            Data can be initital with a data (sans plottable)
274            or with vectors.
275        """
276        self.data=sans_data2d
277        self.image = sans_data2d.data
278        self.err_image = sans_data2d.err_data
279        self.x_bins= sans_data2d.x_bins
280        self.y_bins= sans_data2d.y_bins
281       
282        x = max(self.data.xmin, self.data.xmax)
283        y = max(self.data.ymin, self.data.ymax)
284       
285        ## fitting range
286        self.qmin = 0
287        self.qmax = math.sqrt(x*x +y*y)
288       
289       
290       
291    def setFitRange(self,qmin=None,qmax=None):
292        """ to set the fit range"""
293        if qmin!=None:
294            self.qmin= qmin
295        if qmax!=None:
296            self.qmax= qmax
297     
298       
299    def getFitRange(self):
300        """
301            @return the range of data.x to fit
302        """
303        return self.qmin, self.qmax
304     
305     
306    def residuals(self, fn):
307        """ @param fn: function that return model value
308            @return residuals
309        """
310        res=[]
311       
312        for i in range(len(self.y_bins)):
313            for j in range(len(self.x_bins)):
314                radius = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
315                if self.qmin <= radius and radius <= self.qmax:
316                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
317                            /self.err_image[j][i] )
318       
319        return numpy.array(res)
320       
321         
322    def residuals_deriv(self, model, pars=[]):
323        """
324            @return residuals derivatives .
325            @note: in this case just return empty array
326        """
327        return []
328   
329class sansAssembly:
330    """
331         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
332    """
333    def __init__(self,paramlist,Model=None , Data=None):
334        """
335            @param Model: the model wrapper fro sans -model
336            @param Data: the data wrapper for sans data
337        """
338        self.model = Model
339        self.data  = Data
340        self.paramlist=paramlist
341        self.res=[]
342    def chisq(self, params):
343        """
344            Calculates chi^2
345            @param params: list of parameter values
346            @return: chi^2
347        """
348        sum = 0
349        for item in self.res:
350            sum += item*item
351       
352        return sum/ len(self.res)
353   
354    def __call__(self,params):
355        """
356            Compute residuals
357            @param params: value of parameters to fit
358        """
359        #import thread
360        self.model.setParams(self.paramlist,params)
361        self.res= self.data.residuals(self.model.eval)
362        #print "residuals",thread.get_ident()
363        return self.res
364   
365class FitEngine:
366    def __init__(self):
367        """
368            Base class for scipy and park fit engine
369        """
370        #List of parameter names to fit
371        self.paramList=[]
372        #Dictionnary of fitArrange element (fit problems)
373        self.fitArrangeDict={}
374       
375    def _concatenateData(self, listdata=[]):
376        """ 
377            _concatenateData method concatenates each fields of all data contains ins listdata.
378            @param listdata: list of data
379            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
380             of data concatenanted
381            @raise: if listdata is empty  will return None
382            @raise: if data in listdata don't contain dy field ,will create an error
383            during fitting
384        """
385        #TODO: we have to refactor the way we handle data.
386        # We should move away from plottables and move towards the Data1D objects
387        # defined in DataLoader. Data1D allows data manipulations, which should be
388        # used to concatenate.
389        # In the meantime we should switch off the concatenation.
390        #if len(listdata)>1:
391        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
392        #return listdata[0]
393       
394        if listdata==[]:
395            raise ValueError, " data list missing"
396        else:
397            xtemp=[]
398            ytemp=[]
399            dytemp=[]
400            self.mini=None
401            self.maxi=None
402               
403            for item in listdata:
404                data=item.data
405                mini,maxi=data.getFitRange()
406                if self.mini==None and self.maxi==None:
407                    self.mini=mini
408                    self.maxi=maxi
409                else:
410                    if mini < self.mini:
411                        self.mini=mini
412                    if self.maxi < maxi:
413                        self.maxi=maxi
414                       
415                   
416                for i in range(len(data.x)):
417                    xtemp.append(data.x[i])
418                    ytemp.append(data.y[i])
419                    if data.dy is not None and len(data.dy)==len(data.y):   
420                        dytemp.append(data.dy[i])
421                    else:
422                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
423            data= Data(x=xtemp,y=ytemp,dy=dytemp)
424            data.setFitRange(self.mini, self.maxi)
425            return data
426       
427       
428    def set_model(self,model,Uid,pars=[]):
429        """
430            set a model on a given uid in the fit engine.
431            @param model: the model to fit
432            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
433            @param pars: the list of parameters to fit
434            @note : pars must contains only name of existing model's paramaters
435        """
436        if len(pars) >0:
437            if model==None:
438                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
439            else:
440                temp=[]
441                for item in pars:
442                    if item in model.model.getParamList():
443                        temp.append(item)
444                        self.paramList.append(item)
445                    else:
446                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
447                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
448                        return
449            #A fitArrange is already created but contains dList only at Uid
450            if self.fitArrangeDict.has_key(Uid):
451                self.fitArrangeDict[Uid].set_model(model)
452                self.fitArrangeDict[Uid].pars= pars
453            else:
454            #no fitArrange object has been create with this Uid
455                fitproblem = FitArrange()
456                fitproblem.set_model(model)
457                fitproblem.pars= pars
458                self.fitArrangeDict[Uid] = fitproblem
459               
460        else:
461            raise ValueError, "park_integration:missing parameters"
462   
463    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
464        """ Receives plottable, creates a list of data to fit,set data
465            in a FitArrange object and adds that object in a dictionary
466            with key Uid.
467            @param data: data added
468            @param Uid: unique key corresponding to a fitArrange object with data
469        """
470        if data.__class__.__name__=='Data2D':
471            fitdata=FitData2D(data)
472        else:
473            fitdata=FitData1D(data, smearer)
474       
475        fitdata.setFitRange(qmin=qmin,qmax=qmax)
476        #A fitArrange is already created but contains model only at Uid
477        if self.fitArrangeDict.has_key(Uid):
478            self.fitArrangeDict[Uid].add_data(fitdata)
479        else:
480        #no fitArrange object has been create with this Uid
481            fitproblem= FitArrange()
482            fitproblem.add_data(fitdata)
483            self.fitArrangeDict[Uid]=fitproblem   
484   
485    def get_model(self,Uid):
486        """
487            @param Uid: Uid is key in the dictionary containing the model to return
488            @return  a model at this uid or None if no FitArrange element was created
489            with this Uid
490        """
491        if self.fitArrangeDict.has_key(Uid):
492            return self.fitArrangeDict[Uid].get_model()
493        else:
494            return None
495   
496    def remove_Fit_Problem(self,Uid):
497        """remove   fitarrange in Uid"""
498        if self.fitArrangeDict.has_key(Uid):
499            del self.fitArrangeDict[Uid]
500           
501    def select_problem_for_fit(self,Uid,value):
502        """
503            select a couple of model and data at the Uid position in dictionary
504            and set in self.selected value to value
505            @param value: the value to allow fitting. can only have the value one or zero
506        """
507        if self.fitArrangeDict.has_key(Uid):
508             self.fitArrangeDict[Uid].set_to_fit( value)
509             
510             
511    def get_problem_to_fit(self,Uid):
512        """
513            return the self.selected value of the fit problem of Uid
514           @param Uid: the Uid of the problem
515        """
516        if self.fitArrangeDict.has_key(Uid):
517             self.fitArrangeDict[Uid].get_to_fit()
518   
519class FitArrange:
520    def __init__(self):
521        """
522            Class FitArrange contains a set of data for a given model
523            to perform the Fit.FitArrange must contain exactly one model
524            and at least one data for the fit to be performed.
525            model: the model selected by the user
526            Ldata: a list of data what the user wants to fit
527           
528        """
529        self.model = None
530        self.dList =[]
531        self.pars=[]
532        #self.selected  is zero when this fit problem is not schedule to fit
533        #self.selected is 1 when schedule to fit
534        self.selected = 0
535       
536    def set_model(self,model):
537        """
538            set_model save a copy of the model
539            @param model: the model being set
540        """
541        self.model = model
542       
543    def add_data(self,data):
544        """
545            add_data fill a self.dList with data to fit
546            @param data: Data to add in the list 
547        """
548        if not data in self.dList:
549            self.dList.append(data)
550           
551    def get_model(self):
552        """ @return: saved model """
553        return self.model   
554     
555    def get_data(self):
556        """ @return:  list of data dList"""
557        #return self.dList
558        return self.dList[0] 
559     
560    def remove_data(self,data):
561        """
562            Remove one element from the list
563            @param data: Data to remove from dList
564        """
565        if data in self.dList:
566            self.dList.remove(data)
567    def set_to_fit (self, value=0):
568        """
569           set self.selected to 0 or 1  for other values raise an exception
570           @param value: integer between 0 or 1
571        """
572        self.selected= value
573       
574    def get_to_fit(self):
575        """
576            @return self.selected value
577        """
578        return self.selected
579   
580
581
582   
Note: See TracBrowser for help on using the repository browser.