Changeset caeb06d in sasmodels


Ignore:
Timestamp:
Jan 21, 2016 5:50:44 PM (9 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
9cfcac8
Parents:
eb46451
Message:

reduce lint

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare.py

    reb46451 rcaeb06d  
    11#!/usr/bin/env python 
    22# -*- coding: utf-8 -*- 
     3""" 
     4Program to compare models using different compute engines. 
     5 
     6This program lets you compare results between OpenCL and DLL versions 
     7of the code and between precision (half, fast, single, double, quad), 
     8where fast precision is single precision using native functions for 
     9trig, etc., and may not be completely IEEE 754 compliant.  This lets 
     10make sure that the model calculations are stable, or if you need to 
     11tag the model as double precision only.   
     12 
     13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the  
     14sasmodels root to see the command line options. 
     15 
     16Note that there is no way within sasmodels to select between an  
     17OpenCL CPU device and a GPU device, but you can do so by setting the  
     18PYOPENCL_CTX environment variable ahead of time.  Start a python 
     19interpreter and enter:: 
     20 
     21    import pyopencl as cl 
     22    cl.create_some_context() 
     23 
     24This will prompt you to select from the available OpenCL devices 
     25and tell you which string to use for the PYOPENCL_CTX variable. 
     26On Windows you will need to remove the quotes. 
     27""" 
     28 
     29from __future__ import print_function 
     30 
     31USAGE = """ 
     32usage: compare.py model N1 N2 [options...] [key=val] 
     33 
     34Compare the speed and value for a model between the SasView original and the 
     35sasmodels rewrite. 
     36 
     37model is the name of the model to compare (see below). 
     38N1 is the number of times to run sasmodels (default=1). 
     39N2 is the number times to run sasview (default=1). 
     40 
     41Options (* for default): 
     42 
     43    -plot*/-noplot plots or suppress the plot of the model 
     44    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0 
     45    -nq=128 sets the number of Q points in the data set 
     46    -1d*/-2d computes 1d or 2d data 
     47    -preset*/-random[=seed] preset or random parameters 
     48    -mono/-poly* force monodisperse/polydisperse 
     49    -cutoff=1e-5* cutoff value for including a point in polydispersity 
     50    -pars/-nopars* prints the parameter set or not 
     51    -abs/-rel* plot relative or absolute error 
     52    -linear/-log*/-q4 intensity scaling 
     53    -hist/-nohist* plot histogram of relative error 
     54    -res=0 sets the resolution width dQ/Q if calculating with resolution 
     55    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh 
     56    -edit starts the parameter explorer 
     57 
     58Any two calculation engines can be selected for comparison: 
     59 
     60    -single/-double/-half/-fast sets an OpenCL calculation engine 
     61    -single!/-double!/-quad! sets an OpenMP calculation engine 
     62    -sasview sets the sasview calculation engine 
     63 
     64The default is -single -sasview.  Note that the interpretation of quad 
     65precision depends on architecture, and may vary from 64-bit to 128-bit, 
     66with 80-bit floats being common (1e-19 precision). 
     67 
     68Key=value pairs allow you to set specific values for the model parameters. 
     69""" 
     70 
     71# Update docs with command line usage string.   This is separate from the usual 
     72# doc string so that we can display it at run time if there is an error. 
     73# lin 
     74__doc__ = __doc__ + """ 
     75Program description 
     76------------------- 
     77 
     78""" + USAGE 
     79 
     80 
    381 
    482import sys 
     
    25103# List of available models 
    26104MODELS = [basename(f)[:-3] 
    27           for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))] 
     105          for f in sorted(glob.glob(joinpath(ROOT, "models", "[a-zA-Z]*.py")))] 
    28106 
    29107# CRUFT python 2.6 
     
    76154    elif p.endswith('_pd_type'): 
    77155        return v 
    78     elif any(s in p for s in ('theta','phi','psi')): 
     156    elif any(s in p for s in ('theta', 'phi', 'psi')): 
    79157        # orientation in [-180,180], orientation pd in [0,45] 
    80158        if p.endswith('_pd'): 
    81             return [0,45] 
     159            return [0, 45] 
    82160        else: 
    83161            return [-180, 180] 
     
    97175        return [2*v, -2*v] 
    98176    else: 
    99         return [0, (2*v if v>0 else 1)] 
     177        return [0, (2*v if v > 0 else 1)] 
    100178 
    101179def _randomize_one(p, v): 
    102180    """ 
    103     Randomizing parameter. 
    104     """ 
    105     if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')): 
     181    Randomize a single parameter. 
     182    """ 
     183    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')): 
    106184        return v 
    107185    else: 
     
    109187 
    110188def randomize_pars(pars, seed=None): 
     189    """ 
     190    Generate random values for all of the parameters. 
     191 
     192    Valid ranges for the random number generator are guessed from the name of 
     193    the parameter; this will not account for constraints such as cap radius 
     194    greater than cylinder radius in the capped_cylinder model, so 
     195    :func:`constrain_pars` needs to be called afterward.. 
     196    """ 
    111197    np.random.seed(seed) 
    112198    # Note: the sort guarantees order `of calls to random number generator 
    113     pars = dict((p,_randomize_one(p,v)) 
    114                 for p,v in sorted(pars.items())) 
     199    pars = dict((p, _randomize_one(p, v)) 
     200                for p, v in sorted(pars.items())) 
    115201    return pars 
    116202 
     
    118204    """ 
    119205    Restrict parameters to valid values. 
     206 
     207    This includes model specific code for models such as capped_cylinder 
     208    which need to support within model constraints (cap radius more than 
     209    cylinder radius in this case). 
    120210    """ 
    121211    name = model_definition.name 
    122212    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']: 
    123         pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius'] 
     213        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius'] 
    124214    if name == 'barbell' and pars['bell_radius'] < pars['radius']: 
    125         pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius'] 
     215        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius'] 
    126216 
    127217    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff) 
     
    130220        q_max = 1.0  # high q maximum 
    131221        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max 
    132         pars['rg'] = min(pars['rg'],rg_max) 
     222        pars['rg'] = min(pars['rg'], rg_max) 
    133223 
    134224    if name == 'rpa': 
     
    144234 
    145235def parlist(pars): 
    146     return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items())) 
     236    """ 
     237    Format the parameter list for printing. 
     238    """ 
     239    return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items())) 
    147240 
    148241def suppress_pd(pars): 
     
    159252 
    160253def eval_sasview(model_definition, data): 
     254    """ 
     255    Return a model calculator using the SasView fitting engine. 
     256    """ 
    161257    # importing sas here so that the error message will be that sas failed to 
    162258    # import rather than the more obscure smear_selection not imported error 
     
    166262    # convert model parameters from sasmodel form to sasview form 
    167263    #print("old",sorted(pars.items())) 
    168     modelname, pars = revert_model(model_definition, {}) 
    169     #print("new",sorted(pars.items())) 
     264    modelname, _pars = revert_model(model_definition, {}) 
     265    #print("new",sorted(_pars.items())) 
    170266    sas = __import__('sas.models.'+modelname) 
    171     ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None) 
     267    ModelClass = getattr(getattr(sas.models, modelname, None), modelname, None) 
    172268    if ModelClass is None: 
    173269        raise ValueError("could not find model %r in sas.models"%modelname) 
     
    192288 
    193289    def calculator(**pars): 
     290        """ 
     291        Sasview calculator for model. 
     292        """ 
    194293        # paying for parameter conversion each time to keep life simple, if not fast 
    195294        _, pars = revert_model(model_definition, pars) 
    196         for k,v in pars.items(): 
     295        for k, v in pars.items(): 
    197296            parts = k.split('.')  # polydispersity components 
    198297            if len(parts) == 2: 
     
    217316} 
    218317def eval_opencl(model_definition, data, dtype='single', cutoff=0.): 
     318    """ 
     319    Return a model calculator using the OpenCL calculation engine. 
     320    """ 
    219321    try: 
    220322        model = core.load_model(model_definition, dtype=dtype, platform="ocl") 
     
    229331 
    230332def eval_ctypes(model_definition, data, dtype='double', cutoff=0.): 
    231     if dtype=='quad': 
     333    if dtype == 'quad': 
    232334        dtype = 'longdouble' 
    233335    model = core.load_model(model_definition, dtype=dtype, platform="dll") 
     
    237339 
    238340def time_calculation(calculator, pars, Nevals=1): 
     341    """ 
     342    Compute the average calculation time over N evaluations. 
     343 
     344    An additional call is generated without polydispersity in order to 
     345    initialize the calculation engine, and make the average more stable. 
     346    """ 
    239347    # initialize the code so time is more accurate 
    240348    value = calculator(**suppress_pd(pars)) 
     
    246354 
    247355def make_data(opts): 
     356    """ 
     357    Generate an empty dataset, used with the model to set Q points 
     358    and resolution. 
     359 
     360    *opts* contains the options, with 'qmax', 'nq', 'res', 
     361    'accuracy', 'is2d' and 'view' parsed from the command line. 
     362    """ 
    248363    qmax, nq, res = opts['qmax'], opts['nq'], opts['res'] 
    249364    if opts['is2d']: 
     
    263378 
    264379def make_engine(model_definition, data, dtype, cutoff): 
     380    """ 
     381    Generate the appropriate calculation engine for the given datatype. 
     382 
     383    Datatypes with '!' appended are evaluated using external C DLLs rather 
     384    than OpenCL. 
     385    """ 
    265386    if dtype == 'sasview': 
    266387        return eval_sasview(model_definition, data) 
     
    273394 
    274395def compare(opts, limits=None): 
     396    """ 
     397    Preform a comparison using options from the command line. 
     398 
     399    *limits* are the limits on the values to use, either to set the y-axis 
     400    for 1D or to set the colormap scale for 2D.  If None, then they are 
     401    inferred from the data and returned. When exploring using Bumps, 
     402    the limits are set when the model is initially called, and maintained 
     403    as the values are adjusted, making it easier to see the effects of the 
     404    parameters. 
     405    """ 
    275406    Nbase, Ncomp = opts['N1'], opts['N2'] 
    276407    pars = opts['pars'] 
     
    304435        resid = (base_value - comp_value) 
    305436        relerr = resid/comp_value 
    306         _print_stats("|%s - %s|"%(base.engine,comp.engine)+(" "*(3+len(comp.engine))), resid) 
    307         _print_stats("|(%s - %s) / %s|"%(base.engine,comp.engine,comp.engine), relerr) 
     437        _print_stats("|%s-%s|"%(base.engine, comp.engine) + (" "*(3+len(comp.engine))), 
     438                     resid) 
     439        _print_stats("|(%s-%s)/%s|"%(base.engine, comp.engine, comp.engine), 
     440                     relerr) 
    308441 
    309442    # Plot if requested 
     
    329462        if Nbase > 0: plt.subplot(132) 
    330463        plot_theory(data, comp_value, view=view, plot_data=False, limits=limits) 
    331         plt.title("%s t=%.1f ms"%(comp.engine,comp_time)) 
     464        plt.title("%s t=%.1f ms"%(comp.engine, comp_time)) 
    332465        #cbar_title = "log I" 
    333466    if Ncomp > 0 and Nbase > 0: 
    334467        plt.subplot(133) 
    335468        if '-abs' in opts: 
    336             err,errstr,errview = resid, "abs err", "linear" 
     469            err, errstr, errview = resid, "abs err", "linear" 
    337470        else: 
    338             err,errstr,errview = abs(relerr), "rel err", "log" 
     471            err, errstr, errview = abs(relerr), "rel err", "log" 
    339472        #err,errstr = base/comp,"ratio" 
    340473        plot_theory(data, None, resid=err, view=errview, plot_data=False) 
     
    348481        plt.figure() 
    349482        v = relerr 
    350         v[v==0] = 0.5*np.min(np.abs(v[v!=0])) 
    351         plt.hist(np.log10(np.abs(v)), normed=1, bins=50); 
    352         plt.xlabel('log10(err), err = |(%s - %s) / %s|'%(base.engine, comp.engine, comp.engine)); 
     483        v[v == 0] = 0.5*np.min(np.abs(v[v != 0])) 
     484        plt.hist(np.log10(np.abs(v)), normed=1, bins=50) 
     485        plt.xlabel('log10(err), err = |(%s - %s) / %s|' 
     486                   % (base.engine, comp.engine, comp.engine)) 
    353487        plt.ylabel('P(err)') 
    354488        plt.title('Distribution of relative error between calculation engines') 
     
    370504        "zero-offset:%+.3e"%np.mean(err), 
    371505        ] 
    372     print(label+"  ".join(data)) 
     506    print(label+"  "+"  ".join(data)) 
    373507 
    374508 
     
    376510# =========================================================================== 
    377511# 
    378 USAGE=""" 
    379 usage: compare.py model N1 N2 [options...] [key=val] 
    380  
    381 Compare the speed and value for a model between the SasView original and the 
    382 sasmodels rewrite. 
    383  
    384 model is the name of the model to compare (see below). 
    385 N1 is the number of times to run sasmodels (default=1). 
    386 N2 is the number times to run sasview (default=1). 
    387  
    388 Options (* for default): 
    389  
    390     -plot*/-noplot plots or suppress the plot of the model 
    391     -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0 
    392     -nq=128 sets the number of Q points in the data set 
    393     -1d*/-2d computes 1d or 2d data 
    394     -preset*/-random[=seed] preset or random parameters 
    395     -mono/-poly* force monodisperse/polydisperse 
    396     -cutoff=1e-5* cutoff value for including a point in polydispersity 
    397     -pars/-nopars* prints the parameter set or not 
    398     -abs/-rel* plot relative or absolute error 
    399     -linear/-log*/-q4 intensity scaling 
    400     -hist/-nohist* plot histogram of relative error 
    401     -res=0 sets the resolution width dQ/Q if calculating with resolution 
    402     -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh 
    403     -edit starts the parameter explorer 
    404  
    405 Any two calculation engines can be selected for comparison: 
    406  
    407     -single/-double/-half/-fast sets an OpenCL calculation engine 
    408     -single!/-double!/-quad! sets an OpenMP calculation engine 
    409     -sasview sets the sasview calculation engine 
    410  
    411 The default is -single -sasview.  Note that the interpretation of quad 
    412 precision depends on architecture, and may vary from 64-bit to 128-bit, 
    413 with 80-bit floats being common (1e-19 precision). 
    414  
    415 Key=value pairs allow you to set specific values for the model parameters. 
    416  
    417 Available models: 
    418 """ 
    419  
    420  
    421512NAME_OPTIONS = set([ 
    422513    'plot', 'noplot', 
     
    439530 
    440531def columnize(L, indent="", width=79): 
     532    """ 
     533    Format a list of strings into columns for printing. 
     534    """ 
    441535    column_width = max(len(w) for w in L) + 1 
    442536    num_columns = (width - len(indent)) // column_width 
     
    451545 
    452546def get_demo_pars(model_definition): 
     547    """ 
     548    Extract demo parameters from the model definition. 
     549    """ 
    453550    info = generate.make_info(model_definition) 
    454551    # Get the default values for the parameters 
     
    468565 
    469566def parse_opts(): 
    470     flags = [arg for arg in sys.argv[1:] if arg.startswith('-')] 
    471     values = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' in arg] 
    472     args = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' not in arg] 
     567    """ 
     568    Parse command line options. 
     569    """ 
     570    flags = [arg for arg in sys.argv[1:] 
     571             if arg.startswith('-')] 
     572    values = [arg for arg in sys.argv[1:] 
     573              if not arg.startswith('-') and '=' in arg] 
     574    args = [arg for arg in sys.argv[1:] 
     575            if not arg.startswith('-') and '=' not in arg] 
    473576    models = "\n    ".join("%-15s"%v for v in MODELS) 
    474577    if len(args) == 0: 
    475578        print(USAGE) 
     579        print("\nAvailable models:") 
    476580        print(columnize(MODELS, indent="  ")) 
    477581        sys.exit(1) 
    478582    if args[0] not in MODELS: 
    479         print("Model %r not available. Use one of:\n    %s"%(args[0],models)) 
     583        print("Model %r not available. Use one of:\n    %s"%(args[0], models)) 
    480584        sys.exit(1) 
    481585    if len(args) > 3: 
     
    484588    invalid = [o[1:] for o in flags 
    485589               if o[1:] not in NAME_OPTIONS 
    486                   and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)] 
     590                   and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)] 
    487591    if invalid: 
    488592        print("Invalid options: %s"%(", ".join(invalid))) 
Note: See TracChangeset for help on using the changeset viewer.