[216a9e1] | 1 | #!/usr/bin/env python |
---|
[d15a908] | 2 | """ |
---|
| 3 | Program to compare results from many random parameter sets for a given model. |
---|
| 4 | |
---|
| 5 | The result is a comma separated value (CSV) table that can be redirected |
---|
| 6 | from standard output into a file and loaded into a spreadsheet. |
---|
| 7 | |
---|
| 8 | The models are compared for each parameter set and if the difference is |
---|
| 9 | greater than expected for that precision, the parameter set is labeled |
---|
| 10 | as bad and written to the output, along with the random seed used to |
---|
| 11 | generate that parameter value. This seed can be used with :mod:`compare` |
---|
| 12 | to reload and display the details of the model. |
---|
| 13 | """ |
---|
[a7f909a] | 14 | from __future__ import print_function |
---|
| 15 | |
---|
[216a9e1] | 16 | import sys |
---|
[7cf2cfd] | 17 | import traceback |
---|
[216a9e1] | 18 | |
---|
[7ae2b7f] | 19 | import numpy as np # type: ignore |
---|
[216a9e1] | 20 | |
---|
[e922c5d] | 21 | from . import core |
---|
[f3bd37f] | 22 | from .compare import (randomize_pars, suppress_pd, make_data, |
---|
[ce346b6] | 23 | make_engine, get_pars, columnize, |
---|
[32398dc] | 24 | constrain_pars) |
---|
[216a9e1] | 25 | |
---|
[f3bd37f] | 26 | MODELS = core.list_models() |
---|
| 27 | |
---|
[319ab14] | 28 | def calc_stats(target, value, index): |
---|
[d15a908] | 29 | """ |
---|
| 30 | Calculate statistics between the target value and the computed value. |
---|
| 31 | |
---|
| 32 | *target* and *value* are the vectors being compared, with the |
---|
| 33 | difference normalized by *target* to get relative error. Only |
---|
| 34 | the elements listed in *index* are used, though index may be |
---|
| 35 | and empty slice defined by *slice(None, None)*. |
---|
| 36 | |
---|
| 37 | Returns: |
---|
| 38 | |
---|
| 39 | *maxrel* the maximum relative difference |
---|
| 40 | |
---|
| 41 | *rel95* the relative difference with the 5% biggest differences ignored |
---|
| 42 | |
---|
| 43 | *maxabs* the maximum absolute difference for the 5% biggest differences |
---|
| 44 | |
---|
| 45 | *maxval* the maximum value for the 5% biggest differences |
---|
| 46 | """ |
---|
[216a9e1] | 47 | resid = abs(value-target)[index] |
---|
| 48 | relerr = resid/target[index] |
---|
[d15a908] | 49 | sorted_rel_index = np.argsort(relerr) |
---|
[7cf2cfd] | 50 | #p90 = int(len(relerr)*0.90) |
---|
[216a9e1] | 51 | p95 = int(len(relerr)*0.95) |
---|
| 52 | maxrel = np.max(relerr) |
---|
[d15a908] | 53 | rel95 = relerr[sorted_rel_index[p95]] |
---|
| 54 | maxabs = np.max(resid[sorted_rel_index[p95:]]) |
---|
| 55 | maxval = np.max(value[sorted_rel_index[p95:]]) |
---|
| 56 | return maxrel, rel95, maxabs, maxval |
---|
[216a9e1] | 57 | |
---|
| 58 | def print_column_headers(pars, parts): |
---|
[d15a908] | 59 | """ |
---|
| 60 | Generate column headers for the differences and for the parameters, |
---|
| 61 | and print them to standard output. |
---|
| 62 | """ |
---|
[216a9e1] | 63 | stats = list('Max rel err|95% rel err|Max abs err above 90% rel|Max value above 90% rel'.split('|')) |
---|
| 64 | groups = [''] |
---|
| 65 | for p in parts: |
---|
| 66 | groups.append(p) |
---|
| 67 | groups.extend(['']*(len(stats)-1)) |
---|
[7cf2cfd] | 68 | groups.append("Parameters") |
---|
[216a9e1] | 69 | columns = ['Seed'] + stats*len(parts) + list(sorted(pars.keys())) |
---|
| 70 | print(','.join('"%s"'%c for c in groups)) |
---|
| 71 | print(','.join('"%s"'%c for c in columns)) |
---|
| 72 | |
---|
[d15a908] | 73 | # Target 'good' value for various precision levels. |
---|
[ec7e360] | 74 | PRECISION = { |
---|
| 75 | 'fast': 1e-3, |
---|
| 76 | 'half': 1e-3, |
---|
| 77 | 'single': 5e-5, |
---|
| 78 | 'double': 5e-14, |
---|
| 79 | 'single!': 5e-5, |
---|
| 80 | 'double!': 5e-14, |
---|
| 81 | 'quad!': 5e-18, |
---|
| 82 | } |
---|
[319ab14] | 83 | def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5, |
---|
[32398dc] | 84 | base='single', comp='double'): |
---|
[d15a908] | 85 | r""" |
---|
| 86 | Compare the model under different calculation engines. |
---|
| 87 | |
---|
| 88 | *name* is the name of the model. |
---|
| 89 | |
---|
| 90 | *data* is the data object giving $q, \Delta q$ calculation points. |
---|
| 91 | |
---|
| 92 | *index* is the active set of points. |
---|
| 93 | |
---|
| 94 | *N* is the number of comparisons to make. |
---|
| 95 | |
---|
| 96 | *cutoff* is the polydispersity weight cutoff to make the calculation |
---|
| 97 | a little bit faster. |
---|
| 98 | |
---|
| 99 | *base* and *comp* are the names of the calculation engines to compare. |
---|
| 100 | """ |
---|
| 101 | |
---|
| 102 | is_2d = hasattr(data, 'qx_data') |
---|
[17bbadd] | 103 | model_info = core.load_model_info(name) |
---|
[ce346b6] | 104 | pars = get_pars(model_info, use_demo=True) |
---|
[a7f909a] | 105 | header = ('\n"Model","%s","Count","%d","Dimension","%s"' |
---|
[d15a908] | 106 | % (name, N, "2D" if is_2d else "1D")) |
---|
[5124c969] | 107 | if not mono: |
---|
| 108 | header += ',"Cutoff",%g'%(cutoff,) |
---|
[216a9e1] | 109 | print(header) |
---|
[7cf2cfd] | 110 | |
---|
[d15a908] | 111 | if is_2d: |
---|
[f3bd37f] | 112 | if not model_info.parameters.has_2d: |
---|
[a7f909a] | 113 | print(',"1-D only"') |
---|
| 114 | return |
---|
| 115 | |
---|
[319ab14] | 116 | # Some not very clean macros for evaluating the models and checking the |
---|
| 117 | # results. They freely use variables from the current scope, even some |
---|
| 118 | # which have not been defined yet, complete with abuse of mutable lists |
---|
| 119 | # to allow them to update values in the current scope since nonlocal |
---|
| 120 | # declarations are not available in python 2.7. |
---|
[ec7e360] | 121 | def try_model(fn, pars): |
---|
[d15a908] | 122 | """ |
---|
| 123 | Return the model evaluated at *pars*. If there is an exception, |
---|
| 124 | print it and return NaN of the right shape. |
---|
| 125 | """ |
---|
[7cf2cfd] | 126 | try: |
---|
[ec7e360] | 127 | result = fn(**pars) |
---|
[ee8f734] | 128 | except Exception: |
---|
[9404dd3] | 129 | traceback.print_exc() |
---|
| 130 | print("when comparing %s for %d"%(name, seed)) |
---|
[cd3dba0] | 131 | if hasattr(data, 'qx_data'): |
---|
| 132 | result = np.NaN*data.data |
---|
| 133 | else: |
---|
| 134 | result = np.NaN*data.x |
---|
[7cf2cfd] | 135 | return result |
---|
[ec7e360] | 136 | def check_model(pars): |
---|
[d15a908] | 137 | """ |
---|
| 138 | Run the two calculators against *pars*, returning statistics |
---|
| 139 | on the differences. See :func:`calc_stats` for the list of stats. |
---|
| 140 | """ |
---|
[ec7e360] | 141 | base_value = try_model(calc_base, pars) |
---|
| 142 | comp_value = try_model(calc_comp, pars) |
---|
| 143 | stats = calc_stats(base_value, comp_value, index) |
---|
[319ab14] | 144 | max_diff[0] = max(max_diff[0], stats[0]) |
---|
[ec7e360] | 145 | good[0] = good[0] and (stats[0] < expected) |
---|
| 146 | return list(stats) |
---|
| 147 | |
---|
| 148 | |
---|
[f3bd37f] | 149 | try: |
---|
| 150 | calc_base = make_engine(model_info, data, base, cutoff) |
---|
| 151 | calc_comp = make_engine(model_info, data, comp, cutoff) |
---|
| 152 | except Exception as exc: |
---|
| 153 | #raise |
---|
[40a87fa] | 154 | print('"Error: %s"'%str(exc).replace('"', "'")) |
---|
| 155 | print('"good","%d of %d","max diff",%g' % (0, N, np.NaN)) |
---|
[f3bd37f] | 156 | return |
---|
[ec7e360] | 157 | expected = max(PRECISION[base], PRECISION[comp]) |
---|
[7cf2cfd] | 158 | |
---|
| 159 | num_good = 0 |
---|
[216a9e1] | 160 | first = True |
---|
[319ab14] | 161 | max_diff = [0] |
---|
[cd3dba0] | 162 | for k in range(N): |
---|
[5124c969] | 163 | print("Model %s %d"%(name, k+1), file=sys.stderr) |
---|
[ec7e360] | 164 | seed = np.random.randint(1e6) |
---|
[32398dc] | 165 | np.random.seed(seed) |
---|
| 166 | pars_i = randomize_pars(model_info, pars) |
---|
[ed048b2] | 167 | constrain_pars(model_info, pars_i) |
---|
[f4f3919] | 168 | if mono: |
---|
| 169 | pars_i = suppress_pd(pars_i) |
---|
[7cf2cfd] | 170 | |
---|
[319ab14] | 171 | good = [True] |
---|
[ec7e360] | 172 | columns = check_model(pars_i) |
---|
[d15a908] | 173 | columns += [v for _, v in sorted(pars_i.items())] |
---|
[7cf2cfd] | 174 | if first: |
---|
[ec7e360] | 175 | labels = [" vs. ".join((calc_base.engine, calc_comp.engine))] |
---|
[cd3dba0] | 176 | print_column_headers(pars_i, labels) |
---|
[7cf2cfd] | 177 | first = False |
---|
[319ab14] | 178 | if good[0]: |
---|
[7cf2cfd] | 179 | num_good += 1 |
---|
[216a9e1] | 180 | else: |
---|
[ec7e360] | 181 | print(("%d,"%seed)+','.join("%s"%v for v in columns)) |
---|
[f3bd37f] | 182 | print('"good","%d of %d","max diff",%g'%(num_good, N, max_diff[0])) |
---|
[7cf2cfd] | 183 | |
---|
| 184 | |
---|
| 185 | def print_usage(): |
---|
[d15a908] | 186 | """ |
---|
| 187 | Print the command usage string. |
---|
| 188 | """ |
---|
[f72d70a] | 189 | print("usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono) (single|double|quad)", |
---|
| 190 | file=sys.stderr) |
---|
[7cf2cfd] | 191 | |
---|
| 192 | |
---|
| 193 | def print_models(): |
---|
[d15a908] | 194 | """ |
---|
| 195 | Print the list of available models in columns. |
---|
| 196 | """ |
---|
[7cf2cfd] | 197 | print(columnize(MODELS, indent=" ")) |
---|
[216a9e1] | 198 | |
---|
| 199 | |
---|
[7cf2cfd] | 200 | def print_help(): |
---|
[d15a908] | 201 | """ |
---|
| 202 | Print usage string, the option description and the list of available models. |
---|
| 203 | """ |
---|
[7cf2cfd] | 204 | print_usage() |
---|
| 205 | print("""\ |
---|
| 206 | |
---|
[5124c969] | 207 | MODEL is the model name of the model or one of the model types listed in |
---|
| 208 | sasmodels.core.list_models (all, py, c, double, single, opencl, 1d, 2d, |
---|
| 209 | nonmagnetic, magnetic). Model types can be combined, such as 2d+single. |
---|
[216a9e1] | 210 | |
---|
| 211 | COUNT is the number of randomly generated parameter sets to try. A value |
---|
| 212 | of "10000" is a reasonable check for monodisperse models, or "100" for |
---|
| 213 | polydisperse models. For a quick check, use "100" and "5" respectively. |
---|
| 214 | |
---|
| 215 | NQ is the number of Q values to calculate. If it starts with "1d", then |
---|
| 216 | it is a 1-dimensional problem, with log spaced Q points from 1e-3 to 1.0. |
---|
| 217 | If it starts with "2d" then it is a 2-dimensional problem, with linearly |
---|
| 218 | spaced points Q points from -1.0 to 1.0 in each dimension. The usual |
---|
| 219 | values are "1d100" for 1-D and "2d32" for 2-D. |
---|
| 220 | |
---|
| 221 | CUTOFF is the cutoff value to use for the polydisperse distribution. Weights |
---|
| 222 | below the cutoff will be ignored. Use "mono" for monodisperse models. The |
---|
| 223 | choice of polydisperse parameters, and the number of points in the distribution |
---|
[f72d70a] | 224 | is set in compare.py defaults for each model. Polydispersity is given in the |
---|
| 225 | "demo" attribute of each model. |
---|
[7cf2cfd] | 226 | |
---|
[ec7e360] | 227 | PRECISION is the floating point precision to use for comparisons. If two |
---|
[f72d70a] | 228 | precisions are given, then compare one to the other. Precision is one of |
---|
| 229 | fast, single, double for GPU or single!, double!, quad! for DLL. If no |
---|
| 230 | precision is given, then use single and double! respectively. |
---|
[319ab14] | 231 | |
---|
[7cf2cfd] | 232 | Available models: |
---|
| 233 | """) |
---|
| 234 | print_models() |
---|
| 235 | |
---|
[424fe00] | 236 | def main(argv): |
---|
[d15a908] | 237 | """ |
---|
| 238 | Main program. |
---|
| 239 | """ |
---|
[f72d70a] | 240 | if len(argv) not in (3, 4, 5, 6): |
---|
[7cf2cfd] | 241 | print_help() |
---|
[424fe00] | 242 | return |
---|
[7cf2cfd] | 243 | |
---|
[5124c969] | 244 | target = argv[0] |
---|
| 245 | try: |
---|
| 246 | model_list = [target] if target in MODELS else core.list_models(target) |
---|
| 247 | except ValueError: |
---|
[f72d70a] | 248 | print('Bad model %s. Use model type or one of:' % target, file=sys.stderr) |
---|
[7cf2cfd] | 249 | print_models() |
---|
[5124c969] | 250 | print('model types: all, py, c, double, single, opencl, 1d, 2d, nonmagnetic, magnetic') |
---|
[424fe00] | 251 | return |
---|
[7cf2cfd] | 252 | try: |
---|
[424fe00] | 253 | count = int(argv[1]) |
---|
| 254 | is2D = argv[2].startswith('2d') |
---|
| 255 | assert argv[2][1] == 'd' |
---|
| 256 | Nq = int(argv[2][2:]) |
---|
[f72d70a] | 257 | mono = len(argv) <= 3 or argv[3] == 'mono' |
---|
[424fe00] | 258 | cutoff = float(argv[3]) if not mono else 0 |
---|
[f72d70a] | 259 | base = argv[4] if len(argv) > 4 else "single" |
---|
| 260 | comp = argv[5] if len(argv) > 5 else "double!" |
---|
[ee8f734] | 261 | except Exception: |
---|
[319ab14] | 262 | traceback.print_exc() |
---|
[7cf2cfd] | 263 | print_usage() |
---|
[424fe00] | 264 | return |
---|
[216a9e1] | 265 | |
---|
[2d81cfe] | 266 | data, index = make_data({ |
---|
| 267 | 'qmin': 0.001, 'qmax': 1.0, 'is2d': is2D, 'nq': Nq, 'res': 0., |
---|
| 268 | 'accuracy': 'Low', 'view':'log', 'zero': False |
---|
| 269 | }) |
---|
[216a9e1] | 270 | for model in model_list: |
---|
[319ab14] | 271 | compare_instance(model, data, index, N=count, mono=mono, |
---|
[ec7e360] | 272 | cutoff=cutoff, base=base, comp=comp) |
---|
[216a9e1] | 273 | |
---|
| 274 | if __name__ == "__main__": |
---|
[4f2478e] | 275 | #from .compare import push_seed |
---|
[424fe00] | 276 | #with push_seed(1): main(sys.argv[1:]) |
---|
| 277 | main(sys.argv[1:]) |
---|