source: sasview/src/sas/sascalc/fit/BumpsFitting.py @ f2724b6

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since f2724b6 was f668101, checked in by jhbakker, 7 years ago

Merge branch 'master' into Jurtest

  • Property mode set to 100644
File size: 14.1 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import os
5from datetime import timedelta, datetime
6import traceback
7
8import numpy
9
10from bumps import fitters
11try:
12    from bumps.options import FIT_CONFIG
13    # Default bumps to use the Levenberg-Marquardt optimizer
14    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
15    def get_fitter():
16        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
17except:
18    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
19    # Default bumps to use the Levenberg-Marquardt optimizer
20    fitters.FIT_DEFAULT = 'lm'
21    def get_fitter():
22        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
23        return fitopts.fitclass, fitopts.options.copy()
24
25
26from bumps.mapper import SerialMapper, MPMapper
27from bumps import parameter
28from bumps.fitproblem import FitProblem
29
30from sas.sascalc.fit.AbstractFitEngine import FitEngine
31from sas.sascalc.fit.AbstractFitEngine import FResult
32from sas.sascalc.fit.expression import compile_constraints
33
34class Progress(object):
35    def __init__(self, history, max_step, pars, dof):
36        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
37        # Depending on the time remaining, either display the expected
38        # time of completion, or the amount of time remaining.  Use precision
39        # appropriate for the duration.
40        if remaining_time >= 1800:
41            completion_time = datetime.now() + timedelta(seconds=remaining_time)
42            if remaining_time >= 36000:
43                time = completion_time.strftime('%Y-%m-%d %H:%M')
44            else:
45                time = completion_time.strftime('%H:%M')
46        else:
47            if remaining_time >= 3600:
48                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
49            elif remaining_time >= 60:
50                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
51            else:
52                time = '%ds'%remaining_time
53        chisq = "%.3g"%(2*history.value[0]/dof)
54        step = "%d of %d"%(history.step[0], max_step)
55        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
56        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
57                      for i,(k,v) in enumerate(zip(pars,history.point[0]))]
58        self.msg = "".join([header]+parameters)
59
60    def __str__(self):
61        return self.msg
62
63
64class BumpsMonitor(object):
65    def __init__(self, handler, max_step, pars, dof):
66        self.handler = handler
67        self.max_step = max_step
68        self.pars = pars
69        self.dof = dof
70
71    def config_history(self, history):
72        history.requires(time=1, value=2, point=1, step=1)
73
74    def __call__(self, history):
75        if self.handler is None: return
76        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
77        self.handler.progress(history.step[0], self.max_step)
78        if len(history.step)>1 and history.step[1] > history.step[0]:
79            self.handler.improvement()
80        self.handler.update_fit()
81
82class ConvergenceMonitor(object):
83    """
84    ConvergenceMonitor contains population summary statistics to show progress
85    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
86    just a list [ (best, ) ] if population size is 1.
87    """
88    def __init__(self):
89        self.convergence = []
90
91    def config_history(self, history):
92        history.requires(value=1, population_values=1)
93
94    def __call__(self, history):
95        best = history.value[0]
96        try:
97            p = history.population_values[0]
98            n,p = len(p), numpy.sort(p)
99            QI,Qmid, = int(0.2*n),int(0.5*n)
100            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
101        except:
102            self.convergence.append((best, best,best,best,best,best))
103
104
105# Note: currently using bumps parameters for each parameter object so that
106# a SasFitness can be used directly in bumps with the usual semantics.
107# The disadvantage of this technique is that we need to copy every parameter
108# back into the model each time the function is evaluated.  We could instead
109# define reference parameters for each sas parameter, but then we would not
110# be able to express constraints using python expressions in the usual way
111# from bumps, and would instead need to use string expressions.
112class SasFitness(object):
113    """
114    Wrap SAS model as a bumps fitness object
115    """
116    def __init__(self, model, data, fitted=[], constraints={},
117                 initial_values=None, **kw):
118        self.name = model.name
119        self.model = model.model
120        self.data = data
121        if self.data.smearer is not None:
122            self.data.smearer.model = self.model
123        self._define_pars()
124        self._init_pars(kw)
125        if initial_values is not None:
126            self._reset_pars(fitted, initial_values)
127        self.constraints = dict(constraints)
128        self.set_fitted(fitted)
129        self.update()
130
131    def _reset_pars(self, names, values):
132        for k,v in zip(names, values):
133            self._pars[k].value = v
134
135    def _define_pars(self):
136        self._pars = {}
137        for k in self.model.getParamList():
138            name = ".".join((self.name,k))
139            value = self.model.getParam(k)
140            bounds = self.model.details.get(k,["",None,None])[1:3]
141            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
142                                                fixed=True, name=name)
143        #print parameter.summarize(self._pars.values())
144
145    def _init_pars(self, kw):
146        for k,v in kw.items():
147            # dispersion parameters initialized with _field instead of .field
148            if k.endswith('_width'): k = k[:-6]+'.width'
149            elif k.endswith('_npts'): k = k[:-5]+'.npts'
150            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
151            elif k.endswith('_type'): k = k[:-5]+'.type'
152            if k not in self._pars:
153                formatted_pars = ", ".join(sorted(self._pars.keys()))
154                raise KeyError("invalid parameter %r for %s--use one of: %s"
155                               %(k, self.model, formatted_pars))
156            if '.' in k and not k.endswith('.width'):
157                self.model.setParam(k, v)
158            elif isinstance(v, parameter.BaseParameter):
159                self._pars[k] = v
160            elif isinstance(v, (tuple,list)):
161                low, high = v
162                self._pars[k].value = (low+high)/2
163                self._pars[k].range(low,high)
164            else:
165                self._pars[k].value = v
166
167    def set_fitted(self, param_list):
168        """
169        Flag a set of parameters as fitted parameters.
170        """
171        for k,p in self._pars.items():
172            p.fixed = (k not in param_list or k in self.constraints)
173        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
174        self.computed_par_names = [k for k in param_list if k in self.constraints]
175        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
176        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
177
178    # ===== Fitness interface ====
179    def parameters(self):
180        return self._pars
181
182    def update(self):
183        for k,v in self._pars.items():
184            #print "updating",k,v,v.value
185            self.model.setParam(k,v.value)
186        self._dirty = True
187
188    def _recalculate(self):
189        if self._dirty:
190            self._residuals, self._theory \
191                = self.data.residuals(self.model.evalDistribution)
192            self._dirty = False
193
194    def numpoints(self):
195        return numpy.sum(self.data.idx) # number of fitted points
196
197    def nllf(self):
198        return 0.5*numpy.sum(self.residuals()**2)
199
200    def theory(self):
201        self._recalculate()
202        return self._theory
203
204    def residuals(self):
205        self._recalculate()
206        return self._residuals
207
208    # Not implementing the data methods for now:
209    #
210    #     resynth_data/restore_data/save/plot
211
212class ParameterExpressions(object):
213    def __init__(self, models):
214        self.models = models
215        self._setup()
216
217    def _setup(self):
218        exprs = {}
219        for M in self.models:
220            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
221        if exprs:
222            symtab = dict((".".join((M.name, k)), p)
223                          for M in self.models
224                          for k,p in M.parameters().items())
225            self.update = compile_constraints(symtab, exprs)
226        else:
227            self.update = lambda: 0
228
229    def __call__(self):
230        self.update()
231
232    def __getstate__(self):
233        return self.models
234
235    def __setstate__(self, state):
236        self.models = state
237        self._setup()
238
239class BumpsFit(FitEngine):
240    """
241    Fit a model using bumps.
242    """
243    def __init__(self):
244        """
245        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
246        with Uid as keys
247        """
248        FitEngine.__init__(self)
249        self.curr_thread = None
250
251    def fit(self, msg_q=None,
252            q=None, handler=None, curr_thread=None,
253            ftol=1.49012e-8, reset_flag=False):
254        # Build collection of bumps fitness calculators
255        models = [SasFitness(model=M.get_model(),
256                             data=M.get_data(),
257                             constraints=M.constraints,
258                             fitted=M.pars,
259                             initial_values=M.vals if reset_flag else None)
260                  for M in self.fit_arrange_dict.values()
261                  if M.get_to_fit()]
262        if len(models) == 0:
263            raise RuntimeError("Nothing to fit")
264        problem = FitProblem(models)
265
266        # TODO: need better handling of parameter expressions and bounds constraints
267        # so that they are applied during polydispersity calculations.  This
268        # will remove the immediate need for the setp_hook in bumps, though
269        # bumps may still need something similar, such as a sane class structure
270        # which allows a subclass to override setp.
271        problem.setp_hook = ParameterExpressions(models)
272
273        # Run the fit
274        result = run_bumps(problem, handler, curr_thread)
275        if handler is not None:
276            handler.update_fit(last=True)
277
278        # TODO: shouldn't reference internal parameters of fit problem
279        varying = problem._parameters
280        # collect the results
281        all_results = []
282        for M in problem.models:
283            fitness = M.fitness
284            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
285            param_list = fitness.fitted_par_names + fitness.computed_par_names
286            R = FResult(model=fitness.model, data=fitness.data,
287                        param_list=param_list)
288            R.theory = fitness.theory()
289            R.residuals = fitness.residuals()
290            R.index = fitness.data.idx
291            R.fitter_id = self.fitter_id
292            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
293            R.success = result['success']
294            if R.success:
295                if result['stderr'] is None:
296                    R.stderr = numpy.NaN*numpy.ones(len(param_list))
297                else:
298                    R.stderr = numpy.hstack((result['stderr'][fitted_index],
299                                             numpy.NaN*numpy.ones(len(fitness.computed_pars))))
300                R.pvec = numpy.hstack((result['value'][fitted_index],
301                                      [p.value for p in fitness.computed_pars]))
302                R.fitness = numpy.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
303            else:
304                R.stderr = numpy.NaN*numpy.ones(len(param_list))
305                R.pvec = numpy.asarray( [p.value for p in fitness.fitted_pars+fitness.computed_pars])
306                R.fitness = numpy.NaN
307            R.convergence = result['convergence']
308            if result['uncertainty'] is not None:
309                R.uncertainty_state = result['uncertainty']
310            all_results.append(R)
311        all_results[0].mesg = result['errors']
312
313        if q is not None:
314            q.put(all_results)
315            return q
316        else:
317            return all_results
318
319def run_bumps(problem, handler, curr_thread):
320    def abort_test():
321        if curr_thread is None: return False
322        try: curr_thread.isquit()
323        except KeyboardInterrupt:
324            if handler is not None:
325                handler.stop("Fitting: Terminated!!!")
326            return True
327        return False
328
329    fitclass, options = get_fitter()
330    steps = options.get('steps', 0)
331    if steps == 0:
332        pop = options.get('pop',0)*len(problem._parameters)
333        samples = options.get('samples', 0)
334        steps = (samples+pop-1)/pop if pop != 0 else samples
335    max_step = steps + options.get('burn', 0)
336    pars = [p.name for p in problem._parameters]
337    #x0 = numpy.asarray([p.value for p in problem._parameters])
338    options['monitors'] = [
339        BumpsMonitor(handler, max_step, pars, problem.dof),
340        ConvergenceMonitor(),
341        ]
342    fitdriver = fitters.FitDriver(fitclass, problem=problem,
343                                  abort_test=abort_test, **options)
344    omp_threads = int(os.environ.get('OMP_NUM_THREADS','0'))
345    mapper = MPMapper if omp_threads == 1 else SerialMapper
346    fitdriver.mapper = mapper.start_mapper(problem, None)
347    #import time; T0 = time.time()
348    try:
349        best, fbest = fitdriver.fit()
350        errors = []
351    except Exception as exc:
352        best, fbest = None, numpy.NaN
353        errors = [str(exc), traceback.traceback.format_exc()]
354    finally:
355        mapper.stop_mapper(fitdriver.mapper)
356
357
358    convergence_list = options['monitors'][-1].convergence
359    convergence = (2*numpy.asarray(convergence_list)/problem.dof
360                   if convergence_list else numpy.empty((0,1),'d'))
361
362    success = best is not None
363    try:
364        stderr = fitdriver.stderr() if success else None
365    except Exception as exc:
366        errors.append(str(exc))
367        errors.append(traceback.format_exc())
368        stderr = None
369    return {
370        'value': best if success else None,
371        'stderr': stderr,
372        'success': success,
373        'convergence': convergence,
374        'uncertainty': getattr(fitdriver.fitter, 'state', None),
375        'errors': '\n'.join(errors),
376        }
377
Note: See TracBrowser for help on using the repository browser.