source: sasmodels/sasmodels/compare.py @ 9404dd3

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9404dd3 was 9404dd3, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

python 3.x support

  • Property mode set to 100755
File size: 18.0 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[7cf2cfd]8import datetime
[5753e4e]9import traceback
[87985ca]10
[1726b21]11import numpy as np
[473183c]12
[29fc2a3]13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
[e922c5d]17from . import core
18from . import kerneldll
[cd3dba0]19from . import generate
[e922c5d]20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
22from .convert import revert_model
[750ffa5]23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]24
[d547f16]25# List of available models
26MODELS = [basename(f)[:-3]
[e922c5d]27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
[d547f16]28
[7cf2cfd]29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
[8a20be5]67
[aa4946b]68def sasview_model(model_definition, **pars):
[87985ca]69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
[9404dd3]73    #print("old",sorted(pars.items()))
[aa4946b]74    modelname, pars = revert_model(model_definition, pars)
[9404dd3]75    #print("new",sorted(pars.items()))
[87c722e]76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]78    if ModelClass is None:
[87c722e]79        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]80    model = ModelClass()
81
82    for k,v in pars.items():
83        if k.endswith("_pd"):
84            model.dispersion[k[:-3]]['width'] = v
85        elif k.endswith("_pd_n"):
86            model.dispersion[k[:-5]]['npts'] = v
87        elif k.endswith("_pd_nsigma"):
88            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]89        elif k.endswith("_pd_type"):
90            model.dispersion[k[:-8]]['type'] = v
[8a20be5]91        else:
92            model.setParam(k, v)
93    return model
94
[87985ca]95def randomize(p, v):
96    """
97    Randomizing parameter.
98
99    Guess the parameter type from name.
100    """
101    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
102        return v
103    elif any(s in p for s in ('theta','phi','psi')):
104        # orientation in [-180,180], orientation pd in [0,45]
105        if p.endswith('_pd'):
106            return 45*np.random.rand()
107        else:
108            return 360*np.random.rand() - 180
109    elif 'sld' in p:
110        # sld in in [-0.5,10]
111        return 10.5*np.random.rand() - 0.5
112    elif p.endswith('_pd'):
113        # length pd in [0,1]
114        return np.random.rand()
115    else:
[b89f519]116        # values from 0 to 2*x for all other parameters
117        return 2*np.random.rand()*(v if v != 0 else 1)
[87985ca]118
[cd3dba0]119def randomize_model(pars, seed=None):
[216a9e1]120    if seed is None:
121        seed = np.random.randint(1e9)
122    np.random.seed(seed)
123    # Note: the sort guarantees order of calls to random number generator
124    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
[cd3dba0]125
126    return pars, seed
127
128def constrain_pars(model_definition, pars):
129    name = model_definition.name
[216a9e1]130    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
131        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
[b514adf]132    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
133        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
134
135    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
136    if name == 'guinier':
137        #q_max = 0.2  # mid q maximum
138        q_max = 1.0  # high q maximum
139        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
140        pars['rg'] = min(pars['rg'],rg_max)
[cd3dba0]141
142    # These constraints are only needed for comparison to sasview
143    if name in ('teubner_strey','broad_peak'):
144        del pars['scale']
145    if name in ('guinier',):
146        del pars['background']
147    if getattr(model_definition, 'category', None) == 'structure-factor':
148        del pars['scale'], pars['background']
149
[216a9e1]150
[87985ca]151def parlist(pars):
152    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
153
154def suppress_pd(pars):
155    """
156    Suppress theta_pd for now until the normalization is resolved.
157
158    May also suppress complete polydispersity of the model to test
159    models more quickly.
160    """
161    for p in pars:
162        if p.endswith("_pd"): pars[p] = 0
163
[7cf2cfd]164def eval_sasview(model_definition, pars, data, Nevals=1):
[dc056b9]165    # importing sas here so that the error message will be that sas failed to
166    # import rather than the more obscure smear_selection not imported error
[2bebe2b]167    import sas
[346bc88]168    from sas.models.qsmearing import smear_selection
[7cf2cfd]169    model = sasview_model(model_definition, **pars)
[346bc88]170    smearer = smear_selection(data, model=model)
[0763009]171    value = None  # silence the linter
[216a9e1]172    toc = tic()
[0763009]173    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
[216a9e1]174        if hasattr(data, 'qx_data'):
[346bc88]175            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
176            index = ((~data.mask) & (~np.isnan(data.data))
177                     & (q >= data.qmin) & (q <= data.qmax))
178            if smearer is not None:
179                smearer.model = model  # because smear_selection has a bug
[3e6aaad]180                smearer.accuracy = data.accuracy
[346bc88]181                smearer.set_index(index)
182                value = smearer.get_value()
183            else:
184                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
[216a9e1]185        else:
186            value = model.evalDistribution(data.x)
[346bc88]187            if smearer is not None:
188                value = smearer(value)
[216a9e1]189    average_time = toc()*1000./Nevals
190    return value, average_time
191
[0763009]192def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
[216a9e1]193    try:
[aa4946b]194        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[9404dd3]195    except Exception as exc:
196        print(exc)
197        print("... trying again with single precision")
[aa4946b]198        model = core.load_model(model_definition, dtype='single', platform="ocl")
[7cf2cfd]199    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]200    value = None  # silence the linter
[216a9e1]201    toc = tic()
[0763009]202    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]203        value = calculator(**pars)
[216a9e1]204    average_time = toc()*1000./Nevals
205    return value, average_time
206
[7cf2cfd]207
[0763009]208def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
[aa4946b]209    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[7cf2cfd]210    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]211    value = None  # silence the linter
[216a9e1]212    toc = tic()
[0763009]213    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]214        value = calculator(**pars)
[216a9e1]215    average_time = toc()*1000./Nevals
216    return value, average_time
217
[7cf2cfd]218
[3e6aaad]219def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
[216a9e1]220    if is2D:
[346bc88]221        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
[3e6aaad]222        data.accuracy = accuracy
[87985ca]223        set_beam_stop(data, 0.004)
224        index = ~data.mask
[216a9e1]225    else:
[b89f519]226        if view == 'log':
227            qmax = math.log10(qmax)
228            q = np.logspace(qmax-3, qmax, Nq)
229        else:
230            q = np.linspace(0.001*qmax, qmax, Nq)
[346bc88]231        data = empty_data1D(q, resolution=resolution)
[216a9e1]232        index = slice(None, None)
233    return data, index
234
[4b41184]235def compare(name, pars, Ncomp, Nbase, opts, set_pars):
[cd3dba0]236    model_definition = core.load_model_definition(name)
237
[5edfe12]238    view = ('linear' if '-linear' in opts
239            else 'log' if '-log' in opts
240            else 'q4' if '-q4' in opts
241            else 'log')
[b89f519]242
[216a9e1]243    opt_values = dict(split
244                      for s in opts for split in ((s.split('='),))
245                      if len(split) == 2)
246    # Sort out data
[5edfe12]247    qmax = (10.0 if '-exq' in opts
248            else 1.0 if '-highq' in opts
249            else 0.2 if '-midq' in opts
250            else 0.05)
[216a9e1]251    Nq = int(opt_values.get('-Nq', '128'))
[346bc88]252    res = float(opt_values.get('-res', '0'))
[3e6aaad]253    accuracy = opt_values.get('-accuracy', 'Low')
[73a3e22]254    is2D = "-2d" in opts
[3e6aaad]255    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
[216a9e1]256
[87985ca]257
258    # modelling accuracy is determined by dtype and cutoff
[e1ace4d]259    dtype = ('longdouble' if '-quad' in opts
[5edfe12]260             else 'double' if '-double' in opts
261             else 'single')
[216a9e1]262    cutoff = float(opt_values.get('-cutoff','1e-5'))
[87985ca]263
264    # randomize parameters
[7cf2cfd]265    #pars.update(set_pars)  # set value before random to control range
[216a9e1]266    if '-random' in opts or '-random' in opt_values:
267        seed = int(opt_values['-random']) if '-random' in opt_values else None
[cd3dba0]268        pars, seed = randomize_model(pars, seed=seed)
[9404dd3]269        print("Randomize using -random=%i"%seed)
[7cf2cfd]270    pars.update(set_pars)  # set value after random to control value
[b514adf]271    constrain_pars(model_definition, pars)
[87985ca]272
273    # parameter selection
274    if '-mono' in opts:
275        suppress_pd(pars)
276    if '-pars' in opts:
[9404dd3]277        print("pars "+str(parlist(pars)))
[87985ca]278
[4b41184]279    # Base calculation
280    if 0:
281        from sasmodels.models import sphere as target
282        base_name = target.name
283        base, base_time = eval_ctypes(target, pars, data,
284                         dtype='longdouble', cutoff=0., Nevals=Ncomp)
285    elif Nbase > 0 and "-ctypes" in opts and "-sasview" in opts:
[319ab14]286        try:
[4b41184]287            base, base_time = eval_sasview(model_definition, pars, data, Ncomp)
288            base_name = "sasview"
[9404dd3]289            #print("base/sasview", (base-pars['background'])/(comp-pars['background']))
290            print("sasview t=%.1f ms, intensity=%.0f"%(base_time, sum(base)))
291            #print("sasview",comp)
[319ab14]292        except ImportError:
293            traceback.print_exc()
[1ec7efa]294            Nbase = 0
[4b41184]295    elif Nbase > 0:
296        base, base_time = eval_opencl(model_definition, pars, data,
297                                    dtype=dtype, cutoff=cutoff, Nevals=Nbase)
298        base_name = "ocl"
[9404dd3]299        print("opencl t=%.1f ms, intensity=%.0f"%(base_time, sum(base)))
300        #print("base " + base)
301        #print(max(base), min(base))
[4b41184]302
303    # Comparison calculation
304    if Ncomp > 0 and "-ctypes" in opts:
305        comp, comp_time = eval_ctypes(model_definition, pars, data,
306                                    dtype=dtype, cutoff=cutoff, Nevals=Ncomp)
307        comp_name = "ctypes"
[9404dd3]308        print("ctypes t=%.1f ms, intensity=%.0f"%(comp_time, sum(comp)))
[4b41184]309    elif Ncomp > 0:
[7cf2cfd]310        try:
[4b41184]311            comp, comp_time = eval_sasview(model_definition, pars, data, Ncomp)
312            comp_name = "sasview"
[9404dd3]313            #print("base/sasview", (base-pars['background'])/(comp-pars['background']))
314            print("sasview t=%.1f ms, intensity=%.0f"%(comp_time, sum(comp)))
315            #print("sasview",comp)
[7cf2cfd]316        except ImportError:
[5753e4e]317            traceback.print_exc()
[4b41184]318            Ncomp = 0
[87985ca]319
320    # Compare, but only if computing both forms
[4b41184]321    if Nbase > 0 and Ncomp > 0:
[9404dd3]322        #print("speedup %.2g"%(comp_time/base_time))
323        #print("max |base/comp|", max(abs(base/comp)), "%.15g"%max(abs(base)), "%.15g"%max(abs(comp)))
[4b41184]324        #comp *= max(base/comp)
325        resid = (base - comp)
326        relerr = resid/comp
[ba69383]327        #bad = (relerr>1e-4)
[9404dd3]328        #print(relerr[bad],comp[bad],base[bad],data.qx_data[bad],data.qy_data[bad])
[4b41184]329        _print_stats("|%s-%s|"%(base_name,comp_name)+(" "*(3+len(comp_name))), resid)
330        _print_stats("|(%s-%s)/%s|"%(base_name,comp_name,comp_name), relerr)
[87985ca]331
332    # Plot if requested
333    if '-noplot' in opts: return
[1726b21]334    import matplotlib.pyplot as plt
[4b41184]335    if Ncomp > 0:
336        if Nbase > 0: plt.subplot(131)
337        plot_theory(data, comp, view=view, plot_data=False)
338        plt.title("%s t=%.1f ms"%(comp_name,comp_time))
[7cf2cfd]339        #cbar_title = "log I"
[4b41184]340    if Nbase > 0:
341        if Ncomp > 0: plt.subplot(132)
342        plot_theory(data, base, view=view, plot_data=False)
343        plt.title("%s t=%.1f ms"%(base_name,base_time))
[7cf2cfd]344        #cbar_title = "log I"
[4b41184]345    if Ncomp > 0 and Nbase > 0:
[87985ca]346        plt.subplot(133)
[29f5536]347        if '-abs' in opts:
[b89f519]348            err,errstr,errview = resid, "abs err", "linear"
[29f5536]349        else:
[b89f519]350            err,errstr,errview = abs(relerr), "rel err", "log"
[4b41184]351        #err,errstr = base/comp,"ratio"
[7cf2cfd]352        plot_theory(data, None, resid=err, view=errview, plot_data=False)
[346bc88]353        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
[7cf2cfd]354        #cbar_title = errstr if errview=="linear" else "log "+errstr
355    #if is2D:
356    #    h = plt.colorbar()
357    #    h.ax.set_title(cbar_title)
[ba69383]358
[4b41184]359    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
[ba69383]360        plt.figure()
[346bc88]361        v = relerr
[ba69383]362        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
363        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
364        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
365        plt.ylabel('P(err)')
366        plt.title('Comparison of single and double precision models for %s'%name)
367
[8a20be5]368    plt.show()
369
[0763009]370def _print_stats(label, err):
371    sorted_err = np.sort(abs(err))
372    p50 = int((len(err)-1)*0.50)
373    p98 = int((len(err)-1)*0.98)
374    data = [
375        "max:%.3e"%sorted_err[-1],
376        "median:%.3e"%sorted_err[p50],
377        "98%%:%.3e"%sorted_err[p98],
378        "rms:%.3e"%np.sqrt(np.mean(err**2)),
379        "zero-offset:%+.3e"%np.mean(err),
380        ]
[9404dd3]381    print(label+"  ".join(data))
[0763009]382
383
384
[87985ca]385# ===========================================================================
386#
387USAGE="""
388usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
389
390Compare the speed and value for a model between the SasView original and the
391OpenCL rewrite.
392
393model is the name of the model to compare (see below).
394Nopencl is the number of times to run the OpenCL model (default=5)
395Nsasview is the number of times to run the Sasview model (default=1)
396
397Options (* for default):
398
399    -plot*/-noplot plots or suppress the plot of the model
[e1ace4d]400    -single*/-double/-quad use single/double/quad precision for comparison
[29f5536]401    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[216a9e1]402    -Nq=128 sets the number of Q points in the data set
[73a3e22]403    -1d*/-2d computes 1d or 2d data
[2d0aced]404    -preset*/-random[=seed] preset or random parameters
405    -mono/-poly* force monodisperse/polydisperse
[319ab14]406    -ctypes/-sasview* selects gpu:cpu, gpu:sasview, or sasview:cpu if both
[3e6aaad]407    -cutoff=1e-5* cutoff value for including a point in polydispersity
[2d0aced]408    -pars/-nopars* prints the parameter set or not
409    -abs/-rel* plot relative or absolute error
[b89f519]410    -linear/-log/-q4 intensity scaling
[ba69383]411    -hist/-nohist* plot histogram of relative error
[346bc88]412    -res=0 sets the resolution width dQ/Q if calculating with resolution
[3e6aaad]413    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
[87985ca]414
415Key=value pairs allow you to set specific values to any of the model
416parameters.
417
418Available models:
419"""
420
[7cf2cfd]421
[216a9e1]422NAME_OPTIONS = set([
[87985ca]423    'plot','noplot',
[319ab14]424    'single','double','quad',
[29f5536]425    'lowq','midq','highq','exq',
[87985ca]426    '2d','1d',
427    'preset','random',
428    'poly','mono',
429    'sasview','ctypes',
430    'nopars','pars',
431    'rel','abs',
[b89f519]432    'linear', 'log', 'q4',
[ba69383]433    'hist','nohist',
[216a9e1]434    ])
435VALUE_OPTIONS = [
436    # Note: random is both a name option and a value option
[3e6aaad]437    'cutoff', 'random', 'Nq', 'res', 'accuracy',
[87985ca]438    ]
439
[7cf2cfd]440def columnize(L, indent="", width=79):
441    column_width = max(len(w) for w in L) + 1
442    num_columns = (width - len(indent)) // column_width
443    num_rows = len(L) // num_columns
444    L = L + [""] * (num_rows*num_columns - len(L))
445    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
446    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
447             for row in zip(*columns)]
448    output = indent + ("\n"+indent).join(lines)
449    return output
450
451
[cd3dba0]452def get_demo_pars(model_definition):
453    info = generate.make_info(model_definition)
454    pars = dict((p[0],p[2]) for p in info['parameters'])
455    pars.update(info['demo'])
[373d1b6]456    return pars
457
[87985ca]458def main():
459    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
[319ab14]460    popts = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' in arg]
461    args = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' not in arg]
[d547f16]462    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]463    if len(args) == 0:
[7cf2cfd]464        print(USAGE)
465        print(columnize(MODELS, indent="  "))
[87985ca]466        sys.exit(1)
467    if args[0] not in MODELS:
[9404dd3]468        print("Model %r not available. Use one of:\n    %s"%(args[0],models))
[87985ca]469        sys.exit(1)
[319ab14]470    if len(args) > 3:
471        print("expected parameters: model Nopencl Nsasview")
[87985ca]472
473    invalid = [o[1:] for o in opts
[216a9e1]474               if o[1:] not in NAME_OPTIONS
475                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]476    if invalid:
[9404dd3]477        print("Invalid options: %s"%(", ".join(invalid)))
[87985ca]478        sys.exit(1)
479
[d547f16]480    # Get demo parameters from model definition, or use default parameters
481    # if model does not define demo parameters
482    name = args[0]
[cd3dba0]483    model_definition = core.load_model_definition(name)
484    pars = get_demo_pars(model_definition)
[d547f16]485
[4b41184]486    Ncomp = int(args[1]) if len(args) > 1 else 5
487    Nbase = int(args[2]) if len(args) > 2 else 1
[87985ca]488
489    # Fill in default polydispersity parameters
490    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
491    for p in pds:
492        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
493        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
494
495    # Fill in parameters given on the command line
496    set_pars = {}
[319ab14]497    for arg in popts:
498        k,v = arg.split('=',1)
[87985ca]499        if k not in pars:
500            # extract base name without distribution
501            s = set(p.split('_pd')[0] for p in pars)
[9404dd3]502            print("%r invalid; parameters are: %s"%(k,", ".join(sorted(s))))
[87985ca]503            sys.exit(1)
504        set_pars[k] = float(v) if not v.endswith('type') else v
505
[4b41184]506    compare(name, pars, Ncomp, Nbase, opts, set_pars)
[87985ca]507
[8a20be5]508if __name__ == "__main__":
[87985ca]509    main()
Note: See TracBrowser for help on using the repository browser.