source: sasview/park_integration/ParkFitting.py @ 73b1c72

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 73b1c72 was 9e85792, checked in by Gervaise Alina <gervyh@…>, 16 years ago

working ont he 3 rd used cases

  • Property mode set to 100644
File size: 9.9 KB
Line 
1#class Fitting
2import time
3
4import numpy
5import park
6from scipy import optimize
7from park import fit,fitresult
8from park import assembly
9
10from sans.guitools.plottables import Data1D
11#from sans.guitools import plottables
12from Loader import Load
13from park import expression
14class SansParameter(park.Parameter):
15    """
16    SANS model parameters for use in the PARK fitting service.
17    The parameter attribute value is redirected to the underlying
18    parameter value in the SANS model.
19    """
20    def __init__(self, name, model):
21         self._model, self._name = model,name
22         self.set(model.getParam(name))
23    def _getvalue(self): return self._model.getParam(self.name)
24    def _setvalue(self,value): 
25        if numpy.isnan(value):
26            print "setting %s.%s to"%(self._model.name,self.name),value
27        self._model.setParam(self.name, value)
28    value = property(_getvalue,_setvalue)
29    def _getrange(self):
30        lo,hi = self._model.details[self.name][1:]
31        if lo is None: lo = -numpy.inf
32        if hi is None: hi = numpy.inf
33        return lo,hi
34    def _setrange(self,r):
35        self._model.details[self.name][1:] = r
36    range = property(_getrange,_setrange)
37
38class Model(object):
39    """
40        PARK wrapper for SANS models.
41    """
42    def __init__(self, sans_model):
43        self.model = sans_model
44        sansp = sans_model.getParamList()
45        parkp = [SansParameter(p,sans_model) for p in sansp]
46        self.parameterset = park.ParameterSet(sans_model.name,pars=parkp)
47    def eval(self,x):
48        return self.model.run(x)
49   
50class Data(object):
51    """ Wrapper class  for SANS data """
52    def __init__(self, sans_data):
53        self.x= sans_data.x
54        self.y= sans_data.y
55        self.dx= sans_data.dx
56        self.dy= sans_data.dy
57        self.qmin=None
58        self.qmax=None
59       
60    def setFitRange(self,mini=None,maxi=None):
61        """ to set the fit range"""
62        self.qmin=mini
63        self.qmax=maxi
64       
65    def residuals(self, fn):
66       
67        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
68        if self.qmin==None and self.qmax==None: 
69            return (y - fn(x))/dy
70       
71        else:
72            idx = x>=self.qmin & x <= self.qmax
73            return (y[idx] - fn(x[idx]))/dy[idx]
74           
75         
76    def residuals_deriv(self, model, pars=[]):
77        """ Return residual derivatives .in this case just return empty array"""
78        return []
79   
80class FitArrange:
81    def __init__(self):
82        """
83            Store a set of data for a given model to perform the Fit
84            @param model: the model selected by the user
85            @param Ldata: a list of data what the user want to fit
86        """
87        self.model = None
88        self.dList =[]
89       
90    def set_model(self,model):
91        """ set the model """
92        self.model = model
93       
94    def add_data(self,data):
95        """
96            @param data: Data to add in the list
97            fill a self.dataList with data to fit
98        """
99        if not data in self.dList:
100            self.dList.append(data)
101           
102    def get_model(self):
103        """ Return the model"""
104        return self.model   
105     
106    def get_data(self):
107        """ Return list of data"""
108        return self.dList
109     
110    def remove_data(self,data):
111        """
112            Remove one element from the list
113            @param data: Data to remove from the the lsit of data
114        """
115        if data in self.dList:
116            self.dList.remove(data)
117           
118class ParkFit:
119    """
120        Performs the Fit.he user determine what kind of data
121    """
122    def __init__(self,data=[]):
123        #this is a dictionary of FitArrange elements
124        self.fitArrangeList={}
125        #the constraint of the Fit
126        self.constraint =None
127        #Specify the use of scipy or park fit
128        self.fitType =None
129       
130    def createProblem(self,pars={}):
131        """
132            Check the contraint value and specify what kind of fit to use
133            return (M1,D1)
134        """
135        mylist=[]
136        listmodel=[]
137        for k,value in self.fitArrangeList.iteritems():
138            #couple=()
139            sansmodel=value.get_model()
140           
141            #parameters= self.set_param(model,model.name, pars)
142            parkmodel = Model(sansmodel)
143            #print "model created",model.parameterset[0].value,model.parameterset[1].value
144            # Make all parameters fitting parameters
145           
146           
147            for p in parkmodel.parameterset:
148                #p.range([-numpy.inf,numpy.inf])
149                # Convert parameters with initial values into fitted parameters
150                # spanning all possible values.  Parameters which are expressions
151                # will remain as expressions.
152                if p.isfixed():
153                    p.set([-numpy.inf,numpy.inf])
154               
155            Ldata=value.get_data()
156            data=self._concatenateData(Ldata)
157            data1=Data(data)
158           
159            couple=(parkmodel,data1)
160            mylist.append(couple)
161        #print mylist
162        return mylist
163        #return model,data
164   
165    def fit(self,pars=None, qmin=None, qmax=None):
166        """
167             Do the fit
168        """
169   
170        print "starting ParkFit.fit()"
171        modelList=self.createProblem()
172        problem =  park.Assembly(modelList)
173        pars=problem.fit_parameters()
174        print "About to call eval",pars
175        print "initial",[p.value for p in pars]
176        problem.eval()
177        #print "M2.B",problem.parameterset['M2.B'].expression,problem.parameterset['M2.B'].value
178        #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted
179       
180        #problem[0].parameterset['A'].set([0,1000])
181        #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted
182        fit.fit(problem, handler= fitresult.ConsoleUpdate(improvement_delta=0.1))
183       
184   
185    def set_model(self,model,Uid):
186        """ Set model """
187       
188        if self.fitArrangeList.has_key(Uid):
189            self.fitArrangeList[Uid].set_model(model)
190        else:
191            fitproblem= FitArrange()
192            fitproblem.set_model(model)
193            self.fitArrangeList[Uid]=fitproblem
194       
195    def set_data(self,data,Uid):
196        """ Receive plottable and create a list of data to fit"""
197       
198        if self.fitArrangeList.has_key(Uid):
199            self.fitArrangeList[Uid].add_data(data)
200        else:
201            fitproblem= FitArrange()
202            fitproblem.add_data(data)
203            self.fitArrangeList[Uid]=fitproblem
204           
205    def get_model(self,Uid):
206        """ return list of data"""
207        return self.fitArrangeList[Uid]
208   
209    def set_param(self,model,name, pars):
210        """ Recieve a dictionary of parameter and save it """
211        parameters=[]
212        if model==None:
213            raise ValueError, "Cannot set parameters for empty model"
214        else:
215            model.name=name
216            for key, value in pars.iteritems():
217                param = Parameter(model, key, value)
218                parameters.append(param)
219        return parameters
220   
221    def add_constraint(self, constraint):
222        """ User specify contraint to fit """
223        self.constraint = str(constraint)
224       
225    def get_constraint(self):
226        """ return the contraint value """
227        return self.constraint
228   
229    def set_constraint(self,constraint):
230        """
231            receive a string as a constraint
232            @param constraint: a string used to constraint some parameters to get a
233                specific value
234        """
235        self.constraint= constraint
236    def _concatenateData(self, listdata=[]):
237        """ concatenate each fields of all Data contains ins listdata
238         return data
239        """
240        if listdata==[]:
241            raise ValueError, " data list missing"
242        else:
243            xtemp=[]
244            ytemp=[]
245            dytemp=[]
246            resid=[]
247            resid_deriv=[]
248           
249            for data in listdata:
250                for i in range(len(data.x)):
251                    if not data.x[i] in xtemp:
252                        xtemp.append(data.x[i])
253                       
254                    if not data.y[i] in ytemp:
255                        ytemp.append(data.y[i])
256                       
257                    if not data.dy[i] in dytemp:
258                        dytemp.append(data.dy[i])
259                   
260                   
261            newplottable= Data1D(xtemp,ytemp,None,dytemp)
262            newdata=Data(newplottable)
263           
264            #print "this is new data",newdata.dy
265            return newdata
266class Parameter:
267    """
268        Class to handle model parameters
269    """
270    def __init__(self, model, name, value=None):
271            self.model = model
272            self.name = name
273            if not value==None:
274                self.model.setParam(self.name, value)
275           
276    def set(self, value):
277        """
278            Set the value of the parameter
279        """
280        self.model.setParam(self.name, value)
281
282    def __call__(self):
283        """
284            Return the current value of the parameter
285        """
286        return self.model.getParam(self.name)
287   
288
289     
290if __name__ == "__main__": 
291    load= Load()
292   
293    # test fit one data set one model
294    load.set_filename("testdata_line.txt")
295    load.set_values()
296    data1 = Data1D(x=[], y=[], dx=None,dy=None)
297    data1.name = "data1"
298    load.load_data(data1)
299    fitter =ParkFit()
300   
301    from sans.guitools.LineModel import LineModel
302    model  = LineModel()
303    fitter.set_model(model,1)
304    fitter.set_data(data1,1)
305   
306    print"PARK fit result",fitter.fit({'A':2,'B':1},None,None)
307   
308   
309   
310   
Note: See TracBrowser for help on using the repository browser.