source: sasview/src/sans/fit/ScipyFitting.py @ 8d074d9

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 8d074d9 was 8d074d9, checked in by pkienzle, 10 years ago

refactor fit internals, enabling disperser parameters

  • Property mode set to 100644
File size: 10.6 KB
RevLine 
[792db7d5]1"""
[aa36f96]2ScipyFitting module contains FitArrange , ScipyFit,
3Parameter classes.All listed classes work together to perform a
4simple fit with scipy optimizer.
[792db7d5]5"""
[511c6810]6import sys
[6fe5100]7import copy
[2446b66]8
[6fe5100]9import numpy 
[7705306]10
[b2f25dc5]11from sans.fit.AbstractFitEngine import FitEngine
[6fe5100]12from sans.fit.AbstractFitEngine import FResult
13
14class SansAssembly:
15    """
16    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
17    """
18    def __init__(self, paramlist, model=None, data=None, fitresult=None,
19                 handler=None, curr_thread=None, msg_q=None):
20        """
21        :param Model: the model wrapper fro sans -model
22        :param Data: the data wrapper for sans data
23
24        """
25        self.model = model
26        self.data = data
27        self.paramlist = paramlist
28        self.msg_q = msg_q
29        self.curr_thread = curr_thread
30        self.handler = handler
31        self.fitresult = fitresult
32        self.res = []
33        self.true_res = []
34        self.func_name = "Functor"
35        self.theory = None
36
37    def chisq(self):
38        """
39        Calculates chi^2
40
41        :param params: list of parameter values
42
43        :return: chi^2
44
45        """
46        total = 0
47        for item in self.true_res:
48            total += item * item
49        if len(self.true_res) == 0:
50            return None
[95d58d3]51        return total / (len(self.true_res) - len(self.paramlist))
[6fe5100]52
53    def __call__(self, params):
54        """
55            Compute residuals
56            :param params: value of parameters to fit
57        """
58        #import thread
59        self.model.set_params(self.paramlist, params)
60        #print "params", params
61        self.true_res, theory = self.data.residuals(self.model.eval)
62        self.theory = copy.deepcopy(theory)
63        # check parameters range
64        if self.check_param_range():
65            # if the param value is outside of the bound
66            # just silent return res = inf
67            return self.res
68        self.res = self.true_res
69
70        if self.fitresult is not None:
71            self.fitresult.set_model(model=self.model)
72            self.fitresult.residuals = self.true_res
73            self.fitresult.iterations += 1
74            self.fitresult.theory = theory
75
76            #fitness = self.chisq(params=params)
77            fitness = self.chisq()
78            self.fitresult.pvec = params
79            self.fitresult.set_fitness(fitness=fitness)
80            if self.msg_q is not None:
81                self.msg_q.put(self.fitresult)
82
83            if self.handler is not None:
84                self.handler.set_result(result=self.fitresult)
85                self.handler.update_fit()
86
87            if self.curr_thread != None:
88                try:
89                    self.curr_thread.isquit()
90                except:
91                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
92                    #msg += "fitting may cause a 'Functor error message' "
93                    #msg += "being recorded in the log file....."
94                    #self.handler.stop(msg)
95                    raise
96
97        return self.res
98
99    def check_param_range(self):
100        """
101        Check the lower and upper bound of the parameter value
102        and set res to the inf if the value is outside of the
103        range
104        :limitation: the initial values must be within range.
105        """
106
107        #time.sleep(0.01)
108        is_outofbound = False
109        # loop through the fit parameters
110        model = self.model.model
111        for p in self.paramlist:
112            value = model.getParam(p)
113            low,high = model.details[p][1:3]
114            if low is not None and numpy.isfinite(low):
115                if p.value == 0:
116                    # This value works on Scipy
117                    # Do not change numbers below
118                    value = _SMALLVALUE
119                # For leastsq, it needs a bit step back from the boundary
120                val = low - value * _SMALLVALUE
121                if value < val:
122                    self.res *= 1e+6
123                    is_outofbound = True
124                    break
125            if high is not None and numpy.isfinite(high):
126                # This value works on Scipy
127                # Do not change numbers below
128                if value == 0:
129                    value = _SMALLVALUE
130                # For leastsq, it needs a bit step back from the boundary
131                val = high + value * _SMALLVALUE
132                if value > val:
133                    self.res *= 1e+6
134                    is_outofbound = True
135                    break
136
137        return is_outofbound
[88b5e83]138
[4c718654]139class ScipyFit(FitEngine):
[7705306]140    """
[aa36f96]141    ScipyFit performs the Fit.This class can be used as follow:
142    #Do the fit SCIPY
143    create an engine: engine = ScipyFit()
144    Use data must be of type plottable
145    Use a sans model
146   
147    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
148    is saved in FitArrange object.
149    engine.set_data(data,Uid)
150   
151    Set model parameter "M1"= model.name add {model.parameter.name:value}.
152   
153    :note: Set_param() if used must always preceded set_model()
154         for the fit to be performed.In case of Scipyfit set_param is called in
155         fit () automatically.
156   
157    engine.set_param( model,"M1", {'A':2,'B':4})
158   
159    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
160    is save in FitArrange object.
161    engine.set_model(model,Uid)
162   
163    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
164    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
[7705306]165    """
[792db7d5]166    def __init__(self):
167        """
[b2f25dc5]168        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
[aa36f96]169        with Uid as keys
[792db7d5]170        """
[b2f25dc5]171        FitEngine.__init__(self)
[c4d6900]172        self.curr_thread = None
[d9dc518]173    #def fit(self, *args, **kw):
174    #    return profile(self._fit, *args, **kw)
[393f0f3]175
[ba7dceb]176    def fit(self, msg_q=None,
177            q=None, handler=None, curr_thread=None, 
[7db52f1]178            ftol=1.49012e-8, reset_flag=False):
[aa36f96]179        """
180        """
[89f3b66]181        fitproblem = []
[c4d6900]182        for fproblem in self.fit_arrange_dict.itervalues():
[89f3b66]183            if fproblem.get_to_fit() == 1:
[393f0f3]184                fitproblem.append(fproblem)
[89f3b66]185        if len(fitproblem) > 1 : 
[e0072082]186            msg = "Scipy can't fit more than a single fit problem at a time."
187            raise RuntimeError, msg
[6fe5100]188        elif len(fitproblem) == 0 :
[a9e04aa]189            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
[393f0f3]190        model = fitproblem[0].get_model()
[7db52f1]191        if reset_flag:
192            # reset the initial value; useful for batch
193            for name in fitproblem[0].pars:
194                ind = fitproblem[0].pars.index(name)
195                model.model.setParam(name, fitproblem[0].vals[ind])
196        listdata = []
[393f0f3]197        listdata = fitproblem[0].get_data()
[792db7d5]198        # Concatenate dList set (contains one or more data)before fitting
[e0072082]199        data = listdata
[852354c8]200       
[89f3b66]201        self.curr_thread = curr_thread
[93de635d]202        ftol = ftol
[852354c8]203       
204        # Check the initial value if it is within range
[6fe5100]205        _check_param_range(model.model, self.param_list)
[852354c8]206       
[95d58d3]207        result = FResult(model=model.model, data=data, param_list=self.param_list)
[06e7c26]208        result.fitter_id = self.fitter_id
[852354c8]209        if handler is not None:
210            handler.set_result(result=result)
[6fe5100]211        functor = SansAssembly(paramlist=self.param_list,
212                               model=model,
213                               data=data,
214                               handler=handler,
215                               fitresult=result,
216                               curr_thread=curr_thread,
217                               msg_q=msg_q)
[511c6810]218        try:
[2446b66]219            # This import must be here; otherwise it will be confused when more
220            # than one thread exist.
221            from scipy import optimize
222           
[db427ec]223            out, cov_x, _, mesg, success = optimize.leastsq(functor,
[c4d6900]224                                            model.get_params(self.param_list),
[6fe5100]225                                            ftol=ftol,
226                                            full_output=1)
[2d0756a5]227        except:
[acfff8b]228            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
[852354c8]229                if handler is not None:
[acfff8b]230                    msg = "Fitting: Terminated!!!"
[986da97]231                    handler.stop(msg)
[2d0756a5]232                    raise KeyboardInterrupt, msg
[511c6810]233            else:
[2d0756a5]234                raise
[c4d6900]235        chisqr = functor.chisq()
[15f68ce]236
[fd6b789]237        if cov_x is not None and numpy.isfinite(cov_x).all():
238            stderr = numpy.sqrt(numpy.diag(cov_x))
239        else:
[15f68ce]240            stderr = []
[d8661fb]241           
242        result.index = data.idx
[15f68ce]243        result.fitness = chisqr
244        result.stderr  = stderr
245        result.pvec = out
246        result.success = success
247        result.theory = functor.theory
[fe10df5]248        if handler is not None:
[cc694d0]249            handler.set_result(result=result)
[ee19117]250            handler.update_fit(last=True)
[15f68ce]251        if q is not None:
252            q.put(result)
253            return q
254        if success < 1 or success > 5:
255            result.fitness = None
[444c900e]256        return [result]
[15f68ce]257
[852354c8]258       
[6fe5100]259def _check_param_range(model, param_list):
260    """
261    Check parameter range and set the initial value inside
262    if it is out of range.
263
264    : model: park model object
265    """
266    # loop through parameterset
267    for p in param_list:
268        value = model.getParam(p)
[8d074d9]269        low,high = model.details.setdefault(p,["",None,None])[1:3]
[6fe5100]270        # if the range was defined, check the range
271        if low is not None and value <= low:
272            value = low + _get_zero_shift(low)
273        if high is not None and value > high:
274            value = high - _get_zero_shift(high)
275            # Check one more time if the new value goes below
276            # the low bound, If so, re-evaluate the value
277            # with the mean of the range.
278            if low is not None and value < low:
279                value = 0.5 * (low+high)
280        model.setParam(p, value)
281
282def _get_zero_shift(limit):
283    """
284    Get 10% shift of the param value = 0 based on the range value
285
286    : param range: min or max value of the bounds
287    """
[042f065]288    return 0.1 * (limit if limit != 0.0 else 1.0)
[6fe5100]289
[e0072082]290   
[c4d6900]291#def profile(fn, *args, **kw):
292#    import cProfile, pstats, os
293#    global call_result
294#   def call():
295#        global call_result
296#        call_result = fn(*args, **kw)
297#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
298#    stats = pstats.Stats('profile.out')
299#    stats.sort_stats('time')
300#    stats.sort_stats('calls')
301#    stats.print_stats()
302#    os.unlink('profile.out')
303#    return call_result
[9c648c7]304
[48882d1]305     
Note: See TracBrowser for help on using the repository browser.