source: sasview/src/sas/fit/BumpsFitting.py @ 4666660

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 4666660 was 9f7fbd9, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

rewrite slit smearing for usans

  • Property mode set to 100644
File size: 12.7 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import os
5from datetime import timedelta, datetime
6
7import numpy
8
9from bumps import fitters
10from bumps.mapper import SerialMapper, MPMapper
11from bumps import parameter
12
13# TODO: remove globals from interface to bumps options!
14# Default bumps to use the levenberg-marquardt optimizer
15from bumps.fitproblem import FitProblem
16fitters.FIT_DEFAULT = 'lm'
17
18from sas.fit.AbstractFitEngine import FitEngine
19from sas.fit.AbstractFitEngine import FResult
20from sas.fit.expression import compile_constraints
21
22class Progress(object):
23    def __init__(self, history, max_step, pars, dof):
24        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
25        # Depending on the time remaining, either display the expected
26        # time of completion, or the amount of time remaining.  Use precision
27        # appropriate for the duration.
28        if remaining_time >= 1800:
29            completion_time = datetime.now() + timedelta(seconds=remaining_time)
30            if remaining_time >= 36000:
31                time = completion_time.strftime('%Y-%m-%d %H:%M')
32            else:
33                time = completion_time.strftime('%H:%M')
34        else:
35            if remaining_time >= 3600:
36                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
37            elif remaining_time >= 60:
38                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
39            else:
40                time = '%ds'%remaining_time
41        chisq = "%.3g"%(2*history.value[0]/dof)
42        step = "%d of %d"%(history.step[0], max_step)
43        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
44        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
45                      for i,(k,v) in enumerate(zip(pars,history.point[0]))]
46        self.msg = "".join([header]+parameters)
47
48    def __str__(self):
49        return self.msg
50
51
52class BumpsMonitor(object):
53    def __init__(self, handler, max_step, pars, dof):
54        self.handler = handler
55        self.max_step = max_step
56        self.pars = pars
57        self.dof = dof
58
59    def config_history(self, history):
60        history.requires(time=1, value=2, point=1, step=1)
61
62    def __call__(self, history):
63        if self.handler is None: return
64        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
65        self.handler.progress(history.step[0], self.max_step)
66        if len(history.step)>1 and history.step[1] > history.step[0]:
67            self.handler.improvement()
68        self.handler.update_fit()
69
70class ConvergenceMonitor(object):
71    """
72    ConvergenceMonitor contains population summary statistics to show progress
73    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
74    just a list [ (best, ) ] if population size is 1.
75    """
76    def __init__(self):
77        self.convergence = []
78
79    def config_history(self, history):
80        history.requires(value=1, population_values=1)
81
82    def __call__(self, history):
83        best = history.value[0]
84        try:
85            p = history.population_values[0]
86            n,p = len(p), numpy.sort(p)
87            QI,Qmid, = int(0.2*n),int(0.5*n)
88            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
89        except:
90            self.convergence.append((best, best,best,best,best,best))
91
92
93# Note: currently using bumps parameters for each parameter object so that
94# a SasFitness can be used directly in bumps with the usual semantics.
95# The disadvantage of this technique is that we need to copy every parameter
96# back into the model each time the function is evaluated.  We could instead
97# define reference parameters for each sas parameter, but then we would not
98# be able to express constraints using python expressions in the usual way
99# from bumps, and would instead need to use string expressions.
100class SasFitness(object):
101    """
102    Wrap SAS model as a bumps fitness object
103    """
104    def __init__(self, model, data, fitted=[], constraints={},
105                 initial_values=None, **kw):
106        self.name = model.name
107        self.model = model.model
108        self.data = data
109        if self.data.smearer is not None:
110            self.data.smearer.model = self.model
111        self._define_pars()
112        self._init_pars(kw)
113        if initial_values is not None:
114            self._reset_pars(fitted, initial_values)
115        self.constraints = dict(constraints)
116        self.set_fitted(fitted)
117        self.update()
118
119    def _reset_pars(self, names, values):
120        for k,v in zip(names, values):
121            self._pars[k].value = v
122
123    def _define_pars(self):
124        self._pars = {}
125        for k in self.model.getParamList():
126            name = ".".join((self.name,k))
127            value = self.model.getParam(k)
128            bounds = self.model.details.get(k,["",None,None])[1:3]
129            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
130                                                fixed=True, name=name)
131        #print parameter.summarize(self._pars.values())
132
133    def _init_pars(self, kw):
134        for k,v in kw.items():
135            # dispersion parameters initialized with _field instead of .field
136            if k.endswith('_width'): k = k[:-6]+'.width'
137            elif k.endswith('_npts'): k = k[:-5]+'.npts'
138            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
139            elif k.endswith('_type'): k = k[:-5]+'.type'
140            if k not in self._pars:
141                formatted_pars = ", ".join(sorted(self._pars.keys()))
142                raise KeyError("invalid parameter %r for %s--use one of: %s"
143                               %(k, self.model, formatted_pars))
144            if '.' in k and not k.endswith('.width'):
145                self.model.setParam(k, v)
146            elif isinstance(v, parameter.BaseParameter):
147                self._pars[k] = v
148            elif isinstance(v, (tuple,list)):
149                low, high = v
150                self._pars[k].value = (low+high)/2
151                self._pars[k].range(low,high)
152            else:
153                self._pars[k].value = v
154
155    def set_fitted(self, param_list):
156        """
157        Flag a set of parameters as fitted parameters.
158        """
159        for k,p in self._pars.items():
160            p.fixed = (k not in param_list or k in self.constraints)
161        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
162        self.computed_par_names = [k for k in param_list if k in self.constraints]
163        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
164        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
165
166    # ===== Fitness interface ====
167    def parameters(self):
168        return self._pars
169
170    def update(self):
171        for k,v in self._pars.items():
172            #print "updating",k,v,v.value
173            self.model.setParam(k,v.value)
174        self._dirty = True
175
176    def _recalculate(self):
177        if self._dirty:
178            self._residuals, self._theory \
179                = self.data.residuals(self.model.evalDistribution)
180            self._dirty = False
181
182    def numpoints(self):
183        return numpy.sum(self.data.idx) # number of fitted points
184
185    def nllf(self):
186        return 0.5*numpy.sum(self.residuals()**2)
187
188    def theory(self):
189        self._recalculate()
190        return self._theory
191
192    def residuals(self):
193        self._recalculate()
194        return self._residuals
195
196    # Not implementing the data methods for now:
197    #
198    #     resynth_data/restore_data/save/plot
199
200class ParameterExpressions(object):
201    def __init__(self, models):
202        self.models = models
203        self._setup()
204
205    def _setup(self):
206        exprs = {}
207        for M in self.models:
208            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
209        if exprs:
210            symtab = dict((".".join((M.name, k)), p)
211                          for M in self.models
212                          for k,p in M.parameters().items())
213            self.update = compile_constraints(symtab, exprs)
214        else:
215            self.update = lambda: 0
216
217    def __call__(self):
218        self.update()
219
220    def __getstate__(self):
221        return self.models
222
223    def __setstate__(self, state):
224        self.models = state
225        self._setup()
226
227class BumpsFit(FitEngine):
228    """
229    Fit a model using bumps.
230    """
231    def __init__(self):
232        """
233        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
234        with Uid as keys
235        """
236        FitEngine.__init__(self)
237        self.curr_thread = None
238
239    def fit(self, msg_q=None,
240            q=None, handler=None, curr_thread=None,
241            ftol=1.49012e-8, reset_flag=False):
242        # Build collection of bumps fitness calculators
243        models = [SasFitness(model=M.get_model(),
244                             data=M.get_data(),
245                             constraints=M.constraints,
246                             fitted=M.pars,
247                             initial_values=M.vals if reset_flag else None)
248                  for M in self.fit_arrange_dict.values()
249                  if M.get_to_fit()]
250        if len(models) == 0:
251            raise RuntimeError("Nothing to fit")
252        problem = FitProblem(models)
253
254        # TODO: need better handling of parameter expressions and bounds constraints
255        # so that they are applied during polydispersity calculations.  This
256        # will remove the immediate need for the setp_hook in bumps, though
257        # bumps may still need something similar, such as a sane class structure
258        # which allows a subclass to override setp.
259        problem.setp_hook = ParameterExpressions(models)
260
261        # Run the fit
262        result = run_bumps(problem, handler, curr_thread)
263        if handler is not None:
264            handler.update_fit(last=True)
265
266        # TODO: shouldn't reference internal parameters of fit problem
267        varying = problem._parameters
268        # collect the results
269        all_results = []
270        for M in problem.models:
271            fitness = M.fitness
272            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
273            R = FResult(model=fitness.model, data=fitness.data,
274                        param_list=fitness.fitted_par_names+fitness.computed_par_names)
275            R.theory = fitness.theory()
276            R.residuals = fitness.residuals()
277            R.index = fitness.data.idx
278            R.fitter_id = self.fitter_id
279            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
280            R.stderr = numpy.hstack((result['stderr'][fitted_index],
281                                     numpy.NaN*numpy.ones(len(fitness.computed_pars))))
282            R.pvec = numpy.hstack((result['value'][fitted_index],
283                                  [p.value for p in fitness.computed_pars]))
284            R.success = result['success']
285            R.fitness = numpy.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
286            R.convergence = result['convergence']
287            if result['uncertainty'] is not None:
288                R.uncertainty_state = result['uncertainty']
289            all_results.append(R)
290
291        if q is not None:
292            q.put(all_results)
293            return q
294        else:
295            return all_results
296
297def run_bumps(problem, handler, curr_thread):
298    def abort_test():
299        if curr_thread is None: return False
300        try: curr_thread.isquit()
301        except KeyboardInterrupt:
302            if handler is not None:
303                handler.stop("Fitting: Terminated!!!")
304            return True
305        return False
306
307    fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
308    fitclass = fitopts.fitclass
309    options = fitopts.options.copy()
310    max_step = fitopts.options.get('steps', 0) + fitopts.options.get('burn', 0)
311    pars = [p.name for p in problem._parameters]
312    options['monitors'] = [
313        BumpsMonitor(handler, max_step, pars, problem.dof),
314        ConvergenceMonitor(),
315        ]
316    fitdriver = fitters.FitDriver(fitclass, problem=problem,
317                                  abort_test=abort_test, **options)
318    omp_threads = int(os.environ.get('OMP_NUM_THREADS','0'))
319    mapper = MPMapper if omp_threads == 1 else SerialMapper       
320    fitdriver.mapper = mapper.start_mapper(problem, None)
321    #import time; T0 = time.time()
322    try:
323        best, fbest = fitdriver.fit()
324    except:
325        import traceback; traceback.print_exc()
326        raise
327    finally:
328        mapper.stop_mapper(fitdriver.mapper)
329
330
331    convergence_list = options['monitors'][-1].convergence
332    convergence = (2*numpy.asarray(convergence_list)/problem.dof
333                   if convergence_list else numpy.empty((0,1),'d'))
334    return {
335        'value': best,
336        'stderr': fitdriver.stderr(),
337        'success': True, # better success reporting in bumps
338        'convergence': convergence,
339        'uncertainty': getattr(fitdriver.fitter, 'state', None),
340        }
341
Note: See TracBrowser for help on using the repository browser.