source: sasmodels/sasmodels/compare_many.py @ ec7e360

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ec7e360 was ec7e360, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor option processing for compare.py, allowing more flexible selection of calculation engines

  • Property mode set to 100755
File size: 6.1 KB
Line 
1#!/usr/bin/env python
2import sys
3import traceback
4
5import numpy as np
6
7from . import core
8from .kernelcl import environment
9from .compare import (MODELS, randomize_pars, suppress_pd, eval_sasview,
10                      eval_opencl, eval_ctypes, make_data, get_demo_pars,
11                      columnize, constrain_pars, constrain_new_to_old,
12                      make_engine)
13
14def calc_stats(target, value, index):
15    resid = abs(value-target)[index]
16    relerr = resid/target[index]
17    srel = np.argsort(relerr)
18    #p90 = int(len(relerr)*0.90)
19    p95 = int(len(relerr)*0.95)
20    maxrel = np.max(relerr)
21    rel95 = relerr[srel[p95]]
22    maxabs = np.max(resid[srel[p95:]])
23    maxval = np.max(value[srel[p95:]])
24    return maxrel,rel95,maxabs,maxval
25
26def print_column_headers(pars, parts):
27    stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|'))
28    groups = ['']
29    for p in parts:
30        groups.append(p)
31        groups.extend(['']*(len(stats)-1))
32    groups.append("Parameters")
33    columns = ['Seed'] + stats*len(parts) +  list(sorted(pars.keys()))
34    print(','.join('"%s"'%c for c in groups))
35    print(','.join('"%s"'%c for c in columns))
36
37PRECISION = {
38    'fast': 1e-3,
39    'half': 1e-3,
40    'single': 5e-5,
41    'double': 5e-14,
42    'single!': 5e-5,
43    'double!': 5e-14,
44    'quad!': 5e-18,
45    'sasview': 5e-14,
46}
47def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5,
48                     base='sasview', comp='double'):
49    model_definition = core.load_model_definition(name)
50    pars = get_demo_pars(model_definition)
51    header = '\n"Model","%s","Count","%d"'%(name, N)
52    if not mono: header += ',"Cutoff",%g'%(cutoff,)
53    print(header)
54
55    # Some not very clean macros for evaluating the models and checking the
56    # results.  They freely use variables from the current scope, even some
57    # which have not been defined yet, complete with abuse of mutable lists
58    # to allow them to update values in the current scope since nonlocal
59    # declarations are not available in python 2.7.
60    def try_model(fn, pars):
61        try:
62            result = fn(**pars)
63        except KeyboardInterrupt:
64            raise
65        except:
66            traceback.print_exc()
67            print("when comparing %s for %d"%(name, seed))
68            if hasattr(data, 'qx_data'):
69                result = np.NaN*data.data
70            else:
71                result = np.NaN*data.x
72        return result
73    def check_model(pars):
74        base_value = try_model(calc_base, pars)
75        comp_value = try_model(calc_comp, pars)
76        stats = calc_stats(base_value, comp_value, index)
77        max_diff[0] = max(max_diff[0], stats[0])
78        good[0] = good[0] and (stats[0] < expected)
79        return list(stats)
80
81
82    calc_base = make_engine(model_definition, data, base, cutoff)
83    calc_comp = make_engine(model_definition, data, comp, cutoff)
84    expected = max(PRECISION[base], PRECISION[comp])
85
86    num_good = 0
87    first = True
88    max_diff = [0]
89    for k in range(N):
90        print("%s %d"%(name, k))
91        seed = np.random.randint(1e6)
92        pars_i = randomize_pars(pars, seed)
93        constrain_pars(model_definition, pars_i)
94        constrain_new_to_old(model_definition, pars_i)
95        if mono:
96            pars_i = suppress_pd(pars_i)
97
98        good = [True]
99        columns = check_model(pars_i)
100        columns += [v for _,v in sorted(pars_i.items())]
101        if first:
102            labels = [" vs. ".join((calc_base.engine, calc_comp.engine))]
103            print_column_headers(pars_i, labels)
104            first = False
105        if good[0]:
106            num_good += 1
107        else:
108            print(("%d,"%seed)+','.join("%s"%v for v in columns))
109    print('"good","%d/%d","max diff",%g'%(num_good, N, max_diff[0]))
110
111
112def print_usage():
113    print("usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono) (single|double|quad)")
114
115
116def print_models():
117    print(columnize(MODELS, indent="  "))
118
119
120def print_help():
121    print_usage()
122    print("""\
123
124MODEL is the model name of the model or "all" for all the models
125in alphabetical order.
126
127COUNT is the number of randomly generated parameter sets to try. A value
128of "10000" is a reasonable check for monodisperse models, or "100" for
129polydisperse models.   For a quick check, use "100" and "5" respectively.
130
131NQ is the number of Q values to calculate.  If it starts with "1d", then
132it is a 1-dimensional problem, with log spaced Q points from 1e-3 to 1.0.
133If it starts with "2d" then it is a 2-dimensional problem, with linearly
134spaced points Q points from -1.0 to 1.0 in each dimension. The usual
135values are "1d100" for 1-D and "2d32" for 2-D.
136
137CUTOFF is the cutoff value to use for the polydisperse distribution. Weights
138below the cutoff will be ignored.  Use "mono" for monodisperse models.  The
139choice of polydisperse parameters, and the number of points in the distribution
140is set in compare.py defaults for each model.
141
142PRECISION is the floating point precision to use for comparisons.  If two
143precisions are given, then compare one to the other, ignoring sasview.
144
145Available models:
146""")
147    print_models()
148
149def main():
150    if len(sys.argv) not in (6,7):
151        print_help()
152        sys.exit(1)
153
154    model = sys.argv[1]
155    if not (model in MODELS) and (model != "all"):
156        print('Bad model %s.  Use "all" or one of:')
157        print_models()
158        sys.exit(1)
159    try:
160        count = int(sys.argv[2])
161        is2D = sys.argv[3].startswith('2d')
162        assert sys.argv[3][1] == 'd'
163        Nq = int(sys.argv[3][2:])
164        mono = sys.argv[4] == 'mono'
165        cutoff = float(sys.argv[4]) if not mono else 0
166        base = sys.argv[5]
167        comp = sys.argv[6] if len(sys.argv) > 6 else "sasview"
168    except:
169        traceback.print_exc()
170        print_usage()
171        sys.exit(1)
172
173    data, index = make_data({'qmax':1.0, 'is2d':is2D, 'nq':Nq, 'res':0.,
174                              'accuracy': 'Low', 'view':'log'})
175    model_list = [model] if model != "all" else MODELS
176    for model in model_list:
177        compare_instance(model, data, index, N=count, mono=mono,
178                         cutoff=cutoff, base=base, comp=comp)
179
180if __name__ == "__main__":
181    main()
Note: See TracBrowser for help on using the repository browser.