source: sasview/src/sans/fit/ScipyFitting.py @ 8d074d9

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 8d074d9 was 8d074d9, checked in by pkienzle, 10 years ago

refactor fit internals, enabling disperser parameters

  • Property mode set to 100644
File size: 10.6 KB
Line 
1"""
2ScipyFitting module contains FitArrange , ScipyFit,
3Parameter classes.All listed classes work together to perform a
4simple fit with scipy optimizer.
5"""
6import sys
7import copy
8
9import numpy 
10
11from sans.fit.AbstractFitEngine import FitEngine
12from sans.fit.AbstractFitEngine import FResult
13
14class SansAssembly:
15    """
16    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
17    """
18    def __init__(self, paramlist, model=None, data=None, fitresult=None,
19                 handler=None, curr_thread=None, msg_q=None):
20        """
21        :param Model: the model wrapper fro sans -model
22        :param Data: the data wrapper for sans data
23
24        """
25        self.model = model
26        self.data = data
27        self.paramlist = paramlist
28        self.msg_q = msg_q
29        self.curr_thread = curr_thread
30        self.handler = handler
31        self.fitresult = fitresult
32        self.res = []
33        self.true_res = []
34        self.func_name = "Functor"
35        self.theory = None
36
37    def chisq(self):
38        """
39        Calculates chi^2
40
41        :param params: list of parameter values
42
43        :return: chi^2
44
45        """
46        total = 0
47        for item in self.true_res:
48            total += item * item
49        if len(self.true_res) == 0:
50            return None
51        return total / (len(self.true_res) - len(self.paramlist))
52
53    def __call__(self, params):
54        """
55            Compute residuals
56            :param params: value of parameters to fit
57        """
58        #import thread
59        self.model.set_params(self.paramlist, params)
60        #print "params", params
61        self.true_res, theory = self.data.residuals(self.model.eval)
62        self.theory = copy.deepcopy(theory)
63        # check parameters range
64        if self.check_param_range():
65            # if the param value is outside of the bound
66            # just silent return res = inf
67            return self.res
68        self.res = self.true_res
69
70        if self.fitresult is not None:
71            self.fitresult.set_model(model=self.model)
72            self.fitresult.residuals = self.true_res
73            self.fitresult.iterations += 1
74            self.fitresult.theory = theory
75
76            #fitness = self.chisq(params=params)
77            fitness = self.chisq()
78            self.fitresult.pvec = params
79            self.fitresult.set_fitness(fitness=fitness)
80            if self.msg_q is not None:
81                self.msg_q.put(self.fitresult)
82
83            if self.handler is not None:
84                self.handler.set_result(result=self.fitresult)
85                self.handler.update_fit()
86
87            if self.curr_thread != None:
88                try:
89                    self.curr_thread.isquit()
90                except:
91                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
92                    #msg += "fitting may cause a 'Functor error message' "
93                    #msg += "being recorded in the log file....."
94                    #self.handler.stop(msg)
95                    raise
96
97        return self.res
98
99    def check_param_range(self):
100        """
101        Check the lower and upper bound of the parameter value
102        and set res to the inf if the value is outside of the
103        range
104        :limitation: the initial values must be within range.
105        """
106
107        #time.sleep(0.01)
108        is_outofbound = False
109        # loop through the fit parameters
110        model = self.model.model
111        for p in self.paramlist:
112            value = model.getParam(p)
113            low,high = model.details[p][1:3]
114            if low is not None and numpy.isfinite(low):
115                if p.value == 0:
116                    # This value works on Scipy
117                    # Do not change numbers below
118                    value = _SMALLVALUE
119                # For leastsq, it needs a bit step back from the boundary
120                val = low - value * _SMALLVALUE
121                if value < val:
122                    self.res *= 1e+6
123                    is_outofbound = True
124                    break
125            if high is not None and numpy.isfinite(high):
126                # This value works on Scipy
127                # Do not change numbers below
128                if value == 0:
129                    value = _SMALLVALUE
130                # For leastsq, it needs a bit step back from the boundary
131                val = high + value * _SMALLVALUE
132                if value > val:
133                    self.res *= 1e+6
134                    is_outofbound = True
135                    break
136
137        return is_outofbound
138
139class ScipyFit(FitEngine):
140    """
141    ScipyFit performs the Fit.This class can be used as follow:
142    #Do the fit SCIPY
143    create an engine: engine = ScipyFit()
144    Use data must be of type plottable
145    Use a sans model
146   
147    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
148    is saved in FitArrange object.
149    engine.set_data(data,Uid)
150   
151    Set model parameter "M1"= model.name add {model.parameter.name:value}.
152   
153    :note: Set_param() if used must always preceded set_model()
154         for the fit to be performed.In case of Scipyfit set_param is called in
155         fit () automatically.
156   
157    engine.set_param( model,"M1", {'A':2,'B':4})
158   
159    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
160    is save in FitArrange object.
161    engine.set_model(model,Uid)
162   
163    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
164    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
165    """
166    def __init__(self):
167        """
168        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
169        with Uid as keys
170        """
171        FitEngine.__init__(self)
172        self.curr_thread = None
173    #def fit(self, *args, **kw):
174    #    return profile(self._fit, *args, **kw)
175
176    def fit(self, msg_q=None,
177            q=None, handler=None, curr_thread=None, 
178            ftol=1.49012e-8, reset_flag=False):
179        """
180        """
181        fitproblem = []
182        for fproblem in self.fit_arrange_dict.itervalues():
183            if fproblem.get_to_fit() == 1:
184                fitproblem.append(fproblem)
185        if len(fitproblem) > 1 : 
186            msg = "Scipy can't fit more than a single fit problem at a time."
187            raise RuntimeError, msg
188        elif len(fitproblem) == 0 :
189            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
190        model = fitproblem[0].get_model()
191        if reset_flag:
192            # reset the initial value; useful for batch
193            for name in fitproblem[0].pars:
194                ind = fitproblem[0].pars.index(name)
195                model.model.setParam(name, fitproblem[0].vals[ind])
196        listdata = []
197        listdata = fitproblem[0].get_data()
198        # Concatenate dList set (contains one or more data)before fitting
199        data = listdata
200       
201        self.curr_thread = curr_thread
202        ftol = ftol
203       
204        # Check the initial value if it is within range
205        _check_param_range(model.model, self.param_list)
206       
207        result = FResult(model=model.model, data=data, param_list=self.param_list)
208        result.fitter_id = self.fitter_id
209        if handler is not None:
210            handler.set_result(result=result)
211        functor = SansAssembly(paramlist=self.param_list,
212                               model=model,
213                               data=data,
214                               handler=handler,
215                               fitresult=result,
216                               curr_thread=curr_thread,
217                               msg_q=msg_q)
218        try:
219            # This import must be here; otherwise it will be confused when more
220            # than one thread exist.
221            from scipy import optimize
222           
223            out, cov_x, _, mesg, success = optimize.leastsq(functor,
224                                            model.get_params(self.param_list),
225                                            ftol=ftol,
226                                            full_output=1)
227        except:
228            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
229                if handler is not None:
230                    msg = "Fitting: Terminated!!!"
231                    handler.stop(msg)
232                    raise KeyboardInterrupt, msg
233            else:
234                raise
235        chisqr = functor.chisq()
236
237        if cov_x is not None and numpy.isfinite(cov_x).all():
238            stderr = numpy.sqrt(numpy.diag(cov_x))
239        else:
240            stderr = []
241           
242        result.index = data.idx
243        result.fitness = chisqr
244        result.stderr  = stderr
245        result.pvec = out
246        result.success = success
247        result.theory = functor.theory
248        if handler is not None:
249            handler.set_result(result=result)
250            handler.update_fit(last=True)
251        if q is not None:
252            q.put(result)
253            return q
254        if success < 1 or success > 5:
255            result.fitness = None
256        return [result]
257
258       
259def _check_param_range(model, param_list):
260    """
261    Check parameter range and set the initial value inside
262    if it is out of range.
263
264    : model: park model object
265    """
266    # loop through parameterset
267    for p in param_list:
268        value = model.getParam(p)
269        low,high = model.details.setdefault(p,["",None,None])[1:3]
270        # if the range was defined, check the range
271        if low is not None and value <= low:
272            value = low + _get_zero_shift(low)
273        if high is not None and value > high:
274            value = high - _get_zero_shift(high)
275            # Check one more time if the new value goes below
276            # the low bound, If so, re-evaluate the value
277            # with the mean of the range.
278            if low is not None and value < low:
279                value = 0.5 * (low+high)
280        model.setParam(p, value)
281
282def _get_zero_shift(limit):
283    """
284    Get 10% shift of the param value = 0 based on the range value
285
286    : param range: min or max value of the bounds
287    """
288    return 0.1 * (limit if limit != 0.0 else 1.0)
289
290   
291#def profile(fn, *args, **kw):
292#    import cProfile, pstats, os
293#    global call_result
294#   def call():
295#        global call_result
296#        call_result = fn(*args, **kw)
297#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
298#    stats = pstats.Stats('profile.out')
299#    stats.sort_stats('time')
300#    stats.sort_stats('calls')
301#    stats.print_stats()
302#    os.unlink('profile.out')
303#    return call_result
304
305     
Note: See TracBrowser for help on using the repository browser.