source: sasview/src/sas/sascalc/pr/fit/BumpsFitting.py @ 7af652d

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1249
Last change on this file since 7af652d was 3e6829d, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

default to floating point division in pr

  • Property mode set to 100644
File size: 13.6 KB
RevLine 
[6fe5100]1"""
2BumpsFitting module runs the bumps optimizer.
3"""
[3e6829d]4from __future__ import division
5
[249a7c6]6import os
[35086c3]7from datetime import timedelta, datetime
8
[9a5097c]9import numpy as np
[6fe5100]10
11from bumps import fitters
[7945367]12try:
13    from bumps.options import FIT_CONFIG
14    # Default bumps to use the Levenberg-Marquardt optimizer
15    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
16    def get_fitter():
17        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
18except:
19    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
20    # Default bumps to use the Levenberg-Marquardt optimizer
21    fitters.FIT_DEFAULT = 'lm'
22    def get_fitter():
23        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
24        return fitopts.fitclass, fitopts.options.copy()
25
26
[249a7c6]27from bumps.mapper import SerialMapper, MPMapper
[e3efa6b3]28from bumps import parameter
29from bumps.fitproblem import FitProblem
[7945367]30
[6fe5100]31
[b699768]32from sas.sascalc.fit.AbstractFitEngine import FitEngine
33from sas.sascalc.fit.AbstractFitEngine import FResult
34from sas.sascalc.fit.expression import compile_constraints
[6fe5100]35
[35086c3]36class Progress(object):
37    def __init__(self, history, max_step, pars, dof):
[3e6829d]38        remaining_time = int(history.time[0]*(max_step/history.step[0]-1))
[35086c3]39        # Depending on the time remaining, either display the expected
40        # time of completion, or the amount of time remaining.  Use precision
41        # appropriate for the duration.
42        if remaining_time >= 1800:
43            completion_time = datetime.now() + timedelta(seconds=remaining_time)
44            if remaining_time >= 36000:
45                time = completion_time.strftime('%Y-%m-%d %H:%M')
46            else:
47                time = completion_time.strftime('%H:%M')
48        else:
49            if remaining_time >= 3600:
50                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
51            elif remaining_time >= 60:
52                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
53            else:
54                time = '%ds'%remaining_time
55        chisq = "%.3g"%(2*history.value[0]/dof)
56        step = "%d of %d"%(history.step[0], max_step)
57        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
58        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
59                      for i,(k,v) in enumerate(zip(pars,history.point[0]))]
60        self.msg = "".join([header]+parameters)
61
62    def __str__(self):
63        return self.msg
64
65
[85f17f6]66class BumpsMonitor(object):
[35086c3]67    def __init__(self, handler, max_step, pars, dof):
[85f17f6]68        self.handler = handler
69        self.max_step = max_step
[35086c3]70        self.pars = pars
71        self.dof = dof
[ed4aef2]72
[85f17f6]73    def config_history(self, history):
74        history.requires(time=1, value=2, point=1, step=1)
[ed4aef2]75
[85f17f6]76    def __call__(self, history):
[e3efa6b3]77        if self.handler is None: return
[35086c3]78        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
[85f17f6]79        self.handler.progress(history.step[0], self.max_step)
80        if len(history.step)>1 and history.step[1] > history.step[0]:
81            self.handler.improvement()
82        self.handler.update_fit()
83
[ed4aef2]84class ConvergenceMonitor(object):
85    """
86    ConvergenceMonitor contains population summary statistics to show progress
87    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
88    just a list [ (best, ) ] if population size is 1.
89    """
90    def __init__(self):
91        self.convergence = []
92
93    def config_history(self, history):
94        history.requires(value=1, population_values=1)
95
96    def __call__(self, history):
97        best = history.value[0]
98        try:
99            p = history.population_values[0]
[9a5097c]100            n,p = len(p), np.sort(p)
[ed4aef2]101            QI,Qmid, = int(0.2*n),int(0.5*n)
102            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
103        except:
[e3efa6b3]104            self.convergence.append((best, best,best,best,best,best))
[ed4aef2]105
[e3efa6b3]106
[4e9f227]107# Note: currently using bumps parameters for each parameter object so that
108# a SasFitness can be used directly in bumps with the usual semantics.
109# The disadvantage of this technique is that we need to copy every parameter
110# back into the model each time the function is evaluated.  We could instead
[fd5ac0d]111# define reference parameters for each sas parameter, but then we would not
[4e9f227]112# be able to express constraints using python expressions in the usual way
113# from bumps, and would instead need to use string expressions.
[e3efa6b3]114class SasFitness(object):
[6fe5100]115    """
[e3efa6b3]116    Wrap SAS model as a bumps fitness object
[6fe5100]117    """
[5044543]118    def __init__(self, model, data, fitted=[], constraints={},
119                 initial_values=None, **kw):
[4e9f227]120        self.name = model.name
121        self.model = model.model
[6fe5100]122        self.data = data
[9f7fbd9]123        if self.data.smearer is not None:
124            self.data.smearer.model = self.model
[e3efa6b3]125        self._define_pars()
126        self._init_pars(kw)
[5044543]127        if initial_values is not None:
128            self._reset_pars(fitted, initial_values)
[4e9f227]129        self.constraints = dict(constraints)
[e3efa6b3]130        self.set_fitted(fitted)
[5044543]131        self.update()
132
133    def _reset_pars(self, names, values):
134        for k,v in zip(names, values):
135            self._pars[k].value = v
[e3efa6b3]136
137    def _define_pars(self):
138        self._pars = {}
139        for k in self.model.getParamList():
140            name = ".".join((self.name,k))
141            value = self.model.getParam(k)
142            bounds = self.model.details.get(k,["",None,None])[1:3]
143            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
144                                                fixed=True, name=name)
[4e9f227]145        #print parameter.summarize(self._pars.values())
[e3efa6b3]146
147    def _init_pars(self, kw):
148        for k,v in kw.items():
149            # dispersion parameters initialized with _field instead of .field
150            if k.endswith('_width'): k = k[:-6]+'.width'
151            elif k.endswith('_npts'): k = k[:-5]+'.npts'
152            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
153            elif k.endswith('_type'): k = k[:-5]+'.type'
154            if k not in self._pars:
155                formatted_pars = ", ".join(sorted(self._pars.keys()))
156                raise KeyError("invalid parameter %r for %s--use one of: %s"
157                               %(k, self.model, formatted_pars))
158            if '.' in k and not k.endswith('.width'):
159                self.model.setParam(k, v)
160            elif isinstance(v, parameter.BaseParameter):
161                self._pars[k] = v
162            elif isinstance(v, (tuple,list)):
163                low, high = v
164                self._pars[k].value = (low+high)/2
165                self._pars[k].range(low,high)
[95d58d3]166            else:
[e3efa6b3]167                self._pars[k].value = v
168
169    def set_fitted(self, param_list):
[6fe5100]170        """
[e3efa6b3]171        Flag a set of parameters as fitted parameters.
[6fe5100]172        """
[e3efa6b3]173        for k,p in self._pars.items():
[4e9f227]174            p.fixed = (k not in param_list or k in self.constraints)
[4a0dc427]175        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
[bf5e985]176        self.computed_par_names = [k for k in param_list if k in self.constraints]
177        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
178        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
[6fe5100]179
[e3efa6b3]180    # ===== Fitness interface ====
181    def parameters(self):
182        return self._pars
[6fe5100]183
[e3efa6b3]184    def update(self):
185        for k,v in self._pars.items():
[4e9f227]186            #print "updating",k,v,v.value
[e3efa6b3]187            self.model.setParam(k,v.value)
188        self._dirty = True
[6fe5100]189
[e3efa6b3]190    def _recalculate(self):
191        if self._dirty:
[9f7fbd9]192            self._residuals, self._theory \
193                = self.data.residuals(self.model.evalDistribution)
[e3efa6b3]194            self._dirty = False
[6fe5100]195
[e3efa6b3]196    def numpoints(self):
[9a5097c]197        return np.sum(self.data.idx) # number of fitted points
[6fe5100]198
[e3efa6b3]199    def nllf(self):
[9a5097c]200        return 0.5*np.sum(self.residuals()**2)
[e3efa6b3]201
202    def theory(self):
203        self._recalculate()
204        return self._theory
205
206    def residuals(self):
207        self._recalculate()
208        return self._residuals
209
210    # Not implementing the data methods for now:
211    #
212    #     resynth_data/restore_data/save/plot
[6fe5100]213
[191c648]214class ParameterExpressions(object):
215    def __init__(self, models):
216        self.models = models
217        self._setup()
218
219    def _setup(self):
220        exprs = {}
221        for M in self.models:
222            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
223        if exprs:
224            symtab = dict((".".join((M.name, k)), p)
225                          for M in self.models
226                          for k,p in M.parameters().items())
227            self.update = compile_constraints(symtab, exprs)
228        else:
229            self.update = lambda: 0
230
231    def __call__(self):
232        self.update()
233
234    def __getstate__(self):
235        return self.models
236
237    def __setstate__(self, state):
238        self.models = state
239        self._setup()
240
[6fe5100]241class BumpsFit(FitEngine):
242    """
243    Fit a model using bumps.
244    """
245    def __init__(self):
246        """
247        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
248        with Uid as keys
249        """
250        FitEngine.__init__(self)
251        self.curr_thread = None
252
253    def fit(self, msg_q=None,
254            q=None, handler=None, curr_thread=None,
255            ftol=1.49012e-8, reset_flag=False):
[e3efa6b3]256        # Build collection of bumps fitness calculators
[bf5e985]257        models = [SasFitness(model=M.get_model(),
258                             data=M.get_data(),
259                             constraints=M.constraints,
[5044543]260                             fitted=M.pars,
261                             initial_values=M.vals if reset_flag else None)
[bf5e985]262                  for M in self.fit_arrange_dict.values()
263                  if M.get_to_fit()]
[233c121]264        if len(models) == 0:
265            raise RuntimeError("Nothing to fit")
[e3efa6b3]266        problem = FitProblem(models)
267
[191c648]268        # TODO: need better handling of parameter expressions and bounds constraints
269        # so that they are applied during polydispersity calculations.  This
270        # will remove the immediate need for the setp_hook in bumps, though
271        # bumps may still need something similar, such as a sane class structure
272        # which allows a subclass to override setp.
273        problem.setp_hook = ParameterExpressions(models)
[4e9f227]274
[e3efa6b3]275        # Run the fit
276        result = run_bumps(problem, handler, curr_thread)
[6fe5100]277        if handler is not None:
278            handler.update_fit(last=True)
[e3efa6b3]279
[eff93b8]280        # TODO: shouldn't reference internal parameters of fit problem
[e3efa6b3]281        varying = problem._parameters
282        # collect the results
283        all_results = []
284        for M in problem.models:
285            fitness = M.fitness
286            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
[e1442d4]287            param_list = fitness.fitted_par_names + fitness.computed_par_names
[e3efa6b3]288            R = FResult(model=fitness.model, data=fitness.data,
[e1442d4]289                        param_list=param_list)
[e3efa6b3]290            R.theory = fitness.theory()
291            R.residuals = fitness.residuals()
[5044543]292            R.index = fitness.data.idx
[e3efa6b3]293            R.fitter_id = self.fitter_id
[eff93b8]294            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
[e3efa6b3]295            R.success = result['success']
[e1442d4]296            if R.success:
[9a5097c]297                R.stderr = np.hstack((result['stderr'][fitted_index],
298                                      np.NaN*np.ones(len(fitness.computed_pars))))
299                R.pvec = np.hstack((result['value'][fitted_index],
[e1442d4]300                                      [p.value for p in fitness.computed_pars]))
[9a5097c]301                R.fitness = np.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
[e1442d4]302            else:
[9a5097c]303                R.stderr = np.NaN*np.ones(len(param_list))
304                R.pvec = np.asarray( [p.value for p in fitness.fitted_pars+fitness.computed_pars])
305                R.fitness = np.NaN
[e3efa6b3]306            R.convergence = result['convergence']
307            if result['uncertainty'] is not None:
308                R.uncertainty_state = result['uncertainty']
309            all_results.append(R)
310
[6fe5100]311        if q is not None:
[e3efa6b3]312            q.put(all_results)
[6fe5100]313            return q
[e3efa6b3]314        else:
315            return all_results
[6fe5100]316
[e3efa6b3]317def run_bumps(problem, handler, curr_thread):
[85f17f6]318    def abort_test():
319        if curr_thread is None: return False
320        try: curr_thread.isquit()
321        except KeyboardInterrupt:
322            if handler is not None:
323                handler.stop("Fitting: Terminated!!!")
324            return True
325        return False
326
[7945367]327    fitclass, options = get_fitter()
328    steps = options.get('steps', 0)
329    if steps == 0:
330        pop = options.get('pop',0)*len(problem._parameters)
331        samples = options.get('samples', 0)
332        steps = (samples+pop-1)/pop if pop != 0 else samples
333    max_step = steps + options.get('burn', 0)
[35086c3]334    pars = [p.name for p in problem._parameters]
[9a5097c]335    #x0 = np.asarray([p.value for p in problem._parameters])
[e3efa6b3]336    options['monitors'] = [
[35086c3]337        BumpsMonitor(handler, max_step, pars, problem.dof),
[e3efa6b3]338        ConvergenceMonitor(),
339        ]
[95d58d3]340    fitdriver = fitters.FitDriver(fitclass, problem=problem,
[042f065]341                                  abort_test=abort_test, **options)
[249a7c6]342    omp_threads = int(os.environ.get('OMP_NUM_THREADS','0'))
[e1442d4]343    mapper = MPMapper if omp_threads == 1 else SerialMapper
[6fe5100]344    fitdriver.mapper = mapper.start_mapper(problem, None)
[233c121]345    #import time; T0 = time.time()
[6fe5100]346    try:
347        best, fbest = fitdriver.fit()
348    except:
349        import traceback; traceback.print_exc()
350        raise
[95d58d3]351    finally:
352        mapper.stop_mapper(fitdriver.mapper)
[e3efa6b3]353
354
355    convergence_list = options['monitors'][-1].convergence
[9a5097c]356    convergence = (2*np.asarray(convergence_list)/problem.dof
357                   if convergence_list else np.empty((0,1),'d'))
[e1442d4]358
359    success = best is not None
[e3efa6b3]360    return {
[e1442d4]361        'value': best if success else None,
362        'stderr': fitdriver.stderr() if success else None,
363        'success': success,
[e3efa6b3]364        'convergence': convergence,
365        'uncertainty': getattr(fitdriver.fitter, 'state', None),
366        }
[6fe5100]367
Note: See TracBrowser for help on using the repository browser.