source: sasmodels/sasmodels/compare.py @ dc056b9

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since dc056b9 was dc056b9, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

…and leave a comment about why the unused import should be kept in the code

  • Property mode set to 100755
File size: 17.0 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8import datetime
9import traceback
10
11import numpy as np
12
13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
17from . import core
18from . import kerneldll
19from . import generate
20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
22from .convert import revert_model
23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
24
25# List of available models
26MODELS = [basename(f)[:-3]
27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
28
29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
67
68def sasview_model(model_definition, **pars):
69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
73    #print "old",sorted(pars.items())
74    modelname, pars = revert_model(model_definition, pars)
75    #print "new",sorted(pars.items())
76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
78    if ModelClass is None:
79        raise ValueError("could not find model %r in sas.models"%modelname)
80    model = ModelClass()
81
82    for k,v in pars.items():
83        if k.endswith("_pd"):
84            model.dispersion[k[:-3]]['width'] = v
85        elif k.endswith("_pd_n"):
86            model.dispersion[k[:-5]]['npts'] = v
87        elif k.endswith("_pd_nsigma"):
88            model.dispersion[k[:-10]]['nsigmas'] = v
89        elif k.endswith("_pd_type"):
90            model.dispersion[k[:-8]]['type'] = v
91        else:
92            model.setParam(k, v)
93    return model
94
95def randomize(p, v):
96    """
97    Randomizing parameter.
98
99    Guess the parameter type from name.
100    """
101    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
102        return v
103    elif any(s in p for s in ('theta','phi','psi')):
104        # orientation in [-180,180], orientation pd in [0,45]
105        if p.endswith('_pd'):
106            return 45*np.random.rand()
107        else:
108            return 360*np.random.rand() - 180
109    elif 'sld' in p:
110        # sld in in [-0.5,10]
111        return 10.5*np.random.rand() - 0.5
112    elif p.endswith('_pd'):
113        # length pd in [0,1]
114        return np.random.rand()
115    else:
116        # values from 0 to 2*x for all other parameters
117        return 2*np.random.rand()*(v if v != 0 else 1)
118
119def randomize_model(pars, seed=None):
120    if seed is None:
121        seed = np.random.randint(1e9)
122    np.random.seed(seed)
123    # Note: the sort guarantees order of calls to random number generator
124    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
125
126    return pars, seed
127
128def constrain_pars(model_definition, pars):
129    name = model_definition.name
130    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
131        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
132    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
133        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
134
135    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
136    if name == 'guinier':
137        #q_max = 0.2  # mid q maximum
138        q_max = 1.0  # high q maximum
139        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
140        pars['rg'] = min(pars['rg'],rg_max)
141
142    # These constraints are only needed for comparison to sasview
143    if name in ('teubner_strey','broad_peak'):
144        del pars['scale']
145    if name in ('guinier',):
146        del pars['background']
147    if getattr(model_definition, 'category', None) == 'structure-factor':
148        del pars['scale'], pars['background']
149
150
151def parlist(pars):
152    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
153
154def suppress_pd(pars):
155    """
156    Suppress theta_pd for now until the normalization is resolved.
157
158    May also suppress complete polydispersity of the model to test
159    models more quickly.
160    """
161    for p in pars:
162        if p.endswith("_pd"): pars[p] = 0
163
164def eval_sasview(model_definition, pars, data, Nevals=1):
165    # importing sas here so that the error message will be that sas failed to
166    # import rather than the more obscure smear_selection not imported error
167    import sas
168    from sas.models.qsmearing import smear_selection
169    model = sasview_model(model_definition, **pars)
170    smearer = smear_selection(data, model=model)
171    value = None  # silence the linter
172    toc = tic()
173    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
174        if hasattr(data, 'qx_data'):
175            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
176            index = ((~data.mask) & (~np.isnan(data.data))
177                     & (q >= data.qmin) & (q <= data.qmax))
178            if smearer is not None:
179                smearer.model = model  # because smear_selection has a bug
180                smearer.accuracy = data.accuracy
181                smearer.set_index(index)
182                value = smearer.get_value()
183            else:
184                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
185        else:
186            value = model.evalDistribution(data.x)
187            if smearer is not None:
188                value = smearer(value)
189    average_time = toc()*1000./Nevals
190    return value, average_time
191
192def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
193    try:
194        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
195    except Exception,exc:
196        print exc
197        print "... trying again with single precision"
198        model = core.load_model(model_definition, dtype='single', platform="ocl")
199    calculator = DirectModel(data, model, cutoff=cutoff)
200    value = None  # silence the linter
201    toc = tic()
202    for _ in range(max(Nevals, 1)):  # force at least one eval
203        value = calculator(**pars)
204    average_time = toc()*1000./Nevals
205    return value, average_time
206
207
208def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
209    model = core.load_model(model_definition, dtype=dtype, platform="dll")
210    calculator = DirectModel(data, model, cutoff=cutoff)
211    value = None  # silence the linter
212    toc = tic()
213    for _ in range(max(Nevals, 1)):  # force at least one eval
214        value = calculator(**pars)
215    average_time = toc()*1000./Nevals
216    return value, average_time
217
218
219def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
220    if is2D:
221        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
222        data.accuracy = accuracy
223        set_beam_stop(data, 0.004)
224        index = ~data.mask
225    else:
226        if view == 'log':
227            qmax = math.log10(qmax)
228            q = np.logspace(qmax-3, qmax, Nq)
229        else:
230            q = np.linspace(0.001*qmax, qmax, Nq)
231        data = empty_data1D(q, resolution=resolution)
232        index = slice(None, None)
233    return data, index
234
235def compare(name, pars, Ncpu, Nocl, opts, set_pars):
236    model_definition = core.load_model_definition(name)
237
238    view = ('linear' if '-linear' in opts
239            else 'log' if '-log' in opts
240            else 'q4' if '-q4' in opts
241            else 'log')
242
243    opt_values = dict(split
244                      for s in opts for split in ((s.split('='),))
245                      if len(split) == 2)
246    # Sort out data
247    qmax = (10.0 if '-exq' in opts
248            else 1.0 if '-highq' in opts
249            else 0.2 if '-midq' in opts
250            else 0.05)
251    Nq = int(opt_values.get('-Nq', '128'))
252    res = float(opt_values.get('-res', '0'))
253    accuracy = opt_values.get('-accuracy', 'Low')
254    is2D = "-2d" in opts
255    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
256
257
258    # modelling accuracy is determined by dtype and cutoff
259    dtype = ('longdouble' if '-quad' in opts
260             else 'double' if '-double' in opts
261             else 'single')
262    cutoff = float(opt_values.get('-cutoff','1e-5'))
263
264    # randomize parameters
265    #pars.update(set_pars)  # set value before random to control range
266    if '-random' in opts or '-random' in opt_values:
267        seed = int(opt_values['-random']) if '-random' in opt_values else None
268        pars, seed = randomize_model(pars, seed=seed)
269        print "Randomize using -random=%i"%seed
270    pars.update(set_pars)  # set value after random to control value
271    constrain_pars(model_definition, pars)
272
273    # parameter selection
274    if '-mono' in opts:
275        suppress_pd(pars)
276    if '-pars' in opts:
277        print "pars",parlist(pars)
278
279    # OpenCl calculation
280    if Nocl > 0:
281        ocl, ocl_time = eval_opencl(model_definition, pars, data,
282                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
283        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
284        #print "ocl", ocl
285        #print max(ocl), min(ocl)
286
287    # ctypes/sasview calculation
288    if Ncpu > 0 and "-ctypes" in opts:
289        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
290                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
291        comp = "ctypes"
292        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
293    elif Ncpu > 0:
294        try:
295            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
296            comp = "sasview"
297            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background'])
298            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
299            #print "sasview",cpu
300        except ImportError:
301            traceback.print_exc()
302            Ncpu = 0
303
304    # Compare, but only if computing both forms
305    if Nocl > 0 and Ncpu > 0:
306        #print "speedup %.2g"%(cpu_time/ocl_time)
307        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
308        #cpu *= max(ocl/cpu)
309        resid = (ocl - cpu)
310        relerr = resid/cpu
311        #bad = (relerr>1e-4)
312        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
313        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
314        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
315
316    # Plot if requested
317    if '-noplot' in opts: return
318    import matplotlib.pyplot as plt
319    if Ncpu > 0:
320        if Nocl > 0: plt.subplot(131)
321        plot_theory(data, cpu, view=view, plot_data=False)
322        plt.title("%s t=%.1f ms"%(comp,cpu_time))
323        #cbar_title = "log I"
324    if Nocl > 0:
325        if Ncpu > 0: plt.subplot(132)
326        plot_theory(data, ocl, view=view, plot_data=False)
327        plt.title("opencl t=%.1f ms"%ocl_time)
328        #cbar_title = "log I"
329    if Ncpu > 0 and Nocl > 0:
330        plt.subplot(133)
331        if '-abs' in opts:
332            err,errstr,errview = resid, "abs err", "linear"
333        else:
334            err,errstr,errview = abs(relerr), "rel err", "log"
335        #err,errstr = ocl/cpu,"ratio"
336        plot_theory(data, None, resid=err, view=errview, plot_data=False)
337        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
338        #cbar_title = errstr if errview=="linear" else "log "+errstr
339    #if is2D:
340    #    h = plt.colorbar()
341    #    h.ax.set_title(cbar_title)
342
343    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
344        plt.figure()
345        v = relerr
346        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
347        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
348        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
349        plt.ylabel('P(err)')
350        plt.title('Comparison of single and double precision models for %s'%name)
351
352    plt.show()
353
354def _print_stats(label, err):
355    sorted_err = np.sort(abs(err))
356    p50 = int((len(err)-1)*0.50)
357    p98 = int((len(err)-1)*0.98)
358    data = [
359        "max:%.3e"%sorted_err[-1],
360        "median:%.3e"%sorted_err[p50],
361        "98%%:%.3e"%sorted_err[p98],
362        "rms:%.3e"%np.sqrt(np.mean(err**2)),
363        "zero-offset:%+.3e"%np.mean(err),
364        ]
365    print label,"  ".join(data)
366
367
368
369# ===========================================================================
370#
371USAGE="""
372usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
373
374Compare the speed and value for a model between the SasView original and the
375OpenCL rewrite.
376
377model is the name of the model to compare (see below).
378Nopencl is the number of times to run the OpenCL model (default=5)
379Nsasview is the number of times to run the Sasview model (default=1)
380
381Options (* for default):
382
383    -plot*/-noplot plots or suppress the plot of the model
384    -single*/-double/-quad use single/double/quad precision for comparison
385    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
386    -Nq=128 sets the number of Q points in the data set
387    -1d*/-2d computes 1d or 2d data
388    -preset*/-random[=seed] preset or random parameters
389    -mono/-poly* force monodisperse/polydisperse
390    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
391    -cutoff=1e-5* cutoff value for including a point in polydispersity
392    -pars/-nopars* prints the parameter set or not
393    -abs/-rel* plot relative or absolute error
394    -linear/-log/-q4 intensity scaling
395    -hist/-nohist* plot histogram of relative error
396    -res=0 sets the resolution width dQ/Q if calculating with resolution
397    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
398
399Key=value pairs allow you to set specific values to any of the model
400parameters.
401
402Available models:
403"""
404
405
406NAME_OPTIONS = set([
407    'plot','noplot',
408    'single','double','longdouble',
409    'lowq','midq','highq','exq',
410    '2d','1d',
411    'preset','random',
412    'poly','mono',
413    'sasview','ctypes',
414    'nopars','pars',
415    'rel','abs',
416    'linear', 'log', 'q4',
417    'hist','nohist',
418    ])
419VALUE_OPTIONS = [
420    # Note: random is both a name option and a value option
421    'cutoff', 'random', 'Nq', 'res', 'accuracy',
422    ]
423
424def columnize(L, indent="", width=79):
425    column_width = max(len(w) for w in L) + 1
426    num_columns = (width - len(indent)) // column_width
427    num_rows = len(L) // num_columns
428    L = L + [""] * (num_rows*num_columns - len(L))
429    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
430    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
431             for row in zip(*columns)]
432    output = indent + ("\n"+indent).join(lines)
433    return output
434
435
436def get_demo_pars(model_definition):
437    info = generate.make_info(model_definition)
438    pars = dict((p[0],p[2]) for p in info['parameters'])
439    pars.update(info['demo'])
440    return pars
441
442def main():
443    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
444    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
445    models = "\n    ".join("%-15s"%v for v in MODELS)
446    if len(args) == 0:
447        print(USAGE)
448        print(columnize(MODELS, indent="  "))
449        sys.exit(1)
450    if args[0] not in MODELS:
451        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
452        sys.exit(1)
453
454    invalid = [o[1:] for o in opts
455               if o[1:] not in NAME_OPTIONS
456                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
457    if invalid:
458        print "Invalid options: %s"%(", ".join(invalid))
459        sys.exit(1)
460
461    # Get demo parameters from model definition, or use default parameters
462    # if model does not define demo parameters
463    name = args[0]
464    model_definition = core.load_model_definition(name)
465    pars = get_demo_pars(model_definition)
466
467    Nopencl = int(args[1]) if len(args) > 1 else 5
468    Nsasview = int(args[2]) if len(args) > 2 else 1
469
470    # Fill in default polydispersity parameters
471    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
472    for p in pds:
473        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
474        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
475
476    # Fill in parameters given on the command line
477    set_pars = {}
478    for arg in args[3:]:
479        k,v = arg.split('=')
480        if k not in pars:
481            # extract base name without distribution
482            s = set(p.split('_pd')[0] for p in pars)
483            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
484            sys.exit(1)
485        set_pars[k] = float(v) if not v.endswith('type') else v
486
487    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
488
489if __name__ == "__main__":
490    main()
Note: See TracBrowser for help on using the repository browser.