source: sasview/src/sas/fit/ParkFitting.py @ bc9a0e1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since bc9a0e1 was fd5ac0d, checked in by krzywon, 10 years ago

I have completed the removal of all SANS references.
I will build, run, and run all unit tests before pushing.

  • Property mode set to 100644
File size: 21.7 KB
Line 
1
2
3
4"""
5ParkFitting module contains SasParameter,Model,Data
6FitArrange, ParkFit,Parameter classes.All listed classes work together
7to perform a simple fit with park optimizer.
8"""
9#import time
10import numpy
11import math
12from  numpy.linalg.linalg import LinAlgError
13#import park
14from park import fit
15from park import fitresult
16from  park.fitresult import FitParameter
17import park.simplex
18from park.assembly import Assembly
19from park.assembly import Part
20from park.fitmc import FitSimplex
21import park.fitmc
22from park.fit import Fitter
23from park.formatnum import format_uncertainty
24from sas.fit.AbstractFitEngine import FitEngine
25from sas.fit.AbstractFitEngine import FResult
26
27class SasParameter(park.Parameter):
28    """
29    SAS model parameters for use in the PARK fitting service.
30    The parameter attribute value is redirected to the underlying
31    parameter value in the SAS model.
32    """
33    def __init__(self, name, model, data):
34        """
35            :param name: the name of the model parameter
36            :param model: the sas model to wrap as a park model
37        """
38        park.Parameter.__init__(self, name)
39        #self._model, self._name = model, name
40        self.data = data
41        self.model = model
42        #set the value for the parameter of the given name
43        self.set(model.getParam(name))
44
45        # TODO: model is missing parameter ranges for dispersion parameters
46        if name not in model.details:
47            #print "setting details for",name
48            model.details[name] = ["", None, None]
49
50    def _getvalue(self):
51        """
52        override the _getvalue of park parameter
53
54        :return value the parameter associates with self.name
55
56        """
57        return self.model.getParam(self.name)
58
59    def _setvalue(self, value):
60        """
61        override the _setvalue pf park parameter
62
63        :param value: the value to set on a given parameter
64
65        """
66        self.model.setParam(self.name, value)
67
68    value = property(_getvalue, _setvalue)
69
70    def _getrange(self):
71        """
72        Override _getrange of park parameter
73        return the range of parameter
74        """
75        #if not  self.name in self._model.getDispParamList():
76        lo, hi = self.model.details[self.name][1:3]
77        if lo is None: lo = -numpy.inf
78        if hi is None: hi = numpy.inf
79        if lo > hi:
80            raise ValueError, "wrong fit range for parameters"
81
82        return lo, hi
83
84    def get_name(self):
85        """
86        """
87        return self._getname()
88
89    def _setrange(self, r):
90        """
91        override _setrange of park parameter
92
93        :param r: the value of the range to set
94
95        """
96        self.model.details[self.name][1:3] = r
97    range = property(_getrange, _setrange)
98
99
100class ParkModel(park.Model):
101    """
102    PARK wrapper for SAS models.
103    """
104    def __init__(self, sas_model, sas_data=None, **kw):
105        """
106        :param sas_model: the sas model to wrap using park interface
107
108        """
109        park.Model.__init__(self, **kw)
110        self.model = sas_model
111        self.name = sas_model.name
112        self.data = sas_data
113        #list of parameters names
114        self.sasp = sas_model.getParamList()
115        #list of park parameter
116        self.parkp = [SasParameter(p, sas_model, sas_data) for p in self.sasp]
117        #list of parameter set
118        self.parameterset = park.ParameterSet(sas_model.name, pars=self.parkp)
119        self.pars = []
120
121    def get_params(self, fitparams):
122        """
123        return a list of value of paramter to fit
124
125        :param fitparams: list of paramaters name to fit
126
127        """
128        list_params = []
129        self.pars = fitparams
130        for item in fitparams:
131            for element in self.parkp:
132                if element.name == str(item):
133                    list_params.append(element.value)
134        return list_params
135
136    def set_params(self, paramlist, params):
137        """
138        Set value for parameters to fit
139
140        :param params: list of value for parameters to fit
141
142        """
143        try:
144            for i in range(len(self.parkp)):
145                for j in range(len(paramlist)):
146                    if self.parkp[i].name == paramlist[j]:
147                        self.parkp[i].value = params[j]
148                        self.model.setParam(self.parkp[i].name, params[j])
149        except:
150            raise
151
152    def eval(self, x):
153        """
154            Override eval method of park model.
155
156            :param x: the x value used to compute a function
157        """
158        try:
159            return self.model.evalDistribution(x)
160        except:
161            raise
162
163    def eval_derivs(self, x, pars=[]):
164        """
165        Evaluate the model and derivatives wrt pars at x.
166
167        pars is a list of the names of the parameters for which derivatives
168        are desired.
169
170        This method needs to be specialized in the model to evaluate the
171        model function.  Alternatively, the model can implement is own
172        version of residuals which calculates the residuals directly
173        instead of calling eval.
174        """
175        return []
176
177
178class SasFitResult(fitresult.FitResult):
179    def __init__(self, *args, **kwrds):
180        fitresult.FitResult.__init__(self, *args, **kwrds)
181        self.theory = None
182        self.inputs = []
183       
184class SasFitSimplex(FitSimplex):
185    """
186    Local minimizer using Nelder-Mead simplex algorithm.
187
188    Simplex is robust and derivative free, though not very efficient.
189
190    This class wraps the bounds contrained Nelder-Mead simplex
191    implementation for `park.simplex.simplex`.
192    """
193    radius = 0.05
194    """Size of the initial simplex; this is a portion between 0 and 1"""
195    xtol = 1
196    #xtol = 1e-4
197    """Stop when simplex vertices are within xtol of each other"""
198    ftol = 5e-5
199    """Stop when vertex values are within ftol of each other"""
200    maxiter = None
201    """Maximum number of iterations before fit terminates"""
202    def __init__(self, ftol=5e-5):
203        self.ftol = ftol
204       
205    def fit(self, fitness, x0):
206        """Run the fit"""
207        self.cancel = False
208        pars = fitness.fit_parameters()
209        bounds = numpy.array([p.range for p in pars]).T
210        result = park.simplex.simplex(fitness, x0, bounds=bounds,
211                                 radius=self.radius, xtol=self.xtol,
212                                 ftol=self.ftol, maxiter=self.maxiter,
213                                 abort_test=self._iscancelled)
214        #print "calls:",result.calls
215        #print "simplex returned",result.x,result.fx
216        # Need to make our own copy of the fit results so that the
217        # values don't get stomped on by the next fit iteration.
218        fitpars = [SasFitParameter(pars[i].name,pars[i].range,v, pars[i].model, pars[i].data)
219                   for i,v in enumerate(result.x)]
220        res = SasFitResult(fitpars, result.calls, result.fx)
221        res.inputs = [(pars[i].model, pars[i].data) for i,v in enumerate(result.x)]
222        # Compute the parameter uncertainties from the jacobian
223        res.calc_cov(fitness)
224        return res
225     
226class SasFitter(Fitter):
227    """
228    """
229    def fit(self, fitness, handler):
230        """
231        Global optimizer.
232
233        This function should return immediately
234        """
235        # Determine initial value and bounds
236        pars = fitness.fit_parameters()
237        bounds = numpy.array([p.range for p in pars]).T
238        x0 = [p.value for p in pars]
239
240        # Initialize the monitor and results.
241        # Need to make our own copy of the fit results so that the
242        # values don't get stomped on by the next fit iteration.
243        handler.done = False
244        self.handler = handler
245        fitpars = [SasFitParameter(pars[i].name, pars[i].range, v,
246                                     pars[i].model, pars[i].data)
247                   for i,v in enumerate(x0)]
248        handler.result = fitresult.FitResult(fitpars, 0, numpy.NaN)
249
250        # Run the fit (fit should perform _progress and _improvement updates)
251        # This function may return before the fit is complete.
252        self._fit(fitness, x0, bounds)
253       
254class SasFitMC(SasFitter):
255    """
256    Monte Carlo optimizer.
257
258    This implements `park.fit.Fitter`.
259    """
260    localfit = SasFitSimplex()
261    start_points = 10
262    def __init__(self, localfit, start_points=10):
263        self.localfit = localfit
264        self.start_points = start_points
265       
266    def _fit(self, objective, x0, bounds):
267        """
268        Run a monte carlo fit.
269
270        This procedure maps a local optimizer across a set of initial points.
271        """
272        try:
273            park.fitmc.fitmc(objective, x0, bounds, self.localfit,
274                             self.start_points, self.handler)
275        except:
276            raise ValueError, "Fit did not converge.\n"
277       
278class SasPart(Part):
279    """
280    Part of a fitting assembly.  Part holds the model itself and
281    associated data.  The part can be initialized with a fitness
282    object or with a pair (model,data) for the default fitness function.
283
284    fitness (Fitness)
285        object implementing the `park.assembly.Fitness` interface.  In
286        particular, fitness should provide a parameterset attribute
287        containing a ParameterSet and a residuals method returning a vector
288        of residuals.
289    weight (dimensionless)
290        weight for the model.  See comments in assembly.py for details.
291    isfitted (boolean)
292        True if the model residuals should be included in the fit.
293        The model parameters may still be used in parameter
294        expressions, but there will be no comparison to the data.
295    residuals (vector)
296        Residuals for the model if they have been calculated, or None
297    degrees_of_freedom
298        Number of residuals minus number of fitted parameters.
299        Degrees of freedom for individual models does not make
300        sense in the presence of expressions combining models,
301        particularly in the case where a model has many parameters
302        but no data or many computed parameters.  The degrees of
303        freedom for the model is set to be at least one.
304    chisq
305        sum(residuals**2); use chisq/degrees_of_freedom to
306        get the reduced chisq value.
307
308        Get/set the weight on the given model.
309
310        assembly.weight(3) returns the weight on model 3 (0-origin)
311        assembly.weight(3,0.5) sets the weight on model 3 (0-origin)
312    """
313
314    def __init__(self, fitness, weight=1., isfitted=True):
315        Part.__init__(self, fitness=fitness, weight=weight,
316                       isfitted=isfitted)
317       
318        self.model, self.data = fitness[0], fitness[1]
319
320class SasFitParameter(FitParameter):
321    """
322    Fit result for an individual parameter.
323    """
324    def __init__(self, name, range, value, model, data):
325        FitParameter.__init__(self, name, range, value)
326        self.model = model
327        self.data = data
328       
329    def summarize(self):
330        """
331        Return parameter range string.
332
333        E.g.,  "       Gold .....|.... 5.2043 in [2,7]"
334        """
335        bar = ['.']*10
336        lo,hi = self.range
337        if numpy.isfinite(lo)and numpy.isfinite(hi):
338            portion = (self.value-lo)/(hi-lo)
339            if portion < 0: portion = 0.
340            elif portion >= 1: portion = 0.99999999
341            barpos = int(math.floor(portion*len(bar)))
342            bar[barpos] = '|'
343        bar = "".join(bar)
344        lostr = "[%g"%lo if numpy.isfinite(lo) else "(-inf"
345        histr = "%g]"%hi if numpy.isfinite(hi) else "inf)"
346        valstr = format_uncertainty(self.value, self.stderr)
347        model_name = str(None)
348        if self.model is not None:
349            model_name = self.model.name
350        data_name = str(None)
351        if self.data is not None:
352            data_name = self.data.name
353           
354        return "%25s %s %s in %s,%s, %s, %s"  % (self.name,bar,valstr,lostr,histr, 
355                                                 model_name, data_name)
356    def __repr__(self):
357        #return "FitParameter('%s')"%self.name
358        return str(self.__class__)
359   
360class MyAssembly(Assembly):
361    def __init__(self, models, curr_thread=None):
362        """Build an assembly from a list of models."""
363        self.parts = []
364        for m in models:
365            self.parts.append(SasPart(m))
366        self.curr_thread = curr_thread
367        self.chisq = None
368        self._cancel = False
369        self.theory = None
370        self._reset()
371       
372    def fit_parameters(self):
373        """
374        Return an alphabetical list of the fitting parameters.
375
376        This function is called once at the beginning of a fit,
377        and serves as a convenient place to precalculate what
378        can be precalculated such as the set of fitting parameters
379        and the parameter expressions evaluator.
380        """
381        self.parameterset.setprefix()
382        self._fitparameters = self.parameterset.fitted
383        self._restraints = self.parameterset.restrained
384        pars = self.parameterset.flatten()
385        context = self.parameterset.gather_context()
386        self._fitexpression = park.expression.build_eval(pars,context)
387        #print "constraints",self._fitexpression.__doc__
388
389        self._fitparameters.sort(lambda a,b: cmp(a.path,b.path))
390        # Convert to fitparameter a object
391       
392        fitpars = [SasFitParameter(p.path,p.range,p.value, p.model, p.data)
393                   for p in self._fitparameters]
394        #print "fitpars", fitpars
395        return fitpars
396   
397    def extend_results_with_calculated_parameters(self, result):
398        """
399        Extend result from the fit with the calculated parameters.
400        """
401        calcpars = [SasFitParameter(p.path,p.range,p.value, p.model, p.data)
402                    for p in self.parameterset.computed]
403        result.parameters += calcpars
404        result.theory = self.theory
405
406    def eval(self):
407        """
408        Recalculate the theory functions, and from them, the
409        residuals and chisq.
410
411        :note: Call this after the parameters have been updated.
412        """
413        # Handle abort from a separate thread.
414        self._cancel = False
415        if self.curr_thread != None:
416            try:
417                self.curr_thread.isquit()
418            except:
419                self._cancel = True
420
421        # Evaluate the computed parameters
422        try:
423            self._fitexpression()
424        except NameError:
425            pass
426
427        # Check that the resulting parameters are in a feasible region.
428        if not self.isfeasible(): return numpy.inf
429
430        resid = []
431        k = len(self._fitparameters)
432        for m in self.parts:
433            # In order to support abort, need to be able to propagate an
434            # external abort signal from self.abort() into an abort signal
435            # for the particular model.  Can't see a way to do this which
436            # doesn't involve setting a state variable.
437            self._current_model = m
438            if self._cancel: return numpy.inf
439            if m.isfitted and m.weight != 0:
440                m.residuals, self.theory = m.fitness.residuals()
441                N = len(m.residuals)
442                m.degrees_of_freedom = N-k if N>k else 1
443                # dividing residuals by N in order to be consistent with Scipy
444                m.chisq = numpy.sum(m.residuals**2/N) 
445                resid.append(m.weight*m.residuals)
446        self.residuals = numpy.hstack(resid)
447        N = len(self.residuals)
448        self.degrees_of_freedom = N-k if N>k else 1
449        self.chisq = numpy.sum(self.residuals**2)
450        return self.chisq/self.degrees_of_freedom
451   
452class ParkFit(FitEngine):
453    """
454    ParkFit performs the Fit.This class can be used as follow:
455    #Do the fit Park
456    create an engine: engine = ParkFit()
457    Use data must be of type plottable
458    Use a sas model
459   
460    Add data with a dictionnary of FitArrangeList where Uid is a key and data
461    is saved in FitArrange object.
462    engine.set_data(data,Uid)
463   
464    Set model parameter "M1"= model.name add {model.parameter.name:value}.
465   
466    ..note::
467
468       Set_param() if used must always preceded set_model() for the fit to be performed. ``engine.set_param( model,"M1", {'A':2,'B':4})``
469   
470    Add model with a dictionnary of FitArrangeList{} where Uid is a key
471    and model
472    is save in FitArrange object.
473    engine.set_model(model,Uid)
474   
475    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
476    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
477   
478    ..note::
479
480        {model.parameter.name:value} is ignored in fit function since
481        the user should make sure to call set_param himself.
482       
483    """
484    def __init__(self):
485        """
486        Creates a dictionary (self.fitArrangeList={})of FitArrange elements
487        with Uid as keys
488        """
489        FitEngine.__init__(self)
490        self.fit_arrange_dict = {}
491        self.param_list = []
492       
493    def create_assembly(self, curr_thread, reset_flag=False):
494        """
495        Extract sasmodel and sasdata from
496        self.FitArrangelist ={Uid:FitArrange}
497        Create parkmodel and park data ,form a list couple of parkmodel
498        and parkdata
499        create an assembly self.problem=  park.Assembly([(parkmodel,parkdata)])
500        """
501        mylist = []
502        #listmodel = []
503        #i = 0
504        fitproblems = []
505        for fproblem in self.fit_arrange_dict.itervalues():
506            if fproblem.get_to_fit() == 1:
507                fitproblems.append(fproblem)
508        if len(fitproblems) == 0:
509            raise RuntimeError, "No Assembly scheduled for Park fitting."
510        for item in fitproblems:
511            model = item.get_model()
512            parkmodel = ParkModel(model.model, model.data)
513            parkmodel.pars = item.pars
514            if reset_flag:
515                # reset the initial value; useful for batch
516                for name in item.pars:
517                    ind = item.pars.index(name)
518                    parkmodel.model.setParam(name, item.vals[ind])
519
520            # set the constraints into the model
521            for p,v in item.constraints:
522                parkmodel.parameterset[str(p)].set(str(v))
523           
524            for p in parkmodel.parameterset:
525                ## does not allow status change for constraint parameters
526                if p.status != 'computed':
527                    if p.get_name() in item.pars:
528                        ## make parameters selected for
529                        #fit will be between boundaries
530                        p.set(p.range)         
531                    else:
532                        p.status = 'fixed'
533            data_list = item.get_data()
534            parkdata = data_list
535            fitness = (parkmodel, parkdata)
536            mylist.append(fitness)
537        self.problem = MyAssembly(models=mylist, curr_thread=curr_thread)
538       
539 
540    def fit(self, msg_q=None, 
541            q=None, handler=None, curr_thread=None, 
542            ftol=1.49012e-8, reset_flag=False):
543        """
544        Performs fit with park.fit module.It can  perform fit with one model
545        and a set of data, more than two fit of  one model and sets of data or
546        fit with more than two model associated with their set of data and
547        constraints
548       
549        :param pars: Dictionary of parameter names for the model and their
550            values.
551        :param qmin: The minimum value of data's range to be fit
552        :param qmax: The maximum value of data's range to be fit
553       
554        :note: all parameter are ignored most of the time.Are just there
555            to keep ScipyFit and ParkFit interface the same.
556           
557        :return: result.fitness Value of the goodness of fit metric
558        :return: result.pvec list of parameter with the best value
559            found during fitting
560        :return: result.cov Covariance matrix
561       
562        """
563        self.create_assembly(curr_thread=curr_thread, reset_flag=reset_flag)
564        localfit = SasFitSimplex()
565        localfit.ftol = ftol
566        localfit.xtol = 1e-6
567
568        # See `park.fitresult.FitHandler` for details.
569        fitter = SasFitMC(localfit=localfit, start_points=1)
570        if handler == None:
571            handler = fitresult.ConsoleUpdate(improvement_delta=0.1)
572       
573        result_list = []
574        try:
575            result = fit.fit(self.problem, fitter=fitter, handler=handler)
576            self.problem.extend_results_with_calculated_parameters(result)
577           
578        except LinAlgError:
579            raise ValueError, "SVD did not converge"
580
581        if result is None:
582            raise RuntimeError("park did not return a fit result")
583   
584        for m in self.problem.parts:
585            residuals, theory = m.fitness.residuals()
586            small_result = FResult(model=m.model, data=m.data.sas_data)
587            small_result.fitter_id = self.fitter_id
588            small_result.theory = theory
589            small_result.residuals = residuals
590            small_result.index = m.data.idx
591            small_result.fitness = result.fitness
592
593            # Extract the parameters that are part of this model; make sure
594            # they match the fitted parameters for this model, and place them
595            # in the same order as they occur in the model.
596            pars = {}
597            for p in result.parameters:
598                #if p.data.name == small_result.data.name and
599                if p.model.name == small_result.model.name:
600                    model_name, par_name = p.name.split('.', 1)
601                    pars[par_name] = (p.value, p.stderr)
602            #assert len(pars.keys()) == len(m.model.pars)
603            v,dv = zip(*[pars[p] for p in m.model.pars])
604            small_result.pvec = v
605            small_result.stderr = dv
606            small_result.param_list = m.model.pars
607
608            # normalize chisq by degrees of freedom
609            dof = len(small_result.residuals)-len(small_result.pvec)
610            small_result.fitness = numpy.sum(residuals**2)/dof
611
612            result_list.append(small_result)   
613        if q != None:
614            q.put(result_list)
615            return q
616        return result_list
617       
Note: See TracBrowser for help on using the repository browser.