source: sasview/park_integration/AbstractFitEngine.py @ c09ac449

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since c09ac449 was 09975cbb, checked in by Jae Cho <jhjcho@…>, 16 years ago

Skip all Q=0 point on fitting

  • Property mode set to 100644
File size: 20.2 KB
Line 
1
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146       
147        self.qmin=mini           
148        self.qmax=maxi
149       
150       
151    def getFitRange(self):
152        """
153            @return the range of data.x to fit
154        """
155        return self.qmin, self.qmax
156     
157     
158    def residuals(self, fn):
159        """ @param fn: function that return model value
160            @return residuals
161        """
162        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
163        if self.qmin==None and self.qmax==None: 
164            fx =numpy.asarray([fn(v) for v in x])
165            return (y - fx)/dy
166        else:
167            idx = (x>=self.qmin) & (x <= self.qmax)
168            fx = numpy.asarray([fn(item)for item in x[idx ]])
169            return (y[idx] - fx)/dy[idx]
170       
171    def residuals_deriv(self, model, pars=[]):
172        """
173            @return residuals derivatives .
174            @note: in this case just return empty array
175        """
176        return []
177   
178   
179class FitData1D(object):
180    """ Wrapper class  for SANS data """
181    def __init__(self,sans_data1d, smearer=None):
182        """
183            Data can be initital with a data (sans plottable)
184            or with vectors.
185           
186            self.smearer is an object of class QSmearer or SlitSmearer
187            that will smear the theory data (slit smearing or resolution
188            smearing) when set.
189           
190            The proper way to set the smearing object would be to
191            do the following:
192           
193            from DataLoader.qsmearing import smear_selection
194            fitdata1d = FitData1D(some_data)
195            fitdata1d.smearer = smear_selection(some_data)
196           
197            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
198           
199            Setting it back to None will turn smearing off.
200           
201        """
202       
203        self.smearer = smearer
204     
205        # Initialize from Data1D object
206        self.data=sans_data1d
207        self.x= sans_data1d.x
208        self.y= sans_data1d.y
209        self.dx= sans_data1d.dx
210        self.dy= sans_data1d.dy
211       
212        ## Min Q-value
213        #Skip the Q=0 point.
214        if min (self.data.x) ==0.0:
215            self.qmin = min(self.data.x[self.data.x!=0])
216        else:                             
217            self.qmin= min (self.data.x)
218        ## Max Q-value
219        self.qmax= max (self.data.x)
220       
221       
222    def setFitRange(self,qmin=None,qmax=None):
223        """ to set the fit range"""
224       
225        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
226        #ToDo: Fix this.
227        if qmin==0.0:
228            self.qmin = min(self.data.x[self.data.x!=0])
229        elif qmin!=None:                       
230            self.qmin = qmin           
231
232        if qmax !=None:
233            self.qmax = qmax
234       
235       
236    def getFitRange(self):
237        """
238            @return the range of data.x to fit
239        """
240        return self.qmin, self.qmax
241     
242     
243    def residuals(self, fn):
244        """
245            Compute residuals.
246           
247            If self.smearer has been set, use if to smear
248            the data before computing chi squared.
249           
250            @param fn: function that return model value
251            @return residuals
252        """
253        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
254        if self.dy ==None or self.dy==[]:
255            dy= numpy.zeros(len(y)) 
256        else:
257            dy= copy.deepcopy(self.dy)
258            dy= numpy.asarray(dy)
259        dy[dy==0]=1
260        idx = (x>=self.qmin) & (x <= self.qmax)
261 
262        # Compute theory data f(x)
263        fx = numpy.zeros(len(x))
264        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
265       
266        # Smear theory data
267     
268        if self.smearer is not None:
269            fx = self.smearer(fx)
270       
271        # Sanity check
272        if numpy.size(dy) < numpy.size(x):
273            raise RuntimeError, "FitData1D: invalid error array"
274        return (y[idx] - fx[idx])/dy[idx]
275     
276 
277       
278    def residuals_deriv(self, model, pars=[]):
279        """
280            @return residuals derivatives .
281            @note: in this case just return empty array
282        """
283        return []
284   
285   
286class FitData2D(object):
287    """ Wrapper class  for SANS data """
288    def __init__(self,sans_data2d):
289        """
290            Data can be initital with a data (sans plottable)
291            or with vectors.
292        """
293        self.data=sans_data2d
294        self.image = sans_data2d.data
295        self.err_image = sans_data2d.err_data
296        self.x_bins= sans_data2d.x_bins
297        self.y_bins= sans_data2d.y_bins
298       
299        x = max(self.data.xmin, self.data.xmax)
300        y = max(self.data.ymin, self.data.ymax)
301       
302        ## fitting range
303        self.qmin = 1e-16
304        self.qmax = math.sqrt(x*x +y*y)
305       
306       
307       
308    def setFitRange(self,qmin=None,qmax=None):
309        """ to set the fit range"""
310        if qmin==0.0:
311            self.qmin = 1e-16
312        elif qmin!=None:                       
313            self.qmin = qmin           
314        if qmax!=None:
315            self.qmax= qmax
316     
317       
318    def getFitRange(self):
319        """
320            @return the range of data.x to fit
321        """
322        return self.qmin, self.qmax
323     
324     
325    def residuals(self, fn):
326        """ @param fn: function that return model value
327            @return residuals
328        """
329        res=[]
330        if self.err_image== None or self.err_image ==[]:
331            err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
332        else:
333            err_image = copy.deepcopy(self.err_image)
334           
335        err_image[err_image==0]=1
336        for i in range(len(self.x_bins)):
337            for j in range(len(self.y_bins)):
338                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
339                radius= math.sqrt(temp)
340                if self.qmin <= radius and radius <= self.qmax:
341                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
342                            /err_image[j][i] )
343       
344        return numpy.array(res)
345       
346         
347    def residuals_deriv(self, model, pars=[]):
348        """
349            @return residuals derivatives .
350            @note: in this case just return empty array
351        """
352        return []
353   
354class sansAssembly:
355    """
356         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
357    """
358    def __init__(self,paramlist,Model=None , Data=None):
359        """
360            @param Model: the model wrapper fro sans -model
361            @param Data: the data wrapper for sans data
362        """
363        self.model = Model
364        self.data  = Data
365        self.paramlist=paramlist
366        self.res=[]
367    def chisq(self, params):
368        """
369            Calculates chi^2
370            @param params: list of parameter values
371            @return: chi^2
372        """
373        sum = 0
374        for item in self.res:
375            sum += item*item
376       
377        return sum/ len(self.res)
378   
379    def __call__(self,params):
380        """
381            Compute residuals
382            @param params: value of parameters to fit
383        """
384        #import thread
385        self.model.setParams(self.paramlist,params)
386        self.res= self.data.residuals(self.model.eval)
387        #print "residuals",thread.get_ident()
388        return self.res
389   
390class FitEngine:
391    def __init__(self):
392        """
393            Base class for scipy and park fit engine
394        """
395        #List of parameter names to fit
396        self.paramList=[]
397        #Dictionnary of fitArrange element (fit problems)
398        self.fitArrangeDict={}
399       
400    def _concatenateData(self, listdata=[]):
401        """ 
402            _concatenateData method concatenates each fields of all data contains ins listdata.
403            @param listdata: list of data
404            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
405             of data concatenanted
406            @raise: if listdata is empty  will return None
407            @raise: if data in listdata don't contain dy field ,will create an error
408            during fitting
409        """
410        #TODO: we have to refactor the way we handle data.
411        # We should move away from plottables and move towards the Data1D objects
412        # defined in DataLoader. Data1D allows data manipulations, which should be
413        # used to concatenate.
414        # In the meantime we should switch off the concatenation.
415        #if len(listdata)>1:
416        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
417        #return listdata[0]
418       
419        if listdata==[]:
420            raise ValueError, " data list missing"
421        else:
422            xtemp=[]
423            ytemp=[]
424            dytemp=[]
425            self.mini=None
426            self.maxi=None
427               
428            for item in listdata:
429                data=item.data
430                mini,maxi=data.getFitRange()
431                if self.mini==None and self.maxi==None:
432                    self.mini=mini
433                    self.maxi=maxi
434                else:
435                    if mini < self.mini:
436                        self.mini=mini
437                    if self.maxi < maxi:
438                        self.maxi=maxi
439                       
440                   
441                for i in range(len(data.x)):
442                    xtemp.append(data.x[i])
443                    ytemp.append(data.y[i])
444                    if data.dy is not None and len(data.dy)==len(data.y):   
445                        dytemp.append(data.dy[i])
446                    else:
447                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
448            data= Data(x=xtemp,y=ytemp,dy=dytemp)
449            data.setFitRange(self.mini, self.maxi)
450            return data
451       
452       
453    def set_model(self,model,Uid,pars=[]):
454        """
455            set a model on a given uid in the fit engine.
456            @param model: the model to fit
457            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
458            @param pars: the list of parameters to fit
459            @note : pars must contains only name of existing model's paramaters
460        """
461        if len(pars) >0:
462            if model==None:
463                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
464            else:
465                temp=[]
466                for item in pars:
467                    if item in model.model.getParamList():
468                        temp.append(item)
469                        self.paramList.append(item)
470                    else:
471                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
472                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
473                        return
474            #A fitArrange is already created but contains dList only at Uid
475            if self.fitArrangeDict.has_key(Uid):
476                self.fitArrangeDict[Uid].set_model(model)
477                self.fitArrangeDict[Uid].pars= pars
478            else:
479            #no fitArrange object has been create with this Uid
480                fitproblem = FitArrange()
481                fitproblem.set_model(model)
482                fitproblem.pars= pars
483                self.fitArrangeDict[Uid] = fitproblem
484               
485        else:
486            raise ValueError, "park_integration:missing parameters"
487   
488    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
489        """ Receives plottable, creates a list of data to fit,set data
490            in a FitArrange object and adds that object in a dictionary
491            with key Uid.
492            @param data: data added
493            @param Uid: unique key corresponding to a fitArrange object with data
494        """
495        if data.__class__.__name__=='Data2D':
496            fitdata=FitData2D(data)
497        else:
498            fitdata=FitData1D(data, smearer)
499       
500        fitdata.setFitRange(qmin=qmin,qmax=qmax)
501        #A fitArrange is already created but contains model only at Uid
502        if self.fitArrangeDict.has_key(Uid):
503            self.fitArrangeDict[Uid].add_data(fitdata)
504        else:
505        #no fitArrange object has been create with this Uid
506            fitproblem= FitArrange()
507            fitproblem.add_data(fitdata)
508            self.fitArrangeDict[Uid]=fitproblem   
509   
510    def get_model(self,Uid):
511        """
512            @param Uid: Uid is key in the dictionary containing the model to return
513            @return  a model at this uid or None if no FitArrange element was created
514            with this Uid
515        """
516        if self.fitArrangeDict.has_key(Uid):
517            return self.fitArrangeDict[Uid].get_model()
518        else:
519            return None
520   
521    def remove_Fit_Problem(self,Uid):
522        """remove   fitarrange in Uid"""
523        if self.fitArrangeDict.has_key(Uid):
524            del self.fitArrangeDict[Uid]
525           
526    def select_problem_for_fit(self,Uid,value):
527        """
528            select a couple of model and data at the Uid position in dictionary
529            and set in self.selected value to value
530            @param value: the value to allow fitting. can only have the value one or zero
531        """
532        if self.fitArrangeDict.has_key(Uid):
533             self.fitArrangeDict[Uid].set_to_fit( value)
534             
535             
536    def get_problem_to_fit(self,Uid):
537        """
538            return the self.selected value of the fit problem of Uid
539           @param Uid: the Uid of the problem
540        """
541        if self.fitArrangeDict.has_key(Uid):
542             self.fitArrangeDict[Uid].get_to_fit()
543   
544class FitArrange:
545    def __init__(self):
546        """
547            Class FitArrange contains a set of data for a given model
548            to perform the Fit.FitArrange must contain exactly one model
549            and at least one data for the fit to be performed.
550            model: the model selected by the user
551            Ldata: a list of data what the user wants to fit
552           
553        """
554        self.model = None
555        self.dList =[]
556        self.pars=[]
557        #self.selected  is zero when this fit problem is not schedule to fit
558        #self.selected is 1 when schedule to fit
559        self.selected = 0
560       
561    def set_model(self,model):
562        """
563            set_model save a copy of the model
564            @param model: the model being set
565        """
566        self.model = model
567       
568    def add_data(self,data):
569        """
570            add_data fill a self.dList with data to fit
571            @param data: Data to add in the list 
572        """
573        if not data in self.dList:
574            self.dList.append(data)
575           
576    def get_model(self):
577        """ @return: saved model """
578        return self.model   
579     
580    def get_data(self):
581        """ @return:  list of data dList"""
582        #return self.dList
583        return self.dList[0] 
584     
585    def remove_data(self,data):
586        """
587            Remove one element from the list
588            @param data: Data to remove from dList
589        """
590        if data in self.dList:
591            self.dList.remove(data)
592    def set_to_fit (self, value=0):
593        """
594           set self.selected to 0 or 1  for other values raise an exception
595           @param value: integer between 0 or 1
596        """
597        self.selected= value
598       
599    def get_to_fit(self):
600        """
601            @return self.selected value
602        """
603        return self.selected
604   
605
606
607   
Note: See TracBrowser for help on using the repository browser.