source: sasmodels/sasmodels/compare_many.py @ b514adf

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since b514adf was b514adf, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

set constraints so multi_compare has fewer spurious errors

  • Property mode set to 100755
File size: 5.8 KB
Line 
1#!/usr/bin/env python
2import sys
3import traceback
4
5import numpy as np
6
7from . import core
8from .kernelcl import environment
9from .compare import (MODELS, randomize_model, suppress_pd, eval_sasview,
10                      eval_opencl, eval_ctypes, make_data, get_demo_pars,
11                      columnize, constrain_pars)
12
13def get_stats(target, value, index):
14    resid = abs(value-target)[index]
15    relerr = resid/target[index]
16    srel = np.argsort(relerr)
17    #p90 = int(len(relerr)*0.90)
18    p95 = int(len(relerr)*0.95)
19    maxrel = np.max(relerr)
20    rel95 = relerr[srel[p95]]
21    maxabs = np.max(resid[srel[p95:]])
22    maxval = np.max(value[srel[p95:]])
23    return maxrel,rel95,maxabs,maxval
24
25def print_column_headers(pars, parts):
26    stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|'))
27    groups = ['']
28    for p in parts:
29        groups.append(p)
30        groups.extend(['']*(len(stats)-1))
31    groups.append("Parameters")
32    columns = ['Seed'] + stats*len(parts) +  list(sorted(pars.keys()))
33    print(','.join('"%s"'%c for c in groups))
34    print(','.join('"%s"'%c for c in columns))
35
36def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5):
37    model_definition = core.load_model_definition(name)
38    pars = get_demo_pars(model_definition)
39    header = '\n"Model","%s","Count","%d"'%(name, N)
40    if not mono: header += ',"Cutoff",%g'%(cutoff,)
41    print(header)
42
43    def trymodel(fn, *args, **kw):
44        try:
45            result, _ = fn(model_definition, pars_i, data, *args, **kw)
46        except KeyboardInterrupt:
47            raise
48        except:
49            print >>sys.stderr, traceback.format_exc()
50            print >>sys.stderr, "when comparing",name,"for seed",seed
51            if hasattr(data, 'qx_data'):
52                result = np.NaN*data.data
53            else:
54                result = np.NaN*data.x
55        return result
56
57    num_good = 0
58    first = True
59    max_diff = 0
60    for k in range(N):
61        print >>sys.stderr, name, k
62        pars_i, seed = randomize_model(pars)
63        constrain_pars(model_definition, pars_i)
64        if mono: suppress_pd(pars_i)
65
66        good = True
67        labels = []
68        columns = []
69        if 1:
70            sasview_value = trymodel(eval_sasview)
71        if 0:
72            gpu_single_value = trymodel(eval_opencl, dtype='single', cutoff=cutoff)
73            stats = get_stats(sasview_value, gpu_single_value, index)
74            columns.extend(stats)
75            labels.append('GPU single')
76            max_diff = max(max_diff, stats[0])
77            good = good and (stats[0] < 5e-5)
78        if 0 and environment().has_double:
79            gpu_double_value = trymodel(eval_opencl, dtype='double', cutoff=cutoff)
80            stats = get_stats(sasview_value, gpu_double_value, index)
81            columns.extend(stats)
82            labels.append('GPU double')
83            max_diff = max(max_diff, stats[0])
84            good = good and (stats[0] < 1e-12)
85        if 1:
86            cpu_double_value = trymodel(eval_ctypes, dtype='double', cutoff=cutoff)
87            stats = get_stats(sasview_value, cpu_double_value, index)
88            columns.extend(stats)
89            labels.append('CPU double')
90            max_diff = max(max_diff, stats[0])
91            good = good and (stats[0] < 1e-12)
92        if 0:
93            stats = get_stats(cpu_double_value, gpu_single_value, index)
94            columns.extend(stats)
95            labels.append('single/double')
96            max_diff = max(max_diff, stats[0])
97            good = good and (stats[0] < 5e-5)
98
99        columns += [v for _,v in sorted(pars_i.items())]
100        if first:
101            print_column_headers(pars_i, labels)
102            first = False
103        if good:
104            num_good += 1
105        else:
106            print(("%d,"%seed)+','.join("%g"%v for v in columns))
107    print '"good","%d/%d","max diff",%g'%(num_good, N, max_diff)
108
109
110def print_usage():
111    print "usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono)"
112
113
114def print_models():
115    print(columnize(MODELS, indent="  "))
116
117
118def print_help():
119    print_usage()
120    print("""\
121
122MODEL is the model name of the model or "all" for all the models
123in alphabetical order.
124
125COUNT is the number of randomly generated parameter sets to try. A value
126of "10000" is a reasonable check for monodisperse models, or "100" for
127polydisperse models.   For a quick check, use "100" and "5" respectively.
128
129NQ is the number of Q values to calculate.  If it starts with "1d", then
130it is a 1-dimensional problem, with log spaced Q points from 1e-3 to 1.0.
131If it starts with "2d" then it is a 2-dimensional problem, with linearly
132spaced points Q points from -1.0 to 1.0 in each dimension. The usual
133values are "1d100" for 1-D and "2d32" for 2-D.
134
135CUTOFF is the cutoff value to use for the polydisperse distribution. Weights
136below the cutoff will be ignored.  Use "mono" for monodisperse models.  The
137choice of polydisperse parameters, and the number of points in the distribution
138is set in compare.py defaults for each model.
139
140Available models:
141""")
142    print_models()
143
144def main():
145    if len(sys.argv) == 1:
146        print_help()
147        sys.exit(1)
148
149    model = sys.argv[1]
150    if not (model in MODELS) and (model != "all"):
151        print 'Bad model %s.  Use "all" or one of:'
152        print_models()
153        sys.exit(1)
154    try:
155        count = int(sys.argv[2])
156        is2D = sys.argv[3].startswith('2d')
157        assert sys.argv[3][1] == 'd'
158        Nq = int(sys.argv[3][2:])
159        mono = sys.argv[4] == 'mono'
160        cutoff = float(sys.argv[4]) if not mono else 0
161    except:
162        print_usage()
163        sys.exit(1)
164
165    data, index = make_data(qmax=1.0, is2D=is2D, Nq=Nq)
166    model_list = [model] if model != "all" else MODELS
167    for model in model_list:
168        compare_instance(model, data, index, N=count, mono=mono, cutoff=cutoff)
169
170if __name__ == "__main__":
171    main()
Note: See TracBrowser for help on using the repository browser.