source: sasview/park_integration/AbstractFitEngine.py @ ea290ee

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since ea290ee was 26cb768, checked in by Jae Cho <jhjcho@…>, 16 years ago

fixed 2d fitting bug by correcting bin number input order.

  • Property mode set to 100644
File size: 20.0 KB
Line 
1
2import park,numpy,math
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        lo,hi = self._model.details[self.name][1:]
41        if lo is None: lo = -numpy.inf
42        if hi is None: hi = numpy.inf
43        return lo,hi
44   
45    def _setrange(self,r):
46        """
47            override _setrange of park parameter
48            @param r: the value of the range to set
49        """
50        self._model.details[self.name][1:] = r
51    range = property(_getrange,_setrange)
52   
53class Model(park.Model):
54    """
55        PARK wrapper for SANS models.
56    """
57    def __init__(self, sans_model, **kw):
58        """
59            @param sans_model: the sans model to wrap using park interface
60        """
61        park.Model.__init__(self, **kw)
62        self.model = sans_model
63        self.name = sans_model.name
64        #list of parameters names
65        self.sansp = sans_model.getParamList()
66        #list of park parameter
67        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
68        #list of parameterset
69        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
70        self.pars=[]
71 
72 
73    def getParams(self,fitparams):
74        """
75            return a list of value of paramter to fit
76            @param fitparams: list of paramaters name to fit
77        """
78        list=[]
79        self.pars=[]
80        self.pars=fitparams
81        for item in fitparams:
82            for element in self.parkp:
83                 if element.name ==str(item):
84                     list.append(element.value)
85        return list
86   
87   
88    def setParams(self,paramlist, params):
89        """
90            Set value for parameters to fit
91            @param params: list of value for parameters to fit
92        """
93        try:
94            for i in range(len(self.parkp)):
95                for j in range(len(paramlist)):
96                    if self.parkp[i].name==paramlist[j]:
97                        self.parkp[i].value = params[j]
98                        self.model.setParam(self.parkp[i].name,params[j])
99        except:
100            raise
101 
102    def eval(self,x):
103        """
104            override eval method of park model.
105            @param x: the x value used to compute a function
106        """
107        return self.model.runXY(x)
108   
109   
110
111
112class Data(object):
113    """ Wrapper class  for SANS data """
114    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
115        """
116            Data can be initital with a data (sans plottable)
117            or with vectors.
118        """
119        if  sans_data !=None:
120            self.x= sans_data.x
121            self.y= sans_data.y
122            self.dx= sans_data.dx
123            self.dy= sans_data.dy
124           
125        elif (x!=None and y!=None and dy!=None):
126                self.x=x
127                self.y=y
128                self.dx=dx
129                self.dy=dy
130        else:
131            raise ValueError,\
132            "Data is missing x, y or dy, impossible to compute residuals later on"
133        self.qmin=None
134        self.qmax=None
135       
136       
137    def setFitRange(self,mini=None,maxi=None):
138        """ to set the fit range"""
139        self.qmin=mini
140        self.qmax=maxi
141       
142       
143    def getFitRange(self):
144        """
145            @return the range of data.x to fit
146        """
147        return self.qmin, self.qmax
148     
149     
150    def residuals(self, fn):
151        """ @param fn: function that return model value
152            @return residuals
153        """
154        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
155        if self.qmin==None and self.qmax==None: 
156            fx =numpy.asarray([fn(v) for v in x])
157            return (y - fx)/dy
158        else:
159            idx = (x>=self.qmin) & (x <= self.qmax)
160            fx = numpy.asarray([fn(item)for item in x[idx ]])
161            return (y[idx] - fx)/dy[idx]
162       
163    def residuals_deriv(self, model, pars=[]):
164        """
165            @return residuals derivatives .
166            @note: in this case just return empty array
167        """
168        return []
169class FitData1D(object):
170    """ Wrapper class  for SANS data """
171    def __init__(self,sans_data1d, smearer=None):
172        """
173            Data can be initital with a data (sans plottable)
174            or with vectors.
175           
176            self.smearer is an object of class QSmearer or SlitSmearer
177            that will smear the theory data (slit smearing or resolution
178            smearing) when set.
179           
180            The proper way to set the smearing object would be to
181            do the following:
182           
183            from DataLoader.qsmearing import smear_selection
184            fitdata1d = FitData1D(some_data)
185            fitdata1d.smearer = smear_selection(some_data)
186           
187            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
188           
189            Setting it back to None will turn smearing off.
190           
191        """
192       
193        self.smearer = smearer
194     
195        # Initialize from Data1D object
196        self.data=sans_data1d
197        self.x= sans_data1d.x
198        self.y= sans_data1d.y
199        self.dx= sans_data1d.dx
200        self.dy= sans_data1d.dy
201       
202        ## Min Q-value
203        self.qmin=None
204        ## Max Q-value
205        self.qmax=None
206       
207       
208    def setFitRange(self,qmin=None,qmax=None,ymin=None,ymax=None,):
209        """ to set the fit range"""
210        self.qmin=qmin
211        self.qmax=qmax
212       
213       
214    def getFitRange(self):
215        """
216            @return the range of data.x to fit
217        """
218        return self.qmin, self.qmax
219     
220     
221    def residuals(self, fn):
222        """
223            Compute residuals.
224           
225            If self.smearer has been set, use if to smear
226            the data before computing chi squared.
227           
228            @param fn: function that return model value
229            @return residuals
230        """
231        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
232           
233        # Find entries to consider
234        if self.qmin==None and self.qmax==None:
235            idx = Ellipsis
236        else:
237            idx = (x>=self.qmin) & (x <= self.qmax)
238                       
239        # Compute theory data f(x)
240        fx = numpy.zeros(len(x))
241        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
242       
243        # Smear theory data
244        if self.smearer is not None:
245            fx = self.smearer(fx)
246           
247        # Sanity check
248        if numpy.size(dy) < numpy.size(x):
249            raise RuntimeError, "FitData1D: invalid error array"
250                           
251        return (y[idx] - fx[idx])/dy[idx]
252     
253    def residuals_old(self, fn):
254        """ @param fn: function that return model value
255            @return residuals
256        """
257        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
258        if self.qmin==None and self.qmax==None: 
259            fx =numpy.asarray([fn(v) for v in x])
260            return (y - fx)/dy
261        else:
262            idx = (x>=self.qmin) & (x <= self.qmax)
263            fx = numpy.asarray([fn(item)for item in x[idx ]])
264            return (y[idx] - fx)/dy[idx]
265       
266    def residuals_deriv(self, model, pars=[]):
267        """
268            @return residuals derivatives .
269            @note: in this case just return empty array
270        """
271        return []
272   
273   
274class FitData2D(object):
275    """ Wrapper class  for SANS data """
276    def __init__(self,sans_data2d):
277        """
278            Data can be initital with a data (sans plottable)
279            or with vectors.
280        """
281        self.data=sans_data2d
282        self.image = sans_data2d.data
283        self.err_image = sans_data2d.err_data
284        self.x_bins= sans_data2d.x_bins
285        self.y_bins= sans_data2d.y_bins
286       
287        self.xmin= self.data.xmin
288        self.xmax= self.data.xmax
289        self.ymin= self.data.ymin
290        self.ymax= self.data.ymax
291       
292       
293    def setFitRange(self,qmin=None,qmax=None,ymin=None,ymax=None):
294        """ to set the fit range"""
295        self.xmin= qmin
296        self.xmax= qmax
297        self.ymin= ymin
298        self.ymax= ymax
299       
300    def getFitRange(self):
301        """
302            @return the range of data.x to fit
303        """
304        return self.xmin, self.xmax,self.ymin, self.ymax
305     
306     
307    def residuals(self, fn):
308        """ @param fn: function that return model value
309            @return residuals
310        """
311        res=[]
312        if self.xmin==None:        #Here we define that xmin = qmin >=0 and xmax=qmax>=qmain
313            self.xmin= 0 #self.data.xmin
314        if self.xmax==None:
315            self.xmax= self.data.xmax
316        if self.ymin==None:
317            self.ymin= self.data.ymin
318        if self.ymax==None:
319            self.ymax= self.data.ymax
320        for i in range(len(self.y_bins)):
321            #if self.y_bins[i]>= self.ymin and self.y_bins[i]<= self.ymax:
322            for j in range(len(self.x_bins)):
323                 if math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)>=math.pow(self.xmin,2):
324                     if math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)<=math.pow(self.xmax,2):
325                         #if self.x_bins[j]>= self.xmin and self.x_bins[j]<= self.xmax:               
326                        res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
327                            /self.err_image[j][i] )
328       
329        return numpy.array(res)
330       
331         
332    def residuals_deriv(self, model, pars=[]):
333        """
334            @return residuals derivatives .
335            @note: in this case just return empty array
336        """
337        return []
338   
339class sansAssembly:
340    """
341         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
342    """
343    def __init__(self,paramlist,Model=None , Data=None):
344        """
345            @param Model: the model wrapper fro sans -model
346            @param Data: the data wrapper for sans data
347        """
348        self.model = Model
349        self.data  = Data
350        self.paramlist=paramlist
351        self.res=[]
352    def chisq(self, params):
353        """
354            Calculates chi^2
355            @param params: list of parameter values
356            @return: chi^2
357        """
358        sum = 0
359        for item in self.res:
360            sum += item*item
361        #print "length of data =",len(self.res)
362        return sum/ len(self.res)
363    def __call__(self,params):
364        """
365            Compute residuals
366            @param params: value of parameters to fit
367        """
368        #import thread
369        self.model.setParams(self.paramlist,params)
370        self.res= self.data.residuals(self.model.eval)
371        #print "residuals",thread.get_ident()
372        return self.res
373   
374class FitEngine:
375    def __init__(self):
376        """
377            Base class for scipy and park fit engine
378        """
379        #List of parameter names to fit
380        self.paramList=[]
381        #Dictionnary of fitArrange element (fit problems)
382        self.fitArrangeDict={}
383       
384    def _concatenateData(self, listdata=[]):
385        """ 
386            _concatenateData method concatenates each fields of all data contains ins listdata.
387            @param listdata: list of data
388            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
389             of data concatenanted
390            @raise: if listdata is empty  will return None
391            @raise: if data in listdata don't contain dy field ,will create an error
392            during fitting
393        """
394        #TODO: we have to refactor the way we handle data.
395        # We should move away from plottables and move towards the Data1D objects
396        # defined in DataLoader. Data1D allows data manipulations, which should be
397        # used to concatenate.
398        # In the meantime we should switch off the concatenation.
399        #if len(listdata)>1:
400        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
401        #return listdata[0]
402       
403        if listdata==[]:
404            raise ValueError, " data list missing"
405        else:
406            xtemp=[]
407            ytemp=[]
408            dytemp=[]
409            self.mini=None
410            self.maxi=None
411               
412            for item in listdata:
413                data=item.data
414                mini,maxi=data.getFitRange()
415                if self.mini==None and self.maxi==None:
416                    self.mini=mini
417                    self.maxi=maxi
418                else:
419                    if mini < self.mini:
420                        self.mini=mini
421                    if self.maxi < maxi:
422                        self.maxi=maxi
423                       
424                   
425                for i in range(len(data.x)):
426                    xtemp.append(data.x[i])
427                    ytemp.append(data.y[i])
428                    if data.dy is not None and len(data.dy)==len(data.y):   
429                        dytemp.append(data.dy[i])
430                    else:
431                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
432            data= Data(x=xtemp,y=ytemp,dy=dytemp)
433            data.setFitRange(self.mini, self.maxi)
434            return data
435       
436       
437    def set_model(self,model,Uid,pars=[]):
438        """
439            set a model on a given uid in the fit engine.
440            @param model: the model to fit
441            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
442            @param pars: the list of parameters to fit
443            @note : pars must contains only name of existing model's paramaters
444        """
445        if len(pars) >0:
446            if model==None:
447                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
448            else:
449                for item in pars:
450                    if item in model.model.getParamList():
451                        self.paramList.append(item)
452                    else:
453                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
454                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
455                        return
456            #A fitArrange is already created but contains dList only at Uid
457            if self.fitArrangeDict.has_key(Uid):
458                self.fitArrangeDict[Uid].set_model(model)
459            else:
460            #no fitArrange object has been create with this Uid
461                fitproblem = FitArrange()
462                fitproblem.set_model(model)
463                self.fitArrangeDict[Uid] = fitproblem
464        else:
465            raise ValueError, "park_integration:missing parameters"
466   
467    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None,ymin=None,ymax=None):
468        """ Receives plottable, creates a list of data to fit,set data
469            in a FitArrange object and adds that object in a dictionary
470            with key Uid.
471            @param data: data added
472            @param Uid: unique key corresponding to a fitArrange object with data
473        """
474        if data.__class__.__name__=='Data2D':
475            fitdata=FitData2D(data)
476        else:
477            fitdata=FitData1D(data, smearer)
478       
479        fitdata.setFitRange(qmin=qmin,qmax=qmax, ymin=ymin,ymax=ymax)
480        #A fitArrange is already created but contains model only at Uid
481        if self.fitArrangeDict.has_key(Uid):
482            self.fitArrangeDict[Uid].add_data(fitdata)
483        else:
484        #no fitArrange object has been create with this Uid
485            fitproblem= FitArrange()
486            fitproblem.add_data(fitdata)
487            self.fitArrangeDict[Uid]=fitproblem   
488   
489    def get_model(self,Uid):
490        """
491            @param Uid: Uid is key in the dictionary containing the model to return
492            @return  a model at this uid or None if no FitArrange element was created
493            with this Uid
494        """
495        if self.fitArrangeDict.has_key(Uid):
496            return self.fitArrangeDict[Uid].get_model()
497        else:
498            return None
499   
500    def remove_Fit_Problem(self,Uid):
501        """remove   fitarrange in Uid"""
502        if self.fitArrangeDict.has_key(Uid):
503            del self.fitArrangeDict[Uid]
504           
505    def select_problem_for_fit(self,Uid,value):
506        """
507            select a couple of model and data at the Uid position in dictionary
508            and set in self.selected value to value
509            @param value: the value to allow fitting. can only have the value one or zero
510        """
511        if self.fitArrangeDict.has_key(Uid):
512             self.fitArrangeDict[Uid].set_to_fit( value)
513    def get_problem_to_fit(self,Uid):
514        """
515            return the self.selected value of the fit problem of Uid
516           @param Uid: the Uid of the problem
517        """
518        if self.fitArrangeDict.has_key(Uid):
519             self.fitArrangeDict[Uid].get_to_fit()
520   
521class FitArrange:
522    def __init__(self):
523        """
524            Class FitArrange contains a set of data for a given model
525            to perform the Fit.FitArrange must contain exactly one model
526            and at least one data for the fit to be performed.
527            model: the model selected by the user
528            Ldata: a list of data what the user wants to fit
529           
530        """
531        self.model = None
532        self.dList =[]
533        #self.selected  is zero when this fit problem is not schedule to fit
534        #self.selected is 1 when schedule to fit
535        self.selected = 0
536       
537    def set_model(self,model):
538        """
539            set_model save a copy of the model
540            @param model: the model being set
541        """
542        self.model = model
543       
544    def add_data(self,data):
545        """
546            add_data fill a self.dList with data to fit
547            @param data: Data to add in the list 
548        """
549        if not data in self.dList:
550            self.dList.append(data)
551           
552    def get_model(self):
553        """ @return: saved model """
554        return self.model   
555     
556    def get_data(self):
557        """ @return:  list of data dList"""
558        #return self.dList
559        return self.dList[0] 
560     
561    def remove_data(self,data):
562        """
563            Remove one element from the list
564            @param data: Data to remove from dList
565        """
566        if data in self.dList:
567            self.dList.remove(data)
568    def set_to_fit (self, value=0):
569        """
570           set self.selected to 0 or 1  for other values raise an exception
571           @param value: integer between 0 or 1
572        """
573        self.selected= value
574       
575    def get_to_fit(self):
576        """
577            @return self.selected value
578        """
579        return self.selected
580   
581
582
583   
Note: See TracBrowser for help on using the repository browser.