source: sasview/src/sas/sascalc/fit/BumpsFitting.py @ 1fac6c0

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 1fac6c0 was 7988501, checked in by jhbakker, 8 years ago

Data1D class changed to include SESANS Data format

  • Property mode set to 100644
File size: 13.6 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import os
5from datetime import timedelta, datetime
6
7import numpy
8
9from bumps import fitters
10try:
11    from bumps.options import FIT_CONFIG
12    # Default bumps to use the Levenberg-Marquardt optimizer
13    FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id
14    def get_fitter():
15        return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values
16except:
17    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
18    # Default bumps to use the Levenberg-Marquardt optimizer
19    fitters.FIT_DEFAULT = 'lm'
20    def get_fitter():
21        fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
22        return fitopts.fitclass, fitopts.options.copy()
23
24
25from bumps.mapper import SerialMapper, MPMapper
26from bumps import parameter
27from bumps.fitproblem import FitProblem
28
29from sas.sascalc.fit.AbstractFitEngine import FitEngine
30from sas.sascalc.fit.AbstractFitEngine import FResult
31from sas.sascalc.fit.expression import compile_constraints
32
33class Progress(object):
34    def __init__(self, history, max_step, pars, dof):
35        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
36        # Depending on the time remaining, either display the expected
37        # time of completion, or the amount of time remaining.  Use precision
38        # appropriate for the duration.
39        if remaining_time >= 1800:
40            completion_time = datetime.now() + timedelta(seconds=remaining_time)
41            if remaining_time >= 36000:
42                time = completion_time.strftime('%Y-%m-%d %H:%M')
43            else:
44                time = completion_time.strftime('%H:%M')
45        else:
46            if remaining_time >= 3600:
47                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
48            elif remaining_time >= 60:
49                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
50            else:
51                time = '%ds'%remaining_time
52        chisq = "%.3g"%(2*history.value[0]/dof)
53        step = "%d of %d"%(history.step[0], max_step)
54        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
55        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
56                      for i,(k,v) in enumerate(zip(pars,history.point[0]))]
57        self.msg = "".join([header]+parameters)
58
59    def __str__(self):
60        return self.msg
61
62
63class BumpsMonitor(object):
64    def __init__(self, handler, max_step, pars, dof):
65        self.handler = handler
66        self.max_step = max_step
67        self.pars = pars
68        self.dof = dof
69
70    def config_history(self, history):
71        history.requires(time=1, value=2, point=1, step=1)
72
73    def __call__(self, history):
74        if self.handler is None: return
75        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
76        self.handler.progress(history.step[0], self.max_step)
77        if len(history.step)>1 and history.step[1] > history.step[0]:
78            self.handler.improvement()
79        self.handler.update_fit()
80
81class ConvergenceMonitor(object):
82    """
83    ConvergenceMonitor contains population summary statistics to show progress
84    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
85    just a list [ (best, ) ] if population size is 1.
86    """
87    def __init__(self):
88        self.convergence = []
89
90    def config_history(self, history):
91        history.requires(value=1, population_values=1)
92
93    def __call__(self, history):
94        best = history.value[0]
95        try:
96            p = history.population_values[0]
97            n,p = len(p), numpy.sort(p)
98            QI,Qmid, = int(0.2*n),int(0.5*n)
99            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
100        except:
101            self.convergence.append((best, best,best,best,best,best))
102
103
104# Note: currently using bumps parameters for each parameter object so that
105# a SasFitness can be used directly in bumps with the usual semantics.
106# The disadvantage of this technique is that we need to copy every parameter
107# back into the model each time the function is evaluated.  We could instead
108# define reference parameters for each sas parameter, but then we would not
109# be able to express constraints using python expressions in the usual way
110# from bumps, and would instead need to use string expressions.
111class SasFitness(object):
112    """
113    Wrap SAS model as a bumps fitness object
114    """
115    def __init__(self, model, data, fitted=[], constraints={},
116                 initial_values=None, **kw):
117        self.name = model.name
118        self.model = model.model
119        self.data = data
120        if self.data.smearer is not None:
121            self.data.smearer.model = self.model
122        self._define_pars()
123        self._init_pars(kw)
124        if initial_values is not None:
125            self._reset_pars(fitted, initial_values)
126        self.constraints = dict(constraints)
127        self.set_fitted(fitted)
128        self.update()
129
130    def _reset_pars(self, names, values):
131        for k,v in zip(names, values):
132            self._pars[k].value = v
133
134    def _define_pars(self):
135        self._pars = {}
136        for k in self.model.getParamList():
137            name = ".".join((self.name,k))
138            value = self.model.getParam(k)
139            bounds = self.model.details.get(k,["",None,None])[1:3]
140            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
141                                                fixed=True, name=name)
142        #print parameter.summarize(self._pars.values())
143
144    def _init_pars(self, kw):
145        for k,v in kw.items():
146            # dispersion parameters initialized with _field instead of .field
147            if k.endswith('_width'): k = k[:-6]+'.width'
148            elif k.endswith('_npts'): k = k[:-5]+'.npts'
149            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
150            elif k.endswith('_type'): k = k[:-5]+'.type'
151            if k not in self._pars:
152                formatted_pars = ", ".join(sorted(self._pars.keys()))
153                raise KeyError("invalid parameter %r for %s--use one of: %s"
154                               %(k, self.model, formatted_pars))
155            if '.' in k and not k.endswith('.width'):
156                self.model.setParam(k, v)
157            elif isinstance(v, parameter.BaseParameter):
158                self._pars[k] = v
159            elif isinstance(v, (tuple,list)):
160                low, high = v
161                self._pars[k].value = (low+high)/2
162                self._pars[k].range(low,high)
163            else:
164                self._pars[k].value = v
165
166    def set_fitted(self, param_list):
167        """
168        Flag a set of parameters as fitted parameters.
169        """
170        for k,p in self._pars.items():
171            p.fixed = (k not in param_list or k in self.constraints)
172        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
173        self.computed_par_names = [k for k in param_list if k in self.constraints]
174        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
175        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
176
177    # ===== Fitness interface ====
178    def parameters(self):
179        return self._pars
180
181    def update(self):
182        for k,v in self._pars.items():
183            #print "updating",k,v,v.value
184            self.model.setParam(k,v.value)
185        self._dirty = True
186
187    def _recalculate(self):
188        if self._dirty:
189            self._residuals, self._theory \
190                = self.data.residuals(self.model.evalDistribution)
191            self._dirty = False
192
193    def numpoints(self):
194        return numpy.sum(self.data.idx) # number of fitted points
195
196    def nllf(self):
197        return 0.5*numpy.sum(self.residuals()**2)
198
199    def theory(self):
200        self._recalculate()
201        return self._theory
202
203    def residuals(self):
204        self._recalculate()
205        return self._residuals
206
207    # Not implementing the data methods for now:
208    #
209    #     resynth_data/restore_data/save/plot
210
211class ParameterExpressions(object):
212    def __init__(self, models):
213        self.models = models
214        self._setup()
215
216    def _setup(self):
217        exprs = {}
218        for M in self.models:
219            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
220        if exprs:
221            symtab = dict((".".join((M.name, k)), p)
222                          for M in self.models
223                          for k,p in M.parameters().items())
224            self.update = compile_constraints(symtab, exprs)
225        else:
226            self.update = lambda: 0
227
228    def __call__(self):
229        self.update()
230
231    def __getstate__(self):
232        return self.models
233
234    def __setstate__(self, state):
235        self.models = state
236        self._setup()
237
238class BumpsFit(FitEngine):
239    """
240    Fit a model using bumps.
241    """
242    def __init__(self):
243        """
244        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
245        with Uid as keys
246        """
247        FitEngine.__init__(self)
248        self.curr_thread = None
249
250    def fit(self, msg_q=None,
251            q=None, handler=None, curr_thread=None,
252            ftol=1.49012e-8, reset_flag=False):
253        # Build collection of bumps fitness calculators
254        models = [SasFitness(model=M.get_model(),
255                             data=M.get_data(),
256                             constraints=M.constraints,
257                             fitted=M.pars,
258                             initial_values=M.vals if reset_flag else None)
259                  for M in self.fit_arrange_dict.values()
260                  if M.get_to_fit()]
261        if len(models) == 0:
262            raise RuntimeError("Nothing to fit")
263        problem = FitProblem(models)
264
265        # TODO: need better handling of parameter expressions and bounds constraints
266        # so that they are applied during polydispersity calculations.  This
267        # will remove the immediate need for the setp_hook in bumps, though
268        # bumps may still need something similar, such as a sane class structure
269        # which allows a subclass to override setp.
270        problem.setp_hook = ParameterExpressions(models)
271
272        # Run the fit
273        result = run_bumps(problem, handler, curr_thread)
274        if handler is not None:
275            handler.update_fit(last=True)
276
277        # TODO: shouldn't reference internal parameters of fit problem
278        varying = problem._parameters
279        # collect the results
280        all_results = []
281        for M in problem.models:
282            fitness = M.fitness
283            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
284            param_list = fitness.fitted_par_names + fitness.computed_par_names
285            R = FResult(model=fitness.model, data=fitness.data,
286                        param_list=param_list)
287            R.theory = fitness.theory()
288            R.residuals = fitness.residuals()
289            R.index = fitness.data.idx
290            R.fitter_id = self.fitter_id
291            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
292            R.success = result['success']
293            if R.success:
294                R.stderr = numpy.hstack((result['stderr'][fitted_index],
295                                         numpy.NaN*numpy.ones(len(fitness.computed_pars))))
296                R.pvec = numpy.hstack((result['value'][fitted_index],
297                                      [p.value for p in fitness.computed_pars]))
298                R.fitness = numpy.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
299            else:
300                R.stderr = numpy.NaN*numpy.ones(len(param_list))
301                R.pvec = numpy.asarray( [p.value for p in fitness.fitted_pars+fitness.computed_pars])
302                R.fitness = numpy.NaN
303            R.convergence = result['convergence']
304            if result['uncertainty'] is not None:
305                R.uncertainty_state = result['uncertainty']
306            all_results.append(R)
307
308        if q is not None:
309            q.put(all_results)
310            return q
311        else:
312            return all_results
313
314def run_bumps(problem, handler, curr_thread):
315    def abort_test():
316        if curr_thread is None: return False
317        try: curr_thread.isquit()
318        except KeyboardInterrupt:
319            if handler is not None:
320                handler.stop("Fitting: Terminated!!!")
321            return True
322        return False
323
324    fitclass, options = get_fitter()
325    steps = options.get('steps', 0)
326    if steps == 0:
327        pop = options.get('pop',0)*len(problem._parameters)
328        samples = options.get('samples', 0)
329        steps = (samples+pop-1)/pop if pop != 0 else samples
330    max_step = steps + options.get('burn', 0)
331    pars = [p.name for p in problem._parameters]
332    #x0 = numpy.asarray([p.value for p in problem._parameters])
333    options['monitors'] = [
334        BumpsMonitor(handler, max_step, pars, problem.dof),
335        ConvergenceMonitor(),
336        ]
337    fitdriver = fitters.FitDriver(fitclass, problem=problem,
338                                  abort_test=abort_test, **options)
339    omp_threads = int(os.environ.get('OMP_NUM_THREADS','0'))
340    mapper = MPMapper if omp_threads == 1 else SerialMapper
341    fitdriver.mapper = mapper.start_mapper(problem, None)
342    #import time; T0 = time.time()
343    try:
344        best, fbest = fitdriver.fit()
345    except:
346        import traceback; traceback.print_exc()
347        raise
348    finally:
349        mapper.stop_mapper(fitdriver.mapper)
350
351
352    convergence_list = options['monitors'][-1].convergence
353    convergence = (2*numpy.asarray(convergence_list)/problem.dof
354                   if convergence_list else numpy.empty((0,1),'d'))
355
356    success = best is not None
357    return {
358        'value': best if success else None,
359        'stderr': fitdriver.stderr() if success else None,
360        'success': success,
361        'convergence': convergence,
362        'uncertainty': getattr(fitdriver.fitter, 'state', None),
363        }
364
Note: See TracBrowser for help on using the repository browser.