source: sasview/src/sas/sascalc/pr/fit/BumpsFitting.py @ cb9640f

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249
Last change on this file since cb9640f was 3e6829d, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

default to floating point division in pr

  • Property mode set to 100644
File size: 13.6 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4from __future__ import division
5
6import os
7from datetime import timedelta, datetime
8
9import numpy as np
10
11from bumps import fitters
12try:
13    from bumps.options import FIT_CONFIG
14    # Default bumps to use the Levenberg-Marquardt optimizer
15    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
16    def get_fitter():
17        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
18except:
19    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
20    # Default bumps to use the Levenberg-Marquardt optimizer
21    fitters.FIT_DEFAULT = 'lm'
22    def get_fitter():
23        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
24        return fitopts.fitclass, fitopts.options.copy()
25
26
27from bumps.mapper import SerialMapper, MPMapper
28from bumps import parameter
29from bumps.fitproblem import FitProblem
30
31
32from sas.sascalc.fit.AbstractFitEngine import FitEngine
33from sas.sascalc.fit.AbstractFitEngine import FResult
34from sas.sascalc.fit.expression import compile_constraints
35
36class Progress(object):
37    def __init__(self, history, max_step, pars, dof):
38        remaining_time = int(history.time[0]*(max_step/history.step[0]-1))
39        # Depending on the time remaining, either display the expected
40        # time of completion, or the amount of time remaining.  Use precision
41        # appropriate for the duration.
42        if remaining_time >= 1800:
43            completion_time = datetime.now() + timedelta(seconds=remaining_time)
44            if remaining_time >= 36000:
45                time = completion_time.strftime('%Y-%m-%d %H:%M')
46            else:
47                time = completion_time.strftime('%H:%M')
48        else:
49            if remaining_time >= 3600:
50                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
51            elif remaining_time >= 60:
52                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
53            else:
54                time = '%ds'%remaining_time
55        chisq = "%.3g"%(2*history.value[0]/dof)
56        step = "%d of %d"%(history.step[0], max_step)
57        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
58        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
59                      for i,(k,v) in enumerate(zip(pars,history.point[0]))]
60        self.msg = "".join([header]+parameters)
61
62    def __str__(self):
63        return self.msg
64
65
66class BumpsMonitor(object):
67    def __init__(self, handler, max_step, pars, dof):
68        self.handler = handler
69        self.max_step = max_step
70        self.pars = pars
71        self.dof = dof
72
73    def config_history(self, history):
74        history.requires(time=1, value=2, point=1, step=1)
75
76    def __call__(self, history):
77        if self.handler is None: return
78        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
79        self.handler.progress(history.step[0], self.max_step)
80        if len(history.step)>1 and history.step[1] > history.step[0]:
81            self.handler.improvement()
82        self.handler.update_fit()
83
84class ConvergenceMonitor(object):
85    """
86    ConvergenceMonitor contains population summary statistics to show progress
87    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
88    just a list [ (best, ) ] if population size is 1.
89    """
90    def __init__(self):
91        self.convergence = []
92
93    def config_history(self, history):
94        history.requires(value=1, population_values=1)
95
96    def __call__(self, history):
97        best = history.value[0]
98        try:
99            p = history.population_values[0]
100            n,p = len(p), np.sort(p)
101            QI,Qmid, = int(0.2*n),int(0.5*n)
102            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
103        except:
104            self.convergence.append((best, best,best,best,best,best))
105
106
107# Note: currently using bumps parameters for each parameter object so that
108# a SasFitness can be used directly in bumps with the usual semantics.
109# The disadvantage of this technique is that we need to copy every parameter
110# back into the model each time the function is evaluated.  We could instead
111# define reference parameters for each sas parameter, but then we would not
112# be able to express constraints using python expressions in the usual way
113# from bumps, and would instead need to use string expressions.
114class SasFitness(object):
115    """
116    Wrap SAS model as a bumps fitness object
117    """
118    def __init__(self, model, data, fitted=[], constraints={},
119                 initial_values=None, **kw):
120        self.name = model.name
121        self.model = model.model
122        self.data = data
123        if self.data.smearer is not None:
124            self.data.smearer.model = self.model
125        self._define_pars()
126        self._init_pars(kw)
127        if initial_values is not None:
128            self._reset_pars(fitted, initial_values)
129        self.constraints = dict(constraints)
130        self.set_fitted(fitted)
131        self.update()
132
133    def _reset_pars(self, names, values):
134        for k,v in zip(names, values):
135            self._pars[k].value = v
136
137    def _define_pars(self):
138        self._pars = {}
139        for k in self.model.getParamList():
140            name = ".".join((self.name,k))
141            value = self.model.getParam(k)
142            bounds = self.model.details.get(k,["",None,None])[1:3]
143            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
144                                                fixed=True, name=name)
145        #print parameter.summarize(self._pars.values())
146
147    def _init_pars(self, kw):
148        for k,v in kw.items():
149            # dispersion parameters initialized with _field instead of .field
150            if k.endswith('_width'): k = k[:-6]+'.width'
151            elif k.endswith('_npts'): k = k[:-5]+'.npts'
152            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
153            elif k.endswith('_type'): k = k[:-5]+'.type'
154            if k not in self._pars:
155                formatted_pars = ", ".join(sorted(self._pars.keys()))
156                raise KeyError("invalid parameter %r for %s--use one of: %s"
157                               %(k, self.model, formatted_pars))
158            if '.' in k and not k.endswith('.width'):
159                self.model.setParam(k, v)
160            elif isinstance(v, parameter.BaseParameter):
161                self._pars[k] = v
162            elif isinstance(v, (tuple,list)):
163                low, high = v
164                self._pars[k].value = (low+high)/2
165                self._pars[k].range(low,high)
166            else:
167                self._pars[k].value = v
168
169    def set_fitted(self, param_list):
170        """
171        Flag a set of parameters as fitted parameters.
172        """
173        for k,p in self._pars.items():
174            p.fixed = (k not in param_list or k in self.constraints)
175        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
176        self.computed_par_names = [k for k in param_list if k in self.constraints]
177        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
178        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
179
180    # ===== Fitness interface ====
181    def parameters(self):
182        return self._pars
183
184    def update(self):
185        for k,v in self._pars.items():
186            #print "updating",k,v,v.value
187            self.model.setParam(k,v.value)
188        self._dirty = True
189
190    def _recalculate(self):
191        if self._dirty:
192            self._residuals, self._theory \
193                = self.data.residuals(self.model.evalDistribution)
194            self._dirty = False
195
196    def numpoints(self):
197        return np.sum(self.data.idx) # number of fitted points
198
199    def nllf(self):
200        return 0.5*np.sum(self.residuals()**2)
201
202    def theory(self):
203        self._recalculate()
204        return self._theory
205
206    def residuals(self):
207        self._recalculate()
208        return self._residuals
209
210    # Not implementing the data methods for now:
211    #
212    #     resynth_data/restore_data/save/plot
213
214class ParameterExpressions(object):
215    def __init__(self, models):
216        self.models = models
217        self._setup()
218
219    def _setup(self):
220        exprs = {}
221        for M in self.models:
222            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
223        if exprs:
224            symtab = dict((".".join((M.name, k)), p)
225                          for M in self.models
226                          for k,p in M.parameters().items())
227            self.update = compile_constraints(symtab, exprs)
228        else:
229            self.update = lambda: 0
230
231    def __call__(self):
232        self.update()
233
234    def __getstate__(self):
235        return self.models
236
237    def __setstate__(self, state):
238        self.models = state
239        self._setup()
240
241class BumpsFit(FitEngine):
242    """
243    Fit a model using bumps.
244    """
245    def __init__(self):
246        """
247        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
248        with Uid as keys
249        """
250        FitEngine.__init__(self)
251        self.curr_thread = None
252
253    def fit(self, msg_q=None,
254            q=None, handler=None, curr_thread=None,
255            ftol=1.49012e-8, reset_flag=False):
256        # Build collection of bumps fitness calculators
257        models = [SasFitness(model=M.get_model(),
258                             data=M.get_data(),
259                             constraints=M.constraints,
260                             fitted=M.pars,
261                             initial_values=M.vals if reset_flag else None)
262                  for M in self.fit_arrange_dict.values()
263                  if M.get_to_fit()]
264        if len(models) == 0:
265            raise RuntimeError("Nothing to fit")
266        problem = FitProblem(models)
267
268        # TODO: need better handling of parameter expressions and bounds constraints
269        # so that they are applied during polydispersity calculations.  This
270        # will remove the immediate need for the setp_hook in bumps, though
271        # bumps may still need something similar, such as a sane class structure
272        # which allows a subclass to override setp.
273        problem.setp_hook = ParameterExpressions(models)
274
275        # Run the fit
276        result = run_bumps(problem, handler, curr_thread)
277        if handler is not None:
278            handler.update_fit(last=True)
279
280        # TODO: shouldn't reference internal parameters of fit problem
281        varying = problem._parameters
282        # collect the results
283        all_results = []
284        for M in problem.models:
285            fitness = M.fitness
286            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
287            param_list = fitness.fitted_par_names + fitness.computed_par_names
288            R = FResult(model=fitness.model, data=fitness.data,
289                        param_list=param_list)
290            R.theory = fitness.theory()
291            R.residuals = fitness.residuals()
292            R.index = fitness.data.idx
293            R.fitter_id = self.fitter_id
294            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
295            R.success = result['success']
296            if R.success:
297                R.stderr = np.hstack((result['stderr'][fitted_index],
298                                      np.NaN*np.ones(len(fitness.computed_pars))))
299                R.pvec = np.hstack((result['value'][fitted_index],
300                                      [p.value for p in fitness.computed_pars]))
301                R.fitness = np.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
302            else:
303                R.stderr = np.NaN*np.ones(len(param_list))
304                R.pvec = np.asarray( [p.value for p in fitness.fitted_pars+fitness.computed_pars])
305                R.fitness = np.NaN
306            R.convergence = result['convergence']
307            if result['uncertainty'] is not None:
308                R.uncertainty_state = result['uncertainty']
309            all_results.append(R)
310
311        if q is not None:
312            q.put(all_results)
313            return q
314        else:
315            return all_results
316
317def run_bumps(problem, handler, curr_thread):
318    def abort_test():
319        if curr_thread is None: return False
320        try: curr_thread.isquit()
321        except KeyboardInterrupt:
322            if handler is not None:
323                handler.stop("Fitting: Terminated!!!")
324            return True
325        return False
326
327    fitclass, options = get_fitter()
328    steps = options.get('steps', 0)
329    if steps == 0:
330        pop = options.get('pop',0)*len(problem._parameters)
331        samples = options.get('samples', 0)
332        steps = (samples+pop-1)/pop if pop != 0 else samples
333    max_step = steps + options.get('burn', 0)
334    pars = [p.name for p in problem._parameters]
335    #x0 = np.asarray([p.value for p in problem._parameters])
336    options['monitors'] = [
337        BumpsMonitor(handler, max_step, pars, problem.dof),
338        ConvergenceMonitor(),
339        ]
340    fitdriver = fitters.FitDriver(fitclass, problem=problem,
341                                  abort_test=abort_test, **options)
342    omp_threads = int(os.environ.get('OMP_NUM_THREADS','0'))
343    mapper = MPMapper if omp_threads == 1 else SerialMapper
344    fitdriver.mapper = mapper.start_mapper(problem, None)
345    #import time; T0 = time.time()
346    try:
347        best, fbest = fitdriver.fit()
348    except:
349        import traceback; traceback.print_exc()
350        raise
351    finally:
352        mapper.stop_mapper(fitdriver.mapper)
353
354
355    convergence_list = options['monitors'][-1].convergence
356    convergence = (2*np.asarray(convergence_list)/problem.dof
357                   if convergence_list else np.empty((0,1),'d'))
358
359    success = best is not None
360    return {
361        'value': best if success else None,
362        'stderr': fitdriver.stderr() if success else None,
363        'success': success,
364        'convergence': convergence,
365        'uncertainty': getattr(fitdriver.fitter, 'state', None),
366        }
367
Note: See TracBrowser for help on using the repository browser.