source: sasview/src/sas/fit/BumpsFitting.py @ 573e7034

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 573e7034 was fd5ac0d, checked in by krzywon, 10 years ago

I have completed the removal of all SANS references.
I will build, run, and run all unit tests before pushing.

  • Property mode set to 100644
File size: 12.4 KB
RevLine 
[6fe5100]1"""
2BumpsFitting module runs the bumps optimizer.
3"""
[249a7c6]4import os
[35086c3]5from datetime import timedelta, datetime
6
[6fe5100]7import numpy
8
9from bumps import fitters
[249a7c6]10from bumps.mapper import SerialMapper, MPMapper
[e3efa6b3]11from bumps import parameter
12from bumps.fitproblem import FitProblem
[6fe5100]13
[fd5ac0d]14from sas.fit.AbstractFitEngine import FitEngine
15from sas.fit.AbstractFitEngine import FResult
16from sas.fit.expression import compile_constraints
[6fe5100]17
[35086c3]18class Progress(object):
19    def __init__(self, history, max_step, pars, dof):
20        remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1))
21        # Depending on the time remaining, either display the expected
22        # time of completion, or the amount of time remaining.  Use precision
23        # appropriate for the duration.
24        if remaining_time >= 1800:
25            completion_time = datetime.now() + timedelta(seconds=remaining_time)
26            if remaining_time >= 36000:
27                time = completion_time.strftime('%Y-%m-%d %H:%M')
28            else:
29                time = completion_time.strftime('%H:%M')
30        else:
31            if remaining_time >= 3600:
32                time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60)
33            elif remaining_time >= 60:
34                time = '%dm %ds'%(remaining_time//60, remaining_time%60)
35            else:
36                time = '%ds'%remaining_time
37        chisq = "%.3g"%(2*history.value[0]/dof)
38        step = "%d of %d"%(history.step[0], max_step)
39        header = "=== Steps: %s  chisq: %s  ETA: %s\n"%(step, chisq, time)
40        parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | "))
41                      for i,(k,v) in enumerate(zip(pars,history.point[0]))]
42        self.msg = "".join([header]+parameters)
43
44    def __str__(self):
45        return self.msg
46
47
[85f17f6]48class BumpsMonitor(object):
[35086c3]49    def __init__(self, handler, max_step, pars, dof):
[85f17f6]50        self.handler = handler
51        self.max_step = max_step
[35086c3]52        self.pars = pars
53        self.dof = dof
[ed4aef2]54
[85f17f6]55    def config_history(self, history):
56        history.requires(time=1, value=2, point=1, step=1)
[ed4aef2]57
[85f17f6]58    def __call__(self, history):
[e3efa6b3]59        if self.handler is None: return
[35086c3]60        self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof))
[85f17f6]61        self.handler.progress(history.step[0], self.max_step)
62        if len(history.step)>1 and history.step[1] > history.step[0]:
63            self.handler.improvement()
64        self.handler.update_fit()
65
[ed4aef2]66class ConvergenceMonitor(object):
67    """
68    ConvergenceMonitor contains population summary statistics to show progress
69    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
70    just a list [ (best, ) ] if population size is 1.
71    """
72    def __init__(self):
73        self.convergence = []
74
75    def config_history(self, history):
76        history.requires(value=1, population_values=1)
77
78    def __call__(self, history):
79        best = history.value[0]
80        try:
81            p = history.population_values[0]
82            n,p = len(p), numpy.sort(p)
83            QI,Qmid, = int(0.2*n),int(0.5*n)
84            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
85        except:
[e3efa6b3]86            self.convergence.append((best, best,best,best,best,best))
[ed4aef2]87
[e3efa6b3]88
[4e9f227]89# Note: currently using bumps parameters for each parameter object so that
90# a SasFitness can be used directly in bumps with the usual semantics.
91# The disadvantage of this technique is that we need to copy every parameter
92# back into the model each time the function is evaluated.  We could instead
[fd5ac0d]93# define reference parameters for each sas parameter, but then we would not
[4e9f227]94# be able to express constraints using python expressions in the usual way
95# from bumps, and would instead need to use string expressions.
[e3efa6b3]96class SasFitness(object):
[6fe5100]97    """
[e3efa6b3]98    Wrap SAS model as a bumps fitness object
[6fe5100]99    """
[5044543]100    def __init__(self, model, data, fitted=[], constraints={},
101                 initial_values=None, **kw):
[4e9f227]102        self.name = model.name
103        self.model = model.model
[6fe5100]104        self.data = data
[e3efa6b3]105        self._define_pars()
106        self._init_pars(kw)
[5044543]107        if initial_values is not None:
108            self._reset_pars(fitted, initial_values)
[4e9f227]109        self.constraints = dict(constraints)
[e3efa6b3]110        self.set_fitted(fitted)
[5044543]111        self.update()
112
113    def _reset_pars(self, names, values):
114        for k,v in zip(names, values):
115            self._pars[k].value = v
[e3efa6b3]116
117    def _define_pars(self):
118        self._pars = {}
119        for k in self.model.getParamList():
120            name = ".".join((self.name,k))
121            value = self.model.getParam(k)
122            bounds = self.model.details.get(k,["",None,None])[1:3]
123            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
124                                                fixed=True, name=name)
[4e9f227]125        #print parameter.summarize(self._pars.values())
[e3efa6b3]126
127    def _init_pars(self, kw):
128        for k,v in kw.items():
129            # dispersion parameters initialized with _field instead of .field
130            if k.endswith('_width'): k = k[:-6]+'.width'
131            elif k.endswith('_npts'): k = k[:-5]+'.npts'
132            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
133            elif k.endswith('_type'): k = k[:-5]+'.type'
134            if k not in self._pars:
135                formatted_pars = ", ".join(sorted(self._pars.keys()))
136                raise KeyError("invalid parameter %r for %s--use one of: %s"
137                               %(k, self.model, formatted_pars))
138            if '.' in k and not k.endswith('.width'):
139                self.model.setParam(k, v)
140            elif isinstance(v, parameter.BaseParameter):
141                self._pars[k] = v
142            elif isinstance(v, (tuple,list)):
143                low, high = v
144                self._pars[k].value = (low+high)/2
145                self._pars[k].range(low,high)
[95d58d3]146            else:
[e3efa6b3]147                self._pars[k].value = v
148
149    def set_fitted(self, param_list):
[6fe5100]150        """
[e3efa6b3]151        Flag a set of parameters as fitted parameters.
[6fe5100]152        """
[e3efa6b3]153        for k,p in self._pars.items():
[4e9f227]154            p.fixed = (k not in param_list or k in self.constraints)
[4a0dc427]155        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
[bf5e985]156        self.computed_par_names = [k for k in param_list if k in self.constraints]
157        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
158        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
[6fe5100]159
[e3efa6b3]160    # ===== Fitness interface ====
161    def parameters(self):
162        return self._pars
[6fe5100]163
[e3efa6b3]164    def update(self):
165        for k,v in self._pars.items():
[4e9f227]166            #print "updating",k,v,v.value
[e3efa6b3]167            self.model.setParam(k,v.value)
168        self._dirty = True
[6fe5100]169
[e3efa6b3]170    def _recalculate(self):
171        if self._dirty:
172            self._residuals, self._theory = self.data.residuals(self.model.evalDistribution)
173            self._dirty = False
[6fe5100]174
[e3efa6b3]175    def numpoints(self):
176        return numpy.sum(self.data.idx) # number of fitted points
[6fe5100]177
[e3efa6b3]178    def nllf(self):
179        return 0.5*numpy.sum(self.residuals()**2)
180
181    def theory(self):
182        self._recalculate()
183        return self._theory
184
185    def residuals(self):
186        self._recalculate()
187        return self._residuals
188
189    # Not implementing the data methods for now:
190    #
191    #     resynth_data/restore_data/save/plot
[6fe5100]192
[191c648]193class ParameterExpressions(object):
194    def __init__(self, models):
195        self.models = models
196        self._setup()
197
198    def _setup(self):
199        exprs = {}
200        for M in self.models:
201            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
202        if exprs:
203            symtab = dict((".".join((M.name, k)), p)
204                          for M in self.models
205                          for k,p in M.parameters().items())
206            self.update = compile_constraints(symtab, exprs)
207        else:
208            self.update = lambda: 0
209
210    def __call__(self):
211        self.update()
212
213    def __getstate__(self):
214        return self.models
215
216    def __setstate__(self, state):
217        self.models = state
218        self._setup()
219
[6fe5100]220class BumpsFit(FitEngine):
221    """
222    Fit a model using bumps.
223    """
224    def __init__(self):
225        """
226        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
227        with Uid as keys
228        """
229        FitEngine.__init__(self)
230        self.curr_thread = None
231
232    def fit(self, msg_q=None,
233            q=None, handler=None, curr_thread=None,
234            ftol=1.49012e-8, reset_flag=False):
[e3efa6b3]235        # Build collection of bumps fitness calculators
[bf5e985]236        models = [SasFitness(model=M.get_model(),
237                             data=M.get_data(),
238                             constraints=M.constraints,
[5044543]239                             fitted=M.pars,
240                             initial_values=M.vals if reset_flag else None)
[bf5e985]241                  for M in self.fit_arrange_dict.values()
242                  if M.get_to_fit()]
[233c121]243        if len(models) == 0:
244            raise RuntimeError("Nothing to fit")
[e3efa6b3]245        problem = FitProblem(models)
246
[191c648]247        # TODO: need better handling of parameter expressions and bounds constraints
248        # so that they are applied during polydispersity calculations.  This
249        # will remove the immediate need for the setp_hook in bumps, though
250        # bumps may still need something similar, such as a sane class structure
251        # which allows a subclass to override setp.
252        problem.setp_hook = ParameterExpressions(models)
[4e9f227]253
[e3efa6b3]254        # Run the fit
255        result = run_bumps(problem, handler, curr_thread)
[6fe5100]256        if handler is not None:
257            handler.update_fit(last=True)
[e3efa6b3]258
[eff93b8]259        # TODO: shouldn't reference internal parameters of fit problem
[e3efa6b3]260        varying = problem._parameters
261        # collect the results
262        all_results = []
263        for M in problem.models:
264            fitness = M.fitness
265            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
266            R = FResult(model=fitness.model, data=fitness.data,
[bf5e985]267                        param_list=fitness.fitted_par_names+fitness.computed_par_names)
[e3efa6b3]268            R.theory = fitness.theory()
269            R.residuals = fitness.residuals()
[5044543]270            R.index = fitness.data.idx
[e3efa6b3]271            R.fitter_id = self.fitter_id
[eff93b8]272            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
[bf5e985]273            R.stderr = numpy.hstack((result['stderr'][fitted_index],
274                                     numpy.NaN*numpy.ones(len(fitness.computed_pars))))
275            R.pvec = numpy.hstack((result['value'][fitted_index],
276                                  [p.value for p in fitness.computed_pars]))
[e3efa6b3]277            R.success = result['success']
278            R.fitness = numpy.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
279            R.convergence = result['convergence']
280            if result['uncertainty'] is not None:
281                R.uncertainty_state = result['uncertainty']
282            all_results.append(R)
283
[6fe5100]284        if q is not None:
[e3efa6b3]285            q.put(all_results)
[6fe5100]286            return q
[e3efa6b3]287        else:
288            return all_results
[6fe5100]289
[e3efa6b3]290def run_bumps(problem, handler, curr_thread):
[85f17f6]291    def abort_test():
292        if curr_thread is None: return False
293        try: curr_thread.isquit()
294        except KeyboardInterrupt:
295            if handler is not None:
296                handler.stop("Fitting: Terminated!!!")
297            return True
298        return False
299
[6fe5100]300    fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
[95d58d3]301    fitclass = fitopts.fitclass
302    options = fitopts.options.copy()
[e3efa6b3]303    max_step = fitopts.options.get('steps', 0) + fitopts.options.get('burn', 0)
[35086c3]304    pars = [p.name for p in problem._parameters]
[e3efa6b3]305    options['monitors'] = [
[35086c3]306        BumpsMonitor(handler, max_step, pars, problem.dof),
[e3efa6b3]307        ConvergenceMonitor(),
308        ]
[95d58d3]309    fitdriver = fitters.FitDriver(fitclass, problem=problem,
[042f065]310                                  abort_test=abort_test, **options)
[249a7c6]311    omp_threads = int(os.environ.get('OMP_NUM_THREADS','0'))
312    mapper = MPMapper if omp_threads == 1 else SerialMapper       
[6fe5100]313    fitdriver.mapper = mapper.start_mapper(problem, None)
[233c121]314    #import time; T0 = time.time()
[6fe5100]315    try:
316        best, fbest = fitdriver.fit()
317    except:
318        import traceback; traceback.print_exc()
319        raise
[95d58d3]320    finally:
321        mapper.stop_mapper(fitdriver.mapper)
[e3efa6b3]322
323
324    convergence_list = options['monitors'][-1].convergence
325    convergence = (2*numpy.asarray(convergence_list)/problem.dof
326                   if convergence_list else numpy.empty((0,1),'d'))
327    return {
328        'value': best,
329        'stderr': fitdriver.stderr(),
330        'success': True, # better success reporting in bumps
331        'convergence': convergence,
332        'uncertainty': getattr(fitdriver.fitter, 'state', None),
333        }
[6fe5100]334
Note: See TracBrowser for help on using the repository browser.