source: sasmodels/sasmodels/compare.py @ b514adf

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since b514adf was b514adf, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

set constraints so multi_compare has fewer spurious errors

  • Property mode set to 100755
File size: 16.6 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[7cf2cfd]8import datetime
[5753e4e]9import traceback
[87985ca]10
[1726b21]11import numpy as np
[473183c]12
[29fc2a3]13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
[e922c5d]17from . import core
18from . import kerneldll
[cd3dba0]19from . import generate
[e922c5d]20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
22from .convert import revert_model
[750ffa5]23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]24
[d547f16]25# List of available models
26MODELS = [basename(f)[:-3]
[e922c5d]27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
[d547f16]28
[7cf2cfd]29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
[8a20be5]67
[aa4946b]68def sasview_model(model_definition, **pars):
[87985ca]69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
73    #print "old",sorted(pars.items())
[aa4946b]74    modelname, pars = revert_model(model_definition, pars)
[87985ca]75    #print "new",sorted(pars.items())
[87c722e]76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]78    if ModelClass is None:
[87c722e]79        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]80    model = ModelClass()
81
82    for k,v in pars.items():
83        if k.endswith("_pd"):
84            model.dispersion[k[:-3]]['width'] = v
85        elif k.endswith("_pd_n"):
86            model.dispersion[k[:-5]]['npts'] = v
87        elif k.endswith("_pd_nsigma"):
88            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]89        elif k.endswith("_pd_type"):
90            model.dispersion[k[:-8]]['type'] = v
[8a20be5]91        else:
92            model.setParam(k, v)
93    return model
94
[87985ca]95def randomize(p, v):
96    """
97    Randomizing parameter.
98
99    Guess the parameter type from name.
100    """
101    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
102        return v
103    elif any(s in p for s in ('theta','phi','psi')):
104        # orientation in [-180,180], orientation pd in [0,45]
105        if p.endswith('_pd'):
106            return 45*np.random.rand()
107        else:
108            return 360*np.random.rand() - 180
109    elif 'sld' in p:
110        # sld in in [-0.5,10]
111        return 10.5*np.random.rand() - 0.5
112    elif p.endswith('_pd'):
113        # length pd in [0,1]
114        return np.random.rand()
115    else:
[b89f519]116        # values from 0 to 2*x for all other parameters
117        return 2*np.random.rand()*(v if v != 0 else 1)
[87985ca]118
[cd3dba0]119def randomize_model(pars, seed=None):
[216a9e1]120    if seed is None:
121        seed = np.random.randint(1e9)
122    np.random.seed(seed)
123    # Note: the sort guarantees order of calls to random number generator
124    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
[cd3dba0]125
126    return pars, seed
127
128def constrain_pars(model_definition, pars):
129    name = model_definition.name
[216a9e1]130    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
131        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
[b514adf]132    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
133        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
134
135    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
136    if name == 'guinier':
137        #q_max = 0.2  # mid q maximum
138        q_max = 1.0  # high q maximum
139        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
140        pars['rg'] = min(pars['rg'],rg_max)
[cd3dba0]141
142    # These constraints are only needed for comparison to sasview
143    if name in ('teubner_strey','broad_peak'):
144        del pars['scale']
145    if name in ('guinier',):
146        del pars['background']
147    if getattr(model_definition, 'category', None) == 'structure-factor':
148        del pars['scale'], pars['background']
149
[216a9e1]150
[87985ca]151def parlist(pars):
152    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
153
154def suppress_pd(pars):
155    """
156    Suppress theta_pd for now until the normalization is resolved.
157
158    May also suppress complete polydispersity of the model to test
159    models more quickly.
160    """
161    for p in pars:
162        if p.endswith("_pd"): pars[p] = 0
163
[7cf2cfd]164def eval_sasview(model_definition, pars, data, Nevals=1):
[346bc88]165    from sas.models.qsmearing import smear_selection
[7cf2cfd]166    model = sasview_model(model_definition, **pars)
[346bc88]167    smearer = smear_selection(data, model=model)
[0763009]168    value = None  # silence the linter
[216a9e1]169    toc = tic()
[0763009]170    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
[216a9e1]171        if hasattr(data, 'qx_data'):
[346bc88]172            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
173            index = ((~data.mask) & (~np.isnan(data.data))
174                     & (q >= data.qmin) & (q <= data.qmax))
175            if smearer is not None:
176                smearer.model = model  # because smear_selection has a bug
[3e6aaad]177                smearer.accuracy = data.accuracy
[346bc88]178                smearer.set_index(index)
179                value = smearer.get_value()
180            else:
181                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
[216a9e1]182        else:
183            value = model.evalDistribution(data.x)
[346bc88]184            if smearer is not None:
185                value = smearer(value)
[216a9e1]186    average_time = toc()*1000./Nevals
187    return value, average_time
188
[0763009]189def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
[216a9e1]190    try:
[aa4946b]191        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[216a9e1]192    except Exception,exc:
193        print exc
194        print "... trying again with single precision"
[aa4946b]195        model = core.load_model(model_definition, dtype='single', platform="ocl")
[7cf2cfd]196    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]197    value = None  # silence the linter
[216a9e1]198    toc = tic()
[0763009]199    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]200        value = calculator(**pars)
[216a9e1]201    average_time = toc()*1000./Nevals
202    return value, average_time
203
[7cf2cfd]204
[0763009]205def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
[aa4946b]206    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[7cf2cfd]207    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]208    value = None  # silence the linter
[216a9e1]209    toc = tic()
[0763009]210    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]211        value = calculator(**pars)
[216a9e1]212    average_time = toc()*1000./Nevals
213    return value, average_time
214
[7cf2cfd]215
[3e6aaad]216def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
[216a9e1]217    if is2D:
[346bc88]218        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
[3e6aaad]219        data.accuracy = accuracy
[87985ca]220        set_beam_stop(data, 0.004)
221        index = ~data.mask
[216a9e1]222    else:
[b89f519]223        if view == 'log':
224            qmax = math.log10(qmax)
225            q = np.logspace(qmax-3, qmax, Nq)
226        else:
227            q = np.linspace(0.001*qmax, qmax, Nq)
[346bc88]228        data = empty_data1D(q, resolution=resolution)
[216a9e1]229        index = slice(None, None)
230    return data, index
231
[a503bfd]232def compare(name, pars, Ncpu, Nocl, opts, set_pars):
[cd3dba0]233    model_definition = core.load_model_definition(name)
234
[b89f519]235    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
236
[216a9e1]237    opt_values = dict(split
238                      for s in opts for split in ((s.split('='),))
239                      if len(split) == 2)
240    # Sort out data
[29f5536]241    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
[216a9e1]242    Nq = int(opt_values.get('-Nq', '128'))
[346bc88]243    res = float(opt_values.get('-res', '0'))
[3e6aaad]244    accuracy = opt_values.get('-accuracy', 'Low')
[73a3e22]245    is2D = "-2d" in opts
[3e6aaad]246    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
[216a9e1]247
[87985ca]248
249    # modelling accuracy is determined by dtype and cutoff
250    dtype = 'double' if '-double' in opts else 'single'
[216a9e1]251    cutoff = float(opt_values.get('-cutoff','1e-5'))
[87985ca]252
253    # randomize parameters
[7cf2cfd]254    #pars.update(set_pars)  # set value before random to control range
[216a9e1]255    if '-random' in opts or '-random' in opt_values:
256        seed = int(opt_values['-random']) if '-random' in opt_values else None
[cd3dba0]257        pars, seed = randomize_model(pars, seed=seed)
[87985ca]258        print "Randomize using -random=%i"%seed
[7cf2cfd]259    pars.update(set_pars)  # set value after random to control value
[b514adf]260    constrain_pars(model_definition, pars)
[87985ca]261
262    # parameter selection
263    if '-mono' in opts:
264        suppress_pd(pars)
265    if '-pars' in opts:
266        print "pars",parlist(pars)
267
268    # OpenCl calculation
[a503bfd]269    if Nocl > 0:
[aa4946b]270        ocl, ocl_time = eval_opencl(model_definition, pars, data,
271                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
[346bc88]272        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
[cd3dba0]273        #print "ocl", ocl
[a503bfd]274        #print max(ocl), min(ocl)
[87985ca]275
276    # ctypes/sasview calculation
277    if Ncpu > 0 and "-ctypes" in opts:
[aa4946b]278        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
279                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
[87985ca]280        comp = "ctypes"
[346bc88]281        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
[87985ca]282    elif Ncpu > 0:
[7cf2cfd]283        try:
284            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
285            comp = "sasview"
286            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background'])
287            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
[cd3dba0]288            #print "sasview",cpu
[7cf2cfd]289        except ImportError:
[5753e4e]290            traceback.print_exc()
[7cf2cfd]291            Ncpu = 0
[87985ca]292
293    # Compare, but only if computing both forms
[a503bfd]294    if Nocl > 0 and Ncpu > 0:
295        #print "speedup %.2g"%(cpu_time/ocl_time)
296        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
297        #cpu *= max(ocl/cpu)
[346bc88]298        resid = (ocl - cpu)
299        relerr = resid/cpu
[ba69383]300        #bad = (relerr>1e-4)
[a503bfd]301        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
[0763009]302        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
303        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
[87985ca]304
305    # Plot if requested
306    if '-noplot' in opts: return
[1726b21]307    import matplotlib.pyplot as plt
[87985ca]308    if Ncpu > 0:
[a503bfd]309        if Nocl > 0: plt.subplot(131)
[7cf2cfd]310        plot_theory(data, cpu, view=view, plot_data=False)
[87985ca]311        plt.title("%s t=%.1f ms"%(comp,cpu_time))
[7cf2cfd]312        #cbar_title = "log I"
[a503bfd]313    if Nocl > 0:
[87985ca]314        if Ncpu > 0: plt.subplot(132)
[7cf2cfd]315        plot_theory(data, ocl, view=view, plot_data=False)
[a503bfd]316        plt.title("opencl t=%.1f ms"%ocl_time)
[7cf2cfd]317        #cbar_title = "log I"
[a503bfd]318    if Ncpu > 0 and Nocl > 0:
[87985ca]319        plt.subplot(133)
[29f5536]320        if '-abs' in opts:
[b89f519]321            err,errstr,errview = resid, "abs err", "linear"
[29f5536]322        else:
[b89f519]323            err,errstr,errview = abs(relerr), "rel err", "log"
[a503bfd]324        #err,errstr = ocl/cpu,"ratio"
[7cf2cfd]325        plot_theory(data, None, resid=err, view=errview, plot_data=False)
[346bc88]326        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
[7cf2cfd]327        #cbar_title = errstr if errview=="linear" else "log "+errstr
328    #if is2D:
329    #    h = plt.colorbar()
330    #    h.ax.set_title(cbar_title)
[ba69383]331
[a503bfd]332    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
[ba69383]333        plt.figure()
[346bc88]334        v = relerr
[ba69383]335        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
336        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
337        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
338        plt.ylabel('P(err)')
339        plt.title('Comparison of single and double precision models for %s'%name)
340
[8a20be5]341    plt.show()
342
[0763009]343def _print_stats(label, err):
344    sorted_err = np.sort(abs(err))
345    p50 = int((len(err)-1)*0.50)
346    p98 = int((len(err)-1)*0.98)
347    data = [
348        "max:%.3e"%sorted_err[-1],
349        "median:%.3e"%sorted_err[p50],
350        "98%%:%.3e"%sorted_err[p98],
351        "rms:%.3e"%np.sqrt(np.mean(err**2)),
352        "zero-offset:%+.3e"%np.mean(err),
353        ]
354    print label,"  ".join(data)
355
356
357
[87985ca]358# ===========================================================================
359#
360USAGE="""
361usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
362
363Compare the speed and value for a model between the SasView original and the
364OpenCL rewrite.
365
366model is the name of the model to compare (see below).
367Nopencl is the number of times to run the OpenCL model (default=5)
368Nsasview is the number of times to run the Sasview model (default=1)
369
370Options (* for default):
371
372    -plot*/-noplot plots or suppress the plot of the model
[2d0aced]373    -single*/-double uses double precision for comparison
[29f5536]374    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[216a9e1]375    -Nq=128 sets the number of Q points in the data set
[73a3e22]376    -1d*/-2d computes 1d or 2d data
[2d0aced]377    -preset*/-random[=seed] preset or random parameters
378    -mono/-poly* force monodisperse/polydisperse
379    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
[3e6aaad]380    -cutoff=1e-5* cutoff value for including a point in polydispersity
[2d0aced]381    -pars/-nopars* prints the parameter set or not
382    -abs/-rel* plot relative or absolute error
[b89f519]383    -linear/-log/-q4 intensity scaling
[ba69383]384    -hist/-nohist* plot histogram of relative error
[346bc88]385    -res=0 sets the resolution width dQ/Q if calculating with resolution
[3e6aaad]386    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
[87985ca]387
388Key=value pairs allow you to set specific values to any of the model
389parameters.
390
391Available models:
392"""
393
[7cf2cfd]394
[216a9e1]395NAME_OPTIONS = set([
[87985ca]396    'plot','noplot',
397    'single','double',
[29f5536]398    'lowq','midq','highq','exq',
[87985ca]399    '2d','1d',
400    'preset','random',
401    'poly','mono',
402    'sasview','ctypes',
403    'nopars','pars',
404    'rel','abs',
[b89f519]405    'linear', 'log', 'q4',
[ba69383]406    'hist','nohist',
[216a9e1]407    ])
408VALUE_OPTIONS = [
409    # Note: random is both a name option and a value option
[3e6aaad]410    'cutoff', 'random', 'Nq', 'res', 'accuracy',
[87985ca]411    ]
412
[7cf2cfd]413def columnize(L, indent="", width=79):
414    column_width = max(len(w) for w in L) + 1
415    num_columns = (width - len(indent)) // column_width
416    num_rows = len(L) // num_columns
417    L = L + [""] * (num_rows*num_columns - len(L))
418    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
419    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
420             for row in zip(*columns)]
421    output = indent + ("\n"+indent).join(lines)
422    return output
423
424
[cd3dba0]425def get_demo_pars(model_definition):
426    info = generate.make_info(model_definition)
427    pars = dict((p[0],p[2]) for p in info['parameters'])
428    pars.update(info['demo'])
[373d1b6]429    return pars
430
[87985ca]431def main():
432    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
433    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
[d547f16]434    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]435    if len(args) == 0:
[7cf2cfd]436        print(USAGE)
437        print(columnize(MODELS, indent="  "))
[87985ca]438        sys.exit(1)
439    if args[0] not in MODELS:
440        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
441        sys.exit(1)
442
443    invalid = [o[1:] for o in opts
[216a9e1]444               if o[1:] not in NAME_OPTIONS
445                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]446    if invalid:
447        print "Invalid options: %s"%(", ".join(invalid))
448        sys.exit(1)
449
[d547f16]450    # Get demo parameters from model definition, or use default parameters
451    # if model does not define demo parameters
452    name = args[0]
[cd3dba0]453    model_definition = core.load_model_definition(name)
454    pars = get_demo_pars(model_definition)
[d547f16]455
[87985ca]456    Nopencl = int(args[1]) if len(args) > 1 else 5
[ba69383]457    Nsasview = int(args[2]) if len(args) > 2 else 1
[87985ca]458
459    # Fill in default polydispersity parameters
460    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
461    for p in pds:
462        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
463        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
464
465    # Fill in parameters given on the command line
466    set_pars = {}
467    for arg in args[3:]:
468        k,v = arg.split('=')
469        if k not in pars:
470            # extract base name without distribution
471            s = set(p.split('_pd')[0] for p in pars)
472            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
473            sys.exit(1)
474        set_pars[k] = float(v) if not v.endswith('type') else v
475
476    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
477
[8a20be5]478if __name__ == "__main__":
[87985ca]479    main()
Note: See TracBrowser for help on using the repository browser.