source: sasmodels/compare.py @ 9d76d29

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9d76d29 was 0763009, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

minor code cleanup

  • Property mode set to 100755
File size: 14.0 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[87985ca]8
[1726b21]9import numpy as np
[473183c]10
[346bc88]11from sasmodels.bumps_model import Model, Experiment, plot_theory, tic
[aa4946b]12from sasmodels import core
[87fce00]13from sasmodels import kerneldll
[87985ca]14from sasmodels.convert import revert_model
[750ffa5]15kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]16
[d547f16]17# List of available models
18ROOT = dirname(__file__)
19MODELS = [basename(f)[:-3]
20          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
21
[8a20be5]22
[aa4946b]23def sasview_model(model_definition, **pars):
[87985ca]24    """
25    Load a sasview model given the model name.
26    """
27    # convert model parameters from sasmodel form to sasview form
28    #print "old",sorted(pars.items())
[aa4946b]29    modelname, pars = revert_model(model_definition, pars)
[87985ca]30    #print "new",sorted(pars.items())
[87c722e]31    sas = __import__('sas.models.'+modelname)
32    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]33    if ModelClass is None:
[87c722e]34        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]35    model = ModelClass()
36
37    for k,v in pars.items():
38        if k.endswith("_pd"):
39            model.dispersion[k[:-3]]['width'] = v
40        elif k.endswith("_pd_n"):
41            model.dispersion[k[:-5]]['npts'] = v
42        elif k.endswith("_pd_nsigma"):
43            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]44        elif k.endswith("_pd_type"):
45            model.dispersion[k[:-8]]['type'] = v
[8a20be5]46        else:
47            model.setParam(k, v)
48    return model
49
[87985ca]50def randomize(p, v):
51    """
52    Randomizing parameter.
53
54    Guess the parameter type from name.
55    """
56    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
57        return v
58    elif any(s in p for s in ('theta','phi','psi')):
59        # orientation in [-180,180], orientation pd in [0,45]
60        if p.endswith('_pd'):
61            return 45*np.random.rand()
62        else:
63            return 360*np.random.rand() - 180
64    elif 'sld' in p:
65        # sld in in [-0.5,10]
66        return 10.5*np.random.rand() - 0.5
67    elif p.endswith('_pd'):
68        # length pd in [0,1]
69        return np.random.rand()
70    else:
[b89f519]71        # values from 0 to 2*x for all other parameters
72        return 2*np.random.rand()*(v if v != 0 else 1)
[87985ca]73
[216a9e1]74def randomize_model(name, pars, seed=None):
75    if seed is None:
76        seed = np.random.randint(1e9)
77    np.random.seed(seed)
78    # Note: the sort guarantees order of calls to random number generator
79    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
80    # The capped cylinder model has a constraint on its parameters
81    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
82        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
83    return pars, seed
84
[87985ca]85def parlist(pars):
86    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
87
88def suppress_pd(pars):
89    """
90    Suppress theta_pd for now until the normalization is resolved.
91
92    May also suppress complete polydispersity of the model to test
93    models more quickly.
94    """
95    for p in pars:
96        if p.endswith("_pd"): pars[p] = 0
97
[216a9e1]98def eval_sasview(name, pars, data, Nevals=1):
[346bc88]99    from sas.models.qsmearing import smear_selection
[216a9e1]100    model = sasview_model(name, **pars)
[346bc88]101    smearer = smear_selection(data, model=model)
[0763009]102    value = None  # silence the linter
[216a9e1]103    toc = tic()
[0763009]104    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
[216a9e1]105        if hasattr(data, 'qx_data'):
[346bc88]106            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
107            index = ((~data.mask) & (~np.isnan(data.data))
108                     & (q >= data.qmin) & (q <= data.qmax))
109            if smearer is not None:
110                smearer.model = model  # because smear_selection has a bug
[3e6aaad]111                smearer.accuracy = data.accuracy
[346bc88]112                smearer.set_index(index)
113                value = smearer.get_value()
114            else:
115                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
[216a9e1]116        else:
117            value = model.evalDistribution(data.x)
[346bc88]118            if smearer is not None:
119                value = smearer(value)
[216a9e1]120    average_time = toc()*1000./Nevals
121    return value, average_time
122
[0763009]123def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
[216a9e1]124    try:
[aa4946b]125        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[216a9e1]126    except Exception,exc:
127        print exc
128        print "... trying again with single precision"
[aa4946b]129        model = core.load_model(model_definition, dtype='single', platform="ocl")
[346bc88]130    problem = Experiment(data, Model(model, **pars), cutoff=cutoff)
[0763009]131    value = None  # silence the linter
[216a9e1]132    toc = tic()
[0763009]133    for _ in range(max(Nevals, 1)):  # force at least one eval
[216a9e1]134        #pars['scale'] = np.random.rand()
135        problem.update()
136        value = problem.theory()
137    average_time = toc()*1000./Nevals
138    return value, average_time
139
[0763009]140def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
[aa4946b]141    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[346bc88]142    problem = Experiment(data, Model(model, **pars), cutoff=cutoff)
[0763009]143    value = None  # silence the linter
[216a9e1]144    toc = tic()
[0763009]145    for _ in range(max(Nevals, 1)):  # force at least one eval
[216a9e1]146        problem.update()
147        value = problem.theory()
148    average_time = toc()*1000./Nevals
149    return value, average_time
150
[3e6aaad]151def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
[216a9e1]152    if is2D:
[87985ca]153        from sasmodels.bumps_model import empty_data2D, set_beam_stop
[346bc88]154        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
[3e6aaad]155        data.accuracy = accuracy
[87985ca]156        set_beam_stop(data, 0.004)
157        index = ~data.mask
[216a9e1]158    else:
159        from sasmodels.bumps_model import empty_data1D
[b89f519]160        if view == 'log':
161            qmax = math.log10(qmax)
162            q = np.logspace(qmax-3, qmax, Nq)
163        else:
164            q = np.linspace(0.001*qmax, qmax, Nq)
[346bc88]165        data = empty_data1D(q, resolution=resolution)
[216a9e1]166        index = slice(None, None)
167    return data, index
168
[a503bfd]169def compare(name, pars, Ncpu, Nocl, opts, set_pars):
[b89f519]170    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
171
[216a9e1]172    opt_values = dict(split
173                      for s in opts for split in ((s.split('='),))
174                      if len(split) == 2)
175    # Sort out data
[29f5536]176    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
[216a9e1]177    Nq = int(opt_values.get('-Nq', '128'))
[346bc88]178    res = float(opt_values.get('-res', '0'))
[3e6aaad]179    accuracy = opt_values.get('-accuracy', 'Low')
[216a9e1]180    is2D = not "-1d" in opts
[3e6aaad]181    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
[216a9e1]182
[87985ca]183
184    # modelling accuracy is determined by dtype and cutoff
185    dtype = 'double' if '-double' in opts else 'single'
[216a9e1]186    cutoff = float(opt_values.get('-cutoff','1e-5'))
[87985ca]187
188    # randomize parameters
[b89f519]189    pars.update(set_pars)
[216a9e1]190    if '-random' in opts or '-random' in opt_values:
191        seed = int(opt_values['-random']) if '-random' in opt_values else None
192        pars, seed = randomize_model(name, pars, seed=seed)
[87985ca]193        print "Randomize using -random=%i"%seed
194
195    # parameter selection
196    if '-mono' in opts:
197        suppress_pd(pars)
198    if '-pars' in opts:
199        print "pars",parlist(pars)
200
[aa4946b]201    model_definition = core.load_model_definition(name)
[87985ca]202    # OpenCl calculation
[a503bfd]203    if Nocl > 0:
[aa4946b]204        ocl, ocl_time = eval_opencl(model_definition, pars, data,
205                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
[346bc88]206        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
[a503bfd]207        #print max(ocl), min(ocl)
[87985ca]208
209    # ctypes/sasview calculation
210    if Ncpu > 0 and "-ctypes" in opts:
[aa4946b]211        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
212                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
[87985ca]213        comp = "ctypes"
[346bc88]214        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
[87985ca]215    elif Ncpu > 0:
[aa4946b]216        cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
[87985ca]217        comp = "sasview"
[346bc88]218        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
[87985ca]219
220    # Compare, but only if computing both forms
[a503bfd]221    if Nocl > 0 and Ncpu > 0:
222        #print "speedup %.2g"%(cpu_time/ocl_time)
223        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
224        #cpu *= max(ocl/cpu)
[346bc88]225        resid = (ocl - cpu)
226        relerr = resid/cpu
[ba69383]227        #bad = (relerr>1e-4)
[a503bfd]228        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
[0763009]229        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
230        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
[87985ca]231
232    # Plot if requested
233    if '-noplot' in opts: return
[1726b21]234    import matplotlib.pyplot as plt
[87985ca]235    if Ncpu > 0:
[a503bfd]236        if Nocl > 0: plt.subplot(131)
[346bc88]237        plot_theory(data, cpu, view=view)
[87985ca]238        plt.title("%s t=%.1f ms"%(comp,cpu_time))
[29f5536]239        cbar_title = "log I"
[a503bfd]240    if Nocl > 0:
[87985ca]241        if Ncpu > 0: plt.subplot(132)
[346bc88]242        plot_theory(data, ocl, view=view)
[a503bfd]243        plt.title("opencl t=%.1f ms"%ocl_time)
[29f5536]244        cbar_title = "log I"
[a503bfd]245    if Ncpu > 0 and Nocl > 0:
[87985ca]246        plt.subplot(133)
[29f5536]247        if '-abs' in opts:
[b89f519]248            err,errstr,errview = resid, "abs err", "linear"
[29f5536]249        else:
[b89f519]250            err,errstr,errview = abs(relerr), "rel err", "log"
[a503bfd]251        #err,errstr = ocl/cpu,"ratio"
[346bc88]252        plot_theory(data, err, view=errview)
253        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
[b89f519]254        cbar_title = errstr if errview=="linear" else "log "+errstr
[29f5536]255    if is2D:
256        h = plt.colorbar()
257        h.ax.set_title(cbar_title)
[ba69383]258
[a503bfd]259    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
[ba69383]260        plt.figure()
[346bc88]261        v = relerr
[ba69383]262        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
263        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
264        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
265        plt.ylabel('P(err)')
266        plt.title('Comparison of single and double precision models for %s'%name)
267
[8a20be5]268    plt.show()
269
[0763009]270def _print_stats(label, err):
271    sorted_err = np.sort(abs(err))
272    p50 = int((len(err)-1)*0.50)
273    p98 = int((len(err)-1)*0.98)
274    data = [
275        "max:%.3e"%sorted_err[-1],
276        "median:%.3e"%sorted_err[p50],
277        "98%%:%.3e"%sorted_err[p98],
278        "rms:%.3e"%np.sqrt(np.mean(err**2)),
279        "zero-offset:%+.3e"%np.mean(err),
280        ]
281    print label,"  ".join(data)
282
283
284
[87985ca]285# ===========================================================================
286#
287USAGE="""
288usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
289
290Compare the speed and value for a model between the SasView original and the
291OpenCL rewrite.
292
293model is the name of the model to compare (see below).
294Nopencl is the number of times to run the OpenCL model (default=5)
295Nsasview is the number of times to run the Sasview model (default=1)
296
297Options (* for default):
298
299    -plot*/-noplot plots or suppress the plot of the model
[2d0aced]300    -single*/-double uses double precision for comparison
[29f5536]301    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[216a9e1]302    -Nq=128 sets the number of Q points in the data set
303    -1d/-2d* computes 1d or 2d data
[2d0aced]304    -preset*/-random[=seed] preset or random parameters
305    -mono/-poly* force monodisperse/polydisperse
306    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
[3e6aaad]307    -cutoff=1e-5* cutoff value for including a point in polydispersity
[2d0aced]308    -pars/-nopars* prints the parameter set or not
309    -abs/-rel* plot relative or absolute error
[b89f519]310    -linear/-log/-q4 intensity scaling
[ba69383]311    -hist/-nohist* plot histogram of relative error
[346bc88]312    -res=0 sets the resolution width dQ/Q if calculating with resolution
[3e6aaad]313    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
[87985ca]314
315Key=value pairs allow you to set specific values to any of the model
316parameters.
317
318Available models:
319
320    %s
321"""
322
[216a9e1]323NAME_OPTIONS = set([
[87985ca]324    'plot','noplot',
325    'single','double',
[29f5536]326    'lowq','midq','highq','exq',
[87985ca]327    '2d','1d',
328    'preset','random',
329    'poly','mono',
330    'sasview','ctypes',
331    'nopars','pars',
332    'rel','abs',
[b89f519]333    'linear', 'log', 'q4',
[ba69383]334    'hist','nohist',
[216a9e1]335    ])
336VALUE_OPTIONS = [
337    # Note: random is both a name option and a value option
[3e6aaad]338    'cutoff', 'random', 'Nq', 'res', 'accuracy',
[87985ca]339    ]
340
[373d1b6]341def get_demo_pars(name):
342    import sasmodels.models
343    __import__('sasmodels.models.'+name)
344    model = getattr(sasmodels.models, name)
345    pars = getattr(model, 'demo', None)
346    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
347    return pars
348
[87985ca]349def main():
350    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
351    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
[d547f16]352    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]353    if len(args) == 0:
354        print(USAGE%models)
355        sys.exit(1)
356    if args[0] not in MODELS:
357        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
358        sys.exit(1)
359
360    invalid = [o[1:] for o in opts
[216a9e1]361               if o[1:] not in NAME_OPTIONS
362                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]363    if invalid:
364        print "Invalid options: %s"%(", ".join(invalid))
365        sys.exit(1)
366
[d547f16]367    # Get demo parameters from model definition, or use default parameters
368    # if model does not define demo parameters
369    name = args[0]
[373d1b6]370    pars = get_demo_pars(name)
[d547f16]371
[87985ca]372    Nopencl = int(args[1]) if len(args) > 1 else 5
[ba69383]373    Nsasview = int(args[2]) if len(args) > 2 else 1
[87985ca]374
375    # Fill in default polydispersity parameters
376    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
377    for p in pds:
378        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
379        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
380
381    # Fill in parameters given on the command line
382    set_pars = {}
383    for arg in args[3:]:
384        k,v = arg.split('=')
385        if k not in pars:
386            # extract base name without distribution
387            s = set(p.split('_pd')[0] for p in pars)
388            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
389            sys.exit(1)
390        set_pars[k] = float(v) if not v.endswith('type') else v
391
392    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
393
[8a20be5]394if __name__ == "__main__":
[87985ca]395    main()
Note: See TracBrowser for help on using the repository browser.