source: sasmodels/sasmodels/compare_many.py @ f72d70a

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since f72d70a was f72d70a, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

multi-compare: slightly nicer interface

  • Property mode set to 100755
File size: 9.4 KB
RevLine 
[216a9e1]1#!/usr/bin/env python
[d15a908]2"""
3Program to compare results from many random parameter sets for a given model.
4
5The result is a comma separated value (CSV) table that can be redirected
6from standard output into a file and loaded into a spreadsheet.
7
8The models are compared for each parameter set and if the difference is
9greater than expected for that precision, the parameter set is labeled
10as bad and written to the output, along with the random seed used to
11generate that parameter value.  This seed can be used with :mod:`compare`
12to reload and display the details of the model.
13"""
[a7f909a]14from __future__ import print_function
15
[216a9e1]16import sys
[7cf2cfd]17import traceback
[216a9e1]18
[7ae2b7f]19import numpy as np  # type: ignore
[216a9e1]20
[e922c5d]21from . import core
[f3bd37f]22from .compare import (randomize_pars, suppress_pd, make_data,
[ce346b6]23                      make_engine, get_pars, columnize,
[a7f909a]24                      constrain_pars, constrain_new_to_old)
[216a9e1]25
[f3bd37f]26MODELS = core.list_models()
27
[319ab14]28def calc_stats(target, value, index):
[d15a908]29    """
30    Calculate statistics between the target value and the computed value.
31
32    *target* and *value* are the vectors being compared, with the
33    difference normalized by *target* to get relative error.  Only
34    the elements listed in *index* are used, though index may be
35    and empty slice defined by *slice(None, None)*.
36
37    Returns:
38
39        *maxrel* the maximum relative difference
40
41        *rel95* the relative difference with the 5% biggest differences ignored
42
43        *maxabs* the maximum absolute difference for the 5% biggest differences
44
45        *maxval* the maximum value for the 5% biggest differences
46    """
[216a9e1]47    resid = abs(value-target)[index]
48    relerr = resid/target[index]
[d15a908]49    sorted_rel_index = np.argsort(relerr)
[7cf2cfd]50    #p90 = int(len(relerr)*0.90)
[216a9e1]51    p95 = int(len(relerr)*0.95)
52    maxrel = np.max(relerr)
[d15a908]53    rel95 = relerr[sorted_rel_index[p95]]
54    maxabs = np.max(resid[sorted_rel_index[p95:]])
55    maxval = np.max(value[sorted_rel_index[p95:]])
56    return maxrel, rel95, maxabs, maxval
[216a9e1]57
58def print_column_headers(pars, parts):
[d15a908]59    """
60    Generate column headers for the differences and for the parameters,
61    and print them to standard output.
62    """
[216a9e1]63    stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|'))
64    groups = ['']
65    for p in parts:
66        groups.append(p)
67        groups.extend(['']*(len(stats)-1))
[7cf2cfd]68    groups.append("Parameters")
[216a9e1]69    columns = ['Seed'] + stats*len(parts) +  list(sorted(pars.keys()))
70    print(','.join('"%s"'%c for c in groups))
71    print(','.join('"%s"'%c for c in columns))
72
[d15a908]73# Target 'good' value for various precision levels.
[ec7e360]74PRECISION = {
75    'fast': 1e-3,
76    'half': 1e-3,
77    'single': 5e-5,
78    'double': 5e-14,
79    'single!': 5e-5,
80    'double!': 5e-14,
81    'quad!': 5e-18,
82    'sasview': 5e-14,
83}
[319ab14]84def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5,
[ec7e360]85                     base='sasview', comp='double'):
[d15a908]86    r"""
87    Compare the model under different calculation engines.
88
89    *name* is the name of the model.
90
91    *data* is the data object giving $q, \Delta q$ calculation points.
92
93    *index* is the active set of points.
94
95    *N* is the number of comparisons to make.
96
97    *cutoff* is the polydispersity weight cutoff to make the calculation
98    a little bit faster.
99
100    *base* and *comp* are the names of the calculation engines to compare.
101    """
102
103    is_2d = hasattr(data, 'qx_data')
[17bbadd]104    model_info = core.load_model_info(name)
[ce346b6]105    pars = get_pars(model_info, use_demo=True)
[a7f909a]106    header = ('\n"Model","%s","Count","%d","Dimension","%s"'
[d15a908]107              % (name, N, "2D" if is_2d else "1D"))
[5124c969]108    if not mono:
109        header += ',"Cutoff",%g'%(cutoff,)
[216a9e1]110    print(header)
[7cf2cfd]111
[d15a908]112    if is_2d:
[f3bd37f]113        if not model_info.parameters.has_2d:
[a7f909a]114            print(',"1-D only"')
115            return
116
[319ab14]117    # Some not very clean macros for evaluating the models and checking the
118    # results.  They freely use variables from the current scope, even some
119    # which have not been defined yet, complete with abuse of mutable lists
120    # to allow them to update values in the current scope since nonlocal
121    # declarations are not available in python 2.7.
[ec7e360]122    def try_model(fn, pars):
[d15a908]123        """
124        Return the model evaluated at *pars*.  If there is an exception,
125        print it and return NaN of the right shape.
126        """
[7cf2cfd]127        try:
[ec7e360]128            result = fn(**pars)
[ee8f734]129        except Exception:
[9404dd3]130            traceback.print_exc()
131            print("when comparing %s for %d"%(name, seed))
[cd3dba0]132            if hasattr(data, 'qx_data'):
133                result = np.NaN*data.data
134            else:
135                result = np.NaN*data.x
[7cf2cfd]136        return result
[ec7e360]137    def check_model(pars):
[d15a908]138        """
139        Run the two calculators against *pars*, returning statistics
140        on the differences.  See :func:`calc_stats` for the list of stats.
141        """
[ec7e360]142        base_value = try_model(calc_base, pars)
143        comp_value = try_model(calc_comp, pars)
144        stats = calc_stats(base_value, comp_value, index)
[319ab14]145        max_diff[0] = max(max_diff[0], stats[0])
[ec7e360]146        good[0] = good[0] and (stats[0] < expected)
147        return list(stats)
148
149
[f3bd37f]150    try:
151        calc_base = make_engine(model_info, data, base, cutoff)
152        calc_comp = make_engine(model_info, data, comp, cutoff)
153    except Exception as exc:
154        #raise
[40a87fa]155        print('"Error: %s"'%str(exc).replace('"', "'"))
156        print('"good","%d of %d","max diff",%g' % (0, N, np.NaN))
[f3bd37f]157        return
[ec7e360]158    expected = max(PRECISION[base], PRECISION[comp])
[7cf2cfd]159
160    num_good = 0
[216a9e1]161    first = True
[319ab14]162    max_diff = [0]
[cd3dba0]163    for k in range(N):
[5124c969]164        print("Model %s %d"%(name, k+1), file=sys.stderr)
[ec7e360]165        seed = np.random.randint(1e6)
[f3bd37f]166        pars_i = randomize_pars(model_info, pars, seed)
[ed048b2]167        constrain_pars(model_info, pars_i)
[5124c969]168        if 'sasview' in (base, comp):
169            constrain_new_to_old(model_info, pars_i)
[f4f3919]170        if mono:
171            pars_i = suppress_pd(pars_i)
[7cf2cfd]172
[319ab14]173        good = [True]
[ec7e360]174        columns = check_model(pars_i)
[d15a908]175        columns += [v for _, v in sorted(pars_i.items())]
[7cf2cfd]176        if first:
[ec7e360]177            labels = [" vs. ".join((calc_base.engine, calc_comp.engine))]
[cd3dba0]178            print_column_headers(pars_i, labels)
[7cf2cfd]179            first = False
[319ab14]180        if good[0]:
[7cf2cfd]181            num_good += 1
[216a9e1]182        else:
[ec7e360]183            print(("%d,"%seed)+','.join("%s"%v for v in columns))
[f3bd37f]184    print('"good","%d of %d","max diff",%g'%(num_good, N, max_diff[0]))
[7cf2cfd]185
186
187def print_usage():
[d15a908]188    """
189    Print the command usage string.
190    """
[f72d70a]191    print("usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono) (single|double|quad)",
192          file=sys.stderr)
[7cf2cfd]193
194
195def print_models():
[d15a908]196    """
197    Print the list of available models in columns.
198    """
[7cf2cfd]199    print(columnize(MODELS, indent="  "))
[216a9e1]200
201
[7cf2cfd]202def print_help():
[d15a908]203    """
204    Print usage string, the option description and the list of available models.
205    """
[7cf2cfd]206    print_usage()
207    print("""\
208
[5124c969]209MODEL is the model name of the model or one of the model types listed in
210sasmodels.core.list_models (all, py, c, double, single, opencl, 1d, 2d,
211nonmagnetic, magnetic).  Model types can be combined, such as 2d+single.
[216a9e1]212
213COUNT is the number of randomly generated parameter sets to try. A value
214of "10000" is a reasonable check for monodisperse models, or "100" for
215polydisperse models.   For a quick check, use "100" and "5" respectively.
216
217NQ is the number of Q values to calculate.  If it starts with "1d", then
218it is a 1-dimensional problem, with log spaced Q points from 1e-3 to 1.0.
219If it starts with "2d" then it is a 2-dimensional problem, with linearly
220spaced points Q points from -1.0 to 1.0 in each dimension. The usual
221values are "1d100" for 1-D and "2d32" for 2-D.
222
223CUTOFF is the cutoff value to use for the polydisperse distribution. Weights
224below the cutoff will be ignored.  Use "mono" for monodisperse models.  The
225choice of polydisperse parameters, and the number of points in the distribution
[f72d70a]226is set in compare.py defaults for each model.  Polydispersity is given in the
227"demo" attribute of each model.
[7cf2cfd]228
[ec7e360]229PRECISION is the floating point precision to use for comparisons.  If two
[f72d70a]230precisions are given, then compare one to the other.  Precision is one of
231fast, single, double for GPU or single!, double!, quad! for DLL.  If no
232precision is given, then use single and double! respectively.
[319ab14]233
[7cf2cfd]234Available models:
235""")
236    print_models()
237
[424fe00]238def main(argv):
[d15a908]239    """
240    Main program.
241    """
[f72d70a]242    if len(argv) not in (3, 4, 5, 6):
[7cf2cfd]243        print_help()
[424fe00]244        return
[7cf2cfd]245
[5124c969]246    target = argv[0]
247    try:
248        model_list = [target] if target in MODELS else core.list_models(target)
249    except ValueError:
[f72d70a]250        print('Bad model %s.  Use model type or one of:' % target, file=sys.stderr)
[7cf2cfd]251        print_models()
[5124c969]252        print('model types: all, py, c, double, single, opencl, 1d, 2d, nonmagnetic, magnetic')
[424fe00]253        return
[7cf2cfd]254    try:
[424fe00]255        count = int(argv[1])
256        is2D = argv[2].startswith('2d')
257        assert argv[2][1] == 'd'
258        Nq = int(argv[2][2:])
[f72d70a]259        mono = len(argv) <= 3 or argv[3] == 'mono'
[424fe00]260        cutoff = float(argv[3]) if not mono else 0
[f72d70a]261        base = argv[4] if len(argv) > 4 else "single"
262        comp = argv[5] if len(argv) > 5 else "double!"
[ee8f734]263    except Exception:
[319ab14]264        traceback.print_exc()
[7cf2cfd]265        print_usage()
[424fe00]266        return
[216a9e1]267
[ec7e360]268    data, index = make_data({'qmax':1.0, 'is2d':is2D, 'nq':Nq, 'res':0.,
[f3bd37f]269                             'accuracy': 'Low', 'view':'log', 'zero': False})
[216a9e1]270    for model in model_list:
[319ab14]271        compare_instance(model, data, index, N=count, mono=mono,
[ec7e360]272                         cutoff=cutoff, base=base, comp=comp)
[216a9e1]273
274if __name__ == "__main__":
[4f2478e]275    #from .compare import push_seed
[424fe00]276    #with push_seed(1): main(sys.argv[1:])
277    main(sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.