source: sasmodels/compare.py @ 29fc2a3

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 29fc2a3 was 29fc2a3, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

make sure sasmodels is on the path for compare.py

  • Property mode set to 100755
File size: 14.0 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[87985ca]8
[1726b21]9import numpy as np
[473183c]10
[29fc2a3]11ROOT = dirname(__file__)
12sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
13
14
[346bc88]15from sasmodels.bumps_model import Model, Experiment, plot_theory, tic
[aa4946b]16from sasmodels import core
[87fce00]17from sasmodels import kerneldll
[87985ca]18from sasmodels.convert import revert_model
[750ffa5]19kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]20
[d547f16]21# List of available models
22MODELS = [basename(f)[:-3]
23          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))]
24
[8a20be5]25
[aa4946b]26def sasview_model(model_definition, **pars):
[87985ca]27    """
28    Load a sasview model given the model name.
29    """
30    # convert model parameters from sasmodel form to sasview form
31    #print "old",sorted(pars.items())
[aa4946b]32    modelname, pars = revert_model(model_definition, pars)
[87985ca]33    #print "new",sorted(pars.items())
[87c722e]34    sas = __import__('sas.models.'+modelname)
35    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]36    if ModelClass is None:
[87c722e]37        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]38    model = ModelClass()
39
40    for k,v in pars.items():
41        if k.endswith("_pd"):
42            model.dispersion[k[:-3]]['width'] = v
43        elif k.endswith("_pd_n"):
44            model.dispersion[k[:-5]]['npts'] = v
45        elif k.endswith("_pd_nsigma"):
46            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]47        elif k.endswith("_pd_type"):
48            model.dispersion[k[:-8]]['type'] = v
[8a20be5]49        else:
50            model.setParam(k, v)
51    return model
52
[87985ca]53def randomize(p, v):
54    """
55    Randomizing parameter.
56
57    Guess the parameter type from name.
58    """
59    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
60        return v
61    elif any(s in p for s in ('theta','phi','psi')):
62        # orientation in [-180,180], orientation pd in [0,45]
63        if p.endswith('_pd'):
64            return 45*np.random.rand()
65        else:
66            return 360*np.random.rand() - 180
67    elif 'sld' in p:
68        # sld in in [-0.5,10]
69        return 10.5*np.random.rand() - 0.5
70    elif p.endswith('_pd'):
71        # length pd in [0,1]
72        return np.random.rand()
73    else:
[b89f519]74        # values from 0 to 2*x for all other parameters
75        return 2*np.random.rand()*(v if v != 0 else 1)
[87985ca]76
[216a9e1]77def randomize_model(name, pars, seed=None):
78    if seed is None:
79        seed = np.random.randint(1e9)
80    np.random.seed(seed)
81    # Note: the sort guarantees order of calls to random number generator
82    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
83    # The capped cylinder model has a constraint on its parameters
84    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
85        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
86    return pars, seed
87
[87985ca]88def parlist(pars):
89    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
90
91def suppress_pd(pars):
92    """
93    Suppress theta_pd for now until the normalization is resolved.
94
95    May also suppress complete polydispersity of the model to test
96    models more quickly.
97    """
98    for p in pars:
99        if p.endswith("_pd"): pars[p] = 0
100
[216a9e1]101def eval_sasview(name, pars, data, Nevals=1):
[346bc88]102    from sas.models.qsmearing import smear_selection
[216a9e1]103    model = sasview_model(name, **pars)
[346bc88]104    smearer = smear_selection(data, model=model)
[0763009]105    value = None  # silence the linter
[216a9e1]106    toc = tic()
[0763009]107    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
[216a9e1]108        if hasattr(data, 'qx_data'):
[346bc88]109            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
110            index = ((~data.mask) & (~np.isnan(data.data))
111                     & (q >= data.qmin) & (q <= data.qmax))
112            if smearer is not None:
113                smearer.model = model  # because smear_selection has a bug
[3e6aaad]114                smearer.accuracy = data.accuracy
[346bc88]115                smearer.set_index(index)
116                value = smearer.get_value()
117            else:
118                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
[216a9e1]119        else:
120            value = model.evalDistribution(data.x)
[346bc88]121            if smearer is not None:
122                value = smearer(value)
[216a9e1]123    average_time = toc()*1000./Nevals
124    return value, average_time
125
[0763009]126def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
[216a9e1]127    try:
[aa4946b]128        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[216a9e1]129    except Exception,exc:
130        print exc
131        print "... trying again with single precision"
[aa4946b]132        model = core.load_model(model_definition, dtype='single', platform="ocl")
[346bc88]133    problem = Experiment(data, Model(model, **pars), cutoff=cutoff)
[0763009]134    value = None  # silence the linter
[216a9e1]135    toc = tic()
[0763009]136    for _ in range(max(Nevals, 1)):  # force at least one eval
[216a9e1]137        #pars['scale'] = np.random.rand()
138        problem.update()
139        value = problem.theory()
140    average_time = toc()*1000./Nevals
141    return value, average_time
142
[0763009]143def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
[aa4946b]144    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[346bc88]145    problem = Experiment(data, Model(model, **pars), cutoff=cutoff)
[0763009]146    value = None  # silence the linter
[216a9e1]147    toc = tic()
[0763009]148    for _ in range(max(Nevals, 1)):  # force at least one eval
[216a9e1]149        problem.update()
150        value = problem.theory()
151    average_time = toc()*1000./Nevals
152    return value, average_time
153
[3e6aaad]154def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
[216a9e1]155    if is2D:
[87985ca]156        from sasmodels.bumps_model import empty_data2D, set_beam_stop
[346bc88]157        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
[3e6aaad]158        data.accuracy = accuracy
[87985ca]159        set_beam_stop(data, 0.004)
160        index = ~data.mask
[216a9e1]161    else:
162        from sasmodels.bumps_model import empty_data1D
[b89f519]163        if view == 'log':
164            qmax = math.log10(qmax)
165            q = np.logspace(qmax-3, qmax, Nq)
166        else:
167            q = np.linspace(0.001*qmax, qmax, Nq)
[346bc88]168        data = empty_data1D(q, resolution=resolution)
[216a9e1]169        index = slice(None, None)
170    return data, index
171
[a503bfd]172def compare(name, pars, Ncpu, Nocl, opts, set_pars):
[b89f519]173    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
174
[216a9e1]175    opt_values = dict(split
176                      for s in opts for split in ((s.split('='),))
177                      if len(split) == 2)
178    # Sort out data
[29f5536]179    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
[216a9e1]180    Nq = int(opt_values.get('-Nq', '128'))
[346bc88]181    res = float(opt_values.get('-res', '0'))
[3e6aaad]182    accuracy = opt_values.get('-accuracy', 'Low')
[216a9e1]183    is2D = not "-1d" in opts
[3e6aaad]184    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
[216a9e1]185
[87985ca]186
187    # modelling accuracy is determined by dtype and cutoff
188    dtype = 'double' if '-double' in opts else 'single'
[216a9e1]189    cutoff = float(opt_values.get('-cutoff','1e-5'))
[87985ca]190
191    # randomize parameters
[b89f519]192    pars.update(set_pars)
[216a9e1]193    if '-random' in opts or '-random' in opt_values:
194        seed = int(opt_values['-random']) if '-random' in opt_values else None
195        pars, seed = randomize_model(name, pars, seed=seed)
[87985ca]196        print "Randomize using -random=%i"%seed
197
198    # parameter selection
199    if '-mono' in opts:
200        suppress_pd(pars)
201    if '-pars' in opts:
202        print "pars",parlist(pars)
203
[aa4946b]204    model_definition = core.load_model_definition(name)
[87985ca]205    # OpenCl calculation
[a503bfd]206    if Nocl > 0:
[aa4946b]207        ocl, ocl_time = eval_opencl(model_definition, pars, data,
208                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
[346bc88]209        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
[a503bfd]210        #print max(ocl), min(ocl)
[87985ca]211
212    # ctypes/sasview calculation
213    if Ncpu > 0 and "-ctypes" in opts:
[aa4946b]214        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
215                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
[87985ca]216        comp = "ctypes"
[346bc88]217        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
[87985ca]218    elif Ncpu > 0:
[aa4946b]219        cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
[87985ca]220        comp = "sasview"
[346bc88]221        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
[87985ca]222
223    # Compare, but only if computing both forms
[a503bfd]224    if Nocl > 0 and Ncpu > 0:
225        #print "speedup %.2g"%(cpu_time/ocl_time)
226        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
227        #cpu *= max(ocl/cpu)
[346bc88]228        resid = (ocl - cpu)
229        relerr = resid/cpu
[ba69383]230        #bad = (relerr>1e-4)
[a503bfd]231        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
[0763009]232        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
233        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
[87985ca]234
235    # Plot if requested
236    if '-noplot' in opts: return
[1726b21]237    import matplotlib.pyplot as plt
[87985ca]238    if Ncpu > 0:
[a503bfd]239        if Nocl > 0: plt.subplot(131)
[346bc88]240        plot_theory(data, cpu, view=view)
[87985ca]241        plt.title("%s t=%.1f ms"%(comp,cpu_time))
[29f5536]242        cbar_title = "log I"
[a503bfd]243    if Nocl > 0:
[87985ca]244        if Ncpu > 0: plt.subplot(132)
[346bc88]245        plot_theory(data, ocl, view=view)
[a503bfd]246        plt.title("opencl t=%.1f ms"%ocl_time)
[29f5536]247        cbar_title = "log I"
[a503bfd]248    if Ncpu > 0 and Nocl > 0:
[87985ca]249        plt.subplot(133)
[29f5536]250        if '-abs' in opts:
[b89f519]251            err,errstr,errview = resid, "abs err", "linear"
[29f5536]252        else:
[b89f519]253            err,errstr,errview = abs(relerr), "rel err", "log"
[a503bfd]254        #err,errstr = ocl/cpu,"ratio"
[346bc88]255        plot_theory(data, err, view=errview)
256        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
[b89f519]257        cbar_title = errstr if errview=="linear" else "log "+errstr
[29f5536]258    if is2D:
259        h = plt.colorbar()
260        h.ax.set_title(cbar_title)
[ba69383]261
[a503bfd]262    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
[ba69383]263        plt.figure()
[346bc88]264        v = relerr
[ba69383]265        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
266        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
267        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
268        plt.ylabel('P(err)')
269        plt.title('Comparison of single and double precision models for %s'%name)
270
[8a20be5]271    plt.show()
272
[0763009]273def _print_stats(label, err):
274    sorted_err = np.sort(abs(err))
275    p50 = int((len(err)-1)*0.50)
276    p98 = int((len(err)-1)*0.98)
277    data = [
278        "max:%.3e"%sorted_err[-1],
279        "median:%.3e"%sorted_err[p50],
280        "98%%:%.3e"%sorted_err[p98],
281        "rms:%.3e"%np.sqrt(np.mean(err**2)),
282        "zero-offset:%+.3e"%np.mean(err),
283        ]
284    print label,"  ".join(data)
285
286
287
[87985ca]288# ===========================================================================
289#
290USAGE="""
291usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
292
293Compare the speed and value for a model between the SasView original and the
294OpenCL rewrite.
295
296model is the name of the model to compare (see below).
297Nopencl is the number of times to run the OpenCL model (default=5)
298Nsasview is the number of times to run the Sasview model (default=1)
299
300Options (* for default):
301
302    -plot*/-noplot plots or suppress the plot of the model
[2d0aced]303    -single*/-double uses double precision for comparison
[29f5536]304    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[216a9e1]305    -Nq=128 sets the number of Q points in the data set
306    -1d/-2d* computes 1d or 2d data
[2d0aced]307    -preset*/-random[=seed] preset or random parameters
308    -mono/-poly* force monodisperse/polydisperse
309    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
[3e6aaad]310    -cutoff=1e-5* cutoff value for including a point in polydispersity
[2d0aced]311    -pars/-nopars* prints the parameter set or not
312    -abs/-rel* plot relative or absolute error
[b89f519]313    -linear/-log/-q4 intensity scaling
[ba69383]314    -hist/-nohist* plot histogram of relative error
[346bc88]315    -res=0 sets the resolution width dQ/Q if calculating with resolution
[3e6aaad]316    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
[87985ca]317
318Key=value pairs allow you to set specific values to any of the model
319parameters.
320
321Available models:
322
323    %s
324"""
325
[216a9e1]326NAME_OPTIONS = set([
[87985ca]327    'plot','noplot',
328    'single','double',
[29f5536]329    'lowq','midq','highq','exq',
[87985ca]330    '2d','1d',
331    'preset','random',
332    'poly','mono',
333    'sasview','ctypes',
334    'nopars','pars',
335    'rel','abs',
[b89f519]336    'linear', 'log', 'q4',
[ba69383]337    'hist','nohist',
[216a9e1]338    ])
339VALUE_OPTIONS = [
340    # Note: random is both a name option and a value option
[3e6aaad]341    'cutoff', 'random', 'Nq', 'res', 'accuracy',
[87985ca]342    ]
343
[373d1b6]344def get_demo_pars(name):
345    import sasmodels.models
346    __import__('sasmodels.models.'+name)
347    model = getattr(sasmodels.models, name)
348    pars = getattr(model, 'demo', None)
349    if pars is None: pars = dict((p[0],p[2]) for p in model.parameters)
350    return pars
351
[87985ca]352def main():
353    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
354    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
[d547f16]355    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]356    if len(args) == 0:
357        print(USAGE%models)
358        sys.exit(1)
359    if args[0] not in MODELS:
360        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
361        sys.exit(1)
362
363    invalid = [o[1:] for o in opts
[216a9e1]364               if o[1:] not in NAME_OPTIONS
365                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]366    if invalid:
367        print "Invalid options: %s"%(", ".join(invalid))
368        sys.exit(1)
369
[d547f16]370    # Get demo parameters from model definition, or use default parameters
371    # if model does not define demo parameters
372    name = args[0]
[373d1b6]373    pars = get_demo_pars(name)
[d547f16]374
[87985ca]375    Nopencl = int(args[1]) if len(args) > 1 else 5
[ba69383]376    Nsasview = int(args[2]) if len(args) > 2 else 1
[87985ca]377
378    # Fill in default polydispersity parameters
379    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
380    for p in pds:
381        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
382        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
383
384    # Fill in parameters given on the command line
385    set_pars = {}
386    for arg in args[3:]:
387        k,v = arg.split('=')
388        if k not in pars:
389            # extract base name without distribution
390            s = set(p.split('_pd')[0] for p in pars)
391            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
392            sys.exit(1)
393        set_pars[k] = float(v) if not v.endswith('type') else v
394
395    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
396
[8a20be5]397if __name__ == "__main__":
[87985ca]398    main()
Note: See TracBrowser for help on using the repository browser.