source: sasview/park_integration/src/sans/fit/ScipyFitting.py @ ee53c72

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since ee53c72 was 425e49ca, checked in by Jae Cho <jhjcho@…>, 13 years ago

fixing batch plot problems

  • Property mode set to 100644
File size: 9.6 KB
Line 
1
2
3"""
4ScipyFitting module contains FitArrange , ScipyFit,
5Parameter classes.All listed classes work together to perform a
6simple fit with scipy optimizer.
7"""
8
9import numpy 
10import sys
11
12
13from sans.fit.AbstractFitEngine import FitEngine
14from sans.fit.AbstractFitEngine import SansAssembly
15from sans.fit.AbstractFitEngine import FitAbort
16from sans.fit.AbstractFitEngine import Model
17IS_MAC = True
18if sys.platform.count("win32") > 0:
19    IS_MAC = False
20   
21class fitresult(object):
22    """
23    Storing fit result
24    """
25    def __init__(self, model=None, param_list=None, data=None):
26        self.calls = None
27        self.fitness = None
28        self.chisqr = None
29        self.pvec = None
30        self.cov = None
31        self.info = None
32        self.mesg = None
33        self.success = None
34        self.stderr = None
35        self.parameters = None
36        self.is_mac = IS_MAC
37        if issubclass(model.__class__, Model):
38            model = model.model
39        self.model = model
40        self.data = data
41        self.theory = None
42        self.param_list = param_list
43        self.iterations = 0
44       
45        self.inputs = [(self.model, self.data)]
46     
47    def set_model(self, model):
48        """
49        """
50        self.model = model
51       
52    def set_fitness(self, fitness):
53        """
54        """
55        self.fitness = fitness
56       
57    def __str__(self):
58        """
59        """
60        if self.pvec == None and self.model is None and self.param_list is None:
61            return "No results"
62        n = len(self.model.parameterset)
63        self.iterations += 1
64        result_param = zip(xrange(n), self.model.parameterset)
65        msg1 = ["[Iteration #: %s ]" % self.iterations]
66        msg3 = ["=== goodness of fit: %s ===" % (str(self.fitness))]
67        if not self.is_mac:
68            msg2 = ["P%-3d  %s......|.....%s" % \
69                (p[0], p[1], p[1].value)\
70                  for p in result_param if p[1].name in self.param_list]
71            msg =  msg1 + msg3 + msg2
72        else:
73            msg = msg1 + msg3
74        msg = "\n".join(msg)
75        return msg
76   
77    def print_summary(self):
78        """
79        """
80        print self   
81
82class ScipyFit(FitEngine):
83    """
84    ScipyFit performs the Fit.This class can be used as follow:
85    #Do the fit SCIPY
86    create an engine: engine = ScipyFit()
87    Use data must be of type plottable
88    Use a sans model
89   
90    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
91    is saved in FitArrange object.
92    engine.set_data(data,Uid)
93   
94    Set model parameter "M1"= model.name add {model.parameter.name:value}.
95   
96    :note: Set_param() if used must always preceded set_model()
97         for the fit to be performed.In case of Scipyfit set_param is called in
98         fit () automatically.
99   
100    engine.set_param( model,"M1", {'A':2,'B':4})
101   
102    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
103    is save in FitArrange object.
104    engine.set_model(model,Uid)
105   
106    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
107    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
108    """
109    def __init__(self):
110        """
111        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
112        with Uid as keys
113        """
114        FitEngine.__init__(self)
115        self.fit_arrange_dict = {}
116        self.param_list = []
117        self.curr_thread = None
118    #def fit(self, *args, **kw):
119    #    return profile(self._fit, *args, **kw)
120
121    def fit(self, q=None, handler=None, curr_thread=None, ftol=1.49012e-8):
122        """
123        """
124        fitproblem = []
125        for fproblem in self.fit_arrange_dict.itervalues():
126            if fproblem.get_to_fit() == 1:
127                fitproblem.append(fproblem)
128        if len(fitproblem) > 1 : 
129            msg = "Scipy can't fit more than a single fit problem at a time."
130            raise RuntimeError, msg
131            return
132        elif len(fitproblem) == 0 : 
133            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
134            return
135   
136        listdata = []
137        model = fitproblem[0].get_model()
138        listdata = fitproblem[0].get_data()
139        # Concatenate dList set (contains one or more data)before fitting
140        data = listdata
141       
142        self.curr_thread = curr_thread
143        ftol = ftol
144       
145        # Check the initial value if it is within range
146        self._check_param_range(model)
147       
148        result = fitresult(model=model, data=data.sans_data, param_list=self.param_list)
149        if handler is not None:
150            handler.set_result(result=result)
151        try:
152            # This import must be here; otherwise it will be confused when more
153            # than one thread exist.
154            from scipy import optimize
155           
156            functor = SansAssembly(self.param_list, model, data, handler=handler,\
157                         fitresult=result, curr_thread= curr_thread)
158            out, cov_x, _, mesg, success = optimize.leastsq(functor,
159                                            model.get_params(self.param_list),
160                                                    ftol=ftol,
161                                                    full_output=1,
162                                                    warning=True)
163
164        except KeyboardInterrupt:
165            msg = "Fitting: Terminated!!!"
166            handler.error(msg)
167            raise KeyboardInterrupt, msg #<= more stable
168            #less stable below
169            """
170            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
171                if handler is not None:
172                    msg = "Fitting: Terminated!!!"
173                    handler.error(msg)
174                    result = handler.get_result()
175                    return result
176            else:
177                raise
178            """
179        except:
180            raise
181        chisqr = functor.chisq()
182        if cov_x is not None and numpy.isfinite(cov_x).all():
183            stderr = numpy.sqrt(numpy.diag(cov_x))
184        else:
185            stderr = None
186
187        if not (numpy.isnan(out).any()) and (cov_x != None):
188            result.fitness = chisqr
189            result.stderr  = stderr
190            result.pvec = out
191            result.success = success
192            result.theory = functor.theory
193            #print "scipy", result.inputs
194            if q is not None:
195                q.put(result)
196                return q
197            if success < 1 or success > 5:
198                result = None
199            return result
200        else:
201            return None
202        # Error will be present to the client, not here
203        #else: 
204        #    raise ValueError, "SVD did not converge" + str(mesg)
205       
206    def _check_param_range(self, model):
207        """
208        Check parameter range and set the initial value inside
209        if it is out of range.
210       
211        : model: park model object
212        """
213        is_outofbound = False
214        # loop through parameterset
215        for p in model.parameterset:       
216            param_name = p.get_name()
217            # proceed only if the parameter name is in the list of fitting
218            if param_name in self.param_list:
219                # if the range was defined, check the range
220                if numpy.isfinite(p.range[0]):
221                    if p.value <= p.range[0]: 
222                        # 10 % backing up from the border if not zero
223                        # for Scipy engine to work properly.
224                        shift = self._get_zero_shift(p.range[0])
225                        new_value = p.range[0] + shift
226                        p.value =  new_value
227                        is_outofbound = True
228                if numpy.isfinite(p.range[1]):
229                    if p.value >= p.range[1]:
230                        shift = self._get_zero_shift(p.range[1])
231                        # 10 % backing up from the border if not zero
232                        # for Scipy engine to work properly.
233                        new_value = p.range[1] - shift
234                        # Check one more time if the new value goes below
235                        # the low bound, If so, re-evaluate the value
236                        # with the mean of the range.
237                        if numpy.isfinite(p.range[0]):
238                            if new_value < p.range[0]:
239                                new_value = (p.range[0] + p.range[1]) / 2.0
240                        # Todo:
241                        # Need to think about when both min and max are same.
242                        p.value =  new_value
243                        is_outofbound = True
244                       
245        return is_outofbound
246   
247    def _get_zero_shift(self, range):
248        """
249        Get 10% shift of the param value = 0 based on the range value
250       
251        : param range: min or max value of the bounds
252        """
253        if range == 0:
254            shift = 0.1
255        else:
256            shift = 0.1 * range
257           
258        return shift
259   
260   
261#def profile(fn, *args, **kw):
262#    import cProfile, pstats, os
263#    global call_result
264#   def call():
265#        global call_result
266#        call_result = fn(*args, **kw)
267#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
268#    stats = pstats.Stats('profile.out')
269#    stats.sort_stats('time')
270#    stats.sort_stats('calls')
271#    stats.print_stats()
272#    os.unlink('profile.out')
273#    return call_result
274
275     
Note: See TracBrowser for help on using the repository browser.