source: sasmodels/compare.py @ 7cf2cfd

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 7cf2cfd was 7cf2cfd, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor compare.py so that bumps/sasview not required for simple tests

  • Property mode set to 100755
File size: 15.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8import datetime
9
10import numpy as np
11
12ROOT = dirname(__file__)
13sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
14
15
16from sasmodels import core
17from sasmodels import kerneldll
18from sasmodels.data import plot_theory, empty_data1D, empty_data2D
19from sasmodels.direct_model import DirectModel
20from sasmodels.convert import revert_model
21kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
22
23# List of available models
24MODELS = [basename(f)[:-3]
25          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
26
27# CRUFT python 2.6
28if not hasattr(datetime.timedelta, 'total_seconds'):
29    def delay(dt):
30        """Return number date-time delta as number seconds"""
31        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
32else:
33    def delay(dt):
34        """Return number date-time delta as number seconds"""
35        return dt.total_seconds()
36
37
38def tic():
39    """
40    Timer function.
41
42    Use "toc=tic()" to start the clock and "toc()" to measure
43    a time interval.
44    """
45    then = datetime.datetime.now()
46    return lambda: delay(datetime.datetime.now() - then)
47
48
49def set_beam_stop(data, radius, outer=None):
50    """
51    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
52
53    Note: this function does not use the sasview package
54    """
55    if hasattr(data, 'qx_data'):
56        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
57        data.mask = (q < radius)
58        if outer is not None:
59            data.mask |= (q >= outer)
60    else:
61        data.mask = (data.x < radius)
62        if outer is not None:
63            data.mask |= (data.x >= outer)
64
65
66def sasview_model(model_definition, **pars):
67    """
68    Load a sasview model given the model name.
69    """
70    # convert model parameters from sasmodel form to sasview form
71    #print "old",sorted(pars.items())
72    modelname, pars = revert_model(model_definition, pars)
73    #print "new",sorted(pars.items())
74    sas = __import__('sas.models.'+modelname)
75    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
76    if ModelClass is None:
77        raise ValueError("could not find model %r in sas.models"%modelname)
78    model = ModelClass()
79
80    for k,v in pars.items():
81        if k.endswith("_pd"):
82            model.dispersion[k[:-3]]['width'] = v
83        elif k.endswith("_pd_n"):
84            model.dispersion[k[:-5]]['npts'] = v
85        elif k.endswith("_pd_nsigma"):
86            model.dispersion[k[:-10]]['nsigmas'] = v
87        elif k.endswith("_pd_type"):
88            model.dispersion[k[:-8]]['type'] = v
89        else:
90            model.setParam(k, v)
91    return model
92
93def randomize(p, v):
94    """
95    Randomizing parameter.
96
97    Guess the parameter type from name.
98    """
99    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
100        return v
101    elif any(s in p for s in ('theta','phi','psi')):
102        # orientation in [-180,180], orientation pd in [0,45]
103        if p.endswith('_pd'):
104            return 45*np.random.rand()
105        else:
106            return 360*np.random.rand() - 180
107    elif 'sld' in p:
108        # sld in in [-0.5,10]
109        return 10.5*np.random.rand() - 0.5
110    elif p.endswith('_pd'):
111        # length pd in [0,1]
112        return np.random.rand()
113    else:
114        # values from 0 to 2*x for all other parameters
115        return 2*np.random.rand()*(v if v != 0 else 1)
116
117def randomize_model(name, pars, seed=None):
118    if seed is None:
119        seed = np.random.randint(1e9)
120    np.random.seed(seed)
121    # Note: the sort guarantees order of calls to random number generator
122    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
123    # The capped cylinder model has a constraint on its parameters
124    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
125        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
126    return pars, seed
127
128def parlist(pars):
129    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
130
131def suppress_pd(pars):
132    """
133    Suppress theta_pd for now until the normalization is resolved.
134
135    May also suppress complete polydispersity of the model to test
136    models more quickly.
137    """
138    for p in pars:
139        if p.endswith("_pd"): pars[p] = 0
140
141def eval_sasview(model_definition, pars, data, Nevals=1):
142    from sas.models.qsmearing import smear_selection
143    model = sasview_model(model_definition, **pars)
144    smearer = smear_selection(data, model=model)
145    value = None  # silence the linter
146    toc = tic()
147    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
148        if hasattr(data, 'qx_data'):
149            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
150            index = ((~data.mask) & (~np.isnan(data.data))
151                     & (q >= data.qmin) & (q <= data.qmax))
152            if smearer is not None:
153                smearer.model = model  # because smear_selection has a bug
154                smearer.accuracy = data.accuracy
155                smearer.set_index(index)
156                value = smearer.get_value()
157            else:
158                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
159        else:
160            value = model.evalDistribution(data.x)
161            if smearer is not None:
162                value = smearer(value)
163    average_time = toc()*1000./Nevals
164    return value, average_time
165
166def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
167    try:
168        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
169    except Exception,exc:
170        print exc
171        print "... trying again with single precision"
172        model = core.load_model(model_definition, dtype='single', platform="ocl")
173    calculator = DirectModel(data, model, cutoff=cutoff)
174    value = None  # silence the linter
175    toc = tic()
176    for _ in range(max(Nevals, 1)):  # force at least one eval
177        value = calculator(**pars)
178    average_time = toc()*1000./Nevals
179    return value, average_time
180
181
182def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
183    model = core.load_model(model_definition, dtype=dtype, platform="dll")
184    calculator = DirectModel(data, model, cutoff=cutoff)
185    value = None  # silence the linter
186    toc = tic()
187    for _ in range(max(Nevals, 1)):  # force at least one eval
188        value = calculator(**pars)
189    average_time = toc()*1000./Nevals
190    return value, average_time
191
192
193def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
194    if is2D:
195        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
196        data.accuracy = accuracy
197        set_beam_stop(data, 0.004)
198        index = ~data.mask
199    else:
200        if view == 'log':
201            qmax = math.log10(qmax)
202            q = np.logspace(qmax-3, qmax, Nq)
203        else:
204            q = np.linspace(0.001*qmax, qmax, Nq)
205        data = empty_data1D(q, resolution=resolution)
206        index = slice(None, None)
207    return data, index
208
209def compare(name, pars, Ncpu, Nocl, opts, set_pars):
210    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
211
212    opt_values = dict(split
213                      for s in opts for split in ((s.split('='),))
214                      if len(split) == 2)
215    # Sort out data
216    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
217    Nq = int(opt_values.get('-Nq', '128'))
218    res = float(opt_values.get('-res', '0'))
219    accuracy = opt_values.get('-accuracy', 'Low')
220    is2D = not "-1d" in opts
221    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
222
223
224    # modelling accuracy is determined by dtype and cutoff
225    dtype = 'double' if '-double' in opts else 'single'
226    cutoff = float(opt_values.get('-cutoff','1e-5'))
227
228    # randomize parameters
229    #pars.update(set_pars)  # set value before random to control range
230    if '-random' in opts or '-random' in opt_values:
231        seed = int(opt_values['-random']) if '-random' in opt_values else None
232        pars, seed = randomize_model(name, pars, seed=seed)
233        print "Randomize using -random=%i"%seed
234    pars.update(set_pars)  # set value after random to control value
235
236    # parameter selection
237    if '-mono' in opts:
238        suppress_pd(pars)
239    if '-pars' in opts:
240        print "pars",parlist(pars)
241
242    model_definition = core.load_model_definition(name)
243    # OpenCl calculation
244    if Nocl > 0:
245        ocl, ocl_time = eval_opencl(model_definition, pars, data,
246                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
247        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
248        #print max(ocl), min(ocl)
249
250    # ctypes/sasview calculation
251    if Ncpu > 0 and "-ctypes" in opts:
252        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
253                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
254        comp = "ctypes"
255        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
256    elif Ncpu > 0:
257        try:
258            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
259            comp = "sasview"
260            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background'])
261            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
262        except ImportError:
263            Ncpu = 0
264
265    # Compare, but only if computing both forms
266    if Nocl > 0 and Ncpu > 0:
267        #print "speedup %.2g"%(cpu_time/ocl_time)
268        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
269        #cpu *= max(ocl/cpu)
270        resid = (ocl - cpu)
271        relerr = resid/cpu
272        #bad = (relerr>1e-4)
273        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
274        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
275        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
276
277    # Plot if requested
278    if '-noplot' in opts: return
279    import matplotlib.pyplot as plt
280    if Ncpu > 0:
281        if Nocl > 0: plt.subplot(131)
282        plot_theory(data, cpu, view=view, plot_data=False)
283        plt.title("%s t=%.1f ms"%(comp,cpu_time))
284        #cbar_title = "log I"
285    if Nocl > 0:
286        if Ncpu > 0: plt.subplot(132)
287        plot_theory(data, ocl, view=view, plot_data=False)
288        plt.title("opencl t=%.1f ms"%ocl_time)
289        #cbar_title = "log I"
290    if Ncpu > 0 and Nocl > 0:
291        plt.subplot(133)
292        if '-abs' in opts:
293            err,errstr,errview = resid, "abs err", "linear"
294        else:
295            err,errstr,errview = abs(relerr), "rel err", "log"
296        #err,errstr = ocl/cpu,"ratio"
297        plot_theory(data, None, resid=err, view=errview, plot_data=False)
298        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
299        #cbar_title = errstr if errview=="linear" else "log "+errstr
300    #if is2D:
301    #    h = plt.colorbar()
302    #    h.ax.set_title(cbar_title)
303
304    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
305        plt.figure()
306        v = relerr
307        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
308        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
309        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
310        plt.ylabel('P(err)')
311        plt.title('Comparison of single and double precision models for %s'%name)
312
313    plt.show()
314
315def _print_stats(label, err):
316    sorted_err = np.sort(abs(err))
317    p50 = int((len(err)-1)*0.50)
318    p98 = int((len(err)-1)*0.98)
319    data = [
320        "max:%.3e"%sorted_err[-1],
321        "median:%.3e"%sorted_err[p50],
322        "98%%:%.3e"%sorted_err[p98],
323        "rms:%.3e"%np.sqrt(np.mean(err**2)),
324        "zero-offset:%+.3e"%np.mean(err),
325        ]
326    print label,"  ".join(data)
327
328
329
330# ===========================================================================
331#
332USAGE="""
333usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
334
335Compare the speed and value for a model between the SasView original and the
336OpenCL rewrite.
337
338model is the name of the model to compare (see below).
339Nopencl is the number of times to run the OpenCL model (default=5)
340Nsasview is the number of times to run the Sasview model (default=1)
341
342Options (* for default):
343
344    -plot*/-noplot plots or suppress the plot of the model
345    -single*/-double uses double precision for comparison
346    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
347    -Nq=128 sets the number of Q points in the data set
348    -1d/-2d* computes 1d or 2d data
349    -preset*/-random[=seed] preset or random parameters
350    -mono/-poly* force monodisperse/polydisperse
351    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
352    -cutoff=1e-5* cutoff value for including a point in polydispersity
353    -pars/-nopars* prints the parameter set or not
354    -abs/-rel* plot relative or absolute error
355    -linear/-log/-q4 intensity scaling
356    -hist/-nohist* plot histogram of relative error
357    -res=0 sets the resolution width dQ/Q if calculating with resolution
358    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
359
360Key=value pairs allow you to set specific values to any of the model
361parameters.
362
363Available models:
364"""
365
366
367NAME_OPTIONS = set([
368    'plot','noplot',
369    'single','double',
370    'lowq','midq','highq','exq',
371    '2d','1d',
372    'preset','random',
373    'poly','mono',
374    'sasview','ctypes',
375    'nopars','pars',
376    'rel','abs',
377    'linear', 'log', 'q4',
378    'hist','nohist',
379    ])
380VALUE_OPTIONS = [
381    # Note: random is both a name option and a value option
382    'cutoff', 'random', 'Nq', 'res', 'accuracy',
383    ]
384
385def columnize(L, indent="", width=79):
386    column_width = max(len(w) for w in L) + 1
387    num_columns = (width - len(indent)) // column_width
388    num_rows = len(L) // num_columns
389    L = L + [""] * (num_rows*num_columns - len(L))
390    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
391    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
392             for row in zip(*columns)]
393    output = indent + ("\n"+indent).join(lines)
394    return output
395
396
397def get_demo_pars(name):
398    import sasmodels.models
399    __import__('sasmodels.models.'+name)
400    model = getattr(sasmodels.models, name)
401    pars = getattr(model, 'demo', None)
402    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
403    return pars
404
405def main():
406    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
407    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
408    models = "\n    ".join("%-15s"%v for v in MODELS)
409    if len(args) == 0:
410        print(USAGE)
411        print(columnize(MODELS, indent="  "))
412        sys.exit(1)
413    if args[0] not in MODELS:
414        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
415        sys.exit(1)
416
417    invalid = [o[1:] for o in opts
418               if o[1:] not in NAME_OPTIONS
419                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
420    if invalid:
421        print "Invalid options: %s"%(", ".join(invalid))
422        sys.exit(1)
423
424    # Get demo parameters from model definition, or use default parameters
425    # if model does not define demo parameters
426    name = args[0]
427    pars = get_demo_pars(name)
428
429    Nopencl = int(args[1]) if len(args) > 1 else 5
430    Nsasview = int(args[2]) if len(args) > 2 else 1
431
432    # Fill in default polydispersity parameters
433    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
434    for p in pds:
435        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
436        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
437
438    # Fill in parameters given on the command line
439    set_pars = {}
440    for arg in args[3:]:
441        k,v = arg.split('=')
442        if k not in pars:
443            # extract base name without distribution
444            s = set(p.split('_pd')[0] for p in pars)
445            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
446            sys.exit(1)
447        set_pars[k] = float(v) if not v.endswith('type') else v
448
449    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
450
451if __name__ == "__main__":
452    main()
Note: See TracBrowser for help on using the repository browser.