Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/compare_many.py

    rd15a908 r6458608  
    11#!/usr/bin/env python 
    2 """ 
    3 Program to compare results from many random parameter sets for a given model. 
    4  
    5 The result is a comma separated value (CSV) table that can be redirected 
    6 from standard output into a file and loaded into a spreadsheet. 
    7  
    8 The models are compared for each parameter set and if the difference is 
    9 greater than expected for that precision, the parameter set is labeled 
    10 as bad and written to the output, along with the random seed used to 
    11 generate that parameter value.  This seed can be used with :mod:`compare` 
    12 to reload and display the details of the model. 
    13 """ 
    142from __future__ import print_function 
    153 
     
    2614 
    2715def calc_stats(target, value, index): 
    28     """ 
    29     Calculate statistics between the target value and the computed value. 
    30  
    31     *target* and *value* are the vectors being compared, with the 
    32     difference normalized by *target* to get relative error.  Only 
    33     the elements listed in *index* are used, though index may be 
    34     and empty slice defined by *slice(None, None)*. 
    35  
    36     Returns: 
    37  
    38         *maxrel* the maximum relative difference 
    39  
    40         *rel95* the relative difference with the 5% biggest differences ignored 
    41  
    42         *maxabs* the maximum absolute difference for the 5% biggest differences 
    43  
    44         *maxval* the maximum value for the 5% biggest differences 
    45     """ 
    4616    resid = abs(value-target)[index] 
    4717    relerr = resid/target[index] 
    48     sorted_rel_index = np.argsort(relerr) 
     18    srel = np.argsort(relerr) 
    4919    #p90 = int(len(relerr)*0.90) 
    5020    p95 = int(len(relerr)*0.95) 
    5121    maxrel = np.max(relerr) 
    52     rel95 = relerr[sorted_rel_index[p95]] 
    53     maxabs = np.max(resid[sorted_rel_index[p95:]]) 
    54     maxval = np.max(value[sorted_rel_index[p95:]]) 
    55     return maxrel, rel95, maxabs, maxval 
     22    rel95 = relerr[srel[p95]] 
     23    maxabs = np.max(resid[srel[p95:]]) 
     24    maxval = np.max(value[srel[p95:]]) 
     25    return maxrel,rel95,maxabs,maxval 
    5626 
    5727def print_column_headers(pars, parts): 
    58     """ 
    59     Generate column headers for the differences and for the parameters, 
    60     and print them to standard output. 
    61     """ 
    6228    stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|')) 
    6329    groups = [''] 
     
    7036    print(','.join('"%s"'%c for c in columns)) 
    7137 
    72 # Target 'good' value for various precision levels. 
    7338PRECISION = { 
    7439    'fast': 1e-3, 
     
    8348def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5, 
    8449                     base='sasview', comp='double'): 
    85     r""" 
    86     Compare the model under different calculation engines. 
    87  
    88     *name* is the name of the model. 
    89  
    90     *data* is the data object giving $q, \Delta q$ calculation points. 
    91  
    92     *index* is the active set of points. 
    93  
    94     *N* is the number of comparisons to make. 
    95  
    96     *cutoff* is the polydispersity weight cutoff to make the calculation 
    97     a little bit faster. 
    98  
    99     *base* and *comp* are the names of the calculation engines to compare. 
    100     """ 
    101  
    102     is_2d = hasattr(data, 'qx_data') 
     50    is2D = hasattr(data, 'qx_data') 
    10351    model_definition = core.load_model_definition(name) 
    10452    pars = get_demo_pars(model_definition) 
    10553    header = ('\n"Model","%s","Count","%d","Dimension","%s"' 
    106               % (name, N, "2D" if is_2d else "1D")) 
     54              % (name, N, "2D" if is2D else "1D")) 
    10755    if not mono: header += ',"Cutoff",%g'%(cutoff,) 
    10856    print(header) 
    10957 
    110     if is_2d: 
     58    if is2D: 
    11159        info = generate.make_info(model_definition) 
    11260        partype = info['partype'] 
     
    12169    # declarations are not available in python 2.7. 
    12270    def try_model(fn, pars): 
    123         """ 
    124         Return the model evaluated at *pars*.  If there is an exception, 
    125         print it and return NaN of the right shape. 
    126         """ 
    12771        try: 
    12872            result = fn(**pars) 
     
    13882        return result 
    13983    def check_model(pars): 
    140         """ 
    141         Run the two calculators against *pars*, returning statistics 
    142         on the differences.  See :func:`calc_stats` for the list of stats. 
    143         """ 
    14484        base_value = try_model(calc_base, pars) 
    14585        comp_value = try_model(calc_comp, pars) 
     
    168108        good = [True] 
    169109        columns = check_model(pars_i) 
    170         columns += [v for _, v in sorted(pars_i.items())] 
     110        columns += [v for _,v in sorted(pars_i.items())] 
    171111        if first: 
    172112            labels = [" vs. ".join((calc_base.engine, calc_comp.engine))] 
     
    181121 
    182122def print_usage(): 
    183     """ 
    184     Print the command usage string. 
    185     """ 
    186123    print("usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono) (single|double|quad)") 
    187124 
    188125 
    189126def print_models(): 
    190     """ 
    191     Print the list of available models in columns. 
    192     """ 
    193127    print(columnize(MODELS, indent="  ")) 
    194128 
    195129 
    196130def print_help(): 
    197     """ 
    198     Print usage string, the option description and the list of available models. 
    199     """ 
    200131    print_usage() 
    201132    print("""\ 
     
    227158 
    228159def main(): 
    229     """ 
    230     Main program. 
    231     """ 
    232     if len(sys.argv) not in (6, 7): 
     160    if len(sys.argv) not in (6,7): 
    233161        print_help() 
    234162        sys.exit(1) 
     
    254182 
    255183    data, index = make_data({'qmax':1.0, 'is2d':is2D, 'nq':Nq, 'res':0., 
    256                              'accuracy': 'Low', 'view':'log'}) 
     184                              'accuracy': 'Low', 'view':'log'}) 
    257185    model_list = [model] if model != "all" else MODELS 
    258186    for model in model_list: 
Note: See TracChangeset for help on using the changeset viewer.