source: sasview/src/sas/fit/ScipyFitting.py @ e5b1f17

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since e5b1f17 was a10364b, checked in by krzywon, 10 years ago

Found the issue with the failing tests in ScipyFitting? and fixed the
issue. The fitter was trying to check the value of a list item but the
check should have been against another list.

  • Property mode set to 100644
File size: 10.3 KB
Line 
1"""
2ScipyFitting module contains FitArrange , ScipyFit,
3Parameter classes.All listed classes work together to perform a
4simple fit with scipy optimizer.
5"""
6import sys
7import copy
8
9import numpy 
10
11from sas.fit.AbstractFitEngine import FitEngine
12from sas.fit.AbstractFitEngine import FResult
13
14_SMALLVALUE = 1.0e-10
15
16class SasAssembly:
17    """
18    Sas Assembly class a class wrapper to be call in optimizer.leastsq method
19    """
20    def __init__(self, paramlist, model=None, data=None, fitresult=None,
21                 handler=None, curr_thread=None, msg_q=None):
22        """
23        :param Model: the model wrapper fro sas -model
24        :param Data: the data wrapper for sas data
25
26        """
27        self.model = model
28        self.data = data
29        self.paramlist = paramlist
30        self.msg_q = msg_q
31        self.curr_thread = curr_thread
32        self.handler = handler
33        self.fitresult = fitresult
34        self.res = []
35        self.true_res = []
36        self.func_name = "Functor"
37        self.theory = None
38
39    def chisq(self):
40        """
41        Calculates chi^2
42
43        :param params: list of parameter values
44
45        :return: chi^2
46
47        """
48        total = 0
49        for item in self.true_res:
50            total += item * item
51        if len(self.true_res) == 0:
52            return None
53        return total / (len(self.true_res) - len(self.paramlist))
54
55    def __call__(self, params):
56        """
57            Compute residuals
58            :param params: value of parameters to fit
59        """
60        #import thread
61        self.model.set_params(self.paramlist, params)
62        #print "params", params
63        self.true_res, theory = self.data.residuals(self.model.eval)
64        self.theory = copy.deepcopy(theory)
65        # check parameters range
66        if self.check_param_range():
67            # if the param value is outside of the bound
68            # just silent return res = inf
69            return self.res
70        self.res = self.true_res
71
72        if self.fitresult is not None:
73            self.fitresult.set_model(model=self.model)
74            self.fitresult.residuals = self.true_res
75            self.fitresult.iterations += 1
76            self.fitresult.theory = theory
77
78            #fitness = self.chisq(params=params)
79            fitness = self.chisq()
80            self.fitresult.pvec = params
81            self.fitresult.set_fitness(fitness=fitness)
82            if self.msg_q is not None:
83                self.msg_q.put(self.fitresult)
84
85            if self.handler is not None:
86                self.handler.set_result(result=self.fitresult)
87                self.handler.update_fit()
88
89            if self.curr_thread != None:
90                try:
91                    self.curr_thread.isquit()
92                except:
93                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
94                    #msg += "fitting may cause a 'Functor error message' "
95                    #msg += "being recorded in the log file....."
96                    #self.handler.stop(msg)
97                    raise
98
99        return self.res
100
101    def check_param_range(self):
102        """
103        Check the lower and upper bound of the parameter value
104        and set res to the inf if the value is outside of the
105        range
106        :limitation: the initial values must be within range.
107        """
108
109        #time.sleep(0.01)
110        is_outofbound = False
111        # loop through the fit parameters
112        model = self.model.model
113        for p in self.paramlist:
114            value = model.getParam(p)
115            low,high = model.details[p][1:3]
116            if low is not None and numpy.isfinite(low):
117                if value == 0:
118                    # This value works on Scipy
119                    # Do not change numbers below
120                    value = _SMALLVALUE
121                # For leastsq, it needs a bit step back from the boundary
122                val = low - value * _SMALLVALUE
123                if value < val:
124                    self.res *= 1e+6
125                    is_outofbound = True
126                    break
127            if high is not None and numpy.isfinite(high):
128                # This value works on Scipy
129                # Do not change numbers below
130                if value == 0:
131                    value = _SMALLVALUE
132                # For leastsq, it needs a bit step back from the boundary
133                val = high + value * _SMALLVALUE
134                if value > val:
135                    self.res *= 1e+6
136                    is_outofbound = True
137                    break
138
139        return is_outofbound
140
141class ScipyFit(FitEngine):
142    """
143    ScipyFit performs the Fit.This class can be used as follow:
144    #Do the fit SCIPY
145    create an engine: engine = ScipyFit()
146    Use data must be of type plottable
147    Use a sas model
148   
149    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
150    is saved in FitArrange object.
151    engine.set_data(data,Uid)
152   
153    Set model parameter "M1"= model.name add {model.parameter.name:value}.
154   
155    :note: Set_param() if used must always preceded set_model()
156         for the fit to be performed.In case of Scipyfit set_param is called in
157         fit () automatically.
158   
159    engine.set_param( model,"M1", {'A':2,'B':4})
160   
161    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
162    is save in FitArrange object.
163    engine.set_model(model,Uid)
164   
165    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
166    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
167    """
168    def __init__(self):
169        """
170        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
171        with Uid as keys
172        """
173        FitEngine.__init__(self)
174        self.curr_thread = None
175    #def fit(self, *args, **kw):
176    #    return profile(self._fit, *args, **kw)
177
178    def fit(self, msg_q=None,
179            q=None, handler=None, curr_thread=None, 
180            ftol=1.49012e-8, reset_flag=False):
181        """
182        """
183        fitproblem = []
184        for fproblem in self.fit_arrange_dict.itervalues():
185            if fproblem.get_to_fit() == 1:
186                fitproblem.append(fproblem)
187        if len(fitproblem) > 1 : 
188            msg = "Scipy can't fit more than a single fit problem at a time."
189            raise RuntimeError, msg
190        elif len(fitproblem) == 0 :
191            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
192        model = fitproblem[0].get_model()
193        pars = fitproblem[0].pars
194        if reset_flag:
195            # reset the initial value; useful for batch
196            for name in fitproblem[0].pars:
197                ind = fitproblem[0].pars.index(name)
198                model.model.setParam(name, fitproblem[0].vals[ind])
199        listdata = []
200        listdata = fitproblem[0].get_data()
201        # Concatenate dList set (contains one or more data)before fitting
202        data = listdata
203       
204        self.curr_thread = curr_thread
205        ftol = ftol
206       
207        # Check the initial value if it is within range
208        _check_param_range(model.model, pars)
209       
210        result = FResult(model=model.model, data=data, param_list=pars)
211        result.fitter_id = self.fitter_id
212        if handler is not None:
213            handler.set_result(result=result)
214        functor = SasAssembly(paramlist=pars,
215                               model=model,
216                               data=data,
217                               handler=handler,
218                               fitresult=result,
219                               curr_thread=curr_thread,
220                               msg_q=msg_q)
221        try:
222            # This import must be here; otherwise it will be confused when more
223            # than one thread exist.
224            from scipy import optimize
225           
226            out, cov_x, _, mesg, success = optimize.leastsq(functor,
227                                            model.get_params(pars),
228                                            ftol=ftol,
229                                            full_output=1)
230        except:
231            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
232                if handler is not None:
233                    msg = "Fitting: Terminated!!!"
234                    handler.stop(msg)
235                    raise KeyboardInterrupt, msg
236            else:
237                raise
238        chisqr = functor.chisq()
239
240        if cov_x is not None and numpy.isfinite(cov_x).all():
241            stderr = numpy.sqrt(numpy.diag(cov_x))
242        else:
243            stderr = []
244           
245        result.index = data.idx
246        result.fitness = chisqr
247        result.stderr  = stderr
248        result.pvec = out
249        result.success = success
250        result.theory = functor.theory
251        if handler is not None:
252            handler.set_result(result=result)
253            handler.update_fit(last=True)
254        if q is not None:
255            q.put(result)
256            return q
257        if success < 1 or success > 5:
258            result.fitness = None
259        return [result]
260
261       
262def _check_param_range(model, pars):
263    """
264    Check parameter range and set the initial value inside
265    if it is out of range.
266
267    : model: park model object
268    """
269    # loop through parameterset
270    for p in pars:
271        value = model.getParam(p)
272        low,high = model.details.setdefault(p,["",None,None])[1:3]
273        # if the range was defined, check the range
274        if low is not None and value <= low:
275            value = low + _get_zero_shift(low)
276        if high is not None and value > high:
277            value = high - _get_zero_shift(high)
278            # Check one more time if the new value goes below
279            # the low bound, If so, re-evaluate the value
280            # with the mean of the range.
281            if low is not None and value < low:
282                value = 0.5 * (low+high)
283        model.setParam(p, value)
284
285def _get_zero_shift(limit):
286    """
287    Get 10% shift of the param value = 0 based on the range value
288
289    : param range: min or max value of the bounds
290    """
291    return 0.1 * (limit if limit != 0.0 else 1.0)
292
293   
294#def profile(fn, *args, **kw):
295#    import cProfile, pstats, os
296#    global call_result
297#   def call():
298#        global call_result
299#        call_result = fn(*args, **kw)
300#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
301#    stats = pstats.Stats('profile.out')
302#    stats.sort_stats('time')
303#    stats.sort_stats('calls')
304#    stats.print_stats()
305#    os.unlink('profile.out')
306#    return call_result
307
308     
Note: See TracBrowser for help on using the repository browser.