source: sasmodels/compare.py @ bcd3aa3

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since bcd3aa3 was aa4946b, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

refactor so kernels are loaded via core.load_model

  • Property mode set to 100755
File size: 12.9 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8
9import numpy as np
10
11from sasmodels.bumps_model import BumpsModel, plot_data, tic
12from sasmodels import core
13from sasmodels import kerneldll
14from sasmodels.convert import revert_model
15kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
16
17# List of available models
18ROOT = dirname(__file__)
19MODELS = [basename(f)[:-3]
20          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
21
22
23def sasview_model(model_definition, **pars):
24    """
25    Load a sasview model given the model name.
26    """
27    # convert model parameters from sasmodel form to sasview form
28    #print "old",sorted(pars.items())
29    modelname, pars = revert_model(model_definition, pars)
30    #print "new",sorted(pars.items())
31    sas = __import__('sas.models.'+modelname)
32    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
33    if ModelClass is None:
34        raise ValueError("could not find model %r in sas.models"%modelname)
35    model = ModelClass()
36
37    for k,v in pars.items():
38        if k.endswith("_pd"):
39            model.dispersion[k[:-3]]['width'] = v
40        elif k.endswith("_pd_n"):
41            model.dispersion[k[:-5]]['npts'] = v
42        elif k.endswith("_pd_nsigma"):
43            model.dispersion[k[:-10]]['nsigmas'] = v
44        elif k.endswith("_pd_type"):
45            model.dispersion[k[:-8]]['type'] = v
46        else:
47            model.setParam(k, v)
48    return model
49
50def randomize(p, v):
51    """
52    Randomizing parameter.
53
54    Guess the parameter type from name.
55    """
56    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
57        return v
58    elif any(s in p for s in ('theta','phi','psi')):
59        # orientation in [-180,180], orientation pd in [0,45]
60        if p.endswith('_pd'):
61            return 45*np.random.rand()
62        else:
63            return 360*np.random.rand() - 180
64    elif 'sld' in p:
65        # sld in in [-0.5,10]
66        return 10.5*np.random.rand() - 0.5
67    elif p.endswith('_pd'):
68        # length pd in [0,1]
69        return np.random.rand()
70    else:
71        # values from 0 to 2*x for all other parameters
72        return 2*np.random.rand()*(v if v != 0 else 1)
73
74def randomize_model(name, pars, seed=None):
75    if seed is None:
76        seed = np.random.randint(1e9)
77    np.random.seed(seed)
78    # Note: the sort guarantees order of calls to random number generator
79    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
80    # The capped cylinder model has a constraint on its parameters
81    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
82        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
83    return pars, seed
84
85def parlist(pars):
86    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
87
88def suppress_pd(pars):
89    """
90    Suppress theta_pd for now until the normalization is resolved.
91
92    May also suppress complete polydispersity of the model to test
93    models more quickly.
94    """
95    for p in pars:
96        if p.endswith("_pd"): pars[p] = 0
97
98def eval_sasview(name, pars, data, Nevals=1):
99    model = sasview_model(name, **pars)
100    toc = tic()
101    for _ in range(Nevals):
102        if hasattr(data, 'qx_data'):
103            value = model.evalDistribution([data.qx_data, data.qy_data])
104        else:
105            value = model.evalDistribution(data.x)
106    average_time = toc()*1000./Nevals
107    return value, average_time
108
109def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0):
110    try:
111        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
112    except Exception,exc:
113        print exc
114        print "... trying again with single precision"
115        model = core.load_model(model_definition, dtype='single', platform="ocl")
116    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
117    toc = tic()
118    for _ in range(Nevals):
119        #pars['scale'] = np.random.rand()
120        problem.update()
121        value = problem.theory()
122    average_time = toc()*1000./Nevals
123    return value, average_time
124
125def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0):
126    model = core.load_model(model_definition, dtype=dtype, platform="dll")
127    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
128    toc = tic()
129    for _ in range(Nevals):
130        problem.update()
131        value = problem.theory()
132    average_time = toc()*1000./Nevals
133    return value, average_time
134
135def make_data(qmax, is2D, Nq=128, view='log'):
136    if is2D:
137        from sasmodels.bumps_model import empty_data2D, set_beam_stop
138        data = empty_data2D(np.linspace(-qmax, qmax, Nq))
139        set_beam_stop(data, 0.004)
140        index = ~data.mask
141    else:
142        from sasmodels.bumps_model import empty_data1D
143        if view == 'log':
144            qmax = math.log10(qmax)
145            q = np.logspace(qmax-3, qmax, Nq)
146        else:
147            q = np.linspace(0.001*qmax, qmax, Nq)
148        data = empty_data1D(q)
149        index = slice(None, None)
150    return data, index
151
152def compare(name, pars, Ncpu, Nocl, opts, set_pars):
153    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
154
155    opt_values = dict(split
156                      for s in opts for split in ((s.split('='),))
157                      if len(split) == 2)
158    # Sort out data
159    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
160    Nq = int(opt_values.get('-Nq', '128'))
161    is2D = not "-1d" in opts
162    data, index = make_data(qmax, is2D, Nq, view=view)
163
164
165    # modelling accuracy is determined by dtype and cutoff
166    dtype = 'double' if '-double' in opts else 'single'
167    cutoff = float(opt_values.get('-cutoff','1e-5'))
168
169    # randomize parameters
170    pars.update(set_pars)
171    if '-random' in opts or '-random' in opt_values:
172        seed = int(opt_values['-random']) if '-random' in opt_values else None
173        pars, seed = randomize_model(name, pars, seed=seed)
174        print "Randomize using -random=%i"%seed
175
176    # parameter selection
177    if '-mono' in opts:
178        suppress_pd(pars)
179    if '-pars' in opts:
180        print "pars",parlist(pars)
181
182    model_definition = core.load_model_definition(name)
183    # OpenCl calculation
184    if Nocl > 0:
185        ocl, ocl_time = eval_opencl(model_definition, pars, data,
186                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
187        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl[index]))
188        #print max(ocl), min(ocl)
189
190    # ctypes/sasview calculation
191    if Ncpu > 0 and "-ctypes" in opts:
192        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
193                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
194        comp = "ctypes"
195        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
196    elif Ncpu > 0:
197        cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
198        comp = "sasview"
199        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
200
201    # Compare, but only if computing both forms
202    if Nocl > 0 and Ncpu > 0:
203        #print "speedup %.2g"%(cpu_time/ocl_time)
204        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
205        #cpu *= max(ocl/cpu)
206        resid, relerr = np.zeros_like(ocl), np.zeros_like(ocl)
207        resid[index] = (ocl - cpu)[index]
208        relerr[index] = resid[index]/cpu[index]
209        #bad = (relerr>1e-4)
210        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
211        def stats(label,err):
212            sorted_err = np.sort(abs(err))
213            p50 = int((len(err)-1)*0.50)
214            p98 = int((len(err)-1)*0.98)
215            data = [
216                "max:%.3e"%sorted_err[-1],
217                "median:%.3e"%sorted_err[p50],
218                "98%%:%.3e"%sorted_err[p98],
219                "rms:%.3e"%np.sqrt(np.mean(err**2)),
220                "zero-offset:%+.3e"%np.mean(err),
221                ]
222            print label,"  ".join(data)
223        stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid[index])
224        stats("|(ocl-%s)/%s|"%(comp,comp), relerr[index])
225
226    # Plot if requested
227    if '-noplot' in opts: return
228    import matplotlib.pyplot as plt
229    if Ncpu > 0:
230        if Nocl > 0: plt.subplot(131)
231        plot_data(data, cpu, view=view)
232        plt.title("%s t=%.1f ms"%(comp,cpu_time))
233        cbar_title = "log I"
234    if Nocl > 0:
235        if Ncpu > 0: plt.subplot(132)
236        plot_data(data, ocl, view=view)
237        plt.title("opencl t=%.1f ms"%ocl_time)
238        cbar_title = "log I"
239    if Ncpu > 0 and Nocl > 0:
240        plt.subplot(133)
241        if '-abs' in opts:
242            err,errstr,errview = resid, "abs err", "linear"
243        else:
244            err,errstr,errview = abs(relerr), "rel err", "log"
245        #err,errstr = ocl/cpu,"ratio"
246        plot_data(data, err, view=errview)
247        plt.title("max %s = %.3g"%(errstr, max(abs(err[index]))))
248        cbar_title = errstr if errview=="linear" else "log "+errstr
249    if is2D:
250        h = plt.colorbar()
251        h.ax.set_title(cbar_title)
252
253    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
254        plt.figure()
255        v = relerr[index]
256        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
257        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
258        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
259        plt.ylabel('P(err)')
260        plt.title('Comparison of single and double precision models for %s'%name)
261
262    plt.show()
263
264# ===========================================================================
265#
266USAGE="""
267usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
268
269Compare the speed and value for a model between the SasView original and the
270OpenCL rewrite.
271
272model is the name of the model to compare (see below).
273Nopencl is the number of times to run the OpenCL model (default=5)
274Nsasview is the number of times to run the Sasview model (default=1)
275
276Options (* for default):
277
278    -plot*/-noplot plots or suppress the plot of the model
279    -single*/-double uses double precision for comparison
280    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
281    -Nq=128 sets the number of Q points in the data set
282    -1d/-2d* computes 1d or 2d data
283    -preset*/-random[=seed] preset or random parameters
284    -mono/-poly* force monodisperse/polydisperse
285    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
286    -cutoff=1e-5*/value cutoff for including a point in polydispersity
287    -pars/-nopars* prints the parameter set or not
288    -abs/-rel* plot relative or absolute error
289    -linear/-log/-q4 intensity scaling
290    -hist/-nohist* plot histogram of relative error
291
292Key=value pairs allow you to set specific values to any of the model
293parameters.
294
295Available models:
296
297    %s
298"""
299
300NAME_OPTIONS = set([
301    'plot','noplot',
302    'single','double',
303    'lowq','midq','highq','exq',
304    '2d','1d',
305    'preset','random',
306    'poly','mono',
307    'sasview','ctypes',
308    'nopars','pars',
309    'rel','abs',
310    'linear', 'log', 'q4',
311    'hist','nohist',
312    ])
313VALUE_OPTIONS = [
314    # Note: random is both a name option and a value option
315    'cutoff', 'random', 'Nq',
316    ]
317
318def get_demo_pars(name):
319    import sasmodels.models
320    __import__('sasmodels.models.'+name)
321    model = getattr(sasmodels.models, name)
322    pars = getattr(model, 'demo', None)
323    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
324    return pars
325
326def main():
327    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
328    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
329    models = "\n    ".join("%-15s"%v for v in MODELS)
330    if len(args) == 0:
331        print(USAGE%models)
332        sys.exit(1)
333    if args[0] not in MODELS:
334        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
335        sys.exit(1)
336
337    invalid = [o[1:] for o in opts
338               if o[1:] not in NAME_OPTIONS
339                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
340    if invalid:
341        print "Invalid options: %s"%(", ".join(invalid))
342        sys.exit(1)
343
344    # Get demo parameters from model definition, or use default parameters
345    # if model does not define demo parameters
346    name = args[0]
347    pars = get_demo_pars(name)
348
349    Nopencl = int(args[1]) if len(args) > 1 else 5
350    Nsasview = int(args[2]) if len(args) > 2 else 1
351
352    # Fill in default polydispersity parameters
353    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
354    for p in pds:
355        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
356        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
357
358    # Fill in parameters given on the command line
359    set_pars = {}
360    for arg in args[3:]:
361        k,v = arg.split('=')
362        if k not in pars:
363            # extract base name without distribution
364            s = set(p.split('_pd')[0] for p in pars)
365            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
366            sys.exit(1)
367        set_pars[k] = float(v) if not v.endswith('type') else v
368
369    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
370
371if __name__ == "__main__":
372    main()
Note: See TracBrowser for help on using the repository browser.