source: sasview/src/sas/fit/ScipyFitting.py @ 386ffe1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 386ffe1 was 386ffe1, checked in by pkienzle, 9 years ago

remove scipy levenburg marquardt and park from ui

  • Property mode set to 100644
File size: 10.3 KB
Line 
1"""
2ScipyFitting module contains FitArrange , ScipyFit,
3Parameter classes.All listed classes work together to perform a
4simple fit with scipy optimizer.
5"""
6_ = '''
7import sys
8import copy
9
10import numpy
11
12from sas.fit.AbstractFitEngine import FitEngine
13from sas.fit.AbstractFitEngine import FResult
14
15_SMALLVALUE = 1.0e-10
16
17class SasAssembly:
18    """
19    Sas Assembly class a class wrapper to be call in optimizer.leastsq method
20    """
21    def __init__(self, paramlist, model=None, data=None, fitresult=None,
22                 handler=None, curr_thread=None, msg_q=None):
23        """
24        :param Model: the model wrapper fro sas -model
25        :param Data: the data wrapper for sas data
26
27        """
28        self.model = model
29        self.data = data
30        self.paramlist = paramlist
31        self.msg_q = msg_q
32        self.curr_thread = curr_thread
33        self.handler = handler
34        self.fitresult = fitresult
35        self.res = []
36        self.true_res = []
37        self.func_name = "Functor"
38        self.theory = None
39
40    def chisq(self):
41        """
42        Calculates chi^2
43
44        :param params: list of parameter values
45
46        :return: chi^2
47
48        """
49        total = 0
50        for item in self.true_res:
51            total += item * item
52        if len(self.true_res) == 0:
53            return None
54        return total / (len(self.true_res) - len(self.paramlist))
55
56    def __call__(self, params):
57        """
58            Compute residuals
59            :param params: value of parameters to fit
60        """
61        #import thread
62        self.model.set_params(self.paramlist, params)
63        #print "params", params
64        self.true_res, theory = self.data.residuals(self.model.eval)
65        self.theory = copy.deepcopy(theory)
66        # check parameters range
67        if self.check_param_range():
68            # if the param value is outside of the bound
69            # just silent return res = inf
70            return self.res
71        self.res = self.true_res
72
73        if self.fitresult is not None:
74            self.fitresult.set_model(model=self.model)
75            self.fitresult.residuals = self.true_res
76            self.fitresult.iterations += 1
77            self.fitresult.theory = theory
78
79            #fitness = self.chisq(params=params)
80            fitness = self.chisq()
81            self.fitresult.pvec = params
82            self.fitresult.set_fitness(fitness=fitness)
83            if self.msg_q is not None:
84                self.msg_q.put(self.fitresult)
85
86            if self.handler is not None:
87                self.handler.set_result(result=self.fitresult)
88                self.handler.update_fit()
89
90            if self.curr_thread != None:
91                try:
92                    self.curr_thread.isquit()
93                except:
94                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
95                    #msg += "fitting may cause a 'Functor error message' "
96                    #msg += "being recorded in the log file....."
97                    #self.handler.stop(msg)
98                    raise
99
100        return self.res
101
102    def check_param_range(self):
103        """
104        Check the lower and upper bound of the parameter value
105        and set res to the inf if the value is outside of the
106        range
107        :limitation: the initial values must be within range.
108        """
109
110        #time.sleep(0.01)
111        is_outofbound = False
112        # loop through the fit parameters
113        model = self.model.model
114        for p in self.paramlist:
115            value = model.getParam(p)
116            low,high = model.details[p][1:3]
117            if low is not None and numpy.isfinite(low):
118                if value == 0:
119                    # This value works on Scipy
120                    # Do not change numbers below
121                    value = _SMALLVALUE
122                # For leastsq, it needs a bit step back from the boundary
123                val = low - value * _SMALLVALUE
124                if value < val:
125                    self.res *= 1e+6
126                    is_outofbound = True
127                    break
128            if high is not None and numpy.isfinite(high):
129                # This value works on Scipy
130                # Do not change numbers below
131                if value == 0:
132                    value = _SMALLVALUE
133                # For leastsq, it needs a bit step back from the boundary
134                val = high + value * _SMALLVALUE
135                if value > val:
136                    self.res *= 1e+6
137                    is_outofbound = True
138                    break
139
140        return is_outofbound
141
142class ScipyFit(FitEngine):
143    """
144    ScipyFit performs the Fit.This class can be used as follow:
145    #Do the fit SCIPY
146    create an engine: engine = ScipyFit()
147    Use data must be of type plottable
148    Use a sas model
149   
150    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
151    is saved in FitArrange object.
152    engine.set_data(data,Uid)
153   
154    Set model parameter "M1"= model.name add {model.parameter.name:value}.
155   
156    :note: Set_param() if used must always preceded set_model()
157         for the fit to be performed.In case of Scipyfit set_param is called in
158         fit () automatically.
159   
160    engine.set_param( model,"M1", {'A':2,'B':4})
161   
162    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
163    is save in FitArrange object.
164    engine.set_model(model,Uid)
165   
166    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
167    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
168    """
169    def __init__(self):
170        """
171        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
172        with Uid as keys
173        """
174        FitEngine.__init__(self)
175        self.curr_thread = None
176    #def fit(self, *args, **kw):
177    #    return profile(self._fit, *args, **kw)
178
179    def fit(self, msg_q=None,
180            q=None, handler=None, curr_thread=None,
181            ftol=1.49012e-8, reset_flag=False):
182        """
183        """
184        fitproblem = []
185        for fproblem in self.fit_arrange_dict.itervalues():
186            if fproblem.get_to_fit() == 1:
187                fitproblem.append(fproblem)
188        if len(fitproblem) > 1 :
189            msg = "Scipy can't fit more than a single fit problem at a time."
190            raise RuntimeError, msg
191        elif len(fitproblem) == 0 :
192            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
193        model = fitproblem[0].get_model()
194        pars = fitproblem[0].pars
195        if reset_flag:
196            # reset the initial value; useful for batch
197            for name in fitproblem[0].pars:
198                ind = fitproblem[0].pars.index(name)
199                model.model.setParam(name, fitproblem[0].vals[ind])
200        listdata = []
201        listdata = fitproblem[0].get_data()
202        # Concatenate dList set (contains one or more data)before fitting
203        data = listdata
204       
205        self.curr_thread = curr_thread
206        ftol = ftol
207       
208        # Check the initial value if it is within range
209        _check_param_range(model.model, pars)
210       
211        result = FResult(model=model.model, data=data, param_list=pars)
212        result.fitter_id = self.fitter_id
213        if handler is not None:
214            handler.set_result(result=result)
215        functor = SasAssembly(paramlist=pars,
216                               model=model,
217                               data=data,
218                               handler=handler,
219                               fitresult=result,
220                               curr_thread=curr_thread,
221                               msg_q=msg_q)
222        try:
223            # This import must be here; otherwise it will be confused when more
224            # than one thread exist.
225            from scipy import optimize
226           
227            out, cov_x, _, mesg, success = optimize.leastsq(functor,
228                                            model.get_params(pars),
229                                            ftol=ftol,
230                                            full_output=1)
231        except:
232            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
233                if handler is not None:
234                    msg = "Fitting: Terminated!!!"
235                    handler.stop(msg)
236                    raise KeyboardInterrupt, msg
237            else:
238                raise
239        chisqr = functor.chisq()
240
241        if cov_x is not None and numpy.isfinite(cov_x).all():
242            stderr = numpy.sqrt(numpy.diag(cov_x))
243        else:
244            stderr = []
245           
246        result.index = data.idx
247        result.fitness = chisqr
248        result.stderr  = stderr
249        result.pvec = out
250        result.success = success
251        result.theory = functor.theory
252        if handler is not None:
253            handler.set_result(result=result)
254            handler.update_fit(last=True)
255        if q is not None:
256            q.put(result)
257            return q
258        if success < 1 or success > 5:
259            result.fitness = None
260        return [result]
261
262       
263def _check_param_range(model, pars):
264    """
265    Check parameter range and set the initial value inside
266    if it is out of range.
267
268    : model: model object
269    """
270    # loop through parameterset
271    for p in pars:
272        value = model.getParam(p)
273        low,high = model.details.setdefault(p,["",None,None])[1:3]
274        # if the range was defined, check the range
275        if low is not None and value <= low:
276            value = low + _get_zero_shift(low)
277        if high is not None and value > high:
278            value = high - _get_zero_shift(high)
279            # Check one more time if the new value goes below
280            # the low bound, If so, re-evaluate the value
281            # with the mean of the range.
282            if low is not None and value < low:
283                value = 0.5 * (low+high)
284        model.setParam(p, value)
285
286def _get_zero_shift(limit):
287    """
288    Get 10% shift of the param value = 0 based on the range value
289
290    : param range: min or max value of the bounds
291    """
292    return 0.1 * (limit if limit != 0.0 else 1.0)
293
294   
295#def profile(fn, *args, **kw):
296#    import cProfile, pstats, os
297#    global call_result
298#   def call():
299#        global call_result
300#        call_result = fn(*args, **kw)
301#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
302#    stats = pstats.Stats('profile.out')
303#    stats.sort_stats('time')
304#    stats.sort_stats('calls')
305#    stats.print_stats()
306#    os.unlink('profile.out')
307#    return call_result
308
309
310'''
Note: See TracBrowser for help on using the repository browser.