source: sasmodels/sasmodels/compare_many.py @ d86f0fc

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since d86f0fc was 2d81cfe, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

lint

  • Property mode set to 100755
File size: 9.4 KB
Line 
1#!/usr/bin/env python
2"""
3Program to compare results from many random parameter sets for a given model.
4
5The result is a comma separated value (CSV) table that can be redirected
6from standard output into a file and loaded into a spreadsheet.
7
8The models are compared for each parameter set and if the difference is
9greater than expected for that precision, the parameter set is labeled
10as bad and written to the output, along with the random seed used to
11generate that parameter value.  This seed can be used with :mod:`compare`
12to reload and display the details of the model.
13"""
14from __future__ import print_function
15
16import sys
17import traceback
18
19import numpy as np  # type: ignore
20
21from . import core
22from .compare import (randomize_pars, suppress_pd, make_data,
23                      make_engine, get_pars, columnize,
24                      constrain_pars)
25
26MODELS = core.list_models()
27
28def calc_stats(target, value, index):
29    """
30    Calculate statistics between the target value and the computed value.
31
32    *target* and *value* are the vectors being compared, with the
33    difference normalized by *target* to get relative error.  Only
34    the elements listed in *index* are used, though index may be
35    and empty slice defined by *slice(None, None)*.
36
37    Returns:
38
39        *maxrel* the maximum relative difference
40
41        *rel95* the relative difference with the 5% biggest differences ignored
42
43        *maxabs* the maximum absolute difference for the 5% biggest differences
44
45        *maxval* the maximum value for the 5% biggest differences
46    """
47    resid = abs(value-target)[index]
48    relerr = resid/target[index]
49    sorted_rel_index = np.argsort(relerr)
50    #p90 = int(len(relerr)*0.90)
51    p95 = int(len(relerr)*0.95)
52    maxrel = np.max(relerr)
53    rel95 = relerr[sorted_rel_index[p95]]
54    maxabs = np.max(resid[sorted_rel_index[p95:]])
55    maxval = np.max(value[sorted_rel_index[p95:]])
56    return maxrel, rel95, maxabs, maxval
57
58def print_column_headers(pars, parts):
59    """
60    Generate column headers for the differences and for the parameters,
61    and print them to standard output.
62    """
63    stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|'))
64    groups = ['']
65    for p in parts:
66        groups.append(p)
67        groups.extend(['']*(len(stats)-1))
68    groups.append("Parameters")
69    columns = ['Seed'] + stats*len(parts) +  list(sorted(pars.keys()))
70    print(','.join('"%s"'%c for c in groups))
71    print(','.join('"%s"'%c for c in columns))
72
73# Target 'good' value for various precision levels.
74PRECISION = {
75    'fast': 1e-3,
76    'half': 1e-3,
77    'single': 5e-5,
78    'double': 5e-14,
79    'single!': 5e-5,
80    'double!': 5e-14,
81    'quad!': 5e-18,
82}
83def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5,
84                     base='single', comp='double'):
85    r"""
86    Compare the model under different calculation engines.
87
88    *name* is the name of the model.
89
90    *data* is the data object giving $q, \Delta q$ calculation points.
91
92    *index* is the active set of points.
93
94    *N* is the number of comparisons to make.
95
96    *cutoff* is the polydispersity weight cutoff to make the calculation
97    a little bit faster.
98
99    *base* and *comp* are the names of the calculation engines to compare.
100    """
101
102    is_2d = hasattr(data, 'qx_data')
103    model_info = core.load_model_info(name)
104    pars = get_pars(model_info, use_demo=True)
105    header = ('\n"Model","%s","Count","%d","Dimension","%s"'
106              % (name, N, "2D" if is_2d else "1D"))
107    if not mono:
108        header += ',"Cutoff",%g'%(cutoff,)
109    print(header)
110
111    if is_2d:
112        if not model_info.parameters.has_2d:
113            print(',"1-D only"')
114            return
115
116    # Some not very clean macros for evaluating the models and checking the
117    # results.  They freely use variables from the current scope, even some
118    # which have not been defined yet, complete with abuse of mutable lists
119    # to allow them to update values in the current scope since nonlocal
120    # declarations are not available in python 2.7.
121    def try_model(fn, pars):
122        """
123        Return the model evaluated at *pars*.  If there is an exception,
124        print it and return NaN of the right shape.
125        """
126        try:
127            result = fn(**pars)
128        except Exception:
129            traceback.print_exc()
130            print("when comparing %s for %d"%(name, seed))
131            if hasattr(data, 'qx_data'):
132                result = np.NaN*data.data
133            else:
134                result = np.NaN*data.x
135        return result
136    def check_model(pars):
137        """
138        Run the two calculators against *pars*, returning statistics
139        on the differences.  See :func:`calc_stats` for the list of stats.
140        """
141        base_value = try_model(calc_base, pars)
142        comp_value = try_model(calc_comp, pars)
143        stats = calc_stats(base_value, comp_value, index)
144        max_diff[0] = max(max_diff[0], stats[0])
145        good[0] = good[0] and (stats[0] < expected)
146        return list(stats)
147
148
149    try:
150        calc_base = make_engine(model_info, data, base, cutoff)
151        calc_comp = make_engine(model_info, data, comp, cutoff)
152    except Exception as exc:
153        #raise
154        print('"Error: %s"'%str(exc).replace('"', "'"))
155        print('"good","%d of %d","max diff",%g' % (0, N, np.NaN))
156        return
157    expected = max(PRECISION[base], PRECISION[comp])
158
159    num_good = 0
160    first = True
161    max_diff = [0]
162    for k in range(N):
163        print("Model %s %d"%(name, k+1), file=sys.stderr)
164        seed = np.random.randint(1e6)
165        np.random.seed(seed)
166        pars_i = randomize_pars(model_info, pars)
167        constrain_pars(model_info, pars_i)
168        if mono:
169            pars_i = suppress_pd(pars_i)
170
171        good = [True]
172        columns = check_model(pars_i)
173        columns += [v for _, v in sorted(pars_i.items())]
174        if first:
175            labels = [" vs. ".join((calc_base.engine, calc_comp.engine))]
176            print_column_headers(pars_i, labels)
177            first = False
178        if good[0]:
179            num_good += 1
180        else:
181            print(("%d,"%seed)+','.join("%s"%v for v in columns))
182    print('"good","%d of %d","max diff",%g'%(num_good, N, max_diff[0]))
183
184
185def print_usage():
186    """
187    Print the command usage string.
188    """
189    print("usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono) (single|double|quad)",
190          file=sys.stderr)
191
192
193def print_models():
194    """
195    Print the list of available models in columns.
196    """
197    print(columnize(MODELS, indent="  "))
198
199
200def print_help():
201    """
202    Print usage string, the option description and the list of available models.
203    """
204    print_usage()
205    print("""\
206
207MODEL is the model name of the model or one of the model types listed in
208sasmodels.core.list_models (all, py, c, double, single, opencl, 1d, 2d,
209nonmagnetic, magnetic).  Model types can be combined, such as 2d+single.
210
211COUNT is the number of randomly generated parameter sets to try. A value
212of "10000" is a reasonable check for monodisperse models, or "100" for
213polydisperse models.   For a quick check, use "100" and "5" respectively.
214
215NQ is the number of Q values to calculate.  If it starts with "1d", then
216it is a 1-dimensional problem, with log spaced Q points from 1e-3 to 1.0.
217If it starts with "2d" then it is a 2-dimensional problem, with linearly
218spaced points Q points from -1.0 to 1.0 in each dimension. The usual
219values are "1d100" for 1-D and "2d32" for 2-D.
220
221CUTOFF is the cutoff value to use for the polydisperse distribution. Weights
222below the cutoff will be ignored.  Use "mono" for monodisperse models.  The
223choice of polydisperse parameters, and the number of points in the distribution
224is set in compare.py defaults for each model.  Polydispersity is given in the
225"demo" attribute of each model.
226
227PRECISION is the floating point precision to use for comparisons.  If two
228precisions are given, then compare one to the other.  Precision is one of
229fast, single, double for GPU or single!, double!, quad! for DLL.  If no
230precision is given, then use single and double! respectively.
231
232Available models:
233""")
234    print_models()
235
236def main(argv):
237    """
238    Main program.
239    """
240    if len(argv) not in (3, 4, 5, 6):
241        print_help()
242        return
243
244    target = argv[0]
245    try:
246        model_list = [target] if target in MODELS else core.list_models(target)
247    except ValueError:
248        print('Bad model %s.  Use model type or one of:' % target, file=sys.stderr)
249        print_models()
250        print('model types: all, py, c, double, single, opencl, 1d, 2d, nonmagnetic, magnetic')
251        return
252    try:
253        count = int(argv[1])
254        is2D = argv[2].startswith('2d')
255        assert argv[2][1] == 'd'
256        Nq = int(argv[2][2:])
257        mono = len(argv) <= 3 or argv[3] == 'mono'
258        cutoff = float(argv[3]) if not mono else 0
259        base = argv[4] if len(argv) > 4 else "single"
260        comp = argv[5] if len(argv) > 5 else "double!"
261    except Exception:
262        traceback.print_exc()
263        print_usage()
264        return
265
266    data, index = make_data({
267        'qmin': 0.001, 'qmax': 1.0, 'is2d': is2D, 'nq': Nq, 'res': 0.,
268        'accuracy': 'Low', 'view':'log', 'zero': False
269        })
270    for model in model_list:
271        compare_instance(model, data, index, N=count, mono=mono,
272                         cutoff=cutoff, base=base, comp=comp)
273
274if __name__ == "__main__":
275    #from .compare import push_seed
276    #with push_seed(1): main(sys.argv[1:])
277    main(sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.