source: sasmodels/sasmodels/compare.py @ cf404cb

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since cf404cb was cf404cb, checked in by krzywon, 8 years ago

Added pearl necklace model and numeric method for intrgrating sin(x)/x.
Modified compare.py to ignore PD for number of pearls and string
thickness.

  • Property mode set to 100755
File size: 18.3 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8import datetime
9import traceback
10
11import numpy as np
12
13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
17from . import core
18from . import kerneldll
19from . import generate
20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
22from .convert import revert_model
23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
24
25# List of available models
26MODELS = [basename(f)[:-3]
27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
28
29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
67
68def sasview_model(model_definition, **pars):
69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
73    #print("old",sorted(pars.items()))
74    modelname, pars = revert_model(model_definition, pars)
75    #print("new",sorted(pars.items()))
76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
78    if ModelClass is None:
79        raise ValueError("could not find model %r in sas.models"%modelname)
80    model = ModelClass()
81
82    for k,v in pars.items():
83        if k.startswith("num_pearls_") or k.startswith("thick_string_"):
84            continue
85        if k.endswith("_pd"):
86            model.dispersion[k[:-3]]['width'] = v
87        elif k.endswith("_pd_n"):
88            model.dispersion[k[:-5]]['npts'] = v
89        elif k.endswith("_pd_nsigma"):
90            model.dispersion[k[:-10]]['nsigmas'] = v
91        elif k.endswith("_pd_type"):
92            model.dispersion[k[:-8]]['type'] = v
93        else:
94            model.setParam(k, v)
95    return model
96
97def randomize(p, v):
98    """
99    Randomizing parameter.
100
101    Guess the parameter type from name.
102    """
103    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
104        return v
105    elif any(s in p for s in ('theta','phi','psi')):
106        # orientation in [-180,180], orientation pd in [0,45]
107        if p.endswith('_pd'):
108            return 45*np.random.rand()
109        else:
110            return 360*np.random.rand() - 180
111    elif 'sld' in p:
112        # sld in in [-0.5,10]
113        return 10.5*np.random.rand() - 0.5
114    elif p.endswith('_pd'):
115        # length pd in [0,1]
116        return np.random.rand()
117    else:
118        # values from 0 to 2*x for all other parameters
119        return 2*np.random.rand()*(v if v != 0 else 1)
120
121def randomize_model(pars, seed=None):
122    if seed is None:
123        seed = np.random.randint(1e9)
124    np.random.seed(seed)
125    # Note: the sort guarantees order of calls to random number generator
126    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
127
128    return pars, seed
129
130def constrain_pars(model_definition, pars):
131    name = model_definition.name
132    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
133        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
134    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
135        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
136
137    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
138    if name == 'guinier':
139        #q_max = 0.2  # mid q maximum
140        q_max = 1.0  # high q maximum
141        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
142        pars['rg'] = min(pars['rg'],rg_max)
143
144    # These constraints are only needed for comparison to sasview
145    if name in ('teubner_strey','broad_peak'):
146        del pars['scale']
147    if name in ('guinier',):
148        del pars['background']
149    if getattr(model_definition, 'category', None) == 'structure-factor':
150        del pars['scale'], pars['background']
151
152
153def parlist(pars):
154    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
155
156def suppress_pd(pars):
157    """
158    Suppress theta_pd for now until the normalization is resolved.
159
160    May also suppress complete polydispersity of the model to test
161    models more quickly.
162    """
163    for p in pars:
164        if p.endswith("_pd"): pars[p] = 0
165
166def eval_sasview(model_definition, pars, data, Nevals=1):
167    # importing sas here so that the error message will be that sas failed to
168    # import rather than the more obscure smear_selection not imported error
169    import sas
170    from sas.models.qsmearing import smear_selection
171    model = sasview_model(model_definition, **pars)
172    smearer = smear_selection(data, model=model)
173    value = None  # silence the linter
174    toc = tic()
175    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
176        if hasattr(data, 'qx_data'):
177            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
178            index = ((~data.mask) & (~np.isnan(data.data))
179                     & (q >= data.qmin) & (q <= data.qmax))
180            if smearer is not None:
181                smearer.model = model  # because smear_selection has a bug
182                smearer.accuracy = data.accuracy
183                smearer.set_index(index)
184                value = smearer.get_value()
185            else:
186                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
187        else:
188            value = model.evalDistribution(data.x)
189            if smearer is not None:
190                value = smearer(value)
191    average_time = toc()*1000./Nevals
192    return value, average_time
193
194def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1,
195                cutoff=0., fast=False):
196    try:
197        model = core.load_model(model_definition, dtype=dtype,
198                                platform="ocl", fast=fast)
199    except Exception as exc:
200        print(exc)
201        print("... trying again with single precision")
202        model = core.load_model(model_definition, dtype='single',
203                                platform="ocl", fast=fast)
204    calculator = DirectModel(data, model, cutoff=cutoff)
205    value = None  # silence the linter
206    toc = tic()
207    for _ in range(max(Nevals, 1)):  # force at least one eval
208        value = calculator(**pars)
209    average_time = toc()*1000./Nevals
210    return value, average_time
211
212
213def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
214    model = core.load_model(model_definition, dtype=dtype, platform="dll")
215    calculator = DirectModel(data, model, cutoff=cutoff)
216    value = None  # silence the linter
217    toc = tic()
218    for _ in range(max(Nevals, 1)):  # force at least one eval
219        value = calculator(**pars)
220    average_time = toc()*1000./Nevals
221    return value, average_time
222
223
224def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
225    if is2D:
226        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
227        data.accuracy = accuracy
228        set_beam_stop(data, 0.004)
229        index = ~data.mask
230    else:
231        if view == 'log':
232            qmax = math.log10(qmax)
233            q = np.logspace(qmax-3, qmax, Nq)
234        else:
235            q = np.linspace(0.001*qmax, qmax, Nq)
236        data = empty_data1D(q, resolution=resolution)
237        index = slice(None, None)
238    return data, index
239
240def compare(name, pars, Ncomp, Nbase, opts, set_pars):
241    model_definition = core.load_model_definition(name)
242
243    view = ('linear' if '-linear' in opts
244            else 'log' if '-log' in opts
245            else 'q4' if '-q4' in opts
246            else 'log')
247
248    opt_values = dict(split
249                      for s in opts for split in ((s.split('='),))
250                      if len(split) == 2)
251    # Sort out data
252    qmax = (10.0 if '-exq' in opts
253            else 1.0 if '-highq' in opts
254            else 0.2 if '-midq' in opts
255            else 0.05)
256    Nq = int(opt_values.get('-Nq', '128'))
257    res = float(opt_values.get('-res', '0'))
258    accuracy = opt_values.get('-accuracy', 'Low')
259    is2D = "-2d" in opts
260    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
261
262
263    # modelling accuracy is determined by dtype and cutoff
264    dtype = ('longdouble' if '-quad' in opts
265             else 'double' if '-double' in opts
266             else 'half' if '-half' in opts
267             else 'single')
268    cutoff = float(opt_values.get('-cutoff','1e-5'))
269    fast = "-fast" in opts and dtype is 'single'
270
271    # randomize parameters
272    #pars.update(set_pars)  # set value before random to control range
273    if '-random' in opts or '-random' in opt_values:
274        seed = int(opt_values['-random']) if '-random' in opt_values else None
275        pars, seed = randomize_model(pars, seed=seed)
276        print("Randomize using -random=%i"%seed)
277    pars.update(set_pars)  # set value after random to control value
278    constrain_pars(model_definition, pars)
279
280    # parameter selection
281    if '-mono' in opts:
282        suppress_pd(pars)
283    if '-pars' in opts:
284        print("pars "+str(parlist(pars)))
285
286    # Base calculation
287    if 0:
288        from sasmodels.models import sphere as target
289        base_name = target.name
290        base, base_time = eval_ctypes(target, pars, data,
291                dtype='longdouble', cutoff=0., Nevals=Ncomp)
292    elif Nbase > 0 and "-ctypes" in opts and "-sasview" in opts:
293        try:
294            base, base_time = eval_sasview(model_definition, pars, data, Ncomp)
295            base_name = "sasview"
296            #print("base/sasview", (base-pars['background'])/(comp-pars['background']))
297            print("sasview t=%.1f ms, intensity=%.0f"%(base_time, sum(base)))
298            #print("sasview",comp)
299        except ImportError:
300            traceback.print_exc()
301            Nbase = 0
302    elif Nbase > 0:
303        base, base_time = eval_opencl(model_definition, pars, data,
304                dtype=dtype, cutoff=cutoff, Nevals=Nbase, fast=fast)
305        base_name = "ocl"
306        print("opencl t=%.1f ms, intensity=%.0f"%(base_time, sum(base)))
307        #print("base " + base)
308        #print(max(base), min(base))
309
310    # Comparison calculation
311    if Ncomp > 0 and "-ctypes" in opts:
312        comp, comp_time = eval_ctypes(model_definition, pars, data,
313                dtype=dtype, cutoff=cutoff, Nevals=Ncomp)
314        comp_name = "ctypes"
315        print("ctypes t=%.1f ms, intensity=%.0f"%(comp_time, sum(comp)))
316    elif Ncomp > 0:
317        try:
318            comp, comp_time = eval_sasview(model_definition, pars, data, Ncomp)
319            comp_name = "sasview"
320            #print("base/sasview", (base-pars['background'])/(comp-pars['background']))
321            print("sasview t=%.1f ms, intensity=%.0f"%(comp_time, sum(comp)))
322            #print("sasview",comp)
323        except ImportError:
324            traceback.print_exc()
325            Ncomp = 0
326
327    # Compare, but only if computing both forms
328    if Nbase > 0 and Ncomp > 0:
329        #print("speedup %.2g"%(comp_time/base_time))
330        #print("max |base/comp|", max(abs(base/comp)), "%.15g"%max(abs(base)), "%.15g"%max(abs(comp)))
331        #comp *= max(base/comp)
332        resid = (base - comp)
333        relerr = resid/comp
334        #bad = (relerr>1e-4)
335        #print(relerr[bad],comp[bad],base[bad],data.qx_data[bad],data.qy_data[bad])
336        _print_stats("|%s-%s|"%(base_name,comp_name)+(" "*(3+len(comp_name))), resid)
337        _print_stats("|(%s-%s)/%s|"%(base_name,comp_name,comp_name), relerr)
338
339    # Plot if requested
340    if '-noplot' in opts: return
341    import matplotlib.pyplot as plt
342    if Ncomp > 0:
343        if Nbase > 0: plt.subplot(131)
344        plot_theory(data, comp, view=view, plot_data=False)
345        plt.title("%s t=%.1f ms"%(comp_name,comp_time))
346        #cbar_title = "log I"
347    if Nbase > 0:
348        if Ncomp > 0: plt.subplot(132)
349        plot_theory(data, base, view=view, plot_data=False)
350        plt.title("%s t=%.1f ms"%(base_name,base_time))
351        #cbar_title = "log I"
352    if Ncomp > 0 and Nbase > 0:
353        plt.subplot(133)
354        if '-abs' in opts:
355            err,errstr,errview = resid, "abs err", "linear"
356        else:
357            err,errstr,errview = abs(relerr), "rel err", "log"
358        #err,errstr = base/comp,"ratio"
359        plot_theory(data, None, resid=err, view=errview, plot_data=False)
360        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
361        #cbar_title = errstr if errview=="linear" else "log "+errstr
362    #if is2D:
363    #    h = plt.colorbar()
364    #    h.ax.set_title(cbar_title)
365
366    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
367        plt.figure()
368        v = relerr
369        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
370        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
371        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
372        plt.ylabel('P(err)')
373        plt.title('Comparison of single and double precision models for %s'%name)
374
375    plt.show()
376
377def _print_stats(label, err):
378    sorted_err = np.sort(abs(err))
379    p50 = int((len(err)-1)*0.50)
380    p98 = int((len(err)-1)*0.98)
381    data = [
382        "max:%.3e"%sorted_err[-1],
383        "median:%.3e"%sorted_err[p50],
384        "98%%:%.3e"%sorted_err[p98],
385        "rms:%.3e"%np.sqrt(np.mean(err**2)),
386        "zero-offset:%+.3e"%np.mean(err),
387        ]
388    print(label+"  ".join(data))
389
390
391
392# ===========================================================================
393#
394USAGE="""
395usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
396
397Compare the speed and value for a model between the SasView original and the
398OpenCL rewrite.
399
400model is the name of the model to compare (see below).
401Nopencl is the number of times to run the OpenCL model (default=5)
402Nsasview is the number of times to run the Sasview model (default=1)
403
404Options (* for default):
405
406    -plot*/-noplot plots or suppress the plot of the model
407    -half/-single*/-double/-quad/-fast sets the calculation precision
408    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
409    -Nq=128 sets the number of Q points in the data set
410    -1d*/-2d computes 1d or 2d data
411    -preset*/-random[=seed] preset or random parameters
412    -mono/-poly* force monodisperse/polydisperse
413    -ctypes/-sasview* selects gpu:cpu, gpu:sasview, or sasview:cpu if both
414    -cutoff=1e-5* cutoff value for including a point in polydispersity
415    -pars/-nopars* prints the parameter set or not
416    -abs/-rel* plot relative or absolute error
417    -linear/-log/-q4 intensity scaling
418    -hist/-nohist* plot histogram of relative error
419    -res=0 sets the resolution width dQ/Q if calculating with resolution
420    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
421
422Key=value pairs allow you to set specific values to any of the model
423parameters.
424
425Available models:
426"""
427
428
429NAME_OPTIONS = set([
430    'plot', 'noplot',
431    'half', 'single', 'double', 'quad', 'fast',
432    'lowq', 'midq', 'highq', 'exq',
433    '2d', '1d',
434    'preset', 'random',
435    'poly', 'mono',
436    'sasview', 'ctypes',
437    'nopars', 'pars',
438    'rel', 'abs',
439    'linear', 'log', 'q4',
440    'hist', 'nohist',
441    ])
442VALUE_OPTIONS = [
443    # Note: random is both a name option and a value option
444    'cutoff', 'random', 'Nq', 'res', 'accuracy',
445    ]
446
447def columnize(L, indent="", width=79):
448    column_width = max(len(w) for w in L) + 1
449    num_columns = (width - len(indent)) // column_width
450    num_rows = len(L) // num_columns
451    L = L + [""] * (num_rows*num_columns - len(L))
452    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
453    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
454             for row in zip(*columns)]
455    output = indent + ("\n"+indent).join(lines)
456    return output
457
458
459def get_demo_pars(model_definition):
460    info = generate.make_info(model_definition)
461    pars = dict((p[0],p[2]) for p in info['parameters'])
462    pars.update(info['demo'])
463    return pars
464
465def main():
466    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
467    popts = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' in arg]
468    args = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' not in arg]
469    models = "\n    ".join("%-15s"%v for v in MODELS)
470    if len(args) == 0:
471        print(USAGE)
472        print(columnize(MODELS, indent="  "))
473        sys.exit(1)
474    if args[0] not in MODELS:
475        print("Model %r not available. Use one of:\n    %s"%(args[0],models))
476        sys.exit(1)
477    if len(args) > 3:
478        print("expected parameters: model Nopencl Nsasview")
479
480    invalid = [o[1:] for o in opts
481               if o[1:] not in NAME_OPTIONS
482                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
483    if invalid:
484        print("Invalid options: %s"%(", ".join(invalid)))
485        sys.exit(1)
486
487    # Get demo parameters from model definition, or use default parameters
488    # if model does not define demo parameters
489    name = args[0]
490    model_definition = core.load_model_definition(name)
491    pars = get_demo_pars(model_definition)
492
493    Ncomp = int(args[1]) if len(args) > 1 else 5
494    Nbase = int(args[2]) if len(args) > 2 else 1
495
496    # Fill in default polydispersity parameters
497    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
498    for p in pds:
499        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
500        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
501
502    # Fill in parameters given on the command line
503    set_pars = {}
504    for arg in popts:
505        k,v = arg.split('=',1)
506        if k not in pars:
507            # extract base name without distribution
508            s = set(p.split('_pd')[0] for p in pars)
509            print("%r invalid; parameters are: %s"%(k,", ".join(sorted(s))))
510            sys.exit(1)
511        set_pars[k] = float(v) if not v.endswith('type') else v
512
513    compare(name, pars, Ncomp, Nbase, opts, set_pars)
514
515if __name__ == "__main__":
516    main()
Note: See TracBrowser for help on using the repository browser.