source: sasmodels/compare.py @ 373d1b6

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 373d1b6 was 373d1b6, checked in by pkienzle, 9 years ago

refactor model parameter handling

  • Property mode set to 100755
File size: 12.1 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import math
6from os.path import basename, dirname, join as joinpath
7import glob
8
9import numpy as np
10
11from sasmodels.bumps_model import BumpsModel, plot_data, tic
12from sasmodels import kernelcl, kerneldll
13from sasmodels.convert import revert_model
14
15# List of available models
16ROOT = dirname(__file__)
17MODELS = [basename(f)[:-3]
18          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
19
20
21def sasview_model(modelname, **pars):
22    """
23    Load a sasview model given the model name.
24    """
25    # convert model parameters from sasmodel form to sasview form
26    #print "old",sorted(pars.items())
27    modelname, pars = revert_model(modelname, pars)
28    #print "new",sorted(pars.items())
29    sas = __import__('sas.models.'+modelname)
30    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
31    if ModelClass is None:
32        raise ValueError("could not find model %r in sas.models"%modelname)
33    model = ModelClass()
34
35    for k,v in pars.items():
36        if k.endswith("_pd"):
37            model.dispersion[k[:-3]]['width'] = v
38        elif k.endswith("_pd_n"):
39            model.dispersion[k[:-5]]['npts'] = v
40        elif k.endswith("_pd_nsigma"):
41            model.dispersion[k[:-10]]['nsigmas'] = v
42        elif k.endswith("_pd_type"):
43            model.dispersion[k[:-8]]['type'] = v
44        else:
45            model.setParam(k, v)
46    return model
47
48def load_opencl(modelname, dtype='single'):
49    sasmodels = __import__('sasmodels.models.'+modelname)
50    module = getattr(sasmodels.models, modelname, None)
51    kernel = kernelcl.load_model(module, dtype=dtype)
52    return kernel
53
54def load_ctypes(modelname, dtype='single'):
55    sasmodels = __import__('sasmodels.models.'+modelname)
56    module = getattr(sasmodels.models, modelname, None)
57    kernel = kerneldll.load_model(module, dtype=dtype)
58    return kernel
59
60def randomize(p, v):
61    """
62    Randomizing parameter.
63
64    Guess the parameter type from name.
65    """
66    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
67        return v
68    elif any(s in p for s in ('theta','phi','psi')):
69        # orientation in [-180,180], orientation pd in [0,45]
70        if p.endswith('_pd'):
71            return 45*np.random.rand()
72        else:
73            return 360*np.random.rand() - 180
74    elif 'sld' in p:
75        # sld in in [-0.5,10]
76        return 10.5*np.random.rand() - 0.5
77    elif p.endswith('_pd'):
78        # length pd in [0,1]
79        return np.random.rand()
80    else:
81        # length, scale, background in [0,200]
82        return 200*np.random.rand()
83
84def randomize_model(name, pars, seed=None):
85    if seed is None:
86        seed = np.random.randint(1e9)
87    np.random.seed(seed)
88    # Note: the sort guarantees order of calls to random number generator
89    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
90    # The capped cylinder model has a constraint on its parameters
91    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
92        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
93    return pars, seed
94
95def parlist(pars):
96    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
97
98def suppress_pd(pars):
99    """
100    Suppress theta_pd for now until the normalization is resolved.
101
102    May also suppress complete polydispersity of the model to test
103    models more quickly.
104    """
105    for p in pars:
106        if p.endswith("_pd"): pars[p] = 0
107
108def eval_sasview(name, pars, data, Nevals=1):
109    model = sasview_model(name, **pars)
110    toc = tic()
111    for _ in range(Nevals):
112        if hasattr(data, 'qx_data'):
113            value = model.evalDistribution([data.qx_data, data.qy_data])
114        else:
115            value = model.evalDistribution(data.x)
116    average_time = toc()*1000./Nevals
117    return value, average_time
118
119def eval_opencl(name, pars, data, dtype='single', Nevals=1, cutoff=0):
120    try:
121        model = load_opencl(name, dtype=dtype)
122    except Exception,exc:
123        print exc
124        print "... trying again with single precision"
125        model = load_opencl(name, dtype='single')
126    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
127    toc = tic()
128    for _ in range(Nevals):
129        #pars['scale'] = np.random.rand()
130        problem.update()
131        value = problem.theory()
132    average_time = toc()*1000./Nevals
133    return value, average_time
134
135def eval_ctypes(name, pars, data, dtype='double', Nevals=1, cutoff=0):
136    model = load_ctypes(name, dtype=dtype)
137    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
138    toc = tic()
139    for _ in range(Nevals):
140        problem.update()
141        value = problem.theory()
142    average_time = toc()*1000./Nevals
143    return value, average_time
144
145def make_data(qmax, is2D, Nq=128):
146    if is2D:
147        from sasmodels.bumps_model import empty_data2D, set_beam_stop
148        data = empty_data2D(np.linspace(-qmax, qmax, Nq))
149        set_beam_stop(data, 0.004)
150        index = ~data.mask
151    else:
152        from sasmodels.bumps_model import empty_data1D
153        qmax = math.log10(qmax)
154        data = empty_data1D(np.logspace(qmax-3, qmax, Nq))
155        index = slice(None, None)
156    return data, index
157
158def compare(name, pars, Ncpu, Nocl, opts, set_pars):
159    opt_values = dict(split
160                      for s in opts for split in ((s.split('='),))
161                      if len(split) == 2)
162    # Sort out data
163    qmax = 1.0 if '-highq' in opts else (0.2 if '-midq' in opts else 0.05)
164    Nq = int(opt_values.get('-Nq', '128'))
165    is2D = not "-1d" in opts
166    data, index = make_data(qmax, is2D, Nq)
167
168
169    # modelling accuracy is determined by dtype and cutoff
170    dtype = 'double' if '-double' in opts else 'single'
171    cutoff = float(opt_values.get('-cutoff','1e-5'))
172
173    # randomize parameters
174    if '-random' in opts or '-random' in opt_values:
175        seed = int(opt_values['-random']) if '-random' in opt_values else None
176        pars, seed = randomize_model(name, pars, seed=seed)
177        print "Randomize using -random=%i"%seed
178    pars.update(set_pars)
179
180    # parameter selection
181    if '-mono' in opts:
182        suppress_pd(pars)
183    if '-pars' in opts:
184        print "pars",parlist(pars)
185
186    # OpenCl calculation
187    if Nocl > 0:
188        ocl, ocl_time = eval_opencl(name, pars, data, dtype, Nocl)
189        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl[index]))
190        #print max(ocl), min(ocl)
191
192    # ctypes/sasview calculation
193    if Ncpu > 0 and "-ctypes" in opts:
194        cpu, cpu_time = eval_ctypes(name, pars, data, dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
195        comp = "ctypes"
196        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
197    elif Ncpu > 0:
198        cpu, cpu_time = eval_sasview(name, pars, data, Ncpu)
199        comp = "sasview"
200        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
201
202    # Compare, but only if computing both forms
203    if Nocl > 0 and Ncpu > 0:
204        #print "speedup %.2g"%(cpu_time/ocl_time)
205        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
206        #cpu *= max(ocl/cpu)
207        resid, relerr = np.zeros_like(ocl), np.zeros_like(ocl)
208        resid[index] = (ocl - cpu)[index]
209        relerr[index] = resid[index]/cpu[index]
210        #bad = (relerr>1e-4)
211        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
212        print "max(|ocl-%s|)"%comp, max(abs(resid[index]))
213        print "max(|(ocl-%s)/%s|)"%(comp,comp), max(abs(relerr[index]))
214        p98 = int(len(relerr[index])*0.98)
215        print "98%% (|(ocl-%s)/%s|) <"%(comp,comp), np.sort(abs(relerr[index]))[p98]
216
217
218    # Plot if requested
219    if '-noplot' in opts: return
220    import matplotlib.pyplot as plt
221    if Ncpu > 0:
222        if Nocl > 0: plt.subplot(131)
223        plot_data(data, cpu, scale='log')
224        plt.title("%s t=%.1f ms"%(comp,cpu_time))
225    if Nocl > 0:
226        if Ncpu > 0: plt.subplot(132)
227        plot_data(data, ocl, scale='log')
228        plt.title("opencl t=%.1f ms"%ocl_time)
229    if Ncpu > 0 and Nocl > 0:
230        plt.subplot(133)
231        err = resid if '-abs' in opts else relerr
232        errstr = "abs err" if '-abs' in opts else "rel err"
233        #err,errstr = ocl/cpu,"ratio"
234        plot_data(data, err, scale='linear')
235        plt.title("max %s = %.3g"%(errstr, max(abs(err[index]))))
236    if is2D: plt.colorbar()
237
238    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
239        plt.figure()
240        v = relerr[index]
241        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
242        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
243        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
244        plt.ylabel('P(err)')
245        plt.title('Comparison of single and double precision models for %s'%name)
246
247    plt.show()
248
249# ===========================================================================
250#
251USAGE="""
252usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
253
254Compare the speed and value for a model between the SasView original and the
255OpenCL rewrite.
256
257model is the name of the model to compare (see below).
258Nopencl is the number of times to run the OpenCL model (default=5)
259Nsasview is the number of times to run the Sasview model (default=1)
260
261Options (* for default):
262
263    -plot*/-noplot plots or suppress the plot of the model
264    -single*/-double uses double precision for comparison
265    -lowq*/-midq/-highq use q values up to 0.05, 0.2 or 1.0
266    -Nq=128 sets the number of Q points in the data set
267    -1d/-2d* computes 1d or 2d data
268    -preset*/-random[=seed] preset or random parameters
269    -mono/-poly* force monodisperse/polydisperse
270    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
271    -cutoff=1e-5*/value cutoff for including a point in polydispersity
272    -pars/-nopars* prints the parameter set or not
273    -abs/-rel* plot relative or absolute error
274    -hist/-nohist* plot histogram of relative error
275
276Key=value pairs allow you to set specific values to any of the model
277parameters.
278
279Available models:
280
281    %s
282"""
283
284NAME_OPTIONS = set([
285    'plot','noplot',
286    'single','double',
287    'lowq','midq','highq',
288    '2d','1d',
289    'preset','random',
290    'poly','mono',
291    'sasview','ctypes',
292    'nopars','pars',
293    'rel','abs',
294    'hist','nohist',
295    ])
296VALUE_OPTIONS = [
297    # Note: random is both a name option and a value option
298    'cutoff', 'random', 'Nq',
299    ]
300
301def get_demo_pars(name):
302    import sasmodels.models
303    __import__('sasmodels.models.'+name)
304    model = getattr(sasmodels.models, name)
305    pars = getattr(model, 'demo', None)
306    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
307    return pars
308
309def main():
310    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
311    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
312    models = "\n    ".join("%-15s"%v for v in MODELS)
313    if len(args) == 0:
314        print(USAGE%models)
315        sys.exit(1)
316    if args[0] not in MODELS:
317        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
318        sys.exit(1)
319
320    invalid = [o[1:] for o in opts
321               if o[1:] not in NAME_OPTIONS
322                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
323    if invalid:
324        print "Invalid options: %s"%(", ".join(invalid))
325        sys.exit(1)
326
327    # Get demo parameters from model definition, or use default parameters
328    # if model does not define demo parameters
329    name = args[0]
330    pars = get_demo_pars(name)
331
332    Nopencl = int(args[1]) if len(args) > 1 else 5
333    Nsasview = int(args[2]) if len(args) > 2 else 1
334
335    # Fill in default polydispersity parameters
336    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
337    for p in pds:
338        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
339        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
340
341    # Fill in parameters given on the command line
342    set_pars = {}
343    for arg in args[3:]:
344        k,v = arg.split('=')
345        if k not in pars:
346            # extract base name without distribution
347            s = set(p.split('_pd')[0] for p in pars)
348            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
349            sys.exit(1)
350        set_pars[k] = float(v) if not v.endswith('type') else v
351
352    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
353
354if __name__ == "__main__":
355    main()
Note: See TracBrowser for help on using the repository browser.