source: sasview/park_integration/AbstractFitEngine.py @ 12b76cf

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 12b76cf was 12b76cf, checked in by Jae Cho <jhjcho@…>, 15 years ago

updated lo hi due to the changes in model.details

  • Property mode set to 100644
File size: 21.1 KB
Line 
1import logging, sys
2import park,numpy,math, copy
3from DataLoader.data_info import Data1D
4from DataLoader.data_info import Data2D
5class SansParameter(park.Parameter):
6    """
7        SANS model parameters for use in the PARK fitting service.
8        The parameter attribute value is redirected to the underlying
9        parameter value in the SANS model.
10    """
11    def __init__(self, name, model):
12        """
13            @param name: the name of the model parameter
14            @param model: the sans model to wrap as a park model
15        """
16        self._model, self._name = model,name
17        #set the value for the parameter of the given name
18        self.set(model.getParam(name))
19         
20    def _getvalue(self):
21        """
22            override the _getvalue of park parameter
23            @return value the parameter associates with self.name
24        """
25        return self._model.getParam(self.name)
26   
27    def _setvalue(self,value):
28        """
29            override the _setvalue pf park parameter
30            @param value: the value to set on a given parameter
31        """
32        self._model.setParam(self.name, value)
33       
34    value = property(_getvalue,_setvalue)
35   
36    def _getrange(self):
37        """
38            Override _getrange of park parameter
39            return the range of parameter
40        """
41        #if not  self.name in self._model.getDispParamList():
42        lo,hi = self._model.details[self.name][1:3]
43        if lo is None: lo = -numpy.inf
44        if hi is None: hi = numpy.inf
45        #else:
46            #lo,hi = self._model.details[self.name][1:]
47            #if lo is None: lo = -numpy.inf
48            #if hi is None: hi = numpy.inf
49        if lo >= hi:
50            raise ValueError,"wrong fit range for parameters"
51       
52        return lo,hi
53   
54    def _setrange(self,r):
55        """
56            override _setrange of park parameter
57            @param r: the value of the range to set
58        """
59        self._model.details[self.name][1:3] = r
60    range = property(_getrange,_setrange)
61   
62class Model(park.Model):
63    """
64        PARK wrapper for SANS models.
65    """
66    def __init__(self, sans_model, **kw):
67        """
68            @param sans_model: the sans model to wrap using park interface
69        """
70        park.Model.__init__(self, **kw)
71        self.model = sans_model
72        self.name = sans_model.name
73        #list of parameters names
74        self.sansp = sans_model.getParamList()
75        #list of park parameter
76        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
77        #list of parameterset
78        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
79        self.pars=[]
80 
81 
82    def getParams(self,fitparams):
83        """
84            return a list of value of paramter to fit
85            @param fitparams: list of paramaters name to fit
86        """
87        list=[]
88        self.pars=[]
89        self.pars=fitparams
90        for item in fitparams:
91            for element in self.parkp:
92                 if element.name ==str(item):
93                     list.append(element.value)
94        return list
95   
96   
97    def setParams(self,paramlist, params):
98        """
99            Set value for parameters to fit
100            @param params: list of value for parameters to fit
101        """
102        try:
103            for i in range(len(self.parkp)):
104                for j in range(len(paramlist)):
105                    if self.parkp[i].name==paramlist[j]:
106                        self.parkp[i].value = params[j]
107                        self.model.setParam(self.parkp[i].name,params[j])
108        except:
109            raise
110 
111    def eval(self,x):
112        """
113            override eval method of park model.
114            @param x: the x value used to compute a function
115        """
116        try:
117                return self.model.evalDistribution(x)
118        except:
119                raise
120
121   
122class FitData1D(Data1D):
123    """
124        Wrapper class  for SANS data
125        FitData1D inherits from DataLoader.data_info.Data1D. Implements
126        a way to get residuals from data.
127    """
128    def __init__(self,x, y,dx= None, dy=None, smearer=None):
129        Data1D.__init__(self, x=numpy.array(x), y=numpy.array(y), dx=dx, dy=dy)
130        """
131            @param smearer: is an object of class QSmearer or SlitSmearer
132            that will smear the theory data (slit smearing or resolution
133            smearing) when set.
134           
135            The proper way to set the smearing object would be to
136            do the following:
137           
138            from DataLoader.qsmearing import smear_selection
139            smearer = smear_selection(some_data)
140            fitdata1d = FitData1D( x= [1,3,..,],
141                                    y= [3,4,..,8],
142                                    dx=None,
143                                    dy=[1,2...], smearer= smearer)
144           
145            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
146           
147            Setting it back to None will turn smearing off.
148           
149        """
150       
151        self.smearer = smearer
152        if dy ==None or dy==[]:
153            self.dy= numpy.zeros(len(self.y)) 
154        else:
155            self.dy= numpy.asarray(dy)
156     
157        # For fitting purposes, replace zero errors by 1
158        #TODO: check validity for the rare case where only
159        # a few points have zero errors
160        self.dy[self.dy==0]=1
161       
162        ## Min Q-value
163        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
164        if min (self.x) ==0.0 and self.x[0]==0 and not numpy.isfinite(self.y[0]):
165            self.qmin = min(self.x[self.x!=0])
166        else:                             
167            self.qmin= min (self.x)
168        ## Max Q-value
169        self.qmax = max (self.x)
170       
171        # Range used for input to smearing
172        self._qmin_unsmeared = self.qmin
173        self._qmax_unsmeared = self.qmax
174        # Identify the bin range for the unsmeared and smeared spaces
175        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
176        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
177 
178       
179       
180    def setFitRange(self,qmin=None,qmax=None):
181        """ to set the fit range"""
182        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
183        #ToDo: Fix this.
184        if qmin==0.0 and not numpy.isfinite(self.y[qmin]):
185            self.qmin = min(self.x[self.x!=0])
186        elif qmin!=None:                       
187            self.qmin = qmin           
188
189        if qmax !=None:
190            self.qmax = qmax
191           
192        # Determine the range needed in unsmeared-Q to cover
193        # the smeared Q range
194        self._qmin_unsmeared = self.qmin
195        self._qmax_unsmeared = self.qmax   
196       
197        self._first_unsmeared_bin = 0
198        self._last_unsmeared_bin  = len(self.x)-1
199       
200        if self.smearer!=None:
201            self._first_unsmeared_bin, self._last_unsmeared_bin = self.smearer.get_bin_range(self.qmin, self.qmax)
202            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
203            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
204           
205        # Identify the bin range for the unsmeared and smeared spaces
206        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
207        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
208 
209       
210    def getFitRange(self):
211        """
212            @return the range of data.x to fit
213        """
214        return self.qmin, self.qmax
215       
216    def residuals(self, fn):
217        """
218            Compute residuals.
219           
220            If self.smearer has been set, use if to smear
221            the data before computing chi squared.
222           
223            @param fn: function that return model value
224            @return residuals
225        """
226        # Compute theory data f(x)
227        fx= numpy.zeros(len(self.x))
228        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
229       
230        ## Smear theory data
231        if self.smearer is not None:
232            fx = self.smearer(fx, self._first_unsmeared_bin, self._last_unsmeared_bin)
233       
234        ## Sanity check
235        if numpy.size(self.dy)!= numpy.size(fx):
236            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
237                                                                              numpy.size(fx))
238                                                                             
239        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
240     
241 
242       
243    def residuals_deriv(self, model, pars=[]):
244        """
245            @return residuals derivatives .
246            @note: in this case just return empty array
247        """
248        return []
249   
250   
251class FitData2D(Data2D):
252    """ Wrapper class  for SANS data """
253    def __init__(self,sans_data2d ,data=None, err_data=None):
254        Data2D.__init__(self, data= data, err_data= err_data)
255        """
256            Data can be initital with a data (sans plottable)
257            or with vectors.
258        """
259        self.x_bins_array = []
260        self.y_bins_array = []
261        self.res_err_image=[]
262        self.index_model=[]
263        self.qmin= None
264        self.qmax= None
265        self.set_data(sans_data2d )
266       
267       
268    def set_data(self, sans_data2d ):
269        """
270            Determine the correct x_bin and y_bin to fit
271        """
272        self.x_bins_array= numpy.reshape(sans_data2d.x_bins,
273                                         [1,len(sans_data2d.x_bins)])
274        self.y_bins_array = numpy.reshape(sans_data2d.y_bins,
275                                          [len(sans_data2d.y_bins),1])
276       
277        x_max = max(sans_data2d.xmin, sans_data2d.xmax)
278        y_max = max(sans_data2d.ymin, sans_data2d.ymax)
279       
280        ## fitting range
281        self.qmin = 1e-16
282        self.qmax = math.sqrt(x_max*x_max +y_max*y_max)
283        ## new error image for fitting purpose
284        if self.err_data== None or self.err_data ==[]:
285            self.res_err_data= numpy.zeros(len(self.y_bins),len(self.x_bins))
286        else:
287            self.res_err_data = copy.deepcopy(self.err_data)
288        self.res_err_data[self.res_err_data==0]=1
289       
290        self.radius= numpy.sqrt(self.x_bins_array**2 + self.y_bins_array**2)
291        self.index_model = (self.qmin <= self.radius)&(self.radius<= self.qmax)
292       
293       
294    def setFitRange(self,qmin=None,qmax=None):
295        """ to set the fit range"""
296        if qmin==0.0:
297            self.qmin = 1e-16
298        elif qmin!=None:                       
299            self.qmin = qmin           
300        if qmax!=None:
301            self.qmax= qmax
302     
303       
304    def getFitRange(self):
305        """
306            @return the range of data.x to fit
307        """
308        return self.qmin, self.qmax
309     
310    def residuals(self, fn): 
311       
312        res=self.index_model*(self.data - fn([self.x_bins_array,
313                             self.y_bins_array]))/self.res_err_data
314        return res.ravel() 
315       
316 
317    def residuals_deriv(self, model, pars=[]):
318        """
319            @return residuals derivatives .
320            @note: in this case just return empty array
321        """
322        return []
323   
324class FitAbort(Exception):
325    """
326        Exception raise to stop the fit
327    """
328    print"Creating fit abort Exception"
329
330
331class SansAssembly:
332    """
333         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
334    """
335    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
336        """
337            @param Model: the model wrapper fro sans -model
338            @param Data: the data wrapper for sans data
339        """
340        self.model = Model
341        self.data  = Data
342        self.paramlist=paramlist
343        self.curr_thread= curr_thread
344        self.res=[]
345        self.func_name="Functor"
346    def chisq(self, params):
347        """
348            Calculates chi^2
349            @param params: list of parameter values
350            @return: chi^2
351        """
352        sum = 0
353        for item in self.res:
354            sum += item*item
355        if len(self.res)==0:
356            return None
357        return sum/ len(self.res)
358   
359    def __call__(self,params):
360        """
361            Compute residuals
362            @param params: value of parameters to fit
363        """
364        #import thread
365        self.model.setParams(self.paramlist,params)
366        self.res= self.data.residuals(self.model.eval)
367        #if self.curr_thread != None :
368        #    try:
369        #        self.curr_thread.isquit()
370        #    except:
371        #        raise FitAbort,"stop leastsqr optimizer"   
372        return self.res
373   
374class FitEngine:
375    def __init__(self):
376        """
377            Base class for scipy and park fit engine
378        """
379        #List of parameter names to fit
380        self.paramList=[]
381        #Dictionnary of fitArrange element (fit problems)
382        self.fitArrangeDict={}
383       
384    def _concatenateData(self, listdata=[]):
385        """ 
386            _concatenateData method concatenates each fields of all data contains ins listdata.
387            @param listdata: list of data
388            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
389             of data concatenanted
390            @raise: if listdata is empty  will return None
391            @raise: if data in listdata don't contain dy field ,will create an error
392            during fitting
393        """
394        #TODO: we have to refactor the way we handle data.
395        # We should move away from plottables and move towards the Data1D objects
396        # defined in DataLoader. Data1D allows data manipulations, which should be
397        # used to concatenate.
398        # In the meantime we should switch off the concatenation.
399        #if len(listdata)>1:
400        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
401        #return listdata[0]
402       
403        if listdata==[]:
404            raise ValueError, " data list missing"
405        else:
406            xtemp=[]
407            ytemp=[]
408            dytemp=[]
409            self.mini=None
410            self.maxi=None
411               
412            for item in listdata:
413                data=item.data
414                mini,maxi=data.getFitRange()
415                if self.mini==None and self.maxi==None:
416                    self.mini=mini
417                    self.maxi=maxi
418                else:
419                    if mini < self.mini:
420                        self.mini=mini
421                    if self.maxi < maxi:
422                        self.maxi=maxi
423                       
424                   
425                for i in range(len(data.x)):
426                    xtemp.append(data.x[i])
427                    ytemp.append(data.y[i])
428                    if data.dy is not None and len(data.dy)==len(data.y):   
429                        dytemp.append(data.dy[i])
430                    else:
431                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
432            data= Data(x=xtemp,y=ytemp,dy=dytemp)
433            data.setFitRange(self.mini, self.maxi)
434            return data
435       
436       
437    def set_model(self,model,Uid,pars=[]):
438        """
439            set a model on a given uid in the fit engine.
440            @param model: the model to fit
441            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
442            @param pars: the list of parameters to fit
443            @note : pars must contains only name of existing model's paramaters
444        """
445        if len(pars) >0:
446            if model==None:
447                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
448            else:
449                temp=[]
450                for item in pars:
451                    if item in model.model.getParamList():
452                        temp.append(item)
453                        self.paramList.append(item)
454                    else:
455                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
456                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
457                        return
458            #A fitArrange is already created but contains dList only at Uid
459            if self.fitArrangeDict.has_key(Uid):
460                self.fitArrangeDict[Uid].set_model(model)
461                self.fitArrangeDict[Uid].pars= pars
462            else:
463            #no fitArrange object has been create with this Uid
464                fitproblem = FitArrange()
465                fitproblem.set_model(model)
466                fitproblem.pars= pars
467                self.fitArrangeDict[Uid] = fitproblem
468               
469        else:
470            raise ValueError, "park_integration:missing parameters"
471   
472    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
473        """ Receives plottable, creates a list of data to fit,set data
474            in a FitArrange object and adds that object in a dictionary
475            with key Uid.
476            @param data: data added
477            @param Uid: unique key corresponding to a fitArrange object with data
478        """
479        if data.__class__.__name__=='Data2D':
480            fitdata=FitData2D(sans_data2d=data, data=data.data, err_data= data.err_data)
481        else:
482            fitdata=FitData1D(x=data.x, y=data.y , dx= data.dx,dy=data.dy,smearer=smearer)
483       
484        fitdata.setFitRange(qmin=qmin,qmax=qmax)
485        #A fitArrange is already created but contains model only at Uid
486        if self.fitArrangeDict.has_key(Uid):
487            self.fitArrangeDict[Uid].add_data(fitdata)
488        else:
489        #no fitArrange object has been create with this Uid
490            fitproblem= FitArrange()
491            fitproblem.add_data(fitdata)
492            self.fitArrangeDict[Uid]=fitproblem   
493   
494    def get_model(self,Uid):
495        """
496            @param Uid: Uid is key in the dictionary containing the model to return
497            @return  a model at this uid or None if no FitArrange element was created
498            with this Uid
499        """
500        if self.fitArrangeDict.has_key(Uid):
501            return self.fitArrangeDict[Uid].get_model()
502        else:
503            return None
504   
505    def remove_Fit_Problem(self,Uid):
506        """remove   fitarrange in Uid"""
507        if self.fitArrangeDict.has_key(Uid):
508            del self.fitArrangeDict[Uid]
509           
510    def select_problem_for_fit(self,Uid,value):
511        """
512            select a couple of model and data at the Uid position in dictionary
513            and set in self.selected value to value
514            @param value: the value to allow fitting. can only have the value one or zero
515        """
516        if self.fitArrangeDict.has_key(Uid):
517             self.fitArrangeDict[Uid].set_to_fit( value)
518             
519             
520    def get_problem_to_fit(self,Uid):
521        """
522            return the self.selected value of the fit problem of Uid
523           @param Uid: the Uid of the problem
524        """
525        if self.fitArrangeDict.has_key(Uid):
526             self.fitArrangeDict[Uid].get_to_fit()
527   
528class FitArrange:
529    def __init__(self):
530        """
531            Class FitArrange contains a set of data for a given model
532            to perform the Fit.FitArrange must contain exactly one model
533            and at least one data for the fit to be performed.
534            model: the model selected by the user
535            Ldata: a list of data what the user wants to fit
536           
537        """
538        self.model = None
539        self.dList =[]
540        self.pars=[]
541        #self.selected  is zero when this fit problem is not schedule to fit
542        #self.selected is 1 when schedule to fit
543        self.selected = 0
544       
545    def set_model(self,model):
546        """
547            set_model save a copy of the model
548            @param model: the model being set
549        """
550        self.model = model
551       
552    def add_data(self,data):
553        """
554            add_data fill a self.dList with data to fit
555            @param data: Data to add in the list 
556        """
557        if not data in self.dList:
558            self.dList.append(data)
559           
560    def get_model(self):
561        """ @return: saved model """
562        return self.model   
563     
564    def get_data(self):
565        """ @return:  list of data dList"""
566        #return self.dList
567        return self.dList[0] 
568     
569    def remove_data(self,data):
570        """
571            Remove one element from the list
572            @param data: Data to remove from dList
573        """
574        if data in self.dList:
575            self.dList.remove(data)
576    def set_to_fit (self, value=0):
577        """
578           set self.selected to 0 or 1  for other values raise an exception
579           @param value: integer between 0 or 1
580        """
581        self.selected= value
582       
583    def get_to_fit(self):
584        """
585            @return self.selected value
586        """
587        return self.selected
Note: See TracBrowser for help on using the repository browser.