source: sasmodels/sasmodels/compare_many.py @ 6458608

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 6458608 was 6458608, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

fix compare for scale=None

  • Property mode set to 100755
File size: 6.4 KB
Line 
1#!/usr/bin/env python
2from __future__ import print_function
3
4import sys
5import traceback
6
7import numpy as np
8
9from . import core
10from . import generate
11from .compare import (MODELS, randomize_pars, suppress_pd, make_data,
12                      make_engine, get_demo_pars, columnize,
13                      constrain_pars, constrain_new_to_old)
14
15def calc_stats(target, value, index):
16    resid = abs(value-target)[index]
17    relerr = resid/target[index]
18    srel = np.argsort(relerr)
19    #p90 = int(len(relerr)*0.90)
20    p95 = int(len(relerr)*0.95)
21    maxrel = np.max(relerr)
22    rel95 = relerr[srel[p95]]
23    maxabs = np.max(resid[srel[p95:]])
24    maxval = np.max(value[srel[p95:]])
25    return maxrel,rel95,maxabs,maxval
26
27def print_column_headers(pars, parts):
28    stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|'))
29    groups = ['']
30    for p in parts:
31        groups.append(p)
32        groups.extend(['']*(len(stats)-1))
33    groups.append("Parameters")
34    columns = ['Seed'] + stats*len(parts) +  list(sorted(pars.keys()))
35    print(','.join('"%s"'%c for c in groups))
36    print(','.join('"%s"'%c for c in columns))
37
38PRECISION = {
39    'fast': 1e-3,
40    'half': 1e-3,
41    'single': 5e-5,
42    'double': 5e-14,
43    'single!': 5e-5,
44    'double!': 5e-14,
45    'quad!': 5e-18,
46    'sasview': 5e-14,
47}
48def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5,
49                     base='sasview', comp='double'):
50    is2D = hasattr(data, 'qx_data')
51    model_definition = core.load_model_definition(name)
52    pars = get_demo_pars(model_definition)
53    header = ('\n"Model","%s","Count","%d","Dimension","%s"'
54              % (name, N, "2D" if is2D else "1D"))
55    if not mono: header += ',"Cutoff",%g'%(cutoff,)
56    print(header)
57
58    if is2D:
59        info = generate.make_info(model_definition)
60        partype = info['partype']
61        if not partype['orientation'] and not partype['magnetic']:
62            print(',"1-D only"')
63            return
64
65    # Some not very clean macros for evaluating the models and checking the
66    # results.  They freely use variables from the current scope, even some
67    # which have not been defined yet, complete with abuse of mutable lists
68    # to allow them to update values in the current scope since nonlocal
69    # declarations are not available in python 2.7.
70    def try_model(fn, pars):
71        try:
72            result = fn(**pars)
73        except KeyboardInterrupt:
74            raise
75        except:
76            traceback.print_exc()
77            print("when comparing %s for %d"%(name, seed))
78            if hasattr(data, 'qx_data'):
79                result = np.NaN*data.data
80            else:
81                result = np.NaN*data.x
82        return result
83    def check_model(pars):
84        base_value = try_model(calc_base, pars)
85        comp_value = try_model(calc_comp, pars)
86        stats = calc_stats(base_value, comp_value, index)
87        max_diff[0] = max(max_diff[0], stats[0])
88        good[0] = good[0] and (stats[0] < expected)
89        return list(stats)
90
91
92    calc_base = make_engine(model_definition, data, base, cutoff)
93    calc_comp = make_engine(model_definition, data, comp, cutoff)
94    expected = max(PRECISION[base], PRECISION[comp])
95
96    num_good = 0
97    first = True
98    max_diff = [0]
99    for k in range(N):
100        print("%s %d"%(name, k), file=sys.stderr)
101        seed = np.random.randint(1e6)
102        pars_i = randomize_pars(pars, seed)
103        constrain_pars(model_definition, pars_i)
104        constrain_new_to_old(model_definition, pars_i)
105        if mono:
106            pars_i = suppress_pd(pars_i)
107
108        good = [True]
109        columns = check_model(pars_i)
110        columns += [v for _,v in sorted(pars_i.items())]
111        if first:
112            labels = [" vs. ".join((calc_base.engine, calc_comp.engine))]
113            print_column_headers(pars_i, labels)
114            first = False
115        if good[0]:
116            num_good += 1
117        else:
118            print(("%d,"%seed)+','.join("%s"%v for v in columns))
119    print('"good","%d/%d","max diff",%g'%(num_good, N, max_diff[0]))
120
121
122def print_usage():
123    print("usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono) (single|double|quad)")
124
125
126def print_models():
127    print(columnize(MODELS, indent="  "))
128
129
130def print_help():
131    print_usage()
132    print("""\
133
134MODEL is the model name of the model or "all" for all the models
135in alphabetical order.
136
137COUNT is the number of randomly generated parameter sets to try. A value
138of "10000" is a reasonable check for monodisperse models, or "100" for
139polydisperse models.   For a quick check, use "100" and "5" respectively.
140
141NQ is the number of Q values to calculate.  If it starts with "1d", then
142it is a 1-dimensional problem, with log spaced Q points from 1e-3 to 1.0.
143If it starts with "2d" then it is a 2-dimensional problem, with linearly
144spaced points Q points from -1.0 to 1.0 in each dimension. The usual
145values are "1d100" for 1-D and "2d32" for 2-D.
146
147CUTOFF is the cutoff value to use for the polydisperse distribution. Weights
148below the cutoff will be ignored.  Use "mono" for monodisperse models.  The
149choice of polydisperse parameters, and the number of points in the distribution
150is set in compare.py defaults for each model.
151
152PRECISION is the floating point precision to use for comparisons.  If two
153precisions are given, then compare one to the other, ignoring sasview.
154
155Available models:
156""")
157    print_models()
158
159def main():
160    if len(sys.argv) not in (6,7):
161        print_help()
162        sys.exit(1)
163
164    model = sys.argv[1]
165    if not (model in MODELS) and (model != "all"):
166        print('Bad model %s.  Use "all" or one of:'%model)
167        print_models()
168        sys.exit(1)
169    try:
170        count = int(sys.argv[2])
171        is2D = sys.argv[3].startswith('2d')
172        assert sys.argv[3][1] == 'd'
173        Nq = int(sys.argv[3][2:])
174        mono = sys.argv[4] == 'mono'
175        cutoff = float(sys.argv[4]) if not mono else 0
176        base = sys.argv[5]
177        comp = sys.argv[6] if len(sys.argv) > 6 else "sasview"
178    except:
179        traceback.print_exc()
180        print_usage()
181        sys.exit(1)
182
183    data, index = make_data({'qmax':1.0, 'is2d':is2D, 'nq':Nq, 'res':0.,
184                              'accuracy': 'Low', 'view':'log'})
185    model_list = [model] if model != "all" else MODELS
186    for model in model_list:
187        compare_instance(model, data, index, N=count, mono=mono,
188                         cutoff=cutoff, base=base, comp=comp)
189
190if __name__ == "__main__":
191    main()
Note: See TracBrowser for help on using the repository browser.