source: sasmodels/sasmodels/compare.py @ 9a66e65

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9a66e65 was 9a66e65, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor sasmodels to sasview parameter conversion

  • Property mode set to 100755
File size: 17.8 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[7cf2cfd]8import datetime
[5753e4e]9import traceback
[87985ca]10
[1726b21]11import numpy as np
[473183c]12
[29fc2a3]13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
[e922c5d]17from . import core
18from . import kerneldll
[cd3dba0]19from . import generate
[e922c5d]20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
[9a66e65]22from .convert import revert_model, constrain_new_to_old
[750ffa5]23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]24
[d547f16]25# List of available models
26MODELS = [basename(f)[:-3]
[e922c5d]27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
[d547f16]28
[7cf2cfd]29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
[8a20be5]67
[aa4946b]68def sasview_model(model_definition, **pars):
[87985ca]69    """
70    Load a sasview model given the model name.
71    """
72    # convert model parameters from sasmodel form to sasview form
[9404dd3]73    #print("old",sorted(pars.items()))
[aa4946b]74    modelname, pars = revert_model(model_definition, pars)
[9404dd3]75    #print("new",sorted(pars.items()))
[87c722e]76    sas = __import__('sas.models.'+modelname)
77    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
[8a20be5]78    if ModelClass is None:
[87c722e]79        raise ValueError("could not find model %r in sas.models"%modelname)
[8a20be5]80    model = ModelClass()
81
82    for k,v in pars.items():
[9a66e65]83        parts = k.split('.')  # polydispersity components
84        if len(parts) == 2:
85            model.dispersion[parts[0]][parts[1]] = v
[8a20be5]86        else:
87            model.setParam(k, v)
88    return model
89
[87985ca]90def randomize(p, v):
91    """
92    Randomizing parameter.
93
94    Guess the parameter type from name.
95    """
96    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
97        return v
98    elif any(s in p for s in ('theta','phi','psi')):
99        # orientation in [-180,180], orientation pd in [0,45]
100        if p.endswith('_pd'):
101            return 45*np.random.rand()
102        else:
103            return 360*np.random.rand() - 180
104    elif 'sld' in p:
105        # sld in in [-0.5,10]
106        return 10.5*np.random.rand() - 0.5
107    elif p.endswith('_pd'):
108        # length pd in [0,1]
109        return np.random.rand()
110    else:
[b89f519]111        # values from 0 to 2*x for all other parameters
112        return 2*np.random.rand()*(v if v != 0 else 1)
[87985ca]113
[cd3dba0]114def randomize_model(pars, seed=None):
[216a9e1]115    if seed is None:
116        seed = np.random.randint(1e9)
117    np.random.seed(seed)
118    # Note: the sort guarantees order of calls to random number generator
119    pars = dict((p,randomize(p,v)) for p,v in sorted(pars.items()))
[cd3dba0]120
121    return pars, seed
122
123def constrain_pars(model_definition, pars):
[9a66e65]124    """
125    Restrict parameters to valid values.
126    """
[cd3dba0]127    name = model_definition.name
[216a9e1]128    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
129        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
[b514adf]130    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
131        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
132
133    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
134    if name == 'guinier':
135        #q_max = 0.2  # mid q maximum
136        q_max = 1.0  # high q maximum
137        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
138        pars['rg'] = min(pars['rg'],rg_max)
[cd3dba0]139
[87985ca]140def parlist(pars):
141    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
142
143def suppress_pd(pars):
144    """
145    Suppress theta_pd for now until the normalization is resolved.
146
147    May also suppress complete polydispersity of the model to test
148    models more quickly.
149    """
150    for p in pars:
151        if p.endswith("_pd"): pars[p] = 0
152
[7cf2cfd]153def eval_sasview(model_definition, pars, data, Nevals=1):
[dc056b9]154    # importing sas here so that the error message will be that sas failed to
155    # import rather than the more obscure smear_selection not imported error
[2bebe2b]156    import sas
[346bc88]157    from sas.models.qsmearing import smear_selection
[7cf2cfd]158    model = sasview_model(model_definition, **pars)
[346bc88]159    smearer = smear_selection(data, model=model)
[0763009]160    value = None  # silence the linter
[216a9e1]161    toc = tic()
[0763009]162    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
[216a9e1]163        if hasattr(data, 'qx_data'):
[346bc88]164            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
165            index = ((~data.mask) & (~np.isnan(data.data))
166                     & (q >= data.qmin) & (q <= data.qmax))
167            if smearer is not None:
168                smearer.model = model  # because smear_selection has a bug
[3e6aaad]169                smearer.accuracy = data.accuracy
[346bc88]170                smearer.set_index(index)
171                value = smearer.get_value()
172            else:
173                value = model.evalDistribution([data.qx_data[index], data.qy_data[index]])
[216a9e1]174        else:
175            value = model.evalDistribution(data.x)
[346bc88]176            if smearer is not None:
177                value = smearer(value)
[216a9e1]178    average_time = toc()*1000./Nevals
179    return value, average_time
180
[5d316e9]181def eval_opencl(model_definition, pars, data, dtype='single', Nevals=1,
182                cutoff=0., fast=False):
[216a9e1]183    try:
[5d316e9]184        model = core.load_model(model_definition, dtype=dtype,
185                                platform="ocl", fast=fast)
[9404dd3]186    except Exception as exc:
187        print(exc)
188        print("... trying again with single precision")
[5d316e9]189        model = core.load_model(model_definition, dtype='single',
190                                platform="ocl", fast=fast)
[7cf2cfd]191    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]192    value = None  # silence the linter
[216a9e1]193    toc = tic()
[0763009]194    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]195        value = calculator(**pars)
[216a9e1]196    average_time = toc()*1000./Nevals
197    return value, average_time
198
[7cf2cfd]199
[0763009]200def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.):
[aa4946b]201    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[7cf2cfd]202    calculator = DirectModel(data, model, cutoff=cutoff)
[0763009]203    value = None  # silence the linter
[216a9e1]204    toc = tic()
[0763009]205    for _ in range(max(Nevals, 1)):  # force at least one eval
[7cf2cfd]206        value = calculator(**pars)
[216a9e1]207    average_time = toc()*1000./Nevals
208    return value, average_time
209
[7cf2cfd]210
[3e6aaad]211def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'):
[216a9e1]212    if is2D:
[346bc88]213        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution)
[3e6aaad]214        data.accuracy = accuracy
[87985ca]215        set_beam_stop(data, 0.004)
216        index = ~data.mask
[216a9e1]217    else:
[b89f519]218        if view == 'log':
219            qmax = math.log10(qmax)
220            q = np.logspace(qmax-3, qmax, Nq)
221        else:
222            q = np.linspace(0.001*qmax, qmax, Nq)
[346bc88]223        data = empty_data1D(q, resolution=resolution)
[216a9e1]224        index = slice(None, None)
225    return data, index
226
[4b41184]227def compare(name, pars, Ncomp, Nbase, opts, set_pars):
[cd3dba0]228    model_definition = core.load_model_definition(name)
229
[5edfe12]230    view = ('linear' if '-linear' in opts
231            else 'log' if '-log' in opts
232            else 'q4' if '-q4' in opts
233            else 'log')
[b89f519]234
[216a9e1]235    opt_values = dict(split
236                      for s in opts for split in ((s.split('='),))
237                      if len(split) == 2)
238    # Sort out data
[5edfe12]239    qmax = (10.0 if '-exq' in opts
240            else 1.0 if '-highq' in opts
241            else 0.2 if '-midq' in opts
242            else 0.05)
[216a9e1]243    Nq = int(opt_values.get('-Nq', '128'))
[346bc88]244    res = float(opt_values.get('-res', '0'))
[3e6aaad]245    accuracy = opt_values.get('-accuracy', 'Low')
[73a3e22]246    is2D = "-2d" in opts
[3e6aaad]247    data, index = make_data(qmax, is2D, Nq, res, accuracy, view=view)
[216a9e1]248
[87985ca]249
250    # modelling accuracy is determined by dtype and cutoff
[e1ace4d]251    dtype = ('longdouble' if '-quad' in opts
[5edfe12]252             else 'double' if '-double' in opts
[5d316e9]253             else 'half' if '-half' in opts
[5edfe12]254             else 'single')
[216a9e1]255    cutoff = float(opt_values.get('-cutoff','1e-5'))
[5d316e9]256    fast = "-fast" in opts and dtype is 'single'
[87985ca]257
258    # randomize parameters
[7cf2cfd]259    #pars.update(set_pars)  # set value before random to control range
[216a9e1]260    if '-random' in opts or '-random' in opt_values:
261        seed = int(opt_values['-random']) if '-random' in opt_values else None
[cd3dba0]262        pars, seed = randomize_model(pars, seed=seed)
[9404dd3]263        print("Randomize using -random=%i"%seed)
[7cf2cfd]264    pars.update(set_pars)  # set value after random to control value
[b514adf]265    constrain_pars(model_definition, pars)
[9a66e65]266    constrain_new_to_old(model_definition, pars)
[87985ca]267
268    # parameter selection
269    if '-mono' in opts:
270        suppress_pd(pars)
271    if '-pars' in opts:
[9404dd3]272        print("pars "+str(parlist(pars)))
[87985ca]273
[4b41184]274    # Base calculation
275    if 0:
276        from sasmodels.models import sphere as target
277        base_name = target.name
278        base, base_time = eval_ctypes(target, pars, data,
[5d316e9]279                dtype='longdouble', cutoff=0., Nevals=Ncomp)
[4b41184]280    elif Nbase > 0 and "-ctypes" in opts and "-sasview" in opts:
[319ab14]281        try:
[4b41184]282            base, base_time = eval_sasview(model_definition, pars, data, Ncomp)
283            base_name = "sasview"
[9404dd3]284            #print("base/sasview", (base-pars['background'])/(comp-pars['background']))
285            print("sasview t=%.1f ms, intensity=%.0f"%(base_time, sum(base)))
286            #print("sasview",comp)
[319ab14]287        except ImportError:
288            traceback.print_exc()
[1ec7efa]289            Nbase = 0
[4b41184]290    elif Nbase > 0:
291        base, base_time = eval_opencl(model_definition, pars, data,
[5d316e9]292                dtype=dtype, cutoff=cutoff, Nevals=Nbase, fast=fast)
[4b41184]293        base_name = "ocl"
[9404dd3]294        print("opencl t=%.1f ms, intensity=%.0f"%(base_time, sum(base)))
295        #print("base " + base)
296        #print(max(base), min(base))
[4b41184]297
298    # Comparison calculation
299    if Ncomp > 0 and "-ctypes" in opts:
300        comp, comp_time = eval_ctypes(model_definition, pars, data,
[5d316e9]301                dtype=dtype, cutoff=cutoff, Nevals=Ncomp)
[4b41184]302        comp_name = "ctypes"
[9404dd3]303        print("ctypes t=%.1f ms, intensity=%.0f"%(comp_time, sum(comp)))
[4b41184]304    elif Ncomp > 0:
[7cf2cfd]305        try:
[4b41184]306            comp, comp_time = eval_sasview(model_definition, pars, data, Ncomp)
307            comp_name = "sasview"
[9404dd3]308            #print("base/sasview", (base-pars['background'])/(comp-pars['background']))
309            print("sasview t=%.1f ms, intensity=%.0f"%(comp_time, sum(comp)))
310            #print("sasview",comp)
[7cf2cfd]311        except ImportError:
[5753e4e]312            traceback.print_exc()
[4b41184]313            Ncomp = 0
[87985ca]314
315    # Compare, but only if computing both forms
[4b41184]316    if Nbase > 0 and Ncomp > 0:
[9404dd3]317        #print("speedup %.2g"%(comp_time/base_time))
318        #print("max |base/comp|", max(abs(base/comp)), "%.15g"%max(abs(base)), "%.15g"%max(abs(comp)))
[4b41184]319        #comp *= max(base/comp)
320        resid = (base - comp)
321        relerr = resid/comp
[ba69383]322        #bad = (relerr>1e-4)
[9404dd3]323        #print(relerr[bad],comp[bad],base[bad],data.qx_data[bad],data.qy_data[bad])
[4b41184]324        _print_stats("|%s-%s|"%(base_name,comp_name)+(" "*(3+len(comp_name))), resid)
325        _print_stats("|(%s-%s)/%s|"%(base_name,comp_name,comp_name), relerr)
[87985ca]326
327    # Plot if requested
328    if '-noplot' in opts: return
[1726b21]329    import matplotlib.pyplot as plt
[4b41184]330    if Ncomp > 0:
331        if Nbase > 0: plt.subplot(131)
332        plot_theory(data, comp, view=view, plot_data=False)
333        plt.title("%s t=%.1f ms"%(comp_name,comp_time))
[7cf2cfd]334        #cbar_title = "log I"
[4b41184]335    if Nbase > 0:
336        if Ncomp > 0: plt.subplot(132)
337        plot_theory(data, base, view=view, plot_data=False)
338        plt.title("%s t=%.1f ms"%(base_name,base_time))
[7cf2cfd]339        #cbar_title = "log I"
[4b41184]340    if Ncomp > 0 and Nbase > 0:
[87985ca]341        plt.subplot(133)
[29f5536]342        if '-abs' in opts:
[b89f519]343            err,errstr,errview = resid, "abs err", "linear"
[29f5536]344        else:
[b89f519]345            err,errstr,errview = abs(relerr), "rel err", "log"
[4b41184]346        #err,errstr = base/comp,"ratio"
[7cf2cfd]347        plot_theory(data, None, resid=err, view=errview, plot_data=False)
[346bc88]348        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
[7cf2cfd]349        #cbar_title = errstr if errview=="linear" else "log "+errstr
350    #if is2D:
351    #    h = plt.colorbar()
352    #    h.ax.set_title(cbar_title)
[ba69383]353
[4b41184]354    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
[ba69383]355        plt.figure()
[346bc88]356        v = relerr
[ba69383]357        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
358        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
359        plt.xlabel('log10(err), err = | F(q) single - F(q) double| / | F(q) double |');
360        plt.ylabel('P(err)')
361        plt.title('Comparison of single and double precision models for %s'%name)
362
[8a20be5]363    plt.show()
364
[0763009]365def _print_stats(label, err):
366    sorted_err = np.sort(abs(err))
367    p50 = int((len(err)-1)*0.50)
368    p98 = int((len(err)-1)*0.98)
369    data = [
370        "max:%.3e"%sorted_err[-1],
371        "median:%.3e"%sorted_err[p50],
372        "98%%:%.3e"%sorted_err[p98],
373        "rms:%.3e"%np.sqrt(np.mean(err**2)),
374        "zero-offset:%+.3e"%np.mean(err),
375        ]
[9404dd3]376    print(label+"  ".join(data))
[0763009]377
378
379
[87985ca]380# ===========================================================================
381#
382USAGE="""
383usage: compare.py model [Nopencl] [Nsasview] [options...] [key=val]
384
385Compare the speed and value for a model between the SasView original and the
386OpenCL rewrite.
387
388model is the name of the model to compare (see below).
389Nopencl is the number of times to run the OpenCL model (default=5)
390Nsasview is the number of times to run the Sasview model (default=1)
391
392Options (* for default):
393
394    -plot*/-noplot plots or suppress the plot of the model
[5d316e9]395    -half/-single*/-double/-quad/-fast sets the calculation precision
[29f5536]396    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[216a9e1]397    -Nq=128 sets the number of Q points in the data set
[73a3e22]398    -1d*/-2d computes 1d or 2d data
[2d0aced]399    -preset*/-random[=seed] preset or random parameters
400    -mono/-poly* force monodisperse/polydisperse
[319ab14]401    -ctypes/-sasview* selects gpu:cpu, gpu:sasview, or sasview:cpu if both
[3e6aaad]402    -cutoff=1e-5* cutoff value for including a point in polydispersity
[2d0aced]403    -pars/-nopars* prints the parameter set or not
404    -abs/-rel* plot relative or absolute error
[b89f519]405    -linear/-log/-q4 intensity scaling
[ba69383]406    -hist/-nohist* plot histogram of relative error
[346bc88]407    -res=0 sets the resolution width dQ/Q if calculating with resolution
[5d316e9]408    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
[87985ca]409
410Key=value pairs allow you to set specific values to any of the model
411parameters.
412
413Available models:
414"""
415
[7cf2cfd]416
[216a9e1]417NAME_OPTIONS = set([
[5d316e9]418    'plot', 'noplot',
419    'half', 'single', 'double', 'quad', 'fast',
420    'lowq', 'midq', 'highq', 'exq',
421    '2d', '1d',
422    'preset', 'random',
423    'poly', 'mono',
424    'sasview', 'ctypes',
425    'nopars', 'pars',
426    'rel', 'abs',
[b89f519]427    'linear', 'log', 'q4',
[5d316e9]428    'hist', 'nohist',
[216a9e1]429    ])
430VALUE_OPTIONS = [
431    # Note: random is both a name option and a value option
[3e6aaad]432    'cutoff', 'random', 'Nq', 'res', 'accuracy',
[87985ca]433    ]
434
[7cf2cfd]435def columnize(L, indent="", width=79):
436    column_width = max(len(w) for w in L) + 1
437    num_columns = (width - len(indent)) // column_width
438    num_rows = len(L) // num_columns
439    L = L + [""] * (num_rows*num_columns - len(L))
440    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
441    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
442             for row in zip(*columns)]
443    output = indent + ("\n"+indent).join(lines)
444    return output
445
446
[cd3dba0]447def get_demo_pars(model_definition):
448    info = generate.make_info(model_definition)
449    pars = dict((p[0],p[2]) for p in info['parameters'])
450    pars.update(info['demo'])
[373d1b6]451    return pars
452
[87985ca]453def main():
454    opts = [arg for arg in sys.argv[1:] if arg.startswith('-')]
[319ab14]455    popts = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' in arg]
456    args = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' not in arg]
[d547f16]457    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]458    if len(args) == 0:
[7cf2cfd]459        print(USAGE)
460        print(columnize(MODELS, indent="  "))
[87985ca]461        sys.exit(1)
462    if args[0] not in MODELS:
[9404dd3]463        print("Model %r not available. Use one of:\n    %s"%(args[0],models))
[87985ca]464        sys.exit(1)
[319ab14]465    if len(args) > 3:
466        print("expected parameters: model Nopencl Nsasview")
[87985ca]467
468    invalid = [o[1:] for o in opts
[216a9e1]469               if o[1:] not in NAME_OPTIONS
470                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]471    if invalid:
[9404dd3]472        print("Invalid options: %s"%(", ".join(invalid)))
[87985ca]473        sys.exit(1)
474
[d547f16]475    # Get demo parameters from model definition, or use default parameters
476    # if model does not define demo parameters
477    name = args[0]
[cd3dba0]478    model_definition = core.load_model_definition(name)
479    pars = get_demo_pars(model_definition)
[d547f16]480
[4b41184]481    Ncomp = int(args[1]) if len(args) > 1 else 5
482    Nbase = int(args[2]) if len(args) > 2 else 1
[87985ca]483
484    # Fill in default polydispersity parameters
485    pds = set(p.split('_pd')[0] for p in pars if p.endswith('_pd'))
486    for p in pds:
487        if p+"_pd_nsigma" not in pars: pars[p+"_pd_nsigma"] = 3
488        if p+"_pd_type" not in pars: pars[p+"_pd_type"] = "gaussian"
489
490    # Fill in parameters given on the command line
491    set_pars = {}
[319ab14]492    for arg in popts:
493        k,v = arg.split('=',1)
[87985ca]494        if k not in pars:
495            # extract base name without distribution
496            s = set(p.split('_pd')[0] for p in pars)
[9404dd3]497            print("%r invalid; parameters are: %s"%(k,", ".join(sorted(s))))
[87985ca]498            sys.exit(1)
499        set_pars[k] = float(v) if not v.endswith('type') else v
500
[4b41184]501    compare(name, pars, Ncomp, Nbase, opts, set_pars)
[87985ca]502
[8a20be5]503if __name__ == "__main__":
[87985ca]504    main()
Note: See TracBrowser for help on using the repository browser.