source: sasview/park_integration/AbstractFitEngine.py @ ceb89ac

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since ceb89ac was 24b8d5c, checked in by Gervaise Alina <gervyh@…>, 16 years ago

remove stop fit option

  • Property mode set to 100644
File size: 21.3 KB
RevLine 
[4c718654]1
[54c21f50]2import park,numpy,math, copy
[48882d1]3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
[ca6d914]11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
[48882d1]18         
[ca6d914]19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
[48882d1]25   
[ca6d914]26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
[48882d1]31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
[ca6d914]36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
[c79ee796]40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
[05f14dd]47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
[48882d1]50        return lo,hi
51   
52    def _setrange(self,r):
[ca6d914]53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
[48882d1]57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
[a9e04aa]59   
60class Model(park.Model):
[48882d1]61    """
62        PARK wrapper for SANS models.
63    """
[388309d]64    def __init__(self, sans_model, **kw):
[ca6d914]65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
[a9e04aa]68        park.Model.__init__(self, **kw)
[48882d1]69        self.model = sans_model
[ca6d914]70        self.name = sans_model.name
71        #list of parameters names
[48882d1]72        self.sansp = sans_model.getParamList()
[ca6d914]73        #list of park parameter
[48882d1]74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
[ca6d914]75        #list of parameterset
[48882d1]76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
[ca6d914]78 
79 
[48882d1]80    def getParams(self,fitparams):
[ca6d914]81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
[48882d1]85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
[ca6d914]94   
[e71440c]95    def setParams(self,paramlist, params):
[ca6d914]96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
[e71440c]100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
[ca6d914]108 
[48882d1]109    def eval(self,x):
[ca6d914]110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
[48882d1]114        return self.model.runXY(x)
[388309d]115   
116   
[a9e04aa]117
118
[48882d1]119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
[ca6d914]122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
[48882d1]126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
[ca6d914]143       
[48882d1]144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
[773806e]146       
147        self.qmin=mini           
[48882d1]148        self.qmax=maxi
[ca6d914]149       
150       
[48882d1]151    def getFitRange(self):
[ca6d914]152        """
153            @return the range of data.x to fit
154        """
155        return self.qmin, self.qmax
156     
157     
[48882d1]158    def residuals(self, fn):
159        """ @param fn: function that return model value
160            @return residuals
161        """
162        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
163        if self.qmin==None and self.qmax==None: 
[ca6d914]164            fx =numpy.asarray([fn(v) for v in x])
[48882d1]165            return (y - fx)/dy
166        else:
167            idx = (x>=self.qmin) & (x <= self.qmax)
[ca6d914]168            fx = numpy.asarray([fn(item)for item in x[idx ]])
[48882d1]169            return (y[idx] - fx)/dy[idx]
[e71440c]170       
[48882d1]171    def residuals_deriv(self, model, pars=[]):
172        """
173            @return residuals derivatives .
174            @note: in this case just return empty array
175        """
176        return []
[b64fa56]177   
178   
[7d0c1a8]179class FitData1D(object):
180    """ Wrapper class  for SANS data """
[b461b6d7]181    def __init__(self,sans_data1d, smearer=None):
[7d0c1a8]182        """
183            Data can be initital with a data (sans plottable)
184            or with vectors.
[109e60ab]185           
186            self.smearer is an object of class QSmearer or SlitSmearer
187            that will smear the theory data (slit smearing or resolution
188            smearing) when set.
189           
190            The proper way to set the smearing object would be to
191            do the following:
192           
193            from DataLoader.qsmearing import smear_selection
194            fitdata1d = FitData1D(some_data)
195            fitdata1d.smearer = smear_selection(some_data)
196           
197            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
198           
199            Setting it back to None will turn smearing off.
200           
[7d0c1a8]201        """
[b461b6d7]202       
203        self.smearer = smearer
204     
[109e60ab]205        # Initialize from Data1D object
[7d0c1a8]206        self.data=sans_data1d
207        self.x= sans_data1d.x
208        self.y= sans_data1d.y
209        self.dx= sans_data1d.dx
210        self.dy= sans_data1d.dy
[109e60ab]211       
212        ## Min Q-value
[4bd557d]213        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
214        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
[773806e]215            self.qmin = min(self.data.x[self.data.x!=0])
216        else:                             
217            self.qmin= min (self.data.x)
[109e60ab]218        ## Max Q-value
[20d30e9]219        self.qmax= max (self.data.x)
[7d0c1a8]220       
221       
[20d30e9]222    def setFitRange(self,qmin=None,qmax=None):
[7d0c1a8]223        """ to set the fit range"""
[773806e]224       
[09975cbb]225        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
[773806e]226        #ToDo: Fix this.
[90db8e8]227        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
[773806e]228            self.qmin = min(self.data.x[self.data.x!=0])
229        elif qmin!=None:                       
230            self.qmin = qmin           
231
[eef2e0ed]232        if qmax !=None:
233            self.qmax = qmax
[7d0c1a8]234       
235       
236    def getFitRange(self):
237        """
238            @return the range of data.x to fit
239        """
240        return self.qmin, self.qmax
241     
242     
243    def residuals(self, fn):
[109e60ab]244        """
245            Compute residuals.
246           
247            If self.smearer has been set, use if to smear
248            the data before computing chi squared.
249           
250            @param fn: function that return model value
251            @return residuals
252        """
[b64fa56]253        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
254        if self.dy ==None or self.dy==[]:
255            dy= numpy.zeros(len(y)) 
256        else:
[54c21f50]257            dy= copy.deepcopy(self.dy)
258            dy= numpy.asarray(dy)
[d5b488b]259     
[b64fa56]260        dy[dy==0]=1
[d5b488b]261        #idx = (x>=self.qmin) & (x <= self.qmax)
[20d30e9]262 
[109e60ab]263        # Compute theory data f(x)
[d5b488b]264        #fx = numpy.zeros(len(x))
265        tempy=[]
266        tempfx=[]
267        tempdy=[]
268        #fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
269        for i_x in  range(len(x)):
270            try:
271                if self.qmin <=x[i_x] and x[i_x]<=self.qmax:
272                    value= fn(x[i_x])
273                    tempfx.append( value)
274                    tempy.append(y[i_x])
275                    tempdy.append(dy[i_x])
276            except:
277                ## skip error for model.run(x)
278                pass
279                 
280        ## Smear theory data
[109e60ab]281        if self.smearer is not None:
[d5b488b]282            tempfx = self.smearer(tempfx)
283        newy= numpy.asarray(tempy)
284        newfx= numpy.asarray(tempfx)
285        newdy= numpy.asarray(tempdy)
[20d30e9]286       
[d5b488b]287       
288        ## Sanity check
289        if numpy.size(newdy)!= numpy.size(newfx):
[109e60ab]290            raise RuntimeError, "FitData1D: invalid error array"
[d5b488b]291        #return (y[idx] - fx[idx])/dy[idx]
292       
293        return (newy- newfx)/newdy
[109e60ab]294     
[20d30e9]295 
[7d0c1a8]296       
297    def residuals_deriv(self, model, pars=[]):
298        """
299            @return residuals derivatives .
300            @note: in this case just return empty array
301        """
302        return []
303   
304   
305class FitData2D(object):
306    """ Wrapper class  for SANS data """
307    def __init__(self,sans_data2d):
308        """
309            Data can be initital with a data (sans plottable)
310            or with vectors.
311        """
312        self.data=sans_data2d
[415bc97]313        self.image = sans_data2d.data
314        self.err_image = sans_data2d.err_data
[7d0c1a8]315        self.x_bins= sans_data2d.x_bins
316        self.y_bins= sans_data2d.y_bins
317       
[20d30e9]318        x = max(self.data.xmin, self.data.xmax)
319        y = max(self.data.ymin, self.data.ymax)
320       
321        ## fitting range
[773806e]322        self.qmin = 1e-16
[20d30e9]323        self.qmax = math.sqrt(x*x +y*y)
[7d0c1a8]324       
325       
[20d30e9]326       
327    def setFitRange(self,qmin=None,qmax=None):
[7d0c1a8]328        """ to set the fit range"""
[773806e]329        if qmin==0.0:
330            self.qmin = 1e-16
331        elif qmin!=None:                       
332            self.qmin = qmin           
[eef2e0ed]333        if qmax!=None:
334            self.qmax= qmax
[20d30e9]335     
[7d0c1a8]336       
337    def getFitRange(self):
338        """
339            @return the range of data.x to fit
340        """
[20d30e9]341        return self.qmin, self.qmax
[7d0c1a8]342     
343     
344    def residuals(self, fn):
345        """ @param fn: function that return model value
346            @return residuals
347        """
348        res=[]
[b64fa56]349        if self.err_image== None or self.err_image ==[]:
[54c21f50]350            err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
351        else:
352            err_image = copy.deepcopy(self.err_image)
353           
354        err_image[err_image==0]=1
[fff74cb]355        for i in range(len(self.x_bins)):
356            for j in range(len(self.y_bins)):
357                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
358                radius= math.sqrt(temp)
[20d30e9]359                if self.qmin <= radius and radius <= self.qmax:
360                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
[54c21f50]361                            /err_image[j][i] )
[0e51519]362       
363        return numpy.array(res)
364       
365         
[7d0c1a8]366    def residuals_deriv(self, model, pars=[]):
367        """
368            @return residuals derivatives .
369            @note: in this case just return empty array
370        """
371        return []
[48882d1]372   
[4bd557d]373class FitAbort(Exception):
374    """
375        Exception raise to stop the fit
376    """
377    print"Creating fit abort Exception"
378
379
[48882d1]380class sansAssembly:
[ca6d914]381    """
382         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
383    """
[4bd557d]384    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
[ca6d914]385        """
386            @param Model: the model wrapper fro sans -model
387            @param Data: the data wrapper for sans data
388        """
389        self.model = Model
390        self.data  = Data
[e71440c]391        self.paramlist=paramlist
[4bd557d]392        self.curr_thread= curr_thread
[ca6d914]393        self.res=[]
[4bd557d]394        self.func_name="Functor"
[48882d1]395    def chisq(self, params):
396        """
397            Calculates chi^2
398            @param params: list of parameter values
399            @return: chi^2
400        """
401        sum = 0
402        for item in self.res:
403            sum += item*item
[4bd557d]404        if len(self.res)==0:
405            return None
[26cb768]406        return sum/ len(self.res)
[20d30e9]407   
[48882d1]408    def __call__(self,params):
[ca6d914]409        """
410            Compute residuals
411            @param params: value of parameters to fit
412        """
[681f0dc]413        #import thread
[e71440c]414        self.model.setParams(self.paramlist,params)
[48882d1]415        self.res= self.data.residuals(self.model.eval)
[24b8d5c]416        #if self.curr_thread != None :
417        #    try:
418        #        self.curr_thread.isquit()
419        #    except:
420        #        raise FitAbort,"stop leastsqr optimizer"   
[48882d1]421        return self.res
422   
[4c718654]423class FitEngine:
[ee5b04c]424    def __init__(self):
[ca6d914]425        """
426            Base class for scipy and park fit engine
427        """
428        #List of parameter names to fit
[ee5b04c]429        self.paramList=[]
[ca6d914]430        #Dictionnary of fitArrange element (fit problems)
431        self.fitArrangeDict={}
432       
[4c718654]433    def _concatenateData(self, listdata=[]):
434        """ 
435            _concatenateData method concatenates each fields of all data contains ins listdata.
436            @param listdata: list of data
[ca6d914]437            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
438             of data concatenanted
[4c718654]439            @raise: if listdata is empty  will return None
440            @raise: if data in listdata don't contain dy field ,will create an error
441            during fitting
442        """
[109e60ab]443        #TODO: we have to refactor the way we handle data.
444        # We should move away from plottables and move towards the Data1D objects
445        # defined in DataLoader. Data1D allows data manipulations, which should be
446        # used to concatenate.
447        # In the meantime we should switch off the concatenation.
448        #if len(listdata)>1:
449        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
450        #return listdata[0]
451       
[4c718654]452        if listdata==[]:
453            raise ValueError, " data list missing"
454        else:
455            xtemp=[]
456            ytemp=[]
457            dytemp=[]
[48882d1]458            self.mini=None
459            self.maxi=None
[4c718654]460               
[7d0c1a8]461            for item in listdata:
462                data=item.data
[48882d1]463                mini,maxi=data.getFitRange()
464                if self.mini==None and self.maxi==None:
465                    self.mini=mini
466                    self.maxi=maxi
467                else:
468                    if mini < self.mini:
469                        self.mini=mini
470                    if self.maxi < maxi:
471                        self.maxi=maxi
472                       
473                   
[4c718654]474                for i in range(len(data.x)):
475                    xtemp.append(data.x[i])
476                    ytemp.append(data.y[i])
477                    if data.dy is not None and len(data.dy)==len(data.y):   
478                        dytemp.append(data.dy[i])
479                    else:
[ee5b04c]480                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
[20d30e9]481            data= Data(x=xtemp,y=ytemp,dy=dytemp)
[48882d1]482            data.setFitRange(self.mini, self.maxi)
483            return data
[ca6d914]484       
485       
486    def set_model(self,model,Uid,pars=[]):
487        """
488            set a model on a given uid in the fit engine.
489            @param model: the model to fit
490            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
491            @param pars: the list of parameters to fit
492            @note : pars must contains only name of existing model's paramaters
493        """
[f44dbc7]494        if len(pars) >0:
[6831a99]495            if model==None:
[f44dbc7]496                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
[6831a99]497            else:
[aed7c57]498                temp=[]
[ca6d914]499                for item in pars:
500                    if item in model.model.getParamList():
[aed7c57]501                        temp.append(item)
[ca6d914]502                        self.paramList.append(item)
503                    else:
504                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
505                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
506                        return
[6831a99]507            #A fitArrange is already created but contains dList only at Uid
[ca6d914]508            if self.fitArrangeDict.has_key(Uid):
509                self.fitArrangeDict[Uid].set_model(model)
[aed7c57]510                self.fitArrangeDict[Uid].pars= pars
[6831a99]511            else:
512            #no fitArrange object has been create with this Uid
[48882d1]513                fitproblem = FitArrange()
[6831a99]514                fitproblem.set_model(model)
[aed7c57]515                fitproblem.pars= pars
[ca6d914]516                self.fitArrangeDict[Uid] = fitproblem
[aed7c57]517               
[d4b0687]518        else:
[6831a99]519            raise ValueError, "park_integration:missing parameters"
[48882d1]520   
[20d30e9]521    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
[d4b0687]522        """ Receives plottable, creates a list of data to fit,set data
523            in a FitArrange object and adds that object in a dictionary
524            with key Uid.
525            @param data: data added
526            @param Uid: unique key corresponding to a fitArrange object with data
[ca6d914]527        """
[f2817bb]528        if data.__class__.__name__=='Data2D':
[f8ce013]529            fitdata=FitData2D(data)
530        else:
[b461b6d7]531            fitdata=FitData1D(data, smearer)
[20d30e9]532       
533        fitdata.setFitRange(qmin=qmin,qmax=qmax)
[d4b0687]534        #A fitArrange is already created but contains model only at Uid
[ca6d914]535        if self.fitArrangeDict.has_key(Uid):
[f8ce013]536            self.fitArrangeDict[Uid].add_data(fitdata)
[d4b0687]537        else:
538        #no fitArrange object has been create with this Uid
539            fitproblem= FitArrange()
[f8ce013]540            fitproblem.add_data(fitdata)
[ca6d914]541            self.fitArrangeDict[Uid]=fitproblem   
[20d30e9]542   
[d4b0687]543    def get_model(self,Uid):
544        """
545            @param Uid: Uid is key in the dictionary containing the model to return
546            @return  a model at this uid or None if no FitArrange element was created
547            with this Uid
548        """
[ca6d914]549        if self.fitArrangeDict.has_key(Uid):
550            return self.fitArrangeDict[Uid].get_model()
[d4b0687]551        else:
552            return None
553   
554    def remove_Fit_Problem(self,Uid):
555        """remove   fitarrange in Uid"""
[ca6d914]556        if self.fitArrangeDict.has_key(Uid):
557            del self.fitArrangeDict[Uid]
[a9e04aa]558           
559    def select_problem_for_fit(self,Uid,value):
560        """
561            select a couple of model and data at the Uid position in dictionary
562            and set in self.selected value to value
563            @param value: the value to allow fitting. can only have the value one or zero
564        """
565        if self.fitArrangeDict.has_key(Uid):
566             self.fitArrangeDict[Uid].set_to_fit( value)
[eef2e0ed]567             
568             
[a9e04aa]569    def get_problem_to_fit(self,Uid):
570        """
571            return the self.selected value of the fit problem of Uid
572           @param Uid: the Uid of the problem
573        """
574        if self.fitArrangeDict.has_key(Uid):
575             self.fitArrangeDict[Uid].get_to_fit()
[4c718654]576   
[d4b0687]577class FitArrange:
578    def __init__(self):
579        """
580            Class FitArrange contains a set of data for a given model
581            to perform the Fit.FitArrange must contain exactly one model
582            and at least one data for the fit to be performed.
583            model: the model selected by the user
584            Ldata: a list of data what the user wants to fit
585           
586        """
587        self.model = None
588        self.dList =[]
[aed7c57]589        self.pars=[]
[a9e04aa]590        #self.selected  is zero when this fit problem is not schedule to fit
591        #self.selected is 1 when schedule to fit
592        self.selected = 0
[d4b0687]593       
594    def set_model(self,model):
595        """
596            set_model save a copy of the model
597            @param model: the model being set
598        """
599        self.model = model
600       
601    def add_data(self,data):
602        """
603            add_data fill a self.dList with data to fit
604            @param data: Data to add in the list 
605        """
606        if not data in self.dList:
607            self.dList.append(data)
608           
609    def get_model(self):
610        """ @return: saved model """
611        return self.model   
612     
613    def get_data(self):
614        """ @return:  list of data dList"""
[7d0c1a8]615        #return self.dList
616        return self.dList[0] 
[d4b0687]617     
618    def remove_data(self,data):
619        """
620            Remove one element from the list
621            @param data: Data to remove from dList
622        """
623        if data in self.dList:
624            self.dList.remove(data)
[a9e04aa]625    def set_to_fit (self, value=0):
626        """
627           set self.selected to 0 or 1  for other values raise an exception
628           @param value: integer between 0 or 1
629        """
630        self.selected= value
631       
632    def get_to_fit(self):
633        """
634            @return self.selected value
635        """
636        return self.selected
[94b44293]637   
[4c718654]638
639
640   
Note: See TracBrowser for help on using the repository browser.