source: sasview/src/sas/sascalc/fit/BumpsFitting.py @ 3a08f369

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 3a08f369 was 9a5097c, checked in by andyfaff, 8 years ago

MAINT: import numpy as np

  • Property mode set to 100644
File size: 14.0 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import os
5from datetime import timedelta, datetime
6import traceback
7
8import numpy as np
9
10from bumps import fitters
11try:
12    from bumps.options import FIT_CONFIG
13    # Default bumps to use the Levenberg-Marquardt optimizer
14    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
15    def get_fitter():
16        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
17except:
18    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
19    # Default bumps to use the Levenberg-Marquardt optimizer
20    fitters.FIT_DEFAULT = 'lm'
21    def get_fitter():
22        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
23        return fitopts.fitclass, fitopts.options.copy()
24
25
26from bumps.mapper import SerialMapper, MPMapper
27from bumps import parameter
28from bumps.fitproblem import FitProblem
29
30
31from sas.sascalc.fit.AbstractFitEngine import FitEngine
32from sas.sascalc.fit.AbstractFitEngine import FResult
33from sas.sascalc.fit.expression import compile_constraints
34
35class Progress(object):
36    def __init__(self, history, max_step, pars, dof):
37        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
38        # Depending on the time remaining, either display the expected
39        # time of completion, or the amount of time remaining.  Use precision
40        # appropriate for the duration.
41        if remaining_time >= 1800:
42            completion_time = datetime.now() + timedelta(seconds=remaining_time)
43            if remaining_time >= 36000:
44                time = completion_time.strftime('%Y-%m-%d %H:%M')
45            else:
46                time = completion_time.strftime('%H:%M')
47        else:
48            if remaining_time >= 3600:
49                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
50            elif remaining_time >= 60:
51                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
52            else:
53                time = '%ds'%remaining_time
54        chisq = "%.3g"%(2*history.value[0]/dof)
55        step = "%d of %d"%(history.step[0], max_step)
56        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
57        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
58                      for i,(k,v) in enumerate(zip(pars,history.point[0]))]
59        self.msg = "".join([header]+parameters)
60
61    def __str__(self):
62        return self.msg
63
64
65class BumpsMonitor(object):
66    def __init__(self, handler, max_step, pars, dof):
67        self.handler = handler
68        self.max_step = max_step
69        self.pars = pars
70        self.dof = dof
71
72    def config_history(self, history):
73        history.requires(time=1, value=2, point=1, step=1)
74
75    def __call__(self, history):
76        if self.handler is None: return
77        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
78        self.handler.progress(history.step[0], self.max_step)
79        if len(history.step)>1 and history.step[1] > history.step[0]:
80            self.handler.improvement()
81        self.handler.update_fit()
82
83class ConvergenceMonitor(object):
84    """
85    ConvergenceMonitor contains population summary statistics to show progress
86    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
87    just a list [ (best, ) ] if population size is 1.
88    """
89    def __init__(self):
90        self.convergence = []
91
92    def config_history(self, history):
93        history.requires(value=1, population_values=1)
94
95    def __call__(self, history):
96        best = history.value[0]
97        try:
98            p = history.population_values[0]
99            n,p = len(p), np.sort(p)
100            QI,Qmid, = int(0.2*n),int(0.5*n)
101            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
102        except:
103            self.convergence.append((best, best,best,best,best,best))
104
105
106# Note: currently using bumps parameters for each parameter object so that
107# a SasFitness can be used directly in bumps with the usual semantics.
108# The disadvantage of this technique is that we need to copy every parameter
109# back into the model each time the function is evaluated.  We could instead
110# define reference parameters for each sas parameter, but then we would not
111# be able to express constraints using python expressions in the usual way
112# from bumps, and would instead need to use string expressions.
113class SasFitness(object):
114    """
115    Wrap SAS model as a bumps fitness object
116    """
117    def __init__(self, model, data, fitted=[], constraints={},
118                 initial_values=None, **kw):
119        self.name = model.name
120        self.model = model.model
121        self.data = data
122        if self.data.smearer is not None:
123            self.data.smearer.model = self.model
124        self._define_pars()
125        self._init_pars(kw)
126        if initial_values is not None:
127            self._reset_pars(fitted, initial_values)
128        self.constraints = dict(constraints)
129        self.set_fitted(fitted)
130        self.update()
131
132    def _reset_pars(self, names, values):
133        for k,v in zip(names, values):
134            self._pars[k].value = v
135
136    def _define_pars(self):
137        self._pars = {}
138        for k in self.model.getParamList():
139            name = ".".join((self.name,k))
140            value = self.model.getParam(k)
141            bounds = self.model.details.get(k,["",None,None])[1:3]
142            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
143                                                fixed=True, name=name)
144        #print parameter.summarize(self._pars.values())
145
146    def _init_pars(self, kw):
147        for k,v in kw.items():
148            # dispersion parameters initialized with _field instead of .field
149            if k.endswith('_width'): k = k[:-6]+'.width'
150            elif k.endswith('_npts'): k = k[:-5]+'.npts'
151            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
152            elif k.endswith('_type'): k = k[:-5]+'.type'
153            if k not in self._pars:
154                formatted_pars = ", ".join(sorted(self._pars.keys()))
155                raise KeyError("invalid parameter %r for %s--use one of: %s"
156                               %(k, self.model, formatted_pars))
157            if '.' in k and not k.endswith('.width'):
158                self.model.setParam(k, v)
159            elif isinstance(v, parameter.BaseParameter):
160                self._pars[k] = v
161            elif isinstance(v, (tuple,list)):
162                low, high = v
163                self._pars[k].value = (low+high)/2
164                self._pars[k].range(low,high)
165            else:
166                self._pars[k].value = v
167
168    def set_fitted(self, param_list):
169        """
170        Flag a set of parameters as fitted parameters.
171        """
172        for k,p in self._pars.items():
173            p.fixed = (k not in param_list or k in self.constraints)
174        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
175        self.computed_par_names = [k for k in param_list if k in self.constraints]
176        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
177        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
178
179    # ===== Fitness interface ====
180    def parameters(self):
181        return self._pars
182
183    def update(self):
184        for k,v in self._pars.items():
185            #print "updating",k,v,v.value
186            self.model.setParam(k,v.value)
187        self._dirty = True
188
189    def _recalculate(self):
190        if self._dirty:
191            self._residuals, self._theory \
192                = self.data.residuals(self.model.evalDistribution)
193            self._dirty = False
194
195    def numpoints(self):
196        return np.sum(self.data.idx) # number of fitted points
197
198    def nllf(self):
199        return 0.5*np.sum(self.residuals()**2)
200
201    def theory(self):
202        self._recalculate()
203        return self._theory
204
205    def residuals(self):
206        self._recalculate()
207        return self._residuals
208
209    # Not implementing the data methods for now:
210    #
211    #     resynth_data/restore_data/save/plot
212
213class ParameterExpressions(object):
214    def __init__(self, models):
215        self.models = models
216        self._setup()
217
218    def _setup(self):
219        exprs = {}
220        for M in self.models:
221            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
222        if exprs:
223            symtab = dict((".".join((M.name, k)), p)
224                          for M in self.models
225                          for k,p in M.parameters().items())
226            self.update = compile_constraints(symtab, exprs)
227        else:
228            self.update = lambda: 0
229
230    def __call__(self):
231        self.update()
232
233    def __getstate__(self):
234        return self.models
235
236    def __setstate__(self, state):
237        self.models = state
238        self._setup()
239
240class BumpsFit(FitEngine):
241    """
242    Fit a model using bumps.
243    """
244    def __init__(self):
245        """
246        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
247        with Uid as keys
248        """
249        FitEngine.__init__(self)
250        self.curr_thread = None
251
252    def fit(self, msg_q=None,
253            q=None, handler=None, curr_thread=None,
254            ftol=1.49012e-8, reset_flag=False):
255        # Build collection of bumps fitness calculators
256        models = [SasFitness(model=M.get_model(),
257                             data=M.get_data(),
258                             constraints=M.constraints,
259                             fitted=M.pars,
260                             initial_values=M.vals if reset_flag else None)
261                  for M in self.fit_arrange_dict.values()
262                  if M.get_to_fit()]
263        if len(models) == 0:
264            raise RuntimeError("Nothing to fit")
265        problem = FitProblem(models)
266
267        # TODO: need better handling of parameter expressions and bounds constraints
268        # so that they are applied during polydispersity calculations.  This
269        # will remove the immediate need for the setp_hook in bumps, though
270        # bumps may still need something similar, such as a sane class structure
271        # which allows a subclass to override setp.
272        problem.setp_hook = ParameterExpressions(models)
273
274        # Run the fit
275        result = run_bumps(problem, handler, curr_thread)
276        if handler is not None:
277            handler.update_fit(last=True)
278
279        # TODO: shouldn't reference internal parameters of fit problem
280        varying = problem._parameters
281        # collect the results
282        all_results = []
283        for M in problem.models:
284            fitness = M.fitness
285            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
286            param_list = fitness.fitted_par_names + fitness.computed_par_names
287            R = FResult(model=fitness.model, data=fitness.data,
288                        param_list=param_list)
289            R.theory = fitness.theory()
290            R.residuals = fitness.residuals()
291            R.index = fitness.data.idx
292            R.fitter_id = self.fitter_id
293            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
294            R.success = result['success']
295            if R.success:
296                if result['stderr'] is None:
297                    R.stderr = np.NaN*np.ones(len(param_list))
298                else:
299                    R.stderr = np.hstack((result['stderr'][fitted_index],
300                                          np.NaN*np.ones(len(fitness.computed_pars))))
301                R.pvec = np.hstack((result['value'][fitted_index],
302                                      [p.value for p in fitness.computed_pars]))
303                R.fitness = np.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
304            else:
305                R.stderr = np.NaN*np.ones(len(param_list))
306                R.pvec = np.asarray( [p.value for p in fitness.fitted_pars+fitness.computed_pars])
307                R.fitness = np.NaN
308            R.convergence = result['convergence']
309            if result['uncertainty'] is not None:
310                R.uncertainty_state = result['uncertainty']
311            all_results.append(R)
312        all_results[0].mesg = result['errors']
313
314        if q is not None:
315            q.put(all_results)
316            return q
317        else:
318            return all_results
319
320def run_bumps(problem, handler, curr_thread):
321    def abort_test():
322        if curr_thread is None: return False
323        try: curr_thread.isquit()
324        except KeyboardInterrupt:
325            if handler is not None:
326                handler.stop("Fitting: Terminated!!!")
327            return True
328        return False
329
330    fitclass, options = get_fitter()
331    steps = options.get('steps', 0)
332    if steps == 0:
333        pop = options.get('pop',0)*len(problem._parameters)
334        samples = options.get('samples', 0)
335        steps = (samples+pop-1)/pop if pop != 0 else samples
336    max_step = steps + options.get('burn', 0)
337    pars = [p.name for p in problem._parameters]
338    #x0 = np.asarray([p.value for p in problem._parameters])
339    options['monitors'] = [
340        BumpsMonitor(handler, max_step, pars, problem.dof),
341        ConvergenceMonitor(),
342        ]
343    fitdriver = fitters.FitDriver(fitclass, problem=problem,
344                                  abort_test=abort_test, **options)
345    omp_threads = int(os.environ.get('OMP_NUM_THREADS','0'))
346    mapper = MPMapper if omp_threads == 1 else SerialMapper
347    fitdriver.mapper = mapper.start_mapper(problem, None)
348    #import time; T0 = time.time()
349    try:
350        best, fbest = fitdriver.fit()
351        errors = []
352    except Exception as exc:
353        best, fbest = None, np.NaN
354        errors = [str(exc), traceback.format_exc()]
355    finally:
356        mapper.stop_mapper(fitdriver.mapper)
357
358
359    convergence_list = options['monitors'][-1].convergence
360    convergence = (2*np.asarray(convergence_list)/problem.dof
361                   if convergence_list else np.empty((0,1),'d'))
362
363    success = best is not None
364    try:
365        stderr = fitdriver.stderr() if success else None
366    except Exception as exc:
367        errors.append(str(exc))
368        errors.append(traceback.format_exc())
369        stderr = None
370    return {
371        'value': best if success else None,
372        'stderr': stderr,
373        'success': success,
374        'convergence': convergence,
375        'uncertainty': getattr(fitdriver.fitter, 'state', None),
376        'errors': '\n'.join(errors),
377        }
378
Note: See TracBrowser for help on using the repository browser.