source: sasview/park_integration/src/sans/fit/ScipyFitting.py @ b5fe787

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b5fe787 was 7db52f1, checked in by Jae Cho <jhjcho@…>, 13 years ago

now able to reset the model init param values in batch model

  • Property mode set to 100644
File size: 8.1 KB
RevLine 
[aa36f96]1
2
[792db7d5]3"""
[aa36f96]4ScipyFitting module contains FitArrange , ScipyFit,
5Parameter classes.All listed classes work together to perform a
6simple fit with scipy optimizer.
[792db7d5]7"""
[61cb28d]8
[88b5e83]9import numpy 
[511c6810]10import sys
[2446b66]11
[7705306]12
[b2f25dc5]13from sans.fit.AbstractFitEngine import FitEngine
14from sans.fit.AbstractFitEngine import SansAssembly
[511c6810]15from sans.fit.AbstractFitEngine import FitAbort
[634ca14]16from sans.fit.AbstractFitEngine import Model
[444c900e]17from sans.fit.AbstractFitEngine import FResult
[88b5e83]18
[4c718654]19class ScipyFit(FitEngine):
[7705306]20    """
[aa36f96]21    ScipyFit performs the Fit.This class can be used as follow:
22    #Do the fit SCIPY
23    create an engine: engine = ScipyFit()
24    Use data must be of type plottable
25    Use a sans model
26   
27    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
28    is saved in FitArrange object.
29    engine.set_data(data,Uid)
30   
31    Set model parameter "M1"= model.name add {model.parameter.name:value}.
32   
33    :note: Set_param() if used must always preceded set_model()
34         for the fit to be performed.In case of Scipyfit set_param is called in
35         fit () automatically.
36   
37    engine.set_param( model,"M1", {'A':2,'B':4})
38   
39    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
40    is save in FitArrange object.
41    engine.set_model(model,Uid)
42   
43    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
44    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
[7705306]45    """
[792db7d5]46    def __init__(self):
47        """
[b2f25dc5]48        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
[aa36f96]49        with Uid as keys
[792db7d5]50        """
[b2f25dc5]51        FitEngine.__init__(self)
52        self.fit_arrange_dict = {}
53        self.param_list = []
[c4d6900]54        self.curr_thread = None
[d9dc518]55    #def fit(self, *args, **kw):
56    #    return profile(self._fit, *args, **kw)
[393f0f3]57
[7db52f1]58    def fit(self, q=None, handler=None, curr_thread=None, 
59            ftol=1.49012e-8, reset_flag=False):
[aa36f96]60        """
61        """
[89f3b66]62        fitproblem = []
[c4d6900]63        for fproblem in self.fit_arrange_dict.itervalues():
[89f3b66]64            if fproblem.get_to_fit() == 1:
[393f0f3]65                fitproblem.append(fproblem)
[89f3b66]66        if len(fitproblem) > 1 : 
[e0072082]67            msg = "Scipy can't fit more than a single fit problem at a time."
68            raise RuntimeError, msg
[a9e04aa]69            return
[89f3b66]70        elif len(fitproblem) == 0 : 
[a9e04aa]71            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
72            return
[393f0f3]73        model = fitproblem[0].get_model()
[7db52f1]74        if reset_flag:
75            # reset the initial value; useful for batch
76            for name in fitproblem[0].pars:
77                ind = fitproblem[0].pars.index(name)
78                model.model.setParam(name, fitproblem[0].vals[ind])
79        listdata = []
[393f0f3]80        listdata = fitproblem[0].get_data()
[792db7d5]81        # Concatenate dList set (contains one or more data)before fitting
[e0072082]82        data = listdata
[852354c8]83       
[89f3b66]84        self.curr_thread = curr_thread
[93de635d]85        ftol = ftol
[852354c8]86       
87        # Check the initial value if it is within range
88        self._check_param_range(model)
89       
[444c900e]90        result = FResult(model=model, data=data, param_list=self.param_list)
[852354c8]91        if handler is not None:
92            handler.set_result(result=result)
[511c6810]93        try:
[2446b66]94            # This import must be here; otherwise it will be confused when more
95            # than one thread exist.
96            from scipy import optimize
97           
98            functor = SansAssembly(self.param_list, model, data, handler=handler,\
99                         fitresult=result, curr_thread= curr_thread)
[db427ec]100            out, cov_x, _, mesg, success = optimize.leastsq(functor,
[c4d6900]101                                            model.get_params(self.param_list),
[852354c8]102                                                    ftol=ftol,
[c4d6900]103                                                    full_output=1,
104                                                    warning=True)
[425e49ca]105
[acfff8b]106        except KeyboardInterrupt:
107            msg = "Fitting: Terminated!!!"
108            handler.error(msg)
109            raise KeyboardInterrupt, msg #<= more stable
110            #less stable below
111            """
112            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt:
[852354c8]113                if handler is not None:
[acfff8b]114                    msg = "Fitting: Terminated!!!"
115                    handler.error(msg)
[852354c8]116                    result = handler.get_result()
117                    return result
[511c6810]118            else:
119                raise
[acfff8b]120            """
[e0e22f2c]121        except:
122            raise
[c4d6900]123        chisqr = functor.chisq()
[fd6b789]124        if cov_x is not None and numpy.isfinite(cov_x).all():
125            stderr = numpy.sqrt(numpy.diag(cov_x))
126        else:
[e0072082]127            stderr = None
[d8661fb]128           
129        result.index = data.idx
[852354c8]130        if not (numpy.isnan(out).any()) and (cov_x != None):
131            result.fitness = chisqr
132            result.stderr  = stderr
133            result.pvec = out
134            result.success = success
[425e49ca]135            result.theory = functor.theory
[d91d2c9]136            #print "scipy", result.inputs
[a15da09]137            if q is not None:
[852354c8]138                q.put(result)
139                return q
[4fb520d]140            if success < 1 or success > 5:
[120d9f6]141                result = None
[444c900e]142        return [result]
143        """
[120d9f6]144        else:
145            return None
[444c900e]146        """
[852354c8]147        # Error will be present to the client, not here
148        #else: 
149        #    raise ValueError, "SVD did not converge" + str(mesg)
150       
151    def _check_param_range(self, model):
152        """
153        Check parameter range and set the initial value inside
154        if it is out of range.
155       
156        : model: park model object
157        """
158        is_outofbound = False
159        # loop through parameterset
160        for p in model.parameterset:       
161            param_name = p.get_name()
162            # proceed only if the parameter name is in the list of fitting
163            if param_name in self.param_list:
164                # if the range was defined, check the range
165                if numpy.isfinite(p.range[0]):
166                    if p.value <= p.range[0]: 
167                        # 10 % backing up from the border if not zero
168                        # for Scipy engine to work properly.
169                        shift = self._get_zero_shift(p.range[0])
170                        new_value = p.range[0] + shift
171                        p.value =  new_value
172                        is_outofbound = True
173                if numpy.isfinite(p.range[1]):
174                    if p.value >= p.range[1]:
175                        shift = self._get_zero_shift(p.range[1])
176                        # 10 % backing up from the border if not zero
177                        # for Scipy engine to work properly.
178                        new_value = p.range[1] - shift
179                        # Check one more time if the new value goes below
180                        # the low bound, If so, re-evaluate the value
181                        # with the mean of the range.
182                        if numpy.isfinite(p.range[0]):
183                            if new_value < p.range[0]:
184                                new_value = (p.range[0] + p.range[1]) / 2.0
185                        # Todo:
186                        # Need to think about when both min and max are same.
187                        p.value =  new_value
188                        is_outofbound = True
189                       
190        return is_outofbound
191   
192    def _get_zero_shift(self, range):
193        """
194        Get 10% shift of the param value = 0 based on the range value
195       
196        : param range: min or max value of the bounds
197        """
198        if range == 0:
199            shift = 0.1
200        else:
201            shift = 0.1 * range
202           
203        return shift
204   
[e0072082]205   
[c4d6900]206#def profile(fn, *args, **kw):
207#    import cProfile, pstats, os
208#    global call_result
209#   def call():
210#        global call_result
211#        call_result = fn(*args, **kw)
212#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
213#    stats = pstats.Stats('profile.out')
214#    stats.sort_stats('time')
215#    stats.sort_stats('calls')
216#    stats.print_stats()
217#    os.unlink('profile.out')
218#    return call_result
[9c648c7]219
[48882d1]220     
Note: See TracBrowser for help on using the repository browser.