source: sasmodels/compare.py @ 7841376

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 7841376 was f173599, checked in by pkienzle, 10 years ago

Merge branch 'master' of github.com:sasview/sasmodels

  • Property mode set to 100755
File size: 12.2 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[87985ca]8
[1726b21]9import numpy as np
[473183c]10
[87985ca]11from sasmodels.bumps_model import BumpsModel, plot_data, tic
[87fce00]12try: from sasmodels import kernelcl
13except: from sasmodels import kerneldll as kernelcl
14from sasmodels import kerneldll
[87985ca]15from sasmodels.convert import revert_model
16
[d547f16]17# List of available models
18ROOT = dirname(__file__)
19MODELS = [basename(f)[:-3]
20          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
21
[8a20be5]22
23def sasview_model(modelname, **pars):
[87985ca]24    """
25    Load a sasview model given the model name.
26    """
27    # convert model parameters from sasmodel form to sasview form
28    #print "old",sorted(pars.items())
29    modelname, pars = revert_model(modelname, pars)
30    #print "new",sorted(pars.items())
[87c722e]31    sas = __import__('sas.models.'+modelname)
32    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]33    if ModelClass is None:
[87c722e]34        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]35    model = ModelClass()
36
37    for k,v in pars.items():
38        if k.endswith("_pd"):
39            model.dispersion[k[:-3]]['width'] = v
40        elif k.endswith("_pd_n"):
41            model.dispersion[k[:-5]]['npts'] = v
42        elif k.endswith("_pd_nsigma"):
43            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]44        elif k.endswith("_pd_type"):
45            model.dispersion[k[:-8]]['type'] = v
[8a20be5]46        else:
47            model.setParam(k, v)
48    return model
49
[87985ca]50def load_opencl(modelname, dtype='single'):
51    sasmodels = __import__('sasmodels.models.'+modelname)
52    module = getattr(sasmodels.models, modelname, None)
[f786ff3]53    kernel = kernelcl.load_model(module, dtype=dtype)
[87985ca]54    return kernel
55
56def load_ctypes(modelname, dtype='single'):
57    sasmodels = __import__('sasmodels.models.'+modelname)
58    module = getattr(sasmodels.models, modelname, None)
[f786ff3]59    kernel = kerneldll.load_model(module, dtype=dtype)
[87985ca]60    return kernel
61
62def randomize(p, v):
63    """
64    Randomizing parameter.
65
66    Guess the parameter type from name.
67    """
68    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
69        return v
70    elif any(s in p for s in ('theta','phi','psi')):
71        # orientation in [-180,180], orientation pd in [0,45]
72        if p.endswith('_pd'):
73            return 45*np.random.rand()
74        else:
75            return 360*np.random.rand() - 180
76    elif 'sld' in p:
77        # sld in in [-0.5,10]
78        return 10.5*np.random.rand() - 0.5
79    elif p.endswith('_pd'):
80        # length pd in [0,1]
81        return np.random.rand()
82    else:
83        # length, scale, background in [0,200]
84        return 200*np.random.rand()
85
[216a9e1]86def randomize_model(name, pars, seed=None):
87    if seed is None:
88        seed = np.random.randint(1e9)
89    np.random.seed(seed)
90    # Note: the sort guarantees order of calls to random number generator
91    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
92    # The capped cylinder model has a constraint on its parameters
93    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
94        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
95    return pars, seed
96
[87985ca]97def parlist(pars):
98    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
99
100def suppress_pd(pars):
101    """
102    Suppress theta_pd for now until the normalization is resolved.
103
104    May also suppress complete polydispersity of the model to test
105    models more quickly.
106    """
107    for p in pars:
108        if p.endswith("_pd"): pars[p] = 0
109
[216a9e1]110def eval_sasview(name, pars, data, Nevals=1):
111    model = sasview_model(name, **pars)
112    toc = tic()
113    for _ in range(Nevals):
114        if hasattr(data, 'qx_data'):
115            value = model.evalDistribution([data.qx_data, data.qy_data])
116        else:
117            value = model.evalDistribution(data.x)
118    average_time = toc()*1000./Nevals
119    return value, average_time
120
121def eval_opencl(name, pars, data, dtype='single', Nevals=1, cutoff=0):
122    try:
123        model = load_opencl(name, dtype=dtype)
124    except Exception,exc:
125        print exc
126        print "... trying again with single precision"
127        model = load_opencl(name, dtype='single')
128    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
129    toc = tic()
130    for _ in range(Nevals):
131        #pars['scale'] = np.random.rand()
132        problem.update()
133        value = problem.theory()
134    average_time = toc()*1000./Nevals
135    return value, average_time
136
137def eval_ctypes(name, pars, data, dtype='double', Nevals=1, cutoff=0):
138    model = load_ctypes(name, dtype=dtype)
139    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
140    toc = tic()
141    for _ in range(Nevals):
142        problem.update()
143        value = problem.theory()
144    average_time = toc()*1000./Nevals
145    return value, average_time
146
147def make_data(qmax, is2D, Nq=128):
148    if is2D:
[87985ca]149        from sasmodels.bumps_model import empty_data2D, set_beam_stop
[216a9e1]150        data = empty_data2D(np.linspace(-qmax, qmax, Nq))
[87985ca]151        set_beam_stop(data, 0.004)
152        index = ~data.mask
[216a9e1]153    else:
154        from sasmodels.bumps_model import empty_data1D
155        qmax = math.log10(qmax)
156        data = empty_data1D(np.logspace(qmax-3, qmax, Nq))
157        index = slice(None, None)
158    return data, index
159
[a503bfd]160def compare(name, pars, Ncpu, Nocl, opts, set_pars):
[216a9e1]161    opt_values = dict(split
162                      for s in opts for split in ((s.split('='),))
163                      if len(split) == 2)
164    # Sort out data
165    qmax = 1.0 if '-highq' in opts else (0.2 if '-midq' in opts else 0.05)
166    Nq = int(opt_values.get('-Nq', '128'))
167    is2D = not "-1d" in opts
168    data, index = make_data(qmax, is2D, Nq)
169
[87985ca]170
171    # modelling accuracy is determined by dtype and cutoff
172    dtype = 'double' if '-double' in opts else 'single'
[216a9e1]173    cutoff = float(opt_values.get('-cutoff','1e-5'))
[87985ca]174
175    # randomize parameters
[216a9e1]176    if '-random' in opts or '-random' in opt_values:
177        seed = int(opt_values['-random']) if '-random' in opt_values else None
178        pars, seed = randomize_model(name, pars, seed=seed)
[87985ca]179        print "Randomize using -random=%i"%seed
180    pars.update(set_pars)
181
182    # parameter selection
183    if '-mono' in opts:
184        suppress_pd(pars)
185    if '-pars' in opts:
186        print "pars",parlist(pars)
187
188    # OpenCl calculation
[a503bfd]189    if Nocl > 0:
190        ocl, ocl_time = eval_opencl(name, pars, data, dtype, Nocl)
191        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl[index]))
192        #print max(ocl), min(ocl)
[87985ca]193
194    # ctypes/sasview calculation
195    if Ncpu > 0 and "-ctypes" in opts:
[216a9e1]196        cpu, cpu_time = eval_ctypes(name, pars, data, dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
[87985ca]197        comp = "ctypes"
198        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
199    elif Ncpu > 0:
[216a9e1]200        cpu, cpu_time = eval_sasview(name, pars, data, Ncpu)
[87985ca]201        comp = "sasview"
202        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
203
204    # Compare, but only if computing both forms
[a503bfd]205    if Nocl > 0 and Ncpu > 0:
206        #print "speedup %.2g"%(cpu_time/ocl_time)
207        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
208        #cpu *= max(ocl/cpu)
209        resid, relerr = np.zeros_like(ocl), np.zeros_like(ocl)
210        resid[index] = (ocl - cpu)[index]
[87985ca]211        relerr[index] = resid[index]/cpu[index]
[ba69383]212        #bad = (relerr>1e-4)
[a503bfd]213        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
[87985ca]214        print "max(|ocl-%s|)"%comp, max(abs(resid[index]))
[ba69383]215        print "max(|(ocl-%s)/%s|)"%(comp,comp), max(abs(relerr[index]))
216        p98 = int(len(relerr[index])*0.98)
217        print "98%% (|(ocl-%s)/%s|) <"%(comp,comp), np.sort(abs(relerr[index]))[p98]
218
[87985ca]219
220    # Plot if requested
221    if '-noplot' in opts: return
[1726b21]222    import matplotlib.pyplot as plt
[87985ca]223    if Ncpu > 0:
[a503bfd]224        if Nocl > 0: plt.subplot(131)
[87985ca]225        plot_data(data, cpu, scale='log')
226        plt.title("%s t=%.1f ms"%(comp,cpu_time))
[a503bfd]227    if Nocl > 0:
[87985ca]228        if Ncpu > 0: plt.subplot(132)
[a503bfd]229        plot_data(data, ocl, scale='log')
230        plt.title("opencl t=%.1f ms"%ocl_time)
231    if Ncpu > 0 and Nocl > 0:
[87985ca]232        plt.subplot(133)
233        err = resid if '-abs' in opts else relerr
234        errstr = "abs err" if '-abs' in opts else "rel err"
[a503bfd]235        #err,errstr = ocl/cpu,"ratio"
[87985ca]236        plot_data(data, err, scale='linear')
237        plt.title("max %s = %.3g"%(errstr, max(abs(err[index]))))
[ba69383]238    if is2D: plt.colorbar()
239
[a503bfd]240    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
[ba69383]241        plt.figure()
242        v = relerr[index]
243        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
244        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
245        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
246        plt.ylabel('P(err)')
247        plt.title('Comparison of single and double precision models for %s'%name)
248
[8a20be5]249    plt.show()
250
[87985ca]251# ===========================================================================
252#
253USAGE="""
254usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
255
256Compare the speed and value for a model between the SasView original and the
257OpenCL rewrite.
258
259model is the name of the model to compare (see below).
260Nopencl is the number of times to run the OpenCL model (default=5)
261Nsasview is the number of times to run the Sasview model (default=1)
262
263Options (* for default):
264
265    -plot*/-noplot plots or suppress the plot of the model
[2d0aced]266    -single*/-double uses double precision for comparison
267    -lowq*/-midq/-highq use q values up to 0.05, 0.2 or 1.0
[216a9e1]268    -Nq=128 sets the number of Q points in the data set
269    -1d/-2d* computes 1d or 2d data
[2d0aced]270    -preset*/-random[=seed] preset or random parameters
271    -mono/-poly* force monodisperse/polydisperse
272    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
273    -cutoff=1e-5*/value cutoff for including a point in polydispersity
274    -pars/-nopars* prints the parameter set or not
275    -abs/-rel* plot relative or absolute error
[ba69383]276    -hist/-nohist* plot histogram of relative error
[87985ca]277
278Key=value pairs allow you to set specific values to any of the model
279parameters.
280
281Available models:
282
283    %s
284"""
285
[216a9e1]286NAME_OPTIONS = set([
[87985ca]287    'plot','noplot',
288    'single','double',
289    'lowq','midq','highq',
290    '2d','1d',
291    'preset','random',
292    'poly','mono',
293    'sasview','ctypes',
294    'nopars','pars',
295    'rel','abs',
[ba69383]296    'hist','nohist',
[216a9e1]297    ])
298VALUE_OPTIONS = [
299    # Note: random is both a name option and a value option
300    'cutoff', 'random', 'Nq',
[87985ca]301    ]
302
[373d1b6]303def get_demo_pars(name):
304    import sasmodels.models
305    __import__('sasmodels.models.'+name)
306    model = getattr(sasmodels.models, name)
307    pars = getattr(model, 'demo', None)
308    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
309    return pars
310
[87985ca]311def main():
312    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
313    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
[d547f16]314    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]315    if len(args) == 0:
316        print(USAGE%models)
317        sys.exit(1)
318    if args[0] not in MODELS:
319        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
320        sys.exit(1)
321
322    invalid = [o[1:] for o in opts
[216a9e1]323               if o[1:] not in NAME_OPTIONS
324                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]325    if invalid:
326        print "Invalid options: %s"%(", ".join(invalid))
327        sys.exit(1)
328
[d547f16]329    # Get demo parameters from model definition, or use default parameters
330    # if model does not define demo parameters
331    name = args[0]
[373d1b6]332    pars = get_demo_pars(name)
[d547f16]333
[87985ca]334    Nopencl = int(args[1]) if len(args) > 1 else 5
[ba69383]335    Nsasview = int(args[2]) if len(args) > 2 else 1
[87985ca]336
337    # Fill in default polydispersity parameters
338    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
339    for p in pds:
340        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
341        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
342
343    # Fill in parameters given on the command line
344    set_pars = {}
345    for arg in args[3:]:
346        k,v = arg.split('=')
347        if k not in pars:
348            # extract base name without distribution
349            s = set(p.split('_pd')[0] for p in pars)
350            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
351            sys.exit(1)
352        set_pars[k] = float(v) if not v.endswith('type') else v
353
354    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
355
[8a20be5]356if __name__ == "__main__":
[87985ca]357    main()
Note: See TracBrowser for help on using the repository browser.