source: sasview/park_integration/src/sans/fit/ScipyFitting.py @ 47d9be79

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 47d9be79 was d8661fb, checked in by Gervaise Alina <gervyh@…>, 13 years ago

make sure the fitresult know all information

  • Property mode set to 100644
File size: 7.8 KB
Line 
1
2
3"""
4ScipyFitting module contains FitArrange , ScipyFit,
5Parameter classes.All listed classes work together to perform a
6simple fit with scipy optimizer.
7"""
8
9import numpy 
10import sys
11
12
13from sans.fit.AbstractFitEngine import FitEngine
14from sans.fit.AbstractFitEngine import SansAssembly
15from sans.fit.AbstractFitEngine import FitAbort
16from sans.fit.AbstractFitEngine import Model
17from sans.fit.AbstractFitEngine import FResult
18
19class ScipyFit(FitEngine):
20    """
21    ScipyFit performs the Fit.This class can be used as follow:
22    #Do the fit SCIPY
23    create an engine: engine = ScipyFit()
24    Use data must be of type plottable
25    Use a sans model
26   
27    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
28    is saved in FitArrange object.
29    engine.set_data(data,Uid)
30   
31    Set model parameter "M1"= model.name add {model.parameter.name:value}.
32   
33    :note: Set_param() if used must always preceded set_model()
34         for the fit to be performed.In case of Scipyfit set_param is called in
35         fit () automatically.
36   
37    engine.set_param( model,"M1", {'A':2,'B':4})
38   
39    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
40    is save in FitArrange object.
41    engine.set_model(model,Uid)
42   
43    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
44    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
45    """
46    def __init__(self):
47        """
48        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
49        with Uid as keys
50        """
51        FitEngine.__init__(self)
52        self.fit_arrange_dict = {}
53        self.param_list = []
54        self.curr_thread = None
55    #def fit(self, *args, **kw):
56    #    return profile(self._fit, *args, **kw)
57
58    def fit(self, q=None, handler=None, curr_thread=None, ftol=1.49012e-8):
59        """
60        """
61        fitproblem = []
62        for fproblem in self.fit_arrange_dict.itervalues():
63            if fproblem.get_to_fit() == 1:
64                fitproblem.append(fproblem)
65        if len(fitproblem) > 1 : 
66            msg = "Scipy can't fit more than a single fit problem at a time."
67            raise RuntimeError, msg
68            return
69        elif len(fitproblem) == 0 : 
70            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
71            return
72   
73        listdata = []
74        model = fitproblem[0].get_model()
75        listdata = fitproblem[0].get_data()
76        # Concatenate dList set (contains one or more data)before fitting
77        data = listdata
78       
79        self.curr_thread = curr_thread
80        ftol = ftol
81       
82        # Check the initial value if it is within range
83        self._check_param_range(model)
84       
85        result = FResult(model=model, data=data, param_list=self.param_list)
86        if handler is not None:
87            handler.set_result(result=result)
88        try:
89            # This import must be here; otherwise it will be confused when more
90            # than one thread exist.
91            from scipy import optimize
92           
93            functor = SansAssembly(self.param_list, model, data, handler=handler,\
94                         fitresult=result, curr_thread= curr_thread)
95            out, cov_x, _, mesg, success = optimize.leastsq(functor,
96                                            model.get_params(self.param_list),
97                                                    ftol=ftol,
98                                                    full_output=1,
99                                                    warning=True)
100
101        except KeyboardInterrupt:
102            msg = "Fitting: Terminated!!!"
103            handler.error(msg)
104            raise KeyboardInterrupt, msg #<= more stable
105            #less stable below
106            """
107            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
108                if handler is not None:
109                    msg = "Fitting: Terminated!!!"
110                    handler.error(msg)
111                    result = handler.get_result()
112                    return result
113            else:
114                raise
115            """
116        except:
117            raise
118        chisqr = functor.chisq()
119        if cov_x is not None and numpy.isfinite(cov_x).all():
120            stderr = numpy.sqrt(numpy.diag(cov_x))
121        else:
122            stderr = None
123           
124        result.index = data.idx
125        if not (numpy.isnan(out).any()) and (cov_x != None):
126            result.fitness = chisqr
127            result.stderr  = stderr
128            result.pvec = out
129            result.success = success
130            result.theory = functor.theory
131            #print "scipy", result.inputs
132            if q is not None:
133                q.put(result)
134                return q
135            if success < 1 or success > 5:
136                result = None
137        return [result]
138        """
139        else:
140            return None
141        """
142        # Error will be present to the client, not here
143        #else: 
144        #    raise ValueError, "SVD did not converge" + str(mesg)
145       
146    def _check_param_range(self, model):
147        """
148        Check parameter range and set the initial value inside
149        if it is out of range.
150       
151        : model: park model object
152        """
153        is_outofbound = False
154        # loop through parameterset
155        for p in model.parameterset:       
156            param_name = p.get_name()
157            # proceed only if the parameter name is in the list of fitting
158            if param_name in self.param_list:
159                # if the range was defined, check the range
160                if numpy.isfinite(p.range[0]):
161                    if p.value <= p.range[0]: 
162                        # 10 % backing up from the border if not zero
163                        # for Scipy engine to work properly.
164                        shift = self._get_zero_shift(p.range[0])
165                        new_value = p.range[0] + shift
166                        p.value =  new_value
167                        is_outofbound = True
168                if numpy.isfinite(p.range[1]):
169                    if p.value >= p.range[1]:
170                        shift = self._get_zero_shift(p.range[1])
171                        # 10 % backing up from the border if not zero
172                        # for Scipy engine to work properly.
173                        new_value = p.range[1] - shift
174                        # Check one more time if the new value goes below
175                        # the low bound, If so, re-evaluate the value
176                        # with the mean of the range.
177                        if numpy.isfinite(p.range[0]):
178                            if new_value < p.range[0]:
179                                new_value = (p.range[0] + p.range[1]) / 2.0
180                        # Todo:
181                        # Need to think about when both min and max are same.
182                        p.value =  new_value
183                        is_outofbound = True
184                       
185        return is_outofbound
186   
187    def _get_zero_shift(self, range):
188        """
189        Get 10% shift of the param value = 0 based on the range value
190       
191        : param range: min or max value of the bounds
192        """
193        if range == 0:
194            shift = 0.1
195        else:
196            shift = 0.1 * range
197           
198        return shift
199   
200   
201#def profile(fn, *args, **kw):
202#    import cProfile, pstats, os
203#    global call_result
204#   def call():
205#        global call_result
206#        call_result = fn(*args, **kw)
207#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
208#    stats = pstats.Stats('profile.out')
209#    stats.sort_stats('time')
210#    stats.sort_stats('calls')
211#    stats.print_stats()
212#    os.unlink('profile.out')
213#    return call_result
214
215     
Note: See TracBrowser for help on using the repository browser.