source: sasmodels/sasmodels/compare_many.py @ 319ab14

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 319ab14 was 319ab14, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

allow comparison between double/quad precision and sasview

  • Property mode set to 100755
File size: 6.3 KB
Line 
1#!/usr/bin/env python
2import sys
3import traceback
4
5import numpy as np
6
7from . import core
8from .kernelcl import environment
9from .compare import (MODELS, randomize_model, suppress_pd, eval_sasview,
10                      eval_opencl, eval_ctypes, make_data, get_demo_pars,
11                      columnize, constrain_pars)
12
13def calc_stats(target, value, index):
14    resid = abs(value-target)[index]
15    relerr = resid/target[index]
16    srel = np.argsort(relerr)
17    #p90 = int(len(relerr)*0.90)
18    p95 = int(len(relerr)*0.95)
19    maxrel = np.max(relerr)
20    rel95 = relerr[srel[p95]]
21    maxabs = np.max(resid[srel[p95:]])
22    maxval = np.max(value[srel[p95:]])
23    return maxrel,rel95,maxabs,maxval
24
25def print_column_headers(pars, parts):
26    stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|'))
27    groups = ['']
28    for p in parts:
29        groups.append(p)
30        groups.extend(['']*(len(stats)-1))
31    groups.append("Parameters")
32    columns = ['Seed'] + stats*len(parts) +  list(sorted(pars.keys()))
33    print(','.join('"%s"'%c for c in groups))
34    print(','.join('"%s"'%c for c in columns))
35
36def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5,
37                     precision='double'):
38    model_definition = core.load_model_definition(name)
39    pars = get_demo_pars(model_definition)
40    header = '\n"Model","%s","Count","%d"'%(name, N)
41    if not mono: header += ',"Cutoff",%g'%(cutoff,)
42    print(header)
43
44    # Some not very clean macros for evaluating the models and checking the
45    # results.  They freely use variables from the current scope, even some
46    # which have not been defined yet, complete with abuse of mutable lists
47    # to allow them to update values in the current scope since nonlocal
48    # declarations are not available in python 2.7.
49    def try_model(fn, *args, **kw):
50        try:
51            result, _ = fn(model_definition, pars_i, data, *args, **kw)
52        except KeyboardInterrupt:
53            raise
54        except:
55            print >>sys.stderr, traceback.format_exc()
56            print >>sys.stderr, "when comparing",name,"for seed",seed
57            if hasattr(data, 'qx_data'):
58                result = np.NaN*data.data
59            else:
60                result = np.NaN*data.x
61        return result
62    def check_model(label, target, value, acceptable):
63        stats = calc_stats(target, value, index)
64        columns.extend(stats)
65        labels.append('GPU single')
66        max_diff[0] = max(max_diff[0], stats[0])
67        good[0] = good[0] and (stats[0] < acceptable)
68
69    num_good = 0
70    first = True
71    max_diff = [0]
72    for k in range(N):
73        print >>sys.stderr, name, k
74        pars_i, seed = randomize_model(pars)
75        constrain_pars(model_definition, pars_i)
76        if mono: suppress_pd(pars_i)
77
78        good = [True]
79        labels = []
80        columns = []
81        #target = try_model(eval_sasview)
82        target = try_model(eval_opencl, dtype='longdouble', cutoff=cutoff)
83        if precision == 'single':
84            value = try_model(eval_opencl, dtype='single', cutoff=cutoff)
85            check_model('GPU single', target, value, 5e-5)
86            single_value = value  # remember for single/double comparison
87        elif precision == 'double':
88            if environment().has_double:
89                label = 'GPU double'
90                value = try_model(eval_opencl, dtype='double', cutoff=cutoff)
91            else:
92                label = 'CPU double'
93                value = try_model(eval_ctypes, dtype='double', cutoff=cutoff)
94            check_model(label, target, value, 5e-14)
95            double_value = value  # remember for single/double comparison
96        elif precision == 'quad':
97            value = try_model(eval_opencl, dtype='longdouble', cutoff=cutoff)
98            check_model('CPU quad', target, value, 5e-14)
99        if 0:
100            check_model('single/double', double_value, single_value, 5e-5)
101
102        columns += [v for _,v in sorted(pars_i.items())]
103        if first:
104            print_column_headers(pars_i, labels)
105            first = False
106        if good[0]:
107            num_good += 1
108        else:
109            print(("%d,"%seed)+','.join("%g"%v for v in columns))
110    print '"good","%d/%d","max diff",%g'%(num_good, N, max_diff[0])
111
112
113def print_usage():
114    print "usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono) (single|double|quad)"
115
116
117def print_models():
118    print(columnize(MODELS, indent="  "))
119
120
121def print_help():
122    print_usage()
123    print("""\
124
125MODEL is the model name of the model or "all" for all the models
126in alphabetical order.
127
128COUNT is the number of randomly generated parameter sets to try. A value
129of "10000" is a reasonable check for monodisperse models, or "100" for
130polydisperse models.   For a quick check, use "100" and "5" respectively.
131
132NQ is the number of Q values to calculate.  If it starts with "1d", then
133it is a 1-dimensional problem, with log spaced Q points from 1e-3 to 1.0.
134If it starts with "2d" then it is a 2-dimensional problem, with linearly
135spaced points Q points from -1.0 to 1.0 in each dimension. The usual
136values are "1d100" for 1-D and "2d32" for 2-D.
137
138CUTOFF is the cutoff value to use for the polydisperse distribution. Weights
139below the cutoff will be ignored.  Use "mono" for monodisperse models.  The
140choice of polydisperse parameters, and the number of points in the distribution
141is set in compare.py defaults for each model.
142
143PRECISION is the floating point precision to use for comparisons.
144
145Available models:
146""")
147    print_models()
148
149def main():
150    if len(sys.argv) != 6:
151        print_help()
152        sys.exit(1)
153
154    model = sys.argv[1]
155    if not (model in MODELS) and (model != "all"):
156        print 'Bad model %s.  Use "all" or one of:'
157        print_models()
158        sys.exit(1)
159    try:
160        count = int(sys.argv[2])
161        is2D = sys.argv[3].startswith('2d')
162        assert sys.argv[3][1] == 'd'
163        Nq = int(sys.argv[3][2:])
164        mono = sys.argv[4] == 'mono'
165        cutoff = float(sys.argv[4]) if not mono else 0
166        precision = sys.argv[5]
167    except:
168        traceback.print_exc()
169        print_usage()
170        sys.exit(1)
171
172    data, index = make_data(qmax=1.0, is2D=is2D, Nq=Nq)
173    model_list = [model] if model != "all" else MODELS
174    for model in model_list:
175        compare_instance(model, data, index, N=count, mono=mono,
176                         cutoff=cutoff, precision=precision)
177
178if __name__ == "__main__":
179    main()
Note: See TracBrowser for help on using the repository browser.