source: sasmodels/compare.py @ cb6ecf4

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since cb6ecf4 was 87c722e, checked in by pkienzle, 10 years ago

use sas instead of sans

  • Property mode set to 100755
File size: 14.4 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6
7import numpy as np
8
9from sasmodels.bumps_model import BumpsModel, plot_data, tic
10from sasmodels import gpu, dll
11from sasmodels.convert import revert_model
12
13
14def sasview_model(modelname, **pars):
15    """
16    Load a sasview model given the model name.
17    """
18    # convert model parameters from sasmodel form to sasview form
19    #print "old",sorted(pars.items())
20    modelname, pars = revert_model(modelname, pars)
21    #print "new",sorted(pars.items())
22    sas = __import__('sas.models.'+modelname)
23    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
24    if ModelClass is None:
25        raise ValueError("could not find model %r in sas.models"%modelname)
26    model = ModelClass()
27
28    for k,v in pars.items():
29        if k.endswith("_pd"):
30            model.dispersion[k[:-3]]['width'] = v
31        elif k.endswith("_pd_n"):
32            model.dispersion[k[:-5]]['npts'] = v
33        elif k.endswith("_pd_nsigma"):
34            model.dispersion[k[:-10]]['nsigmas'] = v
35        elif k.endswith("_pd_type"):
36            model.dispersion[k[:-8]]['type'] = v
37        else:
38            model.setParam(k, v)
39    return model
40
41def load_opencl(modelname, dtype='single'):
42    sasmodels = __import__('sasmodels.models.'+modelname)
43    module = getattr(sasmodels.models, modelname, None)
44    kernel = gpu.load_model(module, dtype=dtype)
45    return kernel
46
47def load_ctypes(modelname, dtype='single'):
48    sasmodels = __import__('sasmodels.models.'+modelname)
49    module = getattr(sasmodels.models, modelname, None)
50    kernel = dll.load_model(module, dtype=dtype)
51    return kernel
52
53def randomize(p, v):
54    """
55    Randomizing parameter.
56
57    Guess the parameter type from name.
58    """
59    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
60        return v
61    elif any(s in p for s in ('theta','phi','psi')):
62        # orientation in [-180,180], orientation pd in [0,45]
63        if p.endswith('_pd'):
64            return 45*np.random.rand()
65        else:
66            return 360*np.random.rand() - 180
67    elif 'sld' in p:
68        # sld in in [-0.5,10]
69        return 10.5*np.random.rand() - 0.5
70    elif p.endswith('_pd'):
71        # length pd in [0,1]
72        return np.random.rand()
73    else:
74        # length, scale, background in [0,200]
75        return 200*np.random.rand()
76
77def randomize_model(name, pars, seed=None):
78    if seed is None:
79        seed = np.random.randint(1e9)
80    np.random.seed(seed)
81    # Note: the sort guarantees order of calls to random number generator
82    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
83    # The capped cylinder model has a constraint on its parameters
84    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
85        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
86    return pars, seed
87
88def parlist(pars):
89    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
90
91def suppress_pd(pars):
92    """
93    Suppress theta_pd for now until the normalization is resolved.
94
95    May also suppress complete polydispersity of the model to test
96    models more quickly.
97    """
98    for p in pars:
99        if p.endswith("_pd"): pars[p] = 0
100
101def eval_sasview(name, pars, data, Nevals=1):
102    model = sasview_model(name, **pars)
103    toc = tic()
104    for _ in range(Nevals):
105        if hasattr(data, 'qx_data'):
106            value = model.evalDistribution([data.qx_data, data.qy_data])
107        else:
108            value = model.evalDistribution(data.x)
109    average_time = toc()*1000./Nevals
110    return value, average_time
111
112def eval_opencl(name, pars, data, dtype='single', Nevals=1, cutoff=0):
113    try:
114        model = load_opencl(name, dtype=dtype)
115    except Exception,exc:
116        print exc
117        print "... trying again with single precision"
118        model = load_opencl(name, dtype='single')
119    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
120    toc = tic()
121    for _ in range(Nevals):
122        #pars['scale'] = np.random.rand()
123        problem.update()
124        value = problem.theory()
125    average_time = toc()*1000./Nevals
126    return value, average_time
127
128def eval_ctypes(name, pars, data, dtype='double', Nevals=1, cutoff=0):
129    model = load_ctypes(name, dtype=dtype)
130    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
131    toc = tic()
132    for _ in range(Nevals):
133        problem.update()
134        value = problem.theory()
135    average_time = toc()*1000./Nevals
136    return value, average_time
137
138def make_data(qmax, is2D, Nq=128):
139    if is2D:
140        from sasmodels.bumps_model import empty_data2D, set_beam_stop
141        data = empty_data2D(np.linspace(-qmax, qmax, Nq))
142        set_beam_stop(data, 0.004)
143        index = ~data.mask
144    else:
145        from sasmodels.bumps_model import empty_data1D
146        qmax = math.log10(qmax)
147        data = empty_data1D(np.logspace(qmax-3, qmax, Nq))
148        index = slice(None, None)
149    return data, index
150
151def compare(name, pars, Ncpu, Ngpu, opts, set_pars):
152    opt_values = dict(split
153                      for s in opts for split in ((s.split('='),))
154                      if len(split) == 2)
155    # Sort out data
156    qmax = 1.0 if '-highq' in opts else (0.2 if '-midq' in opts else 0.05)
157    Nq = int(opt_values.get('-Nq', '128'))
158    is2D = not "-1d" in opts
159    data, index = make_data(qmax, is2D, Nq)
160
161
162    # modelling accuracy is determined by dtype and cutoff
163    dtype = 'double' if '-double' in opts else 'single'
164    cutoff = float(opt_values.get('-cutoff','1e-5'))
165
166    # randomize parameters
167    if '-random' in opts or '-random' in opt_values:
168        seed = int(opt_values['-random']) if '-random' in opt_values else None
169        pars, seed = randomize_model(name, pars, seed=seed)
170        print "Randomize using -random=%i"%seed
171    pars.update(set_pars)
172
173    # parameter selection
174    if '-mono' in opts:
175        suppress_pd(pars)
176    if '-pars' in opts:
177        print "pars",parlist(pars)
178
179    # OpenCl calculation
180    if Ngpu > 0:
181        gpu, gpu_time = eval_opencl(name, pars, data, dtype, Ngpu)
182        print "opencl t=%.1f ms, intensity=%.0f"%(gpu_time, sum(gpu[index]))
183        #print max(gpu), min(gpu)
184
185    # ctypes/sasview calculation
186    if Ncpu > 0 and "-ctypes" in opts:
187        cpu, cpu_time = eval_ctypes(name, pars, data, dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
188        comp = "ctypes"
189        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
190    elif Ncpu > 0:
191        cpu, cpu_time = eval_sasview(name, pars, data, Ncpu)
192        comp = "sasview"
193        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
194
195    # Compare, but only if computing both forms
196    if Ngpu > 0 and Ncpu > 0:
197        #print "speedup %.2g"%(cpu_time/gpu_time)
198        #print "max |gpu/cpu|", max(abs(gpu/cpu)), "%.15g"%max(abs(gpu)), "%.15g"%max(abs(cpu))
199        #cpu *= max(gpu/cpu)
200        resid, relerr = np.zeros_like(gpu), np.zeros_like(gpu)
201        resid[index] = (gpu - cpu)[index]
202        relerr[index] = resid[index]/cpu[index]
203        #bad = (relerr>1e-4)
204        #print relerr[bad],cpu[bad],gpu[bad],data.qx_data[bad],data.qy_data[bad]
205        print "max(|ocl-%s|)"%comp, max(abs(resid[index]))
206        print "max(|(ocl-%s)/%s|)"%(comp,comp), max(abs(relerr[index]))
207        p98 = int(len(relerr[index])*0.98)
208        print "98%% (|(ocl-%s)/%s|) <"%(comp,comp), np.sort(abs(relerr[index]))[p98]
209
210
211    # Plot if requested
212    if '-noplot' in opts: return
213    import matplotlib.pyplot as plt
214    if Ncpu > 0:
215        if Ngpu > 0: plt.subplot(131)
216        plot_data(data, cpu, scale='log')
217        plt.title("%s t=%.1f ms"%(comp,cpu_time))
218    if Ngpu > 0:
219        if Ncpu > 0: plt.subplot(132)
220        plot_data(data, gpu, scale='log')
221        plt.title("opencl t=%.1f ms"%gpu_time)
222    if Ncpu > 0 and Ngpu > 0:
223        plt.subplot(133)
224        err = resid if '-abs' in opts else relerr
225        errstr = "abs err" if '-abs' in opts else "rel err"
226        #err,errstr = gpu/cpu,"ratio"
227        plot_data(data, err, scale='linear')
228        plt.title("max %s = %.3g"%(errstr, max(abs(err[index]))))
229    if is2D: plt.colorbar()
230
231    if Ncpu > 0 and Ngpu > 0 and '-hist' in opts:
232        plt.figure()
233        v = relerr[index]
234        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
235        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
236        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
237        plt.ylabel('P(err)')
238        plt.title('Comparison of single and double precision models for %s'%name)
239
240    plt.show()
241
242# ===========================================================================
243#
244USAGE="""
245usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
246
247Compare the speed and value for a model between the SasView original and the
248OpenCL rewrite.
249
250model is the name of the model to compare (see below).
251Nopencl is the number of times to run the OpenCL model (default=5)
252Nsasview is the number of times to run the Sasview model (default=1)
253
254Options (* for default):
255
256    -plot*/-noplot plots or suppress the plot of the model
257    -single*/-double uses double precision for comparison
258    -lowq*/-midq/-highq use q values up to 0.05, 0.2 or 1.0
259    -Nq=128 sets the number of Q points in the data set
260    -1d/-2d* computes 1d or 2d data
261    -preset*/-random[=seed] preset or random parameters
262    -mono/-poly* force monodisperse/polydisperse
263    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
264    -cutoff=1e-5*/value cutoff for including a point in polydispersity
265    -pars/-nopars* prints the parameter set or not
266    -abs/-rel* plot relative or absolute error
267    -hist/-nohist* plot histogram of relative error
268
269Key=value pairs allow you to set specific values to any of the model
270parameters.
271
272Available models:
273
274    %s
275"""
276
277NAME_OPTIONS = set([
278    'plot','noplot',
279    'single','double',
280    'lowq','midq','highq',
281    '2d','1d',
282    'preset','random',
283    'poly','mono',
284    'sasview','ctypes',
285    'nopars','pars',
286    'rel','abs',
287    'hist','nohist',
288    ])
289VALUE_OPTIONS = [
290    # Note: random is both a name option and a value option
291    'cutoff', 'random', 'Nq',
292    ]
293
294def main():
295    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
296    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
297    models = "\n    ".join("%-7s: %s"%(k,v.__name__.replace('_',' '))
298                           for k,v in sorted(MODELS.items()))
299    if len(args) == 0:
300        print(USAGE%models)
301        sys.exit(1)
302    if args[0] not in MODELS:
303        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
304        sys.exit(1)
305
306    invalid = [o[1:] for o in opts
307               if o[1:] not in NAME_OPTIONS
308                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
309    if invalid:
310        print "Invalid options: %s"%(", ".join(invalid))
311        sys.exit(1)
312
313    name, pars = MODELS[args[0]]()
314    Nopencl = int(args[1]) if len(args) > 1 else 5
315    Nsasview = int(args[2]) if len(args) > 2 else 1
316
317    # Fill in default polydispersity parameters
318    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
319    for p in pds:
320        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
321        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
322
323    # Fill in parameters given on the command line
324    set_pars = {}
325    for arg in args[3:]:
326        k,v = arg.split('=')
327        if k not in pars:
328            # extract base name without distribution
329            s = set(p.split('_pd')[0] for p in pars)
330            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
331            sys.exit(1)
332        set_pars[k] = float(v) if not v.endswith('type') else v
333
334    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
335
336# ===========================================================================
337#
338
339MODELS = {}
340def model(name):
341    def gather_function(fn):
342        MODELS[name] = fn
343        return fn
344    return gather_function
345
346
347@model('cyl')
348def cylinder():
349    pars = dict(
350        scale=1, background=0,
351        sld=6, solvent_sld=1,
352        #radius=5, length=20,
353        radius=260, length=290,
354        theta=30, phi=0,
355        radius_pd=.2, radius_pd_n=9,
356        length_pd=.2,length_pd_n=10,
357        theta_pd=15, theta_pd_n=45,
358        phi_pd=15, phi_pd_n=1,
359        )
360    return 'cylinder', pars
361
362@model('capcyl')
363def capped_cylinder():
364    pars = dict(
365        scale=1, background=0,
366        sld=6, solvent_sld=1,
367        radius=260, cap_radius=290, length=290,
368        theta=30, phi=15,
369        radius_pd=.2, radius_pd_n=1,
370        cap_radius_pd=.2, cap_radius_pd_n=1,
371        length_pd=.2, length_pd_n=10,
372        theta_pd=15, theta_pd_n=45,
373        phi_pd=15, phi_pd_n=1,
374        )
375    return 'capped_cylinder', pars
376
377
378@model('cscyl')
379def core_shell_cylinder():
380    pars = dict(
381        scale=1, background=0,
382        core_sld=6, shell_sld=8, solvent_sld=1,
383        radius=45, thickness=25, length=340,
384        theta=30, phi=15,
385        radius_pd=.2, radius_pd_n=1,
386        length_pd=.2, length_pd_n=10,
387        thickness_pd=.2, thickness_pd_n=10,
388        theta_pd=15, theta_pd_n=45,
389        phi_pd=15, phi_pd_n=1,
390        )
391    return 'core_shell_cylinder', pars
392
393
394@model('ell')
395def ellipsoid():
396    pars = dict(
397        scale=1, background=0,
398        sld=6, solvent_sld=1,
399        rpolar=50, requatorial=30,
400        theta=30, phi=15,
401        rpolar_pd=.2, rpolar_pd_n=15,
402        requatorial_pd=.2, requatorial_pd_n=15,
403        theta_pd=15, theta_pd_n=45,
404        phi_pd=15, phi_pd_n=1,
405        )
406    return 'ellipsoid', pars
407
408
409@model('ell3')
410def triaxial_ellipsoid():
411    pars = dict(
412        scale=1, background=0,
413        sld=6, solvent_sld=1,
414        theta=30, phi=15, psi=5,
415        req_minor=25, req_major=36, rpolar=50,
416        req_minor_pd=0, req_minor_pd_n=1,
417        req_major_pd=0, req_major_pd_n=1,
418        rpolar_pd=.2, rpolar_pd_n=30,
419        theta_pd=15, theta_pd_n=45,
420        phi_pd=15, phi_pd_n=1,
421        psi_pd=15, psi_pd_n=1,
422        )
423    return 'triaxial_ellipsoid', pars
424
425@model('sphpy')
426def spherepy():
427    pars = dict(
428        scale=1, background=0,
429        sld=6, solvent_sld=1,
430        radius=120,
431        radius_pd=.2, radius_pd_n=45,
432        )
433    return 'spherepy', pars
434
435@model('sph')
436def sphere():
437    pars = dict(
438        scale=1, background=0,
439        sld=6, solvent_sld=1,
440        radius=120,
441        radius_pd=.2, radius_pd_n=45,
442        )
443    return 'sphere', pars
444
445@model('lam')
446def lamellar():
447    pars = dict(
448        scale=1, background=0,
449        sld=6, solvent_sld=1,
450        thickness=40,
451        thickness_pd= 0.2, thickness_pd_n=40,
452        )
453    return 'lamellar', pars
454
455if __name__ == "__main__":
456    main()
Note: See TracBrowser for help on using the repository browser.