Changeset 852354c8 in sasview


Ignore:
Timestamp:
Mar 29, 2011 9:45:24 AM (14 years ago)
Author:
Jae Cho <jhjcho@…>
Branches:
master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
Children:
9466f2d6
Parents:
4b5bd73
Message:

added constraint bounding check

File:
1 edited

Legend:

Unmodified
Added
Removed
  • park_integration/ScipyFitting.py

    r9a608ed r852354c8  
    1414from sans.fit.AbstractFitEngine import SansAssembly 
    1515from sans.fit.AbstractFitEngine import FitAbort 
    16  
    1716 
    1817class fitresult(object): 
     
    5352        self.iterations += 1 
    5453        result_param = zip(xrange(n), self.model.parameterset) 
    55         msg = [" [Iteration #: %s] | P%-3d  %s......|.....%s" % \ 
    56                (self.iterations, p[0], p[1], p[1].value)\ 
     54        msg1 = ["[Iteration #: %s ]" % self.iterations] 
     55        msg2 = ["P%-3d  %s......|.....%s" % \ 
     56                (p[0], p[1], p[1].value)\ 
    5757              for p in result_param if p[1].name in self.param_list] 
    58         msg.append("=== goodness of fit: %s" % (str(self.fitness))) 
     58         
     59        msg3 = ["=== goodness of fit: %s ===" % (str(self.fitness))] 
     60        msg =  msg1 + msg3 + msg2 
    5961        return "\n".join(msg) 
    6062     
     
    100102        self.param_list = [] 
    101103        self.curr_thread = None 
    102         self.result = None 
    103104    #def fit(self, *args, **kw): 
    104105    #    return profile(self._fit, *args, **kw) 
     
    124125        # Concatenate dList set (contains one or more data)before fitting 
    125126        data = listdata 
     127        
    126128        self.curr_thread = curr_thread 
    127         self.result = fitresult(model=model, param_list=self.param_list) 
    128         self.handler = handler 
    129         if self.handler is not None: 
    130             self.handler.set_result(result=self.result) 
     129        ftol = curr_thread.ftol 
     130         
     131        # Check the initial value if it is within range 
     132        self._check_param_range(model) 
     133         
     134        result = fitresult(model=model, param_list=self.param_list) 
     135        if handler is not None: 
     136            handler.set_result(result=result) 
    131137        #try: 
    132         functor = SansAssembly(self.param_list, model, data, handler=self.handler, 
    133                          fitresult=self.result, curr_thread= self.curr_thread) 
    134      
     138        functor = SansAssembly(self.param_list, model, data, handler=handler, 
     139                         fitresult=result, curr_thread= self.curr_thread) 
    135140        try: 
    136             out, cov_x, _, mesg, success = optimize.leastsq(functor, 
     141                out, cov_x, _, mesg, success = optimize.leastsq(functor, 
    137142                                            model.get_params(self.param_list), 
    138                                                     ftol = 0.001, 
     143                                                    ftol=ftol, 
    139144                                                    full_output=1, 
    140145                                                    warning=True) 
    141146        except: 
    142147            if hasattr(sys, 'last_type') and sys.last_type == FitAbort: 
    143                 if self.handler is not None: 
     148                if handler is not None: 
    144149                    msg = "Fit Stop!" 
    145150                    #self.handler.error(msg) 
    146                     self.result = self.handler.get_result() 
    147                     return self.result 
     151                    result = handler.get_result() 
     152                    return result 
    148153            else: 
    149154                raise  
     
    154159        else: 
    155160            stderr = None 
    156          
    157         if (out is not None) and not (numpy.isnan(out).any()) \ 
    158             and (cov_x != None): 
    159             self.result.fitness = chisqr 
    160             self.result.stderr  = stderr 
    161             self.result.pvec = out 
    162             self.result.success = success 
    163         else:   
    164             msg = "SVD did not converge " + str(mesg) 
    165             #handler.error(msg) 
    166         return self.result 
    167  
    168         
    169      
    170  
    171  
     161 
     162        if not (numpy.isnan(out).any()) and (cov_x != None): 
     163            result.fitness = chisqr 
     164            result.stderr  = stderr 
     165            result.pvec = out 
     166            result.success = success 
     167            if q is not None: 
     168                q.put(result) 
     169                return q 
     170            return result 
     171         
     172        # Error will be present to the client, not here  
     173        #else:   
     174        #    raise ValueError, "SVD did not converge" + str(mesg) 
     175         
     176    def _check_param_range(self, model): 
     177        """ 
     178        Check parameter range and set the initial value inside  
     179        if it is out of range. 
     180         
     181        : model: park model object 
     182        """ 
     183        is_outofbound = False 
     184        # loop through parameterset 
     185        for p in model.parameterset:         
     186            param_name = p.get_name() 
     187            # proceed only if the parameter name is in the list of fitting 
     188            if param_name in self.param_list: 
     189                # if the range was defined, check the range 
     190                if numpy.isfinite(p.range[0]): 
     191                    if p.value <= p.range[0]:  
     192                        # 10 % backing up from the border if not zero 
     193                        # for Scipy engine to work properly. 
     194                        shift = self._get_zero_shift(p.range[0]) 
     195                        new_value = p.range[0] + shift 
     196                        p.value =  new_value 
     197                        is_outofbound = True 
     198                if numpy.isfinite(p.range[1]): 
     199                    if p.value >= p.range[1]: 
     200                        shift = self._get_zero_shift(p.range[1]) 
     201                        # 10 % backing up from the border if not zero 
     202                        # for Scipy engine to work properly. 
     203                        new_value = p.range[1] - shift 
     204                        # Check one more time if the new value goes below 
     205                        # the low bound, If so, re-evaluate the value  
     206                        # with the mean of the range. 
     207                        if numpy.isfinite(p.range[0]): 
     208                            if new_value < p.range[0]: 
     209                                new_value = (p.range[0] + p.range[1]) / 2.0 
     210                        # Todo:  
     211                        # Need to think about when both min and max are same. 
     212                        p.value =  new_value 
     213                        is_outofbound = True 
     214                         
     215        return is_outofbound 
     216     
     217    def _get_zero_shift(self, range): 
     218        """ 
     219        Get 10% shift of the param value = 0 based on the range value 
     220         
     221        : param range: min or max value of the bounds 
     222        """ 
     223        if range == 0: 
     224            shift = 0.1 
     225        else: 
     226            shift = 0.1 * range 
     227             
     228        return shift 
     229     
     230     
    172231#def profile(fn, *args, **kw): 
    173232#    import cProfile, pstats, os 
Note: See TracChangeset for help on using the changeset viewer.