source: sasmodels/sasmodels/compare.py @ f4f3919

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since f4f3919 was f4f3919, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

improve performance timing by suppressing startup time

  • Property mode set to 100755
File size: 18.0 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8import datetime
9import traceback
10
11import numpy as np
12
13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
17from . import core
18from . import kerneldll
19from . import generate
20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
22from .convert import revert_model, constrain_new_to_old
23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
24
25# List of available models
26MODELS = [basename(f)[:-3]
27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
28
29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
67
68def sasview_model(model_definition, **pars):
69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
73    #print("old",sorted(pars.items()))
74    modelname, pars = revert_model(model_definition, pars)
75    #print("new",sorted(pars.items()))
76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
78    if ModelClass is None:
79        raise ValueError("could not find model %r in sas.models"%modelname)
80    model = ModelClass()
81
82    for k,v in pars.items():
83        parts = k.split('.')  # polydispersity components
84        if len(parts) == 2:
85            model.dispersion[parts[0]][parts[1]] = v
86        else:
87            model.setParam(k, v)
88    return model
89
90def randomize(p, v):
91    """
92    Randomizing parameter.
93
94    Guess the parameter type from name.
95    """
96    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
97        return v
98    elif any(s in p for s in ('theta','phi','psi')):
99        # orientation in [-180,180], orientation pd in [0,45]
100        if p.endswith('_pd'):
101            return 45*np.random.rand()
102        else:
103            return 360*np.random.rand() - 180
104    elif 'sld' in p:
105        # sld in in [-0.5,10]
106        return 10.5*np.random.rand() - 0.5
107    elif p.endswith('_pd'):
108        # length pd in [0,1]
109        return np.random.rand()
110    else:
111        # values from 0 to 2*x for all other parameters
112        return 2*np.random.rand()*(v if v != 0 else 1)
113
114def randomize_model(pars, seed=None):
115    if seed is None:
116        seed = np.random.randint(1e9)
117    np.random.seed(seed)
118    # Note: the sort guarantees order of calls to random number generator
119    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
120
121    return pars, seed
122
123def constrain_pars(model_definition, pars):
124    """
125    Restrict parameters to valid values.
126    """
127    name = model_definition.name
128    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
129        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
130    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
131        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
132
133    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
134    if name == 'guinier':
135        #q_max = 0.2  # mid q maximum
136        q_max = 1.0  # high q maximum
137        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
138        pars['rg'] = min(pars['rg'],rg_max)
139
140def parlist(pars):
141    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
142
143def suppress_pd(pars):
144    """
145    Suppress theta_pd for now until the normalization is resolved.
146
147    May also suppress complete polydispersity of the model to test
148    models more quickly.
149    """
150    pars = pars.copy()
151    for p in pars:
152        if p.endswith("_pd"): pars[p] = 0
153    return pars
154
155def eval_sasview(model_definition, pars, data, Nevals=1):
156    # importing sas here so that the error message will be that sas failed to
157    # import rather than the more obscure smear_selection not imported error
158    import sas
159    from sas.models.qsmearing import smear_selection
160    model = sasview_model(model_definition, **pars)
161    smearer = smear_selection(data, model=model)
162    value = None  # silence the linter
163    toc = tic()
164    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
165        if hasattr(data, 'qx_data'):
166            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
167            index = ((~data.mask) & (~np.isnan(data.data))
168                     & (q >= data.qmin) & (q <= data.qmax))
169            if smearer is not None:
170                smearer.model = model  # because smear_selection has a bug
171                smearer.accuracy = data.accuracy
172                smearer.set_index(index)
173                value = smearer.get_value()
174            else:
175                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
176        else:
177            value = model.evalDistribution(data.x)
178            if smearer is not None:
179                value = smearer(value)
180    average_time = toc()*1000./Nevals
181    return value, average_time
182
183def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1,
184                cutoff=0., fast=False):
185    try:
186        model = core.load_model(model_definition, dtype=dtype,
187                                platform="ocl", fast=fast)
188    except Exception as exc:
189        print(exc)
190        print("... trying again with single precision")
191        model = core.load_model(model_definition, dtype='single',
192                                platform="ocl", fast=fast)
193    calculator = DirectModel(data, model, cutoff=cutoff)
194    # Run the calculator once before starting the timing loop
195    value = calculator(**suppress_pd(pars))
196    # Now run the timing loop
197    toc = tic()
198    for _ in range(max(Nevals, 1)):  # force at least one eval
199        value = calculator(**pars)
200    average_time = toc()*1000./Nevals
201    return value, average_time
202
203
204def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
205    model = core.load_model(model_definition, dtype=dtype, platform="dll")
206    calculator = DirectModel(data, model, cutoff=cutoff)
207    # Run the calculator once before starting the timing loop
208    value = calculator(**suppress_pd(pars))
209    # Now run the timing loop
210    toc = tic()
211    for _ in range(max(Nevals, 1)):  # force at least one eval
212        value = calculator(**pars)
213    average_time = toc()*1000./Nevals
214    return value, average_time
215
216
217def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
218    if is2D:
219        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
220        data.accuracy = accuracy
221        set_beam_stop(data, 0.004)
222        index = ~data.mask
223    else:
224        if view == 'log':
225            qmax = math.log10(qmax)
226            q = np.logspace(qmax-3, qmax, Nq)
227        else:
228            q = np.linspace(0.001*qmax, qmax, Nq)
229        data = empty_data1D(q, resolution=resolution)
230        index = slice(None, None)
231    return data, index
232
233def compare(name, pars, Ncomp, Nbase, opts, set_pars):
234    model_definition = core.load_model_definition(name)
235
236    view = ('linear' if '-linear' in opts
237            else 'log' if '-log' in opts
238            else 'q4' if '-q4' in opts
239            else 'log')
240
241    opt_values = dict(split
242                      for s in opts for split in ((s.split('='),))
243                      if len(split) == 2)
244    # Sort out data
245    qmax = (10.0 if '-exq' in opts
246            else 1.0 if '-highq' in opts
247            else 0.2 if '-midq' in opts
248            else 0.05)
249    Nq = int(opt_values.get('-Nq', '128'))
250    res = float(opt_values.get('-res', '0'))
251    accuracy = opt_values.get('-accuracy', 'Low')
252    is2D = "-2d" in opts
253    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
254
255
256    # modelling accuracy is determined by dtype and cutoff
257    dtype = ('longdouble' if '-quad' in opts
258             else 'double' if '-double' in opts
259             else 'half' if '-half' in opts
260             else 'single')
261    cutoff = float(opt_values.get('-cutoff','1e-5'))
262    fast = "-fast" in opts and dtype is 'single'
263
264    # randomize parameters
265    #pars.update(set_pars)  # set value before random to control range
266    if '-random' in opts or '-random' in opt_values:
267        seed = int(opt_values['-random']) if '-random' in opt_values else None
268        pars, seed = randomize_model(pars, seed=seed)
269        print("Randomize using -random=%i"%seed)
270    pars.update(set_pars)  # set value after random to control value
271    constrain_pars(model_definition, pars)
272    constrain_new_to_old(model_definition, pars)
273
274    # parameter selection
275    if '-mono' in opts:
276        pars = suppress_pd(pars)
277    if '-pars' in opts:
278        print("pars "+str(parlist(pars)))
279
280    # Base calculation
281    if 0:
282        from sasmodels.models import sphere as target
283        base_name = target.name
284        base, base_time = eval_ctypes(target, pars, data,
285                dtype='longdouble', cutoff=0., Nevals=Ncomp)
286    elif Nbase > 0 and "-ctypes" in opts and "-sasview" in opts:
287        try:
288            base, base_time = eval_sasview(model_definition, pars, data, Ncomp)
289            base_name = "sasview"
290            #print("base/sasview", (base-pars['background'])/(comp-pars['background']))
291            print("sasview t=%.1f ms, intensity=%.0f"%(base_time, sum(base)))
292            #print("sasview",comp)
293        except ImportError:
294            traceback.print_exc()
295            Nbase = 0
296    elif Nbase > 0:
297        base, base_time = eval_opencl(model_definition, pars, data,
298                dtype=dtype, cutoff=cutoff, Nevals=Nbase, fast=fast)
299        base_name = "ocl"
300        print("opencl t=%.1f ms, intensity=%.0f"%(base_time, sum(base)))
301        #print("base " + base)
302        #print(max(base), min(base))
303
304    # Comparison calculation
305    if Ncomp > 0 and "-ctypes" in opts:
306        comp, comp_time = eval_ctypes(model_definition, pars, data,
307                dtype=dtype, cutoff=cutoff, Nevals=Ncomp)
308        comp_name = "ctypes"
309        print("ctypes t=%.1f ms, intensity=%.0f"%(comp_time, sum(comp)))
310    elif Ncomp > 0:
311        try:
312            comp, comp_time = eval_sasview(model_definition, pars, data, Ncomp)
313            comp_name = "sasview"
314            #print("base/sasview", (base-pars['background'])/(comp-pars['background']))
315            print("sasview t=%.1f ms, intensity=%.0f"%(comp_time, sum(comp)))
316            #print("sasview",comp)
317        except ImportError:
318            traceback.print_exc()
319            Ncomp = 0
320
321    # Compare, but only if computing both forms
322    if Nbase > 0 and Ncomp > 0:
323        #print("speedup %.2g"%(comp_time/base_time))
324        #print("max |base/comp|", max(abs(base/comp)), "%.15g"%max(abs(base)), "%.15g"%max(abs(comp)))
325        #comp *= max(base/comp)
326        resid = (base - comp)
327        relerr = resid/comp
328        #bad = (relerr>1e-4)
329        #print(relerr[bad],comp[bad],base[bad],data.qx_data[bad],data.qy_data[bad])
330        _print_stats("|%s-%s|"%(base_name,comp_name)+(" "*(3+len(comp_name))), resid)
331        _print_stats("|(%s-%s)/%s|"%(base_name,comp_name,comp_name), relerr)
332
333    # Plot if requested
334    if '-noplot' in opts: return
335    import matplotlib.pyplot as plt
336    if Ncomp > 0:
337        if Nbase > 0: plt.subplot(131)
338        plot_theory(data, comp, view=view, plot_data=False)
339        plt.title("%s t=%.1f ms"%(comp_name,comp_time))
340        #cbar_title = "log I"
341    if Nbase > 0:
342        if Ncomp > 0: plt.subplot(132)
343        plot_theory(data, base, view=view, plot_data=False)
344        plt.title("%s t=%.1f ms"%(base_name,base_time))
345        #cbar_title = "log I"
346    if Ncomp > 0 and Nbase > 0:
347        plt.subplot(133)
348        if '-abs' in opts:
349            err,errstr,errview = resid, "abs err", "linear"
350        else:
351            err,errstr,errview = abs(relerr), "rel err", "log"
352        #err,errstr = base/comp,"ratio"
353        plot_theory(data, None, resid=err, view=errview, plot_data=False)
354        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
355        #cbar_title = errstr if errview=="linear" else "log "+errstr
356    #if is2D:
357    #    h = plt.colorbar()
358    #    h.ax.set_title(cbar_title)
359
360    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
361        plt.figure()
362        v = relerr
363        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
364        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
365        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
366        plt.ylabel('P(err)')
367        plt.title('Comparison of single and double precision models for %s'%name)
368
369    plt.show()
370
371def _print_stats(label, err):
372    sorted_err = np.sort(abs(err))
373    p50 = int((len(err)-1)*0.50)
374    p98 = int((len(err)-1)*0.98)
375    data = [
376        "max:%.3e"%sorted_err[-1],
377        "median:%.3e"%sorted_err[p50],
378        "98%%:%.3e"%sorted_err[p98],
379        "rms:%.3e"%np.sqrt(np.mean(err**2)),
380        "zero-offset:%+.3e"%np.mean(err),
381        ]
382    print(label+"  ".join(data))
383
384
385
386# ===========================================================================
387#
388USAGE="""
389usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
390
391Compare the speed and value for a model between the SasView original and the
392OpenCL rewrite.
393
394model is the name of the model to compare (see below).
395Nopencl is the number of times to run the OpenCL model (default=5)
396Nsasview is the number of times to run the Sasview model (default=1)
397
398Options (* for default):
399
400    -plot*/-noplot plots or suppress the plot of the model
401    -half/-single*/-double/-quad/-fast sets the calculation precision
402    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
403    -Nq=128 sets the number of Q points in the data set
404    -1d*/-2d computes 1d or 2d data
405    -preset*/-random[=seed] preset or random parameters
406    -mono/-poly* force monodisperse/polydisperse
407    -ctypes/-sasview* selects gpu:cpu, gpu:sasview, or sasview:cpu if both
408    -cutoff=1e-5* cutoff value for including a point in polydispersity
409    -pars/-nopars* prints the parameter set or not
410    -abs/-rel* plot relative or absolute error
411    -linear/-log/-q4 intensity scaling
412    -hist/-nohist* plot histogram of relative error
413    -res=0 sets the resolution width dQ/Q if calculating with resolution
414    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
415
416Key=value pairs allow you to set specific values to any of the model
417parameters.
418
419Available models:
420"""
421
422
423NAME_OPTIONS = set([
424    'plot', 'noplot',
425    'half', 'single', 'double', 'quad', 'fast',
426    'lowq', 'midq', 'highq', 'exq',
427    '2d', '1d',
428    'preset', 'random',
429    'poly', 'mono',
430    'sasview', 'ctypes',
431    'nopars', 'pars',
432    'rel', 'abs',
433    'linear', 'log', 'q4',
434    'hist', 'nohist',
435    ])
436VALUE_OPTIONS = [
437    # Note: random is both a name option and a value option
438    'cutoff', 'random', 'Nq', 'res', 'accuracy',
439    ]
440
441def columnize(L, indent="", width=79):
442    column_width = max(len(w) for w in L) + 1
443    num_columns = (width - len(indent)) // column_width
444    num_rows = len(L) // num_columns
445    L = L + [""] * (num_rows*num_columns - len(L))
446    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
447    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
448             for row in zip(*columns)]
449    output = indent + ("\n"+indent).join(lines)
450    return output
451
452
453def get_demo_pars(model_definition):
454    info = generate.make_info(model_definition)
455    pars = dict((p[0],p[2]) for p in info['parameters'])
456    pars.update(info['demo'])
457    return pars
458
459def main():
460    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
461    popts = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' in arg]
462    args = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' not in arg]
463    models = "\n    ".join("%-15s"%v for v in MODELS)
464    if len(args) == 0:
465        print(USAGE)
466        print(columnize(MODELS, indent="  "))
467        sys.exit(1)
468    if args[0] not in MODELS:
469        print("Model %r not available. Use one of:\n    %s"%(args[0],models))
470        sys.exit(1)
471    if len(args) > 3:
472        print("expected parameters: model Nopencl Nsasview")
473
474    invalid = [o[1:] for o in opts
475               if o[1:] not in NAME_OPTIONS
476                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
477    if invalid:
478        print("Invalid options: %s"%(", ".join(invalid)))
479        sys.exit(1)
480
481    # Get demo parameters from model definition, or use default parameters
482    # if model does not define demo parameters
483    name = args[0]
484    model_definition = core.load_model_definition(name)
485    pars = get_demo_pars(model_definition)
486
487    Ncomp = int(args[1]) if len(args) > 1 else 5
488    Nbase = int(args[2]) if len(args) > 2 else 1
489
490    # Fill in default polydispersity parameters
491    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
492    for p in pds:
493        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
494        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
495
496    # Fill in parameters given on the command line
497    set_pars = {}
498    for arg in popts:
499        k,v = arg.split('=',1)
500        if k not in pars:
501            # extract base name without distribution
502            s = set(p.split('_pd')[0] for p in pars)
503            print("%r invalid; parameters are: %s"%(k,", ".join(sorted(s))))
504            sys.exit(1)
505        set_pars[k] = float(v) if not v.endswith('type') else v
506
507    compare(name, pars, Ncomp, Nbase, opts, set_pars)
508
509if __name__ == "__main__":
510    main()
Note: See TracBrowser for help on using the repository browser.