source: sasmodels/compare.py @ 77ad412

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 77ad412 was 0763009, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

minor code cleanup

  • Property mode set to 100755
File size: 14.0 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8
9import numpy as np
10
11from sasmodels.bumps_model import Model, Experiment, plot_theory, tic
12from sasmodels import core
13from sasmodels import kerneldll
14from sasmodels.convert import revert_model
15kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
16
17# List of available models
18ROOT = dirname(__file__)
19MODELS = [basename(f)[:-3]
20          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
21
22
23def sasview_model(model_definition, **pars):
24    """
25    Load a sasview model given the model name.
26    """
27    # convert model parameters from sasmodel form to sasview form
28    #print "old",sorted(pars.items())
29    modelname, pars = revert_model(model_definition, pars)
30    #print "new",sorted(pars.items())
31    sas = __import__('sas.models.'+modelname)
32    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
33    if ModelClass is None:
34        raise ValueError("could not find model %r in sas.models"%modelname)
35    model = ModelClass()
36
37    for k,v in pars.items():
38        if k.endswith("_pd"):
39            model.dispersion[k[:-3]]['width'] = v
40        elif k.endswith("_pd_n"):
41            model.dispersion[k[:-5]]['npts'] = v
42        elif k.endswith("_pd_nsigma"):
43            model.dispersion[k[:-10]]['nsigmas'] = v
44        elif k.endswith("_pd_type"):
45            model.dispersion[k[:-8]]['type'] = v
46        else:
47            model.setParam(k, v)
48    return model
49
50def randomize(p, v):
51    """
52    Randomizing parameter.
53
54    Guess the parameter type from name.
55    """
56    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
57        return v
58    elif any(s in p for s in ('theta','phi','psi')):
59        # orientation in [-180,180], orientation pd in [0,45]
60        if p.endswith('_pd'):
61            return 45*np.random.rand()
62        else:
63            return 360*np.random.rand() - 180
64    elif 'sld' in p:
65        # sld in in [-0.5,10]
66        return 10.5*np.random.rand() - 0.5
67    elif p.endswith('_pd'):
68        # length pd in [0,1]
69        return np.random.rand()
70    else:
71        # values from 0 to 2*x for all other parameters
72        return 2*np.random.rand()*(v if v != 0 else 1)
73
74def randomize_model(name, pars, seed=None):
75    if seed is None:
76        seed = np.random.randint(1e9)
77    np.random.seed(seed)
78    # Note: the sort guarantees order of calls to random number generator
79    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
80    # The capped cylinder model has a constraint on its parameters
81    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
82        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
83    return pars, seed
84
85def parlist(pars):
86    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
87
88def suppress_pd(pars):
89    """
90    Suppress theta_pd for now until the normalization is resolved.
91
92    May also suppress complete polydispersity of the model to test
93    models more quickly.
94    """
95    for p in pars:
96        if p.endswith("_pd"): pars[p] = 0
97
98def eval_sasview(name, pars, data, Nevals=1):
99    from sas.models.qsmearing import smear_selection
100    model = sasview_model(name, **pars)
101    smearer = smear_selection(data, model=model)
102    value = None  # silence the linter
103    toc = tic()
104    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
105        if hasattr(data, 'qx_data'):
106            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
107            index = ((~data.mask) & (~np.isnan(data.data))
108                     & (q >= data.qmin) & (q <= data.qmax))
109            if smearer is not None:
110                smearer.model = model  # because smear_selection has a bug
111                smearer.accuracy = data.accuracy
112                smearer.set_index(index)
113                value = smearer.get_value()
114            else:
115                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
116        else:
117            value = model.evalDistribution(data.x)
118            if smearer is not None:
119                value = smearer(value)
120    average_time = toc()*1000./Nevals
121    return value, average_time
122
123def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
124    try:
125        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
126    except Exception,exc:
127        print exc
128        print "... trying again with single precision"
129        model = core.load_model(model_definition, dtype='single', platform="ocl")
130    problem = Experiment(data, Model(model, **pars), cutoff=cutoff)
131    value = None  # silence the linter
132    toc = tic()
133    for _ in range(max(Nevals, 1)):  # force at least one eval
134        #pars['scale'] = np.random.rand()
135        problem.update()
136        value = problem.theory()
137    average_time = toc()*1000./Nevals
138    return value, average_time
139
140def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
141    model = core.load_model(model_definition, dtype=dtype, platform="dll")
142    problem = Experiment(data, Model(model, **pars), cutoff=cutoff)
143    value = None  # silence the linter
144    toc = tic()
145    for _ in range(max(Nevals, 1)):  # force at least one eval
146        problem.update()
147        value = problem.theory()
148    average_time = toc()*1000./Nevals
149    return value, average_time
150
151def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
152    if is2D:
153        from sasmodels.bumps_model import empty_data2D, set_beam_stop
154        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
155        data.accuracy = accuracy
156        set_beam_stop(data, 0.004)
157        index = ~data.mask
158    else:
159        from sasmodels.bumps_model import empty_data1D
160        if view == 'log':
161            qmax = math.log10(qmax)
162            q = np.logspace(qmax-3, qmax, Nq)
163        else:
164            q = np.linspace(0.001*qmax, qmax, Nq)
165        data = empty_data1D(q, resolution=resolution)
166        index = slice(None, None)
167    return data, index
168
169def compare(name, pars, Ncpu, Nocl, opts, set_pars):
170    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
171
172    opt_values = dict(split
173                      for s in opts for split in ((s.split('='),))
174                      if len(split) == 2)
175    # Sort out data
176    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
177    Nq = int(opt_values.get('-Nq', '128'))
178    res = float(opt_values.get('-res', '0'))
179    accuracy = opt_values.get('-accuracy', 'Low')
180    is2D = not "-1d" in opts
181    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
182
183
184    # modelling accuracy is determined by dtype and cutoff
185    dtype = 'double' if '-double' in opts else 'single'
186    cutoff = float(opt_values.get('-cutoff','1e-5'))
187
188    # randomize parameters
189    pars.update(set_pars)
190    if '-random' in opts or '-random' in opt_values:
191        seed = int(opt_values['-random']) if '-random' in opt_values else None
192        pars, seed = randomize_model(name, pars, seed=seed)
193        print "Randomize using -random=%i"%seed
194
195    # parameter selection
196    if '-mono' in opts:
197        suppress_pd(pars)
198    if '-pars' in opts:
199        print "pars",parlist(pars)
200
201    model_definition = core.load_model_definition(name)
202    # OpenCl calculation
203    if Nocl > 0:
204        ocl, ocl_time = eval_opencl(model_definition, pars, data,
205                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
206        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
207        #print max(ocl), min(ocl)
208
209    # ctypes/sasview calculation
210    if Ncpu > 0 and "-ctypes" in opts:
211        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
212                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
213        comp = "ctypes"
214        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
215    elif Ncpu > 0:
216        cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
217        comp = "sasview"
218        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
219
220    # Compare, but only if computing both forms
221    if Nocl > 0 and Ncpu > 0:
222        #print "speedup %.2g"%(cpu_time/ocl_time)
223        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
224        #cpu *= max(ocl/cpu)
225        resid = (ocl - cpu)
226        relerr = resid/cpu
227        #bad = (relerr>1e-4)
228        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
229        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
230        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
231
232    # Plot if requested
233    if '-noplot' in opts: return
234    import matplotlib.pyplot as plt
235    if Ncpu > 0:
236        if Nocl > 0: plt.subplot(131)
237        plot_theory(data, cpu, view=view)
238        plt.title("%s t=%.1f ms"%(comp,cpu_time))
239        cbar_title = "log I"
240    if Nocl > 0:
241        if Ncpu > 0: plt.subplot(132)
242        plot_theory(data, ocl, view=view)
243        plt.title("opencl t=%.1f ms"%ocl_time)
244        cbar_title = "log I"
245    if Ncpu > 0 and Nocl > 0:
246        plt.subplot(133)
247        if '-abs' in opts:
248            err,errstr,errview = resid, "abs err", "linear"
249        else:
250            err,errstr,errview = abs(relerr), "rel err", "log"
251        #err,errstr = ocl/cpu,"ratio"
252        plot_theory(data, err, view=errview)
253        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
254        cbar_title = errstr if errview=="linear" else "log "+errstr
255    if is2D:
256        h = plt.colorbar()
257        h.ax.set_title(cbar_title)
258
259    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
260        plt.figure()
261        v = relerr
262        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
263        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
264        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
265        plt.ylabel('P(err)')
266        plt.title('Comparison of single and double precision models for %s'%name)
267
268    plt.show()
269
270def _print_stats(label, err):
271    sorted_err = np.sort(abs(err))
272    p50 = int((len(err)-1)*0.50)
273    p98 = int((len(err)-1)*0.98)
274    data = [
275        "max:%.3e"%sorted_err[-1],
276        "median:%.3e"%sorted_err[p50],
277        "98%%:%.3e"%sorted_err[p98],
278        "rms:%.3e"%np.sqrt(np.mean(err**2)),
279        "zero-offset:%+.3e"%np.mean(err),
280        ]
281    print label,"  ".join(data)
282
283
284
285# ===========================================================================
286#
287USAGE="""
288usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
289
290Compare the speed and value for a model between the SasView original and the
291OpenCL rewrite.
292
293model is the name of the model to compare (see below).
294Nopencl is the number of times to run the OpenCL model (default=5)
295Nsasview is the number of times to run the Sasview model (default=1)
296
297Options (* for default):
298
299    -plot*/-noplot plots or suppress the plot of the model
300    -single*/-double uses double precision for comparison
301    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
302    -Nq=128 sets the number of Q points in the data set
303    -1d/-2d* computes 1d or 2d data
304    -preset*/-random[=seed] preset or random parameters
305    -mono/-poly* force monodisperse/polydisperse
306    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
307    -cutoff=1e-5* cutoff value for including a point in polydispersity
308    -pars/-nopars* prints the parameter set or not
309    -abs/-rel* plot relative or absolute error
310    -linear/-log/-q4 intensity scaling
311    -hist/-nohist* plot histogram of relative error
312    -res=0 sets the resolution width dQ/Q if calculating with resolution
313    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
314
315Key=value pairs allow you to set specific values to any of the model
316parameters.
317
318Available models:
319
320    %s
321"""
322
323NAME_OPTIONS = set([
324    'plot','noplot',
325    'single','double',
326    'lowq','midq','highq','exq',
327    '2d','1d',
328    'preset','random',
329    'poly','mono',
330    'sasview','ctypes',
331    'nopars','pars',
332    'rel','abs',
333    'linear', 'log', 'q4',
334    'hist','nohist',
335    ])
336VALUE_OPTIONS = [
337    # Note: random is both a name option and a value option
338    'cutoff', 'random', 'Nq', 'res', 'accuracy',
339    ]
340
341def get_demo_pars(name):
342    import sasmodels.models
343    __import__('sasmodels.models.'+name)
344    model = getattr(sasmodels.models, name)
345    pars = getattr(model, 'demo', None)
346    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
347    return pars
348
349def main():
350    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
351    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
352    models = "\n    ".join("%-15s"%v for v in MODELS)
353    if len(args) == 0:
354        print(USAGE%models)
355        sys.exit(1)
356    if args[0] not in MODELS:
357        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
358        sys.exit(1)
359
360    invalid = [o[1:] for o in opts
361               if o[1:] not in NAME_OPTIONS
362                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
363    if invalid:
364        print "Invalid options: %s"%(", ".join(invalid))
365        sys.exit(1)
366
367    # Get demo parameters from model definition, or use default parameters
368    # if model does not define demo parameters
369    name = args[0]
370    pars = get_demo_pars(name)
371
372    Nopencl = int(args[1]) if len(args) > 1 else 5
373    Nsasview = int(args[2]) if len(args) > 2 else 1
374
375    # Fill in default polydispersity parameters
376    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
377    for p in pds:
378        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
379        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
380
381    # Fill in parameters given on the command line
382    set_pars = {}
383    for arg in args[3:]:
384        k,v = arg.split('=')
385        if k not in pars:
386            # extract base name without distribution
387            s = set(p.split('_pd')[0] for p in pars)
388            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
389            sys.exit(1)
390        set_pars[k] = float(v) if not v.endswith('type') else v
391
392    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
393
394if __name__ == "__main__":
395    main()
Note: See TracBrowser for help on using the repository browser.