Changeset ec7e360 in sasmodels for sasmodels/compare_many.py


Ignore:
Timestamp:
Dec 23, 2015 10:17:49 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
e21cc31
Parents:
ce166d3
Message:

refactor option processing for compare.py, allowing more flexible selection of calculation engines

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare_many.py

    rf4f3919 rec7e360  
    77from . import core 
    88from .kernelcl import environment 
    9 from .compare import (MODELS, randomize_model, suppress_pd, eval_sasview, 
     9from .compare import (MODELS, randomize_pars, suppress_pd, eval_sasview, 
    1010                      eval_opencl, eval_ctypes, make_data, get_demo_pars, 
    11                       columnize, constrain_pars, constrain_new_to_old) 
     11                      columnize, constrain_pars, constrain_new_to_old, 
     12                      make_engine) 
    1213 
    1314def calc_stats(target, value, index): 
     
    3435    print(','.join('"%s"'%c for c in columns)) 
    3536 
     37PRECISION = { 
     38    'fast': 1e-3, 
     39    'half': 1e-3, 
     40    'single': 5e-5, 
     41    'double': 5e-14, 
     42    'single!': 5e-5, 
     43    'double!': 5e-14, 
     44    'quad!': 5e-18, 
     45    'sasview': 5e-14, 
     46} 
    3647def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5, 
    37                      precision='double'): 
     48                     base='sasview', comp='double'): 
    3849    model_definition = core.load_model_definition(name) 
    3950    pars = get_demo_pars(model_definition) 
     
    4758    # to allow them to update values in the current scope since nonlocal 
    4859    # declarations are not available in python 2.7. 
    49     def try_model(fn, *args, **kw): 
     60    def try_model(fn, pars): 
    5061        try: 
    51             result, _ = fn(model_definition, pars_i, data, *args, **kw) 
     62            result = fn(**pars) 
    5263        except KeyboardInterrupt: 
    5364            raise 
     
    6071                result = np.NaN*data.x 
    6172        return result 
    62     def check_model(label, target, value, acceptable): 
    63         stats = calc_stats(target, value, index) 
    64         columns.extend(stats) 
    65         labels.append('GPU single') 
     73    def check_model(pars): 
     74        base_value = try_model(calc_base, pars) 
     75        comp_value = try_model(calc_comp, pars) 
     76        stats = calc_stats(base_value, comp_value, index) 
    6677        max_diff[0] = max(max_diff[0], stats[0]) 
    67         good[0] = good[0] and (stats[0] < acceptable) 
     78        good[0] = good[0] and (stats[0] < expected) 
     79        return list(stats) 
     80 
     81 
     82    calc_base = make_engine(model_definition, data, base, cutoff) 
     83    calc_comp = make_engine(model_definition, data, comp, cutoff) 
     84    expected = max(PRECISION[base], PRECISION[comp]) 
    6885 
    6986    num_good = 0 
     
    7289    for k in range(N): 
    7390        print("%s %d"%(name, k)) 
    74         pars_i, seed = randomize_model(pars) 
     91        seed = np.random.randint(1e6) 
     92        pars_i = randomize_pars(pars, seed) 
    7593        constrain_pars(model_definition, pars_i) 
    7694        constrain_new_to_old(model_definition, pars_i) 
     
    7997 
    8098        good = [True] 
    81         labels = [] 
    82         columns = [] 
    83         target = try_model(eval_sasview) 
    84         #target = try_model(eval_ctypes, dtype='double', cutoff=0.) 
    85         #target = try_model(eval_ctypes, dtype='longdouble', cutoff=0.) 
    86         if precision == 'single': 
    87             value = try_model(eval_opencl, dtype='single', cutoff=cutoff) 
    88             check_model('GPU single', target, value, 5e-5) 
    89             single_value = value  # remember for single/double comparison 
    90         elif precision == 'double': 
    91             if environment().has_type('double'): 
    92                 label = 'GPU double' 
    93                 value = try_model(eval_opencl, dtype='double', cutoff=cutoff) 
    94             else: 
    95                 label = 'CPU double' 
    96                 value = try_model(eval_ctypes, dtype='double', cutoff=cutoff) 
    97             check_model(label, target, value, 5e-14) 
    98             double_value = value  # remember for single/double comparison 
    99         elif precision == 'quad': 
    100             value = try_model(eval_opencl, dtype='longdouble', cutoff=cutoff) 
    101             check_model('CPU quad', target, value, 5e-14) 
    102         if 0: 
    103             check_model('single/double', double_value, single_value, 5e-5) 
    104  
     99        columns = check_model(pars_i) 
    105100        columns += [v for _,v in sorted(pars_i.items())] 
    106101        if first: 
     102            labels = [" vs. ".join((calc_base.engine, calc_comp.engine))] 
    107103            print_column_headers(pars_i, labels) 
    108104            first = False 
     
    110106            num_good += 1 
    111107        else: 
    112             print(("%d,"%seed)+','.join("%g"%v for v in columns)) 
     108            print(("%d,"%seed)+','.join("%s"%v for v in columns)) 
    113109    print('"good","%d/%d","max diff",%g'%(num_good, N, max_diff[0])) 
    114110 
     
    144140is set in compare.py defaults for each model. 
    145141 
    146 PRECISION is the floating point precision to use for comparisons. 
     142PRECISION is the floating point precision to use for comparisons.  If two 
     143precisions are given, then compare one to the other, ignoring sasview. 
    147144 
    148145Available models: 
     
    151148 
    152149def main(): 
    153     if len(sys.argv) != 6: 
     150    if len(sys.argv) not in (6,7): 
    154151        print_help() 
    155152        sys.exit(1) 
     
    167164        mono = sys.argv[4] == 'mono' 
    168165        cutoff = float(sys.argv[4]) if not mono else 0 
    169         precision = sys.argv[5] 
     166        base = sys.argv[5] 
     167        comp = sys.argv[6] if len(sys.argv) > 6 else "sasview" 
    170168    except: 
    171169        traceback.print_exc() 
     
    173171        sys.exit(1) 
    174172 
    175     data, index = make_data(qmax=1.0, is2D=is2D, Nq=Nq) 
     173    data, index = make_data({'qmax':1.0, 'is2d':is2D, 'nq':Nq, 'res':0., 
     174                              'accuracy': 'Low', 'view':'log'}) 
    176175    model_list = [model] if model != "all" else MODELS 
    177176    for model in model_list: 
    178177        compare_instance(model, data, index, N=count, mono=mono, 
    179                          cutoff=cutoff, precision=precision) 
     178                         cutoff=cutoff, base=base, comp=comp) 
    180179 
    181180if __name__ == "__main__": 
Note: See TracChangeset for help on using the changeset viewer.