source: sasview/park_integration/src/sans/fit/ScipyFitting.py @ dcc93e4

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since dcc93e4 was d91d2c9, checked in by Gervaise Alina <gervyh@…>, 13 years ago

remove comment

  • Property mode set to 100644
File size: 9.5 KB
RevLine 
[aa36f96]1
2
[792db7d5]3"""
[aa36f96]4ScipyFitting module contains FitArrange , ScipyFit,
5Parameter classes.All listed classes work together to perform a
6simple fit with scipy optimizer.
[792db7d5]7"""
[61cb28d]8
[88b5e83]9import numpy 
[511c6810]10import sys
[2446b66]11
[7705306]12
[b2f25dc5]13from sans.fit.AbstractFitEngine import FitEngine
14from sans.fit.AbstractFitEngine import SansAssembly
[511c6810]15from sans.fit.AbstractFitEngine import FitAbort
[634ca14]16from sans.fit.AbstractFitEngine import Model
[a3fc33d]17IS_MAC = True
18if sys.platform.count("win32") > 0:
19    IS_MAC = False
20   
[e0072082]21class fitresult(object):
[48882d1]22    """
[aa36f96]23    Storing fit result
[48882d1]24    """
[634ca14]25    def __init__(self, model=None, param_list=None, data=None):
[89f3b66]26        self.calls = None
27        self.fitness = None
28        self.chisqr = None
29        self.pvec = None
30        self.cov = None
31        self.info = None
32        self.mesg = None
33        self.success = None
34        self.stderr = None
[e0072082]35        self.parameters = None
[a3fc33d]36        self.is_mac = IS_MAC
[634ca14]37        if issubclass(model.__class__, Model):
38            model = model.model
[e0072082]39        self.model = model
[634ca14]40        self.data = data
[c4d6900]41        self.param_list = param_list
[d603001]42        self.iterations = 0
[634ca14]43       
44        self.inputs = [(self.model, self.data)]
[e0072082]45     
46    def set_model(self, model):
[aa36f96]47        """
48        """
[e0072082]49        self.model = model
50       
[90c9cdf]51    def set_fitness(self, fitness):
[aa36f96]52        """
53        """
[90c9cdf]54        self.fitness = fitness
55       
[e0072082]56    def __str__(self):
[aa36f96]57        """
58        """
[b2f25dc5]59        if self.pvec == None and self.model is None and self.param_list is None:
[e0072082]60            return "No results"
61        n = len(self.model.parameterset)
[d603001]62        self.iterations += 1
[e0072082]63        result_param = zip(xrange(n), self.model.parameterset)
[db427ec]64        msg1 = ["[Iteration #: %s ]" % self.iterations]
65        msg3 = ["=== goodness of fit: %s ===" % (str(self.fitness))]
[a3fc33d]66        if not self.is_mac:
67            msg2 = ["P%-3d  %s......|.....%s" % \
[db427ec]68                (p[0], p[1], p[1].value)\
[a3fc33d]69                  for p in result_param if p[1].name in self.param_list]
70            msg =  msg1 + msg3 + msg2
71        else:
[db427ec]72            msg = msg1 + msg3
73        msg = "\n".join(msg)
[a3fc33d]74        return msg
[48882d1]75   
[e0072082]76    def print_summary(self):
[aa36f96]77        """
78        """
[e0072082]79        print self   
[88b5e83]80
[4c718654]81class ScipyFit(FitEngine):
[7705306]82    """
[aa36f96]83    ScipyFit performs the Fit.This class can be used as follow:
84    #Do the fit SCIPY
85    create an engine: engine = ScipyFit()
86    Use data must be of type plottable
87    Use a sans model
88   
89    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
90    is saved in FitArrange object.
91    engine.set_data(data,Uid)
92   
93    Set model parameter "M1"= model.name add {model.parameter.name:value}.
94   
95    :note: Set_param() if used must always preceded set_model()
96         for the fit to be performed.In case of Scipyfit set_param is called in
97         fit () automatically.
98   
99    engine.set_param( model,"M1", {'A':2,'B':4})
100   
101    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
102    is save in FitArrange object.
103    engine.set_model(model,Uid)
104   
105    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
106    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
[7705306]107    """
[792db7d5]108    def __init__(self):
109        """
[b2f25dc5]110        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
[aa36f96]111        with Uid as keys
[792db7d5]112        """
[b2f25dc5]113        FitEngine.__init__(self)
114        self.fit_arrange_dict = {}
115        self.param_list = []
[c4d6900]116        self.curr_thread = None
[d9dc518]117    #def fit(self, *args, **kw):
118    #    return profile(self._fit, *args, **kw)
[393f0f3]119
[93de635d]120    def fit(self, q=None, handler=None, curr_thread=None, ftol=1.49012e-8):
[aa36f96]121        """
122        """
[89f3b66]123        fitproblem = []
[c4d6900]124        for fproblem in self.fit_arrange_dict.itervalues():
[89f3b66]125            if fproblem.get_to_fit() == 1:
[393f0f3]126                fitproblem.append(fproblem)
[89f3b66]127        if len(fitproblem) > 1 : 
[e0072082]128            msg = "Scipy can't fit more than a single fit problem at a time."
129            raise RuntimeError, msg
[a9e04aa]130            return
[89f3b66]131        elif len(fitproblem) == 0 : 
[a9e04aa]132            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
133            return
134   
[89f3b66]135        listdata = []
[393f0f3]136        model = fitproblem[0].get_model()
137        listdata = fitproblem[0].get_data()
[792db7d5]138        # Concatenate dList set (contains one or more data)before fitting
[e0072082]139        data = listdata
[852354c8]140       
[89f3b66]141        self.curr_thread = curr_thread
[93de635d]142        ftol = ftol
[852354c8]143       
144        # Check the initial value if it is within range
145        self._check_param_range(model)
146       
[634ca14]147        result = fitresult(model=model, data=data.sans_data, param_list=self.param_list)
[852354c8]148        if handler is not None:
149            handler.set_result(result=result)
[511c6810]150        try:
[2446b66]151            # This import must be here; otherwise it will be confused when more
152            # than one thread exist.
153            from scipy import optimize
154           
155            functor = SansAssembly(self.param_list, model, data, handler=handler,\
156                         fitresult=result, curr_thread= curr_thread)
[db427ec]157            out, cov_x, _, mesg, success = optimize.leastsq(functor,
[c4d6900]158                                            model.get_params(self.param_list),
[852354c8]159                                                    ftol=ftol,
[c4d6900]160                                                    full_output=1,
161                                                    warning=True)
[acfff8b]162        except KeyboardInterrupt:
163            msg = "Fitting: Terminated!!!"
164            handler.error(msg)
165            raise KeyboardInterrupt, msg #<= more stable
166            #less stable below
167            """
168            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
[852354c8]169                if handler is not None:
[acfff8b]170                    msg = "Fitting: Terminated!!!"
171                    handler.error(msg)
[852354c8]172                    result = handler.get_result()
173                    return result
[511c6810]174            else:
175                raise
[acfff8b]176            """
[e0e22f2c]177        except:
178            raise
[c4d6900]179        chisqr = functor.chisq()
[fd6b789]180        if cov_x is not None and numpy.isfinite(cov_x).all():
181            stderr = numpy.sqrt(numpy.diag(cov_x))
182        else:
[e0072082]183            stderr = None
[511c6810]184
[852354c8]185        if not (numpy.isnan(out).any()) and (cov_x != None):
186            result.fitness = chisqr
187            result.stderr  = stderr
188            result.pvec = out
189            result.success = success
[d91d2c9]190            #print "scipy", result.inputs
[a15da09]191            if q is not None:
[852354c8]192                q.put(result)
193                return q
[4fb520d]194            if success < 1 or success > 5:
[120d9f6]195                result = None
[852354c8]196            return result
[120d9f6]197        else:
198            return None
[852354c8]199        # Error will be present to the client, not here
200        #else: 
201        #    raise ValueError, "SVD did not converge" + str(mesg)
202       
203    def _check_param_range(self, model):
204        """
205        Check parameter range and set the initial value inside
206        if it is out of range.
207       
208        : model: park model object
209        """
210        is_outofbound = False
211        # loop through parameterset
212        for p in model.parameterset:       
213            param_name = p.get_name()
214            # proceed only if the parameter name is in the list of fitting
215            if param_name in self.param_list:
216                # if the range was defined, check the range
217                if numpy.isfinite(p.range[0]):
218                    if p.value <= p.range[0]: 
219                        # 10 % backing up from the border if not zero
220                        # for Scipy engine to work properly.
221                        shift = self._get_zero_shift(p.range[0])
222                        new_value = p.range[0] + shift
223                        p.value =  new_value
224                        is_outofbound = True
225                if numpy.isfinite(p.range[1]):
226                    if p.value >= p.range[1]:
227                        shift = self._get_zero_shift(p.range[1])
228                        # 10 % backing up from the border if not zero
229                        # for Scipy engine to work properly.
230                        new_value = p.range[1] - shift
231                        # Check one more time if the new value goes below
232                        # the low bound, If so, re-evaluate the value
233                        # with the mean of the range.
234                        if numpy.isfinite(p.range[0]):
235                            if new_value < p.range[0]:
236                                new_value = (p.range[0] + p.range[1]) / 2.0
237                        # Todo:
238                        # Need to think about when both min and max are same.
239                        p.value =  new_value
240                        is_outofbound = True
241                       
242        return is_outofbound
243   
244    def _get_zero_shift(self, range):
245        """
246        Get 10% shift of the param value = 0 based on the range value
247       
248        : param range: min or max value of the bounds
249        """
250        if range == 0:
251            shift = 0.1
252        else:
253            shift = 0.1 * range
254           
255        return shift
256   
[e0072082]257   
[c4d6900]258#def profile(fn, *args, **kw):
259#    import cProfile, pstats, os
260#    global call_result
261#   def call():
262#        global call_result
263#        call_result = fn(*args, **kw)
264#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
265#    stats = pstats.Stats('profile.out')
266#    stats.sort_stats('time')
267#    stats.sort_stats('calls')
268#    stats.print_stats()
269#    os.unlink('profile.out')
270#    return call_result
[9c648c7]271
[48882d1]272     
Note: See TracBrowser for help on using the repository browser.