source: sasmodels/compare.py @ 29fc2a3

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 29fc2a3 was 29fc2a3, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

make sure sasmodels is on the path for compare.py

  • Property mode set to 100755
File size: 14.0 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8
9import numpy as np
10
11ROOT = dirname(__file__)
12sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
13
14
15from sasmodels.bumps_model import Model, Experiment, plot_theory, tic
16from sasmodels import core
17from sasmodels import kerneldll
18from sasmodels.convert import revert_model
19kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
20
21# List of available models
22MODELS = [basename(f)[:-3]
23          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
24
25
26def sasview_model(model_definition, **pars):
27    """
28    Load a sasview model given the model name.
29    """
30    # convert model parameters from sasmodel form to sasview form
31    #print "old",sorted(pars.items())
32    modelname, pars = revert_model(model_definition, pars)
33    #print "new",sorted(pars.items())
34    sas = __import__('sas.models.'+modelname)
35    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
36    if ModelClass is None:
37        raise ValueError("could not find model %r in sas.models"%modelname)
38    model = ModelClass()
39
40    for k,v in pars.items():
41        if k.endswith("_pd"):
42            model.dispersion[k[:-3]]['width'] = v
43        elif k.endswith("_pd_n"):
44            model.dispersion[k[:-5]]['npts'] = v
45        elif k.endswith("_pd_nsigma"):
46            model.dispersion[k[:-10]]['nsigmas'] = v
47        elif k.endswith("_pd_type"):
48            model.dispersion[k[:-8]]['type'] = v
49        else:
50            model.setParam(k, v)
51    return model
52
53def randomize(p, v):
54    """
55    Randomizing parameter.
56
57    Guess the parameter type from name.
58    """
59    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
60        return v
61    elif any(s in p for s in ('theta','phi','psi')):
62        # orientation in [-180,180], orientation pd in [0,45]
63        if p.endswith('_pd'):
64            return 45*np.random.rand()
65        else:
66            return 360*np.random.rand() - 180
67    elif 'sld' in p:
68        # sld in in [-0.5,10]
69        return 10.5*np.random.rand() - 0.5
70    elif p.endswith('_pd'):
71        # length pd in [0,1]
72        return np.random.rand()
73    else:
74        # values from 0 to 2*x for all other parameters
75        return 2*np.random.rand()*(v if v != 0 else 1)
76
77def randomize_model(name, pars, seed=None):
78    if seed is None:
79        seed = np.random.randint(1e9)
80    np.random.seed(seed)
81    # Note: the sort guarantees order of calls to random number generator
82    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
83    # The capped cylinder model has a constraint on its parameters
84    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
85        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
86    return pars, seed
87
88def parlist(pars):
89    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
90
91def suppress_pd(pars):
92    """
93    Suppress theta_pd for now until the normalization is resolved.
94
95    May also suppress complete polydispersity of the model to test
96    models more quickly.
97    """
98    for p in pars:
99        if p.endswith("_pd"): pars[p] = 0
100
101def eval_sasview(name, pars, data, Nevals=1):
102    from sas.models.qsmearing import smear_selection
103    model = sasview_model(name, **pars)
104    smearer = smear_selection(data, model=model)
105    value = None  # silence the linter
106    toc = tic()
107    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
108        if hasattr(data, 'qx_data'):
109            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
110            index = ((~data.mask) & (~np.isnan(data.data))
111                     & (q >= data.qmin) & (q <= data.qmax))
112            if smearer is not None:
113                smearer.model = model  # because smear_selection has a bug
114                smearer.accuracy = data.accuracy
115                smearer.set_index(index)
116                value = smearer.get_value()
117            else:
118                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
119        else:
120            value = model.evalDistribution(data.x)
121            if smearer is not None:
122                value = smearer(value)
123    average_time = toc()*1000./Nevals
124    return value, average_time
125
126def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
127    try:
128        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
129    except Exception,exc:
130        print exc
131        print "... trying again with single precision"
132        model = core.load_model(model_definition, dtype='single', platform="ocl")
133    problem = Experiment(data, Model(model, **pars), cutoff=cutoff)
134    value = None  # silence the linter
135    toc = tic()
136    for _ in range(max(Nevals, 1)):  # force at least one eval
137        #pars['scale'] = np.random.rand()
138        problem.update()
139        value = problem.theory()
140    average_time = toc()*1000./Nevals
141    return value, average_time
142
143def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
144    model = core.load_model(model_definition, dtype=dtype, platform="dll")
145    problem = Experiment(data, Model(model, **pars), cutoff=cutoff)
146    value = None  # silence the linter
147    toc = tic()
148    for _ in range(max(Nevals, 1)):  # force at least one eval
149        problem.update()
150        value = problem.theory()
151    average_time = toc()*1000./Nevals
152    return value, average_time
153
154def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
155    if is2D:
156        from sasmodels.bumps_model import empty_data2D, set_beam_stop
157        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
158        data.accuracy = accuracy
159        set_beam_stop(data, 0.004)
160        index = ~data.mask
161    else:
162        from sasmodels.bumps_model import empty_data1D
163        if view == 'log':
164            qmax = math.log10(qmax)
165            q = np.logspace(qmax-3, qmax, Nq)
166        else:
167            q = np.linspace(0.001*qmax, qmax, Nq)
168        data = empty_data1D(q, resolution=resolution)
169        index = slice(None, None)
170    return data, index
171
172def compare(name, pars, Ncpu, Nocl, opts, set_pars):
173    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
174
175    opt_values = dict(split
176                      for s in opts for split in ((s.split('='),))
177                      if len(split) == 2)
178    # Sort out data
179    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
180    Nq = int(opt_values.get('-Nq', '128'))
181    res = float(opt_values.get('-res', '0'))
182    accuracy = opt_values.get('-accuracy', 'Low')
183    is2D = not "-1d" in opts
184    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
185
186
187    # modelling accuracy is determined by dtype and cutoff
188    dtype = 'double' if '-double' in opts else 'single'
189    cutoff = float(opt_values.get('-cutoff','1e-5'))
190
191    # randomize parameters
192    pars.update(set_pars)
193    if '-random' in opts or '-random' in opt_values:
194        seed = int(opt_values['-random']) if '-random' in opt_values else None
195        pars, seed = randomize_model(name, pars, seed=seed)
196        print "Randomize using -random=%i"%seed
197
198    # parameter selection
199    if '-mono' in opts:
200        suppress_pd(pars)
201    if '-pars' in opts:
202        print "pars",parlist(pars)
203
204    model_definition = core.load_model_definition(name)
205    # OpenCl calculation
206    if Nocl > 0:
207        ocl, ocl_time = eval_opencl(model_definition, pars, data,
208                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
209        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
210        #print max(ocl), min(ocl)
211
212    # ctypes/sasview calculation
213    if Ncpu > 0 and "-ctypes" in opts:
214        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
215                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
216        comp = "ctypes"
217        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
218    elif Ncpu > 0:
219        cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
220        comp = "sasview"
221        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
222
223    # Compare, but only if computing both forms
224    if Nocl > 0 and Ncpu > 0:
225        #print "speedup %.2g"%(cpu_time/ocl_time)
226        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
227        #cpu *= max(ocl/cpu)
228        resid = (ocl - cpu)
229        relerr = resid/cpu
230        #bad = (relerr>1e-4)
231        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
232        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
233        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
234
235    # Plot if requested
236    if '-noplot' in opts: return
237    import matplotlib.pyplot as plt
238    if Ncpu > 0:
239        if Nocl > 0: plt.subplot(131)
240        plot_theory(data, cpu, view=view)
241        plt.title("%s t=%.1f ms"%(comp,cpu_time))
242        cbar_title = "log I"
243    if Nocl > 0:
244        if Ncpu > 0: plt.subplot(132)
245        plot_theory(data, ocl, view=view)
246        plt.title("opencl t=%.1f ms"%ocl_time)
247        cbar_title = "log I"
248    if Ncpu > 0 and Nocl > 0:
249        plt.subplot(133)
250        if '-abs' in opts:
251            err,errstr,errview = resid, "abs err", "linear"
252        else:
253            err,errstr,errview = abs(relerr), "rel err", "log"
254        #err,errstr = ocl/cpu,"ratio"
255        plot_theory(data, err, view=errview)
256        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
257        cbar_title = errstr if errview=="linear" else "log "+errstr
258    if is2D:
259        h = plt.colorbar()
260        h.ax.set_title(cbar_title)
261
262    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
263        plt.figure()
264        v = relerr
265        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
266        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
267        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
268        plt.ylabel('P(err)')
269        plt.title('Comparison of single and double precision models for %s'%name)
270
271    plt.show()
272
273def _print_stats(label, err):
274    sorted_err = np.sort(abs(err))
275    p50 = int((len(err)-1)*0.50)
276    p98 = int((len(err)-1)*0.98)
277    data = [
278        "max:%.3e"%sorted_err[-1],
279        "median:%.3e"%sorted_err[p50],
280        "98%%:%.3e"%sorted_err[p98],
281        "rms:%.3e"%np.sqrt(np.mean(err**2)),
282        "zero-offset:%+.3e"%np.mean(err),
283        ]
284    print label,"  ".join(data)
285
286
287
288# ===========================================================================
289#
290USAGE="""
291usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
292
293Compare the speed and value for a model between the SasView original and the
294OpenCL rewrite.
295
296model is the name of the model to compare (see below).
297Nopencl is the number of times to run the OpenCL model (default=5)
298Nsasview is the number of times to run the Sasview model (default=1)
299
300Options (* for default):
301
302    -plot*/-noplot plots or suppress the plot of the model
303    -single*/-double uses double precision for comparison
304    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
305    -Nq=128 sets the number of Q points in the data set
306    -1d/-2d* computes 1d or 2d data
307    -preset*/-random[=seed] preset or random parameters
308    -mono/-poly* force monodisperse/polydisperse
309    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
310    -cutoff=1e-5* cutoff value for including a point in polydispersity
311    -pars/-nopars* prints the parameter set or not
312    -abs/-rel* plot relative or absolute error
313    -linear/-log/-q4 intensity scaling
314    -hist/-nohist* plot histogram of relative error
315    -res=0 sets the resolution width dQ/Q if calculating with resolution
316    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
317
318Key=value pairs allow you to set specific values to any of the model
319parameters.
320
321Available models:
322
323    %s
324"""
325
326NAME_OPTIONS = set([
327    'plot','noplot',
328    'single','double',
329    'lowq','midq','highq','exq',
330    '2d','1d',
331    'preset','random',
332    'poly','mono',
333    'sasview','ctypes',
334    'nopars','pars',
335    'rel','abs',
336    'linear', 'log', 'q4',
337    'hist','nohist',
338    ])
339VALUE_OPTIONS = [
340    # Note: random is both a name option and a value option
341    'cutoff', 'random', 'Nq', 'res', 'accuracy',
342    ]
343
344def get_demo_pars(name):
345    import sasmodels.models
346    __import__('sasmodels.models.'+name)
347    model = getattr(sasmodels.models, name)
348    pars = getattr(model, 'demo', None)
349    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
350    return pars
351
352def main():
353    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
354    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
355    models = "\n    ".join("%-15s"%v for v in MODELS)
356    if len(args) == 0:
357        print(USAGE%models)
358        sys.exit(1)
359    if args[0] not in MODELS:
360        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
361        sys.exit(1)
362
363    invalid = [o[1:] for o in opts
364               if o[1:] not in NAME_OPTIONS
365                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
366    if invalid:
367        print "Invalid options: %s"%(", ".join(invalid))
368        sys.exit(1)
369
370    # Get demo parameters from model definition, or use default parameters
371    # if model does not define demo parameters
372    name = args[0]
373    pars = get_demo_pars(name)
374
375    Nopencl = int(args[1]) if len(args) > 1 else 5
376    Nsasview = int(args[2]) if len(args) > 2 else 1
377
378    # Fill in default polydispersity parameters
379    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
380    for p in pds:
381        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
382        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
383
384    # Fill in parameters given on the command line
385    set_pars = {}
386    for arg in args[3:]:
387        k,v = arg.split('=')
388        if k not in pars:
389            # extract base name without distribution
390            s = set(p.split('_pd')[0] for p in pars)
391            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
392            sys.exit(1)
393        set_pars[k] = float(v) if not v.endswith('type') else v
394
395    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
396
397if __name__ == "__main__":
398    main()
Note: See TracBrowser for help on using the repository browser.