source: sasmodels/sasmodels/compare.py @ 5edfe12

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5edfe12 was 5edfe12, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

support long double kernels for precision limited models

  • Property mode set to 100755
File size: 16.8 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[7cf2cfd]8import datetime
[5753e4e]9import traceback
[87985ca]10
[1726b21]11import numpy as np
[473183c]12
[29fc2a3]13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
[e922c5d]17from . import core
18from . import kerneldll
[cd3dba0]19from . import generate
[e922c5d]20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
22from .convert import revert_model
[750ffa5]23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]24
[d547f16]25# List of available models
26MODELS = [basename(f)[:-3]
[e922c5d]27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
[d547f16]28
[7cf2cfd]29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
[8a20be5]67
[aa4946b]68def sasview_model(model_definition, **pars):
[87985ca]69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
73    #print "old",sorted(pars.items())
[aa4946b]74    modelname, pars = revert_model(model_definition, pars)
[87985ca]75    #print "new",sorted(pars.items())
[87c722e]76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]78    if ModelClass is None:
[87c722e]79        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]80    model = ModelClass()
81
82    for k,v in pars.items():
83        if k.endswith("_pd"):
84            model.dispersion[k[:-3]]['width'] = v
85        elif k.endswith("_pd_n"):
86            model.dispersion[k[:-5]]['npts'] = v
87        elif k.endswith("_pd_nsigma"):
88            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]89        elif k.endswith("_pd_type"):
90            model.dispersion[k[:-8]]['type'] = v
[8a20be5]91        else:
92            model.setParam(k, v)
93    return model
94
[87985ca]95def randomize(p, v):
96    """
97    Randomizing parameter.
98
99    Guess the parameter type from name.
100    """
101    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
102        return v
103    elif any(s in p for s in ('theta','phi','psi')):
104        # orientation in [-180,180], orientation pd in [0,45]
105        if p.endswith('_pd'):
106            return 45*np.random.rand()
107        else:
108            return 360*np.random.rand() - 180
109    elif 'sld' in p:
110        # sld in in [-0.5,10]
111        return 10.5*np.random.rand() - 0.5
112    elif p.endswith('_pd'):
113        # length pd in [0,1]
114        return np.random.rand()
115    else:
[b89f519]116        # values from 0 to 2*x for all other parameters
117        return 2*np.random.rand()*(v if v != 0 else 1)
[87985ca]118
[cd3dba0]119def randomize_model(pars, seed=None):
[216a9e1]120    if seed is None:
121        seed = np.random.randint(1e9)
122    np.random.seed(seed)
123    # Note: the sort guarantees order of calls to random number generator
124    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
[cd3dba0]125
126    return pars, seed
127
128def constrain_pars(model_definition, pars):
129    name = model_definition.name
[216a9e1]130    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
131        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
[b514adf]132    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
133        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
134
135    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
136    if name == 'guinier':
137        #q_max = 0.2  # mid q maximum
138        q_max = 1.0  # high q maximum
139        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
140        pars['rg'] = min(pars['rg'],rg_max)
[cd3dba0]141
142    # These constraints are only needed for comparison to sasview
143    if name in ('teubner_strey','broad_peak'):
144        del pars['scale']
145    if name in ('guinier',):
146        del pars['background']
147    if getattr(model_definition, 'category', None) == 'structure-factor':
148        del pars['scale'], pars['background']
149
[216a9e1]150
[87985ca]151def parlist(pars):
152    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
153
154def suppress_pd(pars):
155    """
156    Suppress theta_pd for now until the normalization is resolved.
157
158    May also suppress complete polydispersity of the model to test
159    models more quickly.
160    """
161    for p in pars:
162        if p.endswith("_pd"): pars[p] = 0
163
[7cf2cfd]164def eval_sasview(model_definition, pars, data, Nevals=1):
[346bc88]165    from sas.models.qsmearing import smear_selection
[7cf2cfd]166    model = sasview_model(model_definition, **pars)
[346bc88]167    smearer = smear_selection(data, model=model)
[0763009]168    value = None  # silence the linter
[216a9e1]169    toc = tic()
[0763009]170    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
[216a9e1]171        if hasattr(data, 'qx_data'):
[346bc88]172            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
173            index = ((~data.mask) & (~np.isnan(data.data))
174                     & (q >= data.qmin) & (q <= data.qmax))
175            if smearer is not None:
176                smearer.model = model  # because smear_selection has a bug
[3e6aaad]177                smearer.accuracy = data.accuracy
[346bc88]178                smearer.set_index(index)
179                value = smearer.get_value()
180            else:
181                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
[216a9e1]182        else:
183            value = model.evalDistribution(data.x)
[346bc88]184            if smearer is not None:
185                value = smearer(value)
[216a9e1]186    average_time = toc()*1000./Nevals
187    return value, average_time
188
[0763009]189def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
[216a9e1]190    try:
[aa4946b]191        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[216a9e1]192    except Exception,exc:
193        print exc
194        print "... trying again with single precision"
[aa4946b]195        model = core.load_model(model_definition, dtype='single', platform="ocl")
[7cf2cfd]196    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]197    value = None  # silence the linter
[216a9e1]198    toc = tic()
[0763009]199    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]200        value = calculator(**pars)
[216a9e1]201    average_time = toc()*1000./Nevals
202    return value, average_time
203
[7cf2cfd]204
[0763009]205def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
[aa4946b]206    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[7cf2cfd]207    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]208    value = None  # silence the linter
[216a9e1]209    toc = tic()
[0763009]210    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]211        value = calculator(**pars)
[216a9e1]212    average_time = toc()*1000./Nevals
213    return value, average_time
214
[7cf2cfd]215
[3e6aaad]216def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
[216a9e1]217    if is2D:
[346bc88]218        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
[3e6aaad]219        data.accuracy = accuracy
[87985ca]220        set_beam_stop(data, 0.004)
221        index = ~data.mask
[216a9e1]222    else:
[b89f519]223        if view == 'log':
224            qmax = math.log10(qmax)
225            q = np.logspace(qmax-3, qmax, Nq)
226        else:
227            q = np.linspace(0.001*qmax, qmax, Nq)
[346bc88]228        data = empty_data1D(q, resolution=resolution)
[216a9e1]229        index = slice(None, None)
230    return data, index
231
[a503bfd]232def compare(name, pars, Ncpu, Nocl, opts, set_pars):
[cd3dba0]233    model_definition = core.load_model_definition(name)
234
[5edfe12]235    view = ('linear' if '-linear' in opts
236            else 'log' if '-log' in opts
237            else 'q4' if '-q4' in opts
238            else 'log')
[b89f519]239
[216a9e1]240    opt_values = dict(split
241                      for s in opts for split in ((s.split('='),))
242                      if len(split) == 2)
243    # Sort out data
[5edfe12]244    qmax = (10.0 if '-exq' in opts
245            else 1.0 if '-highq' in opts
246            else 0.2 if '-midq' in opts
247            else 0.05)
[216a9e1]248    Nq = int(opt_values.get('-Nq', '128'))
[346bc88]249    res = float(opt_values.get('-res', '0'))
[3e6aaad]250    accuracy = opt_values.get('-accuracy', 'Low')
[73a3e22]251    is2D = "-2d" in opts
[3e6aaad]252    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
[216a9e1]253
[87985ca]254
255    # modelling accuracy is determined by dtype and cutoff
[5edfe12]256    dtype = ('longdouble' if '-longdouble' in opts
257             else 'double' if '-double' in opts
258             else 'single')
[216a9e1]259    cutoff = float(opt_values.get('-cutoff','1e-5'))
[87985ca]260
261    # randomize parameters
[7cf2cfd]262    #pars.update(set_pars)  # set value before random to control range
[216a9e1]263    if '-random' in opts or '-random' in opt_values:
264        seed = int(opt_values['-random']) if '-random' in opt_values else None
[cd3dba0]265        pars, seed = randomize_model(pars, seed=seed)
[87985ca]266        print "Randomize using -random=%i"%seed
[7cf2cfd]267    pars.update(set_pars)  # set value after random to control value
[b514adf]268    constrain_pars(model_definition, pars)
[87985ca]269
270    # parameter selection
271    if '-mono' in opts:
272        suppress_pd(pars)
273    if '-pars' in opts:
274        print "pars",parlist(pars)
275
276    # OpenCl calculation
[a503bfd]277    if Nocl > 0:
[aa4946b]278        ocl, ocl_time = eval_opencl(model_definition, pars, data,
279                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
[346bc88]280        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
[cd3dba0]281        #print "ocl", ocl
[a503bfd]282        #print max(ocl), min(ocl)
[87985ca]283
284    # ctypes/sasview calculation
285    if Ncpu > 0 and "-ctypes" in opts:
[aa4946b]286        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
287                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
[87985ca]288        comp = "ctypes"
[346bc88]289        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
[87985ca]290    elif Ncpu > 0:
[7cf2cfd]291        try:
292            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
293            comp = "sasview"
294            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background'])
295            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
[cd3dba0]296            #print "sasview",cpu
[7cf2cfd]297        except ImportError:
[5753e4e]298            traceback.print_exc()
[7cf2cfd]299            Ncpu = 0
[87985ca]300
301    # Compare, but only if computing both forms
[a503bfd]302    if Nocl > 0 and Ncpu > 0:
303        #print "speedup %.2g"%(cpu_time/ocl_time)
304        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
305        #cpu *= max(ocl/cpu)
[346bc88]306        resid = (ocl - cpu)
307        relerr = resid/cpu
[ba69383]308        #bad = (relerr>1e-4)
[a503bfd]309        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
[0763009]310        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
311        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
[87985ca]312
313    # Plot if requested
314    if '-noplot' in opts: return
[1726b21]315    import matplotlib.pyplot as plt
[87985ca]316    if Ncpu > 0:
[a503bfd]317        if Nocl > 0: plt.subplot(131)
[7cf2cfd]318        plot_theory(data, cpu, view=view, plot_data=False)
[87985ca]319        plt.title("%s t=%.1f ms"%(comp,cpu_time))
[7cf2cfd]320        #cbar_title = "log I"
[a503bfd]321    if Nocl > 0:
[87985ca]322        if Ncpu > 0: plt.subplot(132)
[7cf2cfd]323        plot_theory(data, ocl, view=view, plot_data=False)
[a503bfd]324        plt.title("opencl t=%.1f ms"%ocl_time)
[7cf2cfd]325        #cbar_title = "log I"
[a503bfd]326    if Ncpu > 0 and Nocl > 0:
[87985ca]327        plt.subplot(133)
[29f5536]328        if '-abs' in opts:
[b89f519]329            err,errstr,errview = resid, "abs err", "linear"
[29f5536]330        else:
[b89f519]331            err,errstr,errview = abs(relerr), "rel err", "log"
[a503bfd]332        #err,errstr = ocl/cpu,"ratio"
[7cf2cfd]333        plot_theory(data, None, resid=err, view=errview, plot_data=False)
[346bc88]334        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
[7cf2cfd]335        #cbar_title = errstr if errview=="linear" else "log "+errstr
336    #if is2D:
337    #    h = plt.colorbar()
338    #    h.ax.set_title(cbar_title)
[ba69383]339
[a503bfd]340    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
[ba69383]341        plt.figure()
[346bc88]342        v = relerr
[ba69383]343        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
344        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
345        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
346        plt.ylabel('P(err)')
347        plt.title('Comparison of single and double precision models for %s'%name)
348
[8a20be5]349    plt.show()
350
[0763009]351def _print_stats(label, err):
352    sorted_err = np.sort(abs(err))
353    p50 = int((len(err)-1)*0.50)
354    p98 = int((len(err)-1)*0.98)
355    data = [
356        "max:%.3e"%sorted_err[-1],
357        "median:%.3e"%sorted_err[p50],
358        "98%%:%.3e"%sorted_err[p98],
359        "rms:%.3e"%np.sqrt(np.mean(err**2)),
360        "zero-offset:%+.3e"%np.mean(err),
361        ]
362    print label,"  ".join(data)
363
364
365
[87985ca]366# ===========================================================================
367#
368USAGE="""
369usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
370
371Compare the speed and value for a model between the SasView original and the
372OpenCL rewrite.
373
374model is the name of the model to compare (see below).
375Nopencl is the number of times to run the OpenCL model (default=5)
376Nsasview is the number of times to run the Sasview model (default=1)
377
378Options (* for default):
379
380    -plot*/-noplot plots or suppress the plot of the model
[5edfe12]381    -single*/-double/-longdouble uses double precision for comparison
[29f5536]382    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[216a9e1]383    -Nq=128 sets the number of Q points in the data set
[73a3e22]384    -1d*/-2d computes 1d or 2d data
[2d0aced]385    -preset*/-random[=seed] preset or random parameters
386    -mono/-poly* force monodisperse/polydisperse
387    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
[3e6aaad]388    -cutoff=1e-5* cutoff value for including a point in polydispersity
[2d0aced]389    -pars/-nopars* prints the parameter set or not
390    -abs/-rel* plot relative or absolute error
[b89f519]391    -linear/-log/-q4 intensity scaling
[ba69383]392    -hist/-nohist* plot histogram of relative error
[346bc88]393    -res=0 sets the resolution width dQ/Q if calculating with resolution
[3e6aaad]394    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
[87985ca]395
396Key=value pairs allow you to set specific values to any of the model
397parameters.
398
399Available models:
400"""
401
[7cf2cfd]402
[216a9e1]403NAME_OPTIONS = set([
[87985ca]404    'plot','noplot',
[5edfe12]405    'single','double','longdouble',
[29f5536]406    'lowq','midq','highq','exq',
[87985ca]407    '2d','1d',
408    'preset','random',
409    'poly','mono',
410    'sasview','ctypes',
411    'nopars','pars',
412    'rel','abs',
[b89f519]413    'linear', 'log', 'q4',
[ba69383]414    'hist','nohist',
[216a9e1]415    ])
416VALUE_OPTIONS = [
417    # Note: random is both a name option and a value option
[3e6aaad]418    'cutoff', 'random', 'Nq', 'res', 'accuracy',
[87985ca]419    ]
420
[7cf2cfd]421def columnize(L, indent="", width=79):
422    column_width = max(len(w) for w in L) + 1
423    num_columns = (width - len(indent)) // column_width
424    num_rows = len(L) // num_columns
425    L = L + [""] * (num_rows*num_columns - len(L))
426    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
427    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
428             for row in zip(*columns)]
429    output = indent + ("\n"+indent).join(lines)
430    return output
431
432
[cd3dba0]433def get_demo_pars(model_definition):
434    info = generate.make_info(model_definition)
435    pars = dict((p[0],p[2]) for p in info['parameters'])
436    pars.update(info['demo'])
[373d1b6]437    return pars
438
[87985ca]439def main():
440    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
441    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
[d547f16]442    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]443    if len(args) == 0:
[7cf2cfd]444        print(USAGE)
445        print(columnize(MODELS, indent="  "))
[87985ca]446        sys.exit(1)
447    if args[0] not in MODELS:
448        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
449        sys.exit(1)
450
451    invalid = [o[1:] for o in opts
[216a9e1]452               if o[1:] not in NAME_OPTIONS
453                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]454    if invalid:
455        print "Invalid options: %s"%(", ".join(invalid))
456        sys.exit(1)
457
[d547f16]458    # Get demo parameters from model definition, or use default parameters
459    # if model does not define demo parameters
460    name = args[0]
[cd3dba0]461    model_definition = core.load_model_definition(name)
462    pars = get_demo_pars(model_definition)
[d547f16]463
[87985ca]464    Nopencl = int(args[1]) if len(args) > 1 else 5
[ba69383]465    Nsasview = int(args[2]) if len(args) > 2 else 1
[87985ca]466
467    # Fill in default polydispersity parameters
468    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
469    for p in pds:
470        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
471        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
472
473    # Fill in parameters given on the command line
474    set_pars = {}
475    for arg in args[3:]:
476        k,v = arg.split('=')
477        if k not in pars:
478            # extract base name without distribution
479            s = set(p.split('_pd')[0] for p in pars)
480            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
481            sys.exit(1)
482        set_pars[k] = float(v) if not v.endswith('type') else v
483
484    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
485
[8a20be5]486if __name__ == "__main__":
[87985ca]487    main()
Note: See TracBrowser for help on using the repository browser.