Changeset e3efa6b3 in sasview for src


Ignore:
Timestamp:
May 15, 2014 11:23:22 AM (10 years ago)
Author:
pkienzle
Branches:
master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
Children:
4e9f227
Parents:
76f132a
Message:

restructure bumps wrapper and add levenberg-marquardt

Location:
src/sans
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • src/sans/fit/BumpsFitting.py

    r8d074d9 re3efa6b3  
    88from bumps import fitters 
    99from bumps.mapper import SerialMapper 
     10from bumps import parameter 
     11from bumps.fitproblem import FitProblem 
    1012 
    1113from sans.fit.AbstractFitEngine import FitEngine 
     
    2123 
    2224    def __call__(self, history): 
     25        if self.handler is None: return 
    2326        self.handler.progress(history.step[0], self.max_step) 
    2427        if len(history.step)>1 and history.step[1] > history.step[0]: 
     
    4649            self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1])) 
    4750        except: 
    48             self.convergence.append((best, )) 
    49  
    50 class SasProblem(object): 
    51     """ 
    52     Wrap the SAS model in a form that can be understood by bumps. 
    53     """ 
    54     def __init__(self, param_list, model=None, data=None, fitresult=None, 
    55                  handler=None, curr_thread=None, msg_q=None): 
    56         """ 
    57         :param Model: the model wrapper fro sans -model 
    58         :param Data: the data wrapper for sans data 
    59         """ 
     51            self.convergence.append((best, best,best,best,best,best)) 
     52 
     53 
     54class SasFitness(object): 
     55    """ 
     56    Wrap SAS model as a bumps fitness object 
     57    """ 
     58    def __init__(self, name, model, data, fitted=[], **kw): 
     59        self.name = name 
    6060        self.model = model 
    6161        self.data = data 
    62         self.param_list = param_list 
    63         self.res = None 
    64         self.theory = None 
    65  
    66     @property 
    67     def name(self): 
    68         return self.model.name 
    69  
    70     @property 
    71     def dof(self): 
    72         return self.data.num_points - len(self.param_list) 
    73  
    74     def summarize(self): 
    75         """ 
    76         Return a stylized list of parameter names and values with range bars 
    77         suitable for printing. 
    78         """ 
    79         output = [] 
    80         bounds = self.bounds() 
    81         for i,p in enumerate(self.getp()): 
    82             name = self.param_list[i] 
    83             low,high = bounds[:,i] 
    84             range = ",".join((("[%g"%low if numpy.isfinite(low) else "(-inf"), 
    85                               ("%g]"%high if numpy.isfinite(high) else "inf)"))) 
    86             if not numpy.isfinite(p): 
    87                 bar = "*invalid* " 
     62        self._define_pars() 
     63        self._init_pars(kw) 
     64        self.set_fitted(fitted) 
     65        self._dirty = True 
     66 
     67    def _define_pars(self): 
     68        self._pars = {} 
     69        for k in self.model.getParamList(): 
     70            name = ".".join((self.name,k)) 
     71            value = self.model.getParam(k) 
     72            bounds = self.model.details.get(k,["",None,None])[1:3] 
     73            self._pars[k] = parameter.Parameter(value=value, bounds=bounds, 
     74                                                fixed=True, name=name) 
     75 
     76    def _init_pars(self, kw): 
     77        for k,v in kw.items(): 
     78            # dispersion parameters initialized with _field instead of .field 
     79            if k.endswith('_width'): k = k[:-6]+'.width' 
     80            elif k.endswith('_npts'): k = k[:-5]+'.npts' 
     81            elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas' 
     82            elif k.endswith('_type'): k = k[:-5]+'.type' 
     83            if k not in self._pars: 
     84                formatted_pars = ", ".join(sorted(self._pars.keys())) 
     85                raise KeyError("invalid parameter %r for %s--use one of: %s" 
     86                               %(k, self.model, formatted_pars)) 
     87            if '.' in k and not k.endswith('.width'): 
     88                self.model.setParam(k, v) 
     89            elif isinstance(v, parameter.BaseParameter): 
     90                self._pars[k] = v 
     91            elif isinstance(v, (tuple,list)): 
     92                low, high = v 
     93                self._pars[k].value = (low+high)/2 
     94                self._pars[k].range(low,high) 
    8895            else: 
    89                 bar = ['.']*10 
    90                 if numpy.isfinite(high-low): 
    91                     position = int(9.999999999 * float(p-low)/float(high-low)) 
    92                     if position < 0: bar[0] = '<' 
    93                     elif position > 9: bar[9] = '>' 
    94                     else: bar[position] = '|' 
    95                 bar = "".join(bar) 
    96             output.append("%40s %s %10g in %s"%(name,bar,p,range)) 
    97         return "\n".join(output) 
    98  
    99     def nllf(self, p=None): 
    100         residuals = self.residuals(p) 
    101         return 0.5*numpy.sum(residuals**2) 
    102  
    103     def setp(self, p): 
    104         for k,v in zip(self.param_list, p): 
    105             self.model.setParam(k,v) 
    106         #self.model.set_params(self.param_list, params) 
    107  
    108     def getp(self): 
    109         return numpy.array([self.model.getParam(k) for k in self.param_list]) 
    110         #return numpy.asarray(self.model.get_params(self.param_list)) 
    111  
    112     def bounds(self): 
    113         return numpy.array([self._getrange(p) for p in self.param_list]).T 
    114  
    115     def labels(self): 
    116         return self.param_list 
    117  
    118     def _getrange(self, p): 
    119         """ 
    120         Override _getrange of park parameter 
    121         return the range of parameter 
    122         """ 
    123         lo, hi = self.model.details.get(p,["",None,None])[1:3] 
    124         if lo is None: lo = -numpy.inf 
    125         if hi is None: hi = numpy.inf 
    126         return lo, hi 
    127  
    128     def randomize(self, n): 
    129         p = self.getp() 
    130         # since randn is symmetric and random, doesn't matter 
    131         # point value is negative. 
    132         # TODO: throw in bounds checking! 
    133         return numpy.random.randn(n, len(self.param_list))*p + p 
    134  
    135     def chisq(self): 
    136         """ 
    137         Calculates chi^2 
    138  
    139         :param params: list of parameter values 
    140  
    141         :return: chi^2 
    142  
    143         """ 
    144         return numpy.sum(self.res**2)/self.dof 
    145  
    146     def residuals(self, params=None): 
    147         """ 
    148         Compute residuals 
    149         :param params: value of parameters to fit 
    150         """ 
    151         if params is not None: self.setp(params) 
    152         #import thread 
    153         #print "params", params 
    154         self.res, self.theory = self.data.residuals(self.model.evalDistribution) 
    155         return self.res 
    156  
    157 BOUNDS_PENALTY = 1e6 # cost for going out of bounds on unbounded fitters 
    158 class MonitoredSasProblem(SasProblem): 
    159     """ 
    160     SAS problem definition for optimizers which do not have monitoring or bounds. 
    161     """ 
    162     def __init__(self, param_list, model=None, data=None, fitresult=None, 
    163                  handler=None, curr_thread=None, msg_q=None, update_rate=1): 
    164         """ 
    165         :param Model: the model wrapper fro sans -model 
    166         :param Data: the data wrapper for sans data 
    167         """ 
    168         SasProblem.__init__(self, param_list, model, data) 
    169         self.msg_q = msg_q 
    170         self.curr_thread = curr_thread 
    171         self.handler = handler 
    172         self.fitresult = fitresult 
    173         #self.last_update = time.time() 
    174         #self.func_name = "Functor" 
    175         #self.name = "Fill in proper name!" 
    176  
    177     def residuals(self, p): 
    178         """ 
    179         Cost function for scipy.optimize.leastsq, which does not have a monitor 
    180         built into the algorithm, and instead relies on a monitor built into 
    181         the cost function. 
    182         """ 
    183         # Note: technically, unbounded fitters and unmonitored fitters are 
    184         self.setp(p) 
    185  
    186         # Compute penalty for being out of bounds which increases the farther 
    187         # you get out of bounds.  This allows derivative following algorithms 
    188         # to point back toward the feasible region. 
    189         penalty = self.bounds_penalty() 
    190         if penalty > 0: 
    191             self.theory = numpy.ones(self.data.num_points) 
    192             self.res = self.theory*(penalty/self.data.num_points) + BOUNDS_PENALTY 
    193             return self.res 
    194  
    195         # If no penalty, then we are not out of bounds and we can use the 
    196         # normal residual calculation 
    197         SasProblem.residuals(self, p) 
    198  
    199         # send update to the application 
    200         if True: 
    201             #self.fitresult.set_model(model=self.model) 
    202             # copy residuals into fit results 
    203             self.fitresult.residuals = self.res+0 
    204             self.fitresult.iterations += 1 
    205             self.fitresult.theory = self.theory+0 
    206  
    207             self.fitresult.p = numpy.array(p) # force copy, and coversion to array 
    208             self.fitresult.set_fitness(fitness=self.chisq()) 
    209             if self.msg_q is not None: 
    210                 self.msg_q.put(self.fitresult) 
    211  
    212             if self.handler is not None: 
    213                 self.handler.set_result(result=self.fitresult) 
    214                 self.handler.update_fit() 
    215  
    216             if self.curr_thread != None: 
    217                 try: 
    218                     self.curr_thread.isquit() 
    219                 except: 
    220                     #msg = "Fitting: Terminated...       Note: Forcing to stop " 
    221                     #msg += "fitting may cause a 'Functor error message' " 
    222                     #msg += "being recorded in the log file....." 
    223                     #self.handler.stop(msg) 
    224                     raise 
    225  
    226         return self.res 
    227  
    228     def bounds_penalty(self): 
    229         from numpy import sum, where 
    230         p, bounds = self.getp(), self.bounds() 
    231         return (sum(where(p<bounds[:,0], bounds[:,0]-p, 0)**2) 
    232               + sum(where(p>bounds[:,1], bounds[:,1]-p, 0)**2) ) 
     96                self._pars[k].value = v 
     97        self.update() 
     98 
     99    def set_fitted(self, param_list): 
     100        """ 
     101        Flag a set of parameters as fitted parameters. 
     102        """ 
     103        for k,p in self._pars.items(): 
     104            p.fixed = (k not in param_list) 
     105        self.fitted_pars = [self._pars[k] for k in param_list] 
     106        self.fitted_par_names = param_list 
     107 
     108    # ===== Fitness interface ==== 
     109    def parameters(self): 
     110        return self._pars 
     111 
     112    def update(self): 
     113        for k,v in self._pars.items(): 
     114            self.model.setParam(k,v.value) 
     115        self._dirty = True 
     116 
     117    def _recalculate(self): 
     118        if self._dirty: 
     119            self._residuals, self._theory = self.data.residuals(self.model.evalDistribution) 
     120            self._dirty = False 
     121 
     122    def numpoints(self): 
     123        return numpy.sum(self.data.idx) # number of fitted points 
     124 
     125    def nllf(self): 
     126        return 0.5*numpy.sum(self.residuals()**2) 
     127 
     128    def theory(self): 
     129        self._recalculate() 
     130        return self._theory 
     131 
     132    def residuals(self): 
     133        self._recalculate() 
     134        return self._residuals 
     135 
     136    # Not implementing the data methods for now: 
     137    # 
     138    #     resynth_data/restore_data/save/plot 
    233139 
    234140class BumpsFit(FitEngine): 
     
    247153            q=None, handler=None, curr_thread=None, 
    248154            ftol=1.49012e-8, reset_flag=False): 
    249         """ 
    250         """ 
    251         fitproblem = [] 
    252         for fproblem in self.fit_arrange_dict.itervalues(): 
    253             if fproblem.get_to_fit() == 1: 
    254                 fitproblem.append(fproblem) 
    255         if len(fitproblem) > 1 : 
    256             msg = "Bumps can't fit more than a single fit problem at a time." 
    257             raise RuntimeError, msg 
    258         elif len(fitproblem) == 0 : 
    259             raise RuntimeError, "No problem scheduled for fitting." 
    260         model = fitproblem[0].get_model() 
    261         if reset_flag: 
    262             # reset the initial value; useful for batch 
    263             for name in fitproblem[0].pars: 
    264                 ind = fitproblem[0].pars.index(name) 
    265                 model.setParam(name, fitproblem[0].vals[ind]) 
    266         data = fitproblem[0].get_data() 
    267  
    268         self.curr_thread = curr_thread 
    269  
    270         result = FResult(model=model, data=data, param_list=self.param_list) 
    271         result.pars = fitproblem[0].pars 
    272         result.fitter_id = self.fitter_id 
    273         result.index = data.idx 
    274         if handler is not None: 
    275             handler.set_result(result=result) 
    276  
    277         if True: # bumps 
    278             problem = SasProblem(param_list=self.param_list, 
    279                                  model=model.model, 
    280                                  data=data) 
    281             run_bumps(problem, result, ftol, 
    282                       handler, curr_thread, msg_q) 
    283         else: # scipy levenburg marquardt 
    284             problem = SasProblem(param_list=self.param_list, 
    285                                  model=model.model, 
    286                                  data=data, 
    287                                  handler=handler, 
    288                                  fitresult=result, 
    289                                  curr_thread=curr_thread, 
    290                                  msg_q=msg_q) 
    291             run_levenburg_marquardt(problem, result, ftol) 
    292  
     155        # Build collection of bumps fitness calculators 
     156        models = [ SasFitness(name="M%d"%(i+1), 
     157                              model=M.get_model().model, 
     158                              data=M.get_data(), 
     159                              fitted=M.pars) 
     160                   for i,M in enumerate(self.fit_arrange_dict.values()) 
     161                   if M.get_to_fit() == 1 ] 
     162        problem = FitProblem(models) 
     163 
     164        # Run the fit 
     165        result = run_bumps(problem, handler, curr_thread) 
    293166        if handler is not None: 
    294167            handler.update_fit(last=True) 
     168 
     169        # TODO: shouldn't reference internal parameters 
     170        varying = problem._parameters 
     171        # collect the results 
     172        all_results = [] 
     173        for M in problem.models: 
     174            fitness = M.fitness 
     175            fitted_index = [varying.index(p) for p in fitness.fitted_pars] 
     176            R = FResult(model=fitness.model, data=fitness.data, 
     177                        param_list=fitness.fitted_par_names) 
     178            R.theory = fitness.theory() 
     179            R.residuals = fitness.residuals() 
     180            R.fitter_id = self.fitter_id 
     181            R.stderr = result['stderr'][fitted_index] 
     182            R.pvec = result['value'][fitted_index] 
     183            R.success = result['success'] 
     184            R.fitness = numpy.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index)) 
     185            R.convergence = result['convergence'] 
     186            if result['uncertainty'] is not None: 
     187                R.uncertainty_state = result['uncertainty'] 
     188            all_results.append(R) 
     189 
    295190        if q is not None: 
    296             q.put(result) 
     191            q.put(all_results) 
    297192            return q 
    298         #if success < 1 or success > 5: 
    299         #    result.fitness = None 
    300         return [result] 
    301  
    302 def run_bumps(problem, result, ftol, handler, curr_thread, msg_q): 
     193        else: 
     194            return all_results 
     195 
     196def run_bumps(problem, handler, curr_thread): 
    303197    def abort_test(): 
    304198        if curr_thread is None: return False 
     
    313207    fitclass = fitopts.fitclass 
    314208    options = fitopts.options.copy() 
    315     max_steps = fitopts.options.get('steps', 0) + fitopts.options.get('burn', 0) 
    316     if 'monitors' not in options: 
    317         options['monitors'] = [BumpsMonitor(handler, max_steps)] 
    318     options['monitors'] += [ ConvergenceMonitor() ] 
    319     options['ftol'] = ftol 
     209    max_step = fitopts.options.get('steps', 0) + fitopts.options.get('burn', 0) 
     210    options['monitors'] = [ 
     211        BumpsMonitor(handler, max_step), 
     212        ConvergenceMonitor(), 
     213        ] 
    320214    fitdriver = fitters.FitDriver(fitclass, problem=problem, 
    321215                                  abort_test=abort_test, **options) 
    322216    mapper = SerialMapper  
    323217    fitdriver.mapper = mapper.start_mapper(problem, None) 
     218    import time; T0 = time.time() 
    324219    try: 
    325220        best, fbest = fitdriver.fit() 
     
    329224    finally: 
    330225        mapper.stop_mapper(fitdriver.mapper) 
    331     #print "best,fbest",best,fbest,problem.dof 
    332     result.fitness = 2*fbest/problem.dof 
    333     #print "fitness",result.fitness 
    334     result.stderr  = fitdriver.stderr() 
    335     result.pvec = best 
    336     # TODO: track success better 
    337     result.success = True 
    338     result.theory = problem.theory 
    339     # For the convergence plot 
    340     pop = numpy.asarray(options['monitors'][-1].convergence) 
    341     result.convergence = 2*pop/problem.dof 
    342     # Bumps uncertainty state 
    343     try: result.uncertainty_state = fitdriver.fitter.state 
    344     except AttributeError: pass 
    345  
    346 def run_levenburg_marquardt(problem, result, ftol): 
    347     # This import must be here; otherwise it will be confused when more 
    348     # than one thread exist. 
    349     from scipy import optimize 
    350  
    351     out, cov_x, _, mesg, success = optimize.leastsq(problem.residuals, 
    352                                                     problem.getp(), 
    353                                                     ftol=ftol, 
    354                                                     full_output=1) 
    355     if cov_x is not None and numpy.isfinite(cov_x).all(): 
    356         stderr = numpy.sqrt(numpy.diag(cov_x)) 
    357     else: 
    358         stderr = [] 
    359     result.fitness = problem.chisq() 
    360     result.stderr  = stderr 
    361     result.pvec = out 
    362     result.success = success 
    363     result.theory = problem.theory 
    364  
     226 
     227 
     228    convergence_list = options['monitors'][-1].convergence 
     229    convergence = (2*numpy.asarray(convergence_list)/problem.dof 
     230                   if convergence_list else numpy.empty((0,1),'d')) 
     231    return { 
     232        'value': best, 
     233        'stderr': fitdriver.stderr(), 
     234        'success': True, # better success reporting in bumps 
     235        'convergence': convergence, 
     236        'uncertainty': getattr(fitdriver.fitter, 'state', None), 
     237        } 
     238 
  • src/sans/fit/Fitting.py

    r6fe5100 re3efa6b3  
    3232         
    3333    """   
    34     def __init__(self, engine='scipy'): 
     34    def __init__(self, engine='scipy', *args, **kw): 
    3535        """ 
    3636        """ 
     
    3838        self._engine = None 
    3939        self.fitter_id = None 
    40         self.set_engine(engine) 
     40        self.set_engine(engine, *args, **kw) 
    4141           
    4242    def __setattr__(self, name, value): 
     
    5555            self.__dict__[name] = value 
    5656                 
    57     def set_engine(self, word): 
     57    def set_engine(self, word, *args, **kw): 
    5858        """ 
    5959        Select the type of Fit  
     
    6666        """ 
    6767        try: 
    68             self._engine = ENGINES[word]() 
     68            self._engine = ENGINES[word](*args, **kw) 
    6969        except KeyError, exc: 
    7070            raise KeyError("fit engine should be one of scipy, park or bumps") 
  • src/sans/perspectives/fitting/fit_thread.py

    ra855fec re3efa6b3  
    1818     
    1919    def __init__(self,  
    20                   fn, 
    21                   page_id, 
    22                    handler, 
    23                     batch_outputs, 
    24                     batch_inputs=None,              
    25                   pars=None, 
     20                 fn, 
     21                 page_id, 
     22                 handler, 
     23                 batch_outputs, 
     24                 batch_inputs=None, 
     25                 pars=None, 
    2626                 completefn = None, 
    2727                 updatefn   = None, 
     
    3030                 ftol       = None, 
    3131                 reset_flag = False): 
    32         CalcThread.__init__(self,completefn, 
     32        CalcThread.__init__(self, 
     33                 completefn, 
    3334                 updatefn, 
    3435                 yieldtime, 
     
    8081                list_map_get_attr.append(map_getattr) 
    8182            #from multiprocessing import Pool 
    82             inputs = zip(list_map_get_attr,self.fitter, list_fit_function, 
    83                           list_q, list_q, list_handler,list_curr_thread,list_ftol, 
     83            inputs = zip(list_map_get_attr, self.fitter, list_fit_function, 
     84                         list_q, list_q, list_handler,list_curr_thread,list_ftol, 
    8485                         list_reset_flag) 
    8586            result =  map(map_apply, inputs) 
     
    8788            self.complete(result=result, 
    8889                          batch_inputs=self.batch_inputs, 
    89                            batch_outputs=self.batch_outputs, 
     90                          batch_outputs=self.batch_outputs, 
    9091                          page_id=self.page_id, 
    9192                          pars = self.pars, 
Note: See TracChangeset for help on using the changeset viewer.