source: sasview/src/sas/sascalc/fit/BumpsFitting.py @ 41d6187

Last change on this file since 41d6187 was 1386b2f, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

lint

  • Property mode set to 100644
File size: 14.1 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import os
5from datetime import timedelta, datetime
6import traceback
7
8import numpy as np
9
10from bumps import fitters
11try:
12    from bumps.options import FIT_CONFIG
13    # Default bumps to use the Levenberg-Marquardt optimizer
14    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
15    def get_fitter():
16        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
17except ImportError:
18    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
19    # Default bumps to use the Levenberg-Marquardt optimizer
20    fitters.FIT_DEFAULT = 'lm'
21    def get_fitter():
22        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
23        return fitopts.fitclass, fitopts.options.copy()
24
25
26from bumps.mapper import SerialMapper, MPMapper
27from bumps import parameter
28from bumps.fitproblem import FitProblem
29
30
31from sas.sascalc.fit.AbstractFitEngine import FitEngine
32from sas.sascalc.fit.AbstractFitEngine import FResult
33from sas.sascalc.fit.expression import compile_constraints
34
35class Progress(object):
36    def __init__(self, history, max_step, pars, dof):
37        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
38        # Depending on the time remaining, either display the expected
39        # time of completion, or the amount of time remaining.  Use precision
40        # appropriate for the duration.
41        if remaining_time >= 1800:
42            completion_time = datetime.now() + timedelta(seconds=remaining_time)
43            if remaining_time >= 36000:
44                time = completion_time.strftime('%Y-%m-%d %H:%M')
45            else:
46                time = completion_time.strftime('%H:%M')
47        else:
48            if remaining_time >= 3600:
49                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
50            elif remaining_time >= 60:
51                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
52            else:
53                time = '%ds'%remaining_time
54        chisq = "%.3g"%(2*history.value[0]/dof)
55        step = "%d of %d"%(history.step[0], max_step)
56        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
57        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
58                      for i, (k, v) in enumerate(zip(pars, history.point[0]))]
59        self.msg = "".join([header]+parameters)
60
61    def __str__(self):
62        return self.msg
63
64
65class BumpsMonitor(object):
66    def __init__(self, handler, max_step, pars, dof):
67        self.handler = handler
68        self.max_step = max_step
69        self.pars = pars
70        self.dof = dof
71
72    def config_history(self, history):
73        history.requires(time=1, value=2, point=1, step=1)
74
75    def __call__(self, history):
76        if self.handler is None: return
77        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
78        self.handler.progress(history.step[0], self.max_step)
79        if len(history.step) > 1 and history.step[1] > history.step[0]:
80            self.handler.improvement()
81        self.handler.update_fit()
82
83class ConvergenceMonitor(object):
84    """
85    ConvergenceMonitor contains population summary statistics to show progress
86    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
87    just a list [ (best, ) ] if population size is 1.
88    """
89    def __init__(self):
90        self.convergence = []
91
92    def config_history(self, history):
93        history.requires(value=1, population_values=1)
94
95    def __call__(self, history):
96        best = history.value[0]
97        try:
98            p = history.population_values[0]
99            n, p = len(p), np.sort(p)
100            QI, Qmid = int(0.2*n), int(0.5*n)
101            self.convergence.append((best, p[0], p[QI], p[Qmid], p[-1-QI], p[-1]))
102        except Exception:
103            self.convergence.append((best, best, best, best, best, best))
104
105
106# Note: currently using bumps parameters for each parameter object so that
107# a SasFitness can be used directly in bumps with the usual semantics.
108# The disadvantage of this technique is that we need to copy every parameter
109# back into the model each time the function is evaluated.  We could instead
110# define reference parameters for each sas parameter, but then we would not
111# be able to express constraints using python expressions in the usual way
112# from bumps, and would instead need to use string expressions.
113class SasFitness(object):
114    """
115    Wrap SAS model as a bumps fitness object
116    """
117    def __init__(self, model, data, fitted=[], constraints={},
118                 initial_values=None, **kw):
119        self.name = model.name
120        self.model = model.model
121        self.data = data
122        if self.data.smearer is not None:
123            self.data.smearer.model = self.model
124        self._define_pars()
125        self._init_pars(kw)
126        if initial_values is not None:
127            self._reset_pars(fitted, initial_values)
128        self.constraints = dict(constraints)
129        self.set_fitted(fitted)
130        self.update()
131
132    def _reset_pars(self, names, values):
133        for k, v in zip(names, values):
134            self._pars[k].value = v
135
136    def _define_pars(self):
137        self._pars = {}
138        for k in self.model.getParamList():
139            name = ".".join((self.name, k))
140            value = self.model.getParam(k)
141            bounds = self.model.details.get(k, ["", None, None])[1:3]
142            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
143                                                fixed=True, name=name)
144        #print parameter.summarize(self._pars.values())
145
146    def _init_pars(self, kw):
147        for k, v in kw.items():
148            # dispersion parameters initialized with _field instead of .field
149            if k.endswith('_width'):
150                k = k[:-6]+'.width'
151            elif k.endswith('_npts'):
152                k = k[:-5]+'.npts'
153            elif k.endswith('_nsigmas'):
154                k = k[:-7]+'.nsigmas'
155            elif k.endswith('_type'):
156                k = k[:-5]+'.type'
157            if k not in self._pars:
158                formatted_pars = ", ".join(sorted(self._pars.keys()))
159                raise KeyError("invalid parameter %r for %s--use one of: %s"
160                               %(k, self.model, formatted_pars))
161            if '.' in k and not k.endswith('.width'):
162                self.model.setParam(k, v)
163            elif isinstance(v, parameter.BaseParameter):
164                self._pars[k] = v
165            elif isinstance(v, (tuple, list)):
166                low, high = v
167                self._pars[k].value = (low+high)/2
168                self._pars[k].range(low, high)
169            else:
170                self._pars[k].value = v
171
172    def set_fitted(self, param_list):
173        """
174        Flag a set of parameters as fitted parameters.
175        """
176        for k, p in self._pars.items():
177            p.fixed = (k not in param_list or k in self.constraints)
178        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
179        self.computed_par_names = [k for k in param_list if k in self.constraints]
180        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
181        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
182
183    # ===== Fitness interface ====
184    def parameters(self):
185        return self._pars
186
187    def update(self):
188        for k, v in self._pars.items():
189            #print "updating",k,v,v.value
190            self.model.setParam(k, v.value)
191        self._dirty = True
192
193    def _recalculate(self):
194        if self._dirty:
195            self._residuals, self._theory \
196                = self.data.residuals(self.model.evalDistribution)
197            self._dirty = False
198
199    def numpoints(self):
200        return np.sum(self.data.idx) # number of fitted points
201
202    def nllf(self):
203        return 0.5*np.sum(self.residuals()**2)
204
205    def theory(self):
206        self._recalculate()
207        return self._theory
208
209    def residuals(self):
210        self._recalculate()
211        return self._residuals
212
213    # Not implementing the data methods for now:
214    #
215    #     resynth_data/restore_data/save/plot
216
217class ParameterExpressions(object):
218    def __init__(self, models):
219        self.models = models
220        self._setup()
221
222    def _setup(self):
223        exprs = {}
224        for M in self.models:
225            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
226        if exprs:
227            symtab = dict((".".join((M.name, k)), p)
228                          for M in self.models
229                          for k, p in M.parameters().items())
230            self.update = compile_constraints(symtab, exprs)
231        else:
232            self.update = lambda: 0
233
234    def __call__(self):
235        self.update()
236
237    def __getstate__(self):
238        return self.models
239
240    def __setstate__(self, state):
241        self.models = state
242        self._setup()
243
244class BumpsFit(FitEngine):
245    """
246    Fit a model using bumps.
247    """
248    def __init__(self):
249        """
250        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
251        with Uid as keys
252        """
253        FitEngine.__init__(self)
254        self.curr_thread = None
255
256    def fit(self, msg_q=None,
257            q=None, handler=None, curr_thread=None,
258            ftol=1.49012e-8, reset_flag=False):
259        # Build collection of bumps fitness calculators
260        models = [SasFitness(model=M.get_model(),
261                             data=M.get_data(),
262                             constraints=M.constraints,
263                             fitted=M.pars,
264                             initial_values=M.vals if reset_flag else None)
265                  for M in self.fit_arrange_dict.values()
266                  if M.get_to_fit()]
267        if len(models) == 0:
268            raise RuntimeError("Nothing to fit")
269        problem = FitProblem(models)
270
271        # TODO: need better handling of parameter expressions and bounds constraints
272        # so that they are applied during polydispersity calculations.  This
273        # will remove the immediate need for the setp_hook in bumps, though
274        # bumps may still need something similar, such as a sane class structure
275        # which allows a subclass to override setp.
276        problem.setp_hook = ParameterExpressions(models)
277
278        # Run the fit
279        result = run_bumps(problem, handler, curr_thread)
280        if handler is not None:
281            handler.update_fit(last=True)
282
283        # TODO: shouldn't reference internal parameters of fit problem
284        varying = problem._parameters
285        # collect the results
286        all_results = []
287        for M in problem.models:
288            fitness = M.fitness
289            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
290            param_list = fitness.fitted_par_names + fitness.computed_par_names
291            R = FResult(model=fitness.model, data=fitness.data,
292                        param_list=param_list)
293            R.theory = fitness.theory()
294            R.residuals = fitness.residuals()
295            R.index = fitness.data.idx
296            R.fitter_id = self.fitter_id
297            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
298            R.success = result['success']
299            if R.success:
300                if result['stderr'] is None:
301                    R.stderr = np.NaN*np.ones(len(param_list))
302                else:
303                    R.stderr = np.hstack((result['stderr'][fitted_index],
304                                          np.NaN*np.ones(len(fitness.computed_pars))))
305                R.pvec = np.hstack((result['value'][fitted_index],
306                                    [p.value for p in fitness.computed_pars]))
307                R.fitness = np.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
308            else:
309                R.stderr = np.NaN*np.ones(len(param_list))
310                R.pvec = np.asarray([p.value for p in fitness.fitted_pars+fitness.computed_pars])
311                R.fitness = np.NaN
312            R.convergence = result['convergence']
313            if result['uncertainty'] is not None:
314                R.uncertainty_state = result['uncertainty']
315            all_results.append(R)
316        all_results[0].mesg = result['errors']
317
318        if q is not None:
319            q.put(all_results)
320            return q
321        else:
322            return all_results
323
324def run_bumps(problem, handler, curr_thread):
325    def abort_test():
326        if curr_thread is None: return False
327        try: curr_thread.isquit()
328        except KeyboardInterrupt:
329            if handler is not None:
330                handler.stop("Fitting: Terminated!!!")
331            return True
332        return False
333
334    fitclass, options = get_fitter()
335    steps = options.get('steps', 0)
336    if steps == 0:
337        pop = options.get('pop', 0)*len(problem._parameters)
338        samples = options.get('samples', 0)
339        steps = (samples+pop-1)/pop if pop != 0 else samples
340    max_step = steps + options.get('burn', 0)
341    pars = [p.name for p in problem._parameters]
342    #x0 = np.asarray([p.value for p in problem._parameters])
343    options['monitors'] = [
344        BumpsMonitor(handler, max_step, pars, problem.dof),
345        ConvergenceMonitor(),
346        ]
347    fitdriver = fitters.FitDriver(fitclass, problem=problem,
348                                  abort_test=abort_test, **options)
349    omp_threads = int(os.environ.get('OMP_NUM_THREADS', '0'))
350    mapper = MPMapper if omp_threads == 1 else SerialMapper
351    fitdriver.mapper = mapper.start_mapper(problem, None)
352    #import time; T0 = time.time()
353    try:
354        best, fbest = fitdriver.fit()
355        errors = []
356    except Exception as exc:
357        best, fbest = None, np.NaN
358        errors = [str(exc), traceback.format_exc()]
359    finally:
360        mapper.stop_mapper(fitdriver.mapper)
361
362
363    convergence_list = options['monitors'][-1].convergence
364    convergence = (2*np.asarray(convergence_list)/problem.dof
365                   if convergence_list else np.empty((0, 1), 'd'))
366
367    success = best is not None
368    try:
369        stderr = fitdriver.stderr() if success else None
370    except Exception as exc:
371        errors.append(str(exc))
372        errors.append(traceback.format_exc())
373        stderr = None
374    return {
375        'value': best if success else None,
376        'stderr': stderr,
377        'success': success,
378        'convergence': convergence,
379        'uncertainty': getattr(fitdriver.fitter, 'state', None),
380        'errors': '\n'.join(errors),
381        }
Note: See TracBrowser for help on using the repository browser.