source: sasview/src/sas/fit/ScipyFitting.py @ 35ec279

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 35ec279 was 79492222, checked in by krzywon, 10 years ago

Changed the file and folder names to remove all SANS references.

  • Property mode set to 100644
File size: 10.6 KB
Line 
1"""
2ScipyFitting module contains FitArrange , ScipyFit,
3Parameter classes.All listed classes work together to perform a
4simple fit with scipy optimizer.
5"""
6import sys
7import copy
8
9import numpy 
10
11from sas.fit.AbstractFitEngine import FitEngine
12from sas.fit.AbstractFitEngine import FResult
13
14_SMALLVALUE = 1.0e-10
15
16class SansAssembly:
17    """
18    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
19    """
20    def __init__(self, paramlist, model=None, data=None, fitresult=None,
21                 handler=None, curr_thread=None, msg_q=None):
22        """
23        :param Model: the model wrapper fro sas -model
24        :param Data: the data wrapper for sas data
25
26        """
27        self.model = model
28        self.data = data
29        self.paramlist = paramlist
30        self.msg_q = msg_q
31        self.curr_thread = curr_thread
32        self.handler = handler
33        self.fitresult = fitresult
34        self.res = []
35        self.true_res = []
36        self.func_name = "Functor"
37        self.theory = None
38
39    def chisq(self):
40        """
41        Calculates chi^2
42
43        :param params: list of parameter values
44
45        :return: chi^2
46
47        """
48        total = 0
49        for item in self.true_res:
50            total += item * item
51        if len(self.true_res) == 0:
52            return None
53        return total / (len(self.true_res) - len(self.paramlist))
54
55    def __call__(self, params):
56        """
57            Compute residuals
58            :param params: value of parameters to fit
59        """
60        #import thread
61        self.model.set_params(self.paramlist, params)
62        #print "params", params
63        self.true_res, theory = self.data.residuals(self.model.eval)
64        self.theory = copy.deepcopy(theory)
65        # check parameters range
66        if self.check_param_range():
67            # if the param value is outside of the bound
68            # just silent return res = inf
69            return self.res
70        self.res = self.true_res
71
72        if self.fitresult is not None:
73            self.fitresult.set_model(model=self.model)
74            self.fitresult.residuals = self.true_res
75            self.fitresult.iterations += 1
76            self.fitresult.theory = theory
77
78            #fitness = self.chisq(params=params)
79            fitness = self.chisq()
80            self.fitresult.pvec = params
81            self.fitresult.set_fitness(fitness=fitness)
82            if self.msg_q is not None:
83                self.msg_q.put(self.fitresult)
84
85            if self.handler is not None:
86                self.handler.set_result(result=self.fitresult)
87                self.handler.update_fit()
88
89            if self.curr_thread != None:
90                try:
91                    self.curr_thread.isquit()
92                except:
93                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
94                    #msg += "fitting may cause a 'Functor error message' "
95                    #msg += "being recorded in the log file....."
96                    #self.handler.stop(msg)
97                    raise
98
99        return self.res
100
101    def check_param_range(self):
102        """
103        Check the lower and upper bound of the parameter value
104        and set res to the inf if the value is outside of the
105        range
106        :limitation: the initial values must be within range.
107        """
108
109        #time.sleep(0.01)
110        is_outofbound = False
111        # loop through the fit parameters
112        model = self.model.model
113        for p in self.paramlist:
114            value = model.getParam(p)
115            low,high = model.details[p][1:3]
116            if low is not None and numpy.isfinite(low):
117                if p.value == 0:
118                    # This value works on Scipy
119                    # Do not change numbers below
120                    value = _SMALLVALUE
121                # For leastsq, it needs a bit step back from the boundary
122                val = low - value * _SMALLVALUE
123                if value < val:
124                    self.res *= 1e+6
125                    is_outofbound = True
126                    break
127            if high is not None and numpy.isfinite(high):
128                # This value works on Scipy
129                # Do not change numbers below
130                if value == 0:
131                    value = _SMALLVALUE
132                # For leastsq, it needs a bit step back from the boundary
133                val = high + value * _SMALLVALUE
134                if value > val:
135                    self.res *= 1e+6
136                    is_outofbound = True
137                    break
138
139        return is_outofbound
140
141class ScipyFit(FitEngine):
142    """
143    ScipyFit performs the Fit.This class can be used as follow:
144    #Do the fit SCIPY
145    create an engine: engine = ScipyFit()
146    Use data must be of type plottable
147    Use a sas model
148   
149    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
150    is saved in FitArrange object.
151    engine.set_data(data,Uid)
152   
153    Set model parameter "M1"= model.name add {model.parameter.name:value}.
154   
155    :note: Set_param() if used must always preceded set_model()
156         for the fit to be performed.In case of Scipyfit set_param is called in
157         fit () automatically.
158   
159    engine.set_param( model,"M1", {'A':2,'B':4})
160   
161    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
162    is save in FitArrange object.
163    engine.set_model(model,Uid)
164   
165    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
166    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
167    """
168    def __init__(self):
169        """
170        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
171        with Uid as keys
172        """
173        FitEngine.__init__(self)
174        self.curr_thread = None
175    #def fit(self, *args, **kw):
176    #    return profile(self._fit, *args, **kw)
177
178    def fit(self, msg_q=None,
179            q=None, handler=None, curr_thread=None, 
180            ftol=1.49012e-8, reset_flag=False):
181        """
182        """
183        fitproblem = []
184        for fproblem in self.fit_arrange_dict.itervalues():
185            if fproblem.get_to_fit() == 1:
186                fitproblem.append(fproblem)
187        if len(fitproblem) > 1 : 
188            msg = "Scipy can't fit more than a single fit problem at a time."
189            raise RuntimeError, msg
190        elif len(fitproblem) == 0 :
191            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
192        model = fitproblem[0].get_model()
193        pars = fitproblem[0].pars
194        if reset_flag:
195            # reset the initial value; useful for batch
196            for name in fitproblem[0].pars:
197                ind = fitproblem[0].pars.index(name)
198                model.model.setParam(name, fitproblem[0].vals[ind])
199        listdata = []
200        listdata = fitproblem[0].get_data()
201        # Concatenate dList set (contains one or more data)before fitting
202        data = listdata
203       
204        self.curr_thread = curr_thread
205        ftol = ftol
206       
207        # Check the initial value if it is within range
208        _check_param_range(model.model, pars)
209       
210        result = FResult(model=model.model, data=data, param_list=pars)
211        result.fitter_id = self.fitter_id
212        if handler is not None:
213            handler.set_result(result=result)
214        functor = SansAssembly(paramlist=pars,
215                               model=model,
216                               data=data,
217                               handler=handler,
218                               fitresult=result,
219                               curr_thread=curr_thread,
220                               msg_q=msg_q)
221        try:
222            # This import must be here; otherwise it will be confused when more
223            # than one thread exist.
224            from scipy import optimize
225           
226            out, cov_x, _, mesg, success = optimize.leastsq(functor,
227                                            model.get_params(pars),
228                                            ftol=ftol,
229                                            full_output=1)
230        except:
231            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
232                if handler is not None:
233                    msg = "Fitting: Terminated!!!"
234                    handler.stop(msg)
235                    raise KeyboardInterrupt, msg
236            else:
237                raise
238        chisqr = functor.chisq()
239
240        if cov_x is not None and numpy.isfinite(cov_x).all():
241            stderr = numpy.sqrt(numpy.diag(cov_x))
242        else:
243            stderr = []
244           
245        result.index = data.idx
246        result.fitness = chisqr
247        result.stderr  = stderr
248        result.pvec = out
249        result.success = success
250        result.theory = functor.theory
251        if handler is not None:
252            handler.set_result(result=result)
253            handler.update_fit(last=True)
254        if q is not None:
255            q.put(result)
256            return q
257        if success < 1 or success > 5:
258            result.fitness = None
259        return [result]
260
261       
262def _check_param_range(model, pars):
263    """
264    Check parameter range and set the initial value inside
265    if it is out of range.
266
267    : model: park model object
268    """
269    # loop through parameterset
270    for p in pars:
271        value = model.getParam(p)
272        low,high = model.details.setdefault(p,["",None,None])[1:3]
273        # if the range was defined, check the range
274        if low is not None and value <= low:
275            value = low + _get_zero_shift(low)
276        if high is not None and value > high:
277            value = high - _get_zero_shift(high)
278            # Check one more time if the new value goes below
279            # the low bound, If so, re-evaluate the value
280            # with the mean of the range.
281            if low is not None and value < low:
282                value = 0.5 * (low+high)
283        model.setParam(p, value)
284
285def _get_zero_shift(limit):
286    """
287    Get 10% shift of the param value = 0 based on the range value
288
289    : param range: min or max value of the bounds
290    """
291    return 0.1 * (limit if limit != 0.0 else 1.0)
292
293   
294#def profile(fn, *args, **kw):
295#    import cProfile, pstats, os
296#    global call_result
297#   def call():
298#        global call_result
299#        call_result = fn(*args, **kw)
300#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
301#    stats = pstats.Stats('profile.out')
302#    stats.sort_stats('time')
303#    stats.sort_stats('calls')
304#    stats.print_stats()
305#    os.unlink('profile.out')
306#    return call_result
307
308     
Note: See TracBrowser for help on using the repository browser.