source: sasmodels/sasmodels/compare.py @ 2bebe2b

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 2bebe2b was 2bebe2b, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

slightly cleaner error message from compare if sasview is not available on python path

  • Property mode set to 100755
File size: 16.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8import datetime
9import traceback
10
11import numpy as np
12
13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
17from . import core
18from . import kerneldll
19from . import generate
20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
22from .convert import revert_model
23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
24
25# List of available models
26MODELS = [basename(f)[:-3]
27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
28
29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
67
68def sasview_model(model_definition, **pars):
69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
73    #print "old",sorted(pars.items())
74    modelname, pars = revert_model(model_definition, pars)
75    #print "new",sorted(pars.items())
76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
78    if ModelClass is None:
79        raise ValueError("could not find model %r in sas.models"%modelname)
80    model = ModelClass()
81
82    for k,v in pars.items():
83        if k.endswith("_pd"):
84            model.dispersion[k[:-3]]['width'] = v
85        elif k.endswith("_pd_n"):
86            model.dispersion[k[:-5]]['npts'] = v
87        elif k.endswith("_pd_nsigma"):
88            model.dispersion[k[:-10]]['nsigmas'] = v
89        elif k.endswith("_pd_type"):
90            model.dispersion[k[:-8]]['type'] = v
91        else:
92            model.setParam(k, v)
93    return model
94
95def randomize(p, v):
96    """
97    Randomizing parameter.
98
99    Guess the parameter type from name.
100    """
101    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
102        return v
103    elif any(s in p for s in ('theta','phi','psi')):
104        # orientation in [-180,180], orientation pd in [0,45]
105        if p.endswith('_pd'):
106            return 45*np.random.rand()
107        else:
108            return 360*np.random.rand() - 180
109    elif 'sld' in p:
110        # sld in in [-0.5,10]
111        return 10.5*np.random.rand() - 0.5
112    elif p.endswith('_pd'):
113        # length pd in [0,1]
114        return np.random.rand()
115    else:
116        # values from 0 to 2*x for all other parameters
117        return 2*np.random.rand()*(v if v != 0 else 1)
118
119def randomize_model(pars, seed=None):
120    if seed is None:
121        seed = np.random.randint(1e9)
122    np.random.seed(seed)
123    # Note: the sort guarantees order of calls to random number generator
124    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
125
126    return pars, seed
127
128def constrain_pars(model_definition, pars):
129    name = model_definition.name
130    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
131        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
132    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
133        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
134
135    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
136    if name == 'guinier':
137        #q_max = 0.2  # mid q maximum
138        q_max = 1.0  # high q maximum
139        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
140        pars['rg'] = min(pars['rg'],rg_max)
141
142    # These constraints are only needed for comparison to sasview
143    if name in ('teubner_strey','broad_peak'):
144        del pars['scale']
145    if name in ('guinier',):
146        del pars['background']
147    if getattr(model_definition, 'category', None) == 'structure-factor':
148        del pars['scale'], pars['background']
149
150
151def parlist(pars):
152    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
153
154def suppress_pd(pars):
155    """
156    Suppress theta_pd for now until the normalization is resolved.
157
158    May also suppress complete polydispersity of the model to test
159    models more quickly.
160    """
161    for p in pars:
162        if p.endswith("_pd"): pars[p] = 0
163
164def eval_sasview(model_definition, pars, data, Nevals=1):
165    import sas
166    from sas.models.qsmearing import smear_selection
167    model = sasview_model(model_definition, **pars)
168    smearer = smear_selection(data, model=model)
169    value = None  # silence the linter
170    toc = tic()
171    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
172        if hasattr(data, 'qx_data'):
173            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
174            index = ((~data.mask) & (~np.isnan(data.data))
175                     & (q >= data.qmin) & (q <= data.qmax))
176            if smearer is not None:
177                smearer.model = model  # because smear_selection has a bug
178                smearer.accuracy = data.accuracy
179                smearer.set_index(index)
180                value = smearer.get_value()
181            else:
182                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
183        else:
184            value = model.evalDistribution(data.x)
185            if smearer is not None:
186                value = smearer(value)
187    average_time = toc()*1000./Nevals
188    return value, average_time
189
190def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
191    try:
192        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
193    except Exception,exc:
194        print exc
195        print "... trying again with single precision"
196        model = core.load_model(model_definition, dtype='single', platform="ocl")
197    calculator = DirectModel(data, model, cutoff=cutoff)
198    value = None  # silence the linter
199    toc = tic()
200    for _ in range(max(Nevals, 1)):  # force at least one eval
201        value = calculator(**pars)
202    average_time = toc()*1000./Nevals
203    return value, average_time
204
205
206def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
207    model = core.load_model(model_definition, dtype=dtype, platform="dll")
208    calculator = DirectModel(data, model, cutoff=cutoff)
209    value = None  # silence the linter
210    toc = tic()
211    for _ in range(max(Nevals, 1)):  # force at least one eval
212        value = calculator(**pars)
213    average_time = toc()*1000./Nevals
214    return value, average_time
215
216
217def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
218    if is2D:
219        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
220        data.accuracy = accuracy
221        set_beam_stop(data, 0.004)
222        index = ~data.mask
223    else:
224        if view == 'log':
225            qmax = math.log10(qmax)
226            q = np.logspace(qmax-3, qmax, Nq)
227        else:
228            q = np.linspace(0.001*qmax, qmax, Nq)
229        data = empty_data1D(q, resolution=resolution)
230        index = slice(None, None)
231    return data, index
232
233def compare(name, pars, Ncpu, Nocl, opts, set_pars):
234    model_definition = core.load_model_definition(name)
235
236    view = ('linear' if '-linear' in opts
237            else 'log' if '-log' in opts
238            else 'q4' if '-q4' in opts
239            else 'log')
240
241    opt_values = dict(split
242                      for s in opts for split in ((s.split('='),))
243                      if len(split) == 2)
244    # Sort out data
245    qmax = (10.0 if '-exq' in opts
246            else 1.0 if '-highq' in opts
247            else 0.2 if '-midq' in opts
248            else 0.05)
249    Nq = int(opt_values.get('-Nq', '128'))
250    res = float(opt_values.get('-res', '0'))
251    accuracy = opt_values.get('-accuracy', 'Low')
252    is2D = "-2d" in opts
253    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
254
255
256    # modelling accuracy is determined by dtype and cutoff
257    dtype = ('longdouble' if '-quad' in opts
258             else 'double' if '-double' in opts
259             else 'single')
260    cutoff = float(opt_values.get('-cutoff','1e-5'))
261
262    # randomize parameters
263    #pars.update(set_pars)  # set value before random to control range
264    if '-random' in opts or '-random' in opt_values:
265        seed = int(opt_values['-random']) if '-random' in opt_values else None
266        pars, seed = randomize_model(pars, seed=seed)
267        print "Randomize using -random=%i"%seed
268    pars.update(set_pars)  # set value after random to control value
269    constrain_pars(model_definition, pars)
270
271    # parameter selection
272    if '-mono' in opts:
273        suppress_pd(pars)
274    if '-pars' in opts:
275        print "pars",parlist(pars)
276
277    # OpenCl calculation
278    if Nocl > 0:
279        ocl, ocl_time = eval_opencl(model_definition, pars, data,
280                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
281        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
282        #print "ocl", ocl
283        #print max(ocl), min(ocl)
284
285    # ctypes/sasview calculation
286    if Ncpu > 0 and "-ctypes" in opts:
287        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
288                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
289        comp = "ctypes"
290        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
291    elif Ncpu > 0:
292        try:
293            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
294            comp = "sasview"
295            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background'])
296            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
297            #print "sasview",cpu
298        except ImportError:
299            traceback.print_exc()
300            Ncpu = 0
301
302    # Compare, but only if computing both forms
303    if Nocl > 0 and Ncpu > 0:
304        #print "speedup %.2g"%(cpu_time/ocl_time)
305        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
306        #cpu *= max(ocl/cpu)
307        resid = (ocl - cpu)
308        relerr = resid/cpu
309        #bad = (relerr>1e-4)
310        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
311        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
312        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
313
314    # Plot if requested
315    if '-noplot' in opts: return
316    import matplotlib.pyplot as plt
317    if Ncpu > 0:
318        if Nocl > 0: plt.subplot(131)
319        plot_theory(data, cpu, view=view, plot_data=False)
320        plt.title("%s t=%.1f ms"%(comp,cpu_time))
321        #cbar_title = "log I"
322    if Nocl > 0:
323        if Ncpu > 0: plt.subplot(132)
324        plot_theory(data, ocl, view=view, plot_data=False)
325        plt.title("opencl t=%.1f ms"%ocl_time)
326        #cbar_title = "log I"
327    if Ncpu > 0 and Nocl > 0:
328        plt.subplot(133)
329        if '-abs' in opts:
330            err,errstr,errview = resid, "abs err", "linear"
331        else:
332            err,errstr,errview = abs(relerr), "rel err", "log"
333        #err,errstr = ocl/cpu,"ratio"
334        plot_theory(data, None, resid=err, view=errview, plot_data=False)
335        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
336        #cbar_title = errstr if errview=="linear" else "log "+errstr
337    #if is2D:
338    #    h = plt.colorbar()
339    #    h.ax.set_title(cbar_title)
340
341    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
342        plt.figure()
343        v = relerr
344        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
345        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
346        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
347        plt.ylabel('P(err)')
348        plt.title('Comparison of single and double precision models for %s'%name)
349
350    plt.show()
351
352def _print_stats(label, err):
353    sorted_err = np.sort(abs(err))
354    p50 = int((len(err)-1)*0.50)
355    p98 = int((len(err)-1)*0.98)
356    data = [
357        "max:%.3e"%sorted_err[-1],
358        "median:%.3e"%sorted_err[p50],
359        "98%%:%.3e"%sorted_err[p98],
360        "rms:%.3e"%np.sqrt(np.mean(err**2)),
361        "zero-offset:%+.3e"%np.mean(err),
362        ]
363    print label,"  ".join(data)
364
365
366
367# ===========================================================================
368#
369USAGE="""
370usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
371
372Compare the speed and value for a model between the SasView original and the
373OpenCL rewrite.
374
375model is the name of the model to compare (see below).
376Nopencl is the number of times to run the OpenCL model (default=5)
377Nsasview is the number of times to run the Sasview model (default=1)
378
379Options (* for default):
380
381    -plot*/-noplot plots or suppress the plot of the model
382    -single*/-double/-quad use single/double/quad precision for comparison
383    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
384    -Nq=128 sets the number of Q points in the data set
385    -1d*/-2d computes 1d or 2d data
386    -preset*/-random[=seed] preset or random parameters
387    -mono/-poly* force monodisperse/polydisperse
388    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
389    -cutoff=1e-5* cutoff value for including a point in polydispersity
390    -pars/-nopars* prints the parameter set or not
391    -abs/-rel* plot relative or absolute error
392    -linear/-log/-q4 intensity scaling
393    -hist/-nohist* plot histogram of relative error
394    -res=0 sets the resolution width dQ/Q if calculating with resolution
395    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
396
397Key=value pairs allow you to set specific values to any of the model
398parameters.
399
400Available models:
401"""
402
403
404NAME_OPTIONS = set([
405    'plot','noplot',
406    'single','double','longdouble',
407    'lowq','midq','highq','exq',
408    '2d','1d',
409    'preset','random',
410    'poly','mono',
411    'sasview','ctypes',
412    'nopars','pars',
413    'rel','abs',
414    'linear', 'log', 'q4',
415    'hist','nohist',
416    ])
417VALUE_OPTIONS = [
418    # Note: random is both a name option and a value option
419    'cutoff', 'random', 'Nq', 'res', 'accuracy',
420    ]
421
422def columnize(L, indent="", width=79):
423    column_width = max(len(w) for w in L) + 1
424    num_columns = (width - len(indent)) // column_width
425    num_rows = len(L) // num_columns
426    L = L + [""] * (num_rows*num_columns - len(L))
427    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
428    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
429             for row in zip(*columns)]
430    output = indent + ("\n"+indent).join(lines)
431    return output
432
433
434def get_demo_pars(model_definition):
435    info = generate.make_info(model_definition)
436    pars = dict((p[0],p[2]) for p in info['parameters'])
437    pars.update(info['demo'])
438    return pars
439
440def main():
441    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
442    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
443    models = "\n    ".join("%-15s"%v for v in MODELS)
444    if len(args) == 0:
445        print(USAGE)
446        print(columnize(MODELS, indent="  "))
447        sys.exit(1)
448    if args[0] not in MODELS:
449        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
450        sys.exit(1)
451
452    invalid = [o[1:] for o in opts
453               if o[1:] not in NAME_OPTIONS
454                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
455    if invalid:
456        print "Invalid options: %s"%(", ".join(invalid))
457        sys.exit(1)
458
459    # Get demo parameters from model definition, or use default parameters
460    # if model does not define demo parameters
461    name = args[0]
462    model_definition = core.load_model_definition(name)
463    pars = get_demo_pars(model_definition)
464
465    Nopencl = int(args[1]) if len(args) > 1 else 5
466    Nsasview = int(args[2]) if len(args) > 2 else 1
467
468    # Fill in default polydispersity parameters
469    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
470    for p in pds:
471        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
472        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
473
474    # Fill in parameters given on the command line
475    set_pars = {}
476    for arg in args[3:]:
477        k,v = arg.split('=')
478        if k not in pars:
479            # extract base name without distribution
480            s = set(p.split('_pd')[0] for p in pars)
481            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
482            sys.exit(1)
483        set_pars[k] = float(v) if not v.endswith('type') else v
484
485    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
486
487if __name__ == "__main__":
488    main()
Note: See TracBrowser for help on using the repository browser.