source: sasview/park_integration/ParkFitting.py @ 6aaf444

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 6aaf444 was 202f93a, checked in by Gervaise Alina <gervyh@…>, 16 years ago

fixing fit , only fit parameters newly set in model

  • Property mode set to 100644
File size: 7.0 KB
Line 
1"""
2    @organization: ParkFitting module contains SansParameter,Model,Data
3    FitArrange, ParkFit,Parameter classes.All listed classes work together to perform a
4    simple fit with park optimizer.
5"""
6import time
7import numpy
8
9import park
10from park import fit,fitresult
11from park import assembly
12from park.fitmc import FitSimplex, FitMC
13
14from sans.guitools.plottables import Data1D
15from Loader import Load
16from AbstractFitEngine import FitEngine, Parameter, FitArrange
17class SansParameter(park.Parameter):
18    """
19        SANS model parameters for use in the PARK fitting service.
20        The parameter attribute value is redirected to the underlying
21        parameter value in the SANS model.
22    """
23    def __init__(self, name, model):
24         self._model, self._name = model,name
25         self.set(model.getParam(name))
26         
27    def _getvalue(self): return self._model.getParam(self.name)
28   
29    def _setvalue(self,value): 
30        self._model.setParam(self.name, value)
31       
32    value = property(_getvalue,_setvalue)
33   
34    def _getrange(self):
35        lo,hi = self._model.details[self.name][1:]
36        if lo is None: lo = -numpy.inf
37        if hi is None: hi = numpy.inf
38        return lo,hi
39   
40    def _setrange(self,r):
41        self._model.details[self.name][1:] = r
42    range = property(_getrange,_setrange)
43
44
45class Model(object):
46    """
47        PARK wrapper for SANS models.
48    """
49    def __init__(self, sans_model):
50        self.model = sans_model
51        sansp = sans_model.getParamList()
52       
53        parkp = [SansParameter(p,sans_model) for p in sansp]
54        self.parameterset = park.ParameterSet(sans_model.name,pars=parkp)
55       
56    def eval(self,x):
57        return self.model.run(x)
58   
59class Data(object):
60    """ Wrapper class  for SANS data """
61    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
62        if not sans_data==None:
63            self.x= sans_data.x
64            self.y= sans_data.y
65            self.dx= sans_data.dx
66            self.dy= sans_data.dy
67        else:
68            if x!=None and y!=None and dy!=None:
69                self.x=x
70                self.y=y
71                self.dx=dx
72                self.dy=dy
73            else:
74                raise ValueError,\
75                "Data is missing x, y or dy, impossible to compute residuals later on"
76        self.qmin=None
77        self.qmax=None
78       
79    def setFitRange(self,mini=None,maxi=None):
80        """ to set the fit range"""
81        self.qmin=mini
82        self.qmax=maxi
83       
84    def residuals(self, fn):
85        """ @param fn: function that return model value
86            @return residuals
87        """
88        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
89        if self.qmin==None and self.qmax==None: 
90            self.fx = fn(x)
91            return (y - fn(x))/dy
92       
93        else:
94            self.fx = fn(x[idx])
95            idx = x>=self.qmin & x <= self.qmax
96            return (y[idx] - fn(x[idx]))/dy[idx]
97           
98         
99    def residuals_deriv(self, model, pars=[]):
100        """
101            @return residuals derivatives .
102            @note: in this case just return empty array
103        """
104        return []
105
106           
107class ParkFit(FitEngine):
108    """
109        ParkFit performs the Fit.This class can be used as follow:
110        #Do the fit Park
111        create an engine: engine = ParkFit()
112        Use data must be of type plottable
113        Use a sans model
114       
115        Add data with a dictionnary of FitArrangeList where Uid is a key and data
116        is saved in FitArrange object.
117        engine.set_data(data,Uid)
118       
119        Set model parameter "M1"= model.name add {model.parameter.name:value}.
120        @note: Set_param() if used must always preceded set_model()
121             for the fit to be performed.
122        engine.set_param( model,"M1", {'A':2,'B':4})
123       
124        Add model with a dictionnary of FitArrangeList{} where Uid is a key and model
125        is save in FitArrange object.
126        engine.set_model(model,Uid)
127       
128        engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
129        chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
130        @note: {model.parameter.name:value} is ignored in fit function since
131        the user should make sure to call set_param himself.
132    """
133    def __init__(self,data=[]):
134        """
135            Creates a dictionary (self.fitArrangeList={})of FitArrange elements
136            with Uid as keys
137        """
138        self.fitArrangeList={}
139       
140    def createProblem(self):
141        """
142        Extract sansmodel and sansdata from self.FitArrangelist ={Uid:FitArrange}
143        Create parkmodel and park data ,form a list couple of parkmodel and parkdata
144        create an assembly self.problem=  park.Assembly([(parkmodel,parkdata)])
145        """
146        mylist=[]
147        listmodel=[]
148       
149        for k,value in self.fitArrangeList.iteritems():
150            sansmodel=value.get_model()
151            #wrap sans model
152            parkmodel = Model(sansmodel)
153            for p in parkmodel.parameterset:
154                #self.param_list.append(p._getname())
155                if p.isfixed() and p._getname()in self.paramList:
156                    p.set([-numpy.inf,numpy.inf])
157               
158            Ldata=value.get_data()
159            x,y,dy=self._concatenateData(Ldata)
160            #wrap sansdata
161            parkdata=Data(x,y,dy,None)
162            couple=(parkmodel,parkdata)
163            mylist.append(couple)
164       
165        self.problem =  park.Assembly(mylist)
166       
167   
168    def fit(self, qmin=None, qmax=None):
169        """
170            Performs fit with park.fit module.It can  perform fit with one model
171            and a set of data, more than two fit of  one model and sets of data or
172            fit with more than two model associated with their set of data and constraints
173           
174           
175            @param pars: Dictionary of parameter names for the model and their values.
176            @param qmin: The minimum value of data's range to be fit
177            @param qmax: The maximum value of data's range to be fit
178            @note:all parameter are ignored most of the time.Are just there to keep ScipyFit
179            and ParkFit interface the same.
180            @return result.fitness: Value of the goodness of fit metric
181            @return result.pvec: list of parameter with the best value found during fitting
182            @return result.cov: Covariance matrix
183        """
184
185        self.createProblem()
186        pars=self.problem.fit_parameters()
187        self.problem.eval()
188   
189        localfit = FitSimplex()
190        localfit.ftol = 1e-8
191        fitter = FitMC(localfit=localfit)
192
193        result = fit.fit(self.problem,
194                         fitter=fitter,
195                         handler= fitresult.ConsoleUpdate(improvement_delta=0.1))
196       
197        return result.fitness,result.pvec,result.cov
198   
199   
Note: See TracBrowser for help on using the repository browser.