source: sasmodels/compare_many.py @ 3b4243d

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 3b4243d was 7cf2cfd, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

refactor compare.py so that bumps/sasview not required for simple tests

  • Property mode set to 100755
File size: 5.8 KB
RevLine 
[216a9e1]1#!/usr/bin/env python
2import sys
[7cf2cfd]3import traceback
[216a9e1]4
5import numpy as np
6
[7cf2cfd]7from sasmodels import core
[f786ff3]8from sasmodels.kernelcl import environment
[216a9e1]9from compare import (MODELS, randomize_model, suppress_pd, eval_sasview,
[7cf2cfd]10                     eval_opencl, eval_ctypes, make_data, get_demo_pars,
11                     columnize)
[216a9e1]12
[34756fd]13def get_stats(target, value, index):
[216a9e1]14    resid = abs(value-target)[index]
15    relerr = resid/target[index]
16    srel = np.argsort(relerr)
[7cf2cfd]17    #p90 = int(len(relerr)*0.90)
[216a9e1]18    p95 = int(len(relerr)*0.95)
19    maxrel = np.max(relerr)
20    rel95 = relerr[srel[p95]]
21    maxabs = np.max(resid[srel[p95:]])
22    maxval = np.max(value[srel[p95:]])
23    return maxrel,rel95,maxabs,maxval
24
25def print_column_headers(pars, parts):
26    stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|'))
27    groups = ['']
28    for p in parts:
29        groups.append(p)
30        groups.extend(['']*(len(stats)-1))
[7cf2cfd]31    groups.append("Parameters")
[216a9e1]32    columns = ['Seed'] + stats*len(parts) +  list(sorted(pars.keys()))
33    print(','.join('"%s"'%c for c in groups))
34    print(','.join('"%s"'%c for c in columns))
35
[ab55943]36def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5):
[7cf2cfd]37    model_definition = core.load_model_definition(name)
[ab55943]38    pars = get_demo_pars(name)
[216a9e1]39    header = '\n"Model","%s","Count","%d"'%(name, N)
40    if not mono: header += ',"Cutoff",%g'%(cutoff,)
41    print(header)
[7cf2cfd]42
43    # Stuff the failure flag into a mutable object so we can update it from
44    # within the nested function.  Note that the nested function uses "pars"
45    # which is dynamically scoped, not lexically scoped in this context.  That
46    # is, pars is replaced each time in the loop, so don't assume that it is
47    # the default values defined above.
48    def trymodel(fn, *args, **kw):
49        try:
50            result, _ = fn(model_definition, pars, data, *args, **kw)
51        except:
52            result = np.NaN
53            traceback.print_exc()
54        return result
55
56    num_good = 0
[216a9e1]57    first = True
58    for _ in range(N):
59        pars, seed = randomize_model(name, pars)
60        if mono: suppress_pd(pars)
61
[7cf2cfd]62        # Force parameter constraints on a per-model basis.
63        if name in ('teubner_strey','broad_peak'):
64            pars['scale'] = 1.0
65        #if name == 'parallelepiped':
66        #    pars['a_side'],pars['b_side'],pars['c_side'] = \
67        #        sorted([pars['a_side'],pars['b_side'],pars['c_side']])
68
69
70        good = True
71        labels = []
72        columns = []
73        if 1:
74            sasview_value = trymodel(eval_sasview)
75        if 0:
76            gpu_single_value = trymodel(eval_opencl, dtype='single', cutoff=cutoff)
77            stats = get_stats(sasview_value, gpu_single_value, index)
78            columns.extend(stats)
79            labels.append('GPU single')
80            good = good and (stats[0] < 1e-14)
81        if 0 and environment().has_double:
82            gpu_double_value = trymodel(eval_opencl, dtype='double', cutoff=cutoff)
83            stats = get_stats(sasview_value, gpu_double_value, index)
84            columns.extend(stats)
85            labels.append('GPU double')
86            good = good and (stats[0] < 1e-14)
87        if 1:
88            cpu_double_value = trymodel(eval_ctypes, dtype='double', cutoff=cutoff)
89            stats = get_stats(sasview_value, cpu_double_value, index)
90            columns.extend(stats)
91            labels.append('CPU double')
92            good = good and (stats[0] < 1e-14)
93        if 0:
94            stats = get_stats(cpu_double_value, gpu_single_value, index)
95            columns.extend(stats)
96            labels.append('single/double')
97            good = good and (stats[0] < 1e-14)
98
99        columns += [v for _,v in sorted(pars.items())]
100        if first:
101            print_column_headers(pars, labels)
102            first = False
103        if good:
104            num_good += 1
[216a9e1]105        else:
[7cf2cfd]106            print(("%d,"%seed)+','.join("%g"%v for v in columns))
107    print '"%d/%d good"'%(num_good, N)
108
109
110def print_usage():
111    print "usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono)"
112
113
114def print_models():
115    print(columnize(MODELS, indent="  "))
[216a9e1]116
117
[7cf2cfd]118def print_help():
119    print_usage()
120    print("""\
121
122MODEL is the model name of the model or "all" for all the models
123in alphabetical order.
[216a9e1]124
125COUNT is the number of randomly generated parameter sets to try. A value
126of "10000" is a reasonable check for monodisperse models, or "100" for
127polydisperse models.   For a quick check, use "100" and "5" respectively.
128
129NQ is the number of Q values to calculate.  If it starts with "1d", then
130it is a 1-dimensional problem, with log spaced Q points from 1e-3 to 1.0.
131If it starts with "2d" then it is a 2-dimensional problem, with linearly
132spaced points Q points from -1.0 to 1.0 in each dimension. The usual
133values are "1d100" for 1-D and "2d32" for 2-D.
134
135CUTOFF is the cutoff value to use for the polydisperse distribution. Weights
136below the cutoff will be ignored.  Use "mono" for monodisperse models.  The
137choice of polydisperse parameters, and the number of points in the distribution
138is set in compare.py defaults for each model.
[7cf2cfd]139
140Available models:
141""")
142    print_models()
143
144def main():
145    if len(sys.argv) == 1:
146        print_help()
147        sys.exit(1)
148
149    model = sys.argv[1]
150    if not (model in MODELS) and (model != "all"):
151        print 'Bad model %s.  Use "all" or one of:'
152        print_models()
153        sys.exit(1)
154    try:
155        count = int(sys.argv[2])
156        is2D = sys.argv[3].startswith('2d')
157        assert sys.argv[3][1] == 'd'
158        Nq = int(sys.argv[3][2:])
159        mono = sys.argv[4] == 'mono'
160        cutoff = float(sys.argv[4]) if not mono else 0
161    except:
162        print_usage()
[216a9e1]163        sys.exit(1)
164
165    data, index = make_data(qmax=1.0, is2D=is2D, Nq=Nq)
[ab55943]166    model_list = [model] if model != "all" else MODELS
[216a9e1]167    for model in model_list:
168        compare_instance(model, data, index, N=count, mono=mono, cutoff=cutoff)
169
170if __name__ == "__main__":
171    main()
Note: See TracBrowser for help on using the repository browser.