source: sasview/park_integration/ScipyFitting.py @ 858d6ee

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 858d6ee was e0e22f2c, checked in by Jae Cho <jhjcho@…>, 14 years ago

got param_range error correctly

  • Property mode set to 100644
File size: 8.8 KB
Line 
1
2
3"""
4ScipyFitting module contains FitArrange , ScipyFit,
5Parameter classes.All listed classes work together to perform a
6simple fit with scipy optimizer.
7"""
8
9import numpy 
10import sys
11from scipy import optimize
12
13from sans.fit.AbstractFitEngine import FitEngine
14from sans.fit.AbstractFitEngine import SansAssembly
15from sans.fit.AbstractFitEngine import FitAbort
16
17class fitresult(object):
18    """
19    Storing fit result
20    """
21    def __init__(self, model=None, param_list=None):
22        self.calls = None
23        self.fitness = None
24        self.chisqr = None
25        self.pvec = None
26        self.cov = None
27        self.info = None
28        self.mesg = None
29        self.success = None
30        self.stderr = None
31        self.parameters = None
32        self.model = model
33        self.param_list = param_list
34        self.iterations = 0
35     
36    def set_model(self, model):
37        """
38        """
39        self.model = model
40       
41    def set_fitness(self, fitness):
42        """
43        """
44        self.fitness = fitness
45       
46    def __str__(self):
47        """
48        """
49        if self.pvec == None and self.model is None and self.param_list is None:
50            return "No results"
51        n = len(self.model.parameterset)
52        self.iterations += 1
53        result_param = zip(xrange(n), self.model.parameterset)
54        msg1 = ["[Iteration #: %s ]" % self.iterations]
55        msg2 = ["P%-3d  %s......|.....%s" % \
56                (p[0], p[1], p[1].value)\
57              for p in result_param if p[1].name in self.param_list]
58       
59        msg3 = ["=== goodness of fit: %s ===" % (str(self.fitness))]
60        msg =  msg1 + msg3 + msg2
61        return "\n".join(msg)
62   
63    def print_summary(self):
64        """
65        """
66        print self   
67
68class ScipyFit(FitEngine):
69    """
70    ScipyFit performs the Fit.This class can be used as follow:
71    #Do the fit SCIPY
72    create an engine: engine = ScipyFit()
73    Use data must be of type plottable
74    Use a sans model
75   
76    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
77    is saved in FitArrange object.
78    engine.set_data(data,Uid)
79   
80    Set model parameter "M1"= model.name add {model.parameter.name:value}.
81   
82    :note: Set_param() if used must always preceded set_model()
83         for the fit to be performed.In case of Scipyfit set_param is called in
84         fit () automatically.
85   
86    engine.set_param( model,"M1", {'A':2,'B':4})
87   
88    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
89    is save in FitArrange object.
90    engine.set_model(model,Uid)
91   
92    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
93    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
94    """
95    def __init__(self):
96        """
97        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
98        with Uid as keys
99        """
100        FitEngine.__init__(self)
101        self.fit_arrange_dict = {}
102        self.param_list = []
103        self.curr_thread = None
104    #def fit(self, *args, **kw):
105    #    return profile(self._fit, *args, **kw)
106
107    def fit(self, q=None, handler=None, curr_thread=None, ftol=1.49012e-8):
108        """
109        """
110        fitproblem = []
111        for fproblem in self.fit_arrange_dict.itervalues():
112            if fproblem.get_to_fit() == 1:
113                fitproblem.append(fproblem)
114        if len(fitproblem) > 1 : 
115            msg = "Scipy can't fit more than a single fit problem at a time."
116            raise RuntimeError, msg
117            return
118        elif len(fitproblem) == 0 : 
119            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
120            return
121   
122        listdata = []
123        model = fitproblem[0].get_model()
124        listdata = fitproblem[0].get_data()
125        # Concatenate dList set (contains one or more data)before fitting
126        data = listdata
127       
128        self.curr_thread = curr_thread
129        ftol = ftol
130       
131        # Check the initial value if it is within range
132        self._check_param_range(model)
133       
134        result = fitresult(model=model, param_list=self.param_list)
135        if handler is not None:
136            handler.set_result(result=result)
137        #try:
138        functor = SansAssembly(self.param_list, model, data, handler=handler,
139                         fitresult=result, curr_thread= self.curr_thread)
140        try:
141                out, cov_x, _, mesg, success = optimize.leastsq(functor,
142                                            model.get_params(self.param_list),
143                                                    ftol=ftol,
144                                                    full_output=1,
145                                                    warning=True)
146        except KeyboardInterrupt:
147            msg = "Fitting: Terminated!!!"
148            handler.error(msg)
149            raise KeyboardInterrupt, msg #<= more stable
150            #less stable below
151            """
152            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
153                if handler is not None:
154                    msg = "Fitting: Terminated!!!"
155                    handler.error(msg)
156                    result = handler.get_result()
157                    return result
158            else:
159                raise
160            """
161        except:
162            raise
163       
164        chisqr = functor.chisq()
165        if cov_x is not None and numpy.isfinite(cov_x).all():
166            stderr = numpy.sqrt(numpy.diag(cov_x))
167        else:
168            stderr = None
169
170        if not (numpy.isnan(out).any()) and (cov_x != None):
171            result.fitness = chisqr
172            result.stderr  = stderr
173            result.pvec = out
174            result.success = success
175            if q is not None:
176                q.put(result)
177                return q
178            return result
179       
180        # Error will be present to the client, not here
181        #else: 
182        #    raise ValueError, "SVD did not converge" + str(mesg)
183       
184    def _check_param_range(self, model):
185        """
186        Check parameter range and set the initial value inside
187        if it is out of range.
188       
189        : model: park model object
190        """
191        is_outofbound = False
192        # loop through parameterset
193        for p in model.parameterset:       
194            param_name = p.get_name()
195            # proceed only if the parameter name is in the list of fitting
196            if param_name in self.param_list:
197                # if the range was defined, check the range
198                if numpy.isfinite(p.range[0]):
199                    if p.value <= p.range[0]: 
200                        # 10 % backing up from the border if not zero
201                        # for Scipy engine to work properly.
202                        shift = self._get_zero_shift(p.range[0])
203                        new_value = p.range[0] + shift
204                        p.value =  new_value
205                        is_outofbound = True
206                if numpy.isfinite(p.range[1]):
207                    if p.value >= p.range[1]:
208                        shift = self._get_zero_shift(p.range[1])
209                        # 10 % backing up from the border if not zero
210                        # for Scipy engine to work properly.
211                        new_value = p.range[1] - shift
212                        # Check one more time if the new value goes below
213                        # the low bound, If so, re-evaluate the value
214                        # with the mean of the range.
215                        if numpy.isfinite(p.range[0]):
216                            if new_value < p.range[0]:
217                                new_value = (p.range[0] + p.range[1]) / 2.0
218                        # Todo:
219                        # Need to think about when both min and max are same.
220                        p.value =  new_value
221                        is_outofbound = True
222                       
223        return is_outofbound
224   
225    def _get_zero_shift(self, range):
226        """
227        Get 10% shift of the param value = 0 based on the range value
228       
229        : param range: min or max value of the bounds
230        """
231        if range == 0:
232            shift = 0.1
233        else:
234            shift = 0.1 * range
235           
236        return shift
237   
238   
239#def profile(fn, *args, **kw):
240#    import cProfile, pstats, os
241#    global call_result
242#   def call():
243#        global call_result
244#        call_result = fn(*args, **kw)
245#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
246#    stats = pstats.Stats('profile.out')
247#    stats.sort_stats('time')
248#    stats.sort_stats('calls')
249#    stats.print_stats()
250#    os.unlink('profile.out')
251#    return call_result
252
253     
Note: See TracBrowser for help on using the repository browser.