source: sasview/src/sans/fit/BumpsFitting.py @ 499639c

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 499639c was 85f17f6, checked in by pkienzle, 11 years ago

put progress monitor on bumps

  • Property mode set to 100644
File size: 11.5 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import time
5
6import numpy
7
8from bumps import fitters
9from bumps.mapper import SerialMapper
10
11from sans.fit.AbstractFitEngine import FitEngine
12from sans.fit.AbstractFitEngine import FResult
13
14class BumpsMonitor(object):
15    def __init__(self, handler, max_step=0):
16        self.handler = handler
17        self.max_step = max_step
18    def config_history(self, history):
19        history.requires(time=1, value=2, point=1, step=1)
20    def __call__(self, history):
21        self.handler.progress(history.step[0], self.max_step)
22        if len(history.step)>1 and history.step[1] > history.step[0]:
23            self.handler.improvement()
24        self.handler.update_fit()
25
26class SasProblem(object):
27    """
28    Wrap the SAS model in a form that can be understood by bumps.
29    """
30    def __init__(self, param_list, model=None, data=None, fitresult=None,
31                 handler=None, curr_thread=None, msg_q=None):
32        """
33        :param Model: the model wrapper fro sans -model
34        :param Data: the data wrapper for sans data
35        """
36        self.model = model
37        self.data = data
38        self.param_list = param_list
39        self.res = None
40        self.theory = None
41
42    @property
43    def name(self):
44        return self.model.name
45
46    @property
47    def dof(self):
48        return self.data.num_points - len(self.param_list)
49
50    def summarize(self):
51        """
52        Return a stylized list of parameter names and values with range bars
53        suitable for printing.
54        """
55        output = []
56        bounds = self.bounds()
57        for i,p in enumerate(self.getp()):
58            name = self.param_list[i]
59            low,high = bounds[:,i]
60            range = ",".join((("[%g"%low if numpy.isfinite(low) else "(-inf"),
61                              ("%g]"%high if numpy.isfinite(high) else "inf)")))
62            if not numpy.isfinite(p):
63                bar = "*invalid* "
64            else:
65                bar = ['.']*10
66                if numpy.isfinite(high-low):
67                    position = int(9.999999999 * float(p-low)/float(high-low))
68                    if position < 0: bar[0] = '<'
69                    elif position > 9: bar[9] = '>'
70                    else: bar[position] = '|'
71                bar = "".join(bar)
72            output.append("%40s %s %10g in %s"%(name,bar,p,range))
73        return "\n".join(output)
74
75    def nllf(self, p=None):
76        residuals = self.residuals(p)
77        return 0.5*numpy.sum(residuals**2)
78
79    def setp(self, p):
80        for k,v in zip(self.param_list, p):
81            self.model.setParam(k,v)
82        #self.model.set_params(self.param_list, params)
83
84    def getp(self):
85        return numpy.array([self.model.getParam(k) for k in self.param_list])
86        #return numpy.asarray(self.model.get_params(self.param_list))
87
88    def bounds(self):
89        return numpy.array([self._getrange(p) for p in self.param_list]).T
90
91    def labels(self):
92        return self.param_list
93
94    def _getrange(self, p):
95        """
96        Override _getrange of park parameter
97        return the range of parameter
98        """
99        lo, hi = self.model.details[p][1:3]
100        if lo is None: lo = -numpy.inf
101        if hi is None: hi = numpy.inf
102        return lo, hi
103
104    def randomize(self, n):
105        p = self.getp()
106        # since randn is symmetric and random, doesn't matter
107        # point value is negative.
108        # TODO: throw in bounds checking!
109        return numpy.random.randn(n, len(self.param_list))*p + p
110
111    def chisq(self):
112        """
113        Calculates chi^2
114
115        :param params: list of parameter values
116
117        :return: chi^2
118
119        """
120        return numpy.sum(self.res**2)/self.dof
121
122    def residuals(self, params=None):
123        """
124        Compute residuals
125        :param params: value of parameters to fit
126        """
127        if params is not None: self.setp(params)
128        #import thread
129        #print "params", params
130        self.res, self.theory = self.data.residuals(self.model.evalDistribution)
131        return self.res
132
133BOUNDS_PENALTY = 1e6 # cost for going out of bounds on unbounded fitters
134class MonitoredSasProblem(SasProblem):
135    """
136    SAS problem definition for optimizers which do not have monitoring or bounds.
137    """
138    def __init__(self, param_list, model=None, data=None, fitresult=None,
139                 handler=None, curr_thread=None, msg_q=None, update_rate=1):
140        """
141        :param Model: the model wrapper fro sans -model
142        :param Data: the data wrapper for sans data
143        """
144        SasProblem.__init__(self, param_list, model, data)
145        self.msg_q = msg_q
146        self.curr_thread = curr_thread
147        self.handler = handler
148        self.fitresult = fitresult
149        #self.last_update = time.time()
150        #self.func_name = "Functor"
151        #self.name = "Fill in proper name!"
152
153    def residuals(self, p):
154        """
155        Cost function for scipy.optimize.leastsq, which does not have a monitor
156        built into the algorithm, and instead relies on a monitor built into
157        the cost function.
158        """
159        # Note: technically, unbounded fitters and unmonitored fitters are
160        self.setp(x)
161
162        # Compute penalty for being out of bounds which increases the farther
163        # you get out of bounds.  This allows derivative following algorithms
164        # to point back toward the feasible region.
165        penalty = self.bounds_penalty()
166        if penalty > 0:
167            self.theory = numpy.ones(self.data.num_points)
168            self.res = self.theory*(penalty/self.data.num_points) + BOUNDS_PENALTY
169            return self.res
170
171        # If no penalty, then we are not out of bounds and we can use the
172        # normal residual calculation
173        SasProblem.residuals(self, p)
174
175        # send update to the application
176        if True:
177            #self.fitresult.set_model(model=self.model)
178            # copy residuals into fit results
179            self.fitresult.residuals = self.res+0
180            self.fitresult.iterations += 1
181            self.fitresult.theory = self.theory+0
182
183            self.fitresult.p = numpy.array(p) # force copy, and coversion to array
184            self.fitresult.set_fitness(fitness=self.chisq())
185            if self.msg_q is not None:
186                self.msg_q.put(self.fitresult)
187
188            if self.handler is not None:
189                self.handler.set_result(result=self.fitresult)
190                self.handler.update_fit()
191
192            if self.curr_thread != None:
193                try:
194                    self.curr_thread.isquit()
195                except:
196                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
197                    #msg += "fitting may cause a 'Functor error message' "
198                    #msg += "being recorded in the log file....."
199                    #self.handler.stop(msg)
200                    raise
201
202        return self.res
203
204    def bounds_penalty(self):
205        from numpy import sum, where
206        p, bounds = self.getp(), self.bounds()
207        return (sum(where(p<bounds[:,0], bounds[:,0]-p, 0)**2)
208              + sum(where(p>bounds[:,1], bounds[:,1]-p, 0)**2) )
209
210class BumpsFit(FitEngine):
211    """
212    Fit a model using bumps.
213    """
214    def __init__(self):
215        """
216        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
217        with Uid as keys
218        """
219        FitEngine.__init__(self)
220        self.curr_thread = None
221
222    def fit(self, msg_q=None,
223            q=None, handler=None, curr_thread=None,
224            ftol=1.49012e-8, reset_flag=False):
225        """
226        """
227        fitproblem = []
228        for fproblem in self.fit_arrange_dict.itervalues():
229            if fproblem.get_to_fit() == 1:
230                fitproblem.append(fproblem)
231        if len(fitproblem) > 1 :
232            msg = "Bumps can't fit more than a single fit problem at a time."
233            raise RuntimeError, msg
234        elif len(fitproblem) == 0 :
235            raise RuntimeError, "No problem scheduled for fitting."
236        model = fitproblem[0].get_model()
237        if reset_flag:
238            # reset the initial value; useful for batch
239            for name in fitproblem[0].pars:
240                ind = fitproblem[0].pars.index(name)
241                model.setParam(name, fitproblem[0].vals[ind])
242        data = fitproblem[0].get_data()
243
244        self.curr_thread = curr_thread
245
246        result = FResult(model=model, data=data, param_list=self.param_list)
247        result.pars = fitproblem[0].pars
248        result.fitter_id = self.fitter_id
249        result.index = data.idx
250        if handler is not None:
251            handler.set_result(result=result)
252
253        if True: # bumps
254            problem = SasProblem(param_list=self.param_list,
255                                 model=model.model,
256                                 data=data)
257            run_bumps(problem, result, ftol,
258                      handler, curr_thread, msg_q)
259        else:
260            problem = SasProblem(param_list=self.param_list,
261                                 model=model.model,
262                                 data=data,
263                                 handler=handler,
264                                 fitresult=result,
265                                 curr_thread=curr_thread,
266                                 msg_q=msg_q)
267            run_levenburg_marquardt(problem, result, ftol)
268
269        if handler is not None:
270            handler.update_fit(last=True)
271        if q is not None:
272            q.put(result)
273            return q
274        #if success < 1 or success > 5:
275        #    result.fitness = None
276        return [result]
277
278def run_bumps(problem, result, ftol, handler, curr_thread, msg_q):
279    def abort_test():
280        if curr_thread is None: return False
281        try: curr_thread.isquit()
282        except KeyboardInterrupt:
283            if handler is not None:
284                handler.stop("Fitting: Terminated!!!")
285            return True
286        return False
287
288    fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
289    fitclass = fitopts.fitclass
290    options = fitopts.options.copy()
291    max_steps = fitopts.options.get('steps', 0) + fitopts.options.get('burn', 0)
292    if 'monitors' not in options:
293        options['monitors'] = [BumpsMonitor(handler, max_steps)]
294    options['ftol'] = ftol
295    fitdriver = fitters.FitDriver(fitclass, problem=problem,
296                                  abort_test=abort_test, **options)
297    mapper = SerialMapper
298    fitdriver.mapper = mapper.start_mapper(problem, None)
299    try:
300        best, fbest = fitdriver.fit()
301    except:
302        import traceback; traceback.print_exc()
303        raise
304    finally:
305        mapper.stop_mapper(fitdriver.mapper)
306    #print "best,fbest",best,fbest,problem.dof
307    result.fitness = 2*fbest/problem.dof
308    #print "fitness",result.fitness
309    result.stderr  = fitdriver.stderr()
310    result.pvec = best
311    # TODO: track success better
312    result.success = True
313    result.theory = problem.theory
314
315def run_levenburg_marquardt(problem, result, ftol):
316    # This import must be here; otherwise it will be confused when more
317    # than one thread exist.
318    from scipy import optimize
319
320    out, cov_x, _, mesg, success = optimize.leastsq(problem.residuals,
321                                                    problem.getp(),
322                                                    ftol=ftol,
323                                                    full_output=1)
324    if cov_x is not None and numpy.isfinite(cov_x).all():
325        stderr = numpy.sqrt(numpy.diag(cov_x))
326    else:
327        stderr = []
328    result.fitness = problem.chisq()
329    result.stderr  = stderr
330    result.pvec = out
331    result.success = success
332    result.theory = problem.theory
333
Note: See TracBrowser for help on using the repository browser.