source: sasmodels/sasmodels/compare.py @ 5d316e9

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5d316e9 was 5d316e9, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

support fast and loose single precision and half precision

  • Property mode set to 100755
File size: 18.2 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[7cf2cfd]8import datetime
[5753e4e]9import traceback
[87985ca]10
[1726b21]11import numpy as np
[473183c]12
[29fc2a3]13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
[e922c5d]17from . import core
18from . import kerneldll
[cd3dba0]19from . import generate
[e922c5d]20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
22from .convert import revert_model
[750ffa5]23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]24
[d547f16]25# List of available models
26MODELS = [basename(f)[:-3]
[e922c5d]27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
[d547f16]28
[7cf2cfd]29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
[8a20be5]67
[aa4946b]68def sasview_model(model_definition, **pars):
[87985ca]69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
[9404dd3]73    #print("old",sorted(pars.items()))
[aa4946b]74    modelname, pars = revert_model(model_definition, pars)
[9404dd3]75    #print("new",sorted(pars.items()))
[87c722e]76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]78    if ModelClass is None:
[87c722e]79        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]80    model = ModelClass()
81
82    for k,v in pars.items():
83        if k.endswith("_pd"):
84            model.dispersion[k[:-3]]['width'] = v
85        elif k.endswith("_pd_n"):
86            model.dispersion[k[:-5]]['npts'] = v
87        elif k.endswith("_pd_nsigma"):
88            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]89        elif k.endswith("_pd_type"):
90            model.dispersion[k[:-8]]['type'] = v
[8a20be5]91        else:
92            model.setParam(k, v)
93    return model
94
[87985ca]95def randomize(p, v):
96    """
97    Randomizing parameter.
98
99    Guess the parameter type from name.
100    """
101    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
102        return v
103    elif any(s in p for s in ('theta','phi','psi')):
104        # orientation in [-180,180], orientation pd in [0,45]
105        if p.endswith('_pd'):
106            return 45*np.random.rand()
107        else:
108            return 360*np.random.rand() - 180
109    elif 'sld' in p:
110        # sld in in [-0.5,10]
111        return 10.5*np.random.rand() - 0.5
112    elif p.endswith('_pd'):
113        # length pd in [0,1]
114        return np.random.rand()
115    else:
[b89f519]116        # values from 0 to 2*x for all other parameters
117        return 2*np.random.rand()*(v if v != 0 else 1)
[87985ca]118
[cd3dba0]119def randomize_model(pars, seed=None):
[216a9e1]120    if seed is None:
121        seed = np.random.randint(1e9)
122    np.random.seed(seed)
123    # Note: the sort guarantees order of calls to random number generator
124    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
[cd3dba0]125
126    return pars, seed
127
128def constrain_pars(model_definition, pars):
129    name = model_definition.name
[216a9e1]130    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
131        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
[b514adf]132    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
133        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
134
135    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
136    if name == 'guinier':
137        #q_max = 0.2  # mid q maximum
138        q_max = 1.0  # high q maximum
139        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
140        pars['rg'] = min(pars['rg'],rg_max)
[cd3dba0]141
142    # These constraints are only needed for comparison to sasview
143    if name in ('teubner_strey','broad_peak'):
144        del pars['scale']
145    if name in ('guinier',):
146        del pars['background']
147    if getattr(model_definition, 'category', None) == 'structure-factor':
148        del pars['scale'], pars['background']
149
[216a9e1]150
[87985ca]151def parlist(pars):
152    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
153
154def suppress_pd(pars):
155    """
156    Suppress theta_pd for now until the normalization is resolved.
157
158    May also suppress complete polydispersity of the model to test
159    models more quickly.
160    """
161    for p in pars:
162        if p.endswith("_pd"): pars[p] = 0
163
[7cf2cfd]164def eval_sasview(model_definition, pars, data, Nevals=1):
[dc056b9]165    # importing sas here so that the error message will be that sas failed to
166    # import rather than the more obscure smear_selection not imported error
[2bebe2b]167    import sas
[346bc88]168    from sas.models.qsmearing import smear_selection
[7cf2cfd]169    model = sasview_model(model_definition, **pars)
[346bc88]170    smearer = smear_selection(data, model=model)
[0763009]171    value = None  # silence the linter
[216a9e1]172    toc = tic()
[0763009]173    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
[216a9e1]174        if hasattr(data, 'qx_data'):
[346bc88]175            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
176            index = ((~data.mask) & (~np.isnan(data.data))
177                     & (q >= data.qmin) & (q <= data.qmax))
178            if smearer is not None:
179                smearer.model = model  # because smear_selection has a bug
[3e6aaad]180                smearer.accuracy = data.accuracy
[346bc88]181                smearer.set_index(index)
182                value = smearer.get_value()
183            else:
184                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
[216a9e1]185        else:
186            value = model.evalDistribution(data.x)
[346bc88]187            if smearer is not None:
188                value = smearer(value)
[216a9e1]189    average_time = toc()*1000./Nevals
190    return value, average_time
191
[5d316e9]192def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1,
193                cutoff=0., fast=False):
[216a9e1]194    try:
[5d316e9]195        model = core.load_model(model_definition, dtype=dtype,
196                                platform="ocl", fast=fast)
[9404dd3]197    except Exception as exc:
198        print(exc)
199        print("... trying again with single precision")
[5d316e9]200        model = core.load_model(model_definition, dtype='single',
201                                platform="ocl", fast=fast)
[7cf2cfd]202    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]203    value = None  # silence the linter
[216a9e1]204    toc = tic()
[0763009]205    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]206        value = calculator(**pars)
[216a9e1]207    average_time = toc()*1000./Nevals
208    return value, average_time
209
[7cf2cfd]210
[0763009]211def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
[aa4946b]212    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[7cf2cfd]213    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]214    value = None  # silence the linter
[216a9e1]215    toc = tic()
[0763009]216    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]217        value = calculator(**pars)
[216a9e1]218    average_time = toc()*1000./Nevals
219    return value, average_time
220
[7cf2cfd]221
[3e6aaad]222def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
[216a9e1]223    if is2D:
[346bc88]224        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
[3e6aaad]225        data.accuracy = accuracy
[87985ca]226        set_beam_stop(data, 0.004)
227        index = ~data.mask
[216a9e1]228    else:
[b89f519]229        if view == 'log':
230            qmax = math.log10(qmax)
231            q = np.logspace(qmax-3, qmax, Nq)
232        else:
233            q = np.linspace(0.001*qmax, qmax, Nq)
[346bc88]234        data = empty_data1D(q, resolution=resolution)
[216a9e1]235        index = slice(None, None)
236    return data, index
237
[4b41184]238def compare(name, pars, Ncomp, Nbase, opts, set_pars):
[cd3dba0]239    model_definition = core.load_model_definition(name)
240
[5edfe12]241    view = ('linear' if '-linear' in opts
242            else 'log' if '-log' in opts
243            else 'q4' if '-q4' in opts
244            else 'log')
[b89f519]245
[216a9e1]246    opt_values = dict(split
247                      for s in opts for split in ((s.split('='),))
248                      if len(split) == 2)
249    # Sort out data
[5edfe12]250    qmax = (10.0 if '-exq' in opts
251            else 1.0 if '-highq' in opts
252            else 0.2 if '-midq' in opts
253            else 0.05)
[216a9e1]254    Nq = int(opt_values.get('-Nq', '128'))
[346bc88]255    res = float(opt_values.get('-res', '0'))
[3e6aaad]256    accuracy = opt_values.get('-accuracy', 'Low')
[73a3e22]257    is2D = "-2d" in opts
[3e6aaad]258    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
[216a9e1]259
[87985ca]260
261    # modelling accuracy is determined by dtype and cutoff
[e1ace4d]262    dtype = ('longdouble' if '-quad' in opts
[5edfe12]263             else 'double' if '-double' in opts
[5d316e9]264             else 'half' if '-half' in opts
[5edfe12]265             else 'single')
[216a9e1]266    cutoff = float(opt_values.get('-cutoff','1e-5'))
[5d316e9]267    fast = "-fast" in opts and dtype is 'single'
[87985ca]268
269    # randomize parameters
[7cf2cfd]270    #pars.update(set_pars)  # set value before random to control range
[216a9e1]271    if '-random' in opts or '-random' in opt_values:
272        seed = int(opt_values['-random']) if '-random' in opt_values else None
[cd3dba0]273        pars, seed = randomize_model(pars, seed=seed)
[9404dd3]274        print("Randomize using -random=%i"%seed)
[7cf2cfd]275    pars.update(set_pars)  # set value after random to control value
[b514adf]276    constrain_pars(model_definition, pars)
[87985ca]277
278    # parameter selection
279    if '-mono' in opts:
280        suppress_pd(pars)
281    if '-pars' in opts:
[9404dd3]282        print("pars "+str(parlist(pars)))
[87985ca]283
[4b41184]284    # Base calculation
285    if 0:
286        from sasmodels.models import sphere as target
287        base_name = target.name
288        base, base_time = eval_ctypes(target, pars, data,
[5d316e9]289                dtype='longdouble', cutoff=0., Nevals=Ncomp)
[4b41184]290    elif Nbase > 0 and "-ctypes" in opts and "-sasview" in opts:
[319ab14]291        try:
[4b41184]292            base, base_time = eval_sasview(model_definition, pars, data, Ncomp)
293            base_name = "sasview"
[9404dd3]294            #print("base/sasview", (base-pars['background'])/(comp-pars['background']))
295            print("sasview t=%.1f ms, intensity=%.0f"%(base_time, sum(base)))
296            #print("sasview",comp)
[319ab14]297        except ImportError:
298            traceback.print_exc()
[1ec7efa]299            Nbase = 0
[4b41184]300    elif Nbase > 0:
301        base, base_time = eval_opencl(model_definition, pars, data,
[5d316e9]302                dtype=dtype, cutoff=cutoff, Nevals=Nbase, fast=fast)
[4b41184]303        base_name = "ocl"
[9404dd3]304        print("opencl t=%.1f ms, intensity=%.0f"%(base_time, sum(base)))
305        #print("base " + base)
306        #print(max(base), min(base))
[4b41184]307
308    # Comparison calculation
309    if Ncomp > 0 and "-ctypes" in opts:
310        comp, comp_time = eval_ctypes(model_definition, pars, data,
[5d316e9]311                dtype=dtype, cutoff=cutoff, Nevals=Ncomp)
[4b41184]312        comp_name = "ctypes"
[9404dd3]313        print("ctypes t=%.1f ms, intensity=%.0f"%(comp_time, sum(comp)))
[4b41184]314    elif Ncomp > 0:
[7cf2cfd]315        try:
[4b41184]316            comp, comp_time = eval_sasview(model_definition, pars, data, Ncomp)
317            comp_name = "sasview"
[9404dd3]318            #print("base/sasview", (base-pars['background'])/(comp-pars['background']))
319            print("sasview t=%.1f ms, intensity=%.0f"%(comp_time, sum(comp)))
320            #print("sasview",comp)
[7cf2cfd]321        except ImportError:
[5753e4e]322            traceback.print_exc()
[4b41184]323            Ncomp = 0
[87985ca]324
325    # Compare, but only if computing both forms
[4b41184]326    if Nbase > 0 and Ncomp > 0:
[9404dd3]327        #print("speedup %.2g"%(comp_time/base_time))
328        #print("max |base/comp|", max(abs(base/comp)), "%.15g"%max(abs(base)), "%.15g"%max(abs(comp)))
[4b41184]329        #comp *= max(base/comp)
330        resid = (base - comp)
331        relerr = resid/comp
[ba69383]332        #bad = (relerr>1e-4)
[9404dd3]333        #print(relerr[bad],comp[bad],base[bad],data.qx_data[bad],data.qy_data[bad])
[4b41184]334        _print_stats("|%s-%s|"%(base_name,comp_name)+(" "*(3+len(comp_name))), resid)
335        _print_stats("|(%s-%s)/%s|"%(base_name,comp_name,comp_name), relerr)
[87985ca]336
337    # Plot if requested
338    if '-noplot' in opts: return
[1726b21]339    import matplotlib.pyplot as plt
[4b41184]340    if Ncomp > 0:
341        if Nbase > 0: plt.subplot(131)
342        plot_theory(data, comp, view=view, plot_data=False)
343        plt.title("%s t=%.1f ms"%(comp_name,comp_time))
[7cf2cfd]344        #cbar_title = "log I"
[4b41184]345    if Nbase > 0:
346        if Ncomp > 0: plt.subplot(132)
347        plot_theory(data, base, view=view, plot_data=False)
348        plt.title("%s t=%.1f ms"%(base_name,base_time))
[7cf2cfd]349        #cbar_title = "log I"
[4b41184]350    if Ncomp > 0 and Nbase > 0:
[87985ca]351        plt.subplot(133)
[29f5536]352        if '-abs' in opts:
[b89f519]353            err,errstr,errview = resid, "abs err", "linear"
[29f5536]354        else:
[b89f519]355            err,errstr,errview = abs(relerr), "rel err", "log"
[4b41184]356        #err,errstr = base/comp,"ratio"
[7cf2cfd]357        plot_theory(data, None, resid=err, view=errview, plot_data=False)
[346bc88]358        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
[7cf2cfd]359        #cbar_title = errstr if errview=="linear" else "log "+errstr
360    #if is2D:
361    #    h = plt.colorbar()
362    #    h.ax.set_title(cbar_title)
[ba69383]363
[4b41184]364    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
[ba69383]365        plt.figure()
[346bc88]366        v = relerr
[ba69383]367        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
368        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
369        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
370        plt.ylabel('P(err)')
371        plt.title('Comparison of single and double precision models for %s'%name)
372
[8a20be5]373    plt.show()
374
[0763009]375def _print_stats(label, err):
376    sorted_err = np.sort(abs(err))
377    p50 = int((len(err)-1)*0.50)
378    p98 = int((len(err)-1)*0.98)
379    data = [
380        "max:%.3e"%sorted_err[-1],
381        "median:%.3e"%sorted_err[p50],
382        "98%%:%.3e"%sorted_err[p98],
383        "rms:%.3e"%np.sqrt(np.mean(err**2)),
384        "zero-offset:%+.3e"%np.mean(err),
385        ]
[9404dd3]386    print(label+"  ".join(data))
[0763009]387
388
389
[87985ca]390# ===========================================================================
391#
392USAGE="""
393usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
394
395Compare the speed and value for a model between the SasView original and the
396OpenCL rewrite.
397
398model is the name of the model to compare (see below).
399Nopencl is the number of times to run the OpenCL model (default=5)
400Nsasview is the number of times to run the Sasview model (default=1)
401
402Options (* for default):
403
404    -plot*/-noplot plots or suppress the plot of the model
[5d316e9]405    -half/-single*/-double/-quad/-fast sets the calculation precision
[29f5536]406    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[216a9e1]407    -Nq=128 sets the number of Q points in the data set
[73a3e22]408    -1d*/-2d computes 1d or 2d data
[2d0aced]409    -preset*/-random[=seed] preset or random parameters
410    -mono/-poly* force monodisperse/polydisperse
[319ab14]411    -ctypes/-sasview* selects gpu:cpu, gpu:sasview, or sasview:cpu if both
[3e6aaad]412    -cutoff=1e-5* cutoff value for including a point in polydispersity
[2d0aced]413    -pars/-nopars* prints the parameter set or not
414    -abs/-rel* plot relative or absolute error
[b89f519]415    -linear/-log/-q4 intensity scaling
[ba69383]416    -hist/-nohist* plot histogram of relative error
[346bc88]417    -res=0 sets the resolution width dQ/Q if calculating with resolution
[5d316e9]418    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
[87985ca]419
420Key=value pairs allow you to set specific values to any of the model
421parameters.
422
423Available models:
424"""
425
[7cf2cfd]426
[216a9e1]427NAME_OPTIONS = set([
[5d316e9]428    'plot', 'noplot',
429    'half', 'single', 'double', 'quad', 'fast',
430    'lowq', 'midq', 'highq', 'exq',
431    '2d', '1d',
432    'preset', 'random',
433    'poly', 'mono',
434    'sasview', 'ctypes',
435    'nopars', 'pars',
436    'rel', 'abs',
[b89f519]437    'linear', 'log', 'q4',
[5d316e9]438    'hist', 'nohist',
[216a9e1]439    ])
440VALUE_OPTIONS = [
441    # Note: random is both a name option and a value option
[3e6aaad]442    'cutoff', 'random', 'Nq', 'res', 'accuracy',
[87985ca]443    ]
444
[7cf2cfd]445def columnize(L, indent="", width=79):
446    column_width = max(len(w) for w in L) + 1
447    num_columns = (width - len(indent)) // column_width
448    num_rows = len(L) // num_columns
449    L = L + [""] * (num_rows*num_columns - len(L))
450    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
451    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
452             for row in zip(*columns)]
453    output = indent + ("\n"+indent).join(lines)
454    return output
455
456
[cd3dba0]457def get_demo_pars(model_definition):
458    info = generate.make_info(model_definition)
459    pars = dict((p[0],p[2]) for p in info['parameters'])
460    pars.update(info['demo'])
[373d1b6]461    return pars
462
[87985ca]463def main():
464    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
[319ab14]465    popts = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' in arg]
466    args = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' not in arg]
[d547f16]467    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]468    if len(args) == 0:
[7cf2cfd]469        print(USAGE)
470        print(columnize(MODELS, indent="  "))
[87985ca]471        sys.exit(1)
472    if args[0] not in MODELS:
[9404dd3]473        print("Model %r not available. Use one of:\n    %s"%(args[0],models))
[87985ca]474        sys.exit(1)
[319ab14]475    if len(args) > 3:
476        print("expected parameters: model Nopencl Nsasview")
[87985ca]477
478    invalid = [o[1:] for o in opts
[216a9e1]479               if o[1:] not in NAME_OPTIONS
480                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]481    if invalid:
[9404dd3]482        print("Invalid options: %s"%(", ".join(invalid)))
[87985ca]483        sys.exit(1)
484
[d547f16]485    # Get demo parameters from model definition, or use default parameters
486    # if model does not define demo parameters
487    name = args[0]
[cd3dba0]488    model_definition = core.load_model_definition(name)
489    pars = get_demo_pars(model_definition)
[d547f16]490
[4b41184]491    Ncomp = int(args[1]) if len(args) > 1 else 5
492    Nbase = int(args[2]) if len(args) > 2 else 1
[87985ca]493
494    # Fill in default polydispersity parameters
495    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
496    for p in pds:
497        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
498        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
499
500    # Fill in parameters given on the command line
501    set_pars = {}
[319ab14]502    for arg in popts:
503        k,v = arg.split('=',1)
[87985ca]504        if k not in pars:
505            # extract base name without distribution
506            s = set(p.split('_pd')[0] for p in pars)
[9404dd3]507            print("%r invalid; parameters are: %s"%(k,", ".join(sorted(s))))
[87985ca]508            sys.exit(1)
509        set_pars[k] = float(v) if not v.endswith('type') else v
510
[4b41184]511    compare(name, pars, Ncomp, Nbase, opts, set_pars)
[87985ca]512
[8a20be5]513if __name__ == "__main__":
[87985ca]514    main()
Note: See TracBrowser for help on using the repository browser.