source: sasview/park_integration/docs/ParkFitting.py @ 1d02586

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 1d02586 was 1d02586, checked in by Gervaise Alina <gervyh@…>, 16 years ago
  • Property mode set to 100644
File size: 9.4 KB
Line 
1#class Fitting
2import time
3
4import numpy
5import park
6from scipy import optimize
7from park import fit,fitresult
8from park import assembly
9
10from sans.guitools.plottables import Data1D
11#from sans.guitools import plottables
12from Loader import Load
13
14class SansParameter(park.Parameter):
15    """
16    SANS model parameters for use in the PARK fitting service.
17    The parameter attribute value is redirected to the underlying
18    parameter value in the SANS model.
19    """
20    def __init__(self, name, model):
21         self._model, self._name = model,name
22    def _getvalue(self): return self._model.getParam(self.name)
23    def _setvalue(self,value): self._model.setParam(self.name, value)
24    value = property(_getvalue,_setvalue)
25    def _getrange(self):
26        lo,hi = self._model.details[self.name][1:]
27        if lo is None: lo = -numpy.inf
28        if hi is None: hi = numpy.inf
29        return lo,hi
30    def _setrange(self,r):
31        self._model.details[self.name][1:] = r
32    range = property(_getrange,_setrange)
33
34class Model(object):
35    """
36        PARK wrapper for SANS models.
37    """
38    def __init__(self, sans_model):
39        self.model = sans_model
40        sansp = sans_model.getParamList()
41        parkp = [SansParameter(p,sans_model) for p in sansp]
42        self.parameterset = park.ParameterSet(sans_model.name,pars=parkp)
43    def eval(self,x):
44        return self.model.run(x)
45   
46class Data(object):
47    """ Wrapper class  for SANS data """
48    def __init__(self, sans_data):
49        self.x= sans_data.x
50        self.y= sans_data.y
51        self.dx= sans_data.dx
52        self.dy= sans_data.dy
53        self.qmin=None
54        self.qmax=None
55       
56    def setFitRange(self,mini=None,maxi=None):
57        """ to set the fit range"""
58        self.qmin=mini
59        self.qmax=maxi
60       
61    def residuals(self, fn):
62       
63        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
64        if self.qmin==None and self.qmax==None: 
65            return (y - fn(x))/dy
66       
67        else:
68            idx = x>=self.qmin & x <= self.qmax
69            return (y[idx] - fn(x[idx]))/dy[idx]
70           
71         
72    def residuals_deriv(self, model, pars=[]):
73        """ Return residual derivatives .in this case just return empty array"""
74        return []
75   
76class FitArrange:
77    def __init__(self):
78        """
79            Store a set of data for a given model to perform the Fit
80            @param model: the model selected by the user
81            @param Ldata: a list of data what the user want to fit
82        """
83        self.model = None
84        self.dList =[]
85       
86    def set_model(self,model):
87        """ set the model """
88        self.model = model
89       
90    def add_data(self,data):
91        """
92            @param data: Data to add in the list
93            fill a self.dataList with data to fit
94        """
95        if not data in self.dList:
96            self.dList.append(data)
97           
98    def get_model(self):
99        """ Return the model"""
100        return self.model   
101     
102    def get_data(self):
103        """ Return list of data"""
104        return self.dList
105     
106    def remove_data(self,data):
107        """
108            Remove one element from the list
109            @param data: Data to remove from the the lsit of data
110        """
111        if data in self.dList:
112            self.dList.remove(data)
113           
114class ParkFit:
115    """
116        Performs the Fit.he user determine what kind of data
117    """
118    def __init__(self,data=[]):
119        #this is a dictionary of FitArrange elements
120        self.fitArrangeList={}
121        #the constraint of the Fit
122        self.constraint =None
123        #Specify the use of scipy or park fit
124        self.fitType =None
125       
126    def createProblem(self,pars={}):
127        """
128            Check the contraint value and specify what kind of fit to use
129            return (M1,D1)
130        """
131        mylist=[]
132        for k,value in self.fitArrangeList.iteritems():
133            couple=()
134            model=value.get_model()
135            parameters= self.set_param(model, pars)
136            model = Model(model)
137            #print "model created",model.parameterset[0].value,model.parameterset[1].value
138            # Make all parameters fitting parameters
139            for p in model.parameterset:
140                p.set([-numpy.inf,numpy.inf])
141                #p.set([-10,10])
142            Ldata=value.get_data()
143            data=self._concatenateData(Ldata)
144            #print "this data",data
145            #print "data.residuals in createProblem",Ldata[0].residuals
146            #print "data.residuals in createProblem",data.residuals
147            #couple1=(model,Ldata[0])
148            #mylist.append(couple1)
149            couple=(model,data)
150            mylist.append(couple)
151        #print mylist
152        return mylist
153        #return model,data
154   
155    def fit(self,pars, qmin=None, qmax=None):
156        """
157             Do the fit
158        """
159       
160        modelList=self.createProblem(pars)
161        #model,data=self.createProblem()
162        #fitness=assembly.Fitness(model,data)
163       
164        problem =  park.Assembly(modelList)
165        print "problem :",problem[0].parameterset,problem[0].parameterset.fitted
166        #problem[0].parameterset['A'].set([0,1000])
167        #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted
168        fit.fit(problem, handler= fitresult.ConsoleUpdate(improvement_delta=0.1))
169        #fit.fit(problem, handler= fitresult.ConsoleUpdate(improvement_delta=0.1))
170       
171   
172    def set_model(self,model,Uid):
173        """ Set model """
174       
175        if self.fitArrangeList.has_key(Uid):
176            self.fitArrangeList[Uid].set_model(model)
177        else:
178            fitproblem= FitArrange()
179            fitproblem.set_model(model)
180            self.fitArrangeList[Uid]=fitproblem
181       
182    def set_data(self,data,Uid):
183        """ Receive plottable and create a list of data to fit"""
184        data=Data(data)
185        if self.fitArrangeList.has_key(Uid):
186            self.fitArrangeList[Uid].add_data(data)
187        else:
188            fitproblem= FitArrange()
189            fitproblem.add_data(data)
190            self.fitArrangeList[Uid]=fitproblem
191           
192    def get_model(self,Uid):
193        """ return list of data"""
194        return self.fitArrangeList[Uid]
195   
196    def set_param(self,model, pars):
197        """ Recieve a dictionary of parameter and save it """
198        parameters=[]
199        if model==None:
200            raise ValueError, "Cannot set parameters for empty model"
201        else:
202            #for key ,value in pars:
203            for key, value in pars.iteritems():
204                param = Parameter(model, key, value)
205                parameters.append(param)
206        return parameters
207   
208    def add_constraint(self, constraint):
209        """ User specify contraint to fit """
210        self.constraint = str(constraint)
211       
212    def get_constraint(self):
213        """ return the contraint value """
214        return self.constraint
215   
216    def set_constraint(self,constraint):
217        """
218            receive a string as a constraint
219            @param constraint: a string used to constraint some parameters to get a
220                specific value
221        """
222        self.constraint= constraint
223    def _concatenateData(self, listdata=[]):
224        """ concatenate each fields of all Data contains ins listdata
225         return data
226        """
227        if listdata==[]:
228            raise ValueError, " data list missing"
229        else:
230            xtemp=[]
231            ytemp=[]
232            dytemp=[]
233            resid=[]
234            resid_deriv=[]
235           
236            for data in listdata:
237                for i in range(len(data.x)):
238                    if not data.x[i] in xtemp:
239                        xtemp.append(data.x[i])
240                       
241                    if not data.y[i] in ytemp:
242                        ytemp.append(data.y[i])
243                       
244                    if not data.dy[i] in dytemp:
245                        dytemp.append(data.dy[i])
246                   
247                   
248            newplottable= Data1D(xtemp,ytemp,None,dytemp)
249            newdata=Data(newplottable)
250           
251            #print "this is new data",newdata.dy
252            return newdata
253class Parameter:
254    """
255        Class to handle model parameters
256    """
257    def __init__(self, model, name, value=None):
258            self.model = model
259            self.name = name
260            if not value==None:
261                self.model.setParam(self.name, value)
262           
263    def set(self, value):
264        """
265            Set the value of the parameter
266        """
267        self.model.setParam(self.name, value)
268
269    def __call__(self):
270        """
271            Return the current value of the parameter
272        """
273        return self.model.getParam(self.name)
274   
275
276     
277if __name__ == "__main__": 
278    load= Load()
279   
280    # test fit one data set one model
281    load.set_filename("testdata_line.txt")
282    load.set_values()
283    data1 = Data1D(x=[], y=[], dx=None,dy=None)
284    data1.name = "data1"
285    load.load_data(data1)
286    fitter =ParkFit()
287   
288    from sans.guitools.LineModel import LineModel
289    model  = LineModel()
290    fitter.set_model(model,1)
291    fitter.set_data(data1,1)
292   
293    print"PARK fit result",fitter.fit({'A':2,'B':1},None,None)
294   
295   
296   
297   
Note: See TracBrowser for help on using the repository browser.