source: sasmodels/sasmodels/compare.py @ 7be65ea

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 7be65ea was cd3dba0, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

improve compare.py so that parameters can be constrained to valid values

  • Property mode set to 100755
File size: 16.2 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[7cf2cfd]8import datetime
[87985ca]9
[1726b21]10import numpy as np
[473183c]11
[29fc2a3]12ROOT = dirname(__file__)
13sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
14
15
[e922c5d]16from . import core
17from . import kerneldll
[cd3dba0]18from . import generate
[e922c5d]19from .data import plot_theory, empty_data1D, empty_data2D
20from .direct_model import DirectModel
21from .convert import revert_model
[750ffa5]22kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]23
[d547f16]24# List of available models
25MODELS = [basename(f)[:-3]
[e922c5d]26          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
[d547f16]27
[7cf2cfd]28# CRUFT python 2.6
29if not hasattr(datetime.timedelta, 'total_seconds'):
30    def delay(dt):
31        """Return number date-time delta as number seconds"""
32        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
33else:
34    def delay(dt):
35        """Return number date-time delta as number seconds"""
36        return dt.total_seconds()
37
38
39def tic():
40    """
41    Timer function.
42
43    Use "toc=tic()" to start the clock and "toc()" to measure
44    a time interval.
45    """
46    then = datetime.datetime.now()
47    return lambda: delay(datetime.datetime.now() - then)
48
49
50def set_beam_stop(data, radius, outer=None):
51    """
52    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
53
54    Note: this function does not use the sasview package
55    """
56    if hasattr(data, 'qx_data'):
57        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
58        data.mask = (q < radius)
59        if outer is not None:
60            data.mask |= (q >= outer)
61    else:
62        data.mask = (data.x < radius)
63        if outer is not None:
64            data.mask |= (data.x >= outer)
65
[8a20be5]66
[aa4946b]67def sasview_model(model_definition, **pars):
[87985ca]68    """
69    Load a sasview model given the model name.
70    """
71    # convert model parameters from sasmodel form to sasview form
72    #print "old",sorted(pars.items())
[aa4946b]73    modelname, pars = revert_model(model_definition, pars)
[87985ca]74    #print "new",sorted(pars.items())
[87c722e]75    sas = __import__('sas.models.'+modelname)
76    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]77    if ModelClass is None:
[87c722e]78        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]79    model = ModelClass()
80
81    for k,v in pars.items():
82        if k.endswith("_pd"):
83            model.dispersion[k[:-3]]['width'] = v
84        elif k.endswith("_pd_n"):
85            model.dispersion[k[:-5]]['npts'] = v
86        elif k.endswith("_pd_nsigma"):
87            model.dispersion[k[:-10]]['nsigmas'] = v
[87985ca]88        elif k.endswith("_pd_type"):
89            model.dispersion[k[:-8]]['type'] = v
[8a20be5]90        else:
91            model.setParam(k, v)
92    return model
93
[87985ca]94def randomize(p, v):
95    """
96    Randomizing parameter.
97
98    Guess the parameter type from name.
99    """
100    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
101        return v
102    elif any(s in p for s in ('theta','phi','psi')):
103        # orientation in [-180,180], orientation pd in [0,45]
104        if p.endswith('_pd'):
105            return 45*np.random.rand()
106        else:
107            return 360*np.random.rand() - 180
108    elif 'sld' in p:
109        # sld in in [-0.5,10]
110        return 10.5*np.random.rand() - 0.5
111    elif p.endswith('_pd'):
112        # length pd in [0,1]
113        return np.random.rand()
114    else:
[b89f519]115        # values from 0 to 2*x for all other parameters
116        return 2*np.random.rand()*(v if v != 0 else 1)
[87985ca]117
[cd3dba0]118def randomize_model(pars, seed=None):
[216a9e1]119    if seed is None:
120        seed = np.random.randint(1e9)
121    np.random.seed(seed)
122    # Note: the sort guarantees order of calls to random number generator
123    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
[cd3dba0]124
125    return pars, seed
126
127def constrain_pars(model_definition, pars):
128    name = model_definition.name
[216a9e1]129    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
130        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
[cd3dba0]131
132    # These constraints are only needed for comparison to sasview
133    if name in ('teubner_strey','broad_peak'):
134        del pars['scale']
135    if name in ('guinier',):
136        del pars['background']
137    if getattr(model_definition, 'category', None) == 'structure-factor':
138        del pars['scale'], pars['background']
139
[216a9e1]140
[87985ca]141def parlist(pars):
142    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
143
144def suppress_pd(pars):
145    """
146    Suppress theta_pd for now until the normalization is resolved.
147
148    May also suppress complete polydispersity of the model to test
149    models more quickly.
150    """
151    for p in pars:
152        if p.endswith("_pd"): pars[p] = 0
153
[7cf2cfd]154def eval_sasview(model_definition, pars, data, Nevals=1):
[346bc88]155    from sas.models.qsmearing import smear_selection
[7cf2cfd]156    model = sasview_model(model_definition, **pars)
[346bc88]157    smearer = smear_selection(data, model=model)
[0763009]158    value = None  # silence the linter
[216a9e1]159    toc = tic()
[0763009]160    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
[216a9e1]161        if hasattr(data, 'qx_data'):
[346bc88]162            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
163            index = ((~data.mask) & (~np.isnan(data.data))
164                     & (q >= data.qmin) & (q <= data.qmax))
165            if smearer is not None:
166                smearer.model = model  # because smear_selection has a bug
[3e6aaad]167                smearer.accuracy = data.accuracy
[346bc88]168                smearer.set_index(index)
169                value = smearer.get_value()
170            else:
171                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
[216a9e1]172        else:
173            value = model.evalDistribution(data.x)
[346bc88]174            if smearer is not None:
175                value = smearer(value)
[216a9e1]176    average_time = toc()*1000./Nevals
177    return value, average_time
178
[0763009]179def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1, cutoff=0.):
[216a9e1]180    try:
[aa4946b]181        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[216a9e1]182    except Exception,exc:
183        print exc
184        print "... trying again with single precision"
[aa4946b]185        model = core.load_model(model_definition, dtype='single', platform="ocl")
[7cf2cfd]186    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]187    value = None  # silence the linter
[216a9e1]188    toc = tic()
[0763009]189    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]190        value = calculator(**pars)
[216a9e1]191    average_time = toc()*1000./Nevals
192    return value, average_time
193
[7cf2cfd]194
[0763009]195def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
[aa4946b]196    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[7cf2cfd]197    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]198    value = None  # silence the linter
[216a9e1]199    toc = tic()
[0763009]200    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]201        value = calculator(**pars)
[216a9e1]202    average_time = toc()*1000./Nevals
203    return value, average_time
204
[7cf2cfd]205
[3e6aaad]206def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
[216a9e1]207    if is2D:
[346bc88]208        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
[3e6aaad]209        data.accuracy = accuracy
[87985ca]210        set_beam_stop(data, 0.004)
211        index = ~data.mask
[216a9e1]212    else:
[b89f519]213        if view == 'log':
214            qmax = math.log10(qmax)
215            q = np.logspace(qmax-3, qmax, Nq)
216        else:
217            q = np.linspace(0.001*qmax, qmax, Nq)
[346bc88]218        data = empty_data1D(q, resolution=resolution)
[216a9e1]219        index = slice(None, None)
220    return data, index
221
[a503bfd]222def compare(name, pars, Ncpu, Nocl, opts, set_pars):
[cd3dba0]223    model_definition = core.load_model_definition(name)
224
[b89f519]225    view = 'linear' if '-linear' in opts else 'log' if '-log' in opts else 'q4' if '-q4' in opts else 'log'
226
[216a9e1]227    opt_values = dict(split
228                      for s in opts for split in ((s.split('='),))
229                      if len(split) == 2)
230    # Sort out data
[29f5536]231    qmax = 10.0 if '-exq' in opts else 1.0 if '-highq' in opts else 0.2 if '-midq' in opts else 0.05
[216a9e1]232    Nq = int(opt_values.get('-Nq', '128'))
[346bc88]233    res = float(opt_values.get('-res', '0'))
[3e6aaad]234    accuracy = opt_values.get('-accuracy', 'Low')
[73a3e22]235    is2D = "-2d" in opts
[3e6aaad]236    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
[216a9e1]237
[87985ca]238
239    # modelling accuracy is determined by dtype and cutoff
240    dtype = 'double' if '-double' in opts else 'single'
[216a9e1]241    cutoff = float(opt_values.get('-cutoff','1e-5'))
[87985ca]242
243    # randomize parameters
[7cf2cfd]244    #pars.update(set_pars)  # set value before random to control range
[216a9e1]245    if '-random' in opts or '-random' in opt_values:
246        seed = int(opt_values['-random']) if '-random' in opt_values else None
[cd3dba0]247        pars, seed = randomize_model(pars, seed=seed)
248        constrain_pars(model_definition, pars)
[87985ca]249        print "Randomize using -random=%i"%seed
[7cf2cfd]250    pars.update(set_pars)  # set value after random to control value
[87985ca]251
252    # parameter selection
253    if '-mono' in opts:
254        suppress_pd(pars)
255    if '-pars' in opts:
256        print "pars",parlist(pars)
257
258    # OpenCl calculation
[a503bfd]259    if Nocl > 0:
[aa4946b]260        ocl, ocl_time = eval_opencl(model_definition, pars, data,
261                                    dtype=dtype, cutoff=cutoff, Nevals=Nocl)
[346bc88]262        print "opencl t=%.1f ms, intensity=%.0f"%(ocl_time, sum(ocl))
[cd3dba0]263        #print "ocl", ocl
[a503bfd]264        #print max(ocl), min(ocl)
[87985ca]265
266    # ctypes/sasview calculation
267    if Ncpu > 0 and "-ctypes" in opts:
[aa4946b]268        cpu, cpu_time = eval_ctypes(model_definition, pars, data,
269                                    dtype=dtype, cutoff=cutoff, Nevals=Ncpu)
[87985ca]270        comp = "ctypes"
[346bc88]271        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
[87985ca]272    elif Ncpu > 0:
[7cf2cfd]273        try:
274            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu)
275            comp = "sasview"
276            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background'])
277            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu))
[cd3dba0]278            #print "sasview",cpu
[7cf2cfd]279        except ImportError:
280            Ncpu = 0
[87985ca]281
282    # Compare, but only if computing both forms
[a503bfd]283    if Nocl > 0 and Ncpu > 0:
284        #print "speedup %.2g"%(cpu_time/ocl_time)
285        #print "max |ocl/cpu|", max(abs(ocl/cpu)), "%.15g"%max(abs(ocl)), "%.15g"%max(abs(cpu))
286        #cpu *= max(ocl/cpu)
[346bc88]287        resid = (ocl - cpu)
288        relerr = resid/cpu
[ba69383]289        #bad = (relerr>1e-4)
[a503bfd]290        #print relerr[bad],cpu[bad],ocl[bad],data.qx_data[bad],data.qy_data[bad]
[0763009]291        _print_stats("|ocl-%s|"%comp+(" "*(3+len(comp))), resid)
292        _print_stats("|(ocl-%s)/%s|"%(comp,comp), relerr)
[87985ca]293
294    # Plot if requested
295    if '-noplot' in opts: return
[1726b21]296    import matplotlib.pyplot as plt
[87985ca]297    if Ncpu > 0:
[a503bfd]298        if Nocl > 0: plt.subplot(131)
[7cf2cfd]299        plot_theory(data, cpu, view=view, plot_data=False)
[87985ca]300        plt.title("%s t=%.1f ms"%(comp,cpu_time))
[7cf2cfd]301        #cbar_title = "log I"
[a503bfd]302    if Nocl > 0:
[87985ca]303        if Ncpu > 0: plt.subplot(132)
[7cf2cfd]304        plot_theory(data, ocl, view=view, plot_data=False)
[a503bfd]305        plt.title("opencl t=%.1f ms"%ocl_time)
[7cf2cfd]306        #cbar_title = "log I"
[a503bfd]307    if Ncpu > 0 and Nocl > 0:
[87985ca]308        plt.subplot(133)
[29f5536]309        if '-abs' in opts:
[b89f519]310            err,errstr,errview = resid, "abs err", "linear"
[29f5536]311        else:
[b89f519]312            err,errstr,errview = abs(relerr), "rel err", "log"
[a503bfd]313        #err,errstr = ocl/cpu,"ratio"
[7cf2cfd]314        plot_theory(data, None, resid=err, view=errview, plot_data=False)
[346bc88]315        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
[7cf2cfd]316        #cbar_title = errstr if errview=="linear" else "log "+errstr
317    #if is2D:
318    #    h = plt.colorbar()
319    #    h.ax.set_title(cbar_title)
[ba69383]320
[a503bfd]321    if Ncpu > 0 and Nocl > 0 and '-hist' in opts:
[ba69383]322        plt.figure()
[346bc88]323        v = relerr
[ba69383]324        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
325        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
326        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
327        plt.ylabel('P(err)')
328        plt.title('Comparison of single and double precision models for %s'%name)
329
[8a20be5]330    plt.show()
331
[0763009]332def _print_stats(label, err):
333    sorted_err = np.sort(abs(err))
334    p50 = int((len(err)-1)*0.50)
335    p98 = int((len(err)-1)*0.98)
336    data = [
337        "max:%.3e"%sorted_err[-1],
338        "median:%.3e"%sorted_err[p50],
339        "98%%:%.3e"%sorted_err[p98],
340        "rms:%.3e"%np.sqrt(np.mean(err**2)),
341        "zero-offset:%+.3e"%np.mean(err),
342        ]
343    print label,"  ".join(data)
344
345
346
[87985ca]347# ===========================================================================
348#
349USAGE="""
350usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
351
352Compare the speed and value for a model between the SasView original and the
353OpenCL rewrite.
354
355model is the name of the model to compare (see below).
356Nopencl is the number of times to run the OpenCL model (default=5)
357Nsasview is the number of times to run the Sasview model (default=1)
358
359Options (* for default):
360
361    -plot*/-noplot plots or suppress the plot of the model
[2d0aced]362    -single*/-double uses double precision for comparison
[29f5536]363    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[216a9e1]364    -Nq=128 sets the number of Q points in the data set
[73a3e22]365    -1d*/-2d computes 1d or 2d data
[2d0aced]366    -preset*/-random[=seed] preset or random parameters
367    -mono/-poly* force monodisperse/polydisperse
368    -ctypes/-sasview* whether cpu is tested using sasview or ctypes
[3e6aaad]369    -cutoff=1e-5* cutoff value for including a point in polydispersity
[2d0aced]370    -pars/-nopars* prints the parameter set or not
371    -abs/-rel* plot relative or absolute error
[b89f519]372    -linear/-log/-q4 intensity scaling
[ba69383]373    -hist/-nohist* plot histogram of relative error
[346bc88]374    -res=0 sets the resolution width dQ/Q if calculating with resolution
[3e6aaad]375    -accuracy=Low resolution accuracy Low, Mid, High, Xhigh
[87985ca]376
377Key=value pairs allow you to set specific values to any of the model
378parameters.
379
380Available models:
381"""
382
[7cf2cfd]383
[216a9e1]384NAME_OPTIONS = set([
[87985ca]385    'plot','noplot',
386    'single','double',
[29f5536]387    'lowq','midq','highq','exq',
[87985ca]388    '2d','1d',
389    'preset','random',
390    'poly','mono',
391    'sasview','ctypes',
392    'nopars','pars',
393    'rel','abs',
[b89f519]394    'linear', 'log', 'q4',
[ba69383]395    'hist','nohist',
[216a9e1]396    ])
397VALUE_OPTIONS = [
398    # Note: random is both a name option and a value option
[3e6aaad]399    'cutoff', 'random', 'Nq', 'res', 'accuracy',
[87985ca]400    ]
401
[7cf2cfd]402def columnize(L, indent="", width=79):
403    column_width = max(len(w) for w in L) + 1
404    num_columns = (width - len(indent)) // column_width
405    num_rows = len(L) // num_columns
406    L = L + [""] * (num_rows*num_columns - len(L))
407    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
408    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
409             for row in zip(*columns)]
410    output = indent + ("\n"+indent).join(lines)
411    return output
412
413
[cd3dba0]414def get_demo_pars(model_definition):
415    info = generate.make_info(model_definition)
416    pars = dict((p[0],p[2]) for p in info['parameters'])
417    pars.update(info['demo'])
[373d1b6]418    return pars
419
[87985ca]420def main():
421    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
422    args = [arg for arg in sys.argv[1:] if not arg.startswith('-')]
[d547f16]423    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]424    if len(args) == 0:
[7cf2cfd]425        print(USAGE)
426        print(columnize(MODELS, indent="  "))
[87985ca]427        sys.exit(1)
428    if args[0] not in MODELS:
429        print "Model %r not available. Use one of:\n    %s"%(args[0],models)
430        sys.exit(1)
431
432    invalid = [o[1:] for o in opts
[216a9e1]433               if o[1:] not in NAME_OPTIONS
434                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]435    if invalid:
436        print "Invalid options: %s"%(", ".join(invalid))
437        sys.exit(1)
438
[d547f16]439    # Get demo parameters from model definition, or use default parameters
440    # if model does not define demo parameters
441    name = args[0]
[cd3dba0]442    model_definition = core.load_model_definition(name)
443    pars = get_demo_pars(model_definition)
[d547f16]444
[87985ca]445    Nopencl = int(args[1]) if len(args) > 1 else 5
[ba69383]446    Nsasview = int(args[2]) if len(args) > 2 else 1
[87985ca]447
448    # Fill in default polydispersity parameters
449    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
450    for p in pds:
451        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
452        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
453
454    # Fill in parameters given on the command line
455    set_pars = {}
456    for arg in args[3:]:
457        k,v = arg.split('=')
458        if k not in pars:
459            # extract base name without distribution
460            s = set(p.split('_pd')[0] for p in pars)
461            print "%r invalid; parameters are: %s"%(k,", ".join(sorted(s)))
462            sys.exit(1)
463        set_pars[k] = float(v) if not v.endswith('type') else v
464
465    compare(name, pars, Nsasview, Nopencl, opts, set_pars)
466
[8a20be5]467if __name__ == "__main__":
[87985ca]468    main()
Note: See TracBrowser for help on using the repository browser.