source: sasview/park_integration/AbstractFitEngine.py @ fb8d4050

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since fb8d4050 was a9e04aa, checked in by Gervaise Alina <gervyh@…>, 16 years ago

added fitting selection methods

  • Property mode set to 100644
File size: 13.5 KB
Line 
1
2import park,numpy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        lo,hi = self._model.details[self.name][1:]
41        if lo is None: lo = -numpy.inf
42        if hi is None: hi = numpy.inf
43        return lo,hi
44   
45    def _setrange(self,r):
46        """
47            override _setrange of park parameter
48            @param r: the value of the range to set
49        """
50        self._model.details[self.name][1:] = r
51    range = property(_getrange,_setrange)
52   
53class Model(park.Model):
54    """
55        PARK wrapper for SANS models.
56    """
57    def __init__(self, sans_model, **kw):
58        """
59            @param sans_model: the sans model to wrap using park interface
60        """
61        park.Model.__init__(self, **kw)
62        self.model = sans_model
63        self.name = sans_model.name
64        #list of parameters names
65        self.sansp = sans_model.getParamList()
66        #list of park parameter
67        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
68        #list of parameterset
69        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
70        self.pars=[]
71 
72 
73    def getParams(self,fitparams):
74        """
75            return a list of value of paramter to fit
76            @param fitparams: list of paramaters name to fit
77        """
78        list=[]
79        self.pars=[]
80        self.pars=fitparams
81        for item in fitparams:
82            for element in self.parkp:
83                 if element.name ==str(item):
84                     list.append(element.value)
85        return list
86   
87   
88    def setParams(self,paramlist, params):
89        """
90            Set value for parameters to fit
91            @param params: list of value for parameters to fit
92        """
93        try:
94            for i in range(len(self.parkp)):
95                for j in range(len(paramlist)):
96                    if self.parkp[i].name==paramlist[j]:
97                        self.parkp[i].value = params[j]
98                        self.model.setParam(self.parkp[i].name,params[j])
99        except:
100            raise
101 
102    def eval(self,x):
103        """
104            override eval method of park model.
105            @param x: the x value used to compute a function
106        """
107        return self.model.runXY(x)
108   
109   
110
111
112class Data(object):
113    """ Wrapper class  for SANS data """
114    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
115        """
116            Data can be initital with a data (sans plottable)
117            or with vectors.
118        """
119        if  sans_data !=None:
120            self.x= sans_data.x
121            self.y= sans_data.y
122            self.dx= sans_data.dx
123            self.dy= sans_data.dy
124           
125        elif (x!=None and y!=None and dy!=None):
126                self.x=x
127                self.y=y
128                self.dx=dx
129                self.dy=dy
130        else:
131            raise ValueError,\
132            "Data is missing x, y or dy, impossible to compute residuals later on"
133        self.qmin=None
134        self.qmax=None
135       
136       
137    def setFitRange(self,mini=None,maxi=None):
138        """ to set the fit range"""
139        self.qmin=mini
140        self.qmax=maxi
141       
142       
143    def getFitRange(self):
144        """
145            @return the range of data.x to fit
146        """
147        return self.qmin, self.qmax
148     
149     
150    def residuals(self, fn):
151        """ @param fn: function that return model value
152            @return residuals
153        """
154        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
155        if self.qmin==None and self.qmax==None: 
156            fx =numpy.asarray([fn(v) for v in x])
157            return (y - fx)/dy
158        else:
159            idx = (x>=self.qmin) & (x <= self.qmax)
160            fx = numpy.asarray([fn(item)for item in x[idx ]])
161            return (y[idx] - fx)/dy[idx]
162       
163    def residuals_deriv(self, model, pars=[]):
164        """
165            @return residuals derivatives .
166            @note: in this case just return empty array
167        """
168        return []
169   
170class sansAssembly:
171    """
172         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
173    """
174    def __init__(self,paramlist,Model=None , Data=None):
175        """
176            @param Model: the model wrapper fro sans -model
177            @param Data: the data wrapper for sans data
178        """
179        self.model = Model
180        self.data  = Data
181        self.paramlist=paramlist
182        self.res=[]
183    def chisq(self, params):
184        """
185            Calculates chi^2
186            @param params: list of parameter values
187            @return: chi^2
188        """
189        sum = 0
190        for item in self.res:
191            sum += item*item
192        return sum
193    def __call__(self,params):
194        """
195            Compute residuals
196            @param params: value of parameters to fit
197        """
198        self.model.setParams(self.paramlist,params)
199        self.res= self.data.residuals(self.model.eval)
200        return self.res
201   
202class FitEngine:
203    def __init__(self):
204        """
205            Base class for scipy and park fit engine
206        """
207        #List of parameter names to fit
208        self.paramList=[]
209        #Dictionnary of fitArrange element (fit problems)
210        self.fitArrangeDict={}
211       
212    def _concatenateData(self, listdata=[]):
213        """ 
214            _concatenateData method concatenates each fields of all data contains ins listdata.
215            @param listdata: list of data
216            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
217             of data concatenanted
218            @raise: if listdata is empty  will return None
219            @raise: if data in listdata don't contain dy field ,will create an error
220            during fitting
221        """
222        if listdata==[]:
223            raise ValueError, " data list missing"
224        else:
225            xtemp=[]
226            ytemp=[]
227            dytemp=[]
228            self.mini=None
229            self.maxi=None
230               
231            for data in listdata:
232                mini,maxi=data.getFitRange()
233                if self.mini==None and self.maxi==None:
234                    self.mini=mini
235                    self.maxi=maxi
236                else:
237                    if mini < self.mini:
238                        self.mini=mini
239                    if self.maxi < maxi:
240                        self.maxi=maxi
241                       
242                   
243                for i in range(len(data.x)):
244                    xtemp.append(data.x[i])
245                    ytemp.append(data.y[i])
246                    if data.dy is not None and len(data.dy)==len(data.y):   
247                        dytemp.append(data.dy[i])
248                    else:
249                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
250            data= Data(x=xtemp,y=ytemp,dy=dytemp)
251            data.setFitRange(self.mini, self.maxi)
252            return data
253       
254       
255    def set_model(self,model,Uid,pars=[]):
256        """
257            set a model on a given uid in the fit engine.
258            @param model: the model to fit
259            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
260            @param pars: the list of parameters to fit
261            @note : pars must contains only name of existing model's paramaters
262        """
263        if len(pars) >0:
264            if model==None:
265                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
266            else:
267                for item in pars:
268                    if item in model.model.getParamList():
269                        self.paramList.append(item)
270                    else:
271                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
272                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
273                        return
274            #A fitArrange is already created but contains dList only at Uid
275            if self.fitArrangeDict.has_key(Uid):
276                self.fitArrangeDict[Uid].set_model(model)
277            else:
278            #no fitArrange object has been create with this Uid
279                fitproblem = FitArrange()
280                fitproblem.set_model(model)
281                self.fitArrangeDict[Uid] = fitproblem
282        else:
283            raise ValueError, "park_integration:missing parameters"
284   
285    def set_data(self,data,Uid,qmin=None,qmax=None):
286        """ Receives plottable, creates a list of data to fit,set data
287            in a FitArrange object and adds that object in a dictionary
288            with key Uid.
289            @param data: data added
290            @param Uid: unique key corresponding to a fitArrange object with data
291        """
292        if qmin !=None and qmax !=None:
293            data.setFitRange(mini=qmin,maxi=qmax)
294        #A fitArrange is already created but contains model only at Uid
295        if self.fitArrangeDict.has_key(Uid):
296            self.fitArrangeDict[Uid].add_data(data)
297        else:
298        #no fitArrange object has been create with this Uid
299            fitproblem= FitArrange()
300            fitproblem.add_data(data)
301            self.fitArrangeDict[Uid]=fitproblem   
302   
303    def get_model(self,Uid):
304        """
305            @param Uid: Uid is key in the dictionary containing the model to return
306            @return  a model at this uid or None if no FitArrange element was created
307            with this Uid
308        """
309        if self.fitArrangeDict.has_key(Uid):
310            return self.fitArrangeDict[Uid].get_model()
311        else:
312            return None
313   
314    def remove_Fit_Problem(self,Uid):
315        """remove   fitarrange in Uid"""
316        if self.fitArrangeDict.has_key(Uid):
317            del self.fitArrangeDict[Uid]
318           
319    def select_problem_for_fit(self,Uid,value):
320        """
321            select a couple of model and data at the Uid position in dictionary
322            and set in self.selected value to value
323            @param value: the value to allow fitting. can only have the value one or zero
324        """
325        if self.fitArrangeDict.has_key(Uid):
326             self.fitArrangeDict[Uid].set_to_fit( value)
327    def get_problem_to_fit(self,Uid):
328        """
329            return the self.selected value of the fit problem of Uid
330           @param Uid: the Uid of the problem
331        """
332        if self.fitArrangeDict.has_key(Uid):
333             self.fitArrangeDict[Uid].get_to_fit()
334   
335class FitArrange:
336    def __init__(self):
337        """
338            Class FitArrange contains a set of data for a given model
339            to perform the Fit.FitArrange must contain exactly one model
340            and at least one data for the fit to be performed.
341            model: the model selected by the user
342            Ldata: a list of data what the user wants to fit
343           
344        """
345        self.model = None
346        self.dList =[]
347        #self.selected  is zero when this fit problem is not schedule to fit
348        #self.selected is 1 when schedule to fit
349        self.selected = 0
350       
351    def set_model(self,model):
352        """
353            set_model save a copy of the model
354            @param model: the model being set
355        """
356        self.model = model
357       
358    def add_data(self,data):
359        """
360            add_data fill a self.dList with data to fit
361            @param data: Data to add in the list 
362        """
363        if not data in self.dList:
364            self.dList.append(data)
365           
366    def get_model(self):
367        """ @return: saved model """
368        return self.model   
369     
370    def get_data(self):
371        """ @return:  list of data dList"""
372        return self.dList
373     
374    def remove_data(self,data):
375        """
376            Remove one element from the list
377            @param data: Data to remove from dList
378        """
379        if data in self.dList:
380            self.dList.remove(data)
381    def set_to_fit (self, value=0):
382        """
383           set self.selected to 0 or 1  for other values raise an exception
384           @param value: integer between 0 or 1
385        """
386        self.selected= value
387       
388    def get_to_fit(self):
389        """
390            @return self.selected value
391        """
392        return self.selected
393   
394
395
396   
Note: See TracBrowser for help on using the repository browser.