source: sasview/park_integration/AbstractFitEngine.py @ 48882d1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 48882d1 was 48882d1, checked in by Gervaise Alina <gervyh@…>, 16 years ago

park fitting with new model and new data

  • Property mode set to 100644
File size: 9.6 KB
Line 
1
2import park,numpy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11         self._model, self._name = model,name
12         self.set(model.getParam(name))
13         
14    def _getvalue(self): return self._model.getParam(self.name)
15   
16    def _setvalue(self,value): 
17        self._model.setParam(self.name, value)
18       
19    value = property(_getvalue,_setvalue)
20   
21    def _getrange(self):
22        lo,hi = self._model.details[self.name][1:]
23        if lo is None: lo = -numpy.inf
24        if hi is None: hi = numpy.inf
25        return lo,hi
26   
27    def _setrange(self,r):
28        self._model.details[self.name][1:] = r
29    range = property(_getrange,_setrange)
30
31
32class Model(object):
33    """
34        PARK wrapper for SANS models.
35    """
36    def __init__(self, sans_model):
37        self.model = sans_model
38        #print "ParkFitting:sans model",self.model
39        self.sansp = sans_model.getParamList()
40        #print "ParkFitting: sans model parameter list",sansp
41        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
42        #print "ParkFitting: park model parameter ",self.parkp
43        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
44        self.pars=[]
45       
46    def getParams(self,fitparams):
47        list=[]
48        self.pars=[]
49        self.pars=fitparams
50        for item in fitparams:
51            for element in self.parkp:
52                 if element.name ==str(item):
53                     list.append(element.value)
54        #print "abstractfitengine: getparams",list
55        return list
56   
57    def setParams(self, params):
58        list=[]
59        for item in self.parkp:
60            list.append(item.name)
61        list.sort()
62        for i in range(len(params)):
63            #self.parkp[i].value = params[i]
64            #print "abstractfitengine: set-params",list[i],params[i]
65           
66            self.model.setParam(list[i],params[i])
67 
68    def eval(self,x):
69        #print "eval",self.parameterset[0].value,self.parameterset[1].value
70        return self.model.runXY(x)
71       
72
73class Data(object):
74    """ Wrapper class  for SANS data """
75    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
76       
77        if  sans_data !=None:
78            self.x= sans_data.x
79            self.y= sans_data.y
80            self.dx= sans_data.dx
81            self.dy= sans_data.dy
82           
83        elif (x!=None and y!=None and dy!=None):
84                self.x=x
85                self.y=y
86                self.dx=dx
87                self.dy=dy
88        else:
89            raise ValueError,\
90            "Data is missing x, y or dy, impossible to compute residuals later on"
91        self.qmin=None
92        self.qmax=None
93       
94    def setFitRange(self,mini=None,maxi=None):
95        """ to set the fit range"""
96        self.qmin=mini
97        self.qmax=maxi
98    def getFitRange(self):
99         return self.qmin, self.qmax
100    def residuals(self, fn):
101        """ @param fn: function that return model value
102            @return residuals
103        """
104        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
105        if self.qmin==None and self.qmax==None: 
106            fx =[fn(v) for v in x]
107            return (y - fx)/dy
108        else:
109            idx = (x>=self.qmin) & (x <= self.qmax)
110            fx = [fn(item)for item in x[idx ]]
111            return (y[idx] - fx)/dy[idx]
112         
113           
114         
115    def residuals_deriv(self, model, pars=[]):
116        """
117            @return residuals derivatives .
118            @note: in this case just return empty array
119        """
120        return []
121   
122class sansAssembly:
123    def __init__(self,Model=None , Data=None):
124       self.model = Model
125       self.data  = Data
126       self.res=[]
127    def chisq(self, params):
128        """
129            Calculates chi^2
130            @param params: list of parameter values
131            @return: chi^2
132        """
133        sum = 0
134        for item in self.res:
135            sum += item*item
136        return sum
137    def __call__(self,params):
138        self.model.setParams(params)
139        self.res= self.data.residuals(self.model.eval)
140        return self.res
141   
142class FitEngine:
143    def __init__(self):
144        self.paramList=[]
145    def _concatenateData(self, listdata=[]):
146        """ 
147            _concatenateData method concatenates each fields of all data contains ins listdata.
148            @param listdata: list of data
149           
150            @return Data:
151               
152            @raise: if listdata is empty  will return None
153            @raise: if data in listdata don't contain dy field ,will create an error
154            during fitting
155        """
156        if listdata==[]:
157            raise ValueError, " data list missing"
158        else:
159            xtemp=[]
160            ytemp=[]
161            dytemp=[]
162            self.mini=None
163            self.maxi=None
164               
165            for data in listdata:
166                mini,maxi=data.getFitRange()
167                if self.mini==None and self.maxi==None:
168                    self.mini=mini
169                    self.maxi=maxi
170                else:
171                    if mini < self.mini:
172                        self.mini=mini
173                    if self.maxi < maxi:
174                        self.maxi=maxi
175                       
176                   
177                for i in range(len(data.x)):
178                    xtemp.append(data.x[i])
179                    ytemp.append(data.y[i])
180                    if data.dy is not None and len(data.dy)==len(data.y):   
181                        dytemp.append(data.dy[i])
182                    else:
183                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
184            #return xtemp, ytemp,dytemp
185            data= Data(x=xtemp,y=ytemp,dy=dytemp)
186            data.setFitRange(self.mini, self.maxi)
187            return data
188    def set_model(self,model,name,Uid,pars=[]):
189        if len(pars) >0:
190            self.paramList = []
191            if model==None:
192                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
193            else:
194                model.name = name
195                self.paramList=pars
196            #A fitArrange is already created but contains dList only at Uid
197            if self.fitArrangeList.has_key(Uid):
198                self.fitArrangeList[Uid].set_model(model)
199            else:
200            #no fitArrange object has been create with this Uid
201                fitproblem = FitArrange()
202                fitproblem.set_model(model)
203                self.fitArrangeList[Uid] = fitproblem
204        else:
205            raise ValueError, "park_integration:missing parameters"
206   
207    def set_data(self,data,Uid,qmin=None,qmax=None):
208        """ Receives plottable, creates a list of data to fit,set data
209            in a FitArrange object and adds that object in a dictionary
210            with key Uid.
211            @param data: data added
212            @param Uid: unique key corresponding to a fitArrange object with data
213            """
214        if qmin !=None and qmax !=None:
215            data.setFitRange(mini=qmin,maxi=qmax)
216        #A fitArrange is already created but contains model only at Uid
217        if self.fitArrangeList.has_key(Uid):
218            self.fitArrangeList[Uid].add_data(data)
219        else:
220        #no fitArrange object has been create with this Uid
221            fitproblem= FitArrange()
222            fitproblem.add_data(data)
223            self.fitArrangeList[Uid]=fitproblem   
224   
225    def get_model(self,Uid):
226        """
227            @param Uid: Uid is key in the dictionary containing the model to return
228            @return  a model at this uid or None if no FitArrange element was created
229            with this Uid
230        """
231        if self.fitArrangeList.has_key(Uid):
232            return self.fitArrangeList[Uid].get_model()
233        else:
234            return None
235   
236    def remove_Fit_Problem(self,Uid):
237        """remove   fitarrange in Uid"""
238        if self.fitArrangeList.has_key(Uid):
239            del self.fitArrangeList[Uid]
240
241   
242class FitArrange:
243    def __init__(self):
244        """
245            Class FitArrange contains a set of data for a given model
246            to perform the Fit.FitArrange must contain exactly one model
247            and at least one data for the fit to be performed.
248            model: the model selected by the user
249            Ldata: a list of data what the user wants to fit
250           
251        """
252        self.model = None
253        self.dList =[]
254       
255    def set_model(self,model):
256        """
257            set_model save a copy of the model
258            @param model: the model being set
259        """
260        self.model = model
261       
262    def add_data(self,data):
263        """
264            add_data fill a self.dList with data to fit
265            @param data: Data to add in the list 
266        """
267        if not data in self.dList:
268            self.dList.append(data)
269           
270    def get_model(self):
271        """ @return: saved model """
272        return self.model   
273     
274    def get_data(self):
275        """ @return:  list of data dList"""
276        return self.dList
277     
278    def remove_data(self,data):
279        """
280            Remove one element from the list
281            @param data: Data to remove from dList
282        """
283        if data in self.dList:
284            self.dList.remove(data)
285   
286
287
288   
Note: See TracBrowser for help on using the repository browser.