source: sasview/park_integration/ParkFitting.py @ ee5b04c

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since ee5b04c was ee5b04c, checked in by Gervaise Alina <gervyh@…>, 16 years ago

comment added

  • Property mode set to 100644
File size: 8.2 KB
Line 
1"""
2    @organization: ParkFitting module contains SansParameter,Model,Data
3    FitArrange, ParkFit,Parameter classes.All listed classes work together to perform a
4    simple fit with park optimizer.
5"""
6import time
7import numpy
8
9import park
10from park import fit,fitresult
11from park import assembly
12from park.fitmc import FitSimplex, FitMC
13
14from sans.guitools.plottables import Data1D
15from Loader import Load
16from AbstractFitEngine import FitEngine, Parameter, FitArrange
17class SansParameter(park.Parameter):
18    """
19        SANS model parameters for use in the PARK fitting service.
20        The parameter attribute value is redirected to the underlying
21        parameter value in the SANS model.
22    """
23    def __init__(self, name, model):
24         self._model, self._name = model,name
25         self.set(model.getParam(name))
26         
27    def _getvalue(self): return self._model.getParam(self.name)
28   
29    def _setvalue(self,value): 
30        self._model.setParam(self.name, value)
31       
32    value = property(_getvalue,_setvalue)
33   
34    def _getrange(self):
35        lo,hi = self._model.details[self.name][1:]
36        if lo is None: lo = -numpy.inf
37        if hi is None: hi = numpy.inf
38        return lo,hi
39   
40    def _setrange(self,r):
41        self._model.details[self.name][1:] = r
42    range = property(_getrange,_setrange)
43
44
45class Model(object):
46    """
47        PARK wrapper for SANS models.
48    """
49    def __init__(self, sans_model):
50        self.model = sans_model
51        #print "ParkFitting:sans model",self.model
52        sansp = sans_model.getParamList()
53        #print "ParkFitting: sans model parameter list",sansp
54        parkp = [SansParameter(p,sans_model) for p in sansp]
55        #print "ParkFitting: park model parameter ",parkp
56        self.parameterset = park.ParameterSet(sans_model.name,pars=parkp)
57       
58    def eval(self,x):
59        #print "eval",self.parameterset[0].value,self.parameterset[1].value
60        #print "model run ",self.model.run(x)
61        return self.model.run(x)
62   
63class Data(object):
64    """ Wrapper class  for SANS data """
65    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
66        if not sans_data==None:
67            self.x= sans_data.x
68            self.y= sans_data.y
69            self.dx= sans_data.dx
70            self.dy= sans_data.dy
71        else:
72            if x!=None and y!=None and dy!=None:
73                self.x=x
74                self.y=y
75                self.dx=dx
76                self.dy=dy
77            else:
78                raise ValueError,\
79                "Data is missing x, y or dy, impossible to compute residuals later on"
80        self.qmin=None
81        self.qmax=None
82       
83    def setFitRange(self,mini=None,maxi=None):
84        """ to set the fit range"""
85        self.qmin=mini
86        self.qmax=maxi
87       
88    def residuals(self, fn):
89        """ @param fn: function that return model value
90            @return residuals
91        """
92        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
93        if self.qmin==None and self.qmax==None: 
94            self.fx = fn(x)
95            return (y - fn(x))/dy
96       
97        else:
98            self.fx = fn(x[idx])
99            idx = x>=self.qmin & x <= self.qmax
100            return (y[idx] - fn(x[idx]))/dy[idx]
101           
102         
103    def residuals_deriv(self, model, pars=[]):
104        """
105            @return residuals derivatives .
106            @note: in this case just return empty array
107        """
108        return []
109
110           
111class ParkFit(FitEngine):
112    """
113        ParkFit performs the Fit.This class can be used as follow:
114        #Do the fit Park
115        create an engine: engine = ParkFit()
116        Use data must be of type plottable
117        Use a sans model
118       
119        Add data with a dictionnary of FitArrangeList where Uid is a key and data
120        is saved in FitArrange object.
121        engine.set_data(data,Uid)
122       
123        Set model parameter "M1"= model.name add {model.parameter.name:value}.
124        @note: Set_param() if used must always preceded set_model()
125             for the fit to be performed.
126        engine.set_param( model,"M1", {'A':2,'B':4})
127       
128        Add model with a dictionnary of FitArrangeList{} where Uid is a key and model
129        is save in FitArrange object.
130        engine.set_model(model,Uid)
131       
132        engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
133        chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
134        @note: {model.parameter.name:value} is ignored in fit function since
135        the user should make sure to call set_param himself.
136    """
137    def __init__(self,data=[]):
138        """
139            Creates a dictionary (self.fitArrangeList={})of FitArrange elements
140            with Uid as keys
141        """
142        self.fitArrangeList={}
143        self.paramList=[]
144       
145    def createProblem(self):
146        """
147        Extract sansmodel and sansdata from self.FitArrangelist ={Uid:FitArrange}
148        Create parkmodel and park data ,form a list couple of parkmodel and parkdata
149        create an assembly self.problem=  park.Assembly([(parkmodel,parkdata)])
150        """
151        print "ParkFitting: In createproblem"
152        mylist=[]
153        listmodel=[]
154        i=0
155        for k,value in self.fitArrangeList.iteritems():
156            sansmodel=value.get_model()
157            #wrap sans model
158            parkmodel = Model(sansmodel)
159            #print "ParkFitting: createproblem: just create a model",parkmodel.parameterset
160            for p in parkmodel.parameterset:
161                #self.param_list.append(p._getname())
162                #if p.isfixed():
163                #print 'parameters',p.name
164                #print "self.paramList",self.paramList
165                if p.isfixed() and p._getname()in self.paramList:
166                    p.set([-numpy.inf,numpy.inf])
167            i+=1   
168            Ldata=value.get_data()
169            x,y,dy=self._concatenateData(Ldata)
170            #wrap sansdata
171            parkdata=Data(x,y,dy,None)
172            couple=(parkmodel,parkdata)
173            #print "Parkfitting: fitness",couple   
174            mylist.append(couple)
175        #print "mylist",mylist
176        self.problem =  park.Assembly(mylist)
177       
178   
179    def fit(self, qmin=None, qmax=None):
180        """
181            Performs fit with park.fit module.It can  perform fit with one model
182            and a set of data, more than two fit of  one model and sets of data or
183            fit with more than two model associated with their set of data and constraints
184           
185           
186            @param pars: Dictionary of parameter names for the model and their values.
187            @param qmin: The minimum value of data's range to be fit
188            @param qmax: The maximum value of data's range to be fit
189            @note:all parameter are ignored most of the time.Are just there to keep ScipyFit
190            and ParkFit interface the same.
191            @return result.fitness: Value of the goodness of fit metric
192            @return result.pvec: list of parameter with the best value found during fitting
193            @return result.cov: Covariance matrix
194        """
195        #from numpy.linalg.linalg.LinAlgError import LinAlgError
196        #print "Parkfitting: fit method probably breaking just right before \
197        #call fit"
198        self.createProblem()
199        pars=self.problem.fit_parameters()
200        self.problem.eval()
201        #print "M0.B",self.problem[1].parameterset['B'].value,self.problem[0].parameterset['B'].value
202
203        localfit = FitSimplex()
204        localfit.ftol = 1e-8
205        fitter = FitMC(localfit=localfit)
206       
207        result = fit.fit(self.problem,
208                     fitter=fitter,
209                     handler= fitresult.ConsoleUpdate(improvement_delta=0.1))
210        print "ParkFitting: result",result
211        if result !=None:
212            #for p in result.parameters:
213            #    print "fit in park fitting", p.name, p.value,p.stderr
214            return result.fitness,result.pvec,result.cov,result
215        else:
216            raise ValueError, "SVD did not converge"
217           
218       
219       
220   
221   
Note: See TracBrowser for help on using the repository browser.