source: sasview/park_integration/ScipyFitting.py @ a3fc33d

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since a3fc33d was a3fc33d, checked in by Jae Cho <jhjcho@…>, 13 years ago

removed fit_string on MAC: 'cause shown a problem on MAC

  • Property mode set to 100644
File size: 9.0 KB
Line 
1
2
3"""
4ScipyFitting module contains FitArrange , ScipyFit,
5Parameter classes.All listed classes work together to perform a
6simple fit with scipy optimizer.
7"""
8
9import numpy 
10import sys
11from scipy import optimize
12
13from sans.fit.AbstractFitEngine import FitEngine
14from sans.fit.AbstractFitEngine import SansAssembly
15from sans.fit.AbstractFitEngine import FitAbort
16IS_MAC = True
17if sys.platform.count("win32") > 0:
18    IS_MAC = False
19   
20class fitresult(object):
21    """
22    Storing fit result
23    """
24    def __init__(self, model=None, param_list=None):
25        self.calls = None
26        self.fitness = None
27        self.chisqr = None
28        self.pvec = None
29        self.cov = None
30        self.info = None
31        self.mesg = None
32        self.success = None
33        self.stderr = None
34        self.parameters = None
35        self.is_mac = IS_MAC
36        self.model = model
37        self.param_list = param_list
38        self.iterations = 0
39     
40    def set_model(self, model):
41        """
42        """
43        self.model = model
44       
45    def set_fitness(self, fitness):
46        """
47        """
48        self.fitness = fitness
49       
50    def __str__(self):
51        """
52        """
53        if self.pvec == None and self.model is None and self.param_list is None:
54            return "No results"
55        n = len(self.model.parameterset)
56        self.iterations += 1
57        result_param = zip(xrange(n), self.model.parameterset)
58        if not self.is_mac:
59            msg1 = ["[Iteration #: %s ]" % self.iterations]
60            msg2 = ["P%-3d  %s......|.....%s" % \
61                (p[0], p[1], p[1].value)\
62                  for p in result_param if p[1].name in self.param_list]
63           
64            msg3 = ["=== goodness of fit: %s ===" % (str(self.fitness))]
65            msg =  msg1 + msg3 + msg2
66            msg = "\n".join(msg)
67        else:
68            msg = ''
69        return msg
70   
71    def print_summary(self):
72        """
73        """
74        print self   
75
76class ScipyFit(FitEngine):
77    """
78    ScipyFit performs the Fit.This class can be used as follow:
79    #Do the fit SCIPY
80    create an engine: engine = ScipyFit()
81    Use data must be of type plottable
82    Use a sans model
83   
84    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
85    is saved in FitArrange object.
86    engine.set_data(data,Uid)
87   
88    Set model parameter "M1"= model.name add {model.parameter.name:value}.
89   
90    :note: Set_param() if used must always preceded set_model()
91         for the fit to be performed.In case of Scipyfit set_param is called in
92         fit () automatically.
93   
94    engine.set_param( model,"M1", {'A':2,'B':4})
95   
96    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
97    is save in FitArrange object.
98    engine.set_model(model,Uid)
99   
100    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
101    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
102    """
103    def __init__(self):
104        """
105        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
106        with Uid as keys
107        """
108        FitEngine.__init__(self)
109        self.fit_arrange_dict = {}
110        self.param_list = []
111        self.curr_thread = None
112    #def fit(self, *args, **kw):
113    #    return profile(self._fit, *args, **kw)
114
115    def fit(self, q=None, handler=None, curr_thread=None, ftol=1.49012e-8):
116        """
117        """
118        fitproblem = []
119        for fproblem in self.fit_arrange_dict.itervalues():
120            if fproblem.get_to_fit() == 1:
121                fitproblem.append(fproblem)
122        if len(fitproblem) > 1 : 
123            msg = "Scipy can't fit more than a single fit problem at a time."
124            raise RuntimeError, msg
125            return
126        elif len(fitproblem) == 0 : 
127            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
128            return
129   
130        listdata = []
131        model = fitproblem[0].get_model()
132        listdata = fitproblem[0].get_data()
133        # Concatenate dList set (contains one or more data)before fitting
134        data = listdata
135       
136        self.curr_thread = curr_thread
137        ftol = ftol
138       
139        # Check the initial value if it is within range
140        self._check_param_range(model)
141       
142        result = fitresult(model=model, param_list=self.param_list)
143        if handler is not None:
144            handler.set_result(result=result)
145        #try:
146        functor = SansAssembly(self.param_list, model, data, handler=handler,
147                         fitresult=result, curr_thread= self.curr_thread)
148        try:
149                out, cov_x, _, mesg, success = optimize.leastsq(functor,
150                                            model.get_params(self.param_list),
151                                                    ftol=ftol,
152                                                    full_output=1,
153                                                    warning=True)
154        except KeyboardInterrupt:
155            msg = "Fitting: Terminated!!!"
156            handler.error(msg)
157            raise KeyboardInterrupt, msg #<= more stable
158            #less stable below
159            """
160            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
161                if handler is not None:
162                    msg = "Fitting: Terminated!!!"
163                    handler.error(msg)
164                    result = handler.get_result()
165                    return result
166            else:
167                raise
168            """
169        except:
170            raise
171       
172        chisqr = functor.chisq()
173        if cov_x is not None and numpy.isfinite(cov_x).all():
174            stderr = numpy.sqrt(numpy.diag(cov_x))
175        else:
176            stderr = None
177
178        if not (numpy.isnan(out).any()) and (cov_x != None):
179            result.fitness = chisqr
180            result.stderr  = stderr
181            result.pvec = out
182            result.success = success
183            if q is not None:
184                q.put(result)
185                return q
186            return result
187       
188        # Error will be present to the client, not here
189        #else: 
190        #    raise ValueError, "SVD did not converge" + str(mesg)
191       
192    def _check_param_range(self, model):
193        """
194        Check parameter range and set the initial value inside
195        if it is out of range.
196       
197        : model: park model object
198        """
199        is_outofbound = False
200        # loop through parameterset
201        for p in model.parameterset:       
202            param_name = p.get_name()
203            # proceed only if the parameter name is in the list of fitting
204            if param_name in self.param_list:
205                # if the range was defined, check the range
206                if numpy.isfinite(p.range[0]):
207                    if p.value <= p.range[0]: 
208                        # 10 % backing up from the border if not zero
209                        # for Scipy engine to work properly.
210                        shift = self._get_zero_shift(p.range[0])
211                        new_value = p.range[0] + shift
212                        p.value =  new_value
213                        is_outofbound = True
214                if numpy.isfinite(p.range[1]):
215                    if p.value >= p.range[1]:
216                        shift = self._get_zero_shift(p.range[1])
217                        # 10 % backing up from the border if not zero
218                        # for Scipy engine to work properly.
219                        new_value = p.range[1] - shift
220                        # Check one more time if the new value goes below
221                        # the low bound, If so, re-evaluate the value
222                        # with the mean of the range.
223                        if numpy.isfinite(p.range[0]):
224                            if new_value < p.range[0]:
225                                new_value = (p.range[0] + p.range[1]) / 2.0
226                        # Todo:
227                        # Need to think about when both min and max are same.
228                        p.value =  new_value
229                        is_outofbound = True
230                       
231        return is_outofbound
232   
233    def _get_zero_shift(self, range):
234        """
235        Get 10% shift of the param value = 0 based on the range value
236       
237        : param range: min or max value of the bounds
238        """
239        if range == 0:
240            shift = 0.1
241        else:
242            shift = 0.1 * range
243           
244        return shift
245   
246   
247#def profile(fn, *args, **kw):
248#    import cProfile, pstats, os
249#    global call_result
250#   def call():
251#        global call_result
252#        call_result = fn(*args, **kw)
253#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
254#    stats = pstats.Stats('profile.out')
255#    stats.sort_stats('time')
256#    stats.sort_stats('calls')
257#    stats.print_stats()
258#    os.unlink('profile.out')
259#    return call_result
260
261     
Note: See TracBrowser for help on using the repository browser.