source: sasview/src/sas/sascalc/pr/fit/BumpsFitting.py @ b94889a

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since b94889a was b699768, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 9 years ago

Initial commit of the refactored SasCalc? module.

  • Property mode set to 100644
File size: 13.6 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import os
5from datetime import timedelta, datetime
6
7import numpy
8
9from bumps import fitters
10try:
11    from bumps.options import FIT_CONFIG
12    # Default bumps to use the Levenberg-Marquardt optimizer
13    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
14    def get_fitter():
15        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
16except:
17    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
18    # Default bumps to use the Levenberg-Marquardt optimizer
19    fitters.FIT_DEFAULT = 'lm'
20    def get_fitter():
21        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
22        return fitopts.fitclass, fitopts.options.copy()
23
24
25from bumps.mapper import SerialMapper, MPMapper
26from bumps import parameter
27from bumps.fitproblem import FitProblem
28
29
30from sas.sascalc.fit.AbstractFitEngine import FitEngine
31from sas.sascalc.fit.AbstractFitEngine import FResult
32from sas.sascalc.fit.expression import compile_constraints
33
34class Progress(object):
35    def __init__(self, history, max_step, pars, dof):
36        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
37        # Depending on the time remaining, either display the expected
38        # time of completion, or the amount of time remaining.  Use precision
39        # appropriate for the duration.
40        if remaining_time >= 1800:
41            completion_time = datetime.now() + timedelta(seconds=remaining_time)
42            if remaining_time >= 36000:
43                time = completion_time.strftime('%Y-%m-%d %H:%M')
44            else:
45                time = completion_time.strftime('%H:%M')
46        else:
47            if remaining_time >= 3600:
48                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
49            elif remaining_time >= 60:
50                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
51            else:
52                time = '%ds'%remaining_time
53        chisq = "%.3g"%(2*history.value[0]/dof)
54        step = "%d of %d"%(history.step[0], max_step)
55        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
56        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
57                      for i,(k,v) in enumerate(zip(pars,history.point[0]))]
58        self.msg = "".join([header]+parameters)
59
60    def __str__(self):
61        return self.msg
62
63
64class BumpsMonitor(object):
65    def __init__(self, handler, max_step, pars, dof):
66        self.handler = handler
67        self.max_step = max_step
68        self.pars = pars
69        self.dof = dof
70
71    def config_history(self, history):
72        history.requires(time=1, value=2, point=1, step=1)
73
74    def __call__(self, history):
75        if self.handler is None: return
76        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
77        self.handler.progress(history.step[0], self.max_step)
78        if len(history.step)>1 and history.step[1] > history.step[0]:
79            self.handler.improvement()
80        self.handler.update_fit()
81
82class ConvergenceMonitor(object):
83    """
84    ConvergenceMonitor contains population summary statistics to show progress
85    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
86    just a list [ (best, ) ] if population size is 1.
87    """
88    def __init__(self):
89        self.convergence = []
90
91    def config_history(self, history):
92        history.requires(value=1, population_values=1)
93
94    def __call__(self, history):
95        best = history.value[0]
96        try:
97            p = history.population_values[0]
98            n,p = len(p), numpy.sort(p)
99            QI,Qmid, = int(0.2*n),int(0.5*n)
100            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
101        except:
102            self.convergence.append((best, best,best,best,best,best))
103
104
105# Note: currently using bumps parameters for each parameter object so that
106# a SasFitness can be used directly in bumps with the usual semantics.
107# The disadvantage of this technique is that we need to copy every parameter
108# back into the model each time the function is evaluated.  We could instead
109# define reference parameters for each sas parameter, but then we would not
110# be able to express constraints using python expressions in the usual way
111# from bumps, and would instead need to use string expressions.
112class SasFitness(object):
113    """
114    Wrap SAS model as a bumps fitness object
115    """
116    def __init__(self, model, data, fitted=[], constraints={},
117                 initial_values=None, **kw):
118        self.name = model.name
119        self.model = model.model
120        self.data = data
121        if self.data.smearer is not None:
122            self.data.smearer.model = self.model
123        self._define_pars()
124        self._init_pars(kw)
125        if initial_values is not None:
126            self._reset_pars(fitted, initial_values)
127        self.constraints = dict(constraints)
128        self.set_fitted(fitted)
129        self.update()
130
131    def _reset_pars(self, names, values):
132        for k,v in zip(names, values):
133            self._pars[k].value = v
134
135    def _define_pars(self):
136        self._pars = {}
137        for k in self.model.getParamList():
138            name = ".".join((self.name,k))
139            value = self.model.getParam(k)
140            bounds = self.model.details.get(k,["",None,None])[1:3]
141            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
142                                                fixed=True, name=name)
143        #print parameter.summarize(self._pars.values())
144
145    def _init_pars(self, kw):
146        for k,v in kw.items():
147            # dispersion parameters initialized with _field instead of .field
148            if k.endswith('_width'): k = k[:-6]+'.width'
149            elif k.endswith('_npts'): k = k[:-5]+'.npts'
150            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
151            elif k.endswith('_type'): k = k[:-5]+'.type'
152            if k not in self._pars:
153                formatted_pars = ", ".join(sorted(self._pars.keys()))
154                raise KeyError("invalid parameter %r for %s--use one of: %s"
155                               %(k, self.model, formatted_pars))
156            if '.' in k and not k.endswith('.width'):
157                self.model.setParam(k, v)
158            elif isinstance(v, parameter.BaseParameter):
159                self._pars[k] = v
160            elif isinstance(v, (tuple,list)):
161                low, high = v
162                self._pars[k].value = (low+high)/2
163                self._pars[k].range(low,high)
164            else:
165                self._pars[k].value = v
166
167    def set_fitted(self, param_list):
168        """
169        Flag a set of parameters as fitted parameters.
170        """
171        for k,p in self._pars.items():
172            p.fixed = (k not in param_list or k in self.constraints)
173        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
174        self.computed_par_names = [k for k in param_list if k in self.constraints]
175        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
176        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
177
178    # ===== Fitness interface ====
179    def parameters(self):
180        return self._pars
181
182    def update(self):
183        for k,v in self._pars.items():
184            #print "updating",k,v,v.value
185            self.model.setParam(k,v.value)
186        self._dirty = True
187
188    def _recalculate(self):
189        if self._dirty:
190            self._residuals, self._theory \
191                = self.data.residuals(self.model.evalDistribution)
192            self._dirty = False
193
194    def numpoints(self):
195        return numpy.sum(self.data.idx) # number of fitted points
196
197    def nllf(self):
198        return 0.5*numpy.sum(self.residuals()**2)
199
200    def theory(self):
201        self._recalculate()
202        return self._theory
203
204    def residuals(self):
205        self._recalculate()
206        return self._residuals
207
208    # Not implementing the data methods for now:
209    #
210    #     resynth_data/restore_data/save/plot
211
212class ParameterExpressions(object):
213    def __init__(self, models):
214        self.models = models
215        self._setup()
216
217    def _setup(self):
218        exprs = {}
219        for M in self.models:
220            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
221        if exprs:
222            symtab = dict((".".join((M.name, k)), p)
223                          for M in self.models
224                          for k,p in M.parameters().items())
225            self.update = compile_constraints(symtab, exprs)
226        else:
227            self.update = lambda: 0
228
229    def __call__(self):
230        self.update()
231
232    def __getstate__(self):
233        return self.models
234
235    def __setstate__(self, state):
236        self.models = state
237        self._setup()
238
239class BumpsFit(FitEngine):
240    """
241    Fit a model using bumps.
242    """
243    def __init__(self):
244        """
245        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
246        with Uid as keys
247        """
248        FitEngine.__init__(self)
249        self.curr_thread = None
250
251    def fit(self, msg_q=None,
252            q=None, handler=None, curr_thread=None,
253            ftol=1.49012e-8, reset_flag=False):
254        # Build collection of bumps fitness calculators
255        models = [SasFitness(model=M.get_model(),
256                             data=M.get_data(),
257                             constraints=M.constraints,
258                             fitted=M.pars,
259                             initial_values=M.vals if reset_flag else None)
260                  for M in self.fit_arrange_dict.values()
261                  if M.get_to_fit()]
262        if len(models) == 0:
263            raise RuntimeError("Nothing to fit")
264        problem = FitProblem(models)
265
266        # TODO: need better handling of parameter expressions and bounds constraints
267        # so that they are applied during polydispersity calculations.  This
268        # will remove the immediate need for the setp_hook in bumps, though
269        # bumps may still need something similar, such as a sane class structure
270        # which allows a subclass to override setp.
271        problem.setp_hook = ParameterExpressions(models)
272
273        # Run the fit
274        result = run_bumps(problem, handler, curr_thread)
275        if handler is not None:
276            handler.update_fit(last=True)
277
278        # TODO: shouldn't reference internal parameters of fit problem
279        varying = problem._parameters
280        # collect the results
281        all_results = []
282        for M in problem.models:
283            fitness = M.fitness
284            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
285            param_list = fitness.fitted_par_names + fitness.computed_par_names
286            R = FResult(model=fitness.model, data=fitness.data,
287                        param_list=param_list)
288            R.theory = fitness.theory()
289            R.residuals = fitness.residuals()
290            R.index = fitness.data.idx
291            R.fitter_id = self.fitter_id
292            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
293            R.success = result['success']
294            if R.success:
295                R.stderr = numpy.hstack((result['stderr'][fitted_index],
296                                         numpy.NaN*numpy.ones(len(fitness.computed_pars))))
297                R.pvec = numpy.hstack((result['value'][fitted_index],
298                                      [p.value for p in fitness.computed_pars]))
299                R.fitness = numpy.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
300            else:
301                R.stderr = numpy.NaN*numpy.ones(len(param_list))
302                R.pvec = numpy.asarray( [p.value for p in fitness.fitted_pars+fitness.computed_pars])
303                R.fitness = numpy.NaN
304            R.convergence = result['convergence']
305            if result['uncertainty'] is not None:
306                R.uncertainty_state = result['uncertainty']
307            all_results.append(R)
308
309        if q is not None:
310            q.put(all_results)
311            return q
312        else:
313            return all_results
314
315def run_bumps(problem, handler, curr_thread):
316    def abort_test():
317        if curr_thread is None: return False
318        try: curr_thread.isquit()
319        except KeyboardInterrupt:
320            if handler is not None:
321                handler.stop("Fitting: Terminated!!!")
322            return True
323        return False
324
325    fitclass, options = get_fitter()
326    steps = options.get('steps', 0)
327    if steps == 0:
328        pop = options.get('pop',0)*len(problem._parameters)
329        samples = options.get('samples', 0)
330        steps = (samples+pop-1)/pop if pop != 0 else samples
331    max_step = steps + options.get('burn', 0)
332    pars = [p.name for p in problem._parameters]
333    #x0 = numpy.asarray([p.value for p in problem._parameters])
334    options['monitors'] = [
335        BumpsMonitor(handler, max_step, pars, problem.dof),
336        ConvergenceMonitor(),
337        ]
338    fitdriver = fitters.FitDriver(fitclass, problem=problem,
339                                  abort_test=abort_test, **options)
340    omp_threads = int(os.environ.get('OMP_NUM_THREADS','0'))
341    mapper = MPMapper if omp_threads == 1 else SerialMapper
342    fitdriver.mapper = mapper.start_mapper(problem, None)
343    #import time; T0 = time.time()
344    try:
345        best, fbest = fitdriver.fit()
346    except:
347        import traceback; traceback.print_exc()
348        raise
349    finally:
350        mapper.stop_mapper(fitdriver.mapper)
351
352
353    convergence_list = options['monitors'][-1].convergence
354    convergence = (2*numpy.asarray(convergence_list)/problem.dof
355                   if convergence_list else numpy.empty((0,1),'d'))
356
357    success = best is not None
358    return {
359        'value': best if success else None,
360        'stderr': fitdriver.stderr() if success else None,
361        'success': success,
362        'convergence': convergence,
363        'uncertainty': getattr(fitdriver.fitter, 'state', None),
364        }
365
Note: See TracBrowser for help on using the repository browser.