Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sans/fit/ScipyFitting.py

    r5777106 r8d074d9  
    1  
    2  
    31""" 
    42ScipyFitting module contains FitArrange , ScipyFit, 
     
    64simple fit with scipy optimizer. 
    75""" 
     6import sys 
     7import copy 
    88 
    99import numpy  
    10 import sys 
    11  
    1210 
    1311from sans.fit.AbstractFitEngine import FitEngine 
    14 from sans.fit.AbstractFitEngine import SansAssembly 
    15 from sans.fit.AbstractFitEngine import FitAbort 
    16 from sans.fit.AbstractFitEngine import Model 
    17 from sans.fit.AbstractFitEngine import FResult  
     12from sans.fit.AbstractFitEngine import FResult 
     13 
     14class SansAssembly: 
     15    """ 
     16    Sans Assembly class a class wrapper to be call in optimizer.leastsq method 
     17    """ 
     18    def __init__(self, paramlist, model=None, data=None, fitresult=None, 
     19                 handler=None, curr_thread=None, msg_q=None): 
     20        """ 
     21        :param Model: the model wrapper fro sans -model 
     22        :param Data: the data wrapper for sans data 
     23 
     24        """ 
     25        self.model = model 
     26        self.data = data 
     27        self.paramlist = paramlist 
     28        self.msg_q = msg_q 
     29        self.curr_thread = curr_thread 
     30        self.handler = handler 
     31        self.fitresult = fitresult 
     32        self.res = [] 
     33        self.true_res = [] 
     34        self.func_name = "Functor" 
     35        self.theory = None 
     36 
     37    def chisq(self): 
     38        """ 
     39        Calculates chi^2 
     40 
     41        :param params: list of parameter values 
     42 
     43        :return: chi^2 
     44 
     45        """ 
     46        total = 0 
     47        for item in self.true_res: 
     48            total += item * item 
     49        if len(self.true_res) == 0: 
     50            return None 
     51        return total / (len(self.true_res) - len(self.paramlist)) 
     52 
     53    def __call__(self, params): 
     54        """ 
     55            Compute residuals 
     56            :param params: value of parameters to fit 
     57        """ 
     58        #import thread 
     59        self.model.set_params(self.paramlist, params) 
     60        #print "params", params 
     61        self.true_res, theory = self.data.residuals(self.model.eval) 
     62        self.theory = copy.deepcopy(theory) 
     63        # check parameters range 
     64        if self.check_param_range(): 
     65            # if the param value is outside of the bound 
     66            # just silent return res = inf 
     67            return self.res 
     68        self.res = self.true_res 
     69 
     70        if self.fitresult is not None: 
     71            self.fitresult.set_model(model=self.model) 
     72            self.fitresult.residuals = self.true_res 
     73            self.fitresult.iterations += 1 
     74            self.fitresult.theory = theory 
     75 
     76            #fitness = self.chisq(params=params) 
     77            fitness = self.chisq() 
     78            self.fitresult.pvec = params 
     79            self.fitresult.set_fitness(fitness=fitness) 
     80            if self.msg_q is not None: 
     81                self.msg_q.put(self.fitresult) 
     82 
     83            if self.handler is not None: 
     84                self.handler.set_result(result=self.fitresult) 
     85                self.handler.update_fit() 
     86 
     87            if self.curr_thread != None: 
     88                try: 
     89                    self.curr_thread.isquit() 
     90                except: 
     91                    #msg = "Fitting: Terminated...       Note: Forcing to stop " 
     92                    #msg += "fitting may cause a 'Functor error message' " 
     93                    #msg += "being recorded in the log file....." 
     94                    #self.handler.stop(msg) 
     95                    raise 
     96 
     97        return self.res 
     98 
     99    def check_param_range(self): 
     100        """ 
     101        Check the lower and upper bound of the parameter value 
     102        and set res to the inf if the value is outside of the 
     103        range 
     104        :limitation: the initial values must be within range. 
     105        """ 
     106 
     107        #time.sleep(0.01) 
     108        is_outofbound = False 
     109        # loop through the fit parameters 
     110        model = self.model.model 
     111        for p in self.paramlist: 
     112            value = model.getParam(p) 
     113            low,high = model.details[p][1:3] 
     114            if low is not None and numpy.isfinite(low): 
     115                if p.value == 0: 
     116                    # This value works on Scipy 
     117                    # Do not change numbers below 
     118                    value = _SMALLVALUE 
     119                # For leastsq, it needs a bit step back from the boundary 
     120                val = low - value * _SMALLVALUE 
     121                if value < val: 
     122                    self.res *= 1e+6 
     123                    is_outofbound = True 
     124                    break 
     125            if high is not None and numpy.isfinite(high): 
     126                # This value works on Scipy 
     127                # Do not change numbers below 
     128                if value == 0: 
     129                    value = _SMALLVALUE 
     130                # For leastsq, it needs a bit step back from the boundary 
     131                val = high + value * _SMALLVALUE 
     132                if value > val: 
     133                    self.res *= 1e+6 
     134                    is_outofbound = True 
     135                    break 
     136 
     137        return is_outofbound 
    18138 
    19139class ScipyFit(FitEngine): 
     
    50170        """ 
    51171        FitEngine.__init__(self) 
    52         self.fit_arrange_dict = {} 
    53         self.param_list = [] 
    54172        self.curr_thread = None 
    55173    #def fit(self, *args, **kw): 
     
    68186            msg = "Scipy can't fit more than a single fit problem at a time." 
    69187            raise RuntimeError, msg 
    70             return 
    71         elif len(fitproblem) == 0 :  
     188        elif len(fitproblem) == 0 : 
    72189            raise RuntimeError, "No Assembly scheduled for Scipy fitting." 
    73             return 
    74190        model = fitproblem[0].get_model() 
    75191        if reset_flag: 
     
    87203         
    88204        # Check the initial value if it is within range 
    89         self._check_param_range(model) 
     205        _check_param_range(model.model, self.param_list) 
    90206         
    91         result = FResult(model=model, data=data, param_list=self.param_list) 
    92         result.pars = fitproblem[0].pars 
     207        result = FResult(model=model.model, data=data, param_list=self.param_list) 
    93208        result.fitter_id = self.fitter_id 
    94209        if handler is not None: 
    95210            handler.set_result(result=result) 
     211        functor = SansAssembly(paramlist=self.param_list, 
     212                               model=model, 
     213                               data=data, 
     214                               handler=handler, 
     215                               fitresult=result, 
     216                               curr_thread=curr_thread, 
     217                               msg_q=msg_q) 
    96218        try: 
    97219            # This import must be here; otherwise it will be confused when more 
     
    99221            from scipy import optimize 
    100222             
    101             functor = SansAssembly(paramlist=self.param_list,  
    102                                    model=model,  
    103                                    data=data, 
    104                                     handler=handler, 
    105                                     fitresult=result, 
    106                                      curr_thread=curr_thread, 
    107                                      msg_q=msg_q) 
    108223            out, cov_x, _, mesg, success = optimize.leastsq(functor, 
    109224                                            model.get_params(self.param_list), 
    110                                                     ftol=ftol, 
    111                                                     full_output=1) 
     225                                            ftol=ftol, 
     226                                            full_output=1) 
    112227        except: 
    113228            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt: 
     
    142257 
    143258         
    144     def _check_param_range(self, model): 
    145         """ 
    146         Check parameter range and set the initial value inside  
    147         if it is out of range. 
    148          
    149         : model: park model object 
    150         """ 
    151         is_outofbound = False 
    152         # loop through parameterset 
    153         for p in model.parameterset:         
    154             param_name = p.get_name() 
    155             # proceed only if the parameter name is in the list of fitting 
    156             if param_name in self.param_list: 
    157                 # if the range was defined, check the range 
    158                 if numpy.isfinite(p.range[0]): 
    159                     if p.value <= p.range[0]:  
    160                         # 10 % backing up from the border if not zero 
    161                         # for Scipy engine to work properly. 
    162                         shift = self._get_zero_shift(p.range[0]) 
    163                         new_value = p.range[0] + shift 
    164                         p.value =  new_value 
    165                         is_outofbound = True 
    166                 if numpy.isfinite(p.range[1]): 
    167                     if p.value >= p.range[1]: 
    168                         shift = self._get_zero_shift(p.range[1]) 
    169                         # 10 % backing up from the border if not zero 
    170                         # for Scipy engine to work properly. 
    171                         new_value = p.range[1] - shift 
    172                         # Check one more time if the new value goes below 
    173                         # the low bound, If so, re-evaluate the value  
    174                         # with the mean of the range. 
    175                         if numpy.isfinite(p.range[0]): 
    176                             if new_value < p.range[0]: 
    177                                 new_value = (p.range[0] + p.range[1]) / 2.0 
    178                         # Todo:  
    179                         # Need to think about when both min and max are same. 
    180                         p.value =  new_value 
    181                         is_outofbound = True 
    182                          
    183         return is_outofbound 
    184      
    185     def _get_zero_shift(self, range): 
    186         """ 
    187         Get 10% shift of the param value = 0 based on the range value 
    188          
    189         : param range: min or max value of the bounds 
    190         """ 
    191         if range == 0: 
    192             shift = 0.1 
    193         else: 
    194             shift = 0.1 * range 
    195              
    196         return shift 
    197      
     259def _check_param_range(model, param_list): 
     260    """ 
     261    Check parameter range and set the initial value inside 
     262    if it is out of range. 
     263 
     264    : model: park model object 
     265    """ 
     266    # loop through parameterset 
     267    for p in param_list: 
     268        value = model.getParam(p) 
     269        low,high = model.details.setdefault(p,["",None,None])[1:3] 
     270        # if the range was defined, check the range 
     271        if low is not None and value <= low: 
     272            value = low + _get_zero_shift(low) 
     273        if high is not None and value > high: 
     274            value = high - _get_zero_shift(high) 
     275            # Check one more time if the new value goes below 
     276            # the low bound, If so, re-evaluate the value 
     277            # with the mean of the range. 
     278            if low is not None and value < low: 
     279                value = 0.5 * (low+high) 
     280        model.setParam(p, value) 
     281 
     282def _get_zero_shift(limit): 
     283    """ 
     284    Get 10% shift of the param value = 0 based on the range value 
     285 
     286    : param range: min or max value of the bounds 
     287    """ 
     288    return 0.1 * (limit if limit != 0.0 else 1.0) 
     289 
    198290     
    199291#def profile(fn, *args, **kw): 
Note: See TracChangeset for help on using the changeset viewer.