source: sasview/park_integration/AbstractFitEngine.py @ 1f8accb

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 1f8accb was ca6d914, checked in by Gervaise Alina <gervyh@…>, 16 years ago

some bugs fixed

  • Property mode set to 100644
File size: 12.1 KB
RevLine 
[4c718654]1
[48882d1]2import park,numpy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
[ca6d914]11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
[48882d1]18         
[ca6d914]19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
[48882d1]25   
[ca6d914]26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
[48882d1]31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
[ca6d914]36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
[48882d1]40        lo,hi = self._model.details[self.name][1:]
41        if lo is None: lo = -numpy.inf
42        if hi is None: hi = numpy.inf
43        return lo,hi
44   
45    def _setrange(self,r):
[ca6d914]46        """
47            override _setrange of park parameter
48            @param r: the value of the range to set
49        """
[48882d1]50        self._model.details[self.name][1:] = r
51    range = property(_getrange,_setrange)
52
53
54class Model(object):
55    """
56        PARK wrapper for SANS models.
57    """
[388309d]58    def __init__(self, sans_model, **kw):
[ca6d914]59        """
60            @param sans_model: the sans model to wrap using park interface
61        """
[48882d1]62        self.model = sans_model
[ca6d914]63        self.name = sans_model.name
64        #list of parameters names
[48882d1]65        self.sansp = sans_model.getParamList()
[ca6d914]66        #list of park parameter
[48882d1]67        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
[ca6d914]68        #list of parameterset
[48882d1]69        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
70        self.pars=[]
[ca6d914]71 
72 
[48882d1]73    def getParams(self,fitparams):
[ca6d914]74        """
75            return a list of value of paramter to fit
76            @param fitparams: list of paramaters name to fit
77        """
[48882d1]78        list=[]
79        self.pars=[]
80        self.pars=fitparams
81        for item in fitparams:
82            for element in self.parkp:
83                 if element.name ==str(item):
84                     list.append(element.value)
85        return list
86   
[ca6d914]87   
[48882d1]88    def setParams(self, params):
[ca6d914]89        """
90            Set value for parameters to fit
91            @param params: list of value for parameters to fit
92        """
[48882d1]93        list=[]
94        for item in self.parkp:
95            list.append(item.name)
96        list.sort()
97        for i in range(len(params)):
[388309d]98            self.parkp[i].value = params[i]
[48882d1]99            self.model.setParam(list[i],params[i])
100 
[ca6d914]101 
[48882d1]102    def eval(self,x):
[ca6d914]103        """
104            override eval method of park model.
105            @param x: the x value used to compute a function
106        """
[48882d1]107        return self.model.runXY(x)
[388309d]108   
109   
[48882d1]110class Data(object):
111    """ Wrapper class  for SANS data """
112    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
[ca6d914]113        """
114            Data can be initital with a data (sans plottable)
115            or with vectors.
116        """
[48882d1]117        if  sans_data !=None:
118            self.x= sans_data.x
119            self.y= sans_data.y
120            self.dx= sans_data.dx
121            self.dy= sans_data.dy
122           
123        elif (x!=None and y!=None and dy!=None):
124                self.x=x
125                self.y=y
126                self.dx=dx
127                self.dy=dy
128        else:
129            raise ValueError,\
130            "Data is missing x, y or dy, impossible to compute residuals later on"
131        self.qmin=None
132        self.qmax=None
133       
[ca6d914]134       
[48882d1]135    def setFitRange(self,mini=None,maxi=None):
136        """ to set the fit range"""
137        self.qmin=mini
138        self.qmax=maxi
[ca6d914]139       
140       
[48882d1]141    def getFitRange(self):
[ca6d914]142        """
143            @return the range of data.x to fit
144        """
145        return self.qmin, self.qmax
146     
147     
[48882d1]148    def residuals(self, fn):
149        """ @param fn: function that return model value
150            @return residuals
151        """
152        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
153        if self.qmin==None and self.qmax==None: 
[ca6d914]154            fx =numpy.asarray([fn(v) for v in x])
[48882d1]155            return (y - fx)/dy
156        else:
157            idx = (x>=self.qmin) & (x <= self.qmax)
[ca6d914]158            fx = numpy.asarray([fn(item)for item in x[idx ]])
[48882d1]159            return (y[idx] - fx)/dy[idx]
160         
161           
162         
163    def residuals_deriv(self, model, pars=[]):
164        """
165            @return residuals derivatives .
166            @note: in this case just return empty array
167        """
168        return []
169   
170class sansAssembly:
[ca6d914]171    """
172         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
173    """
[48882d1]174    def __init__(self,Model=None , Data=None):
[ca6d914]175        """
176            @param Model: the model wrapper fro sans -model
177            @param Data: the data wrapper for sans data
178        """
179        self.model = Model
180        self.data  = Data
181        self.res=[]
[48882d1]182    def chisq(self, params):
183        """
184            Calculates chi^2
185            @param params: list of parameter values
186            @return: chi^2
187        """
188        sum = 0
189        for item in self.res:
190            sum += item*item
191        return sum
192    def __call__(self,params):
[ca6d914]193        """
194            Compute residuals
195            @param params: value of parameters to fit
196        """
[48882d1]197        self.model.setParams(params)
198        self.res= self.data.residuals(self.model.eval)
199        return self.res
200   
[4c718654]201class FitEngine:
[ee5b04c]202    def __init__(self):
[ca6d914]203        """
204            Base class for scipy and park fit engine
205        """
206        #List of parameter names to fit
[ee5b04c]207        self.paramList=[]
[ca6d914]208        #Dictionnary of fitArrange element (fit problems)
209        self.fitArrangeDict={}
210       
[4c718654]211    def _concatenateData(self, listdata=[]):
212        """ 
213            _concatenateData method concatenates each fields of all data contains ins listdata.
214            @param listdata: list of data
[ca6d914]215            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
216             of data concatenanted
[4c718654]217            @raise: if listdata is empty  will return None
218            @raise: if data in listdata don't contain dy field ,will create an error
219            during fitting
220        """
221        if listdata==[]:
222            raise ValueError, " data list missing"
223        else:
224            xtemp=[]
225            ytemp=[]
226            dytemp=[]
[48882d1]227            self.mini=None
228            self.maxi=None
[4c718654]229               
230            for data in listdata:
[48882d1]231                mini,maxi=data.getFitRange()
232                if self.mini==None and self.maxi==None:
233                    self.mini=mini
234                    self.maxi=maxi
235                else:
236                    if mini < self.mini:
237                        self.mini=mini
238                    if self.maxi < maxi:
239                        self.maxi=maxi
240                       
241                   
[4c718654]242                for i in range(len(data.x)):
243                    xtemp.append(data.x[i])
244                    ytemp.append(data.y[i])
245                    if data.dy is not None and len(data.dy)==len(data.y):   
246                        dytemp.append(data.dy[i])
247                    else:
[ee5b04c]248                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
[48882d1]249            data= Data(x=xtemp,y=ytemp,dy=dytemp)
250            data.setFitRange(self.mini, self.maxi)
251            return data
[ca6d914]252       
253       
254    def set_model(self,model,Uid,pars=[]):
255        """
256            set a model on a given uid in the fit engine.
257            @param model: the model to fit
258            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
259            @param pars: the list of parameters to fit
260            @note : pars must contains only name of existing model's paramaters
261        """
[f44dbc7]262        if len(pars) >0:
[6831a99]263            if model==None:
[f44dbc7]264                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
[6831a99]265            else:
[ca6d914]266                for item in pars:
267                    if item in model.model.getParamList():
268                        self.paramList.append(item)
269                    else:
270                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
271                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
272                        return
[6831a99]273            #A fitArrange is already created but contains dList only at Uid
[ca6d914]274            if self.fitArrangeDict.has_key(Uid):
275                self.fitArrangeDict[Uid].set_model(model)
[6831a99]276            else:
277            #no fitArrange object has been create with this Uid
[48882d1]278                fitproblem = FitArrange()
[6831a99]279                fitproblem.set_model(model)
[ca6d914]280                self.fitArrangeDict[Uid] = fitproblem
[d4b0687]281        else:
[6831a99]282            raise ValueError, "park_integration:missing parameters"
[48882d1]283   
284    def set_data(self,data,Uid,qmin=None,qmax=None):
[d4b0687]285        """ Receives plottable, creates a list of data to fit,set data
286            in a FitArrange object and adds that object in a dictionary
287            with key Uid.
288            @param data: data added
289            @param Uid: unique key corresponding to a fitArrange object with data
[ca6d914]290        """
[48882d1]291        if qmin !=None and qmax !=None:
292            data.setFitRange(mini=qmin,maxi=qmax)
[d4b0687]293        #A fitArrange is already created but contains model only at Uid
[ca6d914]294        if self.fitArrangeDict.has_key(Uid):
295            self.fitArrangeDict[Uid].add_data(data)
[d4b0687]296        else:
297        #no fitArrange object has been create with this Uid
298            fitproblem= FitArrange()
299            fitproblem.add_data(data)
[ca6d914]300            self.fitArrangeDict[Uid]=fitproblem   
[48882d1]301   
[d4b0687]302    def get_model(self,Uid):
303        """
304            @param Uid: Uid is key in the dictionary containing the model to return
305            @return  a model at this uid or None if no FitArrange element was created
306            with this Uid
307        """
[ca6d914]308        if self.fitArrangeDict.has_key(Uid):
309            return self.fitArrangeDict[Uid].get_model()
[d4b0687]310        else:
311            return None
312   
313    def remove_Fit_Problem(self,Uid):
314        """remove   fitarrange in Uid"""
[ca6d914]315        if self.fitArrangeDict.has_key(Uid):
316            del self.fitArrangeDict[Uid]
[4c718654]317
318   
[d4b0687]319class FitArrange:
320    def __init__(self):
321        """
322            Class FitArrange contains a set of data for a given model
323            to perform the Fit.FitArrange must contain exactly one model
324            and at least one data for the fit to be performed.
325            model: the model selected by the user
326            Ldata: a list of data what the user wants to fit
327           
328        """
329        self.model = None
330        self.dList =[]
331       
332    def set_model(self,model):
333        """
334            set_model save a copy of the model
335            @param model: the model being set
336        """
337        self.model = model
338       
339    def add_data(self,data):
340        """
341            add_data fill a self.dList with data to fit
342            @param data: Data to add in the list 
343        """
344        if not data in self.dList:
345            self.dList.append(data)
346           
347    def get_model(self):
348        """ @return: saved model """
349        return self.model   
350     
351    def get_data(self):
352        """ @return:  list of data dList"""
353        return self.dList
354     
355    def remove_data(self,data):
356        """
357            Remove one element from the list
358            @param data: Data to remove from dList
359        """
360        if data in self.dList:
361            self.dList.remove(data)
[94b44293]362   
[4c718654]363
364
365   
Note: See TracBrowser for help on using the repository browser.