source: sasview/src/sas/sascalc/fit/BumpsFitting.py @ ccfe03b

ESS_GUI_bumps_abstraction
Last change on this file since ccfe03b was ccfe03b, checked in by ibressler, 5 years ago

FittingOptions?: now with tooltips again

  • extracted from generated FittingOptions? UI file
  • moved to bumps related interface code
    • could be provided by bumps directly instead
    • alternatively, extract the tooltips from the documentation programatically, perhaps?
  • Property mode set to 100644
File size: 18.3 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import os
5from datetime import timedelta, datetime
6import traceback
7
8import numpy as np
9
10from bumps import fitters
11try:
12    from bumps.options import FIT_CONFIG
13    # Default bumps to use the Levenberg-Marquardt optimizer
14    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
15    def get_fitter():
16        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
17except ImportError:
18    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
19    # Default bumps to use the Levenberg-Marquardt optimizer
20    fitters.FIT_DEFAULT = 'lm'
21    def get_fitter():
22        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
23        return fitopts.fitclass, fitopts.options.copy()
24
25
26from bumps.mapper import SerialMapper, MPMapper
27from bumps import parameter
28from bumps.fitproblem import FitProblem
29
30
31from sas.sascalc.fit.AbstractFitEngine import FitEngine
32from sas.sascalc.fit.AbstractFitEngine import FResult
33from sas.sascalc.fit.expression import compile_constraints
34
35# converted from FittingOptionsUI.py file
36toolTips = dict(
37        samples_dream = "<html><head/><body><p>Number of points to be drawn from the Markov chain.</p></body></html>",
38        burn_dream = "<html><head/><body><p>The number of iterations to required for the Markov chain to converge to the equilibrium distribution.</p></body></html>",
39        pop_dream = "<html><head/><body><p>The size of the population.</p></body></html>",
40        init_dream = "<html><head/><body><p><span style=\" font-style:italic;\">Initializer</span> determines how the population will be initialized. The options are as follows:</p><p><span style=\" font-style:italic;\">eps</span> (epsilon ball), in which the entire initial population is chosen at random from within a tiny hypersphere centered about the initial point</p><p><span style=\" font-style:italic;\">lhs</span> (latin hypersquare), which chops the bounds within each dimension in <span style=\" font-weight:600;\">k</span> equal sized chunks where <span style=\" font-weight:600;\">k</span> is the size of the population and makes sure that each parameter has at least one value within each chunk across the population.</p><p><span style=\" font-style:italic;\">cov</span> (covariance matrix), in which the uncertainty is estimated using the covariance matrix at the initial point, and points are selected at random from the corresponding gaussian ellipsoid</p><p><span style=\" font-style:italic;\">random</span> (uniform random), in which the points are selected at random within the bounds of the parameters</p></body></html>",
41        thin_dream = "<html><head/><body><p>The amount of thinning to use when collecting the population.</p></body></html>",
42        steps_dream = "<html><head/><body><p>Determines the number of iterations to use for drawing samples after burn in.</p></body></html>",
43        steps_lm = "<html><head/><body><p>The number of gradient steps to take.</p></body></html>",
44        ftol_lm = "<html><head/><body><p>Used to determine when the fit has reached the point where no significant improvement is expected.</p></body></html>",
45        xtol_lm = "<html><head/><body><p>Used to determine when the fit has reached the point where no significant improvement is expected.</p></body></html>",
46        steps_newton = "<html><head/><body><p>The number of gradient steps to take.</p></body></html>",
47        starts_newton = "<html><head/><body><p>Value thattells the optimizer to restart a given number of times. Each time it restarts it uses a random starting point.</p></body></html>",
48        ftol_newton = "<html><head/><body><p>Used to determine when the fit has reached the point where no significant improvement is expected.</p></body></html>",
49        xtol_newton = "<html><head/><body><p>Used to determine when the fit has reached the point where no significant improvement is expected.</p></body></html>",
50        steps_de = "<html><head/><body><p>The number of iterations.</p></body></html>",
51        CR_de = "<html><head/><body><p>The size of the population.</p></body></html>",
52        F_de = "<html><head/><body><p>Determines how much to scale each difference vector before adding it to the candidate point.</p></body></html>",
53        ftol_de = "<html><head/><body><p>Used to determine when the fit has reached the point where no significant improvement is expected.</p></body></html>",
54        xtol_de = "<html><head/><body><p>Used to determine when the fit has reached the point where no significant improvement is expected.</p></body></html>",
55        steps_amoeba = "<html><head/><body><p>The number of simplex update iterations to perform.</p></body></html>",
56        starts_amoeba = "<html><head/><body><p>Tells the optimizer to restart a given number of times. Each time it restarts it uses a random starting point.</p></body></html>",
57        radius_amoeba = "<html><head/><body><p>The initial size of the simplex, as a portion of the bounds defining the parameter space.</p></body></html>",
58        ftol_amoeba = "<html><head/><body><p>Used to determine when the fit has reached the point where no significant improvement is expected. </p></body></html>",
59        xtol_amoeba = "<html><head/><body><p>Used to determine when the fit has reached the point where no significant improvement is expected. </p></body></html>",
60)
61
62class Progress(object):
63    def __init__(self, history, max_step, pars, dof):
64        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
65        # Depending on the time remaining, either display the expected
66        # time of completion, or the amount of time remaining.  Use precision
67        # appropriate for the duration.
68        if remaining_time >= 1800:
69            completion_time = datetime.now() + timedelta(seconds=remaining_time)
70            if remaining_time >= 36000:
71                time = completion_time.strftime('%Y-%m-%d %H:%M')
72            else:
73                time = completion_time.strftime('%H:%M')
74        else:
75            if remaining_time >= 3600:
76                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
77            elif remaining_time >= 60:
78                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
79            else:
80                time = '%ds'%remaining_time
81        chisq = "%.3g"%(2*history.value[0]/dof)
82        step = "%d of %d"%(history.step[0], max_step)
83        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
84        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
85                      for i, (k, v) in enumerate(zip(pars, history.point[0]))]
86        self.msg = "".join([header]+parameters)
87
88    def __str__(self):
89        return self.msg
90
91
92class BumpsMonitor(object):
93    def __init__(self, handler, max_step, pars, dof):
94        self.handler = handler
95        self.max_step = max_step
96        self.pars = pars
97        self.dof = dof
98
99    def config_history(self, history):
100        history.requires(time=1, value=2, point=1, step=1)
101
102    def __call__(self, history):
103        if self.handler is None: return
104        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
105        self.handler.progress(history.step[0], self.max_step)
106        if len(history.step) > 1 and history.step[1] > history.step[0]:
107            self.handler.improvement()
108        self.handler.update_fit()
109
110class ConvergenceMonitor(object):
111    """
112    ConvergenceMonitor contains population summary statistics to show progress
113    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
114    just a list [ (best, ) ] if population size is 1.
115    """
116    def __init__(self):
117        self.convergence = []
118
119    def config_history(self, history):
120        history.requires(value=1, population_values=1)
121
122    def __call__(self, history):
123        best = history.value[0]
124        try:
125            p = history.population_values[0]
126            n, p = len(p), np.sort(p)
127            QI, Qmid = int(0.2*n), int(0.5*n)
128            self.convergence.append((best, p[0], p[QI], p[Qmid], p[-1-QI], p[-1]))
129        except Exception:
130            self.convergence.append((best, best, best, best, best, best))
131
132
133# Note: currently using bumps parameters for each parameter object so that
134# a SasFitness can be used directly in bumps with the usual semantics.
135# The disadvantage of this technique is that we need to copy every parameter
136# back into the model each time the function is evaluated.  We could instead
137# define reference parameters for each sas parameter, but then we would not
138# be able to express constraints using python expressions in the usual way
139# from bumps, and would instead need to use string expressions.
140class SasFitness(object):
141    """
142    Wrap SAS model as a bumps fitness object
143    """
144    def __init__(self, model, data, fitted=[], constraints={},
145                 initial_values=None, **kw):
146        self.name = model.name
147        self.model = model.model
148        self.data = data
149        if self.data.smearer is not None:
150            self.data.smearer.model = self.model
151        self._define_pars()
152        self._init_pars(kw)
153        if initial_values is not None:
154            self._reset_pars(fitted, initial_values)
155        self.constraints = dict(constraints)
156        self.set_fitted(fitted)
157        self.update()
158
159    def _reset_pars(self, names, values):
160        for k, v in zip(names, values):
161            self._pars[k].value = v
162
163    def _define_pars(self):
164        self._pars = {}
165        for k in self.model.getParamList():
166            name = ".".join((self.name, k))
167            value = self.model.getParam(k)
168            bounds = self.model.details.get(k, ["", None, None])[1:3]
169            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
170                                                fixed=True, name=name)
171        #print parameter.summarize(self._pars.values())
172
173    def _init_pars(self, kw):
174        for k, v in kw.items():
175            # dispersion parameters initialized with _field instead of .field
176            if k.endswith('_width'):
177                k = k[:-6]+'.width'
178            elif k.endswith('_npts'):
179                k = k[:-5]+'.npts'
180            elif k.endswith('_nsigmas'):
181                k = k[:-7]+'.nsigmas'
182            elif k.endswith('_type'):
183                k = k[:-5]+'.type'
184            if k not in self._pars:
185                formatted_pars = ", ".join(sorted(self._pars.keys()))
186                raise KeyError("invalid parameter %r for %s--use one of: %s"
187                               %(k, self.model, formatted_pars))
188            if '.' in k and not k.endswith('.width'):
189                self.model.setParam(k, v)
190            elif isinstance(v, parameter.BaseParameter):
191                self._pars[k] = v
192            elif isinstance(v, (tuple, list)):
193                low, high = v
194                self._pars[k].value = (low+high)/2
195                self._pars[k].range(low, high)
196            else:
197                self._pars[k].value = v
198
199    def set_fitted(self, param_list):
200        """
201        Flag a set of parameters as fitted parameters.
202        """
203        for k, p in self._pars.items():
204            p.fixed = (k not in param_list or k in self.constraints)
205        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
206        self.computed_par_names = [k for k in param_list if k in self.constraints]
207        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
208        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
209
210    # ===== Fitness interface ====
211    def parameters(self):
212        return self._pars
213
214    def update(self):
215        for k, v in self._pars.items():
216            #print "updating",k,v,v.value
217            self.model.setParam(k, v.value)
218        self._dirty = True
219
220    def _recalculate(self):
221        if self._dirty:
222            self._residuals, self._theory \
223                = self.data.residuals(self.model.evalDistribution)
224            self._dirty = False
225
226    def numpoints(self):
227        return np.sum(self.data.idx) # number of fitted points
228
229    def nllf(self):
230        return 0.5*np.sum(self.residuals()**2)
231
232    def theory(self):
233        self._recalculate()
234        return self._theory
235
236    def residuals(self):
237        self._recalculate()
238        return self._residuals
239
240    # Not implementing the data methods for now:
241    #
242    #     resynth_data/restore_data/save/plot
243
244class ParameterExpressions(object):
245    def __init__(self, models):
246        self.models = models
247        self._setup()
248
249    def _setup(self):
250        exprs = {}
251        for M in self.models:
252            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
253        if exprs:
254            symtab = dict((".".join((M.name, k)), p)
255                          for M in self.models
256                          for k, p in M.parameters().items())
257            self.update = compile_constraints(symtab, exprs)
258        else:
259            self.update = lambda: 0
260
261    def __call__(self):
262        self.update()
263
264    def __getstate__(self):
265        return self.models
266
267    def __setstate__(self, state):
268        self.models = state
269        self._setup()
270
271class BumpsFit(FitEngine):
272    """
273    Fit a model using bumps.
274    """
275    def __init__(self):
276        """
277        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
278        with Uid as keys
279        """
280        FitEngine.__init__(self)
281        self.curr_thread = None
282
283    def fit(self, msg_q=None,
284            q=None, handler=None, curr_thread=None,
285            ftol=1.49012e-8, reset_flag=False):
286        # Build collection of bumps fitness calculators
287        models = [SasFitness(model=M.get_model(),
288                             data=M.get_data(),
289                             constraints=M.constraints,
290                             fitted=M.pars,
291                             initial_values=M.vals if reset_flag else None)
292                  for M in self.fit_arrange_dict.values()
293                  if M.get_to_fit()]
294        if len(models) == 0:
295            raise RuntimeError("Nothing to fit")
296        problem = FitProblem(models)
297
298        # TODO: need better handling of parameter expressions and bounds constraints
299        # so that they are applied during polydispersity calculations.  This
300        # will remove the immediate need for the setp_hook in bumps, though
301        # bumps may still need something similar, such as a sane class structure
302        # which allows a subclass to override setp.
303        problem.setp_hook = ParameterExpressions(models)
304
305        # Run the fit
306        result = run_bumps(problem, handler, curr_thread)
307        if handler is not None:
308            handler.update_fit(last=True)
309
310        # TODO: shouldn't reference internal parameters of fit problem
311        varying = problem._parameters
312        # collect the results
313        all_results = []
314        for M in problem.models:
315            fitness = M.fitness
316            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
317            param_list = fitness.fitted_par_names + fitness.computed_par_names
318            R = FResult(model=fitness.model, data=fitness.data,
319                        param_list=param_list)
320            R.theory = fitness.theory()
321            R.residuals = fitness.residuals()
322            R.index = fitness.data.idx
323            R.fitter_id = self.fitter_id
324            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
325            R.success = result['success']
326            if R.success:
327                if result['stderr'] is None:
328                    R.stderr = np.NaN*np.ones(len(param_list))
329                else:
330                    R.stderr = np.hstack((result['stderr'][fitted_index],
331                                          np.NaN*np.ones(len(fitness.computed_pars))))
332                R.pvec = np.hstack((result['value'][fitted_index],
333                                    [p.value for p in fitness.computed_pars]))
334                R.fitness = np.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
335            else:
336                R.stderr = np.NaN*np.ones(len(param_list))
337                R.pvec = np.asarray([p.value for p in fitness.fitted_pars+fitness.computed_pars])
338                R.fitness = np.NaN
339            R.convergence = result['convergence']
340            if result['uncertainty'] is not None:
341                R.uncertainty_state = result['uncertainty']
342            all_results.append(R)
343        all_results[0].mesg = result['errors']
344
345        if q is not None:
346            q.put(all_results)
347            return q
348        else:
349            return all_results
350
351def run_bumps(problem, handler, curr_thread):
352    def abort_test():
353        if curr_thread is None: return False
354        try: curr_thread.isquit()
355        except KeyboardInterrupt:
356            if handler is not None:
357                handler.stop("Fitting: Terminated!!!")
358            return True
359        return False
360
361    fitclass, options = get_fitter()
362    steps = options.get('steps', 0)
363    if steps == 0:
364        pop = options.get('pop', 0)*len(problem._parameters)
365        samples = options.get('samples', 0)
366        steps = (samples+pop-1)/pop if pop != 0 else samples
367    max_step = steps + options.get('burn', 0)
368    pars = [p.name for p in problem._parameters]
369    #x0 = np.asarray([p.value for p in problem._parameters])
370    options['monitors'] = [
371        BumpsMonitor(handler, max_step, pars, problem.dof),
372        ConvergenceMonitor(),
373        ]
374    fitdriver = fitters.FitDriver(fitclass, problem=problem,
375                                  abort_test=abort_test, **options)
376    omp_threads = int(os.environ.get('OMP_NUM_THREADS', '0'))
377    mapper = MPMapper if omp_threads == 1 else SerialMapper
378    fitdriver.mapper = mapper.start_mapper(problem, None)
379    #import time; T0 = time.time()
380    try:
381        best, fbest = fitdriver.fit()
382        errors = []
383    except Exception as exc:
384        best, fbest = None, np.NaN
385        errors = [str(exc), traceback.format_exc()]
386    finally:
387        mapper.stop_mapper(fitdriver.mapper)
388
389
390    convergence_list = options['monitors'][-1].convergence
391    convergence = (2*np.asarray(convergence_list)/problem.dof
392                   if convergence_list else np.empty((0, 1), 'd'))
393
394    success = best is not None
395    try:
396        stderr = fitdriver.stderr() if success else None
397    except Exception as exc:
398        errors.append(str(exc))
399        errors.append(traceback.format_exc())
400        stderr = None
401    return {
402        'value': best if success else None,
403        'stderr': stderr,
404        'success': success,
405        'convergence': convergence,
406        'uncertainty': getattr(fitdriver.fitter, 'state', None),
407        'errors': '\n'.join(errors),
408        }
Note: See TracBrowser for help on using the repository browser.