source: sasview/park_integration/ScipyFitting.py @ 64bc8e3

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 64bc8e3 was 5011193, checked in by Jae Cho <jhjcho@…>, 13 years ago

revert previous commit

  • Property mode set to 100644
File size: 9.0 KB
RevLine 
[aa36f96]1
2
[792db7d5]3"""
[aa36f96]4ScipyFitting module contains FitArrange , ScipyFit,
5Parameter classes.All listed classes work together to perform a
6simple fit with scipy optimizer.
[792db7d5]7"""
[61cb28d]8
[88b5e83]9import numpy 
[511c6810]10import sys
[7705306]11from scipy import optimize
12
[b2f25dc5]13from sans.fit.AbstractFitEngine import FitEngine
14from sans.fit.AbstractFitEngine import SansAssembly
[511c6810]15from sans.fit.AbstractFitEngine import FitAbort
[a3fc33d]16IS_MAC = True
17if sys.platform.count("win32") > 0:
18    IS_MAC = False
19   
[e0072082]20class fitresult(object):
[48882d1]21    """
[aa36f96]22    Storing fit result
[48882d1]23    """
[c4d6900]24    def __init__(self, model=None, param_list=None):
[89f3b66]25        self.calls = None
26        self.fitness = None
27        self.chisqr = None
28        self.pvec = None
29        self.cov = None
30        self.info = None
31        self.mesg = None
32        self.success = None
33        self.stderr = None
[e0072082]34        self.parameters = None
[a3fc33d]35        self.is_mac = IS_MAC
[e0072082]36        self.model = model
[c4d6900]37        self.param_list = param_list
[d603001]38        self.iterations = 0
[e0072082]39     
40    def set_model(self, model):
[aa36f96]41        """
42        """
[e0072082]43        self.model = model
44       
[90c9cdf]45    def set_fitness(self, fitness):
[aa36f96]46        """
47        """
[90c9cdf]48        self.fitness = fitness
49       
[e0072082]50    def __str__(self):
[aa36f96]51        """
52        """
[b2f25dc5]53        if self.pvec == None and self.model is None and self.param_list is None:
[e0072082]54            return "No results"
55        n = len(self.model.parameterset)
[d603001]56        self.iterations += 1
[e0072082]57        result_param = zip(xrange(n), self.model.parameterset)
[db427ec]58        msg1 = ["[Iteration #: %s ]" % self.iterations]
59        msg3 = ["=== goodness of fit: %s ===" % (str(self.fitness))]
[a3fc33d]60        if not self.is_mac:
61            msg2 = ["P%-3d  %s......|.....%s" % \
[db427ec]62                (p[0], p[1], p[1].value)\
[a3fc33d]63                  for p in result_param if p[1].name in self.param_list]
64            msg =  msg1 + msg3 + msg2
65        else:
[db427ec]66            msg = msg1 + msg3
67        msg = "\n".join(msg)
[a3fc33d]68        return msg
[48882d1]69   
[e0072082]70    def print_summary(self):
[aa36f96]71        """
72        """
[e0072082]73        print self   
[88b5e83]74
[4c718654]75class ScipyFit(FitEngine):
[7705306]76    """
[aa36f96]77    ScipyFit performs the Fit.This class can be used as follow:
78    #Do the fit SCIPY
79    create an engine: engine = ScipyFit()
80    Use data must be of type plottable
81    Use a sans model
82   
83    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
84    is saved in FitArrange object.
85    engine.set_data(data,Uid)
86   
87    Set model parameter "M1"= model.name add {model.parameter.name:value}.
88   
89    :note: Set_param() if used must always preceded set_model()
90         for the fit to be performed.In case of Scipyfit set_param is called in
91         fit () automatically.
92   
93    engine.set_param( model,"M1", {'A':2,'B':4})
94   
95    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
96    is save in FitArrange object.
97    engine.set_model(model,Uid)
98   
99    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
100    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
[7705306]101    """
[792db7d5]102    def __init__(self):
103        """
[b2f25dc5]104        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
[aa36f96]105        with Uid as keys
[792db7d5]106        """
[b2f25dc5]107        FitEngine.__init__(self)
108        self.fit_arrange_dict = {}
109        self.param_list = []
[c4d6900]110        self.curr_thread = None
[d9dc518]111    #def fit(self, *args, **kw):
112    #    return profile(self._fit, *args, **kw)
[393f0f3]113
[93de635d]114    def fit(self, q=None, handler=None, curr_thread=None, ftol=1.49012e-8):
[aa36f96]115        """
116        """
[89f3b66]117        fitproblem = []
[c4d6900]118        for fproblem in self.fit_arrange_dict.itervalues():
[89f3b66]119            if fproblem.get_to_fit() == 1:
[393f0f3]120                fitproblem.append(fproblem)
[89f3b66]121        if len(fitproblem) > 1 : 
[e0072082]122            msg = "Scipy can't fit more than a single fit problem at a time."
123            raise RuntimeError, msg
[a9e04aa]124            return
[89f3b66]125        elif len(fitproblem) == 0 : 
[a9e04aa]126            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
127            return
128   
[89f3b66]129        listdata = []
[393f0f3]130        model = fitproblem[0].get_model()
131        listdata = fitproblem[0].get_data()
[792db7d5]132        # Concatenate dList set (contains one or more data)before fitting
[e0072082]133        data = listdata
[852354c8]134       
[89f3b66]135        self.curr_thread = curr_thread
[93de635d]136        ftol = ftol
[852354c8]137       
138        # Check the initial value if it is within range
139        self._check_param_range(model)
140       
141        result = fitresult(model=model, param_list=self.param_list)
142        if handler is not None:
143            handler.set_result(result=result)
[fd6b789]144        #try:
[852354c8]145        functor = SansAssembly(self.param_list, model, data, handler=handler,
146                         fitresult=result, curr_thread= self.curr_thread)
[511c6810]147        try:
[db427ec]148            out, cov_x, _, mesg, success = optimize.leastsq(functor,
[c4d6900]149                                            model.get_params(self.param_list),
[852354c8]150                                                    ftol=ftol,
[c4d6900]151                                                    full_output=1,
152                                                    warning=True)
[acfff8b]153        except KeyboardInterrupt:
154            msg = "Fitting: Terminated!!!"
155            handler.error(msg)
156            raise KeyboardInterrupt, msg #<= more stable
157            #less stable below
158            """
159            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
[852354c8]160                if handler is not None:
[acfff8b]161                    msg = "Fitting: Terminated!!!"
162                    handler.error(msg)
[852354c8]163                    result = handler.get_result()
164                    return result
[511c6810]165            else:
166                raise
[acfff8b]167            """
[e0e22f2c]168        except:
169            raise
170       
[c4d6900]171        chisqr = functor.chisq()
[fd6b789]172        if cov_x is not None and numpy.isfinite(cov_x).all():
173            stderr = numpy.sqrt(numpy.diag(cov_x))
174        else:
[e0072082]175            stderr = None
[511c6810]176
[852354c8]177        if not (numpy.isnan(out).any()) and (cov_x != None):
178            result.fitness = chisqr
179            result.stderr  = stderr
180            result.pvec = out
181            result.success = success
182            if q is not None:
183                q.put(result)
184                return q
185            return result
186       
187        # Error will be present to the client, not here
188        #else: 
189        #    raise ValueError, "SVD did not converge" + str(mesg)
190       
191    def _check_param_range(self, model):
192        """
193        Check parameter range and set the initial value inside
194        if it is out of range.
195       
196        : model: park model object
197        """
198        is_outofbound = False
199        # loop through parameterset
200        for p in model.parameterset:       
201            param_name = p.get_name()
202            # proceed only if the parameter name is in the list of fitting
203            if param_name in self.param_list:
204                # if the range was defined, check the range
205                if numpy.isfinite(p.range[0]):
206                    if p.value <= p.range[0]: 
207                        # 10 % backing up from the border if not zero
208                        # for Scipy engine to work properly.
209                        shift = self._get_zero_shift(p.range[0])
210                        new_value = p.range[0] + shift
211                        p.value =  new_value
212                        is_outofbound = True
213                if numpy.isfinite(p.range[1]):
214                    if p.value >= p.range[1]:
215                        shift = self._get_zero_shift(p.range[1])
216                        # 10 % backing up from the border if not zero
217                        # for Scipy engine to work properly.
218                        new_value = p.range[1] - shift
219                        # Check one more time if the new value goes below
220                        # the low bound, If so, re-evaluate the value
221                        # with the mean of the range.
222                        if numpy.isfinite(p.range[0]):
223                            if new_value < p.range[0]:
224                                new_value = (p.range[0] + p.range[1]) / 2.0
225                        # Todo:
226                        # Need to think about when both min and max are same.
227                        p.value =  new_value
228                        is_outofbound = True
229                       
230        return is_outofbound
231   
232    def _get_zero_shift(self, range):
233        """
234        Get 10% shift of the param value = 0 based on the range value
235       
236        : param range: min or max value of the bounds
237        """
238        if range == 0:
239            shift = 0.1
240        else:
241            shift = 0.1 * range
242           
243        return shift
244   
[e0072082]245   
[c4d6900]246#def profile(fn, *args, **kw):
247#    import cProfile, pstats, os
248#    global call_result
249#   def call():
250#        global call_result
251#        call_result = fn(*args, **kw)
252#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
253#    stats = pstats.Stats('profile.out')
254#    stats.sort_stats('time')
255#    stats.sort_stats('calls')
256#    stats.print_stats()
257#    os.unlink('profile.out')
258#    return call_result
[9c648c7]259
[48882d1]260     
Note: See TracBrowser for help on using the repository browser.