source: sasview/park_integration/AbstractFitEngine.py @ 1ae3fe1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 1ae3fe1 was aed7c57, checked in by Gervaise Alina <gervyh@…>, 15 years ago

change parameter status on parkFit park_integration

  • Property mode set to 100644
File size: 19.5 KB
Line 
1
2import park,numpy,math
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        return lo,hi
48   
49    def _setrange(self,r):
50        """
51            override _setrange of park parameter
52            @param r: the value of the range to set
53        """
54        self._model.details[self.name][1:] = r
55    range = property(_getrange,_setrange)
56   
57class Model(park.Model):
58    """
59        PARK wrapper for SANS models.
60    """
61    def __init__(self, sans_model, **kw):
62        """
63            @param sans_model: the sans model to wrap using park interface
64        """
65        park.Model.__init__(self, **kw)
66        self.model = sans_model
67        self.name = sans_model.name
68        #list of parameters names
69        self.sansp = sans_model.getParamList()
70        #list of park parameter
71        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
72        #list of parameterset
73        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
74        self.pars=[]
75 
76 
77    def getParams(self,fitparams):
78        """
79            return a list of value of paramter to fit
80            @param fitparams: list of paramaters name to fit
81        """
82        list=[]
83        self.pars=[]
84        self.pars=fitparams
85        for item in fitparams:
86            for element in self.parkp:
87                 if element.name ==str(item):
88                     list.append(element.value)
89        return list
90   
91   
92    def setParams(self,paramlist, params):
93        """
94            Set value for parameters to fit
95            @param params: list of value for parameters to fit
96        """
97        try:
98            for i in range(len(self.parkp)):
99                for j in range(len(paramlist)):
100                    if self.parkp[i].name==paramlist[j]:
101                        self.parkp[i].value = params[j]
102                        self.model.setParam(self.parkp[i].name,params[j])
103        except:
104            raise
105 
106    def eval(self,x):
107        """
108            override eval method of park model.
109            @param x: the x value used to compute a function
110        """
111        return self.model.runXY(x)
112   
113   
114
115
116class Data(object):
117    """ Wrapper class  for SANS data """
118    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
119        """
120            Data can be initital with a data (sans plottable)
121            or with vectors.
122        """
123        if  sans_data !=None:
124            self.x= sans_data.x
125            self.y= sans_data.y
126            self.dx= sans_data.dx
127            self.dy= sans_data.dy
128           
129        elif (x!=None and y!=None and dy!=None):
130                self.x=x
131                self.y=y
132                self.dx=dx
133                self.dy=dy
134        else:
135            raise ValueError,\
136            "Data is missing x, y or dy, impossible to compute residuals later on"
137        self.qmin=None
138        self.qmax=None
139       
140       
141    def setFitRange(self,mini=None,maxi=None):
142        """ to set the fit range"""
143        self.qmin=mini
144        self.qmax=maxi
145       
146       
147    def getFitRange(self):
148        """
149            @return the range of data.x to fit
150        """
151        return self.qmin, self.qmax
152     
153     
154    def residuals(self, fn):
155        """ @param fn: function that return model value
156            @return residuals
157        """
158        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
159        if self.qmin==None and self.qmax==None: 
160            fx =numpy.asarray([fn(v) for v in x])
161            return (y - fx)/dy
162        else:
163            idx = (x>=self.qmin) & (x <= self.qmax)
164            fx = numpy.asarray([fn(item)for item in x[idx ]])
165            return (y[idx] - fx)/dy[idx]
166       
167    def residuals_deriv(self, model, pars=[]):
168        """
169            @return residuals derivatives .
170            @note: in this case just return empty array
171        """
172        return []
173class FitData1D(object):
174    """ Wrapper class  for SANS data """
175    def __init__(self,sans_data1d, smearer=None):
176        """
177            Data can be initital with a data (sans plottable)
178            or with vectors.
179           
180            self.smearer is an object of class QSmearer or SlitSmearer
181            that will smear the theory data (slit smearing or resolution
182            smearing) when set.
183           
184            The proper way to set the smearing object would be to
185            do the following:
186           
187            from DataLoader.qsmearing import smear_selection
188            fitdata1d = FitData1D(some_data)
189            fitdata1d.smearer = smear_selection(some_data)
190           
191            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
192           
193            Setting it back to None will turn smearing off.
194           
195        """
196       
197        self.smearer = smearer
198     
199        # Initialize from Data1D object
200        self.data=sans_data1d
201        self.x= sans_data1d.x
202        self.y= sans_data1d.y
203        self.dx= sans_data1d.dx
204        self.dy= sans_data1d.dy
205       
206        ## Min Q-value
207        self.qmin= min (self.data.x)
208        ## Max Q-value
209        self.qmax= max (self.data.x)
210       
211       
212    def setFitRange(self,qmin=None,qmax=None):
213        """ to set the fit range"""
214        self.qmin = qmin
215        self.qmax = qmax
216       
217       
218    def getFitRange(self):
219        """
220            @return the range of data.x to fit
221        """
222        return self.qmin, self.qmax
223     
224     
225    def residuals(self, fn):
226        """
227            Compute residuals.
228           
229            If self.smearer has been set, use if to smear
230            the data before computing chi squared.
231           
232            @param fn: function that return model value
233            @return residuals
234        """
235        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
236           
237        # Find entries to consider
238        if self.qmin==None:
239            self.qmin= min(self.data.x)
240        if  self.qmax==None:
241            self.qmin= max(self.data.x)
242       
243        idx = (x>=self.qmin) & (x <= self.qmax)
244 
245        # Compute theory data f(x)
246        fx = numpy.zeros(len(x))
247        fx[idx] = numpy.asarray([fn(v) for v in x[idx]])
248       
249        # Smear theory data
250     
251        if self.smearer is not None:
252            fx = self.smearer(fx)
253       
254        # Sanity check
255        if numpy.size(dy) < numpy.size(x):
256            raise RuntimeError, "FitData1D: invalid error array"
257                           
258        return (y[idx] - fx[idx])/dy[idx]
259     
260 
261       
262    def residuals_deriv(self, model, pars=[]):
263        """
264            @return residuals derivatives .
265            @note: in this case just return empty array
266        """
267        return []
268   
269   
270class FitData2D(object):
271    """ Wrapper class  for SANS data """
272    def __init__(self,sans_data2d):
273        """
274            Data can be initital with a data (sans plottable)
275            or with vectors.
276        """
277        self.data=sans_data2d
278        self.image = sans_data2d.data
279        self.err_image = sans_data2d.err_data
280        self.x_bins= sans_data2d.x_bins
281        self.y_bins= sans_data2d.y_bins
282       
283        x = max(self.data.xmin, self.data.xmax)
284        y = max(self.data.ymin, self.data.ymax)
285       
286        ## fitting range
287        self.qmin = 0
288        self.qmax = math.sqrt(x*x +y*y)
289       
290       
291       
292    def setFitRange(self,qmin=None,qmax=None):
293        """ to set the fit range"""
294        self.qmin= qmin
295        self.qmax= qmax
296     
297       
298    def getFitRange(self):
299        """
300            @return the range of data.x to fit
301        """
302        return self.qmin, self.qmax
303     
304     
305    def residuals(self, fn):
306        """ @param fn: function that return model value
307            @return residuals
308        """
309        res=[]
310        #Here we define that  qmin >=0 and qmax>=qmain
311        x = max(self.data.xmin, self.data.xmax)
312        y = max(self.data.ymin, self.data.ymax)
313       
314        ## fitting range
315        if self.qmin == None:
316            self.qmin = 0
317        if self.qmax == None:
318            self.qmax = math.sqrt(x*x +y*y)
319       
320       
321        for i in range(len(self.y_bins)):
322            for j in range(len(self.x_bins)):
323                radius = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
324                if self.qmin <= radius and radius <= self.qmax:
325                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
326                            /self.err_image[j][i] )
327       
328        return numpy.array(res)
329       
330         
331    def residuals_deriv(self, model, pars=[]):
332        """
333            @return residuals derivatives .
334            @note: in this case just return empty array
335        """
336        return []
337   
338class sansAssembly:
339    """
340         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
341    """
342    def __init__(self,paramlist,Model=None , Data=None):
343        """
344            @param Model: the model wrapper fro sans -model
345            @param Data: the data wrapper for sans data
346        """
347        self.model = Model
348        self.data  = Data
349        self.paramlist=paramlist
350        self.res=[]
351    def chisq(self, params):
352        """
353            Calculates chi^2
354            @param params: list of parameter values
355            @return: chi^2
356        """
357        sum = 0
358        for item in self.res:
359            sum += item*item
360       
361        return sum/ len(self.res)
362   
363    def __call__(self,params):
364        """
365            Compute residuals
366            @param params: value of parameters to fit
367        """
368        #import thread
369        self.model.setParams(self.paramlist,params)
370        self.res= self.data.residuals(self.model.eval)
371        #print "residuals",thread.get_ident()
372        return self.res
373   
374class FitEngine:
375    def __init__(self):
376        """
377            Base class for scipy and park fit engine
378        """
379        #List of parameter names to fit
380        self.paramList=[]
381        #Dictionnary of fitArrange element (fit problems)
382        self.fitArrangeDict={}
383       
384    def _concatenateData(self, listdata=[]):
385        """ 
386            _concatenateData method concatenates each fields of all data contains ins listdata.
387            @param listdata: list of data
388            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
389             of data concatenanted
390            @raise: if listdata is empty  will return None
391            @raise: if data in listdata don't contain dy field ,will create an error
392            during fitting
393        """
394        #TODO: we have to refactor the way we handle data.
395        # We should move away from plottables and move towards the Data1D objects
396        # defined in DataLoader. Data1D allows data manipulations, which should be
397        # used to concatenate.
398        # In the meantime we should switch off the concatenation.
399        #if len(listdata)>1:
400        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
401        #return listdata[0]
402       
403        if listdata==[]:
404            raise ValueError, " data list missing"
405        else:
406            xtemp=[]
407            ytemp=[]
408            dytemp=[]
409            self.mini=None
410            self.maxi=None
411               
412            for item in listdata:
413                data=item.data
414                mini,maxi=data.getFitRange()
415                if self.mini==None and self.maxi==None:
416                    self.mini=mini
417                    self.maxi=maxi
418                else:
419                    if mini < self.mini:
420                        self.mini=mini
421                    if self.maxi < maxi:
422                        self.maxi=maxi
423                       
424                   
425                for i in range(len(data.x)):
426                    xtemp.append(data.x[i])
427                    ytemp.append(data.y[i])
428                    if data.dy is not None and len(data.dy)==len(data.y):   
429                        dytemp.append(data.dy[i])
430                    else:
431                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
432            data= Data(x=xtemp,y=ytemp,dy=dytemp)
433            data.setFitRange(self.mini, self.maxi)
434            return data
435       
436       
437    def set_model(self,model,Uid,pars=[]):
438        """
439            set a model on a given uid in the fit engine.
440            @param model: the model to fit
441            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
442            @param pars: the list of parameters to fit
443            @note : pars must contains only name of existing model's paramaters
444        """
445        if len(pars) >0:
446            if model==None:
447                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
448            else:
449                temp=[]
450                for item in pars:
451                    if item in model.model.getParamList():
452                        temp.append(item)
453                        self.paramList.append(item)
454                    else:
455                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
456                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
457                        return
458            #A fitArrange is already created but contains dList only at Uid
459            if self.fitArrangeDict.has_key(Uid):
460                self.fitArrangeDict[Uid].set_model(model)
461                self.fitArrangeDict[Uid].pars= pars
462            else:
463            #no fitArrange object has been create with this Uid
464                fitproblem = FitArrange()
465                fitproblem.set_model(model)
466                fitproblem.pars= pars
467                self.fitArrangeDict[Uid] = fitproblem
468               
469        else:
470            raise ValueError, "park_integration:missing parameters"
471   
472    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
473        """ Receives plottable, creates a list of data to fit,set data
474            in a FitArrange object and adds that object in a dictionary
475            with key Uid.
476            @param data: data added
477            @param Uid: unique key corresponding to a fitArrange object with data
478        """
479        if data.__class__.__name__=='Data2D':
480            fitdata=FitData2D(data)
481        else:
482            fitdata=FitData1D(data, smearer)
483       
484        fitdata.setFitRange(qmin=qmin,qmax=qmax)
485        #A fitArrange is already created but contains model only at Uid
486        if self.fitArrangeDict.has_key(Uid):
487            self.fitArrangeDict[Uid].add_data(fitdata)
488        else:
489        #no fitArrange object has been create with this Uid
490            fitproblem= FitArrange()
491            fitproblem.add_data(fitdata)
492            self.fitArrangeDict[Uid]=fitproblem   
493   
494    def get_model(self,Uid):
495        """
496            @param Uid: Uid is key in the dictionary containing the model to return
497            @return  a model at this uid or None if no FitArrange element was created
498            with this Uid
499        """
500        if self.fitArrangeDict.has_key(Uid):
501            return self.fitArrangeDict[Uid].get_model()
502        else:
503            return None
504   
505    def remove_Fit_Problem(self,Uid):
506        """remove   fitarrange in Uid"""
507        if self.fitArrangeDict.has_key(Uid):
508            del self.fitArrangeDict[Uid]
509           
510    def select_problem_for_fit(self,Uid,value):
511        """
512            select a couple of model and data at the Uid position in dictionary
513            and set in self.selected value to value
514            @param value: the value to allow fitting. can only have the value one or zero
515        """
516        if self.fitArrangeDict.has_key(Uid):
517             self.fitArrangeDict[Uid].set_to_fit( value)
518    def get_problem_to_fit(self,Uid):
519        """
520            return the self.selected value of the fit problem of Uid
521           @param Uid: the Uid of the problem
522        """
523        if self.fitArrangeDict.has_key(Uid):
524             self.fitArrangeDict[Uid].get_to_fit()
525   
526class FitArrange:
527    def __init__(self):
528        """
529            Class FitArrange contains a set of data for a given model
530            to perform the Fit.FitArrange must contain exactly one model
531            and at least one data for the fit to be performed.
532            model: the model selected by the user
533            Ldata: a list of data what the user wants to fit
534           
535        """
536        self.model = None
537        self.dList =[]
538        self.pars=[]
539        #self.selected  is zero when this fit problem is not schedule to fit
540        #self.selected is 1 when schedule to fit
541        self.selected = 0
542       
543    def set_model(self,model):
544        """
545            set_model save a copy of the model
546            @param model: the model being set
547        """
548        self.model = model
549       
550    def add_data(self,data):
551        """
552            add_data fill a self.dList with data to fit
553            @param data: Data to add in the list 
554        """
555        if not data in self.dList:
556            self.dList.append(data)
557           
558    def get_model(self):
559        """ @return: saved model """
560        return self.model   
561     
562    def get_data(self):
563        """ @return:  list of data dList"""
564        #return self.dList
565        return self.dList[0] 
566     
567    def remove_data(self,data):
568        """
569            Remove one element from the list
570            @param data: Data to remove from dList
571        """
572        if data in self.dList:
573            self.dList.remove(data)
574    def set_to_fit (self, value=0):
575        """
576           set self.selected to 0 or 1  for other values raise an exception
577           @param value: integer between 0 or 1
578        """
579        self.selected= value
580       
581    def get_to_fit(self):
582        """
583            @return self.selected value
584        """
585        return self.selected
586   
587
588
589   
Note: See TracBrowser for help on using the repository browser.