source: sasview/park_integration/AbstractFitEngine.py @ 88b5e83

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 88b5e83 was 88b5e83, checked in by Gervaise Alina <gervyh@…>, 16 years ago

working on stop button

  • Property mode set to 100644
File size: 20.7 KB
Line 
1
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146       
147        self.qmin=mini           
148        self.qmax=maxi
149       
150       
151    def getFitRange(self):
152        """
153            @return the range of data.x to fit
154        """
155        return self.qmin, self.qmax
156     
157     
158    def residuals(self, fn):
159        """ @param fn: function that return model value
160            @return residuals
161        """
162        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
163        if self.qmin==None and self.qmax==None: 
164            fx =numpy.asarray([fn(v) for v in x])
165            return (y - fx)/dy
166        else:
167            idx = (x>=self.qmin) & (x <= self.qmax)
168            fx = numpy.asarray([fn(item)for item in x[idx ]])
169            return (y[idx] - fx)/dy[idx]
170       
171    def residuals_deriv(self, model, pars=[]):
172        """
173            @return residuals derivatives .
174            @note: in this case just return empty array
175        """
176        return []
177   
178   
179class FitData1D(object):
180    """ Wrapper class  for SANS data """
181    def __init__(self,sans_data1d, smearer=None):
182        """
183            Data can be initital with a data (sans plottable)
184            or with vectors.
185           
186            self.smearer is an object of class QSmearer or SlitSmearer
187            that will smear the theory data (slit smearing or resolution
188            smearing) when set.
189           
190            The proper way to set the smearing object would be to
191            do the following:
192           
193            from DataLoader.qsmearing import smear_selection
194            fitdata1d = FitData1D(some_data)
195            fitdata1d.smearer = smear_selection(some_data)
196           
197            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
198           
199            Setting it back to None will turn smearing off.
200           
201        """
202       
203        self.smearer = smearer
204     
205        # Initialize from Data1D object
206        self.data=sans_data1d
207        self.x= sans_data1d.x
208        self.y= sans_data1d.y
209        self.dx= sans_data1d.dx
210        self.dy= sans_data1d.dy
211       
212        ## Min Q-value
213        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
214        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
215            self.qmin = min(self.data.x[self.data.x!=0])
216        else:                             
217            self.qmin= min (self.data.x)
218        ## Max Q-value
219        self.qmax= max (self.data.x)
220       
221       
222    def setFitRange(self,qmin=None,qmax=None):
223        """ to set the fit range"""
224       
225        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
226        #ToDo: Fix this.
227        if qmin==0.0:
228            self.qmin = min(self.data.x[self.data.x!=0])
229        elif qmin!=None:                       
230            self.qmin = qmin           
231
232        if qmax !=None:
233            self.qmax = qmax
234       
235       
236    def getFitRange(self):
237        """
238            @return the range of data.x to fit
239        """
240        return self.qmin, self.qmax
241     
242     
243    def residuals(self, fn):
244        """
245            Compute residuals.
246           
247            If self.smearer has been set, use if to smear
248            the data before computing chi squared.
249           
250            @param fn: function that return model value
251            @return residuals
252        """
253        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
254        if self.dy ==None or self.dy==[]:
255            dy= numpy.zeros(len(y)) 
256        else:
257            dy= copy.deepcopy(self.dy)
258            dy= numpy.asarray(dy)
259        dy[dy==0]=1
260        idx = (x>=self.qmin) & (x <= self.qmax)
261 
262        # Compute theory data f(x)
263        fx = numpy.zeros(len(x))
264        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
265       
266        # Smear theory data
267     
268        if self.smearer is not None:
269            fx = self.smearer(fx)
270       
271        # Sanity check
272        if numpy.size(dy) < numpy.size(x):
273            raise RuntimeError, "FitData1D: invalid error array"
274        return (y[idx] - fx[idx])/dy[idx]
275     
276 
277       
278    def residuals_deriv(self, model, pars=[]):
279        """
280            @return residuals derivatives .
281            @note: in this case just return empty array
282        """
283        return []
284   
285   
286class FitData2D(object):
287    """ Wrapper class  for SANS data """
288    def __init__(self,sans_data2d):
289        """
290            Data can be initital with a data (sans plottable)
291            or with vectors.
292        """
293        self.data=sans_data2d
294        self.image = sans_data2d.data
295        self.err_image = sans_data2d.err_data
296        self.x_bins= sans_data2d.x_bins
297        self.y_bins= sans_data2d.y_bins
298       
299        x = max(self.data.xmin, self.data.xmax)
300        y = max(self.data.ymin, self.data.ymax)
301       
302        ## fitting range
303        self.qmin = 1e-16
304        self.qmax = math.sqrt(x*x +y*y)
305       
306       
307       
308    def setFitRange(self,qmin=None,qmax=None):
309        """ to set the fit range"""
310        if qmin==0.0:
311            self.qmin = 1e-16
312        elif qmin!=None:                       
313            self.qmin = qmin           
314        if qmax!=None:
315            self.qmax= qmax
316     
317       
318    def getFitRange(self):
319        """
320            @return the range of data.x to fit
321        """
322        return self.qmin, self.qmax
323     
324     
325    def residuals(self, fn):
326        """ @param fn: function that return model value
327            @return residuals
328        """
329        res=[]
330        if self.err_image== None or self.err_image ==[]:
331            err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
332        else:
333            err_image = copy.deepcopy(self.err_image)
334           
335        err_image[err_image==0]=1
336        for i in range(len(self.x_bins)):
337            for j in range(len(self.y_bins)):
338                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
339                radius= math.sqrt(temp)
340                if self.qmin <= radius and radius <= self.qmax:
341                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
342                            /err_image[j][i] )
343       
344        return numpy.array(res)
345       
346         
347    def residuals_deriv(self, model, pars=[]):
348        """
349            @return residuals derivatives .
350            @note: in this case just return empty array
351        """
352        return []
353   
354class FitAbort(Exception):
355    """
356        Exception raise to stop the fit
357    """
358    print"Creating fit abort Exception"
359
360
361class sansAssembly:
362    """
363         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
364    """
365    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
366        """
367            @param Model: the model wrapper fro sans -model
368            @param Data: the data wrapper for sans data
369        """
370        self.model = Model
371        self.data  = Data
372        self.paramlist=paramlist
373        self.curr_thread= curr_thread
374        self.res=[]
375        self.func_name="Functor"
376    def chisq(self, params):
377        """
378            Calculates chi^2
379            @param params: list of parameter values
380            @return: chi^2
381        """
382        sum = 0
383        for item in self.res:
384            sum += item*item
385        if len(self.res)==0:
386            return None
387        return sum/ len(self.res)
388   
389    def __call__(self,params):
390        """
391            Compute residuals
392            @param params: value of parameters to fit
393        """
394        #import thread
395        self.model.setParams(self.paramlist,params)
396        self.res= self.data.residuals(self.model.eval)
397        if self.curr_thread != None :
398            try:
399                self.curr_thread.isquit()
400            except:
401                raise FitAbort,"stop leastsqr optimizer"
402               
403        return self.res
404   
405class FitEngine:
406    def __init__(self):
407        """
408            Base class for scipy and park fit engine
409        """
410        #List of parameter names to fit
411        self.paramList=[]
412        #Dictionnary of fitArrange element (fit problems)
413        self.fitArrangeDict={}
414       
415    def _concatenateData(self, listdata=[]):
416        """ 
417            _concatenateData method concatenates each fields of all data contains ins listdata.
418            @param listdata: list of data
419            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
420             of data concatenanted
421            @raise: if listdata is empty  will return None
422            @raise: if data in listdata don't contain dy field ,will create an error
423            during fitting
424        """
425        #TODO: we have to refactor the way we handle data.
426        # We should move away from plottables and move towards the Data1D objects
427        # defined in DataLoader. Data1D allows data manipulations, which should be
428        # used to concatenate.
429        # In the meantime we should switch off the concatenation.
430        #if len(listdata)>1:
431        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
432        #return listdata[0]
433       
434        if listdata==[]:
435            raise ValueError, " data list missing"
436        else:
437            xtemp=[]
438            ytemp=[]
439            dytemp=[]
440            self.mini=None
441            self.maxi=None
442               
443            for item in listdata:
444                data=item.data
445                mini,maxi=data.getFitRange()
446                if self.mini==None and self.maxi==None:
447                    self.mini=mini
448                    self.maxi=maxi
449                else:
450                    if mini < self.mini:
451                        self.mini=mini
452                    if self.maxi < maxi:
453                        self.maxi=maxi
454                       
455                   
456                for i in range(len(data.x)):
457                    xtemp.append(data.x[i])
458                    ytemp.append(data.y[i])
459                    if data.dy is not None and len(data.dy)==len(data.y):   
460                        dytemp.append(data.dy[i])
461                    else:
462                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
463            data= Data(x=xtemp,y=ytemp,dy=dytemp)
464            data.setFitRange(self.mini, self.maxi)
465            return data
466       
467       
468    def set_model(self,model,Uid,pars=[]):
469        """
470            set a model on a given uid in the fit engine.
471            @param model: the model to fit
472            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
473            @param pars: the list of parameters to fit
474            @note : pars must contains only name of existing model's paramaters
475        """
476        if len(pars) >0:
477            if model==None:
478                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
479            else:
480                temp=[]
481                for item in pars:
482                    if item in model.model.getParamList():
483                        temp.append(item)
484                        self.paramList.append(item)
485                    else:
486                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
487                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
488                        return
489            #A fitArrange is already created but contains dList only at Uid
490            if self.fitArrangeDict.has_key(Uid):
491                self.fitArrangeDict[Uid].set_model(model)
492                self.fitArrangeDict[Uid].pars= pars
493            else:
494            #no fitArrange object has been create with this Uid
495                fitproblem = FitArrange()
496                fitproblem.set_model(model)
497                fitproblem.pars= pars
498                self.fitArrangeDict[Uid] = fitproblem
499               
500        else:
501            raise ValueError, "park_integration:missing parameters"
502   
503    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
504        """ Receives plottable, creates a list of data to fit,set data
505            in a FitArrange object and adds that object in a dictionary
506            with key Uid.
507            @param data: data added
508            @param Uid: unique key corresponding to a fitArrange object with data
509        """
510        if data.__class__.__name__=='Data2D':
511            fitdata=FitData2D(data)
512        else:
513            fitdata=FitData1D(data, smearer)
514       
515        fitdata.setFitRange(qmin=qmin,qmax=qmax)
516        #A fitArrange is already created but contains model only at Uid
517        if self.fitArrangeDict.has_key(Uid):
518            self.fitArrangeDict[Uid].add_data(fitdata)
519        else:
520        #no fitArrange object has been create with this Uid
521            fitproblem= FitArrange()
522            fitproblem.add_data(fitdata)
523            self.fitArrangeDict[Uid]=fitproblem   
524   
525    def get_model(self,Uid):
526        """
527            @param Uid: Uid is key in the dictionary containing the model to return
528            @return  a model at this uid or None if no FitArrange element was created
529            with this Uid
530        """
531        if self.fitArrangeDict.has_key(Uid):
532            return self.fitArrangeDict[Uid].get_model()
533        else:
534            return None
535   
536    def remove_Fit_Problem(self,Uid):
537        """remove   fitarrange in Uid"""
538        if self.fitArrangeDict.has_key(Uid):
539            del self.fitArrangeDict[Uid]
540           
541    def select_problem_for_fit(self,Uid,value):
542        """
543            select a couple of model and data at the Uid position in dictionary
544            and set in self.selected value to value
545            @param value: the value to allow fitting. can only have the value one or zero
546        """
547        if self.fitArrangeDict.has_key(Uid):
548             self.fitArrangeDict[Uid].set_to_fit( value)
549             
550             
551    def get_problem_to_fit(self,Uid):
552        """
553            return the self.selected value of the fit problem of Uid
554           @param Uid: the Uid of the problem
555        """
556        if self.fitArrangeDict.has_key(Uid):
557             self.fitArrangeDict[Uid].get_to_fit()
558   
559class FitArrange:
560    def __init__(self):
561        """
562            Class FitArrange contains a set of data for a given model
563            to perform the Fit.FitArrange must contain exactly one model
564            and at least one data for the fit to be performed.
565            model: the model selected by the user
566            Ldata: a list of data what the user wants to fit
567           
568        """
569        self.model = None
570        self.dList =[]
571        self.pars=[]
572        #self.selected  is zero when this fit problem is not schedule to fit
573        #self.selected is 1 when schedule to fit
574        self.selected = 0
575       
576    def set_model(self,model):
577        """
578            set_model save a copy of the model
579            @param model: the model being set
580        """
581        self.model = model
582       
583    def add_data(self,data):
584        """
585            add_data fill a self.dList with data to fit
586            @param data: Data to add in the list 
587        """
588        if not data in self.dList:
589            self.dList.append(data)
590           
591    def get_model(self):
592        """ @return: saved model """
593        return self.model   
594     
595    def get_data(self):
596        """ @return:  list of data dList"""
597        #return self.dList
598        return self.dList[0] 
599     
600    def remove_data(self,data):
601        """
602            Remove one element from the list
603            @param data: Data to remove from dList
604        """
605        if data in self.dList:
606            self.dList.remove(data)
607    def set_to_fit (self, value=0):
608        """
609           set self.selected to 0 or 1  for other values raise an exception
610           @param value: integer between 0 or 1
611        """
612        self.selected= value
613       
614    def get_to_fit(self):
615        """
616            @return self.selected value
617        """
618        return self.selected
619   
620
621
622   
Note: See TracBrowser for help on using the repository browser.