source: sasview/park_integration/AbstractFitEngine.py @ a7abdb1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since a7abdb1 was a7abdb1, checked in by Jae Cho <jhjcho@…>, 15 years ago

fixed a bug in anisotropic 2D fit

  • Property mode set to 100644
File size: 20.4 KB
Line 
1
2import park,numpy,math
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        return lo,hi
48       
49   
50    def _setrange(self,r):
51        """
52            override _setrange of park parameter
53            @param r: the value of the range to set
54        """
55        self._model.details[self.name][1:] = r
56    range = property(_getrange,_setrange)
57   
58class Model(park.Model):
59    """
60        PARK wrapper for SANS models.
61    """
62    def __init__(self, sans_model, **kw):
63        """
64            @param sans_model: the sans model to wrap using park interface
65        """
66        park.Model.__init__(self, **kw)
67        self.model = sans_model
68        self.name = sans_model.name
69        #list of parameters names
70        self.sansp = sans_model.getParamList()
71        #list of park parameter
72        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
73        #list of parameterset
74        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
75        self.pars=[]
76 
77 
78    def getParams(self,fitparams):
79        """
80            return a list of value of paramter to fit
81            @param fitparams: list of paramaters name to fit
82        """
83        list=[]
84        self.pars=[]
85        self.pars=fitparams
86        for item in fitparams:
87            for element in self.parkp:
88                 if element.name ==str(item):
89                     list.append(element.value)
90        return list
91   
92   
93    def setParams(self,paramlist, params):
94        """
95            Set value for parameters to fit
96            @param params: list of value for parameters to fit
97        """
98        try:
99            for i in range(len(self.parkp)):
100                for j in range(len(paramlist)):
101                    if self.parkp[i].name==paramlist[j]:
102                        self.parkp[i].value = params[j]
103                        self.model.setParam(self.parkp[i].name,params[j])
104        except:
105            raise
106 
107    def eval(self,x):
108        """
109            override eval method of park model.
110            @param x: the x value used to compute a function
111        """
112        return self.model.runXY(x)
113   
114   
115
116
117class Data(object):
118    """ Wrapper class  for SANS data """
119    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
120        """
121            Data can be initital with a data (sans plottable)
122            or with vectors.
123        """
124        if  sans_data !=None:
125            self.x= sans_data.x
126            self.y= sans_data.y
127            self.dx= sans_data.dx
128            self.dy= sans_data.dy
129           
130        elif (x!=None and y!=None and dy!=None):
131                self.x=x
132                self.y=y
133                self.dx=dx
134                self.dy=dy
135        else:
136            raise ValueError,\
137            "Data is missing x, y or dy, impossible to compute residuals later on"
138        self.qmin=None
139        self.qmax=None
140       
141       
142    def setFitRange(self,mini=None,maxi=None):
143        """ to set the fit range"""
144        self.qmin=mini
145        self.qmax=maxi
146       
147       
148    def getFitRange(self):
149        """
150            @return the range of data.x to fit
151        """
152        return self.qmin, self.qmax
153     
154     
155    def residuals(self, fn):
156        """ @param fn: function that return model value
157            @return residuals
158        """
159       
160        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
161        if self.qmin==None and self.qmax==None: 
162            fx =numpy.asarray([fn(v) for v in x])
163            return (y - fx)/dy
164        else:
165            idx = (x>=self.qmin) & (x <= self.qmax)
166            fx = numpy.asarray([fn(item)for item in x[idx ]])
167            return (y[idx] - fx)/dy[idx]
168       
169    def residuals_deriv(self, model, pars=[]):
170        """
171            @return residuals derivatives .
172            @note: in this case just return empty array
173        """
174        return []
175class FitData1D(object):
176    """ Wrapper class  for SANS data """
177    def __init__(self,sans_data1d, smearer=None):
178        """
179            Data can be initital with a data (sans plottable)
180            or with vectors.
181           
182            self.smearer is an object of class QSmearer or SlitSmearer
183            that will smear the theory data (slit smearing or resolution
184            smearing) when set.
185           
186            The proper way to set the smearing object would be to
187            do the following:
188           
189            from DataLoader.qsmearing import smear_selection
190            fitdata1d = FitData1D(some_data)
191            fitdata1d.smearer = smear_selection(some_data)
192           
193            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
194           
195            Setting it back to None will turn smearing off.
196           
197        """
198       
199        self.smearer = smearer
200     
201        # Initialize from Data1D object
202        self.data=sans_data1d
203        self.x= sans_data1d.x
204        self.y= sans_data1d.y
205        self.dx= sans_data1d.dx
206        self.dy= sans_data1d.dy
207       
208        ## Min Q-value
209        self.qmin=None
210        ## Max Q-value
211        self.qmax=None
212       
213       
214    def setFitRange(self,qmin=None,qmax=None,ymin=None,ymax=None,):
215        """ to set the fit range"""
216        self.qmin=qmin
217        self.qmax=qmax
218       
219       
220    def getFitRange(self):
221        """
222            @return the range of data.x to fit
223        """
224        return self.qmin, self.qmax
225     
226     
227    def residuals(self, fn):
228        """
229            Compute residuals.
230           
231            If self.smearer has been set, use if to smear
232            the data before computing chi squared.
233           
234            @param fn: function that return model value
235            @return residuals
236        """
237        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
238           
239        # Find entries to consider
240        if self.qmin==None and self.qmax==None:
241            idx = Ellipsis
242        else:
243            idx = (x>=self.qmin) & (x <= self.qmax)
244                       
245        # Compute theory data f(x)
246        fx = numpy.zeros(len(x))
247        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
248       
249        # Smear theory data
250        if self.smearer is not None:
251            fx = self.smearer(fx)
252        # Sanity check
253        if numpy.size(dy) < numpy.size(x):
254            raise RuntimeError, "FitData1D: invalid error array"
255                           
256        return (y[idx] - fx[idx])/dy[idx]
257     
258    def residuals_old(self, fn):
259        """ @param fn: function that return model value
260            @return residuals
261        """
262        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
263        if self.qmin==None and self.qmax==None: 
264            fx =numpy.asarray([fn(v) for v in x])
265            return (y - fx)/dy
266        else:
267            idx = (x>=self.qmin) & (x <= self.qmax)
268            fx = numpy.asarray([fn(item)for item in x[idx ]])
269            return (y[idx] - fx)/dy[idx]
270       
271    def residuals_deriv(self, model, pars=[]):
272        """
273            @return residuals derivatives .
274            @note: in this case just return empty array
275        """
276        return []
277   
278   
279class FitData2D(object):
280    """ Wrapper class  for SANS data """
281    def __init__(self,sans_data2d):
282        """
283            Data can be initital with a data (sans plottable)
284            or with vectors.
285        """
286        self.data=sans_data2d
287        self.image = sans_data2d.data
288        self.err_image = sans_data2d.err_data
289        self.x_bins= sans_data2d.x_bins
290        self.y_bins= sans_data2d.y_bins
291       
292        self.xmin= self.data.xmin
293        self.xmax= self.data.xmax
294        self.ymin= self.data.ymin
295        self.ymax= self.data.ymax
296       
297       
298    def setFitRange(self,qmin=None,qmax=None,ymin=None,ymax=None):
299        """ to set the fit range"""
300        self.xmin= qmin
301        self.xmax= qmax
302        self.ymin= ymin
303        self.ymax= ymax
304       
305    def getFitRange(self):
306        """
307            @return the range of data.x to fit
308        """
309        return self.xmin, self.xmax,self.ymin, self.ymax
310     
311     
312    def residuals(self, fn):
313        """ @param fn: function that return model value
314            @return residuals
315        """
316        res=[]
317        if self.xmin==None:        #Here we define that xmin = qmin >=0 and xmax=qmax>=qmain
318            self.xmin= 0 #self.data.xmin
319        if self.xmax==None:
320            self.xmax= self.data.xmax
321        if self.ymin==None:
322            self.ymin= self.data.ymin
323        if self.ymax==None:
324            self.ymax= self.data.ymax
325        for i in range(len(self.y_bins)):
326            #if self.y_bins[i]>= self.ymin and self.y_bins[i]<= self.ymax:
327            for j in range(len(self.x_bins)):
328                 if math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)>=math.pow(self.xmin,2):
329                     if math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)<=math.pow(self.xmax,2):
330                         #if self.x_bins[j]>= self.xmin and self.x_bins[j]<= self.xmax:               
331                        res.append( (self.image[i][j]- fn([self.x_bins[i],self.y_bins[j]]))\
332                            /self.err_image[i][j] )
333       
334        return numpy.array(res)
335       
336         
337    def residuals_deriv(self, model, pars=[]):
338        """
339            @return residuals derivatives .
340            @note: in this case just return empty array
341        """
342        return []
343   
344class sansAssembly:
345    """
346         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
347    """
348   
349    def __init__(self,paramlist,Model=None , Data=None):
350        """
351            @param Model: the model wrapper fro sans -model
352            @param Data: the data wrapper for sans data
353        """
354        self.model = Model
355        self.data  = Data
356        self.paramlist=paramlist
357        self.res=[]
358    def chisq(self, params):
359        """
360            Calculates chi^2
361            @param params: list of parameter values
362            @return: chi^2
363        """
364        sum = 0
365        for item in self.res:
366            sum += item*item
367        #print "length of data =",len(self.res)
368        return sum/ len(self.res)
369    def __call__(self,params):
370        """
371            Compute residuals
372            @param params: value of parameters to fit
373        """
374        #import thread
375        self.model.setParams(self.paramlist,params)
376        self.res= self.data.residuals(self.model.eval)
377        #print "residuals",thread.get_ident()
378        return self.res
379   
380class FitEngine:
381    def __init__(self):
382        """
383            Base class for scipy and park fit engine
384        """
385        #List of parameter names to fit
386        self.paramList=[]
387        #Dictionnary of fitArrange element (fit problems)
388        self.fitArrangeDict={}
389       
390    def _concatenateData(self, listdata=[]):
391        """ 
392            _concatenateData method concatenates each fields of all data contains ins listdata.
393            @param listdata: list of data
394            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
395             of data concatenanted
396            @raise: if listdata is empty  will return None
397            @raise: if data in listdata don't contain dy field ,will create an error
398            during fitting
399        """
400        #TODO: we have to refactor the way we handle data.
401        # We should move away from plottables and move towards the Data1D objects
402        # defined in DataLoader. Data1D allows data manipulations, which should be
403        # used to concatenate.
404        # In the meantime we should switch off the concatenation.
405        #if len(listdata)>1:
406        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
407        #return listdata[0]
408       
409        if listdata==[]:
410            raise ValueError, " data list missing"
411        else:
412            xtemp=[]
413            ytemp=[]
414            dytemp=[]
415            dxtemp=[]
416            self.mini=None
417            self.maxi=None
418               
419            for item in listdata:
420                data=item.data
421                mini,maxi=data.getFitRange()
422                if self.mini==None and self.maxi==None:
423                    self.mini=mini
424                    self.maxi=maxi
425                else:
426                    if mini < self.mini:
427                        self.mini=mini
428                    if self.maxi < maxi:
429                        self.maxi=maxi
430                       
431                   
432                for i in range(len(data.x)):
433                    xtemp.append(data.x[i])
434                    ytemp.append(data.y[i])
435                    if data.dy is not None and len(data.dy)==len(data.y):   
436                        dytemp.append(data.dy[i])
437                    else:
438                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
439                    if data.dx is not None and len(data.dx)==len(data.x):   
440                        dxtemp.append(data.dx[i])
441                    else:
442                        raise RuntimeError, "Fit._concatenateData: dQ-errors missing"
443            data= Data(x=xtemp,y=ytemp,dy=dytemp, dx=dxtemp)
444            data.setFitRange(self.mini, self.maxi)
445            return data
446       
447       
448    def set_model(self,model,Uid,pars=[]):
449        """
450            set a model on a given uid in the fit engine.
451            @param model: the model to fit
452            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
453            @param pars: the list of parameters to fit
454            @note : pars must contains only name of existing model's paramaters
455        """
456        if len(pars) >0:
457            if model==None:
458                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
459            else:
460                for item in pars:
461                    if item in model.model.getParamList():
462                        self.paramList.append(item)
463                    else:
464                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
465                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
466                        return
467            #A fitArrange is already created but contains dList only at Uid
468            if self.fitArrangeDict.has_key(Uid):
469                self.fitArrangeDict[Uid].set_model(model)
470            else:
471            #no fitArrange object has been create with this Uid
472                fitproblem = FitArrange()
473                fitproblem.set_model(model)
474                self.fitArrangeDict[Uid] = fitproblem
475        else:
476            raise ValueError, "park_integration:missing parameters"
477   
478    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None,ymin=None,ymax=None):
479        """ Receives plottable, creates a list of data to fit,set data
480            in a FitArrange object and adds that object in a dictionary
481            with key Uid.
482            @param data: data added
483            @param Uid: unique key corresponding to a fitArrange object with data
484        """
485        if data.__class__.__name__=='Data2D':
486            fitdata=FitData2D(data)
487        else:
488            fitdata=FitData1D(data, smearer)
489        fitdata.setFitRange(qmin=qmin,qmax=qmax, ymin=ymin,ymax=ymax)
490        #A fitArrange is already created but contains model only at Uid
491        if self.fitArrangeDict.has_key(Uid):
492            self.fitArrangeDict[Uid].add_data(fitdata)
493        else:
494        #no fitArrange object has been create with this Uid
495            fitproblem= FitArrange()
496            fitproblem.add_data(fitdata)
497            self.fitArrangeDict[Uid]=fitproblem   
498
499    def get_model(self,Uid):
500        """
501            @param Uid: Uid is key in the dictionary containing the model to return
502            @return  a model at this uid or None if no FitArrange element was created
503            with this Uid
504        """
505        if self.fitArrangeDict.has_key(Uid):
506            return self.fitArrangeDict[Uid].get_model()
507        else:
508            return None
509   
510    def remove_Fit_Problem(self,Uid):
511        """remove   fitarrange in Uid"""
512        if self.fitArrangeDict.has_key(Uid):
513            del self.fitArrangeDict[Uid]
514           
515    def select_problem_for_fit(self,Uid,value):
516        """
517            select a couple of model and data at the Uid position in dictionary
518            and set in self.selected value to value
519            @param value: the value to allow fitting. can only have the value one or zero
520        """
521        if self.fitArrangeDict.has_key(Uid):
522             self.fitArrangeDict[Uid].set_to_fit( value)
523    def get_problem_to_fit(self,Uid):
524        """
525            return the self.selected value of the fit problem of Uid
526           @param Uid: the Uid of the problem
527        """
528        if self.fitArrangeDict.has_key(Uid):
529             self.fitArrangeDict[Uid].get_to_fit()
530   
531class FitArrange:
532    def __init__(self):
533        """
534            Class FitArrange contains a set of data for a given model
535            to perform the Fit.FitArrange must contain exactly one model
536            and at least one data for the fit to be performed.
537            model: the model selected by the user
538            Ldata: a list of data what the user wants to fit
539           
540        """
541        self.model = None
542        self.dList =[]
543        #self.selected  is zero when this fit problem is not schedule to fit
544        #self.selected is 1 when schedule to fit
545        self.selected = 0
546       
547    def set_model(self,model):
548        """
549            set_model save a copy of the model
550            @param model: the model being set
551        """
552        self.model = model
553       
554    def add_data(self,data):
555        """
556            add_data fill a self.dList with data to fit
557            @param data: Data to add in the list 
558        """
559        if not data in self.dList:
560            self.dList.append(data)
561           
562    def get_model(self):
563        """ @return: saved model """
564        return self.model   
565     
566    def get_data(self):
567        """ @return:  list of data dList"""
568        #return self.dList
569        return self.dList[0] 
570     
571    def remove_data(self,data):
572        """
573            Remove one element from the list
574            @param data: Data to remove from dList
575        """
576        if data in self.dList:
577            self.dList.remove(data)
578    def set_to_fit (self, value=0):
579        """
580           set self.selected to 0 or 1  for other values raise an exception
581           @param value: integer between 0 or 1
582        """
583        self.selected= value
584       
585    def get_to_fit(self):
586        """
587            @return self.selected value
588        """
589        return self.selected
590   
591
592
593   
Note: See TracBrowser for help on using the repository browser.