source: sasview/park_integration/AbstractFitEngine.py @ 4043c96

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 4043c96 was 4bb2917, checked in by Mathieu Doucet <doucetm@…>, 15 years ago

park_integration: refactor code using new smearing code.

  • Property mode set to 100644
File size: 20.6 KB
RevLine 
[72c7d31]1import logging, sys
[54c21f50]2import park,numpy,math, copy
[48882d1]3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
[ca6d914]11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
[48882d1]18         
[ca6d914]19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
[48882d1]25   
[ca6d914]26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
[48882d1]31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
[ca6d914]36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
[920a6e5]40        #if not  self.name in self._model.getDispParamList():
41        lo,hi = self._model.details[self.name][1:]
42        if lo is None: lo = -numpy.inf
43        if hi is None: hi = numpy.inf
44        #else:
45            #lo,hi = self._model.details[self.name][1:]
46            #if lo is None: lo = -numpy.inf
47            #if hi is None: hi = numpy.inf
[05f14dd]48        if lo >= hi:
49            raise ValueError,"wrong fit range for parameters"
50       
[48882d1]51        return lo,hi
52   
53    def _setrange(self,r):
[ca6d914]54        """
55            override _setrange of park parameter
56            @param r: the value of the range to set
57        """
[48882d1]58        self._model.details[self.name][1:] = r
59    range = property(_getrange,_setrange)
[a9e04aa]60   
61class Model(park.Model):
[48882d1]62    """
63        PARK wrapper for SANS models.
64    """
[388309d]65    def __init__(self, sans_model, **kw):
[ca6d914]66        """
67            @param sans_model: the sans model to wrap using park interface
68        """
[a9e04aa]69        park.Model.__init__(self, **kw)
[48882d1]70        self.model = sans_model
[ca6d914]71        self.name = sans_model.name
72        #list of parameters names
[48882d1]73        self.sansp = sans_model.getParamList()
[ca6d914]74        #list of park parameter
[48882d1]75        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
[ca6d914]76        #list of parameterset
[48882d1]77        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
78        self.pars=[]
[ca6d914]79 
80 
[48882d1]81    def getParams(self,fitparams):
[ca6d914]82        """
83            return a list of value of paramter to fit
84            @param fitparams: list of paramaters name to fit
85        """
[48882d1]86        list=[]
87        self.pars=[]
88        self.pars=fitparams
89        for item in fitparams:
90            for element in self.parkp:
91                 if element.name ==str(item):
92                     list.append(element.value)
93        return list
94   
[ca6d914]95   
[e71440c]96    def setParams(self,paramlist, params):
[ca6d914]97        """
98            Set value for parameters to fit
99            @param params: list of value for parameters to fit
100        """
[e71440c]101        try:
102            for i in range(len(self.parkp)):
103                for j in range(len(paramlist)):
104                    if self.parkp[i].name==paramlist[j]:
105                        self.parkp[i].value = params[j]
106                        self.model.setParam(self.parkp[i].name,params[j])
107        except:
108            raise
[ca6d914]109 
[48882d1]110    def eval(self,x):
[ca6d914]111        """
112            override eval method of park model.
113            @param x: the x value used to compute a function
114        """
[d8a2e31]115        try:
[fd0d30fd]116                return self.model.evalDistribution(x)
[d8a2e31]117        except:
[fd0d30fd]118                raise
[a9e04aa]119
[b64fa56]120   
[7d0c1a8]121class FitData1D(object):
122    """ Wrapper class  for SANS data """
[b461b6d7]123    def __init__(self,sans_data1d, smearer=None):
[7d0c1a8]124        """
125            Data can be initital with a data (sans plottable)
126            or with vectors.
[109e60ab]127           
128            self.smearer is an object of class QSmearer or SlitSmearer
129            that will smear the theory data (slit smearing or resolution
130            smearing) when set.
131           
132            The proper way to set the smearing object would be to
133            do the following:
134           
135            from DataLoader.qsmearing import smear_selection
136            fitdata1d = FitData1D(some_data)
137            fitdata1d.smearer = smear_selection(some_data)
138           
139            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
140           
141            Setting it back to None will turn smearing off.
142           
[7d0c1a8]143        """
[b461b6d7]144       
145        self.smearer = smearer
146     
[109e60ab]147        # Initialize from Data1D object
[7d0c1a8]148        self.data=sans_data1d
[fd0d30fd]149        self.x= numpy.array(sans_data1d.x)
150        self.y= numpy.array(sans_data1d.y)
[72c7d31]151        self.dx= sans_data1d.dx
[fd0d30fd]152        if sans_data1d.dy ==None or sans_data1d.dy==[]:
153            self.dy= numpy.zeros(len(y)) 
154        else:
155            self.dy= numpy.asarray(sans_data1d.dy)
156     
157        # For fitting purposes, replace zero errors by 1
158        #TODO: check validity for the rare case where only
159        # a few points have zero errors
160        self.dy[self.dy==0]=1
[109e60ab]161       
162        ## Min Q-value
[4bd557d]163        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
164        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
[773806e]165            self.qmin = min(self.data.x[self.data.x!=0])
166        else:                             
167            self.qmin= min (self.data.x)
[109e60ab]168        ## Max Q-value
[20d30e9]169        self.qmax= max (self.data.x)
[058b2d7]170       
[72c7d31]171        # Range used for input to smearing
172        self._qmin_unsmeared = self.qmin
173        self._qmax_unsmeared = self.qmax
[fd0d30fd]174        # Identify the bin range for the unsmeared and smeared spaces
175        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
176        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
177 
[72c7d31]178       
179       
[20d30e9]180    def setFitRange(self,qmin=None,qmax=None):
[7d0c1a8]181        """ to set the fit range"""
[09975cbb]182        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
[773806e]183        #ToDo: Fix this.
[90db8e8]184        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
[773806e]185            self.qmin = min(self.data.x[self.data.x!=0])
186        elif qmin!=None:                       
187            self.qmin = qmin           
188
[eef2e0ed]189        if qmax !=None:
190            self.qmax = qmax
[72c7d31]191           
[4bb2917]192        # Determine the range needed in unsmeared-Q to cover
193        # the smeared Q range
[72c7d31]194        self._qmin_unsmeared = self.qmin
195        self._qmax_unsmeared = self.qmax   
196       
[4bb2917]197        self._first_unsmeared_bin = 0
198        self._last_unsmeared_bin  = len(self.data.x)-1
199       
200        if self.smearer!=None:
201            self._first_unsmeared_bin, self._last_unsmeared_bin = self.smearer.get_bin_range(self.qmin, self.qmax)
202            self._qmin_unsmeared = self.data.x[self._first_unsmeared_bin]
203            self._qmax_unsmeared = self.data.x[self._last_unsmeared_bin]
204           
[fd0d30fd]205        # Identify the bin range for the unsmeared and smeared spaces
206        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
207        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
208 
[7d0c1a8]209       
210    def getFitRange(self):
211        """
212            @return the range of data.x to fit
213        """
214        return self.qmin, self.qmax
[72c7d31]215       
[7d0c1a8]216    def residuals(self, fn):
[72c7d31]217        """
218            Compute residuals.
219           
220            If self.smearer has been set, use if to smear
221            the data before computing chi squared.
222           
223            @param fn: function that return model value
224            @return residuals
[109e60ab]225        """
226        # Compute theory data f(x)
[fd0d30fd]227        fx= numpy.zeros(len(self.x))
[7e752fe]228        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
[fd0d30fd]229       
[d5b488b]230        ## Smear theory data
[109e60ab]231        if self.smearer is not None:
[4bb2917]232            fx = self.smearer(fx, self._first_unsmeared_bin, self._last_unsmeared_bin)
[72c7d31]233       
[d5b488b]234        ## Sanity check
[fd0d30fd]235        if numpy.size(self.dy)!= numpy.size(fx):
236            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
237                                                                              numpy.size(fx))
238                                                                             
239        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
[72c7d31]240     
241 
242       
[7d0c1a8]243    def residuals_deriv(self, model, pars=[]):
244        """
245            @return residuals derivatives .
246            @note: in this case just return empty array
247        """
248        return []
249   
250   
251class FitData2D(object):
252    """ Wrapper class  for SANS data """
253    def __init__(self,sans_data2d):
254        """
255            Data can be initital with a data (sans plottable)
256            or with vectors.
257        """
258        self.data=sans_data2d
[415bc97]259        self.image = sans_data2d.data
260        self.err_image = sans_data2d.err_data
[d8a2e31]261        self.x_bins_array= numpy.reshape(sans_data2d.x_bins,
[f1c79d2]262                                         [1,len(sans_data2d.x_bins)])
[d8a2e31]263        self.y_bins_array = numpy.reshape(sans_data2d.y_bins,
[f1c79d2]264                                          [len(sans_data2d.y_bins),1])
[d8a2e31]265       
[20d30e9]266        x = max(self.data.xmin, self.data.xmax)
267        y = max(self.data.ymin, self.data.ymax)
268       
269        ## fitting range
[773806e]270        self.qmin = 1e-16
[20d30e9]271        self.qmax = math.sqrt(x*x +y*y)
[70bf68c]272        ## new error image for fitting purpose
273        if self.err_image== None or self.err_image ==[]:
274            self.res_err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
275        else:
276            self.res_err_image = copy.deepcopy(self.err_image)
277        self.res_err_image[self.err_image==0]=1
[d8a2e31]278       
279        self.radius= numpy.sqrt(self.x_bins_array**2 + self.y_bins_array**2)
280        self.index_model = (self.qmin <= self.radius)&(self.radius<= self.qmax)
[7d0c1a8]281       
[20d30e9]282       
283    def setFitRange(self,qmin=None,qmax=None):
[7d0c1a8]284        """ to set the fit range"""
[773806e]285        if qmin==0.0:
286            self.qmin = 1e-16
287        elif qmin!=None:                       
288            self.qmin = qmin           
[eef2e0ed]289        if qmax!=None:
290            self.qmax= qmax
[20d30e9]291     
[7d0c1a8]292       
293    def getFitRange(self):
294        """
295            @return the range of data.x to fit
296        """
[20d30e9]297        return self.qmin, self.qmax
[7d0c1a8]298     
[d8a2e31]299    def residuals(self, fn): 
[fd0d30fd]300       
[1943097]301        res=self.index_model*(self.image - fn([self.x_bins_array,
302                             self.y_bins_array]))/self.res_err_image
[7f81665]303        return res.ravel() 
[0e51519]304       
[fd0d30fd]305 
[7d0c1a8]306    def residuals_deriv(self, model, pars=[]):
307        """
308            @return residuals derivatives .
309            @note: in this case just return empty array
310        """
311        return []
[48882d1]312   
[4bd557d]313class FitAbort(Exception):
314    """
315        Exception raise to stop the fit
316    """
317    print"Creating fit abort Exception"
318
319
[70bf68c]320class SansAssembly:
[ca6d914]321    """
322         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
323    """
[4bd557d]324    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
[ca6d914]325        """
326            @param Model: the model wrapper fro sans -model
327            @param Data: the data wrapper for sans data
328        """
329        self.model = Model
330        self.data  = Data
[e71440c]331        self.paramlist=paramlist
[4bd557d]332        self.curr_thread= curr_thread
[ca6d914]333        self.res=[]
[4bd557d]334        self.func_name="Functor"
[48882d1]335    def chisq(self, params):
336        """
337            Calculates chi^2
338            @param params: list of parameter values
339            @return: chi^2
340        """
341        sum = 0
342        for item in self.res:
343            sum += item*item
[4bd557d]344        if len(self.res)==0:
345            return None
[26cb768]346        return sum/ len(self.res)
[20d30e9]347   
[48882d1]348    def __call__(self,params):
[ca6d914]349        """
350            Compute residuals
351            @param params: value of parameters to fit
352        """
[681f0dc]353        #import thread
[e71440c]354        self.model.setParams(self.paramlist,params)
[48882d1]355        self.res= self.data.residuals(self.model.eval)
[24b8d5c]356        #if self.curr_thread != None :
357        #    try:
358        #        self.curr_thread.isquit()
359        #    except:
360        #        raise FitAbort,"stop leastsqr optimizer"   
[48882d1]361        return self.res
362   
[4c718654]363class FitEngine:
[ee5b04c]364    def __init__(self):
[ca6d914]365        """
366            Base class for scipy and park fit engine
367        """
368        #List of parameter names to fit
[ee5b04c]369        self.paramList=[]
[ca6d914]370        #Dictionnary of fitArrange element (fit problems)
371        self.fitArrangeDict={}
372       
[4c718654]373    def _concatenateData(self, listdata=[]):
374        """ 
375            _concatenateData method concatenates each fields of all data contains ins listdata.
376            @param listdata: list of data
[ca6d914]377            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
378             of data concatenanted
[4c718654]379            @raise: if listdata is empty  will return None
380            @raise: if data in listdata don't contain dy field ,will create an error
381            during fitting
382        """
[109e60ab]383        #TODO: we have to refactor the way we handle data.
384        # We should move away from plottables and move towards the Data1D objects
385        # defined in DataLoader. Data1D allows data manipulations, which should be
386        # used to concatenate.
387        # In the meantime we should switch off the concatenation.
388        #if len(listdata)>1:
389        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
390        #return listdata[0]
391       
[4c718654]392        if listdata==[]:
393            raise ValueError, " data list missing"
394        else:
395            xtemp=[]
396            ytemp=[]
397            dytemp=[]
[48882d1]398            self.mini=None
399            self.maxi=None
[4c718654]400               
[7d0c1a8]401            for item in listdata:
402                data=item.data
[48882d1]403                mini,maxi=data.getFitRange()
404                if self.mini==None and self.maxi==None:
405                    self.mini=mini
406                    self.maxi=maxi
407                else:
408                    if mini < self.mini:
409                        self.mini=mini
410                    if self.maxi < maxi:
411                        self.maxi=maxi
412                       
413                   
[4c718654]414                for i in range(len(data.x)):
415                    xtemp.append(data.x[i])
416                    ytemp.append(data.y[i])
417                    if data.dy is not None and len(data.dy)==len(data.y):   
418                        dytemp.append(data.dy[i])
419                    else:
[ee5b04c]420                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
[20d30e9]421            data= Data(x=xtemp,y=ytemp,dy=dytemp)
[48882d1]422            data.setFitRange(self.mini, self.maxi)
423            return data
[ca6d914]424       
425       
426    def set_model(self,model,Uid,pars=[]):
427        """
428            set a model on a given uid in the fit engine.
429            @param model: the model to fit
430            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
431            @param pars: the list of parameters to fit
432            @note : pars must contains only name of existing model's paramaters
433        """
[f44dbc7]434        if len(pars) >0:
[6831a99]435            if model==None:
[f44dbc7]436                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
[6831a99]437            else:
[aed7c57]438                temp=[]
[ca6d914]439                for item in pars:
440                    if item in model.model.getParamList():
[aed7c57]441                        temp.append(item)
[ca6d914]442                        self.paramList.append(item)
443                    else:
444                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
445                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
446                        return
[6831a99]447            #A fitArrange is already created but contains dList only at Uid
[ca6d914]448            if self.fitArrangeDict.has_key(Uid):
449                self.fitArrangeDict[Uid].set_model(model)
[aed7c57]450                self.fitArrangeDict[Uid].pars= pars
[6831a99]451            else:
452            #no fitArrange object has been create with this Uid
[48882d1]453                fitproblem = FitArrange()
[6831a99]454                fitproblem.set_model(model)
[aed7c57]455                fitproblem.pars= pars
[ca6d914]456                self.fitArrangeDict[Uid] = fitproblem
[aed7c57]457               
[d4b0687]458        else:
[6831a99]459            raise ValueError, "park_integration:missing parameters"
[48882d1]460   
[20d30e9]461    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
[d4b0687]462        """ Receives plottable, creates a list of data to fit,set data
463            in a FitArrange object and adds that object in a dictionary
464            with key Uid.
465            @param data: data added
466            @param Uid: unique key corresponding to a fitArrange object with data
[ca6d914]467        """
[f2817bb]468        if data.__class__.__name__=='Data2D':
[f8ce013]469            fitdata=FitData2D(data)
470        else:
[b461b6d7]471            fitdata=FitData1D(data, smearer)
[20d30e9]472       
473        fitdata.setFitRange(qmin=qmin,qmax=qmax)
[d4b0687]474        #A fitArrange is already created but contains model only at Uid
[ca6d914]475        if self.fitArrangeDict.has_key(Uid):
[f8ce013]476            self.fitArrangeDict[Uid].add_data(fitdata)
[d4b0687]477        else:
478        #no fitArrange object has been create with this Uid
479            fitproblem= FitArrange()
[f8ce013]480            fitproblem.add_data(fitdata)
[ca6d914]481            self.fitArrangeDict[Uid]=fitproblem   
[20d30e9]482   
[d4b0687]483    def get_model(self,Uid):
484        """
485            @param Uid: Uid is key in the dictionary containing the model to return
486            @return  a model at this uid or None if no FitArrange element was created
487            with this Uid
488        """
[ca6d914]489        if self.fitArrangeDict.has_key(Uid):
490            return self.fitArrangeDict[Uid].get_model()
[d4b0687]491        else:
492            return None
493   
494    def remove_Fit_Problem(self,Uid):
495        """remove   fitarrange in Uid"""
[ca6d914]496        if self.fitArrangeDict.has_key(Uid):
497            del self.fitArrangeDict[Uid]
[a9e04aa]498           
499    def select_problem_for_fit(self,Uid,value):
500        """
501            select a couple of model and data at the Uid position in dictionary
502            and set in self.selected value to value
503            @param value: the value to allow fitting. can only have the value one or zero
504        """
505        if self.fitArrangeDict.has_key(Uid):
506             self.fitArrangeDict[Uid].set_to_fit( value)
[eef2e0ed]507             
508             
[a9e04aa]509    def get_problem_to_fit(self,Uid):
510        """
511            return the self.selected value of the fit problem of Uid
512           @param Uid: the Uid of the problem
513        """
514        if self.fitArrangeDict.has_key(Uid):
515             self.fitArrangeDict[Uid].get_to_fit()
[4c718654]516   
[d4b0687]517class FitArrange:
518    def __init__(self):
519        """
520            Class FitArrange contains a set of data for a given model
521            to perform the Fit.FitArrange must contain exactly one model
522            and at least one data for the fit to be performed.
523            model: the model selected by the user
524            Ldata: a list of data what the user wants to fit
525           
526        """
527        self.model = None
528        self.dList =[]
[aed7c57]529        self.pars=[]
[a9e04aa]530        #self.selected  is zero when this fit problem is not schedule to fit
531        #self.selected is 1 when schedule to fit
532        self.selected = 0
[d4b0687]533       
534    def set_model(self,model):
535        """
536            set_model save a copy of the model
537            @param model: the model being set
538        """
539        self.model = model
540       
541    def add_data(self,data):
542        """
543            add_data fill a self.dList with data to fit
544            @param data: Data to add in the list 
545        """
546        if not data in self.dList:
547            self.dList.append(data)
548           
549    def get_model(self):
550        """ @return: saved model """
551        return self.model   
552     
553    def get_data(self):
554        """ @return:  list of data dList"""
[7d0c1a8]555        #return self.dList
556        return self.dList[0] 
[d4b0687]557     
558    def remove_data(self,data):
559        """
560            Remove one element from the list
561            @param data: Data to remove from dList
562        """
563        if data in self.dList:
564            self.dList.remove(data)
[a9e04aa]565    def set_to_fit (self, value=0):
566        """
567           set self.selected to 0 or 1  for other values raise an exception
568           @param value: integer between 0 or 1
569        """
570        self.selected= value
571       
572    def get_to_fit(self):
573        """
574            @return self.selected value
575        """
576        return self.selected
Note: See TracBrowser for help on using the repository browser.