source: sasview/park_integration/AbstractFitEngine.py @ 7e752fe

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 7e752fe was 7e752fe, checked in by Gervaise Alina <gervyh@…>, 15 years ago

fix residual

  • Property mode set to 100644
File size: 21.3 KB
Line 
1import logging, sys
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        try:
115                return self.model.evalDistribution(x)
116        except:
117                raise
118
119   
120class FitData1D(object):
121    """ Wrapper class  for SANS data """
122    def __init__(self,sans_data1d, smearer=None):
123        """
124            Data can be initital with a data (sans plottable)
125            or with vectors.
126           
127            self.smearer is an object of class QSmearer or SlitSmearer
128            that will smear the theory data (slit smearing or resolution
129            smearing) when set.
130           
131            The proper way to set the smearing object would be to
132            do the following:
133           
134            from DataLoader.qsmearing import smear_selection
135            fitdata1d = FitData1D(some_data)
136            fitdata1d.smearer = smear_selection(some_data)
137           
138            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
139           
140            Setting it back to None will turn smearing off.
141           
142        """
143       
144        self.smearer = smearer
145     
146        # Initialize from Data1D object
147        self.data=sans_data1d
148        self.x= numpy.array(sans_data1d.x)
149        self.y= numpy.array(sans_data1d.y)
150        self.dx= sans_data1d.dx
151        if sans_data1d.dy ==None or sans_data1d.dy==[]:
152            self.dy= numpy.zeros(len(y)) 
153        else:
154            self.dy= numpy.asarray(sans_data1d.dy)
155     
156        # For fitting purposes, replace zero errors by 1
157        #TODO: check validity for the rare case where only
158        # a few points have zero errors
159        self.dy[self.dy==0]=1
160       
161        ## Min Q-value
162        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
163        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
164            self.qmin = min(self.data.x[self.data.x!=0])
165        else:                             
166            self.qmin= min (self.data.x)
167        ## Max Q-value
168        self.qmax= max (self.data.x)
169       
170        # Range used for input to smearing
171        self._qmin_unsmeared = self.qmin
172        self._qmax_unsmeared = self.qmax
173        # Identify the bin range for the unsmeared and smeared spaces
174        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
175        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
176 
177       
178       
179    def setFitRange(self,qmin=None,qmax=None):
180        """ to set the fit range"""
181        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
182        #ToDo: Fix this.
183        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
184            self.qmin = min(self.data.x[self.data.x!=0])
185        elif qmin!=None:                       
186            self.qmin = qmin           
187
188        if qmax !=None:
189            self.qmax = qmax
190           
191        # Range used for input to smearing
192        self._qmin_unsmeared = self.qmin
193        self._qmax_unsmeared = self.qmax   
194       
195        # Determine the range needed in unsmeared-Q to cover
196        # the smeared Q range
197        #TODO: use the smearing matrix to determine which
198        # bin range to use
199        if self.smearer.__class__.__name__ == 'SlitSmearer':
200            self._qmin_unsmeared = min(self.data.x)
201            self._qmax_unsmeared = max(self.data.x)
202        elif self.smearer.__class__.__name__ == 'QSmearer':
203            # Take 3 sigmas as the offset between smeared and unsmeared space
204            try:
205                offset = 3.0*max(self.smearer.width)
206                self._qmin_unsmeared = max([min(self.data.x), self.qmin-offset])
207                self._qmax_unsmeared = min([max(self.data.x), self.qmax+offset])
208            except:
209                logging.error("FitData1D.setFitRange: %s" % sys.exc_value)
210        # Identify the bin range for the unsmeared and smeared spaces
211        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
212        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
213 
214       
215    def getFitRange(self):
216        """
217            @return the range of data.x to fit
218        """
219        return self.qmin, self.qmax
220       
221    def residuals(self, fn):
222        """
223            Compute residuals.
224           
225            If self.smearer has been set, use if to smear
226            the data before computing chi squared.
227           
228            @param fn: function that return model value
229            @return residuals
230        """
231        # Compute theory data f(x)
232        fx= numpy.zeros(len(self.x))
233        _first_bin = None
234        _last_bin  = None
235       
236        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
237       
238       
239        for i_x in range(len(self.x)):
240            if self.idx_unsmeared[i_x]==True:
241                # Identify first and last bin
242                #TODO: refactor this to pass q-values to the smearer
243                # and let it figure out which bin range to use
244                if _first_bin is None:
245                    _first_bin = i_x
246                else:
247                    _last_bin  = i_x
248               
249        ## Smear theory data
250        if self.smearer is not None:
251            fx = self.smearer(fx, _first_bin, _last_bin)
252       
253        ## Sanity check
254        if numpy.size(self.dy)!= numpy.size(fx):
255            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
256                                                                              numpy.size(fx))
257                                                                             
258        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
259     
260 
261       
262    def residuals_deriv(self, model, pars=[]):
263        """
264            @return residuals derivatives .
265            @note: in this case just return empty array
266        """
267        return []
268   
269   
270class FitData2D(object):
271    """ Wrapper class  for SANS data """
272    def __init__(self,sans_data2d):
273        """
274            Data can be initital with a data (sans plottable)
275            or with vectors.
276        """
277        self.data=sans_data2d
278        self.image = sans_data2d.data
279        self.err_image = sans_data2d.err_data
280        self.x_bins_array= numpy.reshape(sans_data2d.x_bins,
281                                         [1,len(sans_data2d.x_bins)])
282        self.y_bins_array = numpy.reshape(sans_data2d.y_bins,
283                                          [len(sans_data2d.y_bins),1])
284       
285        x = max(self.data.xmin, self.data.xmax)
286        y = max(self.data.ymin, self.data.ymax)
287       
288        ## fitting range
289        self.qmin = 1e-16
290        self.qmax = math.sqrt(x*x +y*y)
291        ## new error image for fitting purpose
292        if self.err_image== None or self.err_image ==[]:
293            self.res_err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
294        else:
295            self.res_err_image = copy.deepcopy(self.err_image)
296        self.res_err_image[self.err_image==0]=1
297       
298        self.radius= numpy.sqrt(self.x_bins_array**2 + self.y_bins_array**2)
299        self.index_model = (self.qmin <= self.radius)&(self.radius<= self.qmax)
300       
301       
302    def setFitRange(self,qmin=None,qmax=None):
303        """ to set the fit range"""
304        if qmin==0.0:
305            self.qmin = 1e-16
306        elif qmin!=None:                       
307            self.qmin = qmin           
308        if qmax!=None:
309            self.qmax= qmax
310     
311       
312    def getFitRange(self):
313        """
314            @return the range of data.x to fit
315        """
316        return self.qmin, self.qmax
317     
318    def residuals(self, fn): 
319       
320        res=self.index_model*(self.image - fn([self.y_bins_array,
321                             self.x_bins_array]))/self.res_err_image
322        return res.ravel() 
323       
324 
325    def residuals_deriv(self, model, pars=[]):
326        """
327            @return residuals derivatives .
328            @note: in this case just return empty array
329        """
330        return []
331   
332class FitAbort(Exception):
333    """
334        Exception raise to stop the fit
335    """
336    print"Creating fit abort Exception"
337
338
339class SansAssembly:
340    """
341         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
342    """
343    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
344        """
345            @param Model: the model wrapper fro sans -model
346            @param Data: the data wrapper for sans data
347        """
348        self.model = Model
349        self.data  = Data
350        self.paramlist=paramlist
351        self.curr_thread= curr_thread
352        self.res=[]
353        self.func_name="Functor"
354    def chisq(self, params):
355        """
356            Calculates chi^2
357            @param params: list of parameter values
358            @return: chi^2
359        """
360        sum = 0
361        for item in self.res:
362            sum += item*item
363        if len(self.res)==0:
364            return None
365        return sum/ len(self.res)
366   
367    def __call__(self,params):
368        """
369            Compute residuals
370            @param params: value of parameters to fit
371        """
372        #import thread
373        self.model.setParams(self.paramlist,params)
374        self.res= self.data.residuals(self.model.eval)
375        #if self.curr_thread != None :
376        #    try:
377        #        self.curr_thread.isquit()
378        #    except:
379        #        raise FitAbort,"stop leastsqr optimizer"   
380        return self.res
381   
382class FitEngine:
383    def __init__(self):
384        """
385            Base class for scipy and park fit engine
386        """
387        #List of parameter names to fit
388        self.paramList=[]
389        #Dictionnary of fitArrange element (fit problems)
390        self.fitArrangeDict={}
391       
392    def _concatenateData(self, listdata=[]):
393        """ 
394            _concatenateData method concatenates each fields of all data contains ins listdata.
395            @param listdata: list of data
396            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
397             of data concatenanted
398            @raise: if listdata is empty  will return None
399            @raise: if data in listdata don't contain dy field ,will create an error
400            during fitting
401        """
402        #TODO: we have to refactor the way we handle data.
403        # We should move away from plottables and move towards the Data1D objects
404        # defined in DataLoader. Data1D allows data manipulations, which should be
405        # used to concatenate.
406        # In the meantime we should switch off the concatenation.
407        #if len(listdata)>1:
408        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
409        #return listdata[0]
410       
411        if listdata==[]:
412            raise ValueError, " data list missing"
413        else:
414            xtemp=[]
415            ytemp=[]
416            dytemp=[]
417            self.mini=None
418            self.maxi=None
419               
420            for item in listdata:
421                data=item.data
422                mini,maxi=data.getFitRange()
423                if self.mini==None and self.maxi==None:
424                    self.mini=mini
425                    self.maxi=maxi
426                else:
427                    if mini < self.mini:
428                        self.mini=mini
429                    if self.maxi < maxi:
430                        self.maxi=maxi
431                       
432                   
433                for i in range(len(data.x)):
434                    xtemp.append(data.x[i])
435                    ytemp.append(data.y[i])
436                    if data.dy is not None and len(data.dy)==len(data.y):   
437                        dytemp.append(data.dy[i])
438                    else:
439                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
440            data= Data(x=xtemp,y=ytemp,dy=dytemp)
441            data.setFitRange(self.mini, self.maxi)
442            return data
443       
444       
445    def set_model(self,model,Uid,pars=[]):
446        """
447            set a model on a given uid in the fit engine.
448            @param model: the model to fit
449            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
450            @param pars: the list of parameters to fit
451            @note : pars must contains only name of existing model's paramaters
452        """
453        if len(pars) >0:
454            if model==None:
455                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
456            else:
457                temp=[]
458                for item in pars:
459                    if item in model.model.getParamList():
460                        temp.append(item)
461                        self.paramList.append(item)
462                    else:
463                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
464                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
465                        return
466            #A fitArrange is already created but contains dList only at Uid
467            if self.fitArrangeDict.has_key(Uid):
468                self.fitArrangeDict[Uid].set_model(model)
469                self.fitArrangeDict[Uid].pars= pars
470            else:
471            #no fitArrange object has been create with this Uid
472                fitproblem = FitArrange()
473                fitproblem.set_model(model)
474                fitproblem.pars= pars
475                self.fitArrangeDict[Uid] = fitproblem
476               
477        else:
478            raise ValueError, "park_integration:missing parameters"
479   
480    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
481        """ Receives plottable, creates a list of data to fit,set data
482            in a FitArrange object and adds that object in a dictionary
483            with key Uid.
484            @param data: data added
485            @param Uid: unique key corresponding to a fitArrange object with data
486        """
487        if data.__class__.__name__=='Data2D':
488            fitdata=FitData2D(data)
489        else:
490            fitdata=FitData1D(data, smearer)
491       
492        fitdata.setFitRange(qmin=qmin,qmax=qmax)
493        #A fitArrange is already created but contains model only at Uid
494        if self.fitArrangeDict.has_key(Uid):
495            self.fitArrangeDict[Uid].add_data(fitdata)
496        else:
497        #no fitArrange object has been create with this Uid
498            fitproblem= FitArrange()
499            fitproblem.add_data(fitdata)
500            self.fitArrangeDict[Uid]=fitproblem   
501   
502    def get_model(self,Uid):
503        """
504            @param Uid: Uid is key in the dictionary containing the model to return
505            @return  a model at this uid or None if no FitArrange element was created
506            with this Uid
507        """
508        if self.fitArrangeDict.has_key(Uid):
509            return self.fitArrangeDict[Uid].get_model()
510        else:
511            return None
512   
513    def remove_Fit_Problem(self,Uid):
514        """remove   fitarrange in Uid"""
515        if self.fitArrangeDict.has_key(Uid):
516            del self.fitArrangeDict[Uid]
517           
518    def select_problem_for_fit(self,Uid,value):
519        """
520            select a couple of model and data at the Uid position in dictionary
521            and set in self.selected value to value
522            @param value: the value to allow fitting. can only have the value one or zero
523        """
524        if self.fitArrangeDict.has_key(Uid):
525             self.fitArrangeDict[Uid].set_to_fit( value)
526             
527             
528    def get_problem_to_fit(self,Uid):
529        """
530            return the self.selected value of the fit problem of Uid
531           @param Uid: the Uid of the problem
532        """
533        if self.fitArrangeDict.has_key(Uid):
534             self.fitArrangeDict[Uid].get_to_fit()
535   
536class FitArrange:
537    def __init__(self):
538        """
539            Class FitArrange contains a set of data for a given model
540            to perform the Fit.FitArrange must contain exactly one model
541            and at least one data for the fit to be performed.
542            model: the model selected by the user
543            Ldata: a list of data what the user wants to fit
544           
545        """
546        self.model = None
547        self.dList =[]
548        self.pars=[]
549        #self.selected  is zero when this fit problem is not schedule to fit
550        #self.selected is 1 when schedule to fit
551        self.selected = 0
552       
553    def set_model(self,model):
554        """
555            set_model save a copy of the model
556            @param model: the model being set
557        """
558        self.model = model
559       
560    def add_data(self,data):
561        """
562            add_data fill a self.dList with data to fit
563            @param data: Data to add in the list 
564        """
565        if not data in self.dList:
566            self.dList.append(data)
567           
568    def get_model(self):
569        """ @return: saved model """
570        return self.model   
571     
572    def get_data(self):
573        """ @return:  list of data dList"""
574        #return self.dList
575        return self.dList[0] 
576     
577    def remove_data(self,data):
578        """
579            Remove one element from the list
580            @param data: Data to remove from dList
581        """
582        if data in self.dList:
583            self.dList.remove(data)
584    def set_to_fit (self, value=0):
585        """
586           set self.selected to 0 or 1  for other values raise an exception
587           @param value: integer between 0 or 1
588        """
589        self.selected= value
590       
591    def get_to_fit(self):
592        """
593            @return self.selected value
594        """
595        return self.selected
Note: See TracBrowser for help on using the repository browser.