source: sasview/park_integration/ParkFitting.py @ 75b40ce

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 75b40ce was df58d26f, checked in by Gervaise Alina <gervyh@…>, 16 years ago

removed one print

  • Property mode set to 100644
File size: 6.9 KB
RevLine 
[792db7d5]1"""
2    @organization: ParkFitting module contains SansParameter,Model,Data
3    FitArrange, ParkFit,Parameter classes.All listed classes work together to perform a
4    simple fit with park optimizer.
5"""
[7705306]6import time
7import numpy
[792db7d5]8
[7705306]9import park
10from park import fit,fitresult
11from park import assembly
[cf3b781]12from park.fitmc import FitSimplex, FitMC
[7705306]13
14from sans.guitools.plottables import Data1D
15from Loader import Load
[d4b0687]16from AbstractFitEngine import FitEngine, Parameter, FitArrange
[7705306]17class SansParameter(park.Parameter):
18    """
[792db7d5]19        SANS model parameters for use in the PARK fitting service.
20        The parameter attribute value is redirected to the underlying
21        parameter value in the SANS model.
[7705306]22    """
23    def __init__(self, name, model):
24         self._model, self._name = model,name
[9e85792]25         self.set(model.getParam(name))
[792db7d5]26         
[7705306]27    def _getvalue(self): return self._model.getParam(self.name)
[792db7d5]28   
[9e85792]29    def _setvalue(self,value): 
30        self._model.setParam(self.name, value)
[792db7d5]31       
[7705306]32    value = property(_getvalue,_setvalue)
[792db7d5]33   
[7705306]34    def _getrange(self):
35        lo,hi = self._model.details[self.name][1:]
36        if lo is None: lo = -numpy.inf
37        if hi is None: hi = numpy.inf
38        return lo,hi
[792db7d5]39   
[7705306]40    def _setrange(self,r):
41        self._model.details[self.name][1:] = r
42    range = property(_getrange,_setrange)
43
[792db7d5]44
[7705306]45class Model(object):
46    """
47        PARK wrapper for SANS models.
48    """
49    def __init__(self, sans_model):
50        self.model = sans_model
51        sansp = sans_model.getParamList()
52        parkp = [SansParameter(p,sans_model) for p in sansp]
53        self.parameterset = park.ParameterSet(sans_model.name,pars=parkp)
[792db7d5]54       
[7705306]55    def eval(self,x):
56        return self.model.run(x)
57   
58class Data(object):
59    """ Wrapper class  for SANS data """
[792db7d5]60    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
61        if not sans_data==None:
62            self.x= sans_data.x
63            self.y= sans_data.y
64            self.dx= sans_data.dx
65            self.dy= sans_data.dy
66        else:
67            if x!=None and y!=None and dy!=None:
68                self.x=x
69                self.y=y
70                self.dx=dx
71                self.dy=dy
72            else:
73                raise ValueError,\
74                "Data is missing x, y or dy, impossible to compute residuals later on"
[7705306]75        self.qmin=None
76        self.qmax=None
77       
78    def setFitRange(self,mini=None,maxi=None):
79        """ to set the fit range"""
80        self.qmin=mini
81        self.qmax=maxi
82       
83    def residuals(self, fn):
[792db7d5]84        """ @param fn: function that return model value
85            @return residuals
86        """
[7705306]87        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
88        if self.qmin==None and self.qmax==None: 
[cf3b781]89            self.fx = fn(x)
[7705306]90            return (y - fn(x))/dy
91       
92        else:
[cf3b781]93            self.fx = fn(x[idx])
[7705306]94            idx = x>=self.qmin & x <= self.qmax
95            return (y[idx] - fn(x[idx]))/dy[idx]
96           
97         
98    def residuals_deriv(self, model, pars=[]):
[792db7d5]99        """
100            @return residuals derivatives .
101            @note: in this case just return empty array
102        """
[7705306]103        return []
[d4b0687]104
[792db7d5]105           
[4c718654]106class ParkFit(FitEngine):
[7705306]107    """
[792db7d5]108        ParkFit performs the Fit.This class can be used as follow:
109        #Do the fit Park
110        create an engine: engine = ParkFit()
111        Use data must be of type plottable
112        Use a sans model
113       
114        Add data with a dictionnary of FitArrangeList where Uid is a key and data
115        is saved in FitArrange object.
116        engine.set_data(data,Uid)
117       
118        Set model parameter "M1"= model.name add {model.parameter.name:value}.
119        @note: Set_param() if used must always preceded set_model()
120             for the fit to be performed.
121        engine.set_param( model,"M1", {'A':2,'B':4})
122       
123        Add model with a dictionnary of FitArrangeList{} where Uid is a key and model
124        is save in FitArrange object.
125        engine.set_model(model,Uid)
126       
127        engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
128        chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
129        @note: {model.parameter.name:value} is ignored in fit function since
130        the user should make sure to call set_param himself.
[7705306]131    """
132    def __init__(self,data=[]):
[792db7d5]133        """
134            Creates a dictionary (self.fitArrangeList={})of FitArrange elements
135            with Uid as keys
136        """
[7705306]137        self.fitArrangeList={}
[792db7d5]138       
[4dd63eb]139    def createProblem(self):
[7705306]140        """
[792db7d5]141        Extract sansmodel and sansdata from self.FitArrangelist ={Uid:FitArrange}
142        Create parkmodel and park data ,form a list couple of parkmodel and parkdata
143        create an assembly self.problem=  park.Assembly([(parkmodel,parkdata)])
[7705306]144        """
145        mylist=[]
[9e85792]146        listmodel=[]
[7705306]147        for k,value in self.fitArrangeList.iteritems():
[9e85792]148            sansmodel=value.get_model()
[792db7d5]149            #wrap sans model
[9e85792]150            parkmodel = Model(sansmodel)
151            for p in parkmodel.parameterset:
152                if p.isfixed():
153                    p.set([-numpy.inf,numpy.inf])
154               
[7705306]155            Ldata=value.get_data()
[d4b0687]156            x,y,dy=self._concatenateData(Ldata)
[792db7d5]157            #wrap sansdata
[d4b0687]158            parkdata=Data(x,y,dy,None)
[792db7d5]159            couple=(parkmodel,parkdata)
[7705306]160            mylist.append(couple)
[792db7d5]161       
[cf3b781]162        self.problem =  park.Assembly(mylist)
[792db7d5]163       
[7705306]164   
[4dd63eb]165    def fit(self, qmin=None, qmax=None):
[7705306]166        """
[792db7d5]167            Performs fit with park.fit module.It can  perform fit with one model
168            and a set of data, more than two fit of  one model and sets of data or
169            fit with more than two model associated with their set of data and constraints
170           
171           
172            @param pars: Dictionary of parameter names for the model and their values.
173            @param qmin: The minimum value of data's range to be fit
174            @param qmax: The maximum value of data's range to be fit
175            @note:all parameter are ignored most of the time.Are just there to keep ScipyFit
176            and ParkFit interface the same.
177            @return result.fitness: Value of the goodness of fit metric
178            @return result.pvec: list of parameter with the best value found during fitting
179            @return result.cov: Covariance matrix
[7705306]180        """
[cf3b781]181
[792db7d5]182       
[4dd63eb]183        self.createProblem()
[cf3b781]184        pars=self.problem.fit_parameters()
185        self.problem.eval()
[792db7d5]186   
[cf3b781]187        localfit = FitSimplex()
188        localfit.ftol = 1e-8
189        fitter = FitMC(localfit=localfit)
190
191        result = fit.fit(self.problem,
192                         fitter=fitter,
193                         handler= fitresult.ConsoleUpdate(improvement_delta=0.1))
[df58d26f]194       
[792db7d5]195        return result.fitness,result.pvec,result.cov
[7705306]196   
[d4b0687]197   
Note: See TracBrowser for help on using the repository browser.