source: sasmodels/sasmodels/compare.py @ 5753e4e

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5753e4e was 5753e4e, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

support compare/multi_compare on windows

  • Property mode set to 100755
File size: 16.2 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8import datetime
9import traceback
10
11import numpy as np
12
13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
17from . import core
18from . import kerneldll
19from . import generate
20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
22from .convert import revert_model
23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
24
25# List of available models
26MODELS = [basename(f)[:-3]
27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
28
29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
67
68def sasview_model(model_definition, **pars):
69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
73    #print "old",sorted(pars.items())
74    modelname, pars = revert_model(model_definition, pars)
75    #print "new",sorted(pars.items())
76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
78    if ModelClass is None:
79        raise ValueError("could not find model %r in sas.models"%modelname)
80    model = ModelClass()
81
82    for k,v in pars.items():
83        if k.endswith("_pd"):
84            model.dispersion[k[:-3]]['width'] = v
85        elif k.endswith("_pd_n"):
86            model.dispersion[k[:-5]]['npts'] = v
87        elif k.endswith("_pd_nsigma"):
88            model.dispersion[k[:-10]]['nsigmas'] = v
89        elif k.endswith("_pd_type"):
90            model.dispersion[k[:-8]]['type'] = v
91        else:
92            model.setParam(k, v)
93    return model
94
95def randomize(p, v):
96    """
97    Randomizing parameter.
98
99    Guess the parameter type from name.
100    """
101    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
102        return v
103    elif any(s in p for s in ('theta','phi','psi')):
104        # orientation in [-180,180], orientation pd in [0,45]
105        if p.endswith('_pd'):
106            return 45*np.random.rand()
107        else:
108            return 360*np.random.rand() - 180
109    elif 'sld' in p:
110        # sld in in [-0.5,10]
111        return 10.5*np.random.rand() - 0.5
112    elif p.endswith('_pd'):
113        # length pd in [0,1]
114        return np.random.rand()
115    else:
116        # values from 0 to 2*x for all other parameters
117        return 2*np.random.rand()*(v if v != 0 else 1)
118
119def randomize_model(pars, seed=None):
120    if seed is None:
121        seed = np.random.randint(1e9)
122    np.random.seed(seed)
123    # Note: the sort guarantees order of calls to random number generator
124    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
125
126    return pars, seed
127
128def constrain_pars(model_definition, pars):
129    name = model_definition.name
130    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
131        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
132
133    # These constraints are only needed for comparison to sasview
134    if name in ('teubner_strey','broad_peak'):
135        del pars['scale']
136    if name in ('guinier',):
137        del pars['background']
138    if getattr(model_definition, 'category', None) == 'structure-factor':
139        del pars['scale'], pars['background']
140
141
142def parlist(pars):
143    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
144
145def suppress_pd(pars):
146    """
147    Suppress theta_pd for now until the normalization is resolved.
148
149    May also suppress complete polydispersity of the model to test
150    models more quickly.
151    """
152    for p in pars:
153        if p.endswith("_pd"): pars[p] = 0
154
155def eval_sasview(model_definition, pars, data, Nevals=1):
156    from sas.models.qsmearing import smear_selection
157    model = sasview_model(model_definition, **pars)
158    smearer = smear_selection(data, model=model)
159    value = None  # silence the linter
160    toc = tic()
161    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
162        if hasattr(data, 'qx_data'):
163            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
164            index = ((~data.mask) & (~np.isnan(data.data))
165                     & (q >= data.qmin) & (q <= data.qmax))
166            if smearer is not None:
167                smearer.model = model  # because smear_selection has a bug
168                smearer.accuracy = data.accuracy
169                smearer.set_index(index)
170                value = smearer.get_value()
171            else:
172                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
173        else:
174            value = model.evalDistribution(data.x)
175            if smearer is not None:
176                value = smearer(value)
177    average_time = toc()*1000./Nevals
178    return value, average_time
179
180def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
181    try:
182        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
183    except Exception,exc:
184        print exc
185        print "... trying again with single precision"
186        model = core.load_model(model_definition, dtype='single', platform="ocl")
187    calculator = DirectModel(data, model, cutoff=cutoff)
188    value = None  # silence the linter
189    toc = tic()
190    for _ in range(max(Nevals, 1)):  # force at least one eval
191        value = calculator(**pars)
192    average_time = toc()*1000./Nevals
193    return value, average_time
194
195
196def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
197    model = core.load_model(model_definition, dtype=dtype, platform="dll")
198    calculator = DirectModel(data, model, cutoff=cutoff)
199    value = None  # silence the linter
200    toc = tic()
201    for _ in range(max(Nevals, 1)):  # force at least one eval
202        value = calculator(**pars)
203    average_time = toc()*1000./Nevals
204    return value, average_time
205
206
207def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
208    if is2D:
209        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
210        data.accuracy = accuracy
211        set_beam_stop(data, 0.004)
212        index = ~data.mask
213    else:
214        if view == 'log':
215            qmax = math.log10(qmax)
216            q = np.logspace(qmax-3, qmax, Nq)
217        else:
218            q = np.linspace(0.001*qmax, qmax, Nq)
219        data = empty_data1D(q, resolution=resolution)
220        index = slice(None, None)
221    return data, index
222
223def compare(name, pars, Ncpu, Nocl, opts, set_pars):
224    model_definition = core.load_model_definition(name)
225
226    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
227
228    opt_values = dict(split
229                      for s in opts for split in ((s.split('='),))
230                      if len(split) == 2)
231    # Sort out data
232    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
233    Nq = int(opt_values.get('-Nq', '128'))
234    res = float(opt_values.get('-res', '0'))
235    accuracy = opt_values.get('-accuracy', 'Low')
236    is2D = "-2d" in opts
237    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
238
239
240    # modelling accuracy is determined by dtype and cutoff
241    dtype = 'double' if '-double' in opts else 'single'
242    cutoff = float(opt_values.get('-cutoff','1e-5'))
243
244    # randomize parameters
245    #pars.update(set_pars)  # set value before random to control range
246    if '-random' in opts or '-random' in opt_values:
247        seed = int(opt_values['-random']) if '-random' in opt_values else None
248        pars, seed = randomize_model(pars, seed=seed)
249        constrain_pars(model_definition, pars)
250        print "Randomize using -random=%i"%seed
251    pars.update(set_pars)  # set value after random to control value
252
253    # parameter selection
254    if '-mono' in opts:
255        suppress_pd(pars)
256    if '-pars' in opts:
257        print "pars",parlist(pars)
258
259    # OpenCl calculation
260    if Nocl > 0:
261        ocl, ocl_time = eval_opencl(model_definition, pars, data,
262                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
263        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
264        #print "ocl", ocl
265        #print max(ocl), min(ocl)
266
267    # ctypes/sasview calculation
268    if Ncpu > 0 and "-ctypes" in opts:
269        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
270                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
271        comp = "ctypes"
272        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
273    elif Ncpu > 0:
274        try:
275            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
276            comp = "sasview"
277            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background'])
278            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
279            #print "sasview",cpu
280        except ImportError:
281            traceback.print_exc()
282            Ncpu = 0
283
284    # Compare, but only if computing both forms
285    if Nocl > 0 and Ncpu > 0:
286        #print "speedup %.2g"%(cpu_time/ocl_time)
287        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
288        #cpu *= max(ocl/cpu)
289        resid = (ocl - cpu)
290        relerr = resid/cpu
291        #bad = (relerr>1e-4)
292        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
293        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
294        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
295
296    # Plot if requested
297    if '-noplot' in opts: return
298    import matplotlib.pyplot as plt
299    if Ncpu > 0:
300        if Nocl > 0: plt.subplot(131)
301        plot_theory(data, cpu, view=view, plot_data=False)
302        plt.title("%s t=%.1f ms"%(comp,cpu_time))
303        #cbar_title = "log I"
304    if Nocl > 0:
305        if Ncpu > 0: plt.subplot(132)
306        plot_theory(data, ocl, view=view, plot_data=False)
307        plt.title("opencl t=%.1f ms"%ocl_time)
308        #cbar_title = "log I"
309    if Ncpu > 0 and Nocl > 0:
310        plt.subplot(133)
311        if '-abs' in opts:
312            err,errstr,errview = resid, "abs err", "linear"
313        else:
314            err,errstr,errview = abs(relerr), "rel err", "log"
315        #err,errstr = ocl/cpu,"ratio"
316        plot_theory(data, None, resid=err, view=errview, plot_data=False)
317        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
318        #cbar_title = errstr if errview=="linear" else "log "+errstr
319    #if is2D:
320    #    h = plt.colorbar()
321    #    h.ax.set_title(cbar_title)
322
323    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
324        plt.figure()
325        v = relerr
326        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
327        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
328        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
329        plt.ylabel('P(err)')
330        plt.title('Comparison of single and double precision models for %s'%name)
331
332    plt.show()
333
334def _print_stats(label, err):
335    sorted_err = np.sort(abs(err))
336    p50 = int((len(err)-1)*0.50)
337    p98 = int((len(err)-1)*0.98)
338    data = [
339        "max:%.3e"%sorted_err[-1],
340        "median:%.3e"%sorted_err[p50],
341        "98%%:%.3e"%sorted_err[p98],
342        "rms:%.3e"%np.sqrt(np.mean(err**2)),
343        "zero-offset:%+.3e"%np.mean(err),
344        ]
345    print label,"  ".join(data)
346
347
348
349# ===========================================================================
350#
351USAGE="""
352usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
353
354Compare the speed and value for a model between the SasView original and the
355OpenCL rewrite.
356
357model is the name of the model to compare (see below).
358Nopencl is the number of times to run the OpenCL model (default=5)
359Nsasview is the number of times to run the Sasview model (default=1)
360
361Options (* for default):
362
363    -plot*/-noplot plots or suppress the plot of the model
364    -single*/-double uses double precision for comparison
365    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
366    -Nq=128 sets the number of Q points in the data set
367    -1d*/-2d computes 1d or 2d data
368    -preset*/-random[=seed] preset or random parameters
369    -mono/-poly* force monodisperse/polydisperse
370    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
371    -cutoff=1e-5* cutoff value for including a point in polydispersity
372    -pars/-nopars* prints the parameter set or not
373    -abs/-rel* plot relative or absolute error
374    -linear/-log/-q4 intensity scaling
375    -hist/-nohist* plot histogram of relative error
376    -res=0 sets the resolution width dQ/Q if calculating with resolution
377    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
378
379Key=value pairs allow you to set specific values to any of the model
380parameters.
381
382Available models:
383"""
384
385
386NAME_OPTIONS = set([
387    'plot','noplot',
388    'single','double',
389    'lowq','midq','highq','exq',
390    '2d','1d',
391    'preset','random',
392    'poly','mono',
393    'sasview','ctypes',
394    'nopars','pars',
395    'rel','abs',
396    'linear', 'log', 'q4',
397    'hist','nohist',
398    ])
399VALUE_OPTIONS = [
400    # Note: random is both a name option and a value option
401    'cutoff', 'random', 'Nq', 'res', 'accuracy',
402    ]
403
404def columnize(L, indent="", width=79):
405    column_width = max(len(w) for w in L) + 1
406    num_columns = (width - len(indent)) // column_width
407    num_rows = len(L) // num_columns
408    L = L + [""] * (num_rows*num_columns - len(L))
409    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
410    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
411             for row in zip(*columns)]
412    output = indent + ("\n"+indent).join(lines)
413    return output
414
415
416def get_demo_pars(model_definition):
417    info = generate.make_info(model_definition)
418    pars = dict((p[0],p[2]) for p in info['parameters'])
419    pars.update(info['demo'])
420    return pars
421
422def main():
423    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
424    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
425    models = "\n    ".join("%-15s"%v for v in MODELS)
426    if len(args) == 0:
427        print(USAGE)
428        print(columnize(MODELS, indent="  "))
429        sys.exit(1)
430    if args[0] not in MODELS:
431        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
432        sys.exit(1)
433
434    invalid = [o[1:] for o in opts
435               if o[1:] not in NAME_OPTIONS
436                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
437    if invalid:
438        print "Invalid options: %s"%(", ".join(invalid))
439        sys.exit(1)
440
441    # Get demo parameters from model definition, or use default parameters
442    # if model does not define demo parameters
443    name = args[0]
444    model_definition = core.load_model_definition(name)
445    pars = get_demo_pars(model_definition)
446
447    Nopencl = int(args[1]) if len(args) > 1 else 5
448    Nsasview = int(args[2]) if len(args) > 2 else 1
449
450    # Fill in default polydispersity parameters
451    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
452    for p in pds:
453        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
454        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
455
456    # Fill in parameters given on the command line
457    set_pars = {}
458    for arg in args[3:]:
459        k,v = arg.split('=')
460        if k not in pars:
461            # extract base name without distribution
462            s = set(p.split('_pd')[0] for p in pars)
463            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
464            sys.exit(1)
465        set_pars[k] = float(v) if not v.endswith('type') else v
466
467    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
468
469if __name__ == "__main__":
470    main()
Note: See TracBrowser for help on using the repository browser.