source: sasmodels/sasmodels/compare_many.py @ cd3dba0

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since cd3dba0 was cd3dba0, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

improve compare.py so that parameters can be constrained to valid values

  • Property mode set to 100755
File size: 5.5 KB
Line 
1#!/usr/bin/env python
2import sys
3import traceback
4
5import numpy as np
6
7from . import core
8from .kernelcl import environment
9from .compare import (MODELS, randomize_model, suppress_pd, eval_sasview,
10                      eval_opencl, eval_ctypes, make_data, get_demo_pars,
11                      columnize, constrain_pars)
12
13def get_stats(target, value, index):
14    resid = abs(value-target)[index]
15    relerr = resid/target[index]
16    srel = np.argsort(relerr)
17    #p90 = int(len(relerr)*0.90)
18    p95 = int(len(relerr)*0.95)
19    maxrel = np.max(relerr)
20    rel95 = relerr[srel[p95]]
21    maxabs = np.max(resid[srel[p95:]])
22    maxval = np.max(value[srel[p95:]])
23    return maxrel,rel95,maxabs,maxval
24
25def print_column_headers(pars, parts):
26    stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|'))
27    groups = ['']
28    for p in parts:
29        groups.append(p)
30        groups.extend(['']*(len(stats)-1))
31    groups.append("Parameters")
32    columns = ['Seed'] + stats*len(parts) +  list(sorted(pars.keys()))
33    print(','.join('"%s"'%c for c in groups))
34    print(','.join('"%s"'%c for c in columns))
35
36def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5):
37    model_definition = core.load_model_definition(name)
38    pars = get_demo_pars(model_definition)
39    header = '\n"Model","%s","Count","%d"'%(name, N)
40    if not mono: header += ',"Cutoff",%g'%(cutoff,)
41    print(header)
42
43    def trymodel(fn, *args, **kw):
44        try:
45            result, _ = fn(model_definition, pars_i, data, *args, **kw)
46        except KeyboardInterrupt:
47            raise
48        except:
49            print >>sys.stderr, traceback.format_exc()
50            print >>sys.stderr, "when comparing",name,"for seed",seed
51            if hasattr(data, 'qx_data'):
52                result = np.NaN*data.data
53            else:
54                result = np.NaN*data.x
55        return result
56
57    num_good = 0
58    first = True
59    for k in range(N):
60        print >>sys.stderr, name, k
61        pars_i, seed = randomize_model(pars)
62        constrain_pars(model_definition, pars_i)
63        if mono: suppress_pd(pars_i)
64
65        good = True
66        labels = []
67        columns = []
68        if 1:
69            sasview_value = trymodel(eval_sasview)
70        if 0:
71            gpu_single_value = trymodel(eval_opencl, dtype='single', cutoff=cutoff)
72            stats = get_stats(sasview_value, gpu_single_value, index)
73            columns.extend(stats)
74            labels.append('GPU single')
75            good = good and (stats[0] < 1e-14)
76        if 0 and environment().has_double:
77            gpu_double_value = trymodel(eval_opencl, dtype='double', cutoff=cutoff)
78            stats = get_stats(sasview_value, gpu_double_value, index)
79            columns.extend(stats)
80            labels.append('GPU double')
81            good = good and (stats[0] < 1e-14)
82        if 1:
83            cpu_double_value = trymodel(eval_ctypes, dtype='double', cutoff=cutoff)
84            stats = get_stats(sasview_value, cpu_double_value, index)
85            columns.extend(stats)
86            labels.append('CPU double')
87            good = good and (stats[0] < 1e-14)
88        if 0:
89            stats = get_stats(cpu_double_value, gpu_single_value, index)
90            columns.extend(stats)
91            labels.append('single/double')
92            good = good and (stats[0] < 1e-14)
93
94        columns += [v for _,v in sorted(pars_i.items())]
95        if first:
96            print_column_headers(pars_i, labels)
97            first = False
98        if good:
99            num_good += 1
100        else:
101            print(("%d,"%seed)+','.join("%g"%v for v in columns))
102    print '"%d/%d good"'%(num_good, N)
103
104
105def print_usage():
106    print "usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono)"
107
108
109def print_models():
110    print(columnize(MODELS, indent="  "))
111
112
113def print_help():
114    print_usage()
115    print("""\
116
117MODEL is the model name of the model or "all" for all the models
118in alphabetical order.
119
120COUNT is the number of randomly generated parameter sets to try. A value
121of "10000" is a reasonable check for monodisperse models, or "100" for
122polydisperse models.   For a quick check, use "100" and "5" respectively.
123
124NQ is the number of Q values to calculate.  If it starts with "1d", then
125it is a 1-dimensional problem, with log spaced Q points from 1e-3 to 1.0.
126If it starts with "2d" then it is a 2-dimensional problem, with linearly
127spaced points Q points from -1.0 to 1.0 in each dimension. The usual
128values are "1d100" for 1-D and "2d32" for 2-D.
129
130CUTOFF is the cutoff value to use for the polydisperse distribution. Weights
131below the cutoff will be ignored.  Use "mono" for monodisperse models.  The
132choice of polydisperse parameters, and the number of points in the distribution
133is set in compare.py defaults for each model.
134
135Available models:
136""")
137    print_models()
138
139def main():
140    if len(sys.argv) == 1:
141        print_help()
142        sys.exit(1)
143
144    model = sys.argv[1]
145    if not (model in MODELS) and (model != "all"):
146        print 'Bad model %s.  Use "all" or one of:'
147        print_models()
148        sys.exit(1)
149    try:
150        count = int(sys.argv[2])
151        is2D = sys.argv[3].startswith('2d')
152        assert sys.argv[3][1] == 'd'
153        Nq = int(sys.argv[3][2:])
154        mono = sys.argv[4] == 'mono'
155        cutoff = float(sys.argv[4]) if not mono else 0
156    except:
157        print_usage()
158        sys.exit(1)
159
160    data, index = make_data(qmax=1.0, is2D=is2D, Nq=Nq)
161    model_list = [model] if model != "all" else MODELS
162    for model in model_list:
163        compare_instance(model, data, index, N=count, mono=mono, cutoff=cutoff)
164
165if __name__ == "__main__":
166    main()
Note: See TracBrowser for help on using the repository browser.