source: sasview/park_integration/AbstractFitEngine.py @ 385d464

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 385d464 was 7d0c1a8, checked in by Gervaise Alina <gervyh@…>, 16 years ago

introduced fitdata1D and fitdata2D class instead of Data class.

  • Property mode set to 100644
File size: 16.8 KB
RevLine 
[4c718654]1
[48882d1]2import park,numpy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
[ca6d914]11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
[48882d1]18         
[ca6d914]19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
[48882d1]25   
[ca6d914]26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
[48882d1]31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
[ca6d914]36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
[48882d1]40        lo,hi = self._model.details[self.name][1:]
41        if lo is None: lo = -numpy.inf
42        if hi is None: hi = numpy.inf
43        return lo,hi
44   
45    def _setrange(self,r):
[ca6d914]46        """
47            override _setrange of park parameter
48            @param r: the value of the range to set
49        """
[48882d1]50        self._model.details[self.name][1:] = r
51    range = property(_getrange,_setrange)
[a9e04aa]52   
53class Model(park.Model):
[48882d1]54    """
55        PARK wrapper for SANS models.
56    """
[388309d]57    def __init__(self, sans_model, **kw):
[ca6d914]58        """
59            @param sans_model: the sans model to wrap using park interface
60        """
[a9e04aa]61        park.Model.__init__(self, **kw)
[48882d1]62        self.model = sans_model
[ca6d914]63        self.name = sans_model.name
64        #list of parameters names
[48882d1]65        self.sansp = sans_model.getParamList()
[ca6d914]66        #list of park parameter
[48882d1]67        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
[ca6d914]68        #list of parameterset
[48882d1]69        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
70        self.pars=[]
[ca6d914]71 
72 
[48882d1]73    def getParams(self,fitparams):
[ca6d914]74        """
75            return a list of value of paramter to fit
76            @param fitparams: list of paramaters name to fit
77        """
[48882d1]78        list=[]
79        self.pars=[]
80        self.pars=fitparams
81        for item in fitparams:
82            for element in self.parkp:
83                 if element.name ==str(item):
84                     list.append(element.value)
85        return list
86   
[ca6d914]87   
[e71440c]88    def setParams(self,paramlist, params):
[ca6d914]89        """
90            Set value for parameters to fit
91            @param params: list of value for parameters to fit
92        """
[e71440c]93        try:
94            for i in range(len(self.parkp)):
95                for j in range(len(paramlist)):
96                    if self.parkp[i].name==paramlist[j]:
97                        self.parkp[i].value = params[j]
98                        self.model.setParam(self.parkp[i].name,params[j])
99        except:
100            raise
[ca6d914]101 
[48882d1]102    def eval(self,x):
[ca6d914]103        """
104            override eval method of park model.
105            @param x: the x value used to compute a function
106        """
[48882d1]107        return self.model.runXY(x)
[388309d]108   
109   
[a9e04aa]110
111
[48882d1]112class Data(object):
113    """ Wrapper class  for SANS data """
114    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
[ca6d914]115        """
116            Data can be initital with a data (sans plottable)
117            or with vectors.
118        """
[48882d1]119        if  sans_data !=None:
120            self.x= sans_data.x
121            self.y= sans_data.y
122            self.dx= sans_data.dx
123            self.dy= sans_data.dy
124           
125        elif (x!=None and y!=None and dy!=None):
126                self.x=x
127                self.y=y
128                self.dx=dx
129                self.dy=dy
130        else:
131            raise ValueError,\
132            "Data is missing x, y or dy, impossible to compute residuals later on"
133        self.qmin=None
134        self.qmax=None
135       
[ca6d914]136       
[48882d1]137    def setFitRange(self,mini=None,maxi=None):
138        """ to set the fit range"""
139        self.qmin=mini
140        self.qmax=maxi
[ca6d914]141       
142       
[48882d1]143    def getFitRange(self):
[ca6d914]144        """
145            @return the range of data.x to fit
146        """
147        return self.qmin, self.qmax
148     
149     
[48882d1]150    def residuals(self, fn):
151        """ @param fn: function that return model value
152            @return residuals
153        """
154        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
155        if self.qmin==None and self.qmax==None: 
[ca6d914]156            fx =numpy.asarray([fn(v) for v in x])
[48882d1]157            return (y - fx)/dy
158        else:
159            idx = (x>=self.qmin) & (x <= self.qmax)
[ca6d914]160            fx = numpy.asarray([fn(item)for item in x[idx ]])
[48882d1]161            return (y[idx] - fx)/dy[idx]
[e71440c]162       
[48882d1]163    def residuals_deriv(self, model, pars=[]):
164        """
165            @return residuals derivatives .
166            @note: in this case just return empty array
167        """
168        return []
[7d0c1a8]169class FitData1D(object):
170    """ Wrapper class  for SANS data """
171    def __init__(self,sans_data1d):
172        """
173            Data can be initital with a data (sans plottable)
174            or with vectors.
175        """
176        self.data=sans_data1d
177        self.x= sans_data1d.x
178        self.y= sans_data1d.y
179        self.dx= sans_data1d.dx
180        self.dy= sans_data1d.dy
181        self.qmin=None
182        self.qmax=None
183       
184       
185    def setFitRange(self,mini=None,maxi=None):
186        """ to set the fit range"""
187        self.qmin=mini
188        self.qmax=maxi
189       
190       
191    def getFitRange(self):
192        """
193            @return the range of data.x to fit
194        """
195        return self.qmin, self.qmax
196     
197     
198    def residuals(self, fn):
199        """ @param fn: function that return model value
200            @return residuals
201        """
202        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
203        if self.qmin==None and self.qmax==None: 
204            fx =numpy.asarray([fn(v) for v in x])
205            return (y - fx)/dy
206        else:
207            idx = (x>=self.qmin) & (x <= self.qmax)
208            fx = numpy.asarray([fn(item)for item in x[idx ]])
209            return (y[idx] - fx)/dy[idx]
210       
211    def residuals_deriv(self, model, pars=[]):
212        """
213            @return residuals derivatives .
214            @note: in this case just return empty array
215        """
216        return []
217   
218   
219class FitData2D(object):
220    """ Wrapper class  for SANS data """
221    def __init__(self,sans_data2d):
222        """
223            Data can be initital with a data (sans plottable)
224            or with vectors.
225        """
226        self.data=sans_data2d
227        self.image = sans_data2d.image
228        self.err_image = sans_data2d.err_image
229        self.x_bins= sans_data2d.x_bins
230        self.y_bins= sans_data2d.y_bins
231       
232        self.qmin= None
233        self.qmax= None
234       
235       
236    def setFitRange(self,mini=None,maxi=None):
237        """ to set the fit range"""
238        self.qmin= mini
239        self.qmax= maxi
240       
241       
242    def getFitRange(self):
243        """
244            @return the range of data.x to fit
245        """
246        return self.qmin, self.qmax
247     
248     
249    def residuals(self, fn):
250        """ @param fn: function that return model value
251            @return residuals
252        """
253        res=[]
254        if self.qmin==None and self.qmax==None: 
255            for i in range(len(self.x_bins)):
256                res.append( (self.image[i][i]- fn([self.x_bins[i],self.y_bins[i]]))\
257                            /self.err_image[i][i] )
258            return numpy.array(res)
259        else:
260            #idx = (x>=self.qmin) & (x <= self.qmax)
261            #fx = numpy.asarray([fn(item)for item in x[idx ]])
262            #return (y[idx] - fx)/dy[idx]
263            for i in range(len(self.x_bins)):
264                res.append( (self.image[i][i]- fn([self.x_bins[i],self.y_bins[i]]))\
265                            /self.err_image[i][i] )
266            return numpy.array(res)
267    def residuals_deriv(self, model, pars=[]):
268        """
269            @return residuals derivatives .
270            @note: in this case just return empty array
271        """
272        return []
[48882d1]273   
274class sansAssembly:
[ca6d914]275    """
276         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
277    """
[e71440c]278    def __init__(self,paramlist,Model=None , Data=None):
[ca6d914]279        """
280            @param Model: the model wrapper fro sans -model
281            @param Data: the data wrapper for sans data
282        """
283        self.model = Model
284        self.data  = Data
[e71440c]285        self.paramlist=paramlist
[ca6d914]286        self.res=[]
[48882d1]287    def chisq(self, params):
288        """
289            Calculates chi^2
290            @param params: list of parameter values
291            @return: chi^2
292        """
293        sum = 0
294        for item in self.res:
295            sum += item*item
296        return sum
297    def __call__(self,params):
[ca6d914]298        """
299            Compute residuals
300            @param params: value of parameters to fit
301        """
[e71440c]302        self.model.setParams(self.paramlist,params)
[48882d1]303        self.res= self.data.residuals(self.model.eval)
304        return self.res
305   
[4c718654]306class FitEngine:
[ee5b04c]307    def __init__(self):
[ca6d914]308        """
309            Base class for scipy and park fit engine
310        """
311        #List of parameter names to fit
[ee5b04c]312        self.paramList=[]
[ca6d914]313        #Dictionnary of fitArrange element (fit problems)
314        self.fitArrangeDict={}
315       
[4c718654]316    def _concatenateData(self, listdata=[]):
317        """ 
318            _concatenateData method concatenates each fields of all data contains ins listdata.
319            @param listdata: list of data
[ca6d914]320            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
321             of data concatenanted
[4c718654]322            @raise: if listdata is empty  will return None
323            @raise: if data in listdata don't contain dy field ,will create an error
324            during fitting
325        """
326        if listdata==[]:
327            raise ValueError, " data list missing"
328        else:
329            xtemp=[]
330            ytemp=[]
331            dytemp=[]
[48882d1]332            self.mini=None
333            self.maxi=None
[4c718654]334               
[7d0c1a8]335            for item in listdata:
336                data=item.data
[48882d1]337                mini,maxi=data.getFitRange()
338                if self.mini==None and self.maxi==None:
339                    self.mini=mini
340                    self.maxi=maxi
341                else:
342                    if mini < self.mini:
343                        self.mini=mini
344                    if self.maxi < maxi:
345                        self.maxi=maxi
346                       
347                   
[4c718654]348                for i in range(len(data.x)):
349                    xtemp.append(data.x[i])
350                    ytemp.append(data.y[i])
351                    if data.dy is not None and len(data.dy)==len(data.y):   
352                        dytemp.append(data.dy[i])
353                    else:
[ee5b04c]354                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
[48882d1]355            data= Data(x=xtemp,y=ytemp,dy=dytemp)
356            data.setFitRange(self.mini, self.maxi)
357            return data
[ca6d914]358       
359       
360    def set_model(self,model,Uid,pars=[]):
361        """
362            set a model on a given uid in the fit engine.
363            @param model: the model to fit
364            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
365            @param pars: the list of parameters to fit
366            @note : pars must contains only name of existing model's paramaters
367        """
[f44dbc7]368        if len(pars) >0:
[6831a99]369            if model==None:
[f44dbc7]370                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
[6831a99]371            else:
[ca6d914]372                for item in pars:
373                    if item in model.model.getParamList():
374                        self.paramList.append(item)
375                    else:
376                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
377                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
378                        return
[6831a99]379            #A fitArrange is already created but contains dList only at Uid
[ca6d914]380            if self.fitArrangeDict.has_key(Uid):
381                self.fitArrangeDict[Uid].set_model(model)
[6831a99]382            else:
383            #no fitArrange object has been create with this Uid
[48882d1]384                fitproblem = FitArrange()
[6831a99]385                fitproblem.set_model(model)
[ca6d914]386                self.fitArrangeDict[Uid] = fitproblem
[d4b0687]387        else:
[6831a99]388            raise ValueError, "park_integration:missing parameters"
[48882d1]389   
390    def set_data(self,data,Uid,qmin=None,qmax=None):
[d4b0687]391        """ Receives plottable, creates a list of data to fit,set data
392            in a FitArrange object and adds that object in a dictionary
393            with key Uid.
394            @param data: data added
395            @param Uid: unique key corresponding to a fitArrange object with data
[ca6d914]396        """
[48882d1]397        if qmin !=None and qmax !=None:
398            data.setFitRange(mini=qmin,maxi=qmax)
[d4b0687]399        #A fitArrange is already created but contains model only at Uid
[ca6d914]400        if self.fitArrangeDict.has_key(Uid):
401            self.fitArrangeDict[Uid].add_data(data)
[d4b0687]402        else:
403        #no fitArrange object has been create with this Uid
404            fitproblem= FitArrange()
405            fitproblem.add_data(data)
[ca6d914]406            self.fitArrangeDict[Uid]=fitproblem   
[48882d1]407   
[d4b0687]408    def get_model(self,Uid):
409        """
410            @param Uid: Uid is key in the dictionary containing the model to return
411            @return  a model at this uid or None if no FitArrange element was created
412            with this Uid
413        """
[ca6d914]414        if self.fitArrangeDict.has_key(Uid):
415            return self.fitArrangeDict[Uid].get_model()
[d4b0687]416        else:
417            return None
418   
419    def remove_Fit_Problem(self,Uid):
420        """remove   fitarrange in Uid"""
[ca6d914]421        if self.fitArrangeDict.has_key(Uid):
422            del self.fitArrangeDict[Uid]
[a9e04aa]423           
424    def select_problem_for_fit(self,Uid,value):
425        """
426            select a couple of model and data at the Uid position in dictionary
427            and set in self.selected value to value
428            @param value: the value to allow fitting. can only have the value one or zero
429        """
430        if self.fitArrangeDict.has_key(Uid):
431             self.fitArrangeDict[Uid].set_to_fit( value)
432    def get_problem_to_fit(self,Uid):
433        """
434            return the self.selected value of the fit problem of Uid
435           @param Uid: the Uid of the problem
436        """
437        if self.fitArrangeDict.has_key(Uid):
438             self.fitArrangeDict[Uid].get_to_fit()
[4c718654]439   
[d4b0687]440class FitArrange:
441    def __init__(self):
442        """
443            Class FitArrange contains a set of data for a given model
444            to perform the Fit.FitArrange must contain exactly one model
445            and at least one data for the fit to be performed.
446            model: the model selected by the user
447            Ldata: a list of data what the user wants to fit
448           
449        """
450        self.model = None
451        self.dList =[]
[a9e04aa]452        #self.selected  is zero when this fit problem is not schedule to fit
453        #self.selected is 1 when schedule to fit
454        self.selected = 0
[d4b0687]455       
456    def set_model(self,model):
457        """
458            set_model save a copy of the model
459            @param model: the model being set
460        """
461        self.model = model
462       
463    def add_data(self,data):
464        """
465            add_data fill a self.dList with data to fit
466            @param data: Data to add in the list 
467        """
468        if not data in self.dList:
469            self.dList.append(data)
470           
471    def get_model(self):
472        """ @return: saved model """
473        return self.model   
474     
475    def get_data(self):
476        """ @return:  list of data dList"""
[7d0c1a8]477        #return self.dList
478        return self.dList[0] 
[d4b0687]479     
480    def remove_data(self,data):
481        """
482            Remove one element from the list
483            @param data: Data to remove from dList
484        """
485        if data in self.dList:
486            self.dList.remove(data)
[a9e04aa]487    def set_to_fit (self, value=0):
488        """
489           set self.selected to 0 or 1  for other values raise an exception
490           @param value: integer between 0 or 1
491        """
492        self.selected= value
493       
494    def get_to_fit(self):
495        """
496            @return self.selected value
497        """
498        return self.selected
[94b44293]499   
[4c718654]500
501
502   
Note: See TracBrowser for help on using the repository browser.