source: sasview/park_integration/test/ParkFitting.py @ fbc51ef

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since fbc51ef was fbc51ef, checked in by Gervaise Alina <gervyh@…>, 16 years ago

nothing much changed

  • Property mode set to 100644
File size: 9.5 KB
Line 
1#class Fitting
2import time
3
4import numpy
5import park
6from scipy import optimize
7from park import fit,fitresult
8from park import assembly
9
10from sans.guitools.plottables import Data1D
11#from sans.guitools import plottables
12from Loader import Load
13
14class SansParameter(park.Parameter):
15    """
16    SANS model parameters for use in the PARK fitting service.
17    The parameter attribute value is redirected to the underlying
18    parameter value in the SANS model.
19    """
20    def __init__(self, name, model):
21         self._model, self._name = model,name
22    def _getvalue(self): return self._model.getParam(self.name)
23    def _setvalue(self,value): self._model.setParam(self.name, value)
24    value = property(_getvalue,_setvalue)
25    def _getrange(self):
26        lo,hi = self._model.details[self.name][1:]
27        if lo is None: lo = -numpy.inf
28        if hi is None: hi = numpy.inf
29        return lo,hi
30    def _setrange(self,r):
31        self._model.details[self.name][1:] = r
32    range = property(_getrange,_setrange)
33
34class Model(object):
35    """
36        PARK wrapper for SANS models.
37    """
38    def __init__(self, sans_model):
39        self.model = sans_model
40        sansp = sans_model.getParamList()
41        parkp = [SansParameter(p,sans_model) for p in sansp]
42        self.parameterset = park.ParameterSet(sans_model.name,pars=parkp)
43    def eval(self,x):
44        return self.model.run(x)
45   
46class Data(object):
47    """ Wrapper class  for SANS data """
48    def __init__(self, sans_data):
49        self.x= sans_data.x
50        self.y= sans_data.y
51        self.dx= sans_data.dx
52        self.dy= sans_data.dy
53        self.qmin=None
54        self.qmax=None
55       
56    def setFitRange(self,mini=None,maxi=None):
57        """ to set the fit range"""
58        self.qmin=mini
59        self.qmax=maxi
60       
61    def residuals(self, fn):
62       
63        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
64        if self.qmin==None and self.qmax==None: 
65            return (y - fn(x))/dy
66       
67        else:
68            idx = x>=self.qmin & x <= self.qmax
69            return (y[idx] - fn(x[idx]))/dy[idx]
70           
71         
72    def residuals_deriv(self, model, pars=[]):
73        """ Return residual derivatives .in this case just return empty array"""
74        return []
75   
76class FitArrange:
77    def __init__(self):
78        """
79            Store a set of data for a given model to perform the Fit
80            @param model: the model selected by the user
81            @param Ldata: a list of data what the user want to fit
82        """
83        self.model = None
84        self.dList =[]
85       
86    def set_model(self,model):
87        """ set the model """
88        self.model = model
89       
90    def add_data(self,data):
91        """
92            @param data: Data to add in the list
93            fill a self.dataList with data to fit
94        """
95        if not data in self.dList:
96            self.dList.append(data)
97           
98    def get_model(self):
99        """ Return the model"""
100        return self.model   
101     
102    def get_data(self):
103        """ Return list of data"""
104        return self.dList
105     
106    def remove_data(self,data):
107        """
108            Remove one element from the list
109            @param data: Data to remove from the the lsit of data
110        """
111        if data in self.dList:
112            self.dList.remove(data)
113           
114class ParkFit:
115    """
116        Performs the Fit.he user determine what kind of data
117    """
118    def __init__(self,data=[]):
119        #this is a dictionary of FitArrange elements
120        self.fitArrangeList={}
121        #the constraint of the Fit
122        self.constraint =None
123        #Specify the use of scipy or park fit
124        self.fitType =None
125       
126    def createProblem(self,pars={}):
127        """
128            Check the contraint value and specify what kind of fit to use
129            return (M1,D1)
130        """
131        mylist=[]
132        for k,value in self.fitArrangeList.iteritems():
133            couple=()
134            model=value.get_model()
135            parameters= self.set_param(model, pars)
136            model = Model(model)
137            #print "model created",model.parameterset[0].value,model.parameterset[1].value
138            # Make all parameters fitting parameters
139            for p in model.parameterset:
140                p.set([-numpy.inf,numpy.inf])
141                #p.set([-10,10])
142            Ldata=value.get_data()
143            data=self._concatenateData(Ldata)
144            #print "this data",data
145            #print "data.residuals in createProblem",Ldata[0].residuals
146            #print "data.residuals in createProblem",data.residuals
147            #couple1=(model,Ldata[0])
148            #mylist.append(couple1)
149            couple=(model,data)
150            mylist.append(couple)
151        #print mylist
152        return mylist
153        #return model,data
154   
155    def fit(self,pars, qmin=None, qmax=None):
156        """
157             Do the fit
158        """
159       
160        modelList=self.createProblem(pars)
161        #model,data=self.createProblem()
162        #fitness=assembly.Fitness(model,data)
163       
164        problem =  park.Assembly(modelList)
165        #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted
166        #problem[0].parameterset['A'].set([0,1000])
167        #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted
168        fit.fit(problem, handler= fitresult.ConsoleUpdate(improvement_delta=0.1))
169        #return fit.fit(problem)
170        #fit.fit(problem, handler= fitresult.ConsoleUpdate(improvement_delta=0.1))
171       
172   
173    def set_model(self,model,Uid):
174        """ Set model """
175       
176        if self.fitArrangeList.has_key(Uid):
177            self.fitArrangeList[Uid].set_model(model)
178        else:
179            fitproblem= FitArrange()
180            fitproblem.set_model(model)
181            self.fitArrangeList[Uid]=fitproblem
182       
183    def set_data(self,data,Uid):
184        """ Receive plottable and create a list of data to fit"""
185        data=Data(data)
186        if self.fitArrangeList.has_key(Uid):
187            self.fitArrangeList[Uid].add_data(data)
188        else:
189            fitproblem= FitArrange()
190            fitproblem.add_data(data)
191            self.fitArrangeList[Uid]=fitproblem
192           
193    def get_model(self,Uid):
194        """ return list of data"""
195        return self.fitArrangeList[Uid]
196   
197    def set_param(self,model, pars):
198        """ Recieve a dictionary of parameter and save it """
199        parameters=[]
200        if model==None:
201            raise ValueError, "Cannot set parameters for empty model"
202        else:
203            #for key ,value in pars:
204            for key, value in pars.iteritems():
205                param = Parameter(model, key, value)
206                parameters.append(param)
207        return parameters
208   
209    def add_constraint(self, constraint):
210        """ User specify contraint to fit """
211        self.constraint = str(constraint)
212       
213    def get_constraint(self):
214        """ return the contraint value """
215        return self.constraint
216   
217    def set_constraint(self,constraint):
218        """
219            receive a string as a constraint
220            @param constraint: a string used to constraint some parameters to get a
221                specific value
222        """
223        self.constraint= constraint
224    def _concatenateData(self, listdata=[]):
225        """ concatenate each fields of all Data contains ins listdata
226         return data
227        """
228        if listdata==[]:
229            raise ValueError, " data list missing"
230        else:
231            xtemp=[]
232            ytemp=[]
233            dytemp=[]
234            resid=[]
235            resid_deriv=[]
236           
237            for data in listdata:
238                for i in range(len(data.x)):
239                    if not data.x[i] in xtemp:
240                        xtemp.append(data.x[i])
241                       
242                    if not data.y[i] in ytemp:
243                        ytemp.append(data.y[i])
244                       
245                    if not data.dy[i] in dytemp:
246                        dytemp.append(data.dy[i])
247                   
248                   
249            newplottable= Data1D(xtemp,ytemp,None,dytemp)
250            newdata=Data(newplottable)
251           
252            #print "this is new data",newdata.dy
253            return newdata
254class Parameter:
255    """
256        Class to handle model parameters
257    """
258    def __init__(self, model, name, value=None):
259            self.model = model
260            self.name = name
261            if not value==None:
262                self.model.setParam(self.name, value)
263           
264    def set(self, value):
265        """
266            Set the value of the parameter
267        """
268        self.model.setParam(self.name, value)
269
270    def __call__(self):
271        """
272            Return the current value of the parameter
273        """
274        return self.model.getParam(self.name)
275   
276
277     
278if __name__ == "__main__": 
279    load= Load()
280   
281    # test fit one data set one model
282    load.set_filename("testdata_line.txt")
283    load.set_values()
284    data1 = Data1D(x=[], y=[], dx=None,dy=None)
285    data1.name = "data1"
286    load.load_data(data1)
287    fitter =ParkFit()
288   
289    from sans.guitools.LineModel import LineModel
290    model  = LineModel()
291    fitter.set_model(model,1)
292    fitter.set_data(data1,1)
293   
294    print"PARK fit result \n",fitter.fit({'A':2,'B':1},None,None)
295   
296   
297   
298   
Note: See TracBrowser for help on using the repository browser.