source: sasview/src/sas/sascalc/fit/BumpsFitting.py @ 7677b4d

Last change on this file since 7677b4d was 345e7e4, checked in by GitHub <noreply@…>, 8 years ago

Revert "Jurtest2"

  • Property mode set to 100644
File size: 14.1 KB
RevLine 
[6fe5100]1"""
2BumpsFitting module runs the bumps optimizer.
3"""
[249a7c6]4import os
[35086c3]5from datetime import timedelta, datetime
[1a5d5f2]6import traceback
[35086c3]7
[6fe5100]8import numpy
9
10from bumps import fitters
[7945367]11try:
12    from bumps.options import FIT_CONFIG
13    # Default bumps to use the Levenberg-Marquardt optimizer
14    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
15    def get_fitter():
16        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
17except:
18    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
19    # Default bumps to use the Levenberg-Marquardt optimizer
20    fitters.FIT_DEFAULT = 'lm'
21    def get_fitter():
22        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
23        return fitopts.fitclass, fitopts.options.copy()
24
25
[249a7c6]26from bumps.mapper import SerialMapper, MPMapper
[e3efa6b3]27from bumps import parameter
28from bumps.fitproblem import FitProblem
[7945367]29
[345e7e4]30
[b699768]31from sas.sascalc.fit.AbstractFitEngine import FitEngine
32from sas.sascalc.fit.AbstractFitEngine import FResult
33from sas.sascalc.fit.expression import compile_constraints
[6fe5100]34
[35086c3]35class Progress(object):
36    def __init__(self, history, max_step, pars, dof):
37        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
38        # Depending on the time remaining, either display the expected
39        # time of completion, or the amount of time remaining.  Use precision
40        # appropriate for the duration.
41        if remaining_time >= 1800:
42            completion_time = datetime.now() + timedelta(seconds=remaining_time)
43            if remaining_time >= 36000:
44                time = completion_time.strftime('%Y-%m-%d %H:%M')
45            else:
46                time = completion_time.strftime('%H:%M')
47        else:
48            if remaining_time >= 3600:
49                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
50            elif remaining_time >= 60:
51                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
52            else:
53                time = '%ds'%remaining_time
54        chisq = "%.3g"%(2*history.value[0]/dof)
55        step = "%d of %d"%(history.step[0], max_step)
56        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
57        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
58                      for i,(k,v) in enumerate(zip(pars,history.point[0]))]
59        self.msg = "".join([header]+parameters)
60
61    def __str__(self):
62        return self.msg
63
64
[85f17f6]65class BumpsMonitor(object):
[35086c3]66    def __init__(self, handler, max_step, pars, dof):
[85f17f6]67        self.handler = handler
68        self.max_step = max_step
[35086c3]69        self.pars = pars
70        self.dof = dof
[ed4aef2]71
[85f17f6]72    def config_history(self, history):
73        history.requires(time=1, value=2, point=1, step=1)
[ed4aef2]74
[85f17f6]75    def __call__(self, history):
[e3efa6b3]76        if self.handler is None: return
[35086c3]77        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
[85f17f6]78        self.handler.progress(history.step[0], self.max_step)
79        if len(history.step)>1 and history.step[1] > history.step[0]:
80            self.handler.improvement()
81        self.handler.update_fit()
82
[ed4aef2]83class ConvergenceMonitor(object):
84    """
85    ConvergenceMonitor contains population summary statistics to show progress
86    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
87    just a list [ (best, ) ] if population size is 1.
88    """
89    def __init__(self):
90        self.convergence = []
91
92    def config_history(self, history):
93        history.requires(value=1, population_values=1)
94
95    def __call__(self, history):
96        best = history.value[0]
97        try:
98            p = history.population_values[0]
99            n,p = len(p), numpy.sort(p)
100            QI,Qmid, = int(0.2*n),int(0.5*n)
101            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
102        except:
[e3efa6b3]103            self.convergence.append((best, best,best,best,best,best))
[ed4aef2]104
[e3efa6b3]105
[4e9f227]106# Note: currently using bumps parameters for each parameter object so that
107# a SasFitness can be used directly in bumps with the usual semantics.
108# The disadvantage of this technique is that we need to copy every parameter
109# back into the model each time the function is evaluated.  We could instead
[fd5ac0d]110# define reference parameters for each sas parameter, but then we would not
[4e9f227]111# be able to express constraints using python expressions in the usual way
112# from bumps, and would instead need to use string expressions.
[e3efa6b3]113class SasFitness(object):
[6fe5100]114    """
[e3efa6b3]115    Wrap SAS model as a bumps fitness object
[6fe5100]116    """
[5044543]117    def __init__(self, model, data, fitted=[], constraints={},
118                 initial_values=None, **kw):
[4e9f227]119        self.name = model.name
120        self.model = model.model
[6fe5100]121        self.data = data
[9f7fbd9]122        if self.data.smearer is not None:
123            self.data.smearer.model = self.model
[e3efa6b3]124        self._define_pars()
125        self._init_pars(kw)
[5044543]126        if initial_values is not None:
127            self._reset_pars(fitted, initial_values)
[4e9f227]128        self.constraints = dict(constraints)
[e3efa6b3]129        self.set_fitted(fitted)
[5044543]130        self.update()
131
132    def _reset_pars(self, names, values):
133        for k,v in zip(names, values):
134            self._pars[k].value = v
[e3efa6b3]135
136    def _define_pars(self):
137        self._pars = {}
138        for k in self.model.getParamList():
139            name = ".".join((self.name,k))
140            value = self.model.getParam(k)
141            bounds = self.model.details.get(k,["",None,None])[1:3]
142            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
143                                                fixed=True, name=name)
[4e9f227]144        #print parameter.summarize(self._pars.values())
[e3efa6b3]145
146    def _init_pars(self, kw):
147        for k,v in kw.items():
148            # dispersion parameters initialized with _field instead of .field
149            if k.endswith('_width'): k = k[:-6]+'.width'
150            elif k.endswith('_npts'): k = k[:-5]+'.npts'
151            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
152            elif k.endswith('_type'): k = k[:-5]+'.type'
153            if k not in self._pars:
154                formatted_pars = ", ".join(sorted(self._pars.keys()))
155                raise KeyError("invalid parameter %r for %s--use one of: %s"
156                               %(k, self.model, formatted_pars))
157            if '.' in k and not k.endswith('.width'):
158                self.model.setParam(k, v)
159            elif isinstance(v, parameter.BaseParameter):
160                self._pars[k] = v
161            elif isinstance(v, (tuple,list)):
162                low, high = v
163                self._pars[k].value = (low+high)/2
164                self._pars[k].range(low,high)
[95d58d3]165            else:
[e3efa6b3]166                self._pars[k].value = v
167
168    def set_fitted(self, param_list):
[6fe5100]169        """
[e3efa6b3]170        Flag a set of parameters as fitted parameters.
[6fe5100]171        """
[e3efa6b3]172        for k,p in self._pars.items():
[4e9f227]173            p.fixed = (k not in param_list or k in self.constraints)
[4a0dc427]174        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
[bf5e985]175        self.computed_par_names = [k for k in param_list if k in self.constraints]
176        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
177        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
[6fe5100]178
[e3efa6b3]179    # ===== Fitness interface ====
180    def parameters(self):
181        return self._pars
[6fe5100]182
[e3efa6b3]183    def update(self):
184        for k,v in self._pars.items():
[4e9f227]185            #print "updating",k,v,v.value
[e3efa6b3]186            self.model.setParam(k,v.value)
187        self._dirty = True
[6fe5100]188
[e3efa6b3]189    def _recalculate(self):
190        if self._dirty:
[9f7fbd9]191            self._residuals, self._theory \
192                = self.data.residuals(self.model.evalDistribution)
[e3efa6b3]193            self._dirty = False
[6fe5100]194
[e3efa6b3]195    def numpoints(self):
196        return numpy.sum(self.data.idx) # number of fitted points
[6fe5100]197
[e3efa6b3]198    def nllf(self):
199        return 0.5*numpy.sum(self.residuals()**2)
200
201    def theory(self):
202        self._recalculate()
203        return self._theory
204
205    def residuals(self):
206        self._recalculate()
207        return self._residuals
208
209    # Not implementing the data methods for now:
210    #
211    #     resynth_data/restore_data/save/plot
[6fe5100]212
[191c648]213class ParameterExpressions(object):
214    def __init__(self, models):
215        self.models = models
216        self._setup()
217
218    def _setup(self):
219        exprs = {}
220        for M in self.models:
221            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
222        if exprs:
223            symtab = dict((".".join((M.name, k)), p)
224                          for M in self.models
225                          for k,p in M.parameters().items())
226            self.update = compile_constraints(symtab, exprs)
227        else:
228            self.update = lambda: 0
229
230    def __call__(self):
231        self.update()
232
233    def __getstate__(self):
234        return self.models
235
236    def __setstate__(self, state):
237        self.models = state
238        self._setup()
239
[6fe5100]240class BumpsFit(FitEngine):
241    """
242    Fit a model using bumps.
243    """
244    def __init__(self):
245        """
246        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
247        with Uid as keys
248        """
249        FitEngine.__init__(self)
250        self.curr_thread = None
251
252    def fit(self, msg_q=None,
253            q=None, handler=None, curr_thread=None,
254            ftol=1.49012e-8, reset_flag=False):
[e3efa6b3]255        # Build collection of bumps fitness calculators
[bf5e985]256        models = [SasFitness(model=M.get_model(),
257                             data=M.get_data(),
258                             constraints=M.constraints,
[5044543]259                             fitted=M.pars,
260                             initial_values=M.vals if reset_flag else None)
[bf5e985]261                  for M in self.fit_arrange_dict.values()
262                  if M.get_to_fit()]
[233c121]263        if len(models) == 0:
264            raise RuntimeError("Nothing to fit")
[e3efa6b3]265        problem = FitProblem(models)
266
[191c648]267        # TODO: need better handling of parameter expressions and bounds constraints
268        # so that they are applied during polydispersity calculations.  This
269        # will remove the immediate need for the setp_hook in bumps, though
270        # bumps may still need something similar, such as a sane class structure
271        # which allows a subclass to override setp.
272        problem.setp_hook = ParameterExpressions(models)
[4e9f227]273
[e3efa6b3]274        # Run the fit
275        result = run_bumps(problem, handler, curr_thread)
[6fe5100]276        if handler is not None:
277            handler.update_fit(last=True)
[e3efa6b3]278
[eff93b8]279        # TODO: shouldn't reference internal parameters of fit problem
[e3efa6b3]280        varying = problem._parameters
281        # collect the results
282        all_results = []
283        for M in problem.models:
284            fitness = M.fitness
285            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
[e1442d4]286            param_list = fitness.fitted_par_names + fitness.computed_par_names
[e3efa6b3]287            R = FResult(model=fitness.model, data=fitness.data,
[e1442d4]288                        param_list=param_list)
[e3efa6b3]289            R.theory = fitness.theory()
290            R.residuals = fitness.residuals()
[5044543]291            R.index = fitness.data.idx
[e3efa6b3]292            R.fitter_id = self.fitter_id
[eff93b8]293            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
[e3efa6b3]294            R.success = result['success']
[e1442d4]295            if R.success:
[1a5d5f2]296                if result['stderr'] is None:
297                    R.stderr = numpy.NaN*numpy.ones(len(param_list))
298                else:
299                    R.stderr = numpy.hstack((result['stderr'][fitted_index],
300                                             numpy.NaN*numpy.ones(len(fitness.computed_pars))))
[e1442d4]301                R.pvec = numpy.hstack((result['value'][fitted_index],
302                                      [p.value for p in fitness.computed_pars]))
303                R.fitness = numpy.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
304            else:
305                R.stderr = numpy.NaN*numpy.ones(len(param_list))
306                R.pvec = numpy.asarray( [p.value for p in fitness.fitted_pars+fitness.computed_pars])
307                R.fitness = numpy.NaN
[e3efa6b3]308            R.convergence = result['convergence']
309            if result['uncertainty'] is not None:
310                R.uncertainty_state = result['uncertainty']
311            all_results.append(R)
[1a5d5f2]312        all_results[0].mesg = result['errors']
[e3efa6b3]313
[6fe5100]314        if q is not None:
[e3efa6b3]315            q.put(all_results)
[6fe5100]316            return q
[e3efa6b3]317        else:
318            return all_results
[6fe5100]319
[e3efa6b3]320def run_bumps(problem, handler, curr_thread):
[85f17f6]321    def abort_test():
322        if curr_thread is None: return False
323        try: curr_thread.isquit()
324        except KeyboardInterrupt:
325            if handler is not None:
326                handler.stop("Fitting: Terminated!!!")
327            return True
328        return False
329
[7945367]330    fitclass, options = get_fitter()
331    steps = options.get('steps', 0)
332    if steps == 0:
333        pop = options.get('pop',0)*len(problem._parameters)
334        samples = options.get('samples', 0)
335        steps = (samples+pop-1)/pop if pop != 0 else samples
336    max_step = steps + options.get('burn', 0)
[35086c3]337    pars = [p.name for p in problem._parameters]
[e1442d4]338    #x0 = numpy.asarray([p.value for p in problem._parameters])
[e3efa6b3]339    options['monitors'] = [
[35086c3]340        BumpsMonitor(handler, max_step, pars, problem.dof),
[e3efa6b3]341        ConvergenceMonitor(),
342        ]
[95d58d3]343    fitdriver = fitters.FitDriver(fitclass, problem=problem,
[042f065]344                                  abort_test=abort_test, **options)
[249a7c6]345    omp_threads = int(os.environ.get('OMP_NUM_THREADS','0'))
[e1442d4]346    mapper = MPMapper if omp_threads == 1 else SerialMapper
[6fe5100]347    fitdriver.mapper = mapper.start_mapper(problem, None)
[233c121]348    #import time; T0 = time.time()
[6fe5100]349    try:
350        best, fbest = fitdriver.fit()
[1a5d5f2]351        errors = []
352    except Exception as exc:
353        best, fbest = None, numpy.NaN
354        errors = [str(exc), traceback.traceback.format_exc()]
[95d58d3]355    finally:
356        mapper.stop_mapper(fitdriver.mapper)
[e3efa6b3]357
358
359    convergence_list = options['monitors'][-1].convergence
360    convergence = (2*numpy.asarray(convergence_list)/problem.dof
361                   if convergence_list else numpy.empty((0,1),'d'))
[e1442d4]362
363    success = best is not None
[1a5d5f2]364    try:
365        stderr = fitdriver.stderr() if success else None
366    except Exception as exc:
367        errors.append(str(exc))
368        errors.append(traceback.format_exc())
369        stderr = None
[e3efa6b3]370    return {
[e1442d4]371        'value': best if success else None,
[1a5d5f2]372        'stderr': stderr,
[e1442d4]373        'success': success,
[e3efa6b3]374        'convergence': convergence,
375        'uncertainty': getattr(fitdriver.fitter, 'state', None),
[1a5d5f2]376        'errors': '\n'.join(errors),
[e3efa6b3]377        }
[6fe5100]378
Note: See TracBrowser for help on using the repository browser.