source: sasview/park_integration/AbstractFitEngine.py @ 4043c96

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 4043c96 was 4bb2917, checked in by Mathieu Doucet <doucetm@…>, 15 years ago

park_integration: refactor code using new smearing code.

  • Property mode set to 100644
File size: 20.6 KB
Line 
1import logging, sys
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        #if not  self.name in self._model.getDispParamList():
41        lo,hi = self._model.details[self.name][1:]
42        if lo is None: lo = -numpy.inf
43        if hi is None: hi = numpy.inf
44        #else:
45            #lo,hi = self._model.details[self.name][1:]
46            #if lo is None: lo = -numpy.inf
47            #if hi is None: hi = numpy.inf
48        if lo >= hi:
49            raise ValueError,"wrong fit range for parameters"
50       
51        return lo,hi
52   
53    def _setrange(self,r):
54        """
55            override _setrange of park parameter
56            @param r: the value of the range to set
57        """
58        self._model.details[self.name][1:] = r
59    range = property(_getrange,_setrange)
60   
61class Model(park.Model):
62    """
63        PARK wrapper for SANS models.
64    """
65    def __init__(self, sans_model, **kw):
66        """
67            @param sans_model: the sans model to wrap using park interface
68        """
69        park.Model.__init__(self, **kw)
70        self.model = sans_model
71        self.name = sans_model.name
72        #list of parameters names
73        self.sansp = sans_model.getParamList()
74        #list of park parameter
75        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
76        #list of parameterset
77        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
78        self.pars=[]
79 
80 
81    def getParams(self,fitparams):
82        """
83            return a list of value of paramter to fit
84            @param fitparams: list of paramaters name to fit
85        """
86        list=[]
87        self.pars=[]
88        self.pars=fitparams
89        for item in fitparams:
90            for element in self.parkp:
91                 if element.name ==str(item):
92                     list.append(element.value)
93        return list
94   
95   
96    def setParams(self,paramlist, params):
97        """
98            Set value for parameters to fit
99            @param params: list of value for parameters to fit
100        """
101        try:
102            for i in range(len(self.parkp)):
103                for j in range(len(paramlist)):
104                    if self.parkp[i].name==paramlist[j]:
105                        self.parkp[i].value = params[j]
106                        self.model.setParam(self.parkp[i].name,params[j])
107        except:
108            raise
109 
110    def eval(self,x):
111        """
112            override eval method of park model.
113            @param x: the x value used to compute a function
114        """
115        try:
116                return self.model.evalDistribution(x)
117        except:
118                raise
119
120   
121class FitData1D(object):
122    """ Wrapper class  for SANS data """
123    def __init__(self,sans_data1d, smearer=None):
124        """
125            Data can be initital with a data (sans plottable)
126            or with vectors.
127           
128            self.smearer is an object of class QSmearer or SlitSmearer
129            that will smear the theory data (slit smearing or resolution
130            smearing) when set.
131           
132            The proper way to set the smearing object would be to
133            do the following:
134           
135            from DataLoader.qsmearing import smear_selection
136            fitdata1d = FitData1D(some_data)
137            fitdata1d.smearer = smear_selection(some_data)
138           
139            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
140           
141            Setting it back to None will turn smearing off.
142           
143        """
144       
145        self.smearer = smearer
146     
147        # Initialize from Data1D object
148        self.data=sans_data1d
149        self.x= numpy.array(sans_data1d.x)
150        self.y= numpy.array(sans_data1d.y)
151        self.dx= sans_data1d.dx
152        if sans_data1d.dy ==None or sans_data1d.dy==[]:
153            self.dy= numpy.zeros(len(y)) 
154        else:
155            self.dy= numpy.asarray(sans_data1d.dy)
156     
157        # For fitting purposes, replace zero errors by 1
158        #TODO: check validity for the rare case where only
159        # a few points have zero errors
160        self.dy[self.dy==0]=1
161       
162        ## Min Q-value
163        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
164        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
165            self.qmin = min(self.data.x[self.data.x!=0])
166        else:                             
167            self.qmin= min (self.data.x)
168        ## Max Q-value
169        self.qmax= max (self.data.x)
170       
171        # Range used for input to smearing
172        self._qmin_unsmeared = self.qmin
173        self._qmax_unsmeared = self.qmax
174        # Identify the bin range for the unsmeared and smeared spaces
175        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
176        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
177 
178       
179       
180    def setFitRange(self,qmin=None,qmax=None):
181        """ to set the fit range"""
182        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
183        #ToDo: Fix this.
184        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
185            self.qmin = min(self.data.x[self.data.x!=0])
186        elif qmin!=None:                       
187            self.qmin = qmin           
188
189        if qmax !=None:
190            self.qmax = qmax
191           
192        # Determine the range needed in unsmeared-Q to cover
193        # the smeared Q range
194        self._qmin_unsmeared = self.qmin
195        self._qmax_unsmeared = self.qmax   
196       
197        self._first_unsmeared_bin = 0
198        self._last_unsmeared_bin  = len(self.data.x)-1
199       
200        if self.smearer!=None:
201            self._first_unsmeared_bin, self._last_unsmeared_bin = self.smearer.get_bin_range(self.qmin, self.qmax)
202            self._qmin_unsmeared = self.data.x[self._first_unsmeared_bin]
203            self._qmax_unsmeared = self.data.x[self._last_unsmeared_bin]
204           
205        # Identify the bin range for the unsmeared and smeared spaces
206        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
207        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
208 
209       
210    def getFitRange(self):
211        """
212            @return the range of data.x to fit
213        """
214        return self.qmin, self.qmax
215       
216    def residuals(self, fn):
217        """
218            Compute residuals.
219           
220            If self.smearer has been set, use if to smear
221            the data before computing chi squared.
222           
223            @param fn: function that return model value
224            @return residuals
225        """
226        # Compute theory data f(x)
227        fx= numpy.zeros(len(self.x))
228        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
229       
230        ## Smear theory data
231        if self.smearer is not None:
232            fx = self.smearer(fx, self._first_unsmeared_bin, self._last_unsmeared_bin)
233       
234        ## Sanity check
235        if numpy.size(self.dy)!= numpy.size(fx):
236            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
237                                                                              numpy.size(fx))
238                                                                             
239        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
240     
241 
242       
243    def residuals_deriv(self, model, pars=[]):
244        """
245            @return residuals derivatives .
246            @note: in this case just return empty array
247        """
248        return []
249   
250   
251class FitData2D(object):
252    """ Wrapper class  for SANS data """
253    def __init__(self,sans_data2d):
254        """
255            Data can be initital with a data (sans plottable)
256            or with vectors.
257        """
258        self.data=sans_data2d
259        self.image = sans_data2d.data
260        self.err_image = sans_data2d.err_data
261        self.x_bins_array= numpy.reshape(sans_data2d.x_bins,
262                                         [1,len(sans_data2d.x_bins)])
263        self.y_bins_array = numpy.reshape(sans_data2d.y_bins,
264                                          [len(sans_data2d.y_bins),1])
265       
266        x = max(self.data.xmin, self.data.xmax)
267        y = max(self.data.ymin, self.data.ymax)
268       
269        ## fitting range
270        self.qmin = 1e-16
271        self.qmax = math.sqrt(x*x +y*y)
272        ## new error image for fitting purpose
273        if self.err_image== None or self.err_image ==[]:
274            self.res_err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
275        else:
276            self.res_err_image = copy.deepcopy(self.err_image)
277        self.res_err_image[self.err_image==0]=1
278       
279        self.radius= numpy.sqrt(self.x_bins_array**2 + self.y_bins_array**2)
280        self.index_model = (self.qmin <= self.radius)&(self.radius<= self.qmax)
281       
282       
283    def setFitRange(self,qmin=None,qmax=None):
284        """ to set the fit range"""
285        if qmin==0.0:
286            self.qmin = 1e-16
287        elif qmin!=None:                       
288            self.qmin = qmin           
289        if qmax!=None:
290            self.qmax= qmax
291     
292       
293    def getFitRange(self):
294        """
295            @return the range of data.x to fit
296        """
297        return self.qmin, self.qmax
298     
299    def residuals(self, fn): 
300       
301        res=self.index_model*(self.image - fn([self.x_bins_array,
302                             self.y_bins_array]))/self.res_err_image
303        return res.ravel() 
304       
305 
306    def residuals_deriv(self, model, pars=[]):
307        """
308            @return residuals derivatives .
309            @note: in this case just return empty array
310        """
311        return []
312   
313class FitAbort(Exception):
314    """
315        Exception raise to stop the fit
316    """
317    print"Creating fit abort Exception"
318
319
320class SansAssembly:
321    """
322         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
323    """
324    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
325        """
326            @param Model: the model wrapper fro sans -model
327            @param Data: the data wrapper for sans data
328        """
329        self.model = Model
330        self.data  = Data
331        self.paramlist=paramlist
332        self.curr_thread= curr_thread
333        self.res=[]
334        self.func_name="Functor"
335    def chisq(self, params):
336        """
337            Calculates chi^2
338            @param params: list of parameter values
339            @return: chi^2
340        """
341        sum = 0
342        for item in self.res:
343            sum += item*item
344        if len(self.res)==0:
345            return None
346        return sum/ len(self.res)
347   
348    def __call__(self,params):
349        """
350            Compute residuals
351            @param params: value of parameters to fit
352        """
353        #import thread
354        self.model.setParams(self.paramlist,params)
355        self.res= self.data.residuals(self.model.eval)
356        #if self.curr_thread != None :
357        #    try:
358        #        self.curr_thread.isquit()
359        #    except:
360        #        raise FitAbort,"stop leastsqr optimizer"   
361        return self.res
362   
363class FitEngine:
364    def __init__(self):
365        """
366            Base class for scipy and park fit engine
367        """
368        #List of parameter names to fit
369        self.paramList=[]
370        #Dictionnary of fitArrange element (fit problems)
371        self.fitArrangeDict={}
372       
373    def _concatenateData(self, listdata=[]):
374        """ 
375            _concatenateData method concatenates each fields of all data contains ins listdata.
376            @param listdata: list of data
377            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
378             of data concatenanted
379            @raise: if listdata is empty  will return None
380            @raise: if data in listdata don't contain dy field ,will create an error
381            during fitting
382        """
383        #TODO: we have to refactor the way we handle data.
384        # We should move away from plottables and move towards the Data1D objects
385        # defined in DataLoader. Data1D allows data manipulations, which should be
386        # used to concatenate.
387        # In the meantime we should switch off the concatenation.
388        #if len(listdata)>1:
389        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
390        #return listdata[0]
391       
392        if listdata==[]:
393            raise ValueError, " data list missing"
394        else:
395            xtemp=[]
396            ytemp=[]
397            dytemp=[]
398            self.mini=None
399            self.maxi=None
400               
401            for item in listdata:
402                data=item.data
403                mini,maxi=data.getFitRange()
404                if self.mini==None and self.maxi==None:
405                    self.mini=mini
406                    self.maxi=maxi
407                else:
408                    if mini < self.mini:
409                        self.mini=mini
410                    if self.maxi < maxi:
411                        self.maxi=maxi
412                       
413                   
414                for i in range(len(data.x)):
415                    xtemp.append(data.x[i])
416                    ytemp.append(data.y[i])
417                    if data.dy is not None and len(data.dy)==len(data.y):   
418                        dytemp.append(data.dy[i])
419                    else:
420                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
421            data= Data(x=xtemp,y=ytemp,dy=dytemp)
422            data.setFitRange(self.mini, self.maxi)
423            return data
424       
425       
426    def set_model(self,model,Uid,pars=[]):
427        """
428            set a model on a given uid in the fit engine.
429            @param model: the model to fit
430            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
431            @param pars: the list of parameters to fit
432            @note : pars must contains only name of existing model's paramaters
433        """
434        if len(pars) >0:
435            if model==None:
436                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
437            else:
438                temp=[]
439                for item in pars:
440                    if item in model.model.getParamList():
441                        temp.append(item)
442                        self.paramList.append(item)
443                    else:
444                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
445                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
446                        return
447            #A fitArrange is already created but contains dList only at Uid
448            if self.fitArrangeDict.has_key(Uid):
449                self.fitArrangeDict[Uid].set_model(model)
450                self.fitArrangeDict[Uid].pars= pars
451            else:
452            #no fitArrange object has been create with this Uid
453                fitproblem = FitArrange()
454                fitproblem.set_model(model)
455                fitproblem.pars= pars
456                self.fitArrangeDict[Uid] = fitproblem
457               
458        else:
459            raise ValueError, "park_integration:missing parameters"
460   
461    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
462        """ Receives plottable, creates a list of data to fit,set data
463            in a FitArrange object and adds that object in a dictionary
464            with key Uid.
465            @param data: data added
466            @param Uid: unique key corresponding to a fitArrange object with data
467        """
468        if data.__class__.__name__=='Data2D':
469            fitdata=FitData2D(data)
470        else:
471            fitdata=FitData1D(data, smearer)
472       
473        fitdata.setFitRange(qmin=qmin,qmax=qmax)
474        #A fitArrange is already created but contains model only at Uid
475        if self.fitArrangeDict.has_key(Uid):
476            self.fitArrangeDict[Uid].add_data(fitdata)
477        else:
478        #no fitArrange object has been create with this Uid
479            fitproblem= FitArrange()
480            fitproblem.add_data(fitdata)
481            self.fitArrangeDict[Uid]=fitproblem   
482   
483    def get_model(self,Uid):
484        """
485            @param Uid: Uid is key in the dictionary containing the model to return
486            @return  a model at this uid or None if no FitArrange element was created
487            with this Uid
488        """
489        if self.fitArrangeDict.has_key(Uid):
490            return self.fitArrangeDict[Uid].get_model()
491        else:
492            return None
493   
494    def remove_Fit_Problem(self,Uid):
495        """remove   fitarrange in Uid"""
496        if self.fitArrangeDict.has_key(Uid):
497            del self.fitArrangeDict[Uid]
498           
499    def select_problem_for_fit(self,Uid,value):
500        """
501            select a couple of model and data at the Uid position in dictionary
502            and set in self.selected value to value
503            @param value: the value to allow fitting. can only have the value one or zero
504        """
505        if self.fitArrangeDict.has_key(Uid):
506             self.fitArrangeDict[Uid].set_to_fit( value)
507             
508             
509    def get_problem_to_fit(self,Uid):
510        """
511            return the self.selected value of the fit problem of Uid
512           @param Uid: the Uid of the problem
513        """
514        if self.fitArrangeDict.has_key(Uid):
515             self.fitArrangeDict[Uid].get_to_fit()
516   
517class FitArrange:
518    def __init__(self):
519        """
520            Class FitArrange contains a set of data for a given model
521            to perform the Fit.FitArrange must contain exactly one model
522            and at least one data for the fit to be performed.
523            model: the model selected by the user
524            Ldata: a list of data what the user wants to fit
525           
526        """
527        self.model = None
528        self.dList =[]
529        self.pars=[]
530        #self.selected  is zero when this fit problem is not schedule to fit
531        #self.selected is 1 when schedule to fit
532        self.selected = 0
533       
534    def set_model(self,model):
535        """
536            set_model save a copy of the model
537            @param model: the model being set
538        """
539        self.model = model
540       
541    def add_data(self,data):
542        """
543            add_data fill a self.dList with data to fit
544            @param data: Data to add in the list 
545        """
546        if not data in self.dList:
547            self.dList.append(data)
548           
549    def get_model(self):
550        """ @return: saved model """
551        return self.model   
552     
553    def get_data(self):
554        """ @return:  list of data dList"""
555        #return self.dList
556        return self.dList[0] 
557     
558    def remove_data(self,data):
559        """
560            Remove one element from the list
561            @param data: Data to remove from dList
562        """
563        if data in self.dList:
564            self.dList.remove(data)
565    def set_to_fit (self, value=0):
566        """
567           set self.selected to 0 or 1  for other values raise an exception
568           @param value: integer between 0 or 1
569        """
570        self.selected= value
571       
572    def get_to_fit(self):
573        """
574            @return self.selected value
575        """
576        return self.selected
Note: See TracBrowser for help on using the repository browser.