source: sasmodels/compare.py @ 1e11735

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 1e11735 was a503bfd, checked in by pkienzle, 10 years ago

move sasview→sasmodels conversion info to model definition

  • Property mode set to 100755
File size: 12.0 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[87985ca]8
[1726b21]9import numpy as np
[473183c]10
[87985ca]11from sasmodels.bumps_model import BumpsModel, plot_data, tic
[f786ff3]12from sasmodels import kernelcl, kerneldll
[87985ca]13from sasmodels.convert import revert_model
14
[d547f16]15# List of available models
16ROOT = dirname(__file__)
17MODELS = [basename(f)[:-3]
18          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
19
[8a20be5]20
21def sasview_model(modelname, **pars):
[87985ca]22    """
23    Load a sasview model given the model name.
24    """
25    # convert model parameters from sasmodel form to sasview form
26    #print "old",sorted(pars.items())
27    modelname, pars = revert_model(modelname, pars)
28    #print "new",sorted(pars.items())
[87c722e]29    sas = __import__('sas.models.'+modelname)
30    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]31    if ModelClass is None:
[87c722e]32        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]33    model = ModelClass()
34
35    for k,v in pars.items():
36        if k.endswith("_pd"):
37            model.dispersion[k[:-3]]['width'] = v
38        elif k.endswith("_pd_n"):
39            model.dispersion[k[:-5]]['npts'] = v
40        elif k.endswith("_pd_nsigma"):
41            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]42        elif k.endswith("_pd_type"):
43            model.dispersion[k[:-8]]['type'] = v
[8a20be5]44        else:
45            model.setParam(k, v)
46    return model
47
[87985ca]48def load_opencl(modelname, dtype='single'):
49    sasmodels = __import__('sasmodels.models.'+modelname)
50    module = getattr(sasmodels.models, modelname, None)
[f786ff3]51    kernel = kernelcl.load_model(module, dtype=dtype)
[87985ca]52    return kernel
53
54def load_ctypes(modelname, dtype='single'):
55    sasmodels = __import__('sasmodels.models.'+modelname)
56    module = getattr(sasmodels.models, modelname, None)
[f786ff3]57    kernel = kerneldll.load_model(module, dtype=dtype)
[87985ca]58    return kernel
59
60def randomize(p, v):
61    """
62    Randomizing parameter.
63
64    Guess the parameter type from name.
65    """
66    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
67        return v
68    elif any(s in p for s in ('theta','phi','psi')):
69        # orientation in [-180,180], orientation pd in [0,45]
70        if p.endswith('_pd'):
71            return 45*np.random.rand()
72        else:
73            return 360*np.random.rand() - 180
74    elif 'sld' in p:
75        # sld in in [-0.5,10]
76        return 10.5*np.random.rand() - 0.5
77    elif p.endswith('_pd'):
78        # length pd in [0,1]
79        return np.random.rand()
80    else:
81        # length, scale, background in [0,200]
82        return 200*np.random.rand()
83
[216a9e1]84def randomize_model(name, pars, seed=None):
85    if seed is None:
86        seed = np.random.randint(1e9)
87    np.random.seed(seed)
88    # Note: the sort guarantees order of calls to random number generator
89    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
90    # The capped cylinder model has a constraint on its parameters
91    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
92        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
93    return pars, seed
94
[87985ca]95def parlist(pars):
96    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
97
98def suppress_pd(pars):
99    """
100    Suppress theta_pd for now until the normalization is resolved.
101
102    May also suppress complete polydispersity of the model to test
103    models more quickly.
104    """
105    for p in pars:
106        if p.endswith("_pd"): pars[p] = 0
107
[216a9e1]108def eval_sasview(name, pars, data, Nevals=1):
109    model = sasview_model(name, **pars)
110    toc = tic()
111    for _ in range(Nevals):
112        if hasattr(data, 'qx_data'):
113            value = model.evalDistribution([data.qx_data, data.qy_data])
114        else:
115            value = model.evalDistribution(data.x)
116    average_time = toc()*1000./Nevals
117    return value, average_time
118
119def eval_opencl(name, pars, data, dtype='single', Nevals=1, cutoff=0):
120    try:
121        model = load_opencl(name, dtype=dtype)
122    except Exception,exc:
123        print exc
124        print "... trying again with single precision"
125        model = load_opencl(name, dtype='single')
126    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
127    toc = tic()
128    for _ in range(Nevals):
129        #pars['scale'] = np.random.rand()
130        problem.update()
131        value = problem.theory()
132    average_time = toc()*1000./Nevals
133    return value, average_time
134
135def eval_ctypes(name, pars, data, dtype='double', Nevals=1, cutoff=0):
136    model = load_ctypes(name, dtype=dtype)
137    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
138    toc = tic()
139    for _ in range(Nevals):
140        problem.update()
141        value = problem.theory()
142    average_time = toc()*1000./Nevals
143    return value, average_time
144
145def make_data(qmax, is2D, Nq=128):
146    if is2D:
[87985ca]147        from sasmodels.bumps_model import empty_data2D, set_beam_stop
[216a9e1]148        data = empty_data2D(np.linspace(-qmax, qmax, Nq))
[87985ca]149        set_beam_stop(data, 0.004)
150        index = ~data.mask
[216a9e1]151    else:
152        from sasmodels.bumps_model import empty_data1D
153        qmax = math.log10(qmax)
154        data = empty_data1D(np.logspace(qmax-3, qmax, Nq))
155        index = slice(None, None)
156    return data, index
157
[a503bfd]158def compare(name, pars, Ncpu, Nocl, opts, set_pars):
[216a9e1]159    opt_values = dict(split
160                      for s in opts for split in ((s.split('='),))
161                      if len(split) == 2)
162    # Sort out data
163    qmax = 1.0 if '-highq' in opts else (0.2 if '-midq' in opts else 0.05)
164    Nq = int(opt_values.get('-Nq', '128'))
165    is2D = not "-1d" in opts
166    data, index = make_data(qmax, is2D, Nq)
167
[87985ca]168
169    # modelling accuracy is determined by dtype and cutoff
170    dtype = 'double' if '-double' in opts else 'single'
[216a9e1]171    cutoff = float(opt_values.get('-cutoff','1e-5'))
[87985ca]172
173    # randomize parameters
[216a9e1]174    if '-random' in opts or '-random' in opt_values:
175        seed = int(opt_values['-random']) if '-random' in opt_values else None
176        pars, seed = randomize_model(name, pars, seed=seed)
[87985ca]177        print "Randomize using -random=%i"%seed
178    pars.update(set_pars)
179
180    # parameter selection
181    if '-mono' in opts:
182        suppress_pd(pars)
183    if '-pars' in opts:
184        print "pars",parlist(pars)
185
186    # OpenCl calculation
[a503bfd]187    if Nocl > 0:
188        ocl, ocl_time = eval_opencl(name, pars, data, dtype, Nocl)
189        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl[index]))
190        #print max(ocl), min(ocl)
[87985ca]191
192    # ctypes/sasview calculation
193    if Ncpu > 0 and "-ctypes" in opts:
[216a9e1]194        cpu, cpu_time = eval_ctypes(name, pars, data, dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
[87985ca]195        comp = "ctypes"
196        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
197    elif Ncpu > 0:
[216a9e1]198        cpu, cpu_time = eval_sasview(name, pars, data, Ncpu)
[87985ca]199        comp = "sasview"
200        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
201
202    # Compare, but only if computing both forms
[a503bfd]203    if Nocl > 0 and Ncpu > 0:
204        #print "speedup %.2g"%(cpu_time/ocl_time)
205        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
206        #cpu *= max(ocl/cpu)
207        resid, relerr = np.zeros_like(ocl), np.zeros_like(ocl)
208        resid[index] = (ocl - cpu)[index]
[87985ca]209        relerr[index] = resid[index]/cpu[index]
[ba69383]210        #bad = (relerr>1e-4)
[a503bfd]211        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
[87985ca]212        print "max(|ocl-%s|)"%comp, max(abs(resid[index]))
[ba69383]213        print "max(|(ocl-%s)/%s|)"%(comp,comp), max(abs(relerr[index]))
214        p98 = int(len(relerr[index])*0.98)
215        print "98%% (|(ocl-%s)/%s|) <"%(comp,comp), np.sort(abs(relerr[index]))[p98]
216
[87985ca]217
218    # Plot if requested
219    if '-noplot' in opts: return
[1726b21]220    import matplotlib.pyplot as plt
[87985ca]221    if Ncpu > 0:
[a503bfd]222        if Nocl > 0: plt.subplot(131)
[87985ca]223        plot_data(data, cpu, scale='log')
224        plt.title("%s t=%.1f ms"%(comp,cpu_time))
[a503bfd]225    if Nocl > 0:
[87985ca]226        if Ncpu > 0: plt.subplot(132)
[a503bfd]227        plot_data(data, ocl, scale='log')
228        plt.title("opencl t=%.1f ms"%ocl_time)
229    if Ncpu > 0 and Nocl > 0:
[87985ca]230        plt.subplot(133)
231        err = resid if '-abs' in opts else relerr
232        errstr = "abs err" if '-abs' in opts else "rel err"
[a503bfd]233        #err,errstr = ocl/cpu,"ratio"
[87985ca]234        plot_data(data, err, scale='linear')
235        plt.title("max %s = %.3g"%(errstr, max(abs(err[index]))))
[ba69383]236    if is2D: plt.colorbar()
237
[a503bfd]238    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
[ba69383]239        plt.figure()
240        v = relerr[index]
241        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
242        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
243        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
244        plt.ylabel('P(err)')
245        plt.title('Comparison of single and double precision models for %s'%name)
246
[8a20be5]247    plt.show()
248
[87985ca]249# ===========================================================================
250#
251USAGE="""
252usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
253
254Compare the speed and value for a model between the SasView original and the
255OpenCL rewrite.
256
257model is the name of the model to compare (see below).
258Nopencl is the number of times to run the OpenCL model (default=5)
259Nsasview is the number of times to run the Sasview model (default=1)
260
261Options (* for default):
262
263    -plot*/-noplot plots or suppress the plot of the model
[2d0aced]264    -single*/-double uses double precision for comparison
265    -lowq*/-midq/-highq use q values up to 0.05, 0.2 or 1.0
[216a9e1]266    -Nq=128 sets the number of Q points in the data set
267    -1d/-2d* computes 1d or 2d data
[2d0aced]268    -preset*/-random[=seed] preset or random parameters
269    -mono/-poly* force monodisperse/polydisperse
270    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
271    -cutoff=1e-5*/value cutoff for including a point in polydispersity
272    -pars/-nopars* prints the parameter set or not
273    -abs/-rel* plot relative or absolute error
[ba69383]274    -hist/-nohist* plot histogram of relative error
[87985ca]275
276Key=value pairs allow you to set specific values to any of the model
277parameters.
278
279Available models:
280
281    %s
282"""
283
[216a9e1]284NAME_OPTIONS = set([
[87985ca]285    'plot','noplot',
286    'single','double',
287    'lowq','midq','highq',
288    '2d','1d',
289    'preset','random',
290    'poly','mono',
291    'sasview','ctypes',
292    'nopars','pars',
293    'rel','abs',
[ba69383]294    'hist','nohist',
[216a9e1]295    ])
296VALUE_OPTIONS = [
297    # Note: random is both a name option and a value option
298    'cutoff', 'random', 'Nq',
[87985ca]299    ]
300
301def main():
302    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
303    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
[d547f16]304    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]305    if len(args) == 0:
306        print(USAGE%models)
307        sys.exit(1)
308    if args[0] not in MODELS:
309        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
310        sys.exit(1)
311
312    invalid = [o[1:] for o in opts
[216a9e1]313               if o[1:] not in NAME_OPTIONS
314                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]315    if invalid:
316        print "Invalid options: %s"%(", ".join(invalid))
317        sys.exit(1)
318
[d547f16]319    # Get demo parameters from model definition, or use default parameters
320    # if model does not define demo parameters
321    name = args[0]
322    import sasmodels.models
323    __import__('sasmodels.models.'+name)
324    model = getattr(sasmodels.models, name)
325    pars = getattr(model, 'demo', None)
326    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
327
[87985ca]328    Nopencl = int(args[1]) if len(args) > 1 else 5
[ba69383]329    Nsasview = int(args[2]) if len(args) > 2 else 1
[87985ca]330
331    # Fill in default polydispersity parameters
332    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
333    for p in pds:
334        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
335        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
336
337    # Fill in parameters given on the command line
338    set_pars = {}
339    for arg in args[3:]:
340        k,v = arg.split('=')
341        if k not in pars:
342            # extract base name without distribution
343            s = set(p.split('_pd')[0] for p in pars)
344            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
345            sys.exit(1)
346        set_pars[k] = float(v) if not v.endswith('type') else v
347
348    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
349
[8a20be5]350if __name__ == "__main__":
[87985ca]351    main()
Note: See TracBrowser for help on using the repository browser.