source: sasview/park_integration/AbstractFitEngine.py @ 3c939e5

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 3c939e5 was 6963aa3, checked in by Gervaise Alina <gervyh@…>, 15 years ago

remove stop option

  • Property mode set to 100644
File size: 20.7 KB
Line 
1
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146       
147        self.qmin=mini           
148        self.qmax=maxi
149       
150       
151    def getFitRange(self):
152        """
153            @return the range of data.x to fit
154        """
155        return self.qmin, self.qmax
156     
157     
158    def residuals(self, fn):
159        """ @param fn: function that return model value
160            @return residuals
161        """
162        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
163        if self.qmin==None and self.qmax==None: 
164            fx =numpy.asarray([fn(v) for v in x])
165            return (y - fx)/dy
166        else:
167            idx = (x>=self.qmin) & (x <= self.qmax)
168            fx = numpy.asarray([fn(item)for item in x[idx ]])
169            return (y[idx] - fx)/dy[idx]
170       
171    def residuals_deriv(self, model, pars=[]):
172        """
173            @return residuals derivatives .
174            @note: in this case just return empty array
175        """
176        return []
177   
178   
179class FitData1D(object):
180    """ Wrapper class  for SANS data """
181    def __init__(self,sans_data1d, smearer=None):
182        """
183            Data can be initital with a data (sans plottable)
184            or with vectors.
185           
186            self.smearer is an object of class QSmearer or SlitSmearer
187            that will smear the theory data (slit smearing or resolution
188            smearing) when set.
189           
190            The proper way to set the smearing object would be to
191            do the following:
192           
193            from DataLoader.qsmearing import smear_selection
194            fitdata1d = FitData1D(some_data)
195            fitdata1d.smearer = smear_selection(some_data)
196           
197            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
198           
199            Setting it back to None will turn smearing off.
200           
201        """
202       
203        self.smearer = smearer
204     
205        # Initialize from Data1D object
206        self.data=sans_data1d
207        self.x= sans_data1d.x
208        self.y= sans_data1d.y
209        self.dx= sans_data1d.dx
210        self.dy= sans_data1d.dy
211       
212        ## Min Q-value
213        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
214        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
215            self.qmin = min(self.data.x[self.data.x!=0])
216        else:                             
217            self.qmin= min (self.data.x)
218        ## Max Q-value
219        self.qmax= max (self.data.x)
220       
221       
222    def setFitRange(self,qmin=None,qmax=None):
223        """ to set the fit range"""
224       
225        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
226        #ToDo: Fix this.
227        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
228            self.qmin = min(self.data.x[self.data.x!=0])
229        elif qmin!=None:                       
230            self.qmin = qmin           
231
232        if qmax !=None:
233            self.qmax = qmax
234       
235       
236    def getFitRange(self):
237        """
238            @return the range of data.x to fit
239        """
240        return self.qmin, self.qmax
241     
242     
243    def residuals(self, fn):
244        """
245            Compute residuals.
246           
247            If self.smearer has been set, use if to smear
248            the data before computing chi squared.
249           
250            @param fn: function that return model value
251            @return residuals
252        """
253        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
254        if self.dy ==None or self.dy==[]:
255            dy= numpy.zeros(len(y)) 
256        else:
257            dy= copy.deepcopy(self.dy)
258            dy= numpy.asarray(dy)
259        dy[dy==0]=1
260        idx = (x>=self.qmin) & (x <= self.qmax)
261 
262        # Compute theory data f(x)
263        fx = numpy.zeros(len(x))
264        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
265       
266        # Smear theory data
267     
268        if self.smearer is not None:
269            fx = self.smearer(fx)
270       
271        # Sanity check
272        if numpy.size(dy) < numpy.size(x):
273            raise RuntimeError, "FitData1D: invalid error array"
274        return (y[idx] - fx[idx])/dy[idx]
275     
276 
277       
278    def residuals_deriv(self, model, pars=[]):
279        """
280            @return residuals derivatives .
281            @note: in this case just return empty array
282        """
283        return []
284   
285   
286class FitData2D(object):
287    """ Wrapper class  for SANS data """
288    def __init__(self,sans_data2d):
289        """
290            Data can be initital with a data (sans plottable)
291            or with vectors.
292        """
293        self.data=sans_data2d
294        self.image = sans_data2d.data
295        self.err_image = sans_data2d.err_data
296        self.x_bins= sans_data2d.x_bins
297        self.y_bins= sans_data2d.y_bins
298       
299        x = max(self.data.xmin, self.data.xmax)
300        y = max(self.data.ymin, self.data.ymax)
301       
302        ## fitting range
303        self.qmin = 1e-16
304        self.qmax = math.sqrt(x*x +y*y)
305       
306       
307       
308    def setFitRange(self,qmin=None,qmax=None):
309        """ to set the fit range"""
310        if qmin==0.0:
311            self.qmin = 1e-16
312        elif qmin!=None:                       
313            self.qmin = qmin           
314        if qmax!=None:
315            self.qmax= qmax
316     
317       
318    def getFitRange(self):
319        """
320            @return the range of data.x to fit
321        """
322        return self.qmin, self.qmax
323     
324     
325    def residuals(self, fn):
326        """ @param fn: function that return model value
327            @return residuals
328        """
329        res=[]
330        if self.err_image== None or self.err_image ==[]:
331            err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
332        else:
333            err_image = copy.deepcopy(self.err_image)
334           
335        err_image[err_image==0]=1
336        for i in range(len(self.x_bins)):
337            for j in range(len(self.y_bins)):
338                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
339                radius= math.sqrt(temp)
340                if self.qmin <= radius and radius <= self.qmax:
341                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
342                            /err_image[j][i] )
343       
344        return numpy.array(res)
345       
346         
347    def residuals_deriv(self, model, pars=[]):
348        """
349            @return residuals derivatives .
350            @note: in this case just return empty array
351        """
352        return []
353   
354class FitAbort(Exception):
355    """
356        Exception raise to stop the fit
357    """
358    print"Creating fit abort Exception"
359
360
361class sansAssembly:
362    """
363         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
364    """
365    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
366        """
367            @param Model: the model wrapper fro sans -model
368            @param Data: the data wrapper for sans data
369        """
370        self.model = Model
371        self.data  = Data
372        self.paramlist=paramlist
373        self.curr_thread= curr_thread
374        self.res=[]
375        self.func_name="Functor"
376    def chisq(self, params):
377        """
378            Calculates chi^2
379            @param params: list of parameter values
380            @return: chi^2
381        """
382        sum = 0
383        for item in self.res:
384            sum += item*item
385        if len(self.res)==0:
386            return None
387        return sum/ len(self.res)
388   
389    def __call__(self,params):
390        """
391            Compute residuals
392            @param params: value of parameters to fit
393        """
394        #import thread
395        self.model.setParams(self.paramlist,params)
396        self.res= self.data.residuals(self.model.eval)
397        #if self.curr_thread != None :
398        #    try:
399        #        self.curr_thread.isquit()
400        #    except:
401        #        raise FitAbort,"stop leastsqr optimizer"
402        return self.res
403   
404class FitEngine:
405    def __init__(self):
406        """
407            Base class for scipy and park fit engine
408        """
409        #List of parameter names to fit
410        self.paramList=[]
411        #Dictionnary of fitArrange element (fit problems)
412        self.fitArrangeDict={}
413       
414    def _concatenateData(self, listdata=[]):
415        """ 
416            _concatenateData method concatenates each fields of all data contains ins listdata.
417            @param listdata: list of data
418            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
419             of data concatenanted
420            @raise: if listdata is empty  will return None
421            @raise: if data in listdata don't contain dy field ,will create an error
422            during fitting
423        """
424        #TODO: we have to refactor the way we handle data.
425        # We should move away from plottables and move towards the Data1D objects
426        # defined in DataLoader. Data1D allows data manipulations, which should be
427        # used to concatenate.
428        # In the meantime we should switch off the concatenation.
429        #if len(listdata)>1:
430        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
431        #return listdata[0]
432       
433        if listdata==[]:
434            raise ValueError, " data list missing"
435        else:
436            xtemp=[]
437            ytemp=[]
438            dytemp=[]
439            self.mini=None
440            self.maxi=None
441               
442            for item in listdata:
443                data=item.data
444                mini,maxi=data.getFitRange()
445                if self.mini==None and self.maxi==None:
446                    self.mini=mini
447                    self.maxi=maxi
448                else:
449                    if mini < self.mini:
450                        self.mini=mini
451                    if self.maxi < maxi:
452                        self.maxi=maxi
453                       
454                   
455                for i in range(len(data.x)):
456                    xtemp.append(data.x[i])
457                    ytemp.append(data.y[i])
458                    if data.dy is not None and len(data.dy)==len(data.y):   
459                        dytemp.append(data.dy[i])
460                    else:
461                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
462            data= Data(x=xtemp,y=ytemp,dy=dytemp)
463            data.setFitRange(self.mini, self.maxi)
464            return data
465       
466       
467    def set_model(self,model,Uid,pars=[]):
468        """
469            set a model on a given uid in the fit engine.
470            @param model: the model to fit
471            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
472            @param pars: the list of parameters to fit
473            @note : pars must contains only name of existing model's paramaters
474        """
475        if len(pars) >0:
476            if model==None:
477                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
478            else:
479                temp=[]
480                for item in pars:
481                    if item in model.model.getParamList():
482                        temp.append(item)
483                        self.paramList.append(item)
484                    else:
485                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
486                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
487                        return
488            #A fitArrange is already created but contains dList only at Uid
489            if self.fitArrangeDict.has_key(Uid):
490                self.fitArrangeDict[Uid].set_model(model)
491                self.fitArrangeDict[Uid].pars= pars
492            else:
493            #no fitArrange object has been create with this Uid
494                fitproblem = FitArrange()
495                fitproblem.set_model(model)
496                fitproblem.pars= pars
497                self.fitArrangeDict[Uid] = fitproblem
498               
499        else:
500            raise ValueError, "park_integration:missing parameters"
501   
502    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
503        """ Receives plottable, creates a list of data to fit,set data
504            in a FitArrange object and adds that object in a dictionary
505            with key Uid.
506            @param data: data added
507            @param Uid: unique key corresponding to a fitArrange object with data
508        """
509        if data.__class__.__name__=='Data2D':
510            fitdata=FitData2D(data)
511        else:
512            fitdata=FitData1D(data, smearer)
513       
514        fitdata.setFitRange(qmin=qmin,qmax=qmax)
515        #A fitArrange is already created but contains model only at Uid
516        if self.fitArrangeDict.has_key(Uid):
517            self.fitArrangeDict[Uid].add_data(fitdata)
518        else:
519        #no fitArrange object has been create with this Uid
520            fitproblem= FitArrange()
521            fitproblem.add_data(fitdata)
522            self.fitArrangeDict[Uid]=fitproblem   
523   
524    def get_model(self,Uid):
525        """
526            @param Uid: Uid is key in the dictionary containing the model to return
527            @return  a model at this uid or None if no FitArrange element was created
528            with this Uid
529        """
530        if self.fitArrangeDict.has_key(Uid):
531            return self.fitArrangeDict[Uid].get_model()
532        else:
533            return None
534   
535    def remove_Fit_Problem(self,Uid):
536        """remove   fitarrange in Uid"""
537        if self.fitArrangeDict.has_key(Uid):
538            del self.fitArrangeDict[Uid]
539           
540    def select_problem_for_fit(self,Uid,value):
541        """
542            select a couple of model and data at the Uid position in dictionary
543            and set in self.selected value to value
544            @param value: the value to allow fitting. can only have the value one or zero
545        """
546        if self.fitArrangeDict.has_key(Uid):
547             self.fitArrangeDict[Uid].set_to_fit( value)
548             
549             
550    def get_problem_to_fit(self,Uid):
551        """
552            return the self.selected value of the fit problem of Uid
553           @param Uid: the Uid of the problem
554        """
555        if self.fitArrangeDict.has_key(Uid):
556             self.fitArrangeDict[Uid].get_to_fit()
557   
558class FitArrange:
559    def __init__(self):
560        """
561            Class FitArrange contains a set of data for a given model
562            to perform the Fit.FitArrange must contain exactly one model
563            and at least one data for the fit to be performed.
564            model: the model selected by the user
565            Ldata: a list of data what the user wants to fit
566           
567        """
568        self.model = None
569        self.dList =[]
570        self.pars=[]
571        #self.selected  is zero when this fit problem is not schedule to fit
572        #self.selected is 1 when schedule to fit
573        self.selected = 0
574       
575    def set_model(self,model):
576        """
577            set_model save a copy of the model
578            @param model: the model being set
579        """
580        self.model = model
581       
582    def add_data(self,data):
583        """
584            add_data fill a self.dList with data to fit
585            @param data: Data to add in the list 
586        """
587        if not data in self.dList:
588            self.dList.append(data)
589           
590    def get_model(self):
591        """ @return: saved model """
592        return self.model   
593     
594    def get_data(self):
595        """ @return:  list of data dList"""
596        #return self.dList
597        return self.dList[0] 
598     
599    def remove_data(self,data):
600        """
601            Remove one element from the list
602            @param data: Data to remove from dList
603        """
604        if data in self.dList:
605            self.dList.remove(data)
606    def set_to_fit (self, value=0):
607        """
608           set self.selected to 0 or 1  for other values raise an exception
609           @param value: integer between 0 or 1
610        """
611        self.selected= value
612       
613    def get_to_fit(self):
614        """
615            @return self.selected value
616        """
617        return self.selected
618   
619
620
621   
Note: See TracBrowser for help on using the repository browser.