source: sasmodels/sasmodels/compare.py @ cd3dba0

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since cd3dba0 was cd3dba0, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

improve compare.py so that parameters can be constrained to valid values

  • Property mode set to 100755
File size: 16.2 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8import datetime
9
10import numpy as np
11
12ROOT = dirname(__file__)
13sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
14
15
16from . import core
17from . import kerneldll
18from . import generate
19from .data import plot_theory, empty_data1D, empty_data2D
20from .direct_model import DirectModel
21from .convert import revert_model
22kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
23
24# List of available models
25MODELS = [basename(f)[:-3]
26          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
27
28# CRUFT python 2.6
29if not hasattr(datetime.timedelta, 'total_seconds'):
30    def delay(dt):
31        """Return number date-time delta as number seconds"""
32        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
33else:
34    def delay(dt):
35        """Return number date-time delta as number seconds"""
36        return dt.total_seconds()
37
38
39def tic():
40    """
41    Timer function.
42
43    Use "toc=tic()" to start the clock and "toc()" to measure
44    a time interval.
45    """
46    then = datetime.datetime.now()
47    return lambda: delay(datetime.datetime.now() - then)
48
49
50def set_beam_stop(data, radius, outer=None):
51    """
52    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
53
54    Note: this function does not use the sasview package
55    """
56    if hasattr(data, 'qx_data'):
57        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
58        data.mask = (q < radius)
59        if outer is not None:
60            data.mask |= (q >= outer)
61    else:
62        data.mask = (data.x < radius)
63        if outer is not None:
64            data.mask |= (data.x >= outer)
65
66
67def sasview_model(model_definition, **pars):
68    """
69    Load a sasview model given the model name.
70    """
71    # convert model parameters from sasmodel form to sasview form
72    #print "old",sorted(pars.items())
73    modelname, pars = revert_model(model_definition, pars)
74    #print "new",sorted(pars.items())
75    sas = __import__('sas.models.'+modelname)
76    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
77    if ModelClass is None:
78        raise ValueError("could not find model %r in sas.models"%modelname)
79    model = ModelClass()
80
81    for k,v in pars.items():
82        if k.endswith("_pd"):
83            model.dispersion[k[:-3]]['width'] = v
84        elif k.endswith("_pd_n"):
85            model.dispersion[k[:-5]]['npts'] = v
86        elif k.endswith("_pd_nsigma"):
87            model.dispersion[k[:-10]]['nsigmas'] = v
88        elif k.endswith("_pd_type"):
89            model.dispersion[k[:-8]]['type'] = v
90        else:
91            model.setParam(k, v)
92    return model
93
94def randomize(p, v):
95    """
96    Randomizing parameter.
97
98    Guess the parameter type from name.
99    """
100    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
101        return v
102    elif any(s in p for s in ('theta','phi','psi')):
103        # orientation in [-180,180], orientation pd in [0,45]
104        if p.endswith('_pd'):
105            return 45*np.random.rand()
106        else:
107            return 360*np.random.rand() - 180
108    elif 'sld' in p:
109        # sld in in [-0.5,10]
110        return 10.5*np.random.rand() - 0.5
111    elif p.endswith('_pd'):
112        # length pd in [0,1]
113        return np.random.rand()
114    else:
115        # values from 0 to 2*x for all other parameters
116        return 2*np.random.rand()*(v if v != 0 else 1)
117
118def randomize_model(pars, seed=None):
119    if seed is None:
120        seed = np.random.randint(1e9)
121    np.random.seed(seed)
122    # Note: the sort guarantees order of calls to random number generator
123    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
124
125    return pars, seed
126
127def constrain_pars(model_definition, pars):
128    name = model_definition.name
129    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
130        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
131
132    # These constraints are only needed for comparison to sasview
133    if name in ('teubner_strey','broad_peak'):
134        del pars['scale']
135    if name in ('guinier',):
136        del pars['background']
137    if getattr(model_definition, 'category', None) == 'structure-factor':
138        del pars['scale'], pars['background']
139
140
141def parlist(pars):
142    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
143
144def suppress_pd(pars):
145    """
146    Suppress theta_pd for now until the normalization is resolved.
147
148    May also suppress complete polydispersity of the model to test
149    models more quickly.
150    """
151    for p in pars:
152        if p.endswith("_pd"): pars[p] = 0
153
154def eval_sasview(model_definition, pars, data, Nevals=1):
155    from sas.models.qsmearing import smear_selection
156    model = sasview_model(model_definition, **pars)
157    smearer = smear_selection(data, model=model)
158    value = None  # silence the linter
159    toc = tic()
160    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
161        if hasattr(data, 'qx_data'):
162            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
163            index = ((~data.mask) & (~np.isnan(data.data))
164                     & (q >= data.qmin) & (q <= data.qmax))
165            if smearer is not None:
166                smearer.model = model  # because smear_selection has a bug
167                smearer.accuracy = data.accuracy
168                smearer.set_index(index)
169                value = smearer.get_value()
170            else:
171                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
172        else:
173            value = model.evalDistribution(data.x)
174            if smearer is not None:
175                value = smearer(value)
176    average_time = toc()*1000./Nevals
177    return value, average_time
178
179def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
180    try:
181        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
182    except Exception,exc:
183        print exc
184        print "... trying again with single precision"
185        model = core.load_model(model_definition, dtype='single', platform="ocl")
186    calculator = DirectModel(data, model, cutoff=cutoff)
187    value = None  # silence the linter
188    toc = tic()
189    for _ in range(max(Nevals, 1)):  # force at least one eval
190        value = calculator(**pars)
191    average_time = toc()*1000./Nevals
192    return value, average_time
193
194
195def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
196    model = core.load_model(model_definition, dtype=dtype, platform="dll")
197    calculator = DirectModel(data, model, cutoff=cutoff)
198    value = None  # silence the linter
199    toc = tic()
200    for _ in range(max(Nevals, 1)):  # force at least one eval
201        value = calculator(**pars)
202    average_time = toc()*1000./Nevals
203    return value, average_time
204
205
206def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
207    if is2D:
208        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
209        data.accuracy = accuracy
210        set_beam_stop(data, 0.004)
211        index = ~data.mask
212    else:
213        if view == 'log':
214            qmax = math.log10(qmax)
215            q = np.logspace(qmax-3, qmax, Nq)
216        else:
217            q = np.linspace(0.001*qmax, qmax, Nq)
218        data = empty_data1D(q, resolution=resolution)
219        index = slice(None, None)
220    return data, index
221
222def compare(name, pars, Ncpu, Nocl, opts, set_pars):
223    model_definition = core.load_model_definition(name)
224
225    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
226
227    opt_values = dict(split
228                      for s in opts for split in ((s.split('='),))
229                      if len(split) == 2)
230    # Sort out data
231    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
232    Nq = int(opt_values.get('-Nq', '128'))
233    res = float(opt_values.get('-res', '0'))
234    accuracy = opt_values.get('-accuracy', 'Low')
235    is2D = "-2d" in opts
236    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
237
238
239    # modelling accuracy is determined by dtype and cutoff
240    dtype = 'double' if '-double' in opts else 'single'
241    cutoff = float(opt_values.get('-cutoff','1e-5'))
242
243    # randomize parameters
244    #pars.update(set_pars)  # set value before random to control range
245    if '-random' in opts or '-random' in opt_values:
246        seed = int(opt_values['-random']) if '-random' in opt_values else None
247        pars, seed = randomize_model(pars, seed=seed)
248        constrain_pars(model_definition, pars)
249        print "Randomize using -random=%i"%seed
250    pars.update(set_pars)  # set value after random to control value
251
252    # parameter selection
253    if '-mono' in opts:
254        suppress_pd(pars)
255    if '-pars' in opts:
256        print "pars",parlist(pars)
257
258    # OpenCl calculation
259    if Nocl > 0:
260        ocl, ocl_time = eval_opencl(model_definition, pars, data,
261                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
262        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
263        #print "ocl", ocl
264        #print max(ocl), min(ocl)
265
266    # ctypes/sasview calculation
267    if Ncpu > 0 and "-ctypes" in opts:
268        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
269                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
270        comp = "ctypes"
271        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
272    elif Ncpu > 0:
273        try:
274            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
275            comp = "sasview"
276            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background'])
277            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
278            #print "sasview",cpu
279        except ImportError:
280            Ncpu = 0
281
282    # Compare, but only if computing both forms
283    if Nocl > 0 and Ncpu > 0:
284        #print "speedup %.2g"%(cpu_time/ocl_time)
285        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
286        #cpu *= max(ocl/cpu)
287        resid = (ocl - cpu)
288        relerr = resid/cpu
289        #bad = (relerr>1e-4)
290        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
291        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
292        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
293
294    # Plot if requested
295    if '-noplot' in opts: return
296    import matplotlib.pyplot as plt
297    if Ncpu > 0:
298        if Nocl > 0: plt.subplot(131)
299        plot_theory(data, cpu, view=view, plot_data=False)
300        plt.title("%s t=%.1f ms"%(comp,cpu_time))
301        #cbar_title = "log I"
302    if Nocl > 0:
303        if Ncpu > 0: plt.subplot(132)
304        plot_theory(data, ocl, view=view, plot_data=False)
305        plt.title("opencl t=%.1f ms"%ocl_time)
306        #cbar_title = "log I"
307    if Ncpu > 0 and Nocl > 0:
308        plt.subplot(133)
309        if '-abs' in opts:
310            err,errstr,errview = resid, "abs err", "linear"
311        else:
312            err,errstr,errview = abs(relerr), "rel err", "log"
313        #err,errstr = ocl/cpu,"ratio"
314        plot_theory(data, None, resid=err, view=errview, plot_data=False)
315        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
316        #cbar_title = errstr if errview=="linear" else "log "+errstr
317    #if is2D:
318    #    h = plt.colorbar()
319    #    h.ax.set_title(cbar_title)
320
321    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
322        plt.figure()
323        v = relerr
324        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
325        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
326        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
327        plt.ylabel('P(err)')
328        plt.title('Comparison of single and double precision models for %s'%name)
329
330    plt.show()
331
332def _print_stats(label, err):
333    sorted_err = np.sort(abs(err))
334    p50 = int((len(err)-1)*0.50)
335    p98 = int((len(err)-1)*0.98)
336    data = [
337        "max:%.3e"%sorted_err[-1],
338        "median:%.3e"%sorted_err[p50],
339        "98%%:%.3e"%sorted_err[p98],
340        "rms:%.3e"%np.sqrt(np.mean(err**2)),
341        "zero-offset:%+.3e"%np.mean(err),
342        ]
343    print label,"  ".join(data)
344
345
346
347# ===========================================================================
348#
349USAGE="""
350usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
351
352Compare the speed and value for a model between the SasView original and the
353OpenCL rewrite.
354
355model is the name of the model to compare (see below).
356Nopencl is the number of times to run the OpenCL model (default=5)
357Nsasview is the number of times to run the Sasview model (default=1)
358
359Options (* for default):
360
361    -plot*/-noplot plots or suppress the plot of the model
362    -single*/-double uses double precision for comparison
363    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
364    -Nq=128 sets the number of Q points in the data set
365    -1d*/-2d computes 1d or 2d data
366    -preset*/-random[=seed] preset or random parameters
367    -mono/-poly* force monodisperse/polydisperse
368    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
369    -cutoff=1e-5* cutoff value for including a point in polydispersity
370    -pars/-nopars* prints the parameter set or not
371    -abs/-rel* plot relative or absolute error
372    -linear/-log/-q4 intensity scaling
373    -hist/-nohist* plot histogram of relative error
374    -res=0 sets the resolution width dQ/Q if calculating with resolution
375    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
376
377Key=value pairs allow you to set specific values to any of the model
378parameters.
379
380Available models:
381"""
382
383
384NAME_OPTIONS = set([
385    'plot','noplot',
386    'single','double',
387    'lowq','midq','highq','exq',
388    '2d','1d',
389    'preset','random',
390    'poly','mono',
391    'sasview','ctypes',
392    'nopars','pars',
393    'rel','abs',
394    'linear', 'log', 'q4',
395    'hist','nohist',
396    ])
397VALUE_OPTIONS = [
398    # Note: random is both a name option and a value option
399    'cutoff', 'random', 'Nq', 'res', 'accuracy',
400    ]
401
402def columnize(L, indent="", width=79):
403    column_width = max(len(w) for w in L) + 1
404    num_columns = (width - len(indent)) // column_width
405    num_rows = len(L) // num_columns
406    L = L + [""] * (num_rows*num_columns - len(L))
407    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
408    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
409             for row in zip(*columns)]
410    output = indent + ("\n"+indent).join(lines)
411    return output
412
413
414def get_demo_pars(model_definition):
415    info = generate.make_info(model_definition)
416    pars = dict((p[0],p[2]) for p in info['parameters'])
417    pars.update(info['demo'])
418    return pars
419
420def main():
421    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
422    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
423    models = "\n    ".join("%-15s"%v for v in MODELS)
424    if len(args) == 0:
425        print(USAGE)
426        print(columnize(MODELS, indent="  "))
427        sys.exit(1)
428    if args[0] not in MODELS:
429        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
430        sys.exit(1)
431
432    invalid = [o[1:] for o in opts
433               if o[1:] not in NAME_OPTIONS
434                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
435    if invalid:
436        print "Invalid options: %s"%(", ".join(invalid))
437        sys.exit(1)
438
439    # Get demo parameters from model definition, or use default parameters
440    # if model does not define demo parameters
441    name = args[0]
442    model_definition = core.load_model_definition(name)
443    pars = get_demo_pars(model_definition)
444
445    Nopencl = int(args[1]) if len(args) > 1 else 5
446    Nsasview = int(args[2]) if len(args) > 2 else 1
447
448    # Fill in default polydispersity parameters
449    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
450    for p in pds:
451        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
452        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
453
454    # Fill in parameters given on the command line
455    set_pars = {}
456    for arg in args[3:]:
457        k,v = arg.split('=')
458        if k not in pars:
459            # extract base name without distribution
460            s = set(p.split('_pd')[0] for p in pars)
461            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
462            sys.exit(1)
463        set_pars[k] = float(v) if not v.endswith('type') else v
464
465    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
466
467if __name__ == "__main__":
468    main()
Note: See TracBrowser for help on using the repository browser.