source: sasmodels/sasmodels/compare.py @ b514adf

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since b514adf was b514adf, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

set constraints so multi_compare has fewer spurious errors

  • Property mode set to 100755
File size: 16.6 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8import datetime
9import traceback
10
11import numpy as np
12
13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
17from . import core
18from . import kerneldll
19from . import generate
20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
22from .convert import revert_model
23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
24
25# List of available models
26MODELS = [basename(f)[:-3]
27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
28
29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
67
68def sasview_model(model_definition, **pars):
69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
73    #print "old",sorted(pars.items())
74    modelname, pars = revert_model(model_definition, pars)
75    #print "new",sorted(pars.items())
76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
78    if ModelClass is None:
79        raise ValueError("could not find model %r in sas.models"%modelname)
80    model = ModelClass()
81
82    for k,v in pars.items():
83        if k.endswith("_pd"):
84            model.dispersion[k[:-3]]['width'] = v
85        elif k.endswith("_pd_n"):
86            model.dispersion[k[:-5]]['npts'] = v
87        elif k.endswith("_pd_nsigma"):
88            model.dispersion[k[:-10]]['nsigmas'] = v
89        elif k.endswith("_pd_type"):
90            model.dispersion[k[:-8]]['type'] = v
91        else:
92            model.setParam(k, v)
93    return model
94
95def randomize(p, v):
96    """
97    Randomizing parameter.
98
99    Guess the parameter type from name.
100    """
101    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
102        return v
103    elif any(s in p for s in ('theta','phi','psi')):
104        # orientation in [-180,180], orientation pd in [0,45]
105        if p.endswith('_pd'):
106            return 45*np.random.rand()
107        else:
108            return 360*np.random.rand() - 180
109    elif 'sld' in p:
110        # sld in in [-0.5,10]
111        return 10.5*np.random.rand() - 0.5
112    elif p.endswith('_pd'):
113        # length pd in [0,1]
114        return np.random.rand()
115    else:
116        # values from 0 to 2*x for all other parameters
117        return 2*np.random.rand()*(v if v != 0 else 1)
118
119def randomize_model(pars, seed=None):
120    if seed is None:
121        seed = np.random.randint(1e9)
122    np.random.seed(seed)
123    # Note: the sort guarantees order of calls to random number generator
124    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
125
126    return pars, seed
127
128def constrain_pars(model_definition, pars):
129    name = model_definition.name
130    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
131        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
132    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
133        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
134
135    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
136    if name == 'guinier':
137        #q_max = 0.2  # mid q maximum
138        q_max = 1.0  # high q maximum
139        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
140        pars['rg'] = min(pars['rg'],rg_max)
141
142    # These constraints are only needed for comparison to sasview
143    if name in ('teubner_strey','broad_peak'):
144        del pars['scale']
145    if name in ('guinier',):
146        del pars['background']
147    if getattr(model_definition, 'category', None) == 'structure-factor':
148        del pars['scale'], pars['background']
149
150
151def parlist(pars):
152    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
153
154def suppress_pd(pars):
155    """
156    Suppress theta_pd for now until the normalization is resolved.
157
158    May also suppress complete polydispersity of the model to test
159    models more quickly.
160    """
161    for p in pars:
162        if p.endswith("_pd"): pars[p] = 0
163
164def eval_sasview(model_definition, pars, data, Nevals=1):
165    from sas.models.qsmearing import smear_selection
166    model = sasview_model(model_definition, **pars)
167    smearer = smear_selection(data, model=model)
168    value = None  # silence the linter
169    toc = tic()
170    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
171        if hasattr(data, 'qx_data'):
172            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
173            index = ((~data.mask) & (~np.isnan(data.data))
174                     & (q >= data.qmin) & (q <= data.qmax))
175            if smearer is not None:
176                smearer.model = model  # because smear_selection has a bug
177                smearer.accuracy = data.accuracy
178                smearer.set_index(index)
179                value = smearer.get_value()
180            else:
181                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
182        else:
183            value = model.evalDistribution(data.x)
184            if smearer is not None:
185                value = smearer(value)
186    average_time = toc()*1000./Nevals
187    return value, average_time
188
189def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
190    try:
191        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
192    except Exception,exc:
193        print exc
194        print "... trying again with single precision"
195        model = core.load_model(model_definition, dtype='single', platform="ocl")
196    calculator = DirectModel(data, model, cutoff=cutoff)
197    value = None  # silence the linter
198    toc = tic()
199    for _ in range(max(Nevals, 1)):  # force at least one eval
200        value = calculator(**pars)
201    average_time = toc()*1000./Nevals
202    return value, average_time
203
204
205def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
206    model = core.load_model(model_definition, dtype=dtype, platform="dll")
207    calculator = DirectModel(data, model, cutoff=cutoff)
208    value = None  # silence the linter
209    toc = tic()
210    for _ in range(max(Nevals, 1)):  # force at least one eval
211        value = calculator(**pars)
212    average_time = toc()*1000./Nevals
213    return value, average_time
214
215
216def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
217    if is2D:
218        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
219        data.accuracy = accuracy
220        set_beam_stop(data, 0.004)
221        index = ~data.mask
222    else:
223        if view == 'log':
224            qmax = math.log10(qmax)
225            q = np.logspace(qmax-3, qmax, Nq)
226        else:
227            q = np.linspace(0.001*qmax, qmax, Nq)
228        data = empty_data1D(q, resolution=resolution)
229        index = slice(None, None)
230    return data, index
231
232def compare(name, pars, Ncpu, Nocl, opts, set_pars):
233    model_definition = core.load_model_definition(name)
234
235    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
236
237    opt_values = dict(split
238                      for s in opts for split in ((s.split('='),))
239                      if len(split) == 2)
240    # Sort out data
241    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
242    Nq = int(opt_values.get('-Nq', '128'))
243    res = float(opt_values.get('-res', '0'))
244    accuracy = opt_values.get('-accuracy', 'Low')
245    is2D = "-2d" in opts
246    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
247
248
249    # modelling accuracy is determined by dtype and cutoff
250    dtype = 'double' if '-double' in opts else 'single'
251    cutoff = float(opt_values.get('-cutoff','1e-5'))
252
253    # randomize parameters
254    #pars.update(set_pars)  # set value before random to control range
255    if '-random' in opts or '-random' in opt_values:
256        seed = int(opt_values['-random']) if '-random' in opt_values else None
257        pars, seed = randomize_model(pars, seed=seed)
258        print "Randomize using -random=%i"%seed
259    pars.update(set_pars)  # set value after random to control value
260    constrain_pars(model_definition, pars)
261
262    # parameter selection
263    if '-mono' in opts:
264        suppress_pd(pars)
265    if '-pars' in opts:
266        print "pars",parlist(pars)
267
268    # OpenCl calculation
269    if Nocl > 0:
270        ocl, ocl_time = eval_opencl(model_definition, pars, data,
271                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
272        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
273        #print "ocl", ocl
274        #print max(ocl), min(ocl)
275
276    # ctypes/sasview calculation
277    if Ncpu > 0 and "-ctypes" in opts:
278        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
279                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
280        comp = "ctypes"
281        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
282    elif Ncpu > 0:
283        try:
284            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
285            comp = "sasview"
286            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background'])
287            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
288            #print "sasview",cpu
289        except ImportError:
290            traceback.print_exc()
291            Ncpu = 0
292
293    # Compare, but only if computing both forms
294    if Nocl > 0 and Ncpu > 0:
295        #print "speedup %.2g"%(cpu_time/ocl_time)
296        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
297        #cpu *= max(ocl/cpu)
298        resid = (ocl - cpu)
299        relerr = resid/cpu
300        #bad = (relerr>1e-4)
301        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
302        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
303        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
304
305    # Plot if requested
306    if '-noplot' in opts: return
307    import matplotlib.pyplot as plt
308    if Ncpu > 0:
309        if Nocl > 0: plt.subplot(131)
310        plot_theory(data, cpu, view=view, plot_data=False)
311        plt.title("%s t=%.1f ms"%(comp,cpu_time))
312        #cbar_title = "log I"
313    if Nocl > 0:
314        if Ncpu > 0: plt.subplot(132)
315        plot_theory(data, ocl, view=view, plot_data=False)
316        plt.title("opencl t=%.1f ms"%ocl_time)
317        #cbar_title = "log I"
318    if Ncpu > 0 and Nocl > 0:
319        plt.subplot(133)
320        if '-abs' in opts:
321            err,errstr,errview = resid, "abs err", "linear"
322        else:
323            err,errstr,errview = abs(relerr), "rel err", "log"
324        #err,errstr = ocl/cpu,"ratio"
325        plot_theory(data, None, resid=err, view=errview, plot_data=False)
326        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
327        #cbar_title = errstr if errview=="linear" else "log "+errstr
328    #if is2D:
329    #    h = plt.colorbar()
330    #    h.ax.set_title(cbar_title)
331
332    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
333        plt.figure()
334        v = relerr
335        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
336        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
337        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
338        plt.ylabel('P(err)')
339        plt.title('Comparison of single and double precision models for %s'%name)
340
341    plt.show()
342
343def _print_stats(label, err):
344    sorted_err = np.sort(abs(err))
345    p50 = int((len(err)-1)*0.50)
346    p98 = int((len(err)-1)*0.98)
347    data = [
348        "max:%.3e"%sorted_err[-1],
349        "median:%.3e"%sorted_err[p50],
350        "98%%:%.3e"%sorted_err[p98],
351        "rms:%.3e"%np.sqrt(np.mean(err**2)),
352        "zero-offset:%+.3e"%np.mean(err),
353        ]
354    print label,"  ".join(data)
355
356
357
358# ===========================================================================
359#
360USAGE="""
361usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
362
363Compare the speed and value for a model between the SasView original and the
364OpenCL rewrite.
365
366model is the name of the model to compare (see below).
367Nopencl is the number of times to run the OpenCL model (default=5)
368Nsasview is the number of times to run the Sasview model (default=1)
369
370Options (* for default):
371
372    -plot*/-noplot plots or suppress the plot of the model
373    -single*/-double uses double precision for comparison
374    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
375    -Nq=128 sets the number of Q points in the data set
376    -1d*/-2d computes 1d or 2d data
377    -preset*/-random[=seed] preset or random parameters
378    -mono/-poly* force monodisperse/polydisperse
379    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
380    -cutoff=1e-5* cutoff value for including a point in polydispersity
381    -pars/-nopars* prints the parameter set or not
382    -abs/-rel* plot relative or absolute error
383    -linear/-log/-q4 intensity scaling
384    -hist/-nohist* plot histogram of relative error
385    -res=0 sets the resolution width dQ/Q if calculating with resolution
386    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
387
388Key=value pairs allow you to set specific values to any of the model
389parameters.
390
391Available models:
392"""
393
394
395NAME_OPTIONS = set([
396    'plot','noplot',
397    'single','double',
398    'lowq','midq','highq','exq',
399    '2d','1d',
400    'preset','random',
401    'poly','mono',
402    'sasview','ctypes',
403    'nopars','pars',
404    'rel','abs',
405    'linear', 'log', 'q4',
406    'hist','nohist',
407    ])
408VALUE_OPTIONS = [
409    # Note: random is both a name option and a value option
410    'cutoff', 'random', 'Nq', 'res', 'accuracy',
411    ]
412
413def columnize(L, indent="", width=79):
414    column_width = max(len(w) for w in L) + 1
415    num_columns = (width - len(indent)) // column_width
416    num_rows = len(L) // num_columns
417    L = L + [""] * (num_rows*num_columns - len(L))
418    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
419    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
420             for row in zip(*columns)]
421    output = indent + ("\n"+indent).join(lines)
422    return output
423
424
425def get_demo_pars(model_definition):
426    info = generate.make_info(model_definition)
427    pars = dict((p[0],p[2]) for p in info['parameters'])
428    pars.update(info['demo'])
429    return pars
430
431def main():
432    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
433    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
434    models = "\n    ".join("%-15s"%v for v in MODELS)
435    if len(args) == 0:
436        print(USAGE)
437        print(columnize(MODELS, indent="  "))
438        sys.exit(1)
439    if args[0] not in MODELS:
440        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
441        sys.exit(1)
442
443    invalid = [o[1:] for o in opts
444               if o[1:] not in NAME_OPTIONS
445                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
446    if invalid:
447        print "Invalid options: %s"%(", ".join(invalid))
448        sys.exit(1)
449
450    # Get demo parameters from model definition, or use default parameters
451    # if model does not define demo parameters
452    name = args[0]
453    model_definition = core.load_model_definition(name)
454    pars = get_demo_pars(model_definition)
455
456    Nopencl = int(args[1]) if len(args) > 1 else 5
457    Nsasview = int(args[2]) if len(args) > 2 else 1
458
459    # Fill in default polydispersity parameters
460    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
461    for p in pds:
462        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
463        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
464
465    # Fill in parameters given on the command line
466    set_pars = {}
467    for arg in args[3:]:
468        k,v = arg.split('=')
469        if k not in pars:
470            # extract base name without distribution
471            s = set(p.split('_pd')[0] for p in pars)
472            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
473            sys.exit(1)
474        set_pars[k] = float(v) if not v.endswith('type') else v
475
476    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
477
478if __name__ == "__main__":
479    main()
Note: See TracBrowser for help on using the repository browser.