source: sasview/park_integration/ParkFitting.py @ 1c94a9f1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 1c94a9f1 was cf3b781, checked in by Gervaise Alina <gervyh@…>, 16 years ago

need more tests.but usecase 3 implemented

  • Property mode set to 100644
File size: 10.1 KB
Line 
1#class Fitting
2import time
3
4import numpy
5import park
6from scipy import optimize
7from park import fit,fitresult
8from park import assembly
9from park.fitmc import FitSimplex, FitMC
10
11from sans.guitools.plottables import Data1D
12#from sans.guitools import plottables
13from Loader import Load
14from park import expression
15class SansParameter(park.Parameter):
16    """
17    SANS model parameters for use in the PARK fitting service.
18    The parameter attribute value is redirected to the underlying
19    parameter value in the SANS model.
20    """
21    def __init__(self, name, model):
22         self._model, self._name = model,name
23         self.set(model.getParam(name))
24    def _getvalue(self): return self._model.getParam(self.name)
25    def _setvalue(self,value): 
26        if numpy.isnan(value):
27            print "setting %s.%s to"%(self._model.name,self.name),value
28        self._model.setParam(self.name, value)
29    value = property(_getvalue,_setvalue)
30    def _getrange(self):
31        lo,hi = self._model.details[self.name][1:]
32        if lo is None: lo = -numpy.inf
33        if hi is None: hi = numpy.inf
34        return lo,hi
35    def _setrange(self,r):
36        self._model.details[self.name][1:] = r
37    range = property(_getrange,_setrange)
38
39class Model(object):
40    """
41        PARK wrapper for SANS models.
42    """
43    def __init__(self, sans_model):
44        self.model = sans_model
45        sansp = sans_model.getParamList()
46        parkp = [SansParameter(p,sans_model) for p in sansp]
47        self.parameterset = park.ParameterSet(sans_model.name,pars=parkp)
48    def eval(self,x):
49        return self.model.run(x)
50   
51class Data(object):
52    """ Wrapper class  for SANS data """
53    def __init__(self, sans_data):
54        self.x= sans_data.x
55        self.y= sans_data.y
56        self.dx= sans_data.dx
57        self.dy= sans_data.dy
58        self.qmin=None
59        self.qmax=None
60       
61    def setFitRange(self,mini=None,maxi=None):
62        """ to set the fit range"""
63        self.qmin=mini
64        self.qmax=maxi
65       
66    def residuals(self, fn):
67       
68        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
69        if self.qmin==None and self.qmax==None: 
70            self.fx = fn(x)
71            return (y - fn(x))/dy
72       
73        else:
74            self.fx = fn(x[idx])
75            idx = x>=self.qmin & x <= self.qmax
76            return (y[idx] - fn(x[idx]))/dy[idx]
77           
78         
79    def residuals_deriv(self, model, pars=[]):
80        """ Return residual derivatives .in this case just return empty array"""
81        return []
82   
83class FitArrange:
84    def __init__(self):
85        """
86            Store a set of data for a given model to perform the Fit
87            @param model: the model selected by the user
88            @param Ldata: a list of data what the user want to fit
89        """
90        self.model = None
91        self.dList =[]
92       
93    def set_model(self,model):
94        """ set the model """
95        self.model = model
96       
97    def add_data(self,data):
98        """
99            @param data: Data to add in the list
100            fill a self.dataList with data to fit
101        """
102        if not data in self.dList:
103            self.dList.append(data)
104           
105    def get_model(self):
106        """ Return the model"""
107        return self.model   
108     
109    def get_data(self):
110        """ Return list of data"""
111        return self.dList
112     
113    def remove_data(self,data):
114        """
115            Remove one element from the list
116            @param data: Data to remove from the the lsit of data
117        """
118        if data in self.dList:
119            self.dList.remove(data)
120    def remove_model(self):
121        """ Remove model """
122        model=None
123    def remove_datalist(self):
124        self.dList=[]
125           
126class ParkFit:
127    """
128        Performs the Fit.he user determine what kind of data
129    """
130    def __init__(self,data=[]):
131        #this is a dictionary of FitArrange elements
132        self.fitArrangeList={}
133        #the constraint of the Fit
134        self.constraint =None
135        #Specify the use of scipy or park fit
136        self.fitType =None
137       
138    def createProblem(self,pars={}):
139        """
140            Check the contraint value and specify what kind of fit to use
141            return (M1,D1)
142        """
143        mylist=[]
144        listmodel=[]
145        for k,value in self.fitArrangeList.iteritems():
146            #couple=()
147            sansmodel=value.get_model()
148           
149            #parameters= self.set_param(model,model.name, pars)
150            parkmodel = Model(sansmodel)
151            #print "model created",model.parameterset[0].value,model.parameterset[1].value
152            # Make all parameters fitting parameters
153           
154           
155            for p in parkmodel.parameterset:
156                #p.range([-numpy.inf,numpy.inf])
157                # Convert parameters with initial values into fitted parameters
158                # spanning all possible values.  Parameters which are expressions
159                # will remain as expressions.
160                if p.isfixed():
161                    p.set([-numpy.inf,numpy.inf])
162               
163            Ldata=value.get_data()
164            data=self._concatenateData(Ldata)
165            data1=Data(data)
166           
167            couple=(parkmodel,data1)
168            mylist.append(couple)
169        #print mylist
170        self.problem =  park.Assembly(mylist)
171        #return model,data
172   
173    def fit(self,pars=None, qmin=None, qmax=None):
174        """
175             Do the fit
176        """
177
178        self.createProblem(pars)
179        print "starting ParkFit.fit()"
180        #problem[0].model.parameterset['A'].set([1,5])
181        #problem[0].model.parameterset['B'].set([1,5])
182        pars=self.problem.fit_parameters()
183        print "About to call eval",pars
184        print "initial",[p.value for p in pars]
185        self.problem.eval()
186        #print "M2.B",problem.parameterset['M2.B'].expression,problem.parameterset['M2.B'].value
187        #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted
188       
189        #problem[0].parameterset['A'].set([0,1000])
190        #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted
191
192        localfit = FitSimplex()
193        localfit.ftol = 1e-8
194        fitter = FitMC(localfit=localfit)
195
196        result = fit.fit(self.problem,
197                         fitter=fitter,
198                         handler= fitresult.ConsoleUpdate(improvement_delta=0.1))
199        pvec = result.pvec
200        cov = self.problem.cov(pvec)
201        return result.fitness,pvec,numpy.sqrt(numpy.diag(cov))
202
203   
204    def set_model(self,model,Uid):
205        """ Set model """
206       
207        if self.fitArrangeList.has_key(Uid):
208            self.fitArrangeList[Uid].set_model(model)
209        else:
210            fitproblem= FitArrange()
211            fitproblem.set_model(model)
212            self.fitArrangeList[Uid]=fitproblem
213       
214    def set_data(self,data,Uid):
215        """ Receive plottable and create a list of data to fit"""
216       
217        if self.fitArrangeList.has_key(Uid):
218            self.fitArrangeList[Uid].add_data(data)
219        else:
220            fitproblem= FitArrange()
221            fitproblem.add_data(data)
222            self.fitArrangeList[Uid]=fitproblem
223           
224    def get_model(self,Uid):
225        """ return list of data"""
226        return self.fitArrangeList[Uid]
227   
228    def set_param(self,model,name, pars):
229        """ Recieve a dictionary of parameter and save it """
230        parameters=[]
231        if model==None:
232            raise ValueError, "Cannot set parameters for empty model"
233        else:
234            model.name=name
235            for key, value in pars.iteritems():
236                param = Parameter(model, key, value)
237                parameters.append(param)
238        return parameters
239   
240    def remove_data(self,Uid,data=None):
241        """ remove one or all data"""
242        if data==None:# remove all element in data list
243            if self.fitArrangeList.has_key(Uid):
244                self.fitArrangeList[Uid].remove_datalist()
245        else:
246            if self.fitArrangeList.has_key(Uid):
247                self.fitArrangeList[Uid].remove_data(data)
248               
249    def remove_model(self,Uid):
250        """ remove model """
251        if self.fitArrangeList.has_key(Uid):
252            self.fitArrangeList[Uid].remove_model()
253               
254               
255    def _concatenateData(self, listdata=[]):
256        """ concatenate each fields of all Data contains ins listdata
257         return data
258        """
259        if listdata==[]:
260            raise ValueError, " data list missing"
261        else:
262            xtemp=[]
263            ytemp=[]
264            dytemp=[]
265            resid=[]
266            resid_deriv=[]
267           
268            for data in listdata:
269                for i in range(len(data.x)):
270                    if not data.x[i] in xtemp:
271                        xtemp.append(data.x[i])
272                       
273                    if not data.y[i] in ytemp:
274                        ytemp.append(data.y[i])
275                       
276                    if not data.dy[i] in dytemp:
277                        dytemp.append(data.dy[i])
278                   
279                   
280            newplottable= Data1D(xtemp,ytemp,None,dytemp)
281            newdata=Data(newplottable)
282           
283            #print "this is new data",newdata.dy
284            return newdata
285class Parameter:
286    """
287        Class to handle model parameters
288    """
289    def __init__(self, model, name, value=None):
290            self.model = model
291            self.name = name
292            if not value==None:
293                self.model.setParam(self.name, value)
294           
295    def set(self, value):
296        """
297            Set the value of the parameter
298        """
299        self.model.setParam(self.name, value)
300
301    def __call__(self):
302        """
303            Return the current value of the parameter
304        """
305        return self.model.getParam(self.name)
306   
307
308
309   
310   
311   
Note: See TracBrowser for help on using the repository browser.