source: sasview/park_integration/AbstractFitEngine.py @ 72a90bd

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 72a90bd was 24b8d5c, checked in by Gervaise Alina <gervyh@…>, 16 years ago

remove stop fit option

  • Property mode set to 100644
File size: 21.3 KB
Line 
1
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146       
147        self.qmin=mini           
148        self.qmax=maxi
149       
150       
151    def getFitRange(self):
152        """
153            @return the range of data.x to fit
154        """
155        return self.qmin, self.qmax
156     
157     
158    def residuals(self, fn):
159        """ @param fn: function that return model value
160            @return residuals
161        """
162        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
163        if self.qmin==None and self.qmax==None: 
164            fx =numpy.asarray([fn(v) for v in x])
165            return (y - fx)/dy
166        else:
167            idx = (x>=self.qmin) & (x <= self.qmax)
168            fx = numpy.asarray([fn(item)for item in x[idx ]])
169            return (y[idx] - fx)/dy[idx]
170       
171    def residuals_deriv(self, model, pars=[]):
172        """
173            @return residuals derivatives .
174            @note: in this case just return empty array
175        """
176        return []
177   
178   
179class FitData1D(object):
180    """ Wrapper class  for SANS data """
181    def __init__(self,sans_data1d, smearer=None):
182        """
183            Data can be initital with a data (sans plottable)
184            or with vectors.
185           
186            self.smearer is an object of class QSmearer or SlitSmearer
187            that will smear the theory data (slit smearing or resolution
188            smearing) when set.
189           
190            The proper way to set the smearing object would be to
191            do the following:
192           
193            from DataLoader.qsmearing import smear_selection
194            fitdata1d = FitData1D(some_data)
195            fitdata1d.smearer = smear_selection(some_data)
196           
197            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
198           
199            Setting it back to None will turn smearing off.
200           
201        """
202       
203        self.smearer = smearer
204     
205        # Initialize from Data1D object
206        self.data=sans_data1d
207        self.x= sans_data1d.x
208        self.y= sans_data1d.y
209        self.dx= sans_data1d.dx
210        self.dy= sans_data1d.dy
211       
212        ## Min Q-value
213        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
214        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
215            self.qmin = min(self.data.x[self.data.x!=0])
216        else:                             
217            self.qmin= min (self.data.x)
218        ## Max Q-value
219        self.qmax= max (self.data.x)
220       
221       
222    def setFitRange(self,qmin=None,qmax=None):
223        """ to set the fit range"""
224       
225        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
226        #ToDo: Fix this.
227        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
228            self.qmin = min(self.data.x[self.data.x!=0])
229        elif qmin!=None:                       
230            self.qmin = qmin           
231
232        if qmax !=None:
233            self.qmax = qmax
234       
235       
236    def getFitRange(self):
237        """
238            @return the range of data.x to fit
239        """
240        return self.qmin, self.qmax
241     
242     
243    def residuals(self, fn):
244        """
245            Compute residuals.
246           
247            If self.smearer has been set, use if to smear
248            the data before computing chi squared.
249           
250            @param fn: function that return model value
251            @return residuals
252        """
253        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
254        if self.dy ==None or self.dy==[]:
255            dy= numpy.zeros(len(y)) 
256        else:
257            dy= copy.deepcopy(self.dy)
258            dy= numpy.asarray(dy)
259     
260        dy[dy==0]=1
261        #idx = (x>=self.qmin) & (x <= self.qmax)
262 
263        # Compute theory data f(x)
264        #fx = numpy.zeros(len(x))
265        tempy=[]
266        tempfx=[]
267        tempdy=[]
268        #fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
269        for i_x in  range(len(x)):
270            try:
271                if self.qmin <=x[i_x] and x[i_x]<=self.qmax:
272                    value= fn(x[i_x])
273                    tempfx.append( value)
274                    tempy.append(y[i_x])
275                    tempdy.append(dy[i_x])
276            except:
277                ## skip error for model.run(x)
278                pass
279                 
280        ## Smear theory data
281        if self.smearer is not None:
282            tempfx = self.smearer(tempfx)
283        newy= numpy.asarray(tempy)
284        newfx= numpy.asarray(tempfx)
285        newdy= numpy.asarray(tempdy)
286       
287       
288        ## Sanity check
289        if numpy.size(newdy)!= numpy.size(newfx):
290            raise RuntimeError, "FitData1D: invalid error array"
291        #return (y[idx] - fx[idx])/dy[idx]
292       
293        return (newy- newfx)/newdy
294     
295 
296       
297    def residuals_deriv(self, model, pars=[]):
298        """
299            @return residuals derivatives .
300            @note: in this case just return empty array
301        """
302        return []
303   
304   
305class FitData2D(object):
306    """ Wrapper class  for SANS data """
307    def __init__(self,sans_data2d):
308        """
309            Data can be initital with a data (sans plottable)
310            or with vectors.
311        """
312        self.data=sans_data2d
313        self.image = sans_data2d.data
314        self.err_image = sans_data2d.err_data
315        self.x_bins= sans_data2d.x_bins
316        self.y_bins= sans_data2d.y_bins
317       
318        x = max(self.data.xmin, self.data.xmax)
319        y = max(self.data.ymin, self.data.ymax)
320       
321        ## fitting range
322        self.qmin = 1e-16
323        self.qmax = math.sqrt(x*x +y*y)
324       
325       
326       
327    def setFitRange(self,qmin=None,qmax=None):
328        """ to set the fit range"""
329        if qmin==0.0:
330            self.qmin = 1e-16
331        elif qmin!=None:                       
332            self.qmin = qmin           
333        if qmax!=None:
334            self.qmax= qmax
335     
336       
337    def getFitRange(self):
338        """
339            @return the range of data.x to fit
340        """
341        return self.qmin, self.qmax
342     
343     
344    def residuals(self, fn):
345        """ @param fn: function that return model value
346            @return residuals
347        """
348        res=[]
349        if self.err_image== None or self.err_image ==[]:
350            err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
351        else:
352            err_image = copy.deepcopy(self.err_image)
353           
354        err_image[err_image==0]=1
355        for i in range(len(self.x_bins)):
356            for j in range(len(self.y_bins)):
357                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
358                radius= math.sqrt(temp)
359                if self.qmin <= radius and radius <= self.qmax:
360                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
361                            /err_image[j][i] )
362       
363        return numpy.array(res)
364       
365         
366    def residuals_deriv(self, model, pars=[]):
367        """
368            @return residuals derivatives .
369            @note: in this case just return empty array
370        """
371        return []
372   
373class FitAbort(Exception):
374    """
375        Exception raise to stop the fit
376    """
377    print"Creating fit abort Exception"
378
379
380class sansAssembly:
381    """
382         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
383    """
384    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
385        """
386            @param Model: the model wrapper fro sans -model
387            @param Data: the data wrapper for sans data
388        """
389        self.model = Model
390        self.data  = Data
391        self.paramlist=paramlist
392        self.curr_thread= curr_thread
393        self.res=[]
394        self.func_name="Functor"
395    def chisq(self, params):
396        """
397            Calculates chi^2
398            @param params: list of parameter values
399            @return: chi^2
400        """
401        sum = 0
402        for item in self.res:
403            sum += item*item
404        if len(self.res)==0:
405            return None
406        return sum/ len(self.res)
407   
408    def __call__(self,params):
409        """
410            Compute residuals
411            @param params: value of parameters to fit
412        """
413        #import thread
414        self.model.setParams(self.paramlist,params)
415        self.res= self.data.residuals(self.model.eval)
416        #if self.curr_thread != None :
417        #    try:
418        #        self.curr_thread.isquit()
419        #    except:
420        #        raise FitAbort,"stop leastsqr optimizer"   
421        return self.res
422   
423class FitEngine:
424    def __init__(self):
425        """
426            Base class for scipy and park fit engine
427        """
428        #List of parameter names to fit
429        self.paramList=[]
430        #Dictionnary of fitArrange element (fit problems)
431        self.fitArrangeDict={}
432       
433    def _concatenateData(self, listdata=[]):
434        """ 
435            _concatenateData method concatenates each fields of all data contains ins listdata.
436            @param listdata: list of data
437            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
438             of data concatenanted
439            @raise: if listdata is empty  will return None
440            @raise: if data in listdata don't contain dy field ,will create an error
441            during fitting
442        """
443        #TODO: we have to refactor the way we handle data.
444        # We should move away from plottables and move towards the Data1D objects
445        # defined in DataLoader. Data1D allows data manipulations, which should be
446        # used to concatenate.
447        # In the meantime we should switch off the concatenation.
448        #if len(listdata)>1:
449        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
450        #return listdata[0]
451       
452        if listdata==[]:
453            raise ValueError, " data list missing"
454        else:
455            xtemp=[]
456            ytemp=[]
457            dytemp=[]
458            self.mini=None
459            self.maxi=None
460               
461            for item in listdata:
462                data=item.data
463                mini,maxi=data.getFitRange()
464                if self.mini==None and self.maxi==None:
465                    self.mini=mini
466                    self.maxi=maxi
467                else:
468                    if mini < self.mini:
469                        self.mini=mini
470                    if self.maxi < maxi:
471                        self.maxi=maxi
472                       
473                   
474                for i in range(len(data.x)):
475                    xtemp.append(data.x[i])
476                    ytemp.append(data.y[i])
477                    if data.dy is not None and len(data.dy)==len(data.y):   
478                        dytemp.append(data.dy[i])
479                    else:
480                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
481            data= Data(x=xtemp,y=ytemp,dy=dytemp)
482            data.setFitRange(self.mini, self.maxi)
483            return data
484       
485       
486    def set_model(self,model,Uid,pars=[]):
487        """
488            set a model on a given uid in the fit engine.
489            @param model: the model to fit
490            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
491            @param pars: the list of parameters to fit
492            @note : pars must contains only name of existing model's paramaters
493        """
494        if len(pars) >0:
495            if model==None:
496                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
497            else:
498                temp=[]
499                for item in pars:
500                    if item in model.model.getParamList():
501                        temp.append(item)
502                        self.paramList.append(item)
503                    else:
504                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
505                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
506                        return
507            #A fitArrange is already created but contains dList only at Uid
508            if self.fitArrangeDict.has_key(Uid):
509                self.fitArrangeDict[Uid].set_model(model)
510                self.fitArrangeDict[Uid].pars= pars
511            else:
512            #no fitArrange object has been create with this Uid
513                fitproblem = FitArrange()
514                fitproblem.set_model(model)
515                fitproblem.pars= pars
516                self.fitArrangeDict[Uid] = fitproblem
517               
518        else:
519            raise ValueError, "park_integration:missing parameters"
520   
521    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
522        """ Receives plottable, creates a list of data to fit,set data
523            in a FitArrange object and adds that object in a dictionary
524            with key Uid.
525            @param data: data added
526            @param Uid: unique key corresponding to a fitArrange object with data
527        """
528        if data.__class__.__name__=='Data2D':
529            fitdata=FitData2D(data)
530        else:
531            fitdata=FitData1D(data, smearer)
532       
533        fitdata.setFitRange(qmin=qmin,qmax=qmax)
534        #A fitArrange is already created but contains model only at Uid
535        if self.fitArrangeDict.has_key(Uid):
536            self.fitArrangeDict[Uid].add_data(fitdata)
537        else:
538        #no fitArrange object has been create with this Uid
539            fitproblem= FitArrange()
540            fitproblem.add_data(fitdata)
541            self.fitArrangeDict[Uid]=fitproblem   
542   
543    def get_model(self,Uid):
544        """
545            @param Uid: Uid is key in the dictionary containing the model to return
546            @return  a model at this uid or None if no FitArrange element was created
547            with this Uid
548        """
549        if self.fitArrangeDict.has_key(Uid):
550            return self.fitArrangeDict[Uid].get_model()
551        else:
552            return None
553   
554    def remove_Fit_Problem(self,Uid):
555        """remove   fitarrange in Uid"""
556        if self.fitArrangeDict.has_key(Uid):
557            del self.fitArrangeDict[Uid]
558           
559    def select_problem_for_fit(self,Uid,value):
560        """
561            select a couple of model and data at the Uid position in dictionary
562            and set in self.selected value to value
563            @param value: the value to allow fitting. can only have the value one or zero
564        """
565        if self.fitArrangeDict.has_key(Uid):
566             self.fitArrangeDict[Uid].set_to_fit( value)
567             
568             
569    def get_problem_to_fit(self,Uid):
570        """
571            return the self.selected value of the fit problem of Uid
572           @param Uid: the Uid of the problem
573        """
574        if self.fitArrangeDict.has_key(Uid):
575             self.fitArrangeDict[Uid].get_to_fit()
576   
577class FitArrange:
578    def __init__(self):
579        """
580            Class FitArrange contains a set of data for a given model
581            to perform the Fit.FitArrange must contain exactly one model
582            and at least one data for the fit to be performed.
583            model: the model selected by the user
584            Ldata: a list of data what the user wants to fit
585           
586        """
587        self.model = None
588        self.dList =[]
589        self.pars=[]
590        #self.selected  is zero when this fit problem is not schedule to fit
591        #self.selected is 1 when schedule to fit
592        self.selected = 0
593       
594    def set_model(self,model):
595        """
596            set_model save a copy of the model
597            @param model: the model being set
598        """
599        self.model = model
600       
601    def add_data(self,data):
602        """
603            add_data fill a self.dList with data to fit
604            @param data: Data to add in the list 
605        """
606        if not data in self.dList:
607            self.dList.append(data)
608           
609    def get_model(self):
610        """ @return: saved model """
611        return self.model   
612     
613    def get_data(self):
614        """ @return:  list of data dList"""
615        #return self.dList
616        return self.dList[0] 
617     
618    def remove_data(self,data):
619        """
620            Remove one element from the list
621            @param data: Data to remove from dList
622        """
623        if data in self.dList:
624            self.dList.remove(data)
625    def set_to_fit (self, value=0):
626        """
627           set self.selected to 0 or 1  for other values raise an exception
628           @param value: integer between 0 or 1
629        """
630        self.selected= value
631       
632    def get_to_fit(self):
633        """
634            @return self.selected value
635        """
636        return self.selected
637   
638
639
640   
Note: See TracBrowser for help on using the repository browser.