source: sasview/src/sans/fit/BumpsFitting.py @ 880f170

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 880f170 was 35086c3, checked in by pkienzle, 10 years ago

show fit progress for bumps

  • Property mode set to 100644
File size: 11.8 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4from datetime import timedelta, datetime
5
6import numpy
7
8from bumps import fitters
9from bumps.mapper import SerialMapper
10from bumps import parameter
11from bumps.fitproblem import FitProblem
12
13from .AbstractFitEngine import FitEngine
14from .AbstractFitEngine import FResult
15from .expression import compile_constraints
16
17class Progress(object):
18    def __init__(self, history, max_step, pars, dof):
19        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
20        # Depending on the time remaining, either display the expected
21        # time of completion, or the amount of time remaining.  Use precision
22        # appropriate for the duration.
23        if remaining_time >= 1800:
24            completion_time = datetime.now() + timedelta(seconds=remaining_time)
25            if remaining_time >= 36000:
26                time = completion_time.strftime('%Y-%m-%d %H:%M')
27            else:
28                time = completion_time.strftime('%H:%M')
29        else:
30            if remaining_time >= 3600:
31                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
32            elif remaining_time >= 60:
33                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
34            else:
35                time = '%ds'%remaining_time
36        chisq = "%.3g"%(2*history.value[0]/dof)
37        step = "%d of %d"%(history.step[0], max_step)
38        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
39        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
40                      for i,(k,v) in enumerate(zip(pars,history.point[0]))]
41        self.msg = "".join([header]+parameters)
42
43    def __str__(self):
44        return self.msg
45
46
47class BumpsMonitor(object):
48    def __init__(self, handler, max_step, pars, dof):
49        self.handler = handler
50        self.max_step = max_step
51        self.pars = pars
52        self.dof = dof
53
54    def config_history(self, history):
55        history.requires(time=1, value=2, point=1, step=1)
56
57    def __call__(self, history):
58        if self.handler is None: return
59        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
60        self.handler.progress(history.step[0], self.max_step)
61        if len(history.step)>1 and history.step[1] > history.step[0]:
62            self.handler.improvement()
63        self.handler.update_fit()
64
65class ConvergenceMonitor(object):
66    """
67    ConvergenceMonitor contains population summary statistics to show progress
68    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
69    just a list [ (best, ) ] if population size is 1.
70    """
71    def __init__(self):
72        self.convergence = []
73
74    def config_history(self, history):
75        history.requires(value=1, population_values=1)
76
77    def __call__(self, history):
78        best = history.value[0]
79        try:
80            p = history.population_values[0]
81            n,p = len(p), numpy.sort(p)
82            QI,Qmid, = int(0.2*n),int(0.5*n)
83            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
84        except:
85            self.convergence.append((best, best,best,best,best,best))
86
87
88# Note: currently using bumps parameters for each parameter object so that
89# a SasFitness can be used directly in bumps with the usual semantics.
90# The disadvantage of this technique is that we need to copy every parameter
91# back into the model each time the function is evaluated.  We could instead
92# define reference parameters for each sans parameter, but then we would not
93# be able to express constraints using python expressions in the usual way
94# from bumps, and would instead need to use string expressions.
95class SasFitness(object):
96    """
97    Wrap SAS model as a bumps fitness object
98    """
99    def __init__(self, model, data, fitted=[], constraints={},
100                 initial_values=None, **kw):
101        self.name = model.name
102        self.model = model.model
103        self.data = data
104        self._define_pars()
105        self._init_pars(kw)
106        if initial_values is not None:
107            self._reset_pars(fitted, initial_values)
108        self.constraints = dict(constraints)
109        self.set_fitted(fitted)
110        self.update()
111
112    def _reset_pars(self, names, values):
113        for k,v in zip(names, values):
114            self._pars[k].value = v
115
116    def _define_pars(self):
117        self._pars = {}
118        for k in self.model.getParamList():
119            name = ".".join((self.name,k))
120            value = self.model.getParam(k)
121            bounds = self.model.details.get(k,["",None,None])[1:3]
122            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
123                                                fixed=True, name=name)
124        #print parameter.summarize(self._pars.values())
125
126    def _init_pars(self, kw):
127        for k,v in kw.items():
128            # dispersion parameters initialized with _field instead of .field
129            if k.endswith('_width'): k = k[:-6]+'.width'
130            elif k.endswith('_npts'): k = k[:-5]+'.npts'
131            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
132            elif k.endswith('_type'): k = k[:-5]+'.type'
133            if k not in self._pars:
134                formatted_pars = ", ".join(sorted(self._pars.keys()))
135                raise KeyError("invalid parameter %r for %s--use one of: %s"
136                               %(k, self.model, formatted_pars))
137            if '.' in k and not k.endswith('.width'):
138                self.model.setParam(k, v)
139            elif isinstance(v, parameter.BaseParameter):
140                self._pars[k] = v
141            elif isinstance(v, (tuple,list)):
142                low, high = v
143                self._pars[k].value = (low+high)/2
144                self._pars[k].range(low,high)
145            else:
146                self._pars[k].value = v
147
148    def set_fitted(self, param_list):
149        """
150        Flag a set of parameters as fitted parameters.
151        """
152        for k,p in self._pars.items():
153            p.fixed = (k not in param_list or k in self.constraints)
154        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
155        self.computed_par_names = [k for k in param_list if k in self.constraints]
156        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
157        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
158
159    # ===== Fitness interface ====
160    def parameters(self):
161        return self._pars
162
163    def update(self):
164        for k,v in self._pars.items():
165            #print "updating",k,v,v.value
166            self.model.setParam(k,v.value)
167        self._dirty = True
168
169    def _recalculate(self):
170        if self._dirty:
171            self._residuals, self._theory = self.data.residuals(self.model.evalDistribution)
172            self._dirty = False
173
174    def numpoints(self):
175        return numpy.sum(self.data.idx) # number of fitted points
176
177    def nllf(self):
178        return 0.5*numpy.sum(self.residuals()**2)
179
180    def theory(self):
181        self._recalculate()
182        return self._theory
183
184    def residuals(self):
185        self._recalculate()
186        return self._residuals
187
188    # Not implementing the data methods for now:
189    #
190    #     resynth_data/restore_data/save/plot
191
192class BumpsFit(FitEngine):
193    """
194    Fit a model using bumps.
195    """
196    def __init__(self):
197        """
198        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
199        with Uid as keys
200        """
201        FitEngine.__init__(self)
202        self.curr_thread = None
203
204    def fit(self, msg_q=None,
205            q=None, handler=None, curr_thread=None,
206            ftol=1.49012e-8, reset_flag=False):
207        # Build collection of bumps fitness calculators
208        models = [SasFitness(model=M.get_model(),
209                             data=M.get_data(),
210                             constraints=M.constraints,
211                             fitted=M.pars,
212                             initial_values=M.vals if reset_flag else None)
213                  for M in self.fit_arrange_dict.values()
214                  if M.get_to_fit()]
215        if len(models) == 0:
216            raise RuntimeError("Nothing to fit")
217        problem = FitProblem(models)
218
219
220        # Build constraint expressions
221        exprs = {}
222        for M in models:
223            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
224        if exprs:
225            symtab = dict((".".join((M.name, k)), p)
226                          for M in models
227                          for k,p in M.parameters().items())
228            constraints = compile_constraints(symtab, exprs)
229        else:
230            constraints = lambda: 0
231
232        # Override model update so that parameter constraints are applied
233        problem._model_update = problem.model_update
234        def model_update():
235            constraints()
236            problem._model_update()
237        problem.model_update = model_update
238
239        # Run the fit
240        result = run_bumps(problem, handler, curr_thread)
241        if handler is not None:
242            handler.update_fit(last=True)
243
244        # TODO: shouldn't reference internal parameters of fit problem
245        varying = problem._parameters
246        # collect the results
247        all_results = []
248        for M in problem.models:
249            fitness = M.fitness
250            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
251            R = FResult(model=fitness.model, data=fitness.data,
252                        param_list=fitness.fitted_par_names+fitness.computed_par_names)
253            R.theory = fitness.theory()
254            R.residuals = fitness.residuals()
255            R.index = fitness.data.idx
256            R.fitter_id = self.fitter_id
257            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
258            R.stderr = numpy.hstack((result['stderr'][fitted_index],
259                                     numpy.NaN*numpy.ones(len(fitness.computed_pars))))
260            R.pvec = numpy.hstack((result['value'][fitted_index],
261                                  [p.value for p in fitness.computed_pars]))
262            R.success = result['success']
263            R.fitness = numpy.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
264            R.convergence = result['convergence']
265            if result['uncertainty'] is not None:
266                R.uncertainty_state = result['uncertainty']
267            all_results.append(R)
268
269        if q is not None:
270            q.put(all_results)
271            return q
272        else:
273            return all_results
274
275def run_bumps(problem, handler, curr_thread):
276    def abort_test():
277        if curr_thread is None: return False
278        try: curr_thread.isquit()
279        except KeyboardInterrupt:
280            if handler is not None:
281                handler.stop("Fitting: Terminated!!!")
282            return True
283        return False
284
285    fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
286    fitclass = fitopts.fitclass
287    options = fitopts.options.copy()
288    max_step = fitopts.options.get('steps', 0) + fitopts.options.get('burn', 0)
289    pars = [p.name for p in problem._parameters]
290    options['monitors'] = [
291        BumpsMonitor(handler, max_step, pars, problem.dof),
292        ConvergenceMonitor(),
293        ]
294    fitdriver = fitters.FitDriver(fitclass, problem=problem,
295                                  abort_test=abort_test, **options)
296    mapper = SerialMapper
297    fitdriver.mapper = mapper.start_mapper(problem, None)
298    #import time; T0 = time.time()
299    try:
300        best, fbest = fitdriver.fit()
301    except:
302        import traceback; traceback.print_exc()
303        raise
304    finally:
305        mapper.stop_mapper(fitdriver.mapper)
306
307
308    convergence_list = options['monitors'][-1].convergence
309    convergence = (2*numpy.asarray(convergence_list)/problem.dof
310                   if convergence_list else numpy.empty((0,1),'d'))
311    return {
312        'value': best,
313        'stderr': fitdriver.stderr(),
314        'success': True, # better success reporting in bumps
315        'convergence': convergence,
316        'uncertainty': getattr(fitdriver.fitter, 'state', None),
317        }
318
Note: See TracBrowser for help on using the repository browser.