source: sasview/src/sas/fit/ScipyFitting.py @ 827484cf

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 827484cf was fd5ac0d, checked in by krzywon, 10 years ago

I have completed the removal of all SANS references.
I will build, run, and run all unit tests before pushing.

  • Property mode set to 100644
File size: 10.6 KB
RevLine 
[792db7d5]1"""
[aa36f96]2ScipyFitting module contains FitArrange , ScipyFit,
3Parameter classes.All listed classes work together to perform a
4simple fit with scipy optimizer.
[792db7d5]5"""
[511c6810]6import sys
[6fe5100]7import copy
[2446b66]8
[6fe5100]9import numpy 
[7705306]10
[79492222]11from sas.fit.AbstractFitEngine import FitEngine
12from sas.fit.AbstractFitEngine import FResult
[6fe5100]13
[1792311]14_SMALLVALUE = 1.0e-10
15
[fd5ac0d]16class SasAssembly:
[6fe5100]17    """
[fd5ac0d]18    Sas Assembly class a class wrapper to be call in optimizer.leastsq method
[6fe5100]19    """
20    def __init__(self, paramlist, model=None, data=None, fitresult=None,
21                 handler=None, curr_thread=None, msg_q=None):
22        """
[79492222]23        :param Model: the model wrapper fro sas -model
24        :param Data: the data wrapper for sas data
[6fe5100]25
26        """
27        self.model = model
28        self.data = data
29        self.paramlist = paramlist
30        self.msg_q = msg_q
31        self.curr_thread = curr_thread
32        self.handler = handler
33        self.fitresult = fitresult
34        self.res = []
35        self.true_res = []
36        self.func_name = "Functor"
37        self.theory = None
38
39    def chisq(self):
40        """
41        Calculates chi^2
42
43        :param params: list of parameter values
44
45        :return: chi^2
46
47        """
48        total = 0
49        for item in self.true_res:
50            total += item * item
51        if len(self.true_res) == 0:
52            return None
[95d58d3]53        return total / (len(self.true_res) - len(self.paramlist))
[6fe5100]54
55    def __call__(self, params):
56        """
57            Compute residuals
58            :param params: value of parameters to fit
59        """
60        #import thread
61        self.model.set_params(self.paramlist, params)
62        #print "params", params
63        self.true_res, theory = self.data.residuals(self.model.eval)
64        self.theory = copy.deepcopy(theory)
65        # check parameters range
66        if self.check_param_range():
67            # if the param value is outside of the bound
68            # just silent return res = inf
69            return self.res
70        self.res = self.true_res
71
72        if self.fitresult is not None:
73            self.fitresult.set_model(model=self.model)
74            self.fitresult.residuals = self.true_res
75            self.fitresult.iterations += 1
76            self.fitresult.theory = theory
77
78            #fitness = self.chisq(params=params)
79            fitness = self.chisq()
80            self.fitresult.pvec = params
81            self.fitresult.set_fitness(fitness=fitness)
82            if self.msg_q is not None:
83                self.msg_q.put(self.fitresult)
84
85            if self.handler is not None:
86                self.handler.set_result(result=self.fitresult)
87                self.handler.update_fit()
88
89            if self.curr_thread != None:
90                try:
91                    self.curr_thread.isquit()
92                except:
93                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
94                    #msg += "fitting may cause a 'Functor error message' "
95                    #msg += "being recorded in the log file....."
96                    #self.handler.stop(msg)
97                    raise
98
99        return self.res
100
101    def check_param_range(self):
102        """
103        Check the lower and upper bound of the parameter value
104        and set res to the inf if the value is outside of the
105        range
106        :limitation: the initial values must be within range.
107        """
108
109        #time.sleep(0.01)
110        is_outofbound = False
111        # loop through the fit parameters
112        model = self.model.model
113        for p in self.paramlist:
114            value = model.getParam(p)
115            low,high = model.details[p][1:3]
116            if low is not None and numpy.isfinite(low):
117                if p.value == 0:
118                    # This value works on Scipy
119                    # Do not change numbers below
120                    value = _SMALLVALUE
121                # For leastsq, it needs a bit step back from the boundary
122                val = low - value * _SMALLVALUE
123                if value < val:
124                    self.res *= 1e+6
125                    is_outofbound = True
126                    break
127            if high is not None and numpy.isfinite(high):
128                # This value works on Scipy
129                # Do not change numbers below
130                if value == 0:
131                    value = _SMALLVALUE
132                # For leastsq, it needs a bit step back from the boundary
133                val = high + value * _SMALLVALUE
134                if value > val:
135                    self.res *= 1e+6
136                    is_outofbound = True
137                    break
138
139        return is_outofbound
[88b5e83]140
[4c718654]141class ScipyFit(FitEngine):
[7705306]142    """
[aa36f96]143    ScipyFit performs the Fit.This class can be used as follow:
144    #Do the fit SCIPY
145    create an engine: engine = ScipyFit()
146    Use data must be of type plottable
[79492222]147    Use a sas model
[aa36f96]148   
149    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
150    is saved in FitArrange object.
151    engine.set_data(data,Uid)
152   
153    Set model parameter "M1"= model.name add {model.parameter.name:value}.
154   
155    :note: Set_param() if used must always preceded set_model()
156         for the fit to be performed.In case of Scipyfit set_param is called in
157         fit () automatically.
158   
159    engine.set_param( model,"M1", {'A':2,'B':4})
160   
161    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
162    is save in FitArrange object.
163    engine.set_model(model,Uid)
164   
165    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
166    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
[7705306]167    """
[792db7d5]168    def __init__(self):
169        """
[b2f25dc5]170        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
[aa36f96]171        with Uid as keys
[792db7d5]172        """
[b2f25dc5]173        FitEngine.__init__(self)
[c4d6900]174        self.curr_thread = None
[d9dc518]175    #def fit(self, *args, **kw):
176    #    return profile(self._fit, *args, **kw)
[393f0f3]177
[ba7dceb]178    def fit(self, msg_q=None,
179            q=None, handler=None, curr_thread=None, 
[7db52f1]180            ftol=1.49012e-8, reset_flag=False):
[aa36f96]181        """
182        """
[89f3b66]183        fitproblem = []
[c4d6900]184        for fproblem in self.fit_arrange_dict.itervalues():
[89f3b66]185            if fproblem.get_to_fit() == 1:
[393f0f3]186                fitproblem.append(fproblem)
[89f3b66]187        if len(fitproblem) > 1 : 
[e0072082]188            msg = "Scipy can't fit more than a single fit problem at a time."
189            raise RuntimeError, msg
[6fe5100]190        elif len(fitproblem) == 0 :
[a9e04aa]191            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
[393f0f3]192        model = fitproblem[0].get_model()
[1792311]193        pars = fitproblem[0].pars
[7db52f1]194        if reset_flag:
195            # reset the initial value; useful for batch
196            for name in fitproblem[0].pars:
197                ind = fitproblem[0].pars.index(name)
198                model.model.setParam(name, fitproblem[0].vals[ind])
199        listdata = []
[393f0f3]200        listdata = fitproblem[0].get_data()
[792db7d5]201        # Concatenate dList set (contains one or more data)before fitting
[e0072082]202        data = listdata
[852354c8]203       
[89f3b66]204        self.curr_thread = curr_thread
[93de635d]205        ftol = ftol
[852354c8]206       
207        # Check the initial value if it is within range
[1792311]208        _check_param_range(model.model, pars)
[852354c8]209       
[1792311]210        result = FResult(model=model.model, data=data, param_list=pars)
[06e7c26]211        result.fitter_id = self.fitter_id
[852354c8]212        if handler is not None:
213            handler.set_result(result=result)
[fd5ac0d]214        functor = SasAssembly(paramlist=pars,
[6fe5100]215                               model=model,
216                               data=data,
217                               handler=handler,
218                               fitresult=result,
219                               curr_thread=curr_thread,
220                               msg_q=msg_q)
[511c6810]221        try:
[2446b66]222            # This import must be here; otherwise it will be confused when more
223            # than one thread exist.
224            from scipy import optimize
225           
[db427ec]226            out, cov_x, _, mesg, success = optimize.leastsq(functor,
[1792311]227                                            model.get_params(pars),
[6fe5100]228                                            ftol=ftol,
229                                            full_output=1)
[2d0756a5]230        except:
[acfff8b]231            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
[852354c8]232                if handler is not None:
[acfff8b]233                    msg = "Fitting: Terminated!!!"
[986da97]234                    handler.stop(msg)
[2d0756a5]235                    raise KeyboardInterrupt, msg
[511c6810]236            else:
[2d0756a5]237                raise
[c4d6900]238        chisqr = functor.chisq()
[15f68ce]239
[fd6b789]240        if cov_x is not None and numpy.isfinite(cov_x).all():
241            stderr = numpy.sqrt(numpy.diag(cov_x))
242        else:
[15f68ce]243            stderr = []
[d8661fb]244           
245        result.index = data.idx
[15f68ce]246        result.fitness = chisqr
247        result.stderr  = stderr
248        result.pvec = out
249        result.success = success
250        result.theory = functor.theory
[fe10df5]251        if handler is not None:
[cc694d0]252            handler.set_result(result=result)
[ee19117]253            handler.update_fit(last=True)
[15f68ce]254        if q is not None:
255            q.put(result)
256            return q
257        if success < 1 or success > 5:
258            result.fitness = None
[444c900e]259        return [result]
[15f68ce]260
[852354c8]261       
[1792311]262def _check_param_range(model, pars):
[6fe5100]263    """
264    Check parameter range and set the initial value inside
265    if it is out of range.
266
267    : model: park model object
268    """
269    # loop through parameterset
[1792311]270    for p in pars:
[6fe5100]271        value = model.getParam(p)
[8d074d9]272        low,high = model.details.setdefault(p,["",None,None])[1:3]
[6fe5100]273        # if the range was defined, check the range
274        if low is not None and value <= low:
275            value = low + _get_zero_shift(low)
276        if high is not None and value > high:
277            value = high - _get_zero_shift(high)
278            # Check one more time if the new value goes below
279            # the low bound, If so, re-evaluate the value
280            # with the mean of the range.
281            if low is not None and value < low:
282                value = 0.5 * (low+high)
283        model.setParam(p, value)
284
285def _get_zero_shift(limit):
286    """
287    Get 10% shift of the param value = 0 based on the range value
288
289    : param range: min or max value of the bounds
290    """
[042f065]291    return 0.1 * (limit if limit != 0.0 else 1.0)
[6fe5100]292
[e0072082]293   
[c4d6900]294#def profile(fn, *args, **kw):
295#    import cProfile, pstats, os
296#    global call_result
297#   def call():
298#        global call_result
299#        call_result = fn(*args, **kw)
300#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
301#    stats = pstats.Stats('profile.out')
302#    stats.sort_stats('time')
303#    stats.sort_stats('calls')
304#    stats.print_stats()
305#    os.unlink('profile.out')
306#    return call_result
[9c648c7]307
[48882d1]308     
Note: See TracBrowser for help on using the repository browser.