source: sasview/src/sans/fit/BumpsFitting.py @ 6fe5100

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 6fe5100 was 6fe5100, checked in by pkienzle, 10 years ago

Bumps first pass. Fitting works but no pretty pictures

  • Property mode set to 100644
File size: 9.2 KB
Line 
1"""
2BumpsFitting module runs the bumps optimizer.
3"""
4import sys
5import copy
6
7import numpy
8
9from bumps import fitters
10from bumps.mapper import SerialMapper
11
12from sans.fit.AbstractFitEngine import FitEngine
13from sans.fit.AbstractFitEngine import FResult
14
15class SansAssembly(object):
16    """
17    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
18    """
19    def __init__(self, paramlist, model=None, data=None, fitresult=None,
20                 handler=None, curr_thread=None, msg_q=None):
21        """
22        :param Model: the model wrapper fro sans -model
23        :param Data: the data wrapper for sans data
24        """
25        self.model = model
26        self.data = data
27        self.paramlist = paramlist
28        self.msg_q = msg_q
29        self.curr_thread = curr_thread
30        self.handler = handler
31        self.fitresult = fitresult
32        self.res = []
33        self.func_name = "Functor"
34        self.theory = None
35        self.name = "Fill in proper name!"
36
37    @property
38    def dof(self):
39        return self.data.num_points - len(self.paramlist)
40
41    def summarize(self):
42        return "summarize"
43
44    def nllf(self, pvec=None):
45        residuals = self.residuals(pvec)
46        return 0.5*numpy.sum(residuals**2)
47
48    def setp(self, params):
49        self.model.set_params(self.paramlist, params)
50
51    def getp(self):
52        return numpy.asarray(self.model.get_params(self.paramlist))
53
54    def bounds(self):
55        return numpy.array([self._getrange(p) for p in self.paramlist]).T
56
57    def labels(self):
58        return self.paramlist
59
60    def _getrange(self, p):
61        """
62        Override _getrange of park parameter
63        return the range of parameter
64        """
65        lo, hi = self.model.model.details[p][1:3]
66        if lo is None: lo = -numpy.inf
67        if hi is None: hi = numpy.inf
68        return lo, hi
69
70    def randomize(self, n):
71        pvec = self.getp()
72        # since randn is symmetric and random, doesn't matter
73        # point value is negative.
74        # TODO: throw in bounds checking!
75        return numpy.random.randn(n, len(self.paramlist))*pvec + pvec
76
77    def chisq(self):
78        """
79        Calculates chi^2
80
81        :param params: list of parameter values
82
83        :return: chi^2
84
85        """
86        total = 0
87        for item in self.res:
88            total += item * item
89        if len(self.res) == 0:
90            return None
91        return total / len(self.res)
92
93    def residuals(self, params=None):
94        """
95        Compute residuals
96        :param params: value of parameters to fit
97        """
98        if params is not None: self.setp(params)
99        #import thread
100        #print "params", params
101        self.res, self.theory = self.data.residuals(self.model.eval)
102
103        if self.fitresult is not None:
104            self.fitresult.set_model(model=self.model)
105            self.fitresult.residuals = self.res+0
106            self.fitresult.iterations += 1
107            self.fitresult.theory = self.theory+0
108
109            #fitness = self.chisq(params=params)
110            fitness = self.chisq()
111            self.fitresult.pvec = params
112            self.fitresult.set_fitness(fitness=fitness)
113            if self.msg_q is not None:
114                self.msg_q.put(self.fitresult)
115
116            if self.handler is not None:
117                self.handler.set_result(result=self.fitresult)
118                self.handler.update_fit()
119
120            if self.curr_thread != None:
121                try:
122                    self.curr_thread.isquit()
123                except:
124                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
125                    #msg += "fitting may cause a 'Functor error message' "
126                    #msg += "being recorded in the log file....."
127                    #self.handler.stop(msg)
128                    raise
129
130        return self.res
131    __call__ = residuals
132
133    def check_param_range(self):
134        """
135        Check the lower and upper bound of the parameter value
136        and set res to the inf if the value is outside of the
137        range
138        :limitation: the initial values must be within range.
139        """
140
141        #time.sleep(0.01)
142        is_outofbound = False
143        # loop through the fit parameters
144        model = self.model.model
145        for p in self.paramlist:
146            value = model.getParam(p)
147            low,high = model.details[p][1:3]
148            if low is not None and numpy.isfinite(low):
149                if p.value == 0:
150                    # This value works on Scipy
151                    # Do not change numbers below
152                    value = _SMALLVALUE
153                # For leastsq, it needs a bit step back from the boundary
154                val = low - value * _SMALLVALUE
155                if value < val:
156                    self.res *= 1e+6
157                    is_outofbound = True
158                    break
159            if high is not None and numpy.isfinite(high):
160                # This value works on Scipy
161                # Do not change numbers below
162                if value == 0:
163                    value = _SMALLVALUE
164                # For leastsq, it needs a bit step back from the boundary
165                val = high + value * _SMALLVALUE
166                if value > val:
167                    self.res *= 1e+6
168                    is_outofbound = True
169                    break
170
171        return is_outofbound
172
173class BumpsFit(FitEngine):
174    """
175    Fit a model using bumps.
176    """
177    def __init__(self):
178        """
179        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
180        with Uid as keys
181        """
182        FitEngine.__init__(self)
183        self.curr_thread = None
184
185    def fit(self, msg_q=None,
186            q=None, handler=None, curr_thread=None,
187            ftol=1.49012e-8, reset_flag=False):
188        """
189        """
190        fitproblem = []
191        for fproblem in self.fit_arrange_dict.itervalues():
192            if fproblem.get_to_fit() == 1:
193                fitproblem.append(fproblem)
194        if len(fitproblem) > 1 :
195            msg = "Bumps can't fit more than a single fit problem at a time."
196            raise RuntimeError, msg
197        elif len(fitproblem) == 0 :
198            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
199        model = fitproblem[0].get_model()
200        if reset_flag:
201            # reset the initial value; useful for batch
202            for name in fitproblem[0].pars:
203                ind = fitproblem[0].pars.index(name)
204                model.setParam(name, fitproblem[0].vals[ind])
205        listdata = []
206        listdata = fitproblem[0].get_data()
207        # Concatenate dList set (contains one or more data)before fitting
208        data = listdata
209
210        self.curr_thread = curr_thread
211        ftol = ftol
212
213        result = FResult(model=model, data=data, param_list=self.param_list)
214        result.pars = fitproblem[0].pars
215        result.fitter_id = self.fitter_id
216        result.index = data.idx
217        if handler is not None:
218            handler.set_result(result=result)
219        functor = SansAssembly(paramlist=self.param_list,
220                               model=model,
221                               data=data,
222                               handler=handler,
223                               fitresult=result,
224                               curr_thread=curr_thread,
225                               msg_q=msg_q)
226        try:
227            run_bumps(functor, result)
228        except:
229            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
230                if handler is not None:
231                    msg = "Fitting: Terminated!!!"
232                    handler.stop(msg)
233                    raise KeyboardInterrupt, msg
234            else:
235                raise
236
237        if handler is not None:
238            handler.set_result(result=result)
239            handler.update_fit(last=True)
240        if q is not None:
241            q.put(result)
242            return q
243        #if success < 1 or success > 5:
244        #    result.fitness = None
245        return [result]
246
247def run_bumps(problem, result):
248    fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT]
249    fitdriver = fitters.FitDriver(fitopts.fitclass, problem=problem, 
250        abort_test=lambda: False, **fitopts.options)
251    mapper = SerialMapper
252    fitdriver.mapper = mapper.start_mapper(problem, None)
253    try:
254        best, fbest = fitdriver.fit()
255    except:
256        import traceback; traceback.print_exc()
257        raise
258    mapper.stop_mapper(fitdriver.mapper)
259    fitdriver.show()
260    #fitdriver.plot()
261    result.fitness = fbest * 2. / len(result.pars) 
262    result.stderr  = numpy.ones(len(result.pars))
263    result.pvec = best
264    result.success = True
265    result.theory = problem.theory
266
267def run_scipy(model, result):
268    # This import must be here; otherwise it will be confused when more
269    # than one thread exist.
270    from scipy import optimize
271
272    out, cov_x, _, mesg, success = optimize.leastsq(functor,
273                                                    model.get_params(self.param_list),
274                                                    ftol=ftol,
275                                                    full_output=1)
276    if cov_x is not None and numpy.isfinite(cov_x).all():
277        stderr = numpy.sqrt(numpy.diag(cov_x))
278    else:
279        stderr = []
280    result.fitness = functor.chisqr()
281    result.stderr  = stderr
282    result.pvec = out
283    result.success = success
284    result.theory = functor.theory
285
Note: See TracBrowser for help on using the repository browser.