Changeset 319ab14 in sasmodels for sasmodels/compare_many.py


Ignore:
Timestamp:
Nov 25, 2015 1:12:06 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
0fa687d
Parents:
38d8774
Message:

allow comparison between double/quad precision and sasview

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare_many.py

    rb514adf r319ab14  
    1111                      columnize, constrain_pars) 
    1212 
    13 def get_stats(target, value, index): 
     13def calc_stats(target, value, index): 
    1414    resid = abs(value-target)[index] 
    1515    relerr = resid/target[index] 
     
    3434    print(','.join('"%s"'%c for c in columns)) 
    3535 
    36 def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5): 
     36def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5, 
     37                     precision='double'): 
    3738    model_definition = core.load_model_definition(name) 
    3839    pars = get_demo_pars(model_definition) 
     
    4142    print(header) 
    4243 
    43     def trymodel(fn, *args, **kw): 
     44    # Some not very clean macros for evaluating the models and checking the 
     45    # results.  They freely use variables from the current scope, even some 
     46    # which have not been defined yet, complete with abuse of mutable lists 
     47    # to allow them to update values in the current scope since nonlocal 
     48    # declarations are not available in python 2.7. 
     49    def try_model(fn, *args, **kw): 
    4450        try: 
    4551            result, _ = fn(model_definition, pars_i, data, *args, **kw) 
     
    5460                result = np.NaN*data.x 
    5561        return result 
     62    def check_model(label, target, value, acceptable): 
     63        stats = calc_stats(target, value, index) 
     64        columns.extend(stats) 
     65        labels.append('GPU single') 
     66        max_diff[0] = max(max_diff[0], stats[0]) 
     67        good[0] = good[0] and (stats[0] < acceptable) 
    5668 
    5769    num_good = 0 
    5870    first = True 
    59     max_diff = 0 
     71    max_diff = [0] 
    6072    for k in range(N): 
    6173        print >>sys.stderr, name, k 
     
    6476        if mono: suppress_pd(pars_i) 
    6577 
    66         good = True 
     78        good = [True] 
    6779        labels = [] 
    6880        columns = [] 
    69         if 1: 
    70             sasview_value = trymodel(eval_sasview) 
     81        #target = try_model(eval_sasview) 
     82        target = try_model(eval_opencl, dtype='longdouble', cutoff=cutoff) 
     83        if precision == 'single': 
     84            value = try_model(eval_opencl, dtype='single', cutoff=cutoff) 
     85            check_model('GPU single', target, value, 5e-5) 
     86            single_value = value  # remember for single/double comparison 
     87        elif precision == 'double': 
     88            if environment().has_double: 
     89                label = 'GPU double' 
     90                value = try_model(eval_opencl, dtype='double', cutoff=cutoff) 
     91            else: 
     92                label = 'CPU double' 
     93                value = try_model(eval_ctypes, dtype='double', cutoff=cutoff) 
     94            check_model(label, target, value, 5e-14) 
     95            double_value = value  # remember for single/double comparison 
     96        elif precision == 'quad': 
     97            value = try_model(eval_opencl, dtype='longdouble', cutoff=cutoff) 
     98            check_model('CPU quad', target, value, 5e-14) 
    7199        if 0: 
    72             gpu_single_value = trymodel(eval_opencl, dtype='single', cutoff=cutoff) 
    73             stats = get_stats(sasview_value, gpu_single_value, index) 
    74             columns.extend(stats) 
    75             labels.append('GPU single') 
    76             max_diff = max(max_diff, stats[0]) 
    77             good = good and (stats[0] < 5e-5) 
    78         if 0 and environment().has_double: 
    79             gpu_double_value = trymodel(eval_opencl, dtype='double', cutoff=cutoff) 
    80             stats = get_stats(sasview_value, gpu_double_value, index) 
    81             columns.extend(stats) 
    82             labels.append('GPU double') 
    83             max_diff = max(max_diff, stats[0]) 
    84             good = good and (stats[0] < 1e-12) 
    85         if 1: 
    86             cpu_double_value = trymodel(eval_ctypes, dtype='double', cutoff=cutoff) 
    87             stats = get_stats(sasview_value, cpu_double_value, index) 
    88             columns.extend(stats) 
    89             labels.append('CPU double') 
    90             max_diff = max(max_diff, stats[0]) 
    91             good = good and (stats[0] < 1e-12) 
    92         if 0: 
    93             stats = get_stats(cpu_double_value, gpu_single_value, index) 
    94             columns.extend(stats) 
    95             labels.append('single/double') 
    96             max_diff = max(max_diff, stats[0]) 
    97             good = good and (stats[0] < 5e-5) 
     100            check_model('single/double', double_value, single_value, 5e-5) 
    98101 
    99102        columns += [v for _,v in sorted(pars_i.items())] 
     
    101104            print_column_headers(pars_i, labels) 
    102105            first = False 
    103         if good: 
     106        if good[0]: 
    104107            num_good += 1 
    105108        else: 
    106109            print(("%d,"%seed)+','.join("%g"%v for v in columns)) 
    107     print '"good","%d/%d","max diff",%g'%(num_good, N, max_diff) 
     110    print '"good","%d/%d","max diff",%g'%(num_good, N, max_diff[0]) 
    108111 
    109112 
    110113def print_usage(): 
    111     print "usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono)" 
     114    print "usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono) (single|double|quad)" 
    112115 
    113116 
     
    138141is set in compare.py defaults for each model. 
    139142 
     143PRECISION is the floating point precision to use for comparisons. 
     144 
    140145Available models: 
    141146""") 
     
    143148 
    144149def main(): 
    145     if len(sys.argv) == 1: 
     150    if len(sys.argv) != 6: 
    146151        print_help() 
    147152        sys.exit(1) 
     
    159164        mono = sys.argv[4] == 'mono' 
    160165        cutoff = float(sys.argv[4]) if not mono else 0 
     166        precision = sys.argv[5] 
    161167    except: 
     168        traceback.print_exc() 
    162169        print_usage() 
    163170        sys.exit(1) 
     
    166173    model_list = [model] if model != "all" else MODELS 
    167174    for model in model_list: 
    168         compare_instance(model, data, index, N=count, mono=mono, cutoff=cutoff) 
     175        compare_instance(model, data, index, N=count, mono=mono, 
     176                         cutoff=cutoff, precision=precision) 
    169177 
    170178if __name__ == "__main__": 
Note: See TracChangeset for help on using the changeset viewer.