source: sasview/src/sans/fit/BumpsFitting.py @ fb7180c

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since fb7180c was fb7180c, checked in by pkienzle, 10 years ago

enable park constrained fit test

  • Property mode set to 100644
File size: 12.5 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import time
5
6import numpy
7
8from bumps import fitters
9from bumps.mapper import SerialMapper
10
11from sans.fit.AbstractFitEngine import FitEngine
12from sans.fit.AbstractFitEngine import FResult
13
14class BumpsMonitor(object):
15    def __init__(self, handler, max_step=0):
16        self.handler = handler
17        self.max_step = max_step
18
19    def config_history(self, history):
20        history.requires(time=1, value=2, point=1, step=1)
21
22    def __call__(self, history):
23        self.handler.progress(history.step[0], self.max_step)
24        if len(history.step)>1 and history.step[1] > history.step[0]:
25            self.handler.improvement()
26        self.handler.update_fit()
27
28class ConvergenceMonitor(object):
29    """
30    ConvergenceMonitor contains population summary statistics to show progress
31    of the fit.  This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or
32    just a list [ (best, ) ] if population size is 1.
33    """
34    def __init__(self):
35        self.convergence = []
36
37    def config_history(self, history):
38        history.requires(value=1, population_values=1)
39
40    def __call__(self, history):
41        best = history.value[0]
42        try:
43            p = history.population_values[0]
44            n,p = len(p), numpy.sort(p)
45            QI,Qmid, = int(0.2*n),int(0.5*n)
46            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1]))
47        except:
48            self.convergence.append((best, ))
49
50class SasProblem(object):
51    """
52    Wrap the SAS model in a form that can be understood by bumps.
53    """
54    def __init__(self, param_list, model=None, data=None, fitresult=None,
55                 handler=None, curr_thread=None, msg_q=None):
56        """
57        :param Model: the model wrapper fro sans -model
58        :param Data: the data wrapper for sans data
59        """
60        self.model = model
61        self.data = data
62        self.param_list = param_list
63        self.res = None
64        self.theory = None
65
66    @property
67    def name(self):
68        return self.model.name
69
70    @property
71    def dof(self):
72        return self.data.num_points - len(self.param_list)
73
74    def summarize(self):
75        """
76        Return a stylized list of parameter names and values with range bars
77        suitable for printing.
78        """
79        output = []
80        bounds = self.bounds()
81        for i,p in enumerate(self.getp()):
82            name = self.param_list[i]
83            low,high = bounds[:,i]
84            range = ",".join((("[%g"%low if numpy.isfinite(low) else "(-inf"),
85                              ("%g]"%high if numpy.isfinite(high) else "inf)")))
86            if not numpy.isfinite(p):
87                bar = "*invalid* "
88            else:
89                bar = ['.']*10
90                if numpy.isfinite(high-low):
91                    position = int(9.999999999 * float(p-low)/float(high-low))
92                    if position < 0: bar[0] = '<'
93                    elif position > 9: bar[9] = '>'
94                    else: bar[position] = '|'
95                bar = "".join(bar)
96            output.append("%40s %s %10g in %s"%(name,bar,p,range))
97        return "\n".join(output)
98
99    def nllf(self, p=None):
100        residuals = self.residuals(p)
101        return 0.5*numpy.sum(residuals**2)
102
103    def setp(self, p):
104        for k,v in zip(self.param_list, p):
105            self.model.setParam(k,v)
106        #self.model.set_params(self.param_list, params)
107
108    def getp(self):
109        return numpy.array([self.model.getParam(k) for k in self.param_list])
110        #return numpy.asarray(self.model.get_params(self.param_list))
111
112    def bounds(self):
113        return numpy.array([self._getrange(p) for p in self.param_list]).T
114
115    def labels(self):
116        return self.param_list
117
118    def _getrange(self, p):
119        """
120        Override _getrange of park parameter
121        return the range of parameter
122        """
123        lo, hi = self.model.details[p][1:3]
124        if lo is None: lo = -numpy.inf
125        if hi is None: hi = numpy.inf
126        return lo, hi
127
128    def randomize(self, n):
129        p = self.getp()
130        # since randn is symmetric and random, doesn't matter
131        # point value is negative.
132        # TODO: throw in bounds checking!
133        return numpy.random.randn(n, len(self.param_list))*p + p
134
135    def chisq(self):
136        """
137        Calculates chi^2
138
139        :param params: list of parameter values
140
141        :return: chi^2
142
143        """
144        return numpy.sum(self.res**2)/self.dof
145
146    def residuals(self, params=None):
147        """
148        Compute residuals
149        :param params: value of parameters to fit
150        """
151        if params is not None: self.setp(params)
152        #import thread
153        #print "params", params
154        self.res, self.theory = self.data.residuals(self.model.evalDistribution)
155        return self.res
156
157BOUNDS_PENALTY = 1e6 # cost for going out of bounds on unbounded fitters
158class MonitoredSasProblem(SasProblem):
159    """
160    SAS problem definition for optimizers which do not have monitoring or bounds.
161    """
162    def __init__(self, param_list, model=None, data=None, fitresult=None,
163                 handler=None, curr_thread=None, msg_q=None, update_rate=1):
164        """
165        :param Model: the model wrapper fro sans -model
166        :param Data: the data wrapper for sans data
167        """
168        SasProblem.__init__(self, param_list, model, data)
169        self.msg_q = msg_q
170        self.curr_thread = curr_thread
171        self.handler = handler
172        self.fitresult = fitresult
173        #self.last_update = time.time()
174        #self.func_name = "Functor"
175        #self.name = "Fill in proper name!"
176
177    def residuals(self, p):
178        """
179        Cost function for scipy.optimize.leastsq, which does not have a monitor
180        built into the algorithm, and instead relies on a monitor built into
181        the cost function.
182        """
183        # Note: technically, unbounded fitters and unmonitored fitters are
184        self.setp(p)
185
186        # Compute penalty for being out of bounds which increases the farther
187        # you get out of bounds.  This allows derivative following algorithms
188        # to point back toward the feasible region.
189        penalty = self.bounds_penalty()
190        if penalty > 0:
191            self.theory = numpy.ones(self.data.num_points)
192            self.res = self.theory*(penalty/self.data.num_points) + BOUNDS_PENALTY
193            return self.res
194
195        # If no penalty, then we are not out of bounds and we can use the
196        # normal residual calculation
197        SasProblem.residuals(self, p)
198
199        # send update to the application
200        if True:
201            #self.fitresult.set_model(model=self.model)
202            # copy residuals into fit results
203            self.fitresult.residuals = self.res+0
204            self.fitresult.iterations += 1
205            self.fitresult.theory = self.theory+0
206
207            self.fitresult.p = numpy.array(p) # force copy, and coversion to array
208            self.fitresult.set_fitness(fitness=self.chisq())
209            if self.msg_q is not None:
210                self.msg_q.put(self.fitresult)
211
212            if self.handler is not None:
213                self.handler.set_result(result=self.fitresult)
214                self.handler.update_fit()
215
216            if self.curr_thread != None:
217                try:
218                    self.curr_thread.isquit()
219                except:
220                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
221                    #msg += "fitting may cause a 'Functor error message' "
222                    #msg += "being recorded in the log file....."
223                    #self.handler.stop(msg)
224                    raise
225
226        return self.res
227
228    def bounds_penalty(self):
229        from numpy import sum, where
230        p, bounds = self.getp(), self.bounds()
231        return (sum(where(p<bounds[:,0], bounds[:,0]-p, 0)**2)
232              + sum(where(p>bounds[:,1], bounds[:,1]-p, 0)**2) )
233
234class BumpsFit(FitEngine):
235    """
236    Fit a model using bumps.
237    """
238    def __init__(self):
239        """
240        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
241        with Uid as keys
242        """
243        FitEngine.__init__(self)
244        self.curr_thread = None
245
246    def fit(self, msg_q=None,
247            q=None, handler=None, curr_thread=None,
248            ftol=1.49012e-8, reset_flag=False):
249        """
250        """
251        fitproblem = []
252        for fproblem in self.fit_arrange_dict.itervalues():
253            if fproblem.get_to_fit() == 1:
254                fitproblem.append(fproblem)
255        if len(fitproblem) > 1 :
256            msg = "Bumps can't fit more than a single fit problem at a time."
257            raise RuntimeError, msg
258        elif len(fitproblem) == 0 :
259            raise RuntimeError, "No problem scheduled for fitting."
260        model = fitproblem[0].get_model()
261        if reset_flag:
262            # reset the initial value; useful for batch
263            for name in fitproblem[0].pars:
264                ind = fitproblem[0].pars.index(name)
265                model.setParam(name, fitproblem[0].vals[ind])
266        data = fitproblem[0].get_data()
267
268        self.curr_thread = curr_thread
269
270        result = FResult(model=model, data=data, param_list=self.param_list)
271        result.pars = fitproblem[0].pars
272        result.fitter_id = self.fitter_id
273        result.index = data.idx
274        if handler is not None:
275            handler.set_result(result=result)
276
277        if True: # bumps
278            problem = SasProblem(param_list=self.param_list,
279                                 model=model.model,
280                                 data=data)
281            run_bumps(problem, result, ftol,
282                      handler, curr_thread, msg_q)
283        else: # scipy levenburg marquardt
284            problem = SasProblem(param_list=self.param_list,
285                                 model=model.model,
286                                 data=data,
287                                 handler=handler,
288                                 fitresult=result,
289                                 curr_thread=curr_thread,
290                                 msg_q=msg_q)
291            run_levenburg_marquardt(problem, result, ftol)
292
293        if handler is not None:
294            handler.update_fit(last=True)
295        if q is not None:
296            q.put(result)
297            return q
298        #if success < 1 or success > 5:
299        #    result.fitness = None
300        return [result]
301
302def run_bumps(problem, result, ftol, handler, curr_thread, msg_q):
303    def abort_test():
304        if curr_thread is None: return False
305        try: curr_thread.isquit()
306        except KeyboardInterrupt:
307            if handler is not None:
308                handler.stop("Fitting: Terminated!!!")
309            return True
310        return False
311
312    fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
313    fitclass = fitopts.fitclass
314    options = fitopts.options.copy()
315    max_steps = fitopts.options.get('steps', 0) + fitopts.options.get('burn', 0)
316    if 'monitors' not in options:
317        options['monitors'] = [BumpsMonitor(handler, max_steps)]
318    options['monitors'] += [ ConvergenceMonitor() ]
319    options['ftol'] = ftol
320    fitdriver = fitters.FitDriver(fitclass, problem=problem,
321                                  abort_test=abort_test, **options)
322    mapper = SerialMapper
323    fitdriver.mapper = mapper.start_mapper(problem, None)
324    try:
325        best, fbest = fitdriver.fit()
326    except:
327        import traceback; traceback.print_exc()
328        raise
329    finally:
330        mapper.stop_mapper(fitdriver.mapper)
331    #print "best,fbest",best,fbest,problem.dof
332    result.fitness = 2*fbest/problem.dof
333    #print "fitness",result.fitness
334    result.stderr  = fitdriver.stderr()
335    result.pvec = best
336    # TODO: track success better
337    result.success = True
338    result.theory = problem.theory
339    # For the convergence plot
340    pop = numpy.asarray(options['monitors'][-1].convergence)
341    result.convergence = 2*pop/problem.dof
342    # Bumps uncertainty state
343    try: result.uncertainty_state = fitdriver.fitter.state
344    except AttributeError: pass
345
346def run_levenburg_marquardt(problem, result, ftol):
347    # This import must be here; otherwise it will be confused when more
348    # than one thread exist.
349    from scipy import optimize
350
351    out, cov_x, _, mesg, success = optimize.leastsq(problem.residuals,
352                                                    problem.getp(),
353                                                    ftol=ftol,
354                                                    full_output=1)
355    if cov_x is not None and numpy.isfinite(cov_x).all():
356        stderr = numpy.sqrt(numpy.diag(cov_x))
357    else:
358        stderr = []
359    result.fitness = problem.chisq()
360    result.stderr  = stderr
361    result.pvec = out
362    result.success = success
363    result.theory = problem.theory
364
Note: See TracBrowser for help on using the repository browser.