source: sasmodels/compare.py @ d6adfbe

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since d6adfbe was 5134b2c, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

fix compare plots so they show both positive and negative error

  • Property mode set to 100755
File size: 12.4 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[87985ca]8
[1726b21]9import numpy as np
[473183c]10
[87985ca]11from sasmodels.bumps_model import BumpsModel, plot_data, tic
[87fce00]12try: from sasmodels import kernelcl
13except: from sasmodels import kerneldll as kernelcl
14from sasmodels import kerneldll
[87985ca]15from sasmodels.convert import revert_model
16
[d547f16]17# List of available models
18ROOT = dirname(__file__)
19MODELS = [basename(f)[:-3]
20          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
21
[8a20be5]22
23def sasview_model(modelname, **pars):
[87985ca]24    """
25    Load a sasview model given the model name.
26    """
27    # convert model parameters from sasmodel form to sasview form
28    #print "old",sorted(pars.items())
29    modelname, pars = revert_model(modelname, pars)
30    #print "new",sorted(pars.items())
[87c722e]31    sas = __import__('sas.models.'+modelname)
32    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]33    if ModelClass is None:
[87c722e]34        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]35    model = ModelClass()
36
37    for k,v in pars.items():
38        if k.endswith("_pd"):
39            model.dispersion[k[:-3]]['width'] = v
40        elif k.endswith("_pd_n"):
41            model.dispersion[k[:-5]]['npts'] = v
42        elif k.endswith("_pd_nsigma"):
43            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]44        elif k.endswith("_pd_type"):
45            model.dispersion[k[:-8]]['type'] = v
[8a20be5]46        else:
47            model.setParam(k, v)
48    return model
49
[87985ca]50def load_opencl(modelname, dtype='single'):
51    sasmodels = __import__('sasmodels.models.'+modelname)
52    module = getattr(sasmodels.models, modelname, None)
[f786ff3]53    kernel = kernelcl.load_model(module, dtype=dtype)
[87985ca]54    return kernel
55
56def load_ctypes(modelname, dtype='single'):
57    sasmodels = __import__('sasmodels.models.'+modelname)
58    module = getattr(sasmodels.models, modelname, None)
[f786ff3]59    kernel = kerneldll.load_model(module, dtype=dtype)
[87985ca]60    return kernel
61
62def randomize(p, v):
63    """
64    Randomizing parameter.
65
66    Guess the parameter type from name.
67    """
68    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
69        return v
70    elif any(s in p for s in ('theta','phi','psi')):
71        # orientation in [-180,180], orientation pd in [0,45]
72        if p.endswith('_pd'):
73            return 45*np.random.rand()
74        else:
75            return 360*np.random.rand() - 180
76    elif 'sld' in p:
77        # sld in in [-0.5,10]
78        return 10.5*np.random.rand() - 0.5
79    elif p.endswith('_pd'):
80        # length pd in [0,1]
81        return np.random.rand()
82    else:
83        # length, scale, background in [0,200]
84        return 200*np.random.rand()
85
[216a9e1]86def randomize_model(name, pars, seed=None):
87    if seed is None:
88        seed = np.random.randint(1e9)
89    np.random.seed(seed)
90    # Note: the sort guarantees order of calls to random number generator
91    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
92    # The capped cylinder model has a constraint on its parameters
93    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
94        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
95    return pars, seed
96
[87985ca]97def parlist(pars):
98    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
99
100def suppress_pd(pars):
101    """
102    Suppress theta_pd for now until the normalization is resolved.
103
104    May also suppress complete polydispersity of the model to test
105    models more quickly.
106    """
107    for p in pars:
108        if p.endswith("_pd"): pars[p] = 0
109
[216a9e1]110def eval_sasview(name, pars, data, Nevals=1):
111    model = sasview_model(name, **pars)
112    toc = tic()
113    for _ in range(Nevals):
114        if hasattr(data, 'qx_data'):
115            value = model.evalDistribution([data.qx_data, data.qy_data])
116        else:
117            value = model.evalDistribution(data.x)
118    average_time = toc()*1000./Nevals
119    return value, average_time
120
121def eval_opencl(name, pars, data, dtype='single', Nevals=1, cutoff=0):
122    try:
123        model = load_opencl(name, dtype=dtype)
124    except Exception,exc:
125        print exc
126        print "... trying again with single precision"
127        model = load_opencl(name, dtype='single')
128    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
129    toc = tic()
130    for _ in range(Nevals):
131        #pars['scale'] = np.random.rand()
132        problem.update()
133        value = problem.theory()
134    average_time = toc()*1000./Nevals
135    return value, average_time
136
137def eval_ctypes(name, pars, data, dtype='double', Nevals=1, cutoff=0):
138    model = load_ctypes(name, dtype=dtype)
139    problem = BumpsModel(data, model, cutoff=cutoff, **pars)
140    toc = tic()
141    for _ in range(Nevals):
142        problem.update()
143        value = problem.theory()
144    average_time = toc()*1000./Nevals
145    return value, average_time
146
147def make_data(qmax, is2D, Nq=128):
148    if is2D:
[87985ca]149        from sasmodels.bumps_model import empty_data2D, set_beam_stop
[216a9e1]150        data = empty_data2D(np.linspace(-qmax, qmax, Nq))
[87985ca]151        set_beam_stop(data, 0.004)
152        index = ~data.mask
[216a9e1]153    else:
154        from sasmodels.bumps_model import empty_data1D
155        qmax = math.log10(qmax)
156        data = empty_data1D(np.logspace(qmax-3, qmax, Nq))
157        index = slice(None, None)
158    return data, index
159
[a503bfd]160def compare(name, pars, Ncpu, Nocl, opts, set_pars):
[216a9e1]161    opt_values = dict(split
162                      for s in opts for split in ((s.split('='),))
163                      if len(split) == 2)
164    # Sort out data
[29f5536]165    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
[216a9e1]166    Nq = int(opt_values.get('-Nq', '128'))
167    is2D = not "-1d" in opts
168    data, index = make_data(qmax, is2D, Nq)
169
[87985ca]170
171    # modelling accuracy is determined by dtype and cutoff
172    dtype = 'double' if '-double' in opts else 'single'
[216a9e1]173    cutoff = float(opt_values.get('-cutoff','1e-5'))
[87985ca]174
175    # randomize parameters
[216a9e1]176    if '-random' in opts or '-random' in opt_values:
177        seed = int(opt_values['-random']) if '-random' in opt_values else None
178        pars, seed = randomize_model(name, pars, seed=seed)
[87985ca]179        print "Randomize using -random=%i"%seed
180    pars.update(set_pars)
181
182    # parameter selection
183    if '-mono' in opts:
184        suppress_pd(pars)
185    if '-pars' in opts:
186        print "pars",parlist(pars)
187
188    # OpenCl calculation
[a503bfd]189    if Nocl > 0:
190        ocl, ocl_time = eval_opencl(name, pars, data, dtype, Nocl)
191        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl[index]))
192        #print max(ocl), min(ocl)
[87985ca]193
194    # ctypes/sasview calculation
195    if Ncpu > 0 and "-ctypes" in opts:
[216a9e1]196        cpu, cpu_time = eval_ctypes(name, pars, data, dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
[87985ca]197        comp = "ctypes"
198        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
199    elif Ncpu > 0:
[216a9e1]200        cpu, cpu_time = eval_sasview(name, pars, data, Ncpu)
[87985ca]201        comp = "sasview"
202        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[index]))
203
204    # Compare, but only if computing both forms
[a503bfd]205    if Nocl > 0 and Ncpu > 0:
206        #print "speedup %.2g"%(cpu_time/ocl_time)
207        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
208        #cpu *= max(ocl/cpu)
209        resid, relerr = np.zeros_like(ocl), np.zeros_like(ocl)
210        resid[index] = (ocl - cpu)[index]
[87985ca]211        relerr[index] = resid[index]/cpu[index]
[ba69383]212        #bad = (relerr>1e-4)
[a503bfd]213        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
[87985ca]214        print "max(|ocl-%s|)"%comp, max(abs(resid[index]))
[ba69383]215        print "max(|(ocl-%s)/%s|)"%(comp,comp), max(abs(relerr[index]))
216        p98 = int(len(relerr[index])*0.98)
217        print "98%% (|(ocl-%s)/%s|) <"%(comp,comp), np.sort(abs(relerr[index]))[p98]
218
[87985ca]219
220    # Plot if requested
221    if '-noplot' in opts: return
[1726b21]222    import matplotlib.pyplot as plt
[87985ca]223    if Ncpu > 0:
[a503bfd]224        if Nocl > 0: plt.subplot(131)
[87985ca]225        plot_data(data, cpu, scale='log')
226        plt.title("%s t=%.1f ms"%(comp,cpu_time))
[29f5536]227        cbar_title = "log I"
[a503bfd]228    if Nocl > 0:
[87985ca]229        if Ncpu > 0: plt.subplot(132)
[a503bfd]230        plot_data(data, ocl, scale='log')
231        plt.title("opencl t=%.1f ms"%ocl_time)
[29f5536]232        cbar_title = "log I"
[a503bfd]233    if Ncpu > 0 and Nocl > 0:
[87985ca]234        plt.subplot(133)
[29f5536]235        if '-abs' in opts:
[5134b2c]236            err,errstr,scale = resid, "abs err", "linear"
[29f5536]237        else:
[5134b2c]238            err,errstr,scale = abs(relerr), "rel err", "log"
[a503bfd]239        #err,errstr = ocl/cpu,"ratio"
[5134b2c]240        plot_data(data, err, scale=scale)
[87985ca]241        plt.title("max %s = %.3g"%(errstr, max(abs(err[index]))))
[5134b2c]242        cbar_title = errstr if scale=="linear" else "log "+errstr
[29f5536]243    if is2D:
244        h = plt.colorbar()
245        h.ax.set_title(cbar_title)
[ba69383]246
[a503bfd]247    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
[ba69383]248        plt.figure()
249        v = relerr[index]
250        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
251        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
252        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
253        plt.ylabel('P(err)')
254        plt.title('Comparison of single and double precision models for %s'%name)
255
[8a20be5]256    plt.show()
257
[87985ca]258# ===========================================================================
259#
260USAGE="""
261usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
262
263Compare the speed and value for a model between the SasView original and the
264OpenCL rewrite.
265
266model is the name of the model to compare (see below).
267Nopencl is the number of times to run the OpenCL model (default=5)
268Nsasview is the number of times to run the Sasview model (default=1)
269
270Options (* for default):
271
272    -plot*/-noplot plots or suppress the plot of the model
[2d0aced]273    -single*/-double uses double precision for comparison
[29f5536]274    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[216a9e1]275    -Nq=128 sets the number of Q points in the data set
276    -1d/-2d* computes 1d or 2d data
[2d0aced]277    -preset*/-random[=seed] preset or random parameters
278    -mono/-poly* force monodisperse/polydisperse
279    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
280    -cutoff=1e-5*/value cutoff for including a point in polydispersity
281    -pars/-nopars* prints the parameter set or not
282    -abs/-rel* plot relative or absolute error
[ba69383]283    -hist/-nohist* plot histogram of relative error
[87985ca]284
285Key=value pairs allow you to set specific values to any of the model
286parameters.
287
288Available models:
289
290    %s
291"""
292
[216a9e1]293NAME_OPTIONS = set([
[87985ca]294    'plot','noplot',
295    'single','double',
[29f5536]296    'lowq','midq','highq','exq',
[87985ca]297    '2d','1d',
298    'preset','random',
299    'poly','mono',
300    'sasview','ctypes',
301    'nopars','pars',
302    'rel','abs',
[ba69383]303    'hist','nohist',
[216a9e1]304    ])
305VALUE_OPTIONS = [
306    # Note: random is both a name option and a value option
307    'cutoff', 'random', 'Nq',
[87985ca]308    ]
309
[373d1b6]310def get_demo_pars(name):
311    import sasmodels.models
312    __import__('sasmodels.models.'+name)
313    model = getattr(sasmodels.models, name)
314    pars = getattr(model, 'demo', None)
315    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
316    return pars
317
[87985ca]318def main():
319    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
320    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
[d547f16]321    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]322    if len(args) == 0:
323        print(USAGE%models)
324        sys.exit(1)
325    if args[0] not in MODELS:
326        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
327        sys.exit(1)
328
329    invalid = [o[1:] for o in opts
[216a9e1]330               if o[1:] not in NAME_OPTIONS
331                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]332    if invalid:
333        print "Invalid options: %s"%(", ".join(invalid))
334        sys.exit(1)
335
[d547f16]336    # Get demo parameters from model definition, or use default parameters
337    # if model does not define demo parameters
338    name = args[0]
[373d1b6]339    pars = get_demo_pars(name)
[d547f16]340
[87985ca]341    Nopencl = int(args[1]) if len(args) > 1 else 5
[ba69383]342    Nsasview = int(args[2]) if len(args) > 2 else 1
[87985ca]343
344    # Fill in default polydispersity parameters
345    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
346    for p in pds:
347        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
348        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
349
350    # Fill in parameters given on the command line
351    set_pars = {}
352    for arg in args[3:]:
353        k,v = arg.split('=')
354        if k not in pars:
355            # extract base name without distribution
356            s = set(p.split('_pd')[0] for p in pars)
357            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
358            sys.exit(1)
359        set_pars[k] = float(v) if not v.endswith('type') else v
360
361    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
362
[8a20be5]363if __name__ == "__main__":
[87985ca]364    main()
Note: See TracBrowser for help on using the repository browser.