source: sasmodels/compare.py @ af1d68c

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since af1d68c was aa4946b, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

refactor so kernels are loaded via core.load_model

  • Property mode set to 100755
File size: 12.9 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[87985ca]8
[1726b21]9import numpy as np
[473183c]10
[87985ca]11from sasmodels.bumps_model import BumpsModel, plot_data, tic
[aa4946b]12from sasmodels import core
[87fce00]13from sasmodels import kerneldll
[87985ca]14from sasmodels.convert import revert_model
[750ffa5]15kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]16
[d547f16]17# List of available models
18ROOT = dirname(__file__)
19MODELS = [basename(f)[:-3]
20          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
21
[8a20be5]22
[aa4946b]23def sasview_model(model_definition, **pars):
[87985ca]24    """
25    Load a sasview model given the model name.
26    """
27    # convert model parameters from sasmodel form to sasview form
28    #print "old",sorted(pars.items())
[aa4946b]29    modelname, pars = revert_model(model_definition, pars)
[87985ca]30    #print "new",sorted(pars.items())
[87c722e]31    sas = __import__('sas.models.'+modelname)
32    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]33    if ModelClass is None:
[87c722e]34        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]35    model = ModelClass()
36
37    for k,v in pars.items():
38        if k.endswith("_pd"):
39            model.dispersion[k[:-3]]['width'] = v
40        elif k.endswith("_pd_n"):
41            model.dispersion[k[:-5]]['npts'] = v
42        elif k.endswith("_pd_nsigma"):
43            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]44        elif k.endswith("_pd_type"):
45            model.dispersion[k[:-8]]['type'] = v
[8a20be5]46        else:
47            model.setParam(k, v)
48    return model
49
[87985ca]50def randomize(p, v):
51    """
52    Randomizing parameter.
53
54    Guess the parameter type from name.
55    """
56    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
57        return v
58    elif any(s in p for s in ('theta','phi','psi')):
59        # orientation in [-180,180], orientation pd in [0,45]
60        if p.endswith('_pd'):
61            return 45*np.random.rand()
62        else:
63            return 360*np.random.rand() - 180
64    elif 'sld' in p:
65        # sld in in [-0.5,10]
66        return 10.5*np.random.rand() - 0.5
67    elif p.endswith('_pd'):
68        # length pd in [0,1]
69        return np.random.rand()
70    else:
[b89f519]71        # values from 0 to 2*x for all other parameters
72        return 2*np.random.rand()*(v if v != 0 else 1)
[87985ca]73
[216a9e1]74def randomize_model(name, pars, seed=None):
75    if seed is None:
76        seed = np.random.randint(1e9)
77    np.random.seed(seed)
78    # Note: the sort guarantees order of calls to random number generator
79    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
80    # The capped cylinder model has a constraint on its parameters
81    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
82        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
83    return pars, seed
84
[87985ca]85def parlist(pars):
86    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
87
88def suppress_pd(pars):
89    """
90    Suppress theta_pd for now until the normalization is resolved.
91
92    May also suppress complete polydispersity of the model to test
93    models more quickly.
94    """
95    for p in pars:
96        if p.endswith("_pd"): pars[p] = 0
97
[216a9e1]98def eval_sasview(name, pars, data, Nevals=1):
99    model = sasview_model(name, **pars)
100    toc = tic()
101    for _ in range(Nevals):
102        if hasattr(data, 'qx_data'):
103            value = model.evalDistribution([data.qx_data, data.qy_data])
104        else:
105            value = model.evalDistribution(data.x)
106    average_time = toc()*1000./Nevals
107    return value, average_time
108
[aa4946b]109def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0):
[216a9e1]110    try:
[aa4946b]111        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[216a9e1]112    except Exception,exc:
113        print exc
114        print "... trying again with single precision"
[aa4946b]115        model = core.load_model(model_definition, dtype='single', platform="ocl")
[216a9e1]116    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
117    toc = tic()
118    for _ in range(Nevals):
119        #pars['scale'] = np.random.rand()
120        problem.update()
121        value = problem.theory()
122    average_time = toc()*1000./Nevals
123    return value, average_time
124
[aa4946b]125def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0):
126    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[216a9e1]127    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
128    toc = tic()
129    for _ in range(Nevals):
130        problem.update()
131        value = problem.theory()
132    average_time = toc()*1000./Nevals
133    return value, average_time
134
[b89f519]135def make_data(qmax, is2D, Nq=128, view='log'):
[216a9e1]136    if is2D:
[87985ca]137        from sasmodels.bumps_model import empty_data2D, set_beam_stop
[216a9e1]138        data = empty_data2D(np.linspace(-qmax, qmax, Nq))
[87985ca]139        set_beam_stop(data, 0.004)
140        index = ~data.mask
[216a9e1]141    else:
142        from sasmodels.bumps_model import empty_data1D
[b89f519]143        if view == 'log':
144            qmax = math.log10(qmax)
145            q = np.logspace(qmax-3, qmax, Nq)
146        else:
147            q = np.linspace(0.001*qmax, qmax, Nq)
148        data = empty_data1D(q)
[216a9e1]149        index = slice(None, None)
150    return data, index
151
[a503bfd]152def compare(name, pars, Ncpu, Nocl, opts, set_pars):
[b89f519]153    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
154
[216a9e1]155    opt_values = dict(split
156                      for s in opts for split in ((s.split('='),))
157                      if len(split) == 2)
158    # Sort out data
[29f5536]159    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
[216a9e1]160    Nq = int(opt_values.get('-Nq', '128'))
161    is2D = not "-1d" in opts
[b89f519]162    data, index = make_data(qmax, is2D, Nq, view=view)
[216a9e1]163
[87985ca]164
165    # modelling accuracy is determined by dtype and cutoff
166    dtype = 'double' if '-double' in opts else 'single'
[216a9e1]167    cutoff = float(opt_values.get('-cutoff','1e-5'))
[87985ca]168
169    # randomize parameters
[b89f519]170    pars.update(set_pars)
[216a9e1]171    if '-random' in opts or '-random' in opt_values:
172        seed = int(opt_values['-random']) if '-random' in opt_values else None
173        pars, seed = randomize_model(name, pars, seed=seed)
[87985ca]174        print "Randomize using -random=%i"%seed
175
176    # parameter selection
177    if '-mono' in opts:
178        suppress_pd(pars)
179    if '-pars' in opts:
180        print "pars",parlist(pars)
181
[aa4946b]182    model_definition = core.load_model_definition(name)
[87985ca]183    # OpenCl calculation
[a503bfd]184    if Nocl > 0:
[aa4946b]185        ocl, ocl_time = eval_opencl(model_definition, pars, data,
186                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
[a503bfd]187        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl[index]))
188        #print max(ocl), min(ocl)
[87985ca]189
190    # ctypes/sasview calculation
191    if Ncpu > 0 and "-ctypes" in opts:
[aa4946b]192        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
193                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
[87985ca]194        comp = "ctypes"
195        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
196    elif Ncpu > 0:
[aa4946b]197        cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
[87985ca]198        comp = "sasview"
199        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
200
201    # Compare, but only if computing both forms
[a503bfd]202    if Nocl > 0 and Ncpu > 0:
203        #print "speedup %.2g"%(cpu_time/ocl_time)
204        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
205        #cpu *= max(ocl/cpu)
206        resid, relerr = np.zeros_like(ocl), np.zeros_like(ocl)
207        resid[index] = (ocl - cpu)[index]
[87985ca]208        relerr[index] = resid[index]/cpu[index]
[ba69383]209        #bad = (relerr>1e-4)
[a503bfd]210        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
[aa4946b]211        def stats(label,err):
212            sorted_err = np.sort(abs(err))
213            p50 = int((len(err)-1)*0.50)
214            p98 = int((len(err)-1)*0.98)
215            data = [
216                "max:%.3e"%sorted_err[-1],
217                "median:%.3e"%sorted_err[p50],
218                "98%%:%.3e"%sorted_err[p98],
219                "rms:%.3e"%np.sqrt(np.mean(err**2)),
220                "zero-offset:%+.3e"%np.mean(err),
221                ]
222            print label,"  ".join(data)
223        stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid[index])
224        stats("|(ocl-%s)/%s|"%(comp,comp), relerr[index])
[87985ca]225
226    # Plot if requested
227    if '-noplot' in opts: return
[1726b21]228    import matplotlib.pyplot as plt
[87985ca]229    if Ncpu > 0:
[a503bfd]230        if Nocl > 0: plt.subplot(131)
[b89f519]231        plot_data(data, cpu, view=view)
[87985ca]232        plt.title("%s t=%.1f ms"%(comp,cpu_time))
[29f5536]233        cbar_title = "log I"
[a503bfd]234    if Nocl > 0:
[87985ca]235        if Ncpu > 0: plt.subplot(132)
[b89f519]236        plot_data(data, ocl, view=view)
[a503bfd]237        plt.title("opencl t=%.1f ms"%ocl_time)
[29f5536]238        cbar_title = "log I"
[a503bfd]239    if Ncpu > 0 and Nocl > 0:
[87985ca]240        plt.subplot(133)
[29f5536]241        if '-abs' in opts:
[b89f519]242            err,errstr,errview = resid, "abs err", "linear"
[29f5536]243        else:
[b89f519]244            err,errstr,errview = abs(relerr), "rel err", "log"
[a503bfd]245        #err,errstr = ocl/cpu,"ratio"
[b89f519]246        plot_data(data, err, view=errview)
[87985ca]247        plt.title("max %s = %.3g"%(errstr, max(abs(err[index]))))
[b89f519]248        cbar_title = errstr if errview=="linear" else "log "+errstr
[29f5536]249    if is2D:
250        h = plt.colorbar()
251        h.ax.set_title(cbar_title)
[ba69383]252
[a503bfd]253    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
[ba69383]254        plt.figure()
255        v = relerr[index]
256        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
257        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
258        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
259        plt.ylabel('P(err)')
260        plt.title('Comparison of single and double precision models for %s'%name)
261
[8a20be5]262    plt.show()
263
[87985ca]264# ===========================================================================
265#
266USAGE="""
267usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
268
269Compare the speed and value for a model between the SasView original and the
270OpenCL rewrite.
271
272model is the name of the model to compare (see below).
273Nopencl is the number of times to run the OpenCL model (default=5)
274Nsasview is the number of times to run the Sasview model (default=1)
275
276Options (* for default):
277
278    -plot*/-noplot plots or suppress the plot of the model
[2d0aced]279    -single*/-double uses double precision for comparison
[29f5536]280    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[216a9e1]281    -Nq=128 sets the number of Q points in the data set
282    -1d/-2d* computes 1d or 2d data
[2d0aced]283    -preset*/-random[=seed] preset or random parameters
284    -mono/-poly* force monodisperse/polydisperse
285    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
286    -cutoff=1e-5*/value cutoff for including a point in polydispersity
287    -pars/-nopars* prints the parameter set or not
288    -abs/-rel* plot relative or absolute error
[b89f519]289    -linear/-log/-q4 intensity scaling
[ba69383]290    -hist/-nohist* plot histogram of relative error
[87985ca]291
292Key=value pairs allow you to set specific values to any of the model
293parameters.
294
295Available models:
296
297    %s
298"""
299
[216a9e1]300NAME_OPTIONS = set([
[87985ca]301    'plot','noplot',
302    'single','double',
[29f5536]303    'lowq','midq','highq','exq',
[87985ca]304    '2d','1d',
305    'preset','random',
306    'poly','mono',
307    'sasview','ctypes',
308    'nopars','pars',
309    'rel','abs',
[b89f519]310    'linear', 'log', 'q4',
[ba69383]311    'hist','nohist',
[216a9e1]312    ])
313VALUE_OPTIONS = [
314    # Note: random is both a name option and a value option
315    'cutoff', 'random', 'Nq',
[87985ca]316    ]
317
[373d1b6]318def get_demo_pars(name):
319    import sasmodels.models
320    __import__('sasmodels.models.'+name)
321    model = getattr(sasmodels.models, name)
322    pars = getattr(model, 'demo', None)
323    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
324    return pars
325
[87985ca]326def main():
327    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
328    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
[d547f16]329    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]330    if len(args) == 0:
331        print(USAGE%models)
332        sys.exit(1)
333    if args[0] not in MODELS:
334        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
335        sys.exit(1)
336
337    invalid = [o[1:] for o in opts
[216a9e1]338               if o[1:] not in NAME_OPTIONS
339                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]340    if invalid:
341        print "Invalid options: %s"%(", ".join(invalid))
342        sys.exit(1)
343
[d547f16]344    # Get demo parameters from model definition, or use default parameters
345    # if model does not define demo parameters
346    name = args[0]
[373d1b6]347    pars = get_demo_pars(name)
[d547f16]348
[87985ca]349    Nopencl = int(args[1]) if len(args) > 1 else 5
[ba69383]350    Nsasview = int(args[2]) if len(args) > 2 else 1
[87985ca]351
352    # Fill in default polydispersity parameters
353    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
354    for p in pds:
355        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
356        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
357
358    # Fill in parameters given on the command line
359    set_pars = {}
360    for arg in args[3:]:
361        k,v = arg.split('=')
362        if k not in pars:
363            # extract base name without distribution
364            s = set(p.split('_pd')[0] for p in pars)
365            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
366            sys.exit(1)
367        set_pars[k] = float(v) if not v.endswith('type') else v
368
369    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
370
[8a20be5]371if __name__ == "__main__":
[87985ca]372    main()
Note: See TracBrowser for help on using the repository browser.