source: sasview/src/sans/fit/ParkFitting.py @ 8d074d9

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 8d074d9 was 8d074d9, checked in by pkienzle, 10 years ago

refactor fit internals, enabling disperser parameters

  • Property mode set to 100644
File size: 21.7 KB
Line 
1
2
3
4"""
5ParkFitting module contains SansParameter,Model,Data
6FitArrange, ParkFit,Parameter classes.All listed classes work together
7to perform a simple fit with park optimizer.
8"""
9#import time
10import numpy
11import math
12from  numpy.linalg.linalg import LinAlgError
13#import park
14from park import fit
15from park import fitresult
16from  park.fitresult import FitParameter
17import park.simplex
18from park.assembly import Assembly
19from park.assembly import Part
20from park.fitmc import FitSimplex
21import park.fitmc
22from park.fit import Fitter
23from park.formatnum import format_uncertainty
24from sans.fit.AbstractFitEngine import FitEngine
25from sans.fit.AbstractFitEngine import FResult
26
27class SansParameter(park.Parameter):
28    """
29    SANS model parameters for use in the PARK fitting service.
30    The parameter attribute value is redirected to the underlying
31    parameter value in the SANS model.
32    """
33    def __init__(self, name, model, data):
34        """
35            :param name: the name of the model parameter
36            :param model: the sans model to wrap as a park model
37        """
38        park.Parameter.__init__(self, name)
39        #self._model, self._name = model, name
40        self.data = data
41        self.model = model
42        #set the value for the parameter of the given name
43        self.set(model.getParam(name))
44
45        # TODO: model is missing parameter ranges for dispersion parameters
46        if name not in model.details:
47            #print "setting details for",name
48            model.details[name] = ["", None, None]
49
50    def _getvalue(self):
51        """
52        override the _getvalue of park parameter
53
54        :return value the parameter associates with self.name
55
56        """
57        return self.model.getParam(self.name)
58
59    def _setvalue(self, value):
60        """
61        override the _setvalue pf park parameter
62
63        :param value: the value to set on a given parameter
64
65        """
66        self.model.setParam(self.name, value)
67
68    value = property(_getvalue, _setvalue)
69
70    def _getrange(self):
71        """
72        Override _getrange of park parameter
73        return the range of parameter
74        """
75        #if not  self.name in self._model.getDispParamList():
76        lo, hi = self.model.details[self.name][1:3]
77        if lo is None: lo = -numpy.inf
78        if hi is None: hi = numpy.inf
79        if lo > hi:
80            raise ValueError, "wrong fit range for parameters"
81
82        return lo, hi
83
84    def get_name(self):
85        """
86        """
87        return self._getname()
88
89    def _setrange(self, r):
90        """
91        override _setrange of park parameter
92
93        :param r: the value of the range to set
94
95        """
96        self.model.details[self.name][1:3] = r
97    range = property(_getrange, _setrange)
98
99
100class ParkModel(park.Model):
101    """
102    PARK wrapper for SANS models.
103    """
104    def __init__(self, sans_model, sans_data=None, **kw):
105        """
106        :param sans_model: the sans model to wrap using park interface
107
108        """
109        park.Model.__init__(self, **kw)
110        self.model = sans_model
111        self.name = sans_model.name
112        self.data = sans_data
113        #list of parameters names
114        self.sansp = sans_model.getParamList()
115        #list of park parameter
116        self.parkp = [SansParameter(p, sans_model, sans_data) for p in self.sansp]
117        #list of parameter set
118        self.parameterset = park.ParameterSet(sans_model.name, pars=self.parkp)
119        self.pars = []
120
121    def get_params(self, fitparams):
122        """
123        return a list of value of paramter to fit
124
125        :param fitparams: list of paramaters name to fit
126
127        """
128        list_params = []
129        self.pars = fitparams
130        for item in fitparams:
131            for element in self.parkp:
132                if element.name == str(item):
133                    list_params.append(element.value)
134        return list_params
135
136    def set_params(self, paramlist, params):
137        """
138        Set value for parameters to fit
139
140        :param params: list of value for parameters to fit
141
142        """
143        try:
144            for i in range(len(self.parkp)):
145                for j in range(len(paramlist)):
146                    if self.parkp[i].name == paramlist[j]:
147                        self.parkp[i].value = params[j]
148                        self.model.setParam(self.parkp[i].name, params[j])
149        except:
150            raise
151
152    def eval(self, x):
153        """
154            Override eval method of park model.
155
156            :param x: the x value used to compute a function
157        """
158        try:
159            return self.model.evalDistribution(x)
160        except:
161            raise
162
163    def eval_derivs(self, x, pars=[]):
164        """
165        Evaluate the model and derivatives wrt pars at x.
166
167        pars is a list of the names of the parameters for which derivatives
168        are desired.
169
170        This method needs to be specialized in the model to evaluate the
171        model function.  Alternatively, the model can implement is own
172        version of residuals which calculates the residuals directly
173        instead of calling eval.
174        """
175        return []
176
177
178class SansFitResult(fitresult.FitResult):
179    def __init__(self, *args, **kwrds):
180        fitresult.FitResult.__init__(self, *args, **kwrds)
181        self.theory = None
182        self.inputs = []
183       
184class SansFitSimplex(FitSimplex):
185    """
186    Local minimizer using Nelder-Mead simplex algorithm.
187
188    Simplex is robust and derivative free, though not very efficient.
189
190    This class wraps the bounds contrained Nelder-Mead simplex
191    implementation for `park.simplex.simplex`.
192    """
193    radius = 0.05
194    """Size of the initial simplex; this is a portion between 0 and 1"""
195    xtol = 1
196    #xtol = 1e-4
197    """Stop when simplex vertices are within xtol of each other"""
198    ftol = 5e-5
199    """Stop when vertex values are within ftol of each other"""
200    maxiter = None
201    """Maximum number of iterations before fit terminates"""
202    def __init__(self, ftol=5e-5):
203        self.ftol = ftol
204       
205    def fit(self, fitness, x0):
206        """Run the fit"""
207        self.cancel = False
208        pars = fitness.fit_parameters()
209        bounds = numpy.array([p.range for p in pars]).T
210        result = park.simplex.simplex(fitness, x0, bounds=bounds,
211                                 radius=self.radius, xtol=self.xtol,
212                                 ftol=self.ftol, maxiter=self.maxiter,
213                                 abort_test=self._iscancelled)
214        #print "calls:",result.calls
215        #print "simplex returned",result.x,result.fx
216        # Need to make our own copy of the fit results so that the
217        # values don't get stomped on by the next fit iteration.
218        fitpars = [SansFitParameter(pars[i].name,pars[i].range,v, pars[i].model, pars[i].data)
219                   for i,v in enumerate(result.x)]
220        res = SansFitResult(fitpars, result.calls, result.fx)
221        res.inputs = [(pars[i].model, pars[i].data) for i,v in enumerate(result.x)]
222        # Compute the parameter uncertainties from the jacobian
223        res.calc_cov(fitness)
224        return res
225     
226class SansFitter(Fitter):
227    """
228    """
229    def fit(self, fitness, handler):
230        """
231        Global optimizer.
232
233        This function should return immediately
234        """
235        # Determine initial value and bounds
236        pars = fitness.fit_parameters()
237        bounds = numpy.array([p.range for p in pars]).T
238        x0 = [p.value for p in pars]
239
240        # Initialize the monitor and results.
241        # Need to make our own copy of the fit results so that the
242        # values don't get stomped on by the next fit iteration.
243        handler.done = False
244        self.handler = handler
245        fitpars = [SansFitParameter(pars[i].name, pars[i].range, v,
246                                     pars[i].model, pars[i].data)
247                   for i,v in enumerate(x0)]
248        handler.result = fitresult.FitResult(fitpars, 0, numpy.NaN)
249
250        # Run the fit (fit should perform _progress and _improvement updates)
251        # This function may return before the fit is complete.
252        self._fit(fitness, x0, bounds)
253       
254class SansFitMC(SansFitter):
255    """
256    Monte Carlo optimizer.
257
258    This implements `park.fit.Fitter`.
259    """
260    localfit = SansFitSimplex()
261    start_points = 10
262    def __init__(self, localfit, start_points=10):
263        self.localfit = localfit
264        self.start_points = start_points
265       
266    def _fit(self, objective, x0, bounds):
267        """
268        Run a monte carlo fit.
269
270        This procedure maps a local optimizer across a set of initial points.
271        """
272        try:
273            park.fitmc.fitmc(objective, x0, bounds, self.localfit,
274                             self.start_points, self.handler)
275        except:
276            raise ValueError, "Fit did not converge.\n"
277       
278class SansPart(Part):
279    """
280    Part of a fitting assembly.  Part holds the model itself and
281    associated data.  The part can be initialized with a fitness
282    object or with a pair (model,data) for the default fitness function.
283
284    fitness (Fitness)
285        object implementing the `park.assembly.Fitness` interface.  In
286        particular, fitness should provide a parameterset attribute
287        containing a ParameterSet and a residuals method returning a vector
288        of residuals.
289    weight (dimensionless)
290        weight for the model.  See comments in assembly.py for details.
291    isfitted (boolean)
292        True if the model residuals should be included in the fit.
293        The model parameters may still be used in parameter
294        expressions, but there will be no comparison to the data.
295    residuals (vector)
296        Residuals for the model if they have been calculated, or None
297    degrees_of_freedom
298        Number of residuals minus number of fitted parameters.
299        Degrees of freedom for individual models does not make
300        sense in the presence of expressions combining models,
301        particularly in the case where a model has many parameters
302        but no data or many computed parameters.  The degrees of
303        freedom for the model is set to be at least one.
304    chisq
305        sum(residuals**2); use chisq/degrees_of_freedom to
306        get the reduced chisq value.
307
308        Get/set the weight on the given model.
309
310        assembly.weight(3) returns the weight on model 3 (0-origin)
311        assembly.weight(3,0.5) sets the weight on model 3 (0-origin)
312    """
313
314    def __init__(self, fitness, weight=1., isfitted=True):
315        Part.__init__(self, fitness=fitness, weight=weight,
316                       isfitted=isfitted)
317       
318        self.model, self.data = fitness[0], fitness[1]
319
320class SansFitParameter(FitParameter):
321    """
322    Fit result for an individual parameter.
323    """
324    def __init__(self, name, range, value, model, data):
325        FitParameter.__init__(self, name, range, value)
326        self.model = model
327        self.data = data
328       
329    def summarize(self):
330        """
331        Return parameter range string.
332
333        E.g.,  "       Gold .....|.... 5.2043 in [2,7]"
334        """
335        bar = ['.']*10
336        lo,hi = self.range
337        if numpy.isfinite(lo)and numpy.isfinite(hi):
338            portion = (self.value-lo)/(hi-lo)
339            if portion < 0: portion = 0.
340            elif portion >= 1: portion = 0.99999999
341            barpos = int(math.floor(portion*len(bar)))
342            bar[barpos] = '|'
343        bar = "".join(bar)
344        lostr = "[%g"%lo if numpy.isfinite(lo) else "(-inf"
345        histr = "%g]"%hi if numpy.isfinite(hi) else "inf)"
346        valstr = format_uncertainty(self.value, self.stderr)
347        model_name = str(None)
348        if self.model is not None:
349            model_name = self.model.name
350        data_name = str(None)
351        if self.data is not None:
352            data_name = self.data.name
353           
354        return "%25s %s %s in %s,%s, %s, %s"  % (self.name,bar,valstr,lostr,histr, 
355                                                 model_name, data_name)
356    def __repr__(self):
357        #return "FitParameter('%s')"%self.name
358        return str(self.__class__)
359   
360class MyAssembly(Assembly):
361    def __init__(self, models, curr_thread=None):
362        """Build an assembly from a list of models."""
363        self.parts = []
364        for m in models:
365            self.parts.append(SansPart(m))
366        self.curr_thread = curr_thread
367        self.chisq = None
368        self._cancel = False
369        self.theory = None
370        self._reset()
371       
372    def fit_parameters(self):
373        """
374        Return an alphabetical list of the fitting parameters.
375
376        This function is called once at the beginning of a fit,
377        and serves as a convenient place to precalculate what
378        can be precalculated such as the set of fitting parameters
379        and the parameter expressions evaluator.
380        """
381        self.parameterset.setprefix()
382        self._fitparameters = self.parameterset.fitted
383        self._restraints = self.parameterset.restrained
384        pars = self.parameterset.flatten()
385        context = self.parameterset.gather_context()
386        self._fitexpression = park.expression.build_eval(pars,context)
387        #print "constraints",self._fitexpression.__doc__
388
389        self._fitparameters.sort(lambda a,b: cmp(a.path,b.path))
390        # Convert to fitparameter a object
391       
392        fitpars = [SansFitParameter(p.path,p.range,p.value, p.model, p.data)
393                   for p in self._fitparameters]
394        #print "fitpars", fitpars
395        return fitpars
396   
397    def extend_results_with_calculated_parameters(self, result):
398        """
399        Extend result from the fit with the calculated parameters.
400        """
401        calcpars = [SansFitParameter(p.path,p.range,p.value, p.model, p.data)
402                    for p in self.parameterset.computed]
403        result.parameters += calcpars
404        result.theory = self.theory
405
406    def eval(self):
407        """
408        Recalculate the theory functions, and from them, the
409        residuals and chisq.
410
411        :note: Call this after the parameters have been updated.
412        """
413        # Handle abort from a separate thread.
414        self._cancel = False
415        if self.curr_thread != None:
416            try:
417                self.curr_thread.isquit()
418            except:
419                self._cancel = True
420
421        # Evaluate the computed parameters
422        try:
423            self._fitexpression()
424        except NameError:
425            pass
426
427        # Check that the resulting parameters are in a feasible region.
428        if not self.isfeasible(): return numpy.inf
429
430        resid = []
431        k = len(self._fitparameters)
432        for m in self.parts:
433            # In order to support abort, need to be able to propagate an
434            # external abort signal from self.abort() into an abort signal
435            # for the particular model.  Can't see a way to do this which
436            # doesn't involve setting a state variable.
437            self._current_model = m
438            if self._cancel: return numpy.inf
439            if m.isfitted and m.weight != 0:
440                m.residuals, self.theory = m.fitness.residuals()
441                N = len(m.residuals)
442                m.degrees_of_freedom = N-k if N>k else 1
443                # dividing residuals by N in order to be consistent with Scipy
444                m.chisq = numpy.sum(m.residuals**2/N) 
445                resid.append(m.weight*m.residuals)
446        self.residuals = numpy.hstack(resid)
447        N = len(self.residuals)
448        self.degrees_of_freedom = N-k if N>k else 1
449        self.chisq = numpy.sum(self.residuals**2)
450        return self.chisq/self.degrees_of_freedom
451   
452class ParkFit(FitEngine):
453    """
454    ParkFit performs the Fit.This class can be used as follow:
455    #Do the fit Park
456    create an engine: engine = ParkFit()
457    Use data must be of type plottable
458    Use a sans model
459   
460    Add data with a dictionnary of FitArrangeList where Uid is a key and data
461    is saved in FitArrange object.
462    engine.set_data(data,Uid)
463   
464    Set model parameter "M1"= model.name add {model.parameter.name:value}.
465   
466    ..note::
467       Set_param() if used must always preceded set_model() for the fit to be performed.
468      ``engine.set_param( model,"M1", {'A':2,'B':4})``
469   
470    Add model with a dictionnary of FitArrangeList{} where Uid is a key
471    and model
472    is save in FitArrange object.
473    engine.set_model(model,Uid)
474   
475    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
476    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
477   
478    ..note::
479        {model.parameter.name:value} is ignored in fit function since
480        the user should make sure to call set_param himself.
481       
482    """
483    def __init__(self):
484        """
485        Creates a dictionary (self.fitArrangeList={})of FitArrange elements
486        with Uid as keys
487        """
488        FitEngine.__init__(self)
489        self.fit_arrange_dict = {}
490        self.param_list = []
491       
492    def create_assembly(self, curr_thread, reset_flag=False):
493        """
494        Extract sansmodel and sansdata from
495        self.FitArrangelist ={Uid:FitArrange}
496        Create parkmodel and park data ,form a list couple of parkmodel
497        and parkdata
498        create an assembly self.problem=  park.Assembly([(parkmodel,parkdata)])
499        """
500        mylist = []
501        #listmodel = []
502        #i = 0
503        fitproblems = []
504        for fproblem in self.fit_arrange_dict.itervalues():
505            if fproblem.get_to_fit() == 1:
506                fitproblems.append(fproblem)
507        if len(fitproblems) == 0:
508            raise RuntimeError, "No Assembly scheduled for Park fitting."
509        for item in fitproblems:
510            model = item.get_model()
511            parkmodel = ParkModel(model.model, model.data)
512            parkmodel.pars = item.pars
513            if reset_flag:
514                # reset the initial value; useful for batch
515                for name in item.pars:
516                    ind = item.pars.index(name)
517                    parkmodel.model.setParam(name, item.vals[ind])
518
519            # set the constraints into the model
520            for p,v in item.constraints:
521                parkmodel.parameterset[str(p)].set(str(v))
522           
523            for p in parkmodel.parameterset:
524                ## does not allow status change for constraint parameters
525                if p.status != 'computed':
526                    if p.get_name() in item.pars:
527                        ## make parameters selected for
528                        #fit will be between boundaries
529                        p.set(p.range)         
530                    else:
531                        p.status = 'fixed'
532            data_list = item.get_data()
533            parkdata = data_list
534            fitness = (parkmodel, parkdata)
535            mylist.append(fitness)
536        self.problem = MyAssembly(models=mylist, curr_thread=curr_thread)
537       
538 
539    def fit(self, msg_q=None, 
540            q=None, handler=None, curr_thread=None, 
541            ftol=1.49012e-8, reset_flag=False):
542        """
543        Performs fit with park.fit module.It can  perform fit with one model
544        and a set of data, more than two fit of  one model and sets of data or
545        fit with more than two model associated with their set of data and
546        constraints
547       
548        :param pars: Dictionary of parameter names for the model and their
549            values.
550        :param qmin: The minimum value of data's range to be fit
551        :param qmax: The maximum value of data's range to be fit
552       
553        :note: all parameter are ignored most of the time.Are just there
554            to keep ScipyFit and ParkFit interface the same.
555           
556        :return: result.fitness Value of the goodness of fit metric
557        :return: result.pvec list of parameter with the best value
558            found during fitting
559        :return: result.cov Covariance matrix
560       
561        """
562        self.create_assembly(curr_thread=curr_thread, reset_flag=reset_flag)
563        localfit = SansFitSimplex()
564        localfit.ftol = ftol
565        localfit.xtol = 1e-6
566
567        # See `park.fitresult.FitHandler` for details.
568        fitter = SansFitMC(localfit=localfit, start_points=1)
569        if handler == None:
570            handler = fitresult.ConsoleUpdate(improvement_delta=0.1)
571       
572        result_list = []
573        try:
574            result = fit.fit(self.problem, fitter=fitter, handler=handler)
575            self.problem.extend_results_with_calculated_parameters(result)
576           
577        except LinAlgError:
578            raise ValueError, "SVD did not converge"
579
580        if result is None:
581            raise RuntimeError("park did not return a fit result")
582   
583        for m in self.problem.parts:
584            residuals, theory = m.fitness.residuals()
585            small_result = FResult(model=m.model, data=m.data.sans_data)
586            small_result.fitter_id = self.fitter_id
587            small_result.theory = theory
588            small_result.residuals = residuals
589            small_result.index = m.data.idx
590            small_result.fitness = result.fitness
591
592            # Extract the parameters that are part of this model; make sure
593            # they match the fitted parameters for this model, and place them
594            # in the same order as they occur in the model.
595            pars = {}
596            for p in result.parameters:
597                #if p.data.name == small_result.data.name and
598                if p.model.name == small_result.model.name:
599                    model_name, par_name = p.name.split('.', 1)
600                    pars[par_name] = (p.value, p.stderr)
601            #assert len(pars.keys()) == len(m.model.pars)
602            v,dv = zip(*[pars[p] for p in m.model.pars])
603            small_result.pvec = v
604            small_result.stderr = dv
605            small_result.param_list = m.model.pars
606
607            # normalize chisq by degrees of freedom
608            dof = len(small_result.residuals)-len(small_result.pvec)
609            small_result.fitness = numpy.sum(residuals**2)/dof
610
611            result_list.append(small_result)   
612        if q != None:
613            q.put(result_list)
614            return q
615        return result_list
616       
Note: See TracBrowser for help on using the repository browser.