source: sasview/src/sans/fit/BumpsFitting.py @ 90f49a8

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 90f49a8 was 90f49a8, checked in by pkienzle, 10 years ago

oops…use bumps instead of scipy.leasqr in bumps wrapper

  • Property mode set to 100644
File size: 10.4 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import sys
5
6import numpy
7
8from bumps import fitters
9from bumps.mapper import SerialMapper
10
11from sans.fit.AbstractFitEngine import FitEngine
12from sans.fit.AbstractFitEngine import FResult
13
14class SasProblem(object):
15    """
16    Wrap the SAS model in a form that can be understood by bumps.
17    """
18    def __init__(self, param_list, model=None, data=None, fitresult=None,
19                 handler=None, curr_thread=None, msg_q=None):
20        """
21        :param Model: the model wrapper fro sans -model
22        :param Data: the data wrapper for sans data
23        """
24        self.model = model
25        self.data = data
26        self.param_list = param_list
27        self.msg_q = msg_q
28        self.curr_thread = curr_thread
29        self.handler = handler
30        self.fitresult = fitresult
31        self.res = []
32        self.func_name = "Functor"
33        self.theory = None
34        self.name = "Fill in proper name!"
35
36    @property
37    def dof(self):
38        return self.data.num_points - len(self.param_list)
39
40    def summarize(self):
41        """
42        Return a stylized list of parameter names and values with range bars
43        suitable for printing.
44        """
45        output = []
46        bounds = self.bounds()
47        for i,p in enumerate(self.getp()):
48            name = self.param_list[i]
49            low,high = bounds[:,i]
50            range = ",".join((("[%g"%low if numpy.isfinite(low) else "(-inf"),
51                              ("%g]"%high if numpy.isfinite(high) else "inf)")))
52            if not numpy.isfinite(p):
53                bar = "*invalid* "
54            else:
55                bar = ['.']*10
56                if numpy.isfinite(high-low):
57                    position = int(9.999999999 * float(p-low)/float(high-low))
58                    if position < 0: bar[0] = '<'
59                    elif position > 9: bar[9] = '>'
60                    else: bar[position] = '|'
61                bar = "".join(bar)
62            output.append("%40s %s %10g in %s"%(name,bar,p,range))
63        return "\n".join(output)
64
65    def nllf(self, p=None):
66        residuals = self.residuals(p)
67        return 0.5*numpy.sum(residuals**2)
68
69    def setp(self, p):
70        for k,v in zip(self.param_list, p):
71            self.model.setParam(k,v)
72        #self.model.set_params(self.param_list, params)
73
74    def getp(self):
75        return numpy.array([self.model.getParam(k) for k in self.param_list])
76        #return numpy.asarray(self.model.get_params(self.param_list))
77
78    def bounds(self):
79        return numpy.array([self._getrange(p) for p in self.param_list]).T
80
81    def labels(self):
82        return self.param_list
83
84    def _getrange(self, p):
85        """
86        Override _getrange of park parameter
87        return the range of parameter
88        """
89        lo, hi = self.model.details[p][1:3]
90        if lo is None: lo = -numpy.inf
91        if hi is None: hi = numpy.inf
92        return lo, hi
93
94    def randomize(self, n):
95        p = self.getp()
96        # since randn is symmetric and random, doesn't matter
97        # point value is negative.
98        # TODO: throw in bounds checking!
99        return numpy.random.randn(n, len(self.param_list))*p + p
100
101    def chisq(self):
102        """
103        Calculates chi^2
104
105        :param params: list of parameter values
106
107        :return: chi^2
108
109        """
110        return numpy.sum(self.res**2)/self.dof
111
112    def residuals(self, params=None):
113        """
114        Compute residuals
115        :param params: value of parameters to fit
116        """
117        if params is not None: self.setp(params)
118        #import thread
119        #print "params", params
120        self.res, self.theory = self.data.residuals(self.model.evalDistribution)
121
122        # TODO: this belongs in monitor not residuals calculation
123        if False: # self.fitresult is not None:
124            #self.fitresult.set_model(model=self.model)
125            self.fitresult.residuals = self.res+0
126            self.fitresult.iterations += 1
127            self.fitresult.theory = self.theory+0
128
129            #fitness = self.chisq(params=params)
130            fitness = self.chisq()
131            self.fitresult.p = params
132            self.fitresult.set_fitness(fitness=fitness)
133            if self.msg_q is not None:
134                self.msg_q.put(self.fitresult)
135
136            if self.handler is not None:
137                self.handler.set_result(result=self.fitresult)
138                self.handler.update_fit()
139
140            if self.curr_thread != None:
141                try:
142                    self.curr_thread.isquit()
143                except:
144                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
145                    #msg += "fitting may cause a 'Functor error message' "
146                    #msg += "being recorded in the log file....."
147                    #self.handler.stop(msg)
148                    raise
149
150        return self.res
151    __call__ = residuals
152
153    def _DEAD_check_param_range(self):
154        """
155        Check the lower and upper bound of the parameter value
156        and set res to the inf if the value is outside of the
157        range
158        :limitation: the initial values must be within range.
159        """
160
161        #time.sleep(0.01)
162        is_outofbound = False
163        # loop through the fit parameters
164        model = self.model
165        for p in self.param_list:
166            value = model.getParam(p)
167            low,high = model.details[p][1:3]
168            if low is not None and numpy.isfinite(low):
169                if p.value == 0:
170                    # This value works on Scipy
171                    # Do not change numbers below
172                    value = _SMALLVALUE
173                # For leastsq, it needs a bit step back from the boundary
174                val = low - value * _SMALLVALUE
175                if value < val:
176                    self.res *= 1e+6
177                    is_outofbound = True
178                    break
179            if high is not None and numpy.isfinite(high):
180                # This value works on Scipy
181                # Do not change numbers below
182                if value == 0:
183                    value = _SMALLVALUE
184                # For leastsq, it needs a bit step back from the boundary
185                val = high + value * _SMALLVALUE
186                if value > val:
187                    self.res *= 1e+6
188                    is_outofbound = True
189                    break
190
191        return is_outofbound
192
193class BumpsFit(FitEngine):
194    """
195    Fit a model using bumps.
196    """
197    def __init__(self):
198        """
199        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
200        with Uid as keys
201        """
202        FitEngine.__init__(self)
203        self.curr_thread = None
204
205    def fit(self, msg_q=None,
206            q=None, handler=None, curr_thread=None,
207            ftol=1.49012e-8, reset_flag=False):
208        """
209        """
210        fitproblem = []
211        for fproblem in self.fit_arrange_dict.itervalues():
212            if fproblem.get_to_fit() == 1:
213                fitproblem.append(fproblem)
214        if len(fitproblem) > 1 :
215            msg = "Bumps can't fit more than a single fit problem at a time."
216            raise RuntimeError, msg
217        elif len(fitproblem) == 0 :
218            raise RuntimeError, "No problem scheduled for fitting."
219        model = fitproblem[0].get_model()
220        if reset_flag:
221            # reset the initial value; useful for batch
222            for name in fitproblem[0].pars:
223                ind = fitproblem[0].pars.index(name)
224                model.setParam(name, fitproblem[0].vals[ind])
225        listdata = fitproblem[0].get_data()
226        # Concatenate dList set (contains one or more data)before fitting
227        data = listdata
228
229        self.curr_thread = curr_thread
230
231        result = FResult(model=model, data=data, param_list=self.param_list)
232        result.pars = fitproblem[0].pars
233        result.fitter_id = self.fitter_id
234        result.index = data.idx
235        if handler is not None:
236            handler.set_result(result=result)
237        problem = SasProblem(param_list=self.param_list,
238                              model=model.model,
239                              data=data,
240                              handler=handler,
241                              fitresult=result,
242                              curr_thread=curr_thread,
243                              msg_q=msg_q)
244        try:
245            run_bumps(problem, result, ftol)
246            #run_scipy(problem, result, ftol)
247        except:
248            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
249                if handler is not None:
250                    msg = "Fitting: Terminated!!!"
251                    handler.stop(msg)
252                    raise KeyboardInterrupt, msg
253            else:
254                raise
255
256        if handler is not None:
257            handler.set_result(result=result)
258            handler.update_fit(last=True)
259        if q is not None:
260            q.put(result)
261            return q
262        #if success < 1 or success > 5:
263        #    result.fitness = None
264        return [result]
265
266def run_bumps(problem, result, ftol):
267    fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
268    fitclass = fitopts.fitclass
269    options = fitopts.options.copy()
270    options['ftol'] = ftol
271    fitdriver = fitters.FitDriver(fitclass, problem=problem,
272                                  abort_test=lambda: False, **options)
273    mapper = SerialMapper
274    fitdriver.mapper = mapper.start_mapper(problem, None)
275    try:
276        best, fbest = fitdriver.fit()
277    except:
278        import traceback; traceback.print_exc()
279        raise
280    finally:
281        mapper.stop_mapper(fitdriver.mapper)
282    #print "best,fbest",best,fbest,problem.dof
283    result.fitness = 2*fbest/problem.dof
284    #print "fitness",result.fitness
285    result.stderr  = fitdriver.stderr()
286    result.pvec = best
287    # TODO: track success better
288    result.success = True
289    result.theory = problem.theory
290
291def run_scipy(model, result, ftol):
292    # This import must be here; otherwise it will be confused when more
293    # than one thread exist.
294    from scipy import optimize
295
296    out, cov_x, _, mesg, success = optimize.leastsq(model.residuals,
297                                                    model.getp(),
298                                                    ftol=ftol,
299                                                    full_output=1)
300    if cov_x is not None and numpy.isfinite(cov_x).all():
301        stderr = numpy.sqrt(numpy.diag(cov_x))
302    else:
303        stderr = []
304    result.fitness = model.chisq()
305    result.stderr  = stderr
306    result.pvec = out
307    result.success = success
308    result.theory = model.theory
309
Note: See TracBrowser for help on using the repository browser.