source: sasmodels/sasmodels/compare.py @ 73a3e22

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 73a3e22 was 73a3e22, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

make -1d the default for compare.sh

  • Property mode set to 100755
File size: 15.7 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8import datetime
9
10import numpy as np
11
12ROOT = dirname(__file__)
13sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
14
15
16from . import core
17from . import kerneldll
18from . import models
19from .data import plot_theory, empty_data1D, empty_data2D
20from .direct_model import DirectModel
21from .convert import revert_model
22kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
23
24# List of available models
25MODELS = [basename(f)[:-3]
26          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
27
28# CRUFT python 2.6
29if not hasattr(datetime.timedelta, 'total_seconds'):
30    def delay(dt):
31        """Return number date-time delta as number seconds"""
32        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
33else:
34    def delay(dt):
35        """Return number date-time delta as number seconds"""
36        return dt.total_seconds()
37
38
39def tic():
40    """
41    Timer function.
42
43    Use "toc=tic()" to start the clock and "toc()" to measure
44    a time interval.
45    """
46    then = datetime.datetime.now()
47    return lambda: delay(datetime.datetime.now() - then)
48
49
50def set_beam_stop(data, radius, outer=None):
51    """
52    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
53
54    Note: this function does not use the sasview package
55    """
56    if hasattr(data, 'qx_data'):
57        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
58        data.mask = (q < radius)
59        if outer is not None:
60            data.mask |= (q >= outer)
61    else:
62        data.mask = (data.x < radius)
63        if outer is not None:
64            data.mask |= (data.x >= outer)
65
66
67def sasview_model(model_definition, **pars):
68    """
69    Load a sasview model given the model name.
70    """
71    # convert model parameters from sasmodel form to sasview form
72    #print "old",sorted(pars.items())
73    modelname, pars = revert_model(model_definition, pars)
74    #print "new",sorted(pars.items())
75    sas = __import__('sas.models.'+modelname)
76    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
77    if ModelClass is None:
78        raise ValueError("could not find model %r in sas.models"%modelname)
79    model = ModelClass()
80
81    for k,v in pars.items():
82        if k.endswith("_pd"):
83            model.dispersion[k[:-3]]['width'] = v
84        elif k.endswith("_pd_n"):
85            model.dispersion[k[:-5]]['npts'] = v
86        elif k.endswith("_pd_nsigma"):
87            model.dispersion[k[:-10]]['nsigmas'] = v
88        elif k.endswith("_pd_type"):
89            model.dispersion[k[:-8]]['type'] = v
90        else:
91            model.setParam(k, v)
92    return model
93
94def randomize(p, v):
95    """
96    Randomizing parameter.
97
98    Guess the parameter type from name.
99    """
100    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
101        return v
102    elif any(s in p for s in ('theta','phi','psi')):
103        # orientation in [-180,180], orientation pd in [0,45]
104        if p.endswith('_pd'):
105            return 45*np.random.rand()
106        else:
107            return 360*np.random.rand() - 180
108    elif 'sld' in p:
109        # sld in in [-0.5,10]
110        return 10.5*np.random.rand() - 0.5
111    elif p.endswith('_pd'):
112        # length pd in [0,1]
113        return np.random.rand()
114    else:
115        # values from 0 to 2*x for all other parameters
116        return 2*np.random.rand()*(v if v != 0 else 1)
117
118def randomize_model(name, pars, seed=None):
119    if seed is None:
120        seed = np.random.randint(1e9)
121    np.random.seed(seed)
122    # Note: the sort guarantees order of calls to random number generator
123    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
124    # The capped cylinder model has a constraint on its parameters
125    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
126        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
127    return pars, seed
128
129def parlist(pars):
130    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
131
132def suppress_pd(pars):
133    """
134    Suppress theta_pd for now until the normalization is resolved.
135
136    May also suppress complete polydispersity of the model to test
137    models more quickly.
138    """
139    for p in pars:
140        if p.endswith("_pd"): pars[p] = 0
141
142def eval_sasview(model_definition, pars, data, Nevals=1):
143    from sas.models.qsmearing import smear_selection
144    model = sasview_model(model_definition, **pars)
145    smearer = smear_selection(data, model=model)
146    value = None  # silence the linter
147    toc = tic()
148    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
149        if hasattr(data, 'qx_data'):
150            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
151            index = ((~data.mask) & (~np.isnan(data.data))
152                     & (q >= data.qmin) & (q <= data.qmax))
153            if smearer is not None:
154                smearer.model = model  # because smear_selection has a bug
155                smearer.accuracy = data.accuracy
156                smearer.set_index(index)
157                value = smearer.get_value()
158            else:
159                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
160        else:
161            value = model.evalDistribution(data.x)
162            if smearer is not None:
163                value = smearer(value)
164    average_time = toc()*1000./Nevals
165    return value, average_time
166
167def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
168    try:
169        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
170    except Exception,exc:
171        print exc
172        print "... trying again with single precision"
173        model = core.load_model(model_definition, dtype='single', platform="ocl")
174    calculator = DirectModel(data, model, cutoff=cutoff)
175    value = None  # silence the linter
176    toc = tic()
177    for _ in range(max(Nevals, 1)):  # force at least one eval
178        value = calculator(**pars)
179    average_time = toc()*1000./Nevals
180    return value, average_time
181
182
183def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
184    model = core.load_model(model_definition, dtype=dtype, platform="dll")
185    calculator = DirectModel(data, model, cutoff=cutoff)
186    value = None  # silence the linter
187    toc = tic()
188    for _ in range(max(Nevals, 1)):  # force at least one eval
189        value = calculator(**pars)
190    average_time = toc()*1000./Nevals
191    return value, average_time
192
193
194def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
195    if is2D:
196        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
197        data.accuracy = accuracy
198        set_beam_stop(data, 0.004)
199        index = ~data.mask
200    else:
201        if view == 'log':
202            qmax = math.log10(qmax)
203            q = np.logspace(qmax-3, qmax, Nq)
204        else:
205            q = np.linspace(0.001*qmax, qmax, Nq)
206        data = empty_data1D(q, resolution=resolution)
207        index = slice(None, None)
208    return data, index
209
210def compare(name, pars, Ncpu, Nocl, opts, set_pars):
211    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
212
213    opt_values = dict(split
214                      for s in opts for split in ((s.split('='),))
215                      if len(split) == 2)
216    # Sort out data
217    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
218    Nq = int(opt_values.get('-Nq', '128'))
219    res = float(opt_values.get('-res', '0'))
220    accuracy = opt_values.get('-accuracy', 'Low')
221    is2D = "-2d" in opts
222    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
223
224
225    # modelling accuracy is determined by dtype and cutoff
226    dtype = 'double' if '-double' in opts else 'single'
227    cutoff = float(opt_values.get('-cutoff','1e-5'))
228
229    # randomize parameters
230    #pars.update(set_pars)  # set value before random to control range
231    if '-random' in opts or '-random' in opt_values:
232        seed = int(opt_values['-random']) if '-random' in opt_values else None
233        pars, seed = randomize_model(name, pars, seed=seed)
234        print "Randomize using -random=%i"%seed
235    pars.update(set_pars)  # set value after random to control value
236
237    # parameter selection
238    if '-mono' in opts:
239        suppress_pd(pars)
240    if '-pars' in opts:
241        print "pars",parlist(pars)
242
243    model_definition = core.load_model_definition(name)
244    # OpenCl calculation
245    if Nocl > 0:
246        ocl, ocl_time = eval_opencl(model_definition, pars, data,
247                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
248        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
249        #print max(ocl), min(ocl)
250
251    # ctypes/sasview calculation
252    if Ncpu > 0 and "-ctypes" in opts:
253        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
254                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
255        comp = "ctypes"
256        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
257    elif Ncpu > 0:
258        try:
259            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
260            comp = "sasview"
261            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background'])
262            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
263        except ImportError:
264            Ncpu = 0
265
266    # Compare, but only if computing both forms
267    if Nocl > 0 and Ncpu > 0:
268        #print "speedup %.2g"%(cpu_time/ocl_time)
269        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
270        #cpu *= max(ocl/cpu)
271        resid = (ocl - cpu)
272        relerr = resid/cpu
273        #bad = (relerr>1e-4)
274        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
275        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
276        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
277
278    # Plot if requested
279    if '-noplot' in opts: return
280    import matplotlib.pyplot as plt
281    if Ncpu > 0:
282        if Nocl > 0: plt.subplot(131)
283        plot_theory(data, cpu, view=view, plot_data=False)
284        plt.title("%s t=%.1f ms"%(comp,cpu_time))
285        #cbar_title = "log I"
286    if Nocl > 0:
287        if Ncpu > 0: plt.subplot(132)
288        plot_theory(data, ocl, view=view, plot_data=False)
289        plt.title("opencl t=%.1f ms"%ocl_time)
290        #cbar_title = "log I"
291    if Ncpu > 0 and Nocl > 0:
292        plt.subplot(133)
293        if '-abs' in opts:
294            err,errstr,errview = resid, "abs err", "linear"
295        else:
296            err,errstr,errview = abs(relerr), "rel err", "log"
297        #err,errstr = ocl/cpu,"ratio"
298        plot_theory(data, None, resid=err, view=errview, plot_data=False)
299        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
300        #cbar_title = errstr if errview=="linear" else "log "+errstr
301    #if is2D:
302    #    h = plt.colorbar()
303    #    h.ax.set_title(cbar_title)
304
305    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
306        plt.figure()
307        v = relerr
308        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
309        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
310        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
311        plt.ylabel('P(err)')
312        plt.title('Comparison of single and double precision models for %s'%name)
313
314    plt.show()
315
316def _print_stats(label, err):
317    sorted_err = np.sort(abs(err))
318    p50 = int((len(err)-1)*0.50)
319    p98 = int((len(err)-1)*0.98)
320    data = [
321        "max:%.3e"%sorted_err[-1],
322        "median:%.3e"%sorted_err[p50],
323        "98%%:%.3e"%sorted_err[p98],
324        "rms:%.3e"%np.sqrt(np.mean(err**2)),
325        "zero-offset:%+.3e"%np.mean(err),
326        ]
327    print label,"  ".join(data)
328
329
330
331# ===========================================================================
332#
333USAGE="""
334usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
335
336Compare the speed and value for a model between the SasView original and the
337OpenCL rewrite.
338
339model is the name of the model to compare (see below).
340Nopencl is the number of times to run the OpenCL model (default=5)
341Nsasview is the number of times to run the Sasview model (default=1)
342
343Options (* for default):
344
345    -plot*/-noplot plots or suppress the plot of the model
346    -single*/-double uses double precision for comparison
347    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
348    -Nq=128 sets the number of Q points in the data set
349    -1d*/-2d computes 1d or 2d data
350    -preset*/-random[=seed] preset or random parameters
351    -mono/-poly* force monodisperse/polydisperse
352    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
353    -cutoff=1e-5* cutoff value for including a point in polydispersity
354    -pars/-nopars* prints the parameter set or not
355    -abs/-rel* plot relative or absolute error
356    -linear/-log/-q4 intensity scaling
357    -hist/-nohist* plot histogram of relative error
358    -res=0 sets the resolution width dQ/Q if calculating with resolution
359    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
360
361Key=value pairs allow you to set specific values to any of the model
362parameters.
363
364Available models:
365"""
366
367
368NAME_OPTIONS = set([
369    'plot','noplot',
370    'single','double',
371    'lowq','midq','highq','exq',
372    '2d','1d',
373    'preset','random',
374    'poly','mono',
375    'sasview','ctypes',
376    'nopars','pars',
377    'rel','abs',
378    'linear', 'log', 'q4',
379    'hist','nohist',
380    ])
381VALUE_OPTIONS = [
382    # Note: random is both a name option and a value option
383    'cutoff', 'random', 'Nq', 'res', 'accuracy',
384    ]
385
386def columnize(L, indent="", width=79):
387    column_width = max(len(w) for w in L) + 1
388    num_columns = (width - len(indent)) // column_width
389    num_rows = len(L) // num_columns
390    L = L + [""] * (num_rows*num_columns - len(L))
391    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
392    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
393             for row in zip(*columns)]
394    output = indent + ("\n"+indent).join(lines)
395    return output
396
397
398def get_demo_pars(name):
399    __import__('sasmodels.models.'+name)
400    model = getattr(models, name)
401    pars = getattr(model, 'demo', None)
402    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
403    return pars
404
405def main():
406    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
407    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
408    models = "\n    ".join("%-15s"%v for v in MODELS)
409    if len(args) == 0:
410        print(USAGE)
411        print(columnize(MODELS, indent="  "))
412        sys.exit(1)
413    if args[0] not in MODELS:
414        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
415        sys.exit(1)
416
417    invalid = [o[1:] for o in opts
418               if o[1:] not in NAME_OPTIONS
419                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
420    if invalid:
421        print "Invalid options: %s"%(", ".join(invalid))
422        sys.exit(1)
423
424    # Get demo parameters from model definition, or use default parameters
425    # if model does not define demo parameters
426    name = args[0]
427    pars = get_demo_pars(name)
428
429    Nopencl = int(args[1]) if len(args) > 1 else 5
430    Nsasview = int(args[2]) if len(args) > 2 else 1
431
432    # Fill in default polydispersity parameters
433    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
434    for p in pds:
435        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
436        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
437
438    # Fill in parameters given on the command line
439    set_pars = {}
440    for arg in args[3:]:
441        k,v = arg.split('=')
442        if k not in pars:
443            # extract base name without distribution
444            s = set(p.split('_pd')[0] for p in pars)
445            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
446            sys.exit(1)
447        set_pars[k] = float(v) if not v.endswith('type') else v
448
449    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
450
451if __name__ == "__main__":
452    main()
Note: See TracBrowser for help on using the repository browser.