source: sasview/src/sas/sascalc/fit/BumpsFitting.py @ 0aeba4e

ticket-1094-headless
Last change on this file since 0aeba4e was 0aeba4e, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

code cleanup for bumps fitting wrapper

  • Property mode set to 100644
File size: 14.5 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4from __future__ import print_function
5
6import os
7from datetime import timedelta, datetime
8import traceback
9
10import numpy as np
11
12from bumps import fitters
13
14try:
15    from bumps.options import FIT_CONFIG
16    # Preserve bumps default fitter in case someone wants it later
17    BUMPS_DEFAULT_FITTER = FIT_CONFIG.selected_id
18    # Default bumps to use the Levenberg-Marquardt optimizer
19    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
20    def get_fitter():
21        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
22except ImportError:
23    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
24    # Preserve bumps default fitter in case someone wants it later
25    BUMPS_DEFAULT_FITTER = fitters.FIT_DEFAULT
26    # Default bumps to use the Levenberg-Marquardt optimizer
27    fitters.FIT_DEFAULT = 'lm'
28    def get_fitter():
29        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
30        return fitopts.fitclass, fitopts.options.copy()
31
32
33from bumps.mapper import SerialMapper, MPMapper
34from bumps import parameter
35from bumps.fitproblem import FitProblem
36
37
38from sas.sascalc.fit.AbstractFitEngine import FitEngine
39from sas.sascalc.fit.AbstractFitEngine import FResult
40from sas.sascalc.fit.expression import compile_constraints
41
42class Progress(object):
43    def __init__(self, history, max_step, pars, dof):
44        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
45        # Depending on the time remaining, either display the expected
46        # time of completion, or the amount of time remaining.  Use precision
47        # appropriate for the duration.
48        if remaining_time >= 1800:
49            completion_time = datetime.now() + timedelta(seconds=remaining_time)
50            if remaining_time >= 36000:
51                time = completion_time.strftime('%Y-%m-%d %H:%M')
52            else:
53                time = completion_time.strftime('%H:%M')
54        else:
55            if remaining_time >= 3600:
56                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
57            elif remaining_time >= 60:
58                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
59            else:
60                time = '%ds'%remaining_time
61        chisq = "%.3g"%(2*history.value[0]/dof)
62        step = "%d of %d"%(history.step[0], max_step)
63        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
64        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
65                      for i, (k, v) in enumerate(zip(pars, history.point[0]))]
66        self.msg = "".join([header]+parameters)
67
68    def __str__(self):
69        return self.msg
70
71
72class BumpsMonitor(object):
73    def __init__(self, handler, max_step, pars, dof):
74        self.handler = handler
75        self.max_step = max_step
76        self.pars = pars
77        self.dof = dof
78
79    def config_history(self, history):
80        history.requires(time=1, value=2, point=1, step=1)
81
82    def __call__(self, history):
83        if self.handler is None: return
84        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
85        self.handler.progress(history.step[0], self.max_step)
86        if len(history.step) > 1 and history.step[1] > history.step[0]:
87            self.handler.improvement()
88        self.handler.update_fit()
89
90class ConvergenceMonitor(object):
91    """
92    ConvergenceMonitor contains population summary statistics to show progress
93    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
94    just a list [ (best, ) ] if population size is 1.
95    """
96    def __init__(self):
97        self.convergence = []
98
99    def config_history(self, history):
100        history.requires(value=1, population_values=1)
101
102    def __call__(self, history):
103        best = history.value[0]
104        try:
105            p = history.population_values[0]
106            n, p = len(p), np.sort(p)
107            QI, Qmid = int(0.2*n), int(0.5*n)
108            self.convergence.append((best, p[0], p[QI], p[Qmid], p[-1-QI], p[-1]))
109        except Exception:
110            self.convergence.append((best, best, best, best, best, best))
111
112
113# Note: currently using bumps parameters for each parameter object so that
114# a SasFitness can be used directly in bumps with the usual semantics.
115# The disadvantage of this technique is that we need to copy every parameter
116# back into the model each time the function is evaluated.  We could instead
117# define reference parameters for each sas parameter, but then we would not
118# be able to express constraints using python expressions in the usual way
119# from bumps, and would instead need to use string expressions.
120class SasFitness(object):
121    """
122    Wrap SAS model as a bumps fitness object
123    """
124    def __init__(self, model, data, fitted=[], constraints={},
125                 initial_values=None, **kw):
126        self.name = model.name
127        self.model = model.model
128        self.data = data
129        if self.data.smearer is not None:
130            self.data.smearer.model = self.model
131        self._define_pars()
132        self._init_pars(kw)
133        if initial_values is not None:
134            self._reset_pars(fitted, initial_values)
135        #print("constraints", constraints)
136        self.constraints = dict(constraints)
137        self.set_fitted(fitted)
138        self.update()
139
140    def _reset_pars(self, names, values):
141        for k, v in zip(names, values):
142            self._pars[k].value = v
143
144    def _define_pars(self):
145        self._pars = {}
146        for k in self.model.getParamList():
147            name = ".".join((self.name, k))
148            value = self.model.getParam(k)
149            bounds = self.model.details.get(k, ["", None, None])[1:3]
150            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
151                                                fixed=True, name=name)
152        #print parameter.summarize(self._pars.values())
153
154    def _init_pars(self, kw):
155        for k, v in kw.items():
156            # dispersion parameters initialized with _field instead of .field
157            if k.endswith('_width'):
158                k = k[:-6]+'.width'
159            elif k.endswith('_npts'):
160                k = k[:-5]+'.npts'
161            elif k.endswith('_nsigmas'):
162                k = k[:-7]+'.nsigmas'
163            elif k.endswith('_type'):
164                k = k[:-5]+'.type'
165            if k not in self._pars:
166                formatted_pars = ", ".join(sorted(self._pars.keys()))
167                raise KeyError("invalid parameter %r for %s--use one of: %s"
168                               %(k, self.model, formatted_pars))
169            if '.' in k and not k.endswith('.width'):
170                self.model.setParam(k, v)
171            elif isinstance(v, parameter.BaseParameter):
172                self._pars[k] = v
173            elif isinstance(v, (tuple, list)):
174                low, high = v
175                self._pars[k].value = (low+high)/2
176                self._pars[k].range(low, high)
177            else:
178                self._pars[k].value = v
179
180    def set_fitted(self, param_list):
181        """
182        Flag a set of parameters as fitted parameters.
183        """
184        for k, p in self._pars.items():
185            p.fixed = (k not in param_list or k in self.constraints)
186        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
187        self.computed_par_names = [k for k in param_list if k in self.constraints]
188        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
189        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
190
191    # ===== Fitness interface ====
192    def parameters(self):
193        return self._pars
194
195    def update(self):
196        for k, v in self._pars.items():
197            #print "updating",k,v,v.value
198            self.model.setParam(k, v.value)
199        self._dirty = True
200
201    def _recalculate(self):
202        if self._dirty:
203            self._residuals, self._theory \
204                = self.data.residuals(self.model.evalDistribution)
205            self._dirty = False
206
207    def numpoints(self):
208        return np.sum(self.data.idx) # number of fitted points
209
210    def nllf(self):
211        return 0.5*np.sum(self.residuals()**2)
212
213    def theory(self):
214        self._recalculate()
215        return self._theory
216
217    def residuals(self):
218        self._recalculate()
219        return self._residuals
220
221    # Not implementing the data methods for now:
222    #
223    #     resynth_data/restore_data/save/plot
224
225class ParameterExpressions(object):
226    def __init__(self, models):
227        self.models = models
228        self._setup()
229
230    def _setup(self):
231        exprs = {}
232        for model in self.models:
233            exprs.update((".".join((model.name, k)), v)
234                         for k, v in model.constraints.items())
235        if exprs:
236            symtab = dict((".".join((model.name, k)), p)
237                          for model in self.models
238                          for k, p in model.parameters().items())
239            self.update = compile_constraints(symtab, exprs)
240        else:
241            self.update = lambda: 0
242
243    def __call__(self):
244        self.update()
245
246    def __getstate__(self):
247        return self.models
248
249    def __setstate__(self, state):
250        self.models = state
251        self._setup()
252
253class BumpsFit(FitEngine):
254    """
255    Fit a model using bumps.
256    """
257    def __init__(self):
258        """
259        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
260        with Uid as keys
261        """
262        FitEngine.__init__(self)
263        self.curr_thread = None
264
265    def fit(self, msg_q=None,
266            q=None, handler=None, curr_thread=None,
267            ftol=1.49012e-8, reset_flag=False):
268        # Build collection of bumps fitness calculators
269        models = [SasFitness(model=M.get_model(),
270                             data=M.get_data(),
271                             constraints=M.constraints,
272                             fitted=M.pars,
273                             initial_values=M.vals if reset_flag else None)
274                  for M in self.fit_arrange_dict.values()
275                  if M.get_to_fit()]
276        if len(models) == 0:
277            raise RuntimeError("Nothing to fit")
278        problem = FitProblem(models)
279
280        # TODO: need better handling of parameter expressions and bounds constraints
281        # so that they are applied during polydispersity calculations.  This
282        # will remove the immediate need for the setp_hook in bumps, though
283        # bumps may still need something similar, such as a sane class structure
284        # which allows a subclass to override setp.
285        problem.setp_hook = ParameterExpressions(models)
286
287        # Run the fit
288        result = run_bumps(problem, handler, curr_thread)
289        if handler is not None:
290            handler.update_fit(last=True)
291
292        # TODO: shouldn't reference internal parameters of fit problem
293        varying = problem._parameters
294        # collect the results
295        all_results = []
296        for M in problem.models:
297            fitness = M.fitness
298            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
299            param_list = fitness.fitted_par_names + fitness.computed_par_names
300            R = FResult(model=fitness.model, data=fitness.data,
301                        param_list=param_list)
302            R.theory = fitness.theory()
303            R.residuals = fitness.residuals()
304            R.index = fitness.data.idx
305            R.fitter_id = self.fitter_id
306            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
307            R.success = result['success']
308            if R.success:
309                if result['stderr'] is None:
310                    R.stderr = np.NaN*np.ones(len(param_list))
311                else:
312                    R.stderr = np.hstack((result['stderr'][fitted_index],
313                                          np.NaN*np.ones(len(fitness.computed_pars))))
314                R.pvec = np.hstack((result['value'][fitted_index],
315                                    [p.value for p in fitness.computed_pars]))
316                R.fitness = np.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
317            else:
318                R.stderr = np.NaN*np.ones(len(param_list))
319                R.pvec = np.asarray([p.value for p in fitness.fitted_pars+fitness.computed_pars])
320                R.fitness = np.NaN
321            R.convergence = result['convergence']
322            if result['uncertainty'] is not None:
323                R.uncertainty_state = result['uncertainty']
324            all_results.append(R)
325        all_results[0].mesg = result['errors']
326
327        if q is not None:
328            q.put(all_results)
329            return q
330        else:
331            return all_results
332
333def run_bumps(problem, handler, curr_thread):
334    def abort_test():
335        if curr_thread is None: return False
336        try: curr_thread.isquit()
337        except KeyboardInterrupt:
338            if handler is not None:
339                handler.stop("Fitting: Terminated!!!")
340            return True
341        return False
342
343    fitclass, options = get_fitter()
344    steps = options.get('steps', 0)
345    if steps == 0:
346        pop = options.get('pop', 0)*len(problem._parameters)
347        samples = options.get('samples', 0)
348        steps = (samples+pop-1)/pop if pop != 0 else samples
349    max_step = steps + options.get('burn', 0)
350    pars = [p.name for p in problem._parameters]
351    #x0 = np.asarray([p.value for p in problem._parameters])
352    options['monitors'] = [
353        BumpsMonitor(handler, max_step, pars, problem.dof),
354        ConvergenceMonitor(),
355        ]
356    fitdriver = fitters.FitDriver(fitclass, problem=problem,
357                                  abort_test=abort_test, **options)
358    omp_threads = int(os.environ.get('OMP_NUM_THREADS', '0'))
359    mapper = MPMapper if omp_threads == 1 else SerialMapper
360    fitdriver.mapper = mapper.start_mapper(problem, None)
361    #import time; T0 = time.time()
362    try:
363        best, fbest = fitdriver.fit()
364        errors = []
365    except Exception as exc:
366        best, fbest = None, np.NaN
367        errors = [str(exc), traceback.format_exc()]
368    finally:
369        mapper.stop_mapper(fitdriver.mapper)
370
371
372    convergence_list = options['monitors'][-1].convergence
373    convergence = (2*np.asarray(convergence_list)/problem.dof
374                   if convergence_list else np.empty((0, 1), 'd'))
375
376    success = best is not None
377    try:
378        stderr = fitdriver.stderr() if success else None
379    except Exception as exc:
380        errors.append(str(exc))
381        errors.append(traceback.format_exc())
382        stderr = None
383    return {
384        'value': best if success else None,
385        'stderr': stderr,
386        'success': success,
387        'convergence': convergence,
388        'uncertainty': getattr(fitdriver.fitter, 'state', None),
389        'errors': '\n'.join(errors),
390        }
Note: See TracBrowser for help on using the repository browser.