source: sasview/park_integration/AbstractFitEngine.py @ 5f96484

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 5f96484 was ca6d914, checked in by Gervaise Alina <gervyh@…>, 16 years ago

some bugs fixed

  • Property mode set to 100644
File size: 12.1 KB
Line 
1
2import park,numpy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        lo,hi = self._model.details[self.name][1:]
41        if lo is None: lo = -numpy.inf
42        if hi is None: hi = numpy.inf
43        return lo,hi
44   
45    def _setrange(self,r):
46        """
47            override _setrange of park parameter
48            @param r: the value of the range to set
49        """
50        self._model.details[self.name][1:] = r
51    range = property(_getrange,_setrange)
52
53
54class Model(object):
55    """
56        PARK wrapper for SANS models.
57    """
58    def __init__(self, sans_model, **kw):
59        """
60            @param sans_model: the sans model to wrap using park interface
61        """
62        self.model = sans_model
63        self.name = sans_model.name
64        #list of parameters names
65        self.sansp = sans_model.getParamList()
66        #list of park parameter
67        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
68        #list of parameterset
69        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
70        self.pars=[]
71 
72 
73    def getParams(self,fitparams):
74        """
75            return a list of value of paramter to fit
76            @param fitparams: list of paramaters name to fit
77        """
78        list=[]
79        self.pars=[]
80        self.pars=fitparams
81        for item in fitparams:
82            for element in self.parkp:
83                 if element.name ==str(item):
84                     list.append(element.value)
85        return list
86   
87   
88    def setParams(self, params):
89        """
90            Set value for parameters to fit
91            @param params: list of value for parameters to fit
92        """
93        list=[]
94        for item in self.parkp:
95            list.append(item.name)
96        list.sort()
97        for i in range(len(params)):
98            self.parkp[i].value = params[i]
99            self.model.setParam(list[i],params[i])
100 
101 
102    def eval(self,x):
103        """
104            override eval method of park model.
105            @param x: the x value used to compute a function
106        """
107        return self.model.runXY(x)
108   
109   
110class Data(object):
111    """ Wrapper class  for SANS data """
112    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
113        """
114            Data can be initital with a data (sans plottable)
115            or with vectors.
116        """
117        if  sans_data !=None:
118            self.x= sans_data.x
119            self.y= sans_data.y
120            self.dx= sans_data.dx
121            self.dy= sans_data.dy
122           
123        elif (x!=None and y!=None and dy!=None):
124                self.x=x
125                self.y=y
126                self.dx=dx
127                self.dy=dy
128        else:
129            raise ValueError,\
130            "Data is missing x, y or dy, impossible to compute residuals later on"
131        self.qmin=None
132        self.qmax=None
133       
134       
135    def setFitRange(self,mini=None,maxi=None):
136        """ to set the fit range"""
137        self.qmin=mini
138        self.qmax=maxi
139       
140       
141    def getFitRange(self):
142        """
143            @return the range of data.x to fit
144        """
145        return self.qmin, self.qmax
146     
147     
148    def residuals(self, fn):
149        """ @param fn: function that return model value
150            @return residuals
151        """
152        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
153        if self.qmin==None and self.qmax==None: 
154            fx =numpy.asarray([fn(v) for v in x])
155            return (y - fx)/dy
156        else:
157            idx = (x>=self.qmin) & (x <= self.qmax)
158            fx = numpy.asarray([fn(item)for item in x[idx ]])
159            return (y[idx] - fx)/dy[idx]
160         
161           
162         
163    def residuals_deriv(self, model, pars=[]):
164        """
165            @return residuals derivatives .
166            @note: in this case just return empty array
167        """
168        return []
169   
170class sansAssembly:
171    """
172         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
173    """
174    def __init__(self,Model=None , Data=None):
175        """
176            @param Model: the model wrapper fro sans -model
177            @param Data: the data wrapper for sans data
178        """
179        self.model = Model
180        self.data  = Data
181        self.res=[]
182    def chisq(self, params):
183        """
184            Calculates chi^2
185            @param params: list of parameter values
186            @return: chi^2
187        """
188        sum = 0
189        for item in self.res:
190            sum += item*item
191        return sum
192    def __call__(self,params):
193        """
194            Compute residuals
195            @param params: value of parameters to fit
196        """
197        self.model.setParams(params)
198        self.res= self.data.residuals(self.model.eval)
199        return self.res
200   
201class FitEngine:
202    def __init__(self):
203        """
204            Base class for scipy and park fit engine
205        """
206        #List of parameter names to fit
207        self.paramList=[]
208        #Dictionnary of fitArrange element (fit problems)
209        self.fitArrangeDict={}
210       
211    def _concatenateData(self, listdata=[]):
212        """ 
213            _concatenateData method concatenates each fields of all data contains ins listdata.
214            @param listdata: list of data
215            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
216             of data concatenanted
217            @raise: if listdata is empty  will return None
218            @raise: if data in listdata don't contain dy field ,will create an error
219            during fitting
220        """
221        if listdata==[]:
222            raise ValueError, " data list missing"
223        else:
224            xtemp=[]
225            ytemp=[]
226            dytemp=[]
227            self.mini=None
228            self.maxi=None
229               
230            for data in listdata:
231                mini,maxi=data.getFitRange()
232                if self.mini==None and self.maxi==None:
233                    self.mini=mini
234                    self.maxi=maxi
235                else:
236                    if mini < self.mini:
237                        self.mini=mini
238                    if self.maxi < maxi:
239                        self.maxi=maxi
240                       
241                   
242                for i in range(len(data.x)):
243                    xtemp.append(data.x[i])
244                    ytemp.append(data.y[i])
245                    if data.dy is not None and len(data.dy)==len(data.y):   
246                        dytemp.append(data.dy[i])
247                    else:
248                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
249            data= Data(x=xtemp,y=ytemp,dy=dytemp)
250            data.setFitRange(self.mini, self.maxi)
251            return data
252       
253       
254    def set_model(self,model,Uid,pars=[]):
255        """
256            set a model on a given uid in the fit engine.
257            @param model: the model to fit
258            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
259            @param pars: the list of parameters to fit
260            @note : pars must contains only name of existing model's paramaters
261        """
262        if len(pars) >0:
263            if model==None:
264                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
265            else:
266                for item in pars:
267                    if item in model.model.getParamList():
268                        self.paramList.append(item)
269                    else:
270                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
271                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
272                        return
273            #A fitArrange is already created but contains dList only at Uid
274            if self.fitArrangeDict.has_key(Uid):
275                self.fitArrangeDict[Uid].set_model(model)
276            else:
277            #no fitArrange object has been create with this Uid
278                fitproblem = FitArrange()
279                fitproblem.set_model(model)
280                self.fitArrangeDict[Uid] = fitproblem
281        else:
282            raise ValueError, "park_integration:missing parameters"
283   
284    def set_data(self,data,Uid,qmin=None,qmax=None):
285        """ Receives plottable, creates a list of data to fit,set data
286            in a FitArrange object and adds that object in a dictionary
287            with key Uid.
288            @param data: data added
289            @param Uid: unique key corresponding to a fitArrange object with data
290        """
291        if qmin !=None and qmax !=None:
292            data.setFitRange(mini=qmin,maxi=qmax)
293        #A fitArrange is already created but contains model only at Uid
294        if self.fitArrangeDict.has_key(Uid):
295            self.fitArrangeDict[Uid].add_data(data)
296        else:
297        #no fitArrange object has been create with this Uid
298            fitproblem= FitArrange()
299            fitproblem.add_data(data)
300            self.fitArrangeDict[Uid]=fitproblem   
301   
302    def get_model(self,Uid):
303        """
304            @param Uid: Uid is key in the dictionary containing the model to return
305            @return  a model at this uid or None if no FitArrange element was created
306            with this Uid
307        """
308        if self.fitArrangeDict.has_key(Uid):
309            return self.fitArrangeDict[Uid].get_model()
310        else:
311            return None
312   
313    def remove_Fit_Problem(self,Uid):
314        """remove   fitarrange in Uid"""
315        if self.fitArrangeDict.has_key(Uid):
316            del self.fitArrangeDict[Uid]
317
318   
319class FitArrange:
320    def __init__(self):
321        """
322            Class FitArrange contains a set of data for a given model
323            to perform the Fit.FitArrange must contain exactly one model
324            and at least one data for the fit to be performed.
325            model: the model selected by the user
326            Ldata: a list of data what the user wants to fit
327           
328        """
329        self.model = None
330        self.dList =[]
331       
332    def set_model(self,model):
333        """
334            set_model save a copy of the model
335            @param model: the model being set
336        """
337        self.model = model
338       
339    def add_data(self,data):
340        """
341            add_data fill a self.dList with data to fit
342            @param data: Data to add in the list 
343        """
344        if not data in self.dList:
345            self.dList.append(data)
346           
347    def get_model(self):
348        """ @return: saved model """
349        return self.model   
350     
351    def get_data(self):
352        """ @return:  list of data dList"""
353        return self.dList
354     
355    def remove_data(self,data):
356        """
357            Remove one element from the list
358            @param data: Data to remove from dList
359        """
360        if data in self.dList:
361            self.dList.remove(data)
362   
363
364
365   
Note: See TracBrowser for help on using the repository browser.