source: sasview/park_integration/AbstractFitEngine.py @ 6aa47df

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 6aa47df was e71440c, checked in by Gervaise Alina <gervyh@…>, 16 years ago

changes on setparams to fix fit

  • Property mode set to 100644
File size: 12.3 KB
Line 
1
2import park,numpy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        lo,hi = self._model.details[self.name][1:]
41        if lo is None: lo = -numpy.inf
42        if hi is None: hi = numpy.inf
43        return lo,hi
44   
45    def _setrange(self,r):
46        """
47            override _setrange of park parameter
48            @param r: the value of the range to set
49        """
50        self._model.details[self.name][1:] = r
51    range = property(_getrange,_setrange)
52
53
54class Model(object):
55    """
56        PARK wrapper for SANS models.
57    """
58    def __init__(self, sans_model, **kw):
59        """
60            @param sans_model: the sans model to wrap using park interface
61        """
62        self.model = sans_model
63        self.name = sans_model.name
64        #list of parameters names
65        self.sansp = sans_model.getParamList()
66        #list of park parameter
67        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
68        #list of parameterset
69        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
70        self.pars=[]
71 
72 
73    def getParams(self,fitparams):
74        """
75            return a list of value of paramter to fit
76            @param fitparams: list of paramaters name to fit
77        """
78        list=[]
79        self.pars=[]
80        self.pars=fitparams
81        for item in fitparams:
82            for element in self.parkp:
83                 if element.name ==str(item):
84                     list.append(element.value)
85        return list
86   
87   
88    def setParams(self,paramlist, params):
89        """
90            Set value for parameters to fit
91            @param params: list of value for parameters to fit
92        """
93        try:
94            for i in range(len(self.parkp)):
95                for j in range(len(paramlist)):
96                    if self.parkp[i].name==paramlist[j]:
97                        self.parkp[i].value = params[j]
98                        self.model.setParam(self.parkp[i].name,params[j])
99        except:
100            raise
101 
102    def eval(self,x):
103        """
104            override eval method of park model.
105            @param x: the x value used to compute a function
106        """
107        return self.model.runXY(x)
108   
109   
110class Data(object):
111    """ Wrapper class  for SANS data """
112    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
113        """
114            Data can be initital with a data (sans plottable)
115            or with vectors.
116        """
117        if  sans_data !=None:
118            self.x= sans_data.x
119            self.y= sans_data.y
120            self.dx= sans_data.dx
121            self.dy= sans_data.dy
122           
123        elif (x!=None and y!=None and dy!=None):
124                self.x=x
125                self.y=y
126                self.dx=dx
127                self.dy=dy
128        else:
129            raise ValueError,\
130            "Data is missing x, y or dy, impossible to compute residuals later on"
131        self.qmin=None
132        self.qmax=None
133       
134       
135    def setFitRange(self,mini=None,maxi=None):
136        """ to set the fit range"""
137        self.qmin=mini
138        self.qmax=maxi
139       
140       
141    def getFitRange(self):
142        """
143            @return the range of data.x to fit
144        """
145        return self.qmin, self.qmax
146     
147     
148    def residuals(self, fn):
149        """ @param fn: function that return model value
150            @return residuals
151        """
152        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
153        if self.qmin==None and self.qmax==None: 
154            fx =numpy.asarray([fn(v) for v in x])
155            return (y - fx)/dy
156        else:
157            idx = (x>=self.qmin) & (x <= self.qmax)
158            fx = numpy.asarray([fn(item)for item in x[idx ]])
159            return (y[idx] - fx)/dy[idx]
160       
161    def residuals_deriv(self, model, pars=[]):
162        """
163            @return residuals derivatives .
164            @note: in this case just return empty array
165        """
166        return []
167   
168class sansAssembly:
169    """
170         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
171    """
172    def __init__(self,paramlist,Model=None , Data=None):
173        """
174            @param Model: the model wrapper fro sans -model
175            @param Data: the data wrapper for sans data
176        """
177        self.model = Model
178        self.data  = Data
179        self.paramlist=paramlist
180        self.res=[]
181    def chisq(self, params):
182        """
183            Calculates chi^2
184            @param params: list of parameter values
185            @return: chi^2
186        """
187        sum = 0
188        for item in self.res:
189            sum += item*item
190        return sum
191    def __call__(self,params):
192        """
193            Compute residuals
194            @param params: value of parameters to fit
195        """
196        self.model.setParams(self.paramlist,params)
197        self.res= self.data.residuals(self.model.eval)
198        return self.res
199   
200class FitEngine:
201    def __init__(self):
202        """
203            Base class for scipy and park fit engine
204        """
205        #List of parameter names to fit
206        self.paramList=[]
207        #Dictionnary of fitArrange element (fit problems)
208        self.fitArrangeDict={}
209       
210    def _concatenateData(self, listdata=[]):
211        """ 
212            _concatenateData method concatenates each fields of all data contains ins listdata.
213            @param listdata: list of data
214            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
215             of data concatenanted
216            @raise: if listdata is empty  will return None
217            @raise: if data in listdata don't contain dy field ,will create an error
218            during fitting
219        """
220        if listdata==[]:
221            raise ValueError, " data list missing"
222        else:
223            xtemp=[]
224            ytemp=[]
225            dytemp=[]
226            self.mini=None
227            self.maxi=None
228               
229            for data in listdata:
230                mini,maxi=data.getFitRange()
231                if self.mini==None and self.maxi==None:
232                    self.mini=mini
233                    self.maxi=maxi
234                else:
235                    if mini < self.mini:
236                        self.mini=mini
237                    if self.maxi < maxi:
238                        self.maxi=maxi
239                       
240                   
241                for i in range(len(data.x)):
242                    xtemp.append(data.x[i])
243                    ytemp.append(data.y[i])
244                    if data.dy is not None and len(data.dy)==len(data.y):   
245                        dytemp.append(data.dy[i])
246                    else:
247                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
248            data= Data(x=xtemp,y=ytemp,dy=dytemp)
249            data.setFitRange(self.mini, self.maxi)
250            return data
251       
252       
253    def set_model(self,model,Uid,pars=[]):
254        """
255            set a model on a given uid in the fit engine.
256            @param model: the model to fit
257            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
258            @param pars: the list of parameters to fit
259            @note : pars must contains only name of existing model's paramaters
260        """
261        if len(pars) >0:
262            if model==None:
263                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
264            else:
265                for item in pars:
266                    if item in model.model.getParamList():
267                        self.paramList.append(item)
268                    else:
269                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
270                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
271                        return
272            #A fitArrange is already created but contains dList only at Uid
273            if self.fitArrangeDict.has_key(Uid):
274                self.fitArrangeDict[Uid].set_model(model)
275            else:
276            #no fitArrange object has been create with this Uid
277                fitproblem = FitArrange()
278                fitproblem.set_model(model)
279                self.fitArrangeDict[Uid] = fitproblem
280        else:
281            raise ValueError, "park_integration:missing parameters"
282   
283    def set_data(self,data,Uid,qmin=None,qmax=None):
284        """ Receives plottable, creates a list of data to fit,set data
285            in a FitArrange object and adds that object in a dictionary
286            with key Uid.
287            @param data: data added
288            @param Uid: unique key corresponding to a fitArrange object with data
289        """
290        if qmin !=None and qmax !=None:
291            data.setFitRange(mini=qmin,maxi=qmax)
292        #A fitArrange is already created but contains model only at Uid
293        if self.fitArrangeDict.has_key(Uid):
294            self.fitArrangeDict[Uid].add_data(data)
295        else:
296        #no fitArrange object has been create with this Uid
297            fitproblem= FitArrange()
298            fitproblem.add_data(data)
299            self.fitArrangeDict[Uid]=fitproblem   
300   
301    def get_model(self,Uid):
302        """
303            @param Uid: Uid is key in the dictionary containing the model to return
304            @return  a model at this uid or None if no FitArrange element was created
305            with this Uid
306        """
307        if self.fitArrangeDict.has_key(Uid):
308            return self.fitArrangeDict[Uid].get_model()
309        else:
310            return None
311   
312    def remove_Fit_Problem(self,Uid):
313        """remove   fitarrange in Uid"""
314        if self.fitArrangeDict.has_key(Uid):
315            del self.fitArrangeDict[Uid]
316
317   
318class FitArrange:
319    def __init__(self):
320        """
321            Class FitArrange contains a set of data for a given model
322            to perform the Fit.FitArrange must contain exactly one model
323            and at least one data for the fit to be performed.
324            model: the model selected by the user
325            Ldata: a list of data what the user wants to fit
326           
327        """
328        self.model = None
329        self.dList =[]
330       
331    def set_model(self,model):
332        """
333            set_model save a copy of the model
334            @param model: the model being set
335        """
336        self.model = model
337       
338    def add_data(self,data):
339        """
340            add_data fill a self.dList with data to fit
341            @param data: Data to add in the list 
342        """
343        if not data in self.dList:
344            self.dList.append(data)
345           
346    def get_model(self):
347        """ @return: saved model """
348        return self.model   
349     
350    def get_data(self):
351        """ @return:  list of data dList"""
352        return self.dList
353     
354    def remove_data(self,data):
355        """
356            Remove one element from the list
357            @param data: Data to remove from dList
358        """
359        if data in self.dList:
360            self.dList.remove(data)
361   
362
363
364   
Note: See TracBrowser for help on using the repository browser.