source: sasview/src/sas/sascalc/fit/BumpsFitting.py @ 3f89c0e

ticket-1094-headless
Last change on this file since 3f89c0e was 0aeba4e, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

code cleanup for bumps fitting wrapper

  • Property mode set to 100644
File size: 14.5 KB
RevLine 
[6fe5100]1"""
2BumpsFitting module runs the bumps optimizer.
3"""
[0aeba4e]4from __future__ import print_function
5
[249a7c6]6import os
[35086c3]7from datetime import timedelta, datetime
[1a5d5f2]8import traceback
[35086c3]9
[9a5097c]10import numpy as np
[6fe5100]11
12from bumps import fitters
[0aeba4e]13
[7945367]14try:
15    from bumps.options import FIT_CONFIG
[0aeba4e]16    # Preserve bumps default fitter in case someone wants it later
17    BUMPS_DEFAULT_FITTER = FIT_CONFIG.selected_id
[7945367]18    # Default bumps to use the Levenberg-Marquardt optimizer
19    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
20    def get_fitter():
21        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
[1386b2f]22except ImportError:
[7945367]23    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
[0aeba4e]24    # Preserve bumps default fitter in case someone wants it later
25    BUMPS_DEFAULT_FITTER = fitters.FIT_DEFAULT
[7945367]26    # Default bumps to use the Levenberg-Marquardt optimizer
27    fitters.FIT_DEFAULT = 'lm'
28    def get_fitter():
29        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
30        return fitopts.fitclass, fitopts.options.copy()
31
32
[249a7c6]33from bumps.mapper import SerialMapper, MPMapper
[e3efa6b3]34from bumps import parameter
35from bumps.fitproblem import FitProblem
[7945367]36
[345e7e4]37
[b699768]38from sas.sascalc.fit.AbstractFitEngine import FitEngine
39from sas.sascalc.fit.AbstractFitEngine import FResult
40from sas.sascalc.fit.expression import compile_constraints
[6fe5100]41
[35086c3]42class Progress(object):
43    def __init__(self, history, max_step, pars, dof):
44        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
45        # Depending on the time remaining, either display the expected
46        # time of completion, or the amount of time remaining.  Use precision
47        # appropriate for the duration.
48        if remaining_time >= 1800:
49            completion_time = datetime.now() + timedelta(seconds=remaining_time)
50            if remaining_time >= 36000:
51                time = completion_time.strftime('%Y-%m-%d %H:%M')
52            else:
53                time = completion_time.strftime('%H:%M')
54        else:
55            if remaining_time >= 3600:
56                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
57            elif remaining_time >= 60:
58                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
59            else:
60                time = '%ds'%remaining_time
61        chisq = "%.3g"%(2*history.value[0]/dof)
62        step = "%d of %d"%(history.step[0], max_step)
63        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
64        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
[1386b2f]65                      for i, (k, v) in enumerate(zip(pars, history.point[0]))]
[35086c3]66        self.msg = "".join([header]+parameters)
67
68    def __str__(self):
69        return self.msg
70
71
[85f17f6]72class BumpsMonitor(object):
[35086c3]73    def __init__(self, handler, max_step, pars, dof):
[85f17f6]74        self.handler = handler
75        self.max_step = max_step
[35086c3]76        self.pars = pars
77        self.dof = dof
[ed4aef2]78
[85f17f6]79    def config_history(self, history):
80        history.requires(time=1, value=2, point=1, step=1)
[ed4aef2]81
[85f17f6]82    def __call__(self, history):
[e3efa6b3]83        if self.handler is None: return
[35086c3]84        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
[85f17f6]85        self.handler.progress(history.step[0], self.max_step)
[1386b2f]86        if len(history.step) > 1 and history.step[1] > history.step[0]:
[85f17f6]87            self.handler.improvement()
88        self.handler.update_fit()
89
[ed4aef2]90class ConvergenceMonitor(object):
91    """
92    ConvergenceMonitor contains population summary statistics to show progress
93    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
94    just a list [ (best, ) ] if population size is 1.
95    """
96    def __init__(self):
97        self.convergence = []
98
99    def config_history(self, history):
100        history.requires(value=1, population_values=1)
101
102    def __call__(self, history):
103        best = history.value[0]
104        try:
105            p = history.population_values[0]
[1386b2f]106            n, p = len(p), np.sort(p)
107            QI, Qmid = int(0.2*n), int(0.5*n)
108            self.convergence.append((best, p[0], p[QI], p[Qmid], p[-1-QI], p[-1]))
109        except Exception:
110            self.convergence.append((best, best, best, best, best, best))
[ed4aef2]111
[e3efa6b3]112
[4e9f227]113# Note: currently using bumps parameters for each parameter object so that
114# a SasFitness can be used directly in bumps with the usual semantics.
115# The disadvantage of this technique is that we need to copy every parameter
116# back into the model each time the function is evaluated.  We could instead
[fd5ac0d]117# define reference parameters for each sas parameter, but then we would not
[4e9f227]118# be able to express constraints using python expressions in the usual way
119# from bumps, and would instead need to use string expressions.
[e3efa6b3]120class SasFitness(object):
[6fe5100]121    """
[e3efa6b3]122    Wrap SAS model as a bumps fitness object
[6fe5100]123    """
[5044543]124    def __init__(self, model, data, fitted=[], constraints={},
125                 initial_values=None, **kw):
[4e9f227]126        self.name = model.name
127        self.model = model.model
[6fe5100]128        self.data = data
[9f7fbd9]129        if self.data.smearer is not None:
130            self.data.smearer.model = self.model
[e3efa6b3]131        self._define_pars()
132        self._init_pars(kw)
[5044543]133        if initial_values is not None:
134            self._reset_pars(fitted, initial_values)
[0aeba4e]135        #print("constraints", constraints)
[4e9f227]136        self.constraints = dict(constraints)
[e3efa6b3]137        self.set_fitted(fitted)
[5044543]138        self.update()
139
140    def _reset_pars(self, names, values):
[1386b2f]141        for k, v in zip(names, values):
[5044543]142            self._pars[k].value = v
[e3efa6b3]143
144    def _define_pars(self):
145        self._pars = {}
146        for k in self.model.getParamList():
[1386b2f]147            name = ".".join((self.name, k))
[e3efa6b3]148            value = self.model.getParam(k)
[1386b2f]149            bounds = self.model.details.get(k, ["", None, None])[1:3]
[e3efa6b3]150            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
151                                                fixed=True, name=name)
[4e9f227]152        #print parameter.summarize(self._pars.values())
[e3efa6b3]153
154    def _init_pars(self, kw):
[1386b2f]155        for k, v in kw.items():
[e3efa6b3]156            # dispersion parameters initialized with _field instead of .field
[1386b2f]157            if k.endswith('_width'):
158                k = k[:-6]+'.width'
159            elif k.endswith('_npts'):
160                k = k[:-5]+'.npts'
161            elif k.endswith('_nsigmas'):
162                k = k[:-7]+'.nsigmas'
163            elif k.endswith('_type'):
164                k = k[:-5]+'.type'
[e3efa6b3]165            if k not in self._pars:
166                formatted_pars = ", ".join(sorted(self._pars.keys()))
167                raise KeyError("invalid parameter %r for %s--use one of: %s"
168                               %(k, self.model, formatted_pars))
169            if '.' in k and not k.endswith('.width'):
170                self.model.setParam(k, v)
171            elif isinstance(v, parameter.BaseParameter):
172                self._pars[k] = v
[1386b2f]173            elif isinstance(v, (tuple, list)):
[e3efa6b3]174                low, high = v
175                self._pars[k].value = (low+high)/2
[1386b2f]176                self._pars[k].range(low, high)
[95d58d3]177            else:
[e3efa6b3]178                self._pars[k].value = v
179
180    def set_fitted(self, param_list):
[6fe5100]181        """
[e3efa6b3]182        Flag a set of parameters as fitted parameters.
[6fe5100]183        """
[1386b2f]184        for k, p in self._pars.items():
[4e9f227]185            p.fixed = (k not in param_list or k in self.constraints)
[4a0dc427]186        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
[bf5e985]187        self.computed_par_names = [k for k in param_list if k in self.constraints]
188        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
189        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
[6fe5100]190
[e3efa6b3]191    # ===== Fitness interface ====
192    def parameters(self):
193        return self._pars
[6fe5100]194
[e3efa6b3]195    def update(self):
[1386b2f]196        for k, v in self._pars.items():
[4e9f227]197            #print "updating",k,v,v.value
[1386b2f]198            self.model.setParam(k, v.value)
[e3efa6b3]199        self._dirty = True
[6fe5100]200
[e3efa6b3]201    def _recalculate(self):
202        if self._dirty:
[9f7fbd9]203            self._residuals, self._theory \
204                = self.data.residuals(self.model.evalDistribution)
[e3efa6b3]205            self._dirty = False
[6fe5100]206
[e3efa6b3]207    def numpoints(self):
[9a5097c]208        return np.sum(self.data.idx) # number of fitted points
[6fe5100]209
[e3efa6b3]210    def nllf(self):
[9a5097c]211        return 0.5*np.sum(self.residuals()**2)
[e3efa6b3]212
213    def theory(self):
214        self._recalculate()
215        return self._theory
216
217    def residuals(self):
218        self._recalculate()
219        return self._residuals
220
221    # Not implementing the data methods for now:
222    #
223    #     resynth_data/restore_data/save/plot
[6fe5100]224
[191c648]225class ParameterExpressions(object):
226    def __init__(self, models):
227        self.models = models
228        self._setup()
229
230    def _setup(self):
231        exprs = {}
[0aeba4e]232        for model in self.models:
233            exprs.update((".".join((model.name, k)), v)
234                         for k, v in model.constraints.items())
[191c648]235        if exprs:
[0aeba4e]236            symtab = dict((".".join((model.name, k)), p)
237                          for model in self.models
238                          for k, p in model.parameters().items())
[191c648]239            self.update = compile_constraints(symtab, exprs)
240        else:
241            self.update = lambda: 0
242
243    def __call__(self):
244        self.update()
245
246    def __getstate__(self):
247        return self.models
248
249    def __setstate__(self, state):
250        self.models = state
251        self._setup()
252
[6fe5100]253class BumpsFit(FitEngine):
254    """
255    Fit a model using bumps.
256    """
257    def __init__(self):
258        """
259        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
260        with Uid as keys
261        """
262        FitEngine.__init__(self)
263        self.curr_thread = None
264
265    def fit(self, msg_q=None,
266            q=None, handler=None, curr_thread=None,
267            ftol=1.49012e-8, reset_flag=False):
[e3efa6b3]268        # Build collection of bumps fitness calculators
[bf5e985]269        models = [SasFitness(model=M.get_model(),
270                             data=M.get_data(),
271                             constraints=M.constraints,
[5044543]272                             fitted=M.pars,
273                             initial_values=M.vals if reset_flag else None)
[bf5e985]274                  for M in self.fit_arrange_dict.values()
275                  if M.get_to_fit()]
[233c121]276        if len(models) == 0:
277            raise RuntimeError("Nothing to fit")
[e3efa6b3]278        problem = FitProblem(models)
279
[191c648]280        # TODO: need better handling of parameter expressions and bounds constraints
281        # so that they are applied during polydispersity calculations.  This
282        # will remove the immediate need for the setp_hook in bumps, though
283        # bumps may still need something similar, such as a sane class structure
284        # which allows a subclass to override setp.
285        problem.setp_hook = ParameterExpressions(models)
[4e9f227]286
[e3efa6b3]287        # Run the fit
288        result = run_bumps(problem, handler, curr_thread)
[6fe5100]289        if handler is not None:
290            handler.update_fit(last=True)
[e3efa6b3]291
[eff93b8]292        # TODO: shouldn't reference internal parameters of fit problem
[e3efa6b3]293        varying = problem._parameters
294        # collect the results
295        all_results = []
296        for M in problem.models:
297            fitness = M.fitness
298            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
[e1442d4]299            param_list = fitness.fitted_par_names + fitness.computed_par_names
[e3efa6b3]300            R = FResult(model=fitness.model, data=fitness.data,
[e1442d4]301                        param_list=param_list)
[e3efa6b3]302            R.theory = fitness.theory()
303            R.residuals = fitness.residuals()
[5044543]304            R.index = fitness.data.idx
[e3efa6b3]305            R.fitter_id = self.fitter_id
[eff93b8]306            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
[e3efa6b3]307            R.success = result['success']
[e1442d4]308            if R.success:
[1a5d5f2]309                if result['stderr'] is None:
[9a5097c]310                    R.stderr = np.NaN*np.ones(len(param_list))
[1a5d5f2]311                else:
[9a5097c]312                    R.stderr = np.hstack((result['stderr'][fitted_index],
313                                          np.NaN*np.ones(len(fitness.computed_pars))))
314                R.pvec = np.hstack((result['value'][fitted_index],
[1386b2f]315                                    [p.value for p in fitness.computed_pars]))
[9a5097c]316                R.fitness = np.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
[e1442d4]317            else:
[9a5097c]318                R.stderr = np.NaN*np.ones(len(param_list))
[1386b2f]319                R.pvec = np.asarray([p.value for p in fitness.fitted_pars+fitness.computed_pars])
[9a5097c]320                R.fitness = np.NaN
[e3efa6b3]321            R.convergence = result['convergence']
322            if result['uncertainty'] is not None:
323                R.uncertainty_state = result['uncertainty']
324            all_results.append(R)
[1a5d5f2]325        all_results[0].mesg = result['errors']
[e3efa6b3]326
[6fe5100]327        if q is not None:
[e3efa6b3]328            q.put(all_results)
[6fe5100]329            return q
[e3efa6b3]330        else:
331            return all_results
[6fe5100]332
[e3efa6b3]333def run_bumps(problem, handler, curr_thread):
[85f17f6]334    def abort_test():
335        if curr_thread is None: return False
336        try: curr_thread.isquit()
337        except KeyboardInterrupt:
338            if handler is not None:
339                handler.stop("Fitting: Terminated!!!")
340            return True
341        return False
342
[7945367]343    fitclass, options = get_fitter()
344    steps = options.get('steps', 0)
345    if steps == 0:
[1386b2f]346        pop = options.get('pop', 0)*len(problem._parameters)
[7945367]347        samples = options.get('samples', 0)
348        steps = (samples+pop-1)/pop if pop != 0 else samples
349    max_step = steps + options.get('burn', 0)
[35086c3]350    pars = [p.name for p in problem._parameters]
[9a5097c]351    #x0 = np.asarray([p.value for p in problem._parameters])
[e3efa6b3]352    options['monitors'] = [
[35086c3]353        BumpsMonitor(handler, max_step, pars, problem.dof),
[e3efa6b3]354        ConvergenceMonitor(),
355        ]
[95d58d3]356    fitdriver = fitters.FitDriver(fitclass, problem=problem,
[042f065]357                                  abort_test=abort_test, **options)
[1386b2f]358    omp_threads = int(os.environ.get('OMP_NUM_THREADS', '0'))
[e1442d4]359    mapper = MPMapper if omp_threads == 1 else SerialMapper
[6fe5100]360    fitdriver.mapper = mapper.start_mapper(problem, None)
[233c121]361    #import time; T0 = time.time()
[6fe5100]362    try:
363        best, fbest = fitdriver.fit()
[1a5d5f2]364        errors = []
365    except Exception as exc:
[9a5097c]366        best, fbest = None, np.NaN
[1a30720]367        errors = [str(exc), traceback.format_exc()]
[95d58d3]368    finally:
369        mapper.stop_mapper(fitdriver.mapper)
[e3efa6b3]370
371
372    convergence_list = options['monitors'][-1].convergence
[9a5097c]373    convergence = (2*np.asarray(convergence_list)/problem.dof
[1386b2f]374                   if convergence_list else np.empty((0, 1), 'd'))
[e1442d4]375
376    success = best is not None
[1a5d5f2]377    try:
378        stderr = fitdriver.stderr() if success else None
379    except Exception as exc:
380        errors.append(str(exc))
381        errors.append(traceback.format_exc())
382        stderr = None
[e3efa6b3]383    return {
[e1442d4]384        'value': best if success else None,
[1a5d5f2]385        'stderr': stderr,
[e1442d4]386        'success': success,
[e3efa6b3]387        'convergence': convergence,
388        'uncertainty': getattr(fitdriver.fitter, 'state', None),
[1a5d5f2]389        'errors': '\n'.join(errors),
[e3efa6b3]390        }
Note: See TracBrowser for help on using the repository browser.