1 | #!/usr/bin/env python |
---|
2 | import sys |
---|
3 | import traceback |
---|
4 | |
---|
5 | import numpy as np |
---|
6 | |
---|
7 | from . import core |
---|
8 | from .kernelcl import environment |
---|
9 | from .compare import (MODELS, randomize_pars, suppress_pd, eval_sasview, |
---|
10 | eval_opencl, eval_ctypes, make_data, get_demo_pars, |
---|
11 | columnize, constrain_pars, constrain_new_to_old, |
---|
12 | make_engine) |
---|
13 | |
---|
14 | def calc_stats(target, value, index): |
---|
15 | resid = abs(value-target)[index] |
---|
16 | relerr = resid/target[index] |
---|
17 | srel = np.argsort(relerr) |
---|
18 | #p90 = int(len(relerr)*0.90) |
---|
19 | p95 = int(len(relerr)*0.95) |
---|
20 | maxrel = np.max(relerr) |
---|
21 | rel95 = relerr[srel[p95]] |
---|
22 | maxabs = np.max(resid[srel[p95:]]) |
---|
23 | maxval = np.max(value[srel[p95:]]) |
---|
24 | return maxrel,rel95,maxabs,maxval |
---|
25 | |
---|
26 | def print_column_headers(pars, parts): |
---|
27 | stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|')) |
---|
28 | groups = [''] |
---|
29 | for p in parts: |
---|
30 | groups.append(p) |
---|
31 | groups.extend(['']*(len(stats)-1)) |
---|
32 | groups.append("Parameters") |
---|
33 | columns = ['Seed'] + stats*len(parts) + list(sorted(pars.keys())) |
---|
34 | print(','.join('"%s"'%c for c in groups)) |
---|
35 | print(','.join('"%s"'%c for c in columns)) |
---|
36 | |
---|
37 | PRECISION = { |
---|
38 | 'fast': 1e-3, |
---|
39 | 'half': 1e-3, |
---|
40 | 'single': 5e-5, |
---|
41 | 'double': 5e-14, |
---|
42 | 'single!': 5e-5, |
---|
43 | 'double!': 5e-14, |
---|
44 | 'quad!': 5e-18, |
---|
45 | 'sasview': 5e-14, |
---|
46 | } |
---|
47 | def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5, |
---|
48 | base='sasview', comp='double'): |
---|
49 | model_definition = core.load_model_definition(name) |
---|
50 | pars = get_demo_pars(model_definition) |
---|
51 | header = '\n"Model","%s","Count","%d"'%(name, N) |
---|
52 | if not mono: header += ',"Cutoff",%g'%(cutoff,) |
---|
53 | print(header) |
---|
54 | |
---|
55 | # Some not very clean macros for evaluating the models and checking the |
---|
56 | # results. They freely use variables from the current scope, even some |
---|
57 | # which have not been defined yet, complete with abuse of mutable lists |
---|
58 | # to allow them to update values in the current scope since nonlocal |
---|
59 | # declarations are not available in python 2.7. |
---|
60 | def try_model(fn, pars): |
---|
61 | try: |
---|
62 | result = fn(**pars) |
---|
63 | except KeyboardInterrupt: |
---|
64 | raise |
---|
65 | except: |
---|
66 | traceback.print_exc() |
---|
67 | print("when comparing %s for %d"%(name, seed)) |
---|
68 | if hasattr(data, 'qx_data'): |
---|
69 | result = np.NaN*data.data |
---|
70 | else: |
---|
71 | result = np.NaN*data.x |
---|
72 | return result |
---|
73 | def check_model(pars): |
---|
74 | base_value = try_model(calc_base, pars) |
---|
75 | comp_value = try_model(calc_comp, pars) |
---|
76 | stats = calc_stats(base_value, comp_value, index) |
---|
77 | max_diff[0] = max(max_diff[0], stats[0]) |
---|
78 | good[0] = good[0] and (stats[0] < expected) |
---|
79 | return list(stats) |
---|
80 | |
---|
81 | |
---|
82 | calc_base = make_engine(model_definition, data, base, cutoff) |
---|
83 | calc_comp = make_engine(model_definition, data, comp, cutoff) |
---|
84 | expected = max(PRECISION[base], PRECISION[comp]) |
---|
85 | |
---|
86 | num_good = 0 |
---|
87 | first = True |
---|
88 | max_diff = [0] |
---|
89 | for k in range(N): |
---|
90 | print("%s %d"%(name, k)) |
---|
91 | seed = np.random.randint(1e6) |
---|
92 | pars_i = randomize_pars(pars, seed) |
---|
93 | constrain_pars(model_definition, pars_i) |
---|
94 | constrain_new_to_old(model_definition, pars_i) |
---|
95 | if mono: |
---|
96 | pars_i = suppress_pd(pars_i) |
---|
97 | |
---|
98 | good = [True] |
---|
99 | columns = check_model(pars_i) |
---|
100 | columns += [v for _,v in sorted(pars_i.items())] |
---|
101 | if first: |
---|
102 | labels = [" vs. ".join((calc_base.engine, calc_comp.engine))] |
---|
103 | print_column_headers(pars_i, labels) |
---|
104 | first = False |
---|
105 | if good[0]: |
---|
106 | num_good += 1 |
---|
107 | else: |
---|
108 | print(("%d,"%seed)+','.join("%s"%v for v in columns)) |
---|
109 | print('"good","%d/%d","max diff",%g'%(num_good, N, max_diff[0])) |
---|
110 | |
---|
111 | |
---|
112 | def print_usage(): |
---|
113 | print("usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono) (single|double|quad)") |
---|
114 | |
---|
115 | |
---|
116 | def print_models(): |
---|
117 | print(columnize(MODELS, indent=" ")) |
---|
118 | |
---|
119 | |
---|
120 | def print_help(): |
---|
121 | print_usage() |
---|
122 | print("""\ |
---|
123 | |
---|
124 | MODEL is the model name of the model or "all" for all the models |
---|
125 | in alphabetical order. |
---|
126 | |
---|
127 | COUNT is the number of randomly generated parameter sets to try. A value |
---|
128 | of "10000" is a reasonable check for monodisperse models, or "100" for |
---|
129 | polydisperse models. For a quick check, use "100" and "5" respectively. |
---|
130 | |
---|
131 | NQ is the number of Q values to calculate. If it starts with "1d", then |
---|
132 | it is a 1-dimensional problem, with log spaced Q points from 1e-3 to 1.0. |
---|
133 | If it starts with "2d" then it is a 2-dimensional problem, with linearly |
---|
134 | spaced points Q points from -1.0 to 1.0 in each dimension. The usual |
---|
135 | values are "1d100" for 1-D and "2d32" for 2-D. |
---|
136 | |
---|
137 | CUTOFF is the cutoff value to use for the polydisperse distribution. Weights |
---|
138 | below the cutoff will be ignored. Use "mono" for monodisperse models. The |
---|
139 | choice of polydisperse parameters, and the number of points in the distribution |
---|
140 | is set in compare.py defaults for each model. |
---|
141 | |
---|
142 | PRECISION is the floating point precision to use for comparisons. If two |
---|
143 | precisions are given, then compare one to the other, ignoring sasview. |
---|
144 | |
---|
145 | Available models: |
---|
146 | """) |
---|
147 | print_models() |
---|
148 | |
---|
149 | def main(): |
---|
150 | if len(sys.argv) not in (6,7): |
---|
151 | print_help() |
---|
152 | sys.exit(1) |
---|
153 | |
---|
154 | model = sys.argv[1] |
---|
155 | if not (model in MODELS) and (model != "all"): |
---|
156 | print('Bad model %s. Use "all" or one of:') |
---|
157 | print_models() |
---|
158 | sys.exit(1) |
---|
159 | try: |
---|
160 | count = int(sys.argv[2]) |
---|
161 | is2D = sys.argv[3].startswith('2d') |
---|
162 | assert sys.argv[3][1] == 'd' |
---|
163 | Nq = int(sys.argv[3][2:]) |
---|
164 | mono = sys.argv[4] == 'mono' |
---|
165 | cutoff = float(sys.argv[4]) if not mono else 0 |
---|
166 | base = sys.argv[5] |
---|
167 | comp = sys.argv[6] if len(sys.argv) > 6 else "sasview" |
---|
168 | except: |
---|
169 | traceback.print_exc() |
---|
170 | print_usage() |
---|
171 | sys.exit(1) |
---|
172 | |
---|
173 | data, index = make_data({'qmax':1.0, 'is2d':is2D, 'nq':Nq, 'res':0., |
---|
174 | 'accuracy': 'Low', 'view':'log'}) |
---|
175 | model_list = [model] if model != "all" else MODELS |
---|
176 | for model in model_list: |
---|
177 | compare_instance(model, data, index, N=count, mono=mono, |
---|
178 | cutoff=cutoff, base=base, comp=comp) |
---|
179 | |
---|
180 | if __name__ == "__main__": |
---|
181 | main() |
---|