source: sasview/park_integration/AbstractFitEngine.py @ a96d246

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since a96d246 was a96d246, checked in by Gervaise Alina <gervyh@…>, 15 years ago

change orientation of eval.dist result

  • Property mode set to 100644
File size: 21.3 KB
RevLine 
[72c7d31]1import logging, sys
[54c21f50]2import park,numpy,math, copy
[48882d1]3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
[ca6d914]11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
[48882d1]18         
[ca6d914]19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
[48882d1]25   
[ca6d914]26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
[48882d1]31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
[ca6d914]36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
[c79ee796]40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
[05f14dd]47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
[48882d1]50        return lo,hi
51   
52    def _setrange(self,r):
[ca6d914]53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
[48882d1]57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
[a9e04aa]59   
60class Model(park.Model):
[48882d1]61    """
62        PARK wrapper for SANS models.
63    """
[388309d]64    def __init__(self, sans_model, **kw):
[ca6d914]65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
[a9e04aa]68        park.Model.__init__(self, **kw)
[48882d1]69        self.model = sans_model
[ca6d914]70        self.name = sans_model.name
71        #list of parameters names
[48882d1]72        self.sansp = sans_model.getParamList()
[ca6d914]73        #list of park parameter
[48882d1]74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
[ca6d914]75        #list of parameterset
[48882d1]76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
[ca6d914]78 
79 
[48882d1]80    def getParams(self,fitparams):
[ca6d914]81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
[48882d1]85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
[ca6d914]94   
[e71440c]95    def setParams(self,paramlist, params):
[ca6d914]96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
[e71440c]100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
[ca6d914]108 
[48882d1]109    def eval(self,x):
[ca6d914]110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
[d8a2e31]114        try:
[fd0d30fd]115                return self.model.evalDistribution(x)
[d8a2e31]116        except:
[fd0d30fd]117                raise
[a9e04aa]118
[b64fa56]119   
[7d0c1a8]120class FitData1D(object):
121    """ Wrapper class  for SANS data """
[b461b6d7]122    def __init__(self,sans_data1d, smearer=None):
[7d0c1a8]123        """
124            Data can be initital with a data (sans plottable)
125            or with vectors.
[109e60ab]126           
127            self.smearer is an object of class QSmearer or SlitSmearer
128            that will smear the theory data (slit smearing or resolution
129            smearing) when set.
130           
131            The proper way to set the smearing object would be to
132            do the following:
133           
134            from DataLoader.qsmearing import smear_selection
135            fitdata1d = FitData1D(some_data)
136            fitdata1d.smearer = smear_selection(some_data)
137           
138            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
139           
140            Setting it back to None will turn smearing off.
141           
[7d0c1a8]142        """
[b461b6d7]143       
144        self.smearer = smearer
145     
[109e60ab]146        # Initialize from Data1D object
[7d0c1a8]147        self.data=sans_data1d
[fd0d30fd]148        self.x= numpy.array(sans_data1d.x)
149        self.y= numpy.array(sans_data1d.y)
[72c7d31]150        self.dx= sans_data1d.dx
[fd0d30fd]151        if sans_data1d.dy ==None or sans_data1d.dy==[]:
152            self.dy= numpy.zeros(len(y)) 
153        else:
154            self.dy= numpy.asarray(sans_data1d.dy)
155     
156        # For fitting purposes, replace zero errors by 1
157        #TODO: check validity for the rare case where only
158        # a few points have zero errors
159        self.dy[self.dy==0]=1
[109e60ab]160       
161        ## Min Q-value
[4bd557d]162        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
163        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
[773806e]164            self.qmin = min(self.data.x[self.data.x!=0])
165        else:                             
166            self.qmin= min (self.data.x)
[109e60ab]167        ## Max Q-value
[20d30e9]168        self.qmax= max (self.data.x)
[058b2d7]169       
[72c7d31]170        # Range used for input to smearing
171        self._qmin_unsmeared = self.qmin
172        self._qmax_unsmeared = self.qmax
[fd0d30fd]173        # Identify the bin range for the unsmeared and smeared spaces
174        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
175        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
176 
[72c7d31]177       
178       
[20d30e9]179    def setFitRange(self,qmin=None,qmax=None):
[7d0c1a8]180        """ to set the fit range"""
[09975cbb]181        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
[773806e]182        #ToDo: Fix this.
[90db8e8]183        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
[773806e]184            self.qmin = min(self.data.x[self.data.x!=0])
185        elif qmin!=None:                       
186            self.qmin = qmin           
187
[eef2e0ed]188        if qmax !=None:
189            self.qmax = qmax
[72c7d31]190           
191        # Range used for input to smearing
192        self._qmin_unsmeared = self.qmin
193        self._qmax_unsmeared = self.qmax   
194       
195        # Determine the range needed in unsmeared-Q to cover
196        # the smeared Q range
197        #TODO: use the smearing matrix to determine which
198        # bin range to use
199        if self.smearer.__class__.__name__ == 'SlitSmearer':
200            self._qmin_unsmeared = min(self.data.x)
201            self._qmax_unsmeared = max(self.data.x)
202        elif self.smearer.__class__.__name__ == 'QSmearer':
203            # Take 3 sigmas as the offset between smeared and unsmeared space
204            try:
205                offset = 3.0*max(self.smearer.width)
206                self._qmin_unsmeared = max([min(self.data.x), self.qmin-offset])
207                self._qmax_unsmeared = min([max(self.data.x), self.qmax+offset])
208            except:
209                logging.error("FitData1D.setFitRange: %s" % sys.exc_value)
[fd0d30fd]210        # Identify the bin range for the unsmeared and smeared spaces
211        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
212        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
213 
[7d0c1a8]214       
215    def getFitRange(self):
216        """
217            @return the range of data.x to fit
218        """
219        return self.qmin, self.qmax
[72c7d31]220       
[7d0c1a8]221    def residuals(self, fn):
[72c7d31]222        """
223            Compute residuals.
224           
225            If self.smearer has been set, use if to smear
226            the data before computing chi squared.
227           
228            @param fn: function that return model value
229            @return residuals
[109e60ab]230        """
231        # Compute theory data f(x)
[fd0d30fd]232        fx= numpy.zeros(len(self.x))
[72c7d31]233        _first_bin = None
234        _last_bin  = None
[fd0d30fd]235       
[7e752fe]236        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
[fd0d30fd]237       
238       
239        for i_x in range(len(self.x)):
240            if self.idx_unsmeared[i_x]==True:
241                # Identify first and last bin
242                #TODO: refactor this to pass q-values to the smearer
243                # and let it figure out which bin range to use
244                if _first_bin is None:
245                    _first_bin = i_x
246                else:
247                    _last_bin  = i_x
248               
[d5b488b]249        ## Smear theory data
[109e60ab]250        if self.smearer is not None:
[72c7d31]251            fx = self.smearer(fx, _first_bin, _last_bin)
252       
[d5b488b]253        ## Sanity check
[fd0d30fd]254        if numpy.size(self.dy)!= numpy.size(fx):
255            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
256                                                                              numpy.size(fx))
257                                                                             
258        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
[72c7d31]259     
260 
261       
[7d0c1a8]262    def residuals_deriv(self, model, pars=[]):
263        """
264            @return residuals derivatives .
265            @note: in this case just return empty array
266        """
267        return []
268   
269   
270class FitData2D(object):
271    """ Wrapper class  for SANS data """
272    def __init__(self,sans_data2d):
273        """
274            Data can be initital with a data (sans plottable)
275            or with vectors.
276        """
277        self.data=sans_data2d
[415bc97]278        self.image = sans_data2d.data
279        self.err_image = sans_data2d.err_data
[d8a2e31]280        self.x_bins_array= numpy.reshape(sans_data2d.x_bins,
[a96d246]281                                         [len(sans_data2d.x_bins),1])
[d8a2e31]282        self.y_bins_array = numpy.reshape(sans_data2d.y_bins,
[a96d246]283                                          [1,len(sans_data2d.y_bins)])
[d8a2e31]284       
[20d30e9]285        x = max(self.data.xmin, self.data.xmax)
286        y = max(self.data.ymin, self.data.ymax)
287       
288        ## fitting range
[773806e]289        self.qmin = 1e-16
[20d30e9]290        self.qmax = math.sqrt(x*x +y*y)
[70bf68c]291        ## new error image for fitting purpose
292        if self.err_image== None or self.err_image ==[]:
293            self.res_err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
294        else:
295            self.res_err_image = copy.deepcopy(self.err_image)
296        self.res_err_image[self.err_image==0]=1
[d8a2e31]297       
298        self.radius= numpy.sqrt(self.x_bins_array**2 + self.y_bins_array**2)
299        self.index_model = (self.qmin <= self.radius)&(self.radius<= self.qmax)
[7d0c1a8]300       
[20d30e9]301       
302    def setFitRange(self,qmin=None,qmax=None):
[7d0c1a8]303        """ to set the fit range"""
[773806e]304        if qmin==0.0:
305            self.qmin = 1e-16
306        elif qmin!=None:                       
307            self.qmin = qmin           
[eef2e0ed]308        if qmax!=None:
309            self.qmax= qmax
[20d30e9]310     
[7d0c1a8]311       
312    def getFitRange(self):
313        """
314            @return the range of data.x to fit
315        """
[20d30e9]316        return self.qmin, self.qmax
[7d0c1a8]317     
[d8a2e31]318    def residuals(self, fn): 
[fd0d30fd]319       
[1943097]320        res=self.index_model*(self.image - fn([self.x_bins_array,
321                             self.y_bins_array]))/self.res_err_image
[7f81665]322        return res.ravel() 
[0e51519]323       
[fd0d30fd]324 
[7d0c1a8]325    def residuals_deriv(self, model, pars=[]):
326        """
327            @return residuals derivatives .
328            @note: in this case just return empty array
329        """
330        return []
[48882d1]331   
[4bd557d]332class FitAbort(Exception):
333    """
334        Exception raise to stop the fit
335    """
336    print"Creating fit abort Exception"
337
338
[70bf68c]339class SansAssembly:
[ca6d914]340    """
341         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
342    """
[4bd557d]343    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
[ca6d914]344        """
345            @param Model: the model wrapper fro sans -model
346            @param Data: the data wrapper for sans data
347        """
348        self.model = Model
349        self.data  = Data
[e71440c]350        self.paramlist=paramlist
[4bd557d]351        self.curr_thread= curr_thread
[ca6d914]352        self.res=[]
[4bd557d]353        self.func_name="Functor"
[48882d1]354    def chisq(self, params):
355        """
356            Calculates chi^2
357            @param params: list of parameter values
358            @return: chi^2
359        """
360        sum = 0
361        for item in self.res:
362            sum += item*item
[4bd557d]363        if len(self.res)==0:
364            return None
[26cb768]365        return sum/ len(self.res)
[20d30e9]366   
[48882d1]367    def __call__(self,params):
[ca6d914]368        """
369            Compute residuals
370            @param params: value of parameters to fit
371        """
[681f0dc]372        #import thread
[e71440c]373        self.model.setParams(self.paramlist,params)
[48882d1]374        self.res= self.data.residuals(self.model.eval)
[24b8d5c]375        #if self.curr_thread != None :
376        #    try:
377        #        self.curr_thread.isquit()
378        #    except:
379        #        raise FitAbort,"stop leastsqr optimizer"   
[48882d1]380        return self.res
381   
[4c718654]382class FitEngine:
[ee5b04c]383    def __init__(self):
[ca6d914]384        """
385            Base class for scipy and park fit engine
386        """
387        #List of parameter names to fit
[ee5b04c]388        self.paramList=[]
[ca6d914]389        #Dictionnary of fitArrange element (fit problems)
390        self.fitArrangeDict={}
391       
[4c718654]392    def _concatenateData(self, listdata=[]):
393        """ 
394            _concatenateData method concatenates each fields of all data contains ins listdata.
395            @param listdata: list of data
[ca6d914]396            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
397             of data concatenanted
[4c718654]398            @raise: if listdata is empty  will return None
399            @raise: if data in listdata don't contain dy field ,will create an error
400            during fitting
401        """
[109e60ab]402        #TODO: we have to refactor the way we handle data.
403        # We should move away from plottables and move towards the Data1D objects
404        # defined in DataLoader. Data1D allows data manipulations, which should be
405        # used to concatenate.
406        # In the meantime we should switch off the concatenation.
407        #if len(listdata)>1:
408        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
409        #return listdata[0]
410       
[4c718654]411        if listdata==[]:
412            raise ValueError, " data list missing"
413        else:
414            xtemp=[]
415            ytemp=[]
416            dytemp=[]
[48882d1]417            self.mini=None
418            self.maxi=None
[4c718654]419               
[7d0c1a8]420            for item in listdata:
421                data=item.data
[48882d1]422                mini,maxi=data.getFitRange()
423                if self.mini==None and self.maxi==None:
424                    self.mini=mini
425                    self.maxi=maxi
426                else:
427                    if mini < self.mini:
428                        self.mini=mini
429                    if self.maxi < maxi:
430                        self.maxi=maxi
431                       
432                   
[4c718654]433                for i in range(len(data.x)):
434                    xtemp.append(data.x[i])
435                    ytemp.append(data.y[i])
436                    if data.dy is not None and len(data.dy)==len(data.y):   
437                        dytemp.append(data.dy[i])
438                    else:
[ee5b04c]439                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
[20d30e9]440            data= Data(x=xtemp,y=ytemp,dy=dytemp)
[48882d1]441            data.setFitRange(self.mini, self.maxi)
442            return data
[ca6d914]443       
444       
445    def set_model(self,model,Uid,pars=[]):
446        """
447            set a model on a given uid in the fit engine.
448            @param model: the model to fit
449            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
450            @param pars: the list of parameters to fit
451            @note : pars must contains only name of existing model's paramaters
452        """
[f44dbc7]453        if len(pars) >0:
[6831a99]454            if model==None:
[f44dbc7]455                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
[6831a99]456            else:
[aed7c57]457                temp=[]
[ca6d914]458                for item in pars:
459                    if item in model.model.getParamList():
[aed7c57]460                        temp.append(item)
[ca6d914]461                        self.paramList.append(item)
462                    else:
463                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
464                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
465                        return
[6831a99]466            #A fitArrange is already created but contains dList only at Uid
[ca6d914]467            if self.fitArrangeDict.has_key(Uid):
468                self.fitArrangeDict[Uid].set_model(model)
[aed7c57]469                self.fitArrangeDict[Uid].pars= pars
[6831a99]470            else:
471            #no fitArrange object has been create with this Uid
[48882d1]472                fitproblem = FitArrange()
[6831a99]473                fitproblem.set_model(model)
[aed7c57]474                fitproblem.pars= pars
[ca6d914]475                self.fitArrangeDict[Uid] = fitproblem
[aed7c57]476               
[d4b0687]477        else:
[6831a99]478            raise ValueError, "park_integration:missing parameters"
[48882d1]479   
[20d30e9]480    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
[d4b0687]481        """ Receives plottable, creates a list of data to fit,set data
482            in a FitArrange object and adds that object in a dictionary
483            with key Uid.
484            @param data: data added
485            @param Uid: unique key corresponding to a fitArrange object with data
[ca6d914]486        """
[f2817bb]487        if data.__class__.__name__=='Data2D':
[f8ce013]488            fitdata=FitData2D(data)
489        else:
[b461b6d7]490            fitdata=FitData1D(data, smearer)
[20d30e9]491       
492        fitdata.setFitRange(qmin=qmin,qmax=qmax)
[d4b0687]493        #A fitArrange is already created but contains model only at Uid
[ca6d914]494        if self.fitArrangeDict.has_key(Uid):
[f8ce013]495            self.fitArrangeDict[Uid].add_data(fitdata)
[d4b0687]496        else:
497        #no fitArrange object has been create with this Uid
498            fitproblem= FitArrange()
[f8ce013]499            fitproblem.add_data(fitdata)
[ca6d914]500            self.fitArrangeDict[Uid]=fitproblem   
[20d30e9]501   
[d4b0687]502    def get_model(self,Uid):
503        """
504            @param Uid: Uid is key in the dictionary containing the model to return
505            @return  a model at this uid or None if no FitArrange element was created
506            with this Uid
507        """
[ca6d914]508        if self.fitArrangeDict.has_key(Uid):
509            return self.fitArrangeDict[Uid].get_model()
[d4b0687]510        else:
511            return None
512   
513    def remove_Fit_Problem(self,Uid):
514        """remove   fitarrange in Uid"""
[ca6d914]515        if self.fitArrangeDict.has_key(Uid):
516            del self.fitArrangeDict[Uid]
[a9e04aa]517           
518    def select_problem_for_fit(self,Uid,value):
519        """
520            select a couple of model and data at the Uid position in dictionary
521            and set in self.selected value to value
522            @param value: the value to allow fitting. can only have the value one or zero
523        """
524        if self.fitArrangeDict.has_key(Uid):
525             self.fitArrangeDict[Uid].set_to_fit( value)
[eef2e0ed]526             
527             
[a9e04aa]528    def get_problem_to_fit(self,Uid):
529        """
530            return the self.selected value of the fit problem of Uid
531           @param Uid: the Uid of the problem
532        """
533        if self.fitArrangeDict.has_key(Uid):
534             self.fitArrangeDict[Uid].get_to_fit()
[4c718654]535   
[d4b0687]536class FitArrange:
537    def __init__(self):
538        """
539            Class FitArrange contains a set of data for a given model
540            to perform the Fit.FitArrange must contain exactly one model
541            and at least one data for the fit to be performed.
542            model: the model selected by the user
543            Ldata: a list of data what the user wants to fit
544           
545        """
546        self.model = None
547        self.dList =[]
[aed7c57]548        self.pars=[]
[a9e04aa]549        #self.selected  is zero when this fit problem is not schedule to fit
550        #self.selected is 1 when schedule to fit
551        self.selected = 0
[d4b0687]552       
553    def set_model(self,model):
554        """
555            set_model save a copy of the model
556            @param model: the model being set
557        """
558        self.model = model
559       
560    def add_data(self,data):
561        """
562            add_data fill a self.dList with data to fit
563            @param data: Data to add in the list 
564        """
565        if not data in self.dList:
566            self.dList.append(data)
567           
568    def get_model(self):
569        """ @return: saved model """
570        return self.model   
571     
572    def get_data(self):
573        """ @return:  list of data dList"""
[7d0c1a8]574        #return self.dList
575        return self.dList[0] 
[d4b0687]576     
577    def remove_data(self,data):
578        """
579            Remove one element from the list
580            @param data: Data to remove from dList
581        """
582        if data in self.dList:
583            self.dList.remove(data)
[a9e04aa]584    def set_to_fit (self, value=0):
585        """
586           set self.selected to 0 or 1  for other values raise an exception
587           @param value: integer between 0 or 1
588        """
589        self.selected= value
590       
591    def get_to_fit(self):
592        """
593            @return self.selected value
594        """
595        return self.selected
Note: See TracBrowser for help on using the repository browser.