source: sasview/src/sans/fit/BumpsFitting.py @ 5044543

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 5044543 was 5044543, checked in by pkienzle, 10 years ago

fix batch fit support for bumps

  • Property mode set to 100644
File size: 10.2 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import numpy
5
6from bumps import fitters
7from bumps.mapper import SerialMapper
8from bumps import parameter
9from bumps.fitproblem import FitProblem
10
11from .AbstractFitEngine import FitEngine
12from .AbstractFitEngine import FResult
13from .expression import compile_constraints
14
15class BumpsMonitor(object):
16    def __init__(self, handler, max_step=0):
17        self.handler = handler
18        self.max_step = max_step
19
20    def config_history(self, history):
21        history.requires(time=1, value=2, point=1, step=1)
22
23    def __call__(self, history):
24        if self.handler is None: return
25        self.handler.progress(history.step[0], self.max_step)
26        if len(history.step)>1 and history.step[1] > history.step[0]:
27            self.handler.improvement()
28        self.handler.update_fit()
29
30class ConvergenceMonitor(object):
31    """
32    ConvergenceMonitor contains population summary statistics to show progress
33    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
34    just a list [ (best, ) ] if population size is 1.
35    """
36    def __init__(self):
37        self.convergence = []
38
39    def config_history(self, history):
40        history.requires(value=1, population_values=1)
41
42    def __call__(self, history):
43        best = history.value[0]
44        try:
45            p = history.population_values[0]
46            n,p = len(p), numpy.sort(p)
47            QI,Qmid, = int(0.2*n),int(0.5*n)
48            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
49        except:
50            self.convergence.append((best, best,best,best,best,best))
51
52
53# Note: currently using bumps parameters for each parameter object so that
54# a SasFitness can be used directly in bumps with the usual semantics.
55# The disadvantage of this technique is that we need to copy every parameter
56# back into the model each time the function is evaluated.  We could instead
57# define reference parameters for each sans parameter, but then we would not
58# be able to express constraints using python expressions in the usual way
59# from bumps, and would instead need to use string expressions.
60class SasFitness(object):
61    """
62    Wrap SAS model as a bumps fitness object
63    """
64    def __init__(self, model, data, fitted=[], constraints={},
65                 initial_values=None, **kw):
66        self.name = model.name
67        self.model = model.model
68        self.data = data
69        self._define_pars()
70        self._init_pars(kw)
71        if initial_values is not None:
72            self._reset_pars(fitted, initial_values)
73        self.constraints = dict(constraints)
74        self.set_fitted(fitted)
75        self.update()
76
77    def _reset_pars(self, names, values):
78        for k,v in zip(names, values):
79            self._pars[k].value = v
80
81    def _define_pars(self):
82        self._pars = {}
83        for k in self.model.getParamList():
84            name = ".".join((self.name,k))
85            value = self.model.getParam(k)
86            bounds = self.model.details.get(k,["",None,None])[1:3]
87            self._pars[k] = parameter.Parameter(value=value, bounds=bounds,
88                                                fixed=True, name=name)
89        #print parameter.summarize(self._pars.values())
90
91    def _init_pars(self, kw):
92        for k,v in kw.items():
93            # dispersion parameters initialized with _field instead of .field
94            if k.endswith('_width'): k = k[:-6]+'.width'
95            elif k.endswith('_npts'): k = k[:-5]+'.npts'
96            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas'
97            elif k.endswith('_type'): k = k[:-5]+'.type'
98            if k not in self._pars:
99                formatted_pars = ", ".join(sorted(self._pars.keys()))
100                raise KeyError("invalid parameter %r for %s--use one of: %s"
101                               %(k, self.model, formatted_pars))
102            if '.' in k and not k.endswith('.width'):
103                self.model.setParam(k, v)
104            elif isinstance(v, parameter.BaseParameter):
105                self._pars[k] = v
106            elif isinstance(v, (tuple,list)):
107                low, high = v
108                self._pars[k].value = (low+high)/2
109                self._pars[k].range(low,high)
110            else:
111                self._pars[k].value = v
112
113    def set_fitted(self, param_list):
114        """
115        Flag a set of parameters as fitted parameters.
116        """
117        for k,p in self._pars.items():
118            p.fixed = (k not in param_list or k in self.constraints)
119        self.fitted_par_names = [k for k in param_list if k not in self.constraints]
120        self.computed_par_names = [k for k in param_list if k in self.constraints]
121        self.fitted_pars = [self._pars[k] for k in self.fitted_par_names]
122        self.computed_pars = [self._pars[k] for k in self.computed_par_names]
123
124    # ===== Fitness interface ====
125    def parameters(self):
126        return self._pars
127
128    def update(self):
129        for k,v in self._pars.items():
130            #print "updating",k,v,v.value
131            self.model.setParam(k,v.value)
132        self._dirty = True
133
134    def _recalculate(self):
135        if self._dirty:
136            self._residuals, self._theory = self.data.residuals(self.model.evalDistribution)
137            self._dirty = False
138
139    def numpoints(self):
140        return numpy.sum(self.data.idx) # number of fitted points
141
142    def nllf(self):
143        return 0.5*numpy.sum(self.residuals()**2)
144
145    def theory(self):
146        self._recalculate()
147        return self._theory
148
149    def residuals(self):
150        self._recalculate()
151        return self._residuals
152
153    # Not implementing the data methods for now:
154    #
155    #     resynth_data/restore_data/save/plot
156
157class BumpsFit(FitEngine):
158    """
159    Fit a model using bumps.
160    """
161    def __init__(self):
162        """
163        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
164        with Uid as keys
165        """
166        FitEngine.__init__(self)
167        self.curr_thread = None
168
169    def fit(self, msg_q=None,
170            q=None, handler=None, curr_thread=None,
171            ftol=1.49012e-8, reset_flag=False):
172        # Build collection of bumps fitness calculators
173        models = [SasFitness(model=M.get_model(),
174                             data=M.get_data(),
175                             constraints=M.constraints,
176                             fitted=M.pars,
177                             initial_values=M.vals if reset_flag else None)
178                  for M in self.fit_arrange_dict.values()
179                  if M.get_to_fit()]
180        if len(models) == 0:
181            raise RuntimeError("Nothing to fit")
182        problem = FitProblem(models)
183
184
185        # Build constraint expressions
186        exprs = {}
187        for M in models:
188            exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items())
189        if exprs:
190            symtab = dict((".".join((M.name, k)), p)
191                          for M in models
192                          for k,p in M.parameters().items())
193            constraints = compile_constraints(symtab, exprs)
194        else:
195            constraints = lambda: 0
196
197        # Override model update so that parameter constraints are applied
198        problem._model_update = problem.model_update
199        def model_update():
200            constraints()
201            problem._model_update()
202        problem.model_update = model_update
203
204        # Run the fit
205        result = run_bumps(problem, handler, curr_thread)
206        if handler is not None:
207            handler.update_fit(last=True)
208
209        # TODO: shouldn't reference internal parameters of fit problem
210        varying = problem._parameters
211        # collect the results
212        all_results = []
213        for M in problem.models:
214            fitness = M.fitness
215            fitted_index = [varying.index(p) for p in fitness.fitted_pars]
216            R = FResult(model=fitness.model, data=fitness.data,
217                        param_list=fitness.fitted_par_names+fitness.computed_par_names)
218            R.theory = fitness.theory()
219            R.residuals = fitness.residuals()
220            R.index = fitness.data.idx
221            R.fitter_id = self.fitter_id
222            # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown
223            R.stderr = numpy.hstack((result['stderr'][fitted_index],
224                                     numpy.NaN*numpy.ones(len(fitness.computed_pars))))
225            R.pvec = numpy.hstack((result['value'][fitted_index],
226                                  [p.value for p in fitness.computed_pars]))
227            R.success = result['success']
228            R.fitness = numpy.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index))
229            R.convergence = result['convergence']
230            if result['uncertainty'] is not None:
231                R.uncertainty_state = result['uncertainty']
232            all_results.append(R)
233
234        if q is not None:
235            q.put(all_results)
236            return q
237        else:
238            return all_results
239
240def run_bumps(problem, handler, curr_thread):
241    def abort_test():
242        if curr_thread is None: return False
243        try: curr_thread.isquit()
244        except KeyboardInterrupt:
245            if handler is not None:
246                handler.stop("Fitting: Terminated!!!")
247            return True
248        return False
249
250    fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
251    fitclass = fitopts.fitclass
252    options = fitopts.options.copy()
253    max_step = fitopts.options.get('steps', 0) + fitopts.options.get('burn', 0)
254    options['monitors'] = [
255        BumpsMonitor(handler, max_step),
256        ConvergenceMonitor(),
257        ]
258    fitdriver = fitters.FitDriver(fitclass, problem=problem,
259                                  abort_test=abort_test, **options)
260    mapper = SerialMapper
261    fitdriver.mapper = mapper.start_mapper(problem, None)
262    #import time; T0 = time.time()
263    try:
264        best, fbest = fitdriver.fit()
265    except:
266        import traceback; traceback.print_exc()
267        raise
268    finally:
269        mapper.stop_mapper(fitdriver.mapper)
270
271
272    convergence_list = options['monitors'][-1].convergence
273    convergence = (2*numpy.asarray(convergence_list)/problem.dof
274                   if convergence_list else numpy.empty((0,1),'d'))
275    return {
276        'value': best,
277        'stderr': fitdriver.stderr(),
278        'success': True, # better success reporting in bumps
279        'convergence': convergence,
280        'uncertainty': getattr(fitdriver.fitter, 'state', None),
281        }
282
Note: See TracBrowser for help on using the repository browser.