source: sasmodels/sasmodels/compare.py @ fc0fcd0

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since fc0fcd0 was fc0fcd0, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

remove duplicate code: model list is defined in core

  • Property mode set to 100755
File size: 29.7 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
[caeb06d]3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
[9cfcac8]11tag the model as double precision only.
[caeb06d]12
[9cfcac8]13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
[caeb06d]14sasmodels root to see the command line options.
15
[9cfcac8]16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
[caeb06d]18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
[190fc2b]31import sys
32import math
33from os.path import basename, dirname, join as joinpath
34import glob
35import datetime
36import traceback
37
38import numpy as np
39
40from . import core
41from . import kerneldll
42from . import generate
43from .data import plot_theory, empty_data1D, empty_data2D
44from .direct_model import DirectModel
45from .convert import revert_model, constrain_new_to_old
46
[caeb06d]47USAGE = """
48usage: compare.py model N1 N2 [options...] [key=val]
49
50Compare the speed and value for a model between the SasView original and the
51sasmodels rewrite.
52
53model is the name of the model to compare (see below).
54N1 is the number of times to run sasmodels (default=1).
55N2 is the number times to run sasview (default=1).
56
57Options (* for default):
58
59    -plot*/-noplot plots or suppress the plot of the model
60    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
61    -nq=128 sets the number of Q points in the data set
62    -1d*/-2d computes 1d or 2d data
63    -preset*/-random[=seed] preset or random parameters
64    -mono/-poly* force monodisperse/polydisperse
65    -cutoff=1e-5* cutoff value for including a point in polydispersity
66    -pars/-nopars* prints the parameter set or not
67    -abs/-rel* plot relative or absolute error
68    -linear/-log*/-q4 intensity scaling
69    -hist/-nohist* plot histogram of relative error
70    -res=0 sets the resolution width dQ/Q if calculating with resolution
71    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
72    -edit starts the parameter explorer
73
74Any two calculation engines can be selected for comparison:
75
76    -single/-double/-half/-fast sets an OpenCL calculation engine
77    -single!/-double!/-quad! sets an OpenMP calculation engine
78    -sasview sets the sasview calculation engine
79
80The default is -single -sasview.  Note that the interpretation of quad
81precision depends on architecture, and may vary from 64-bit to 128-bit,
82with 80-bit floats being common (1e-19 precision).
83
84Key=value pairs allow you to set specific values for the model parameters.
85"""
86
87# Update docs with command line usage string.   This is separate from the usual
88# doc string so that we can display it at run time if there is an error.
89# lin
[d15a908]90__doc__ = (__doc__  # pylint: disable=redefined-builtin
91           + """
[caeb06d]92Program description
93-------------------
94
[d15a908]95"""
96           + USAGE)
[caeb06d]97
[750ffa5]98kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]99
[fc0fcd0]100MODELS = core.list_models()
[d547f16]101
[7cf2cfd]102# CRUFT python 2.6
103if not hasattr(datetime.timedelta, 'total_seconds'):
104    def delay(dt):
105        """Return number date-time delta as number seconds"""
106        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
107else:
108    def delay(dt):
109        """Return number date-time delta as number seconds"""
110        return dt.total_seconds()
111
112
[4f2478e]113class push_seed(object):
114    """
115    Set the seed value for the random number generator.
116
117    When used in a with statement, the random number generator state is
118    restored after the with statement is complete.
119
120    :Parameters:
121
122    *seed* : int or array_like, optional
123        Seed for RandomState
124
125    :Example:
126
127    Seed can be used directly to set the seed::
128
129        >>> from numpy.random import randint
130        >>> push_seed(24)
131        <...push_seed object at...>
132        >>> print(randint(0,1000000,3))
133        [242082    899 211136]
134
135    Seed can also be used in a with statement, which sets the random
136    number generator state for the enclosed computations and restores
137    it to the previous state on completion::
138
139        >>> with push_seed(24):
140        ...    print(randint(0,1000000,3))
141        [242082    899 211136]
142
143    Using nested contexts, we can demonstrate that state is indeed
144    restored after the block completes::
145
146        >>> with push_seed(24):
147        ...    print(randint(0,1000000))
148        ...    with push_seed(24):
149        ...        print(randint(0,1000000,3))
150        ...    print(randint(0,1000000))
151        242082
152        [242082    899 211136]
153        899
154
155    The restore step is protected against exceptions in the block::
156
157        >>> with push_seed(24):
158        ...    print(randint(0,1000000))
159        ...    try:
160        ...        with push_seed(24):
161        ...            print(randint(0,1000000,3))
162        ...            raise Exception()
163        ...    except:
164        ...        print("Exception raised")
165        ...    print(randint(0,1000000))
166        242082
167        [242082    899 211136]
168        Exception raised
169        899
170    """
171    def __init__(self, seed=None):
172        self._state = np.random.get_state()
173        np.random.seed(seed)
174
175    def __enter__(self):
176        return None
177
178    def __exit__(self, *args):
179        np.random.set_state(self._state)
180
[7cf2cfd]181def tic():
182    """
183    Timer function.
184
185    Use "toc=tic()" to start the clock and "toc()" to measure
186    a time interval.
187    """
188    then = datetime.datetime.now()
189    return lambda: delay(datetime.datetime.now() - then)
190
191
192def set_beam_stop(data, radius, outer=None):
193    """
194    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
195
196    Note: this function does not use the sasview package
197    """
198    if hasattr(data, 'qx_data'):
199        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
200        data.mask = (q < radius)
201        if outer is not None:
202            data.mask |= (q >= outer)
203    else:
204        data.mask = (data.x < radius)
205        if outer is not None:
206            data.mask |= (data.x >= outer)
207
[8a20be5]208
[ec7e360]209def parameter_range(p, v):
[87985ca]210    """
[ec7e360]211    Choose a parameter range based on parameter name and initial value.
[87985ca]212    """
[ec7e360]213    if p.endswith('_pd_n'):
214        return [0, 100]
215    elif p.endswith('_pd_nsigma'):
216        return [0, 5]
217    elif p.endswith('_pd_type'):
[87985ca]218        return v
[caeb06d]219    elif any(s in p for s in ('theta', 'phi', 'psi')):
[87985ca]220        # orientation in [-180,180], orientation pd in [0,45]
221        if p.endswith('_pd'):
[caeb06d]222            return [0, 45]
[87985ca]223        else:
[ec7e360]224            return [-180, 180]
[87985ca]225    elif 'sld' in p:
[ec7e360]226        return [-0.5, 10]
[87985ca]227    elif p.endswith('_pd'):
[ec7e360]228        return [0, 1]
[eb46451]229    elif p == 'background':
230        return [0, 10]
231    elif p == 'scale':
[ec7e360]232        return [0, 1e3]
[eb46451]233    elif p == 'case_num':
234        # RPA hack
235        return [0, 10]
236    elif v < 0:
237        # Kxy parameters in rpa model can be negative
238        return [2*v, -2*v]
[87985ca]239    else:
[caeb06d]240        return [0, (2*v if v > 0 else 1)]
[87985ca]241
[4f2478e]242
[ec7e360]243def _randomize_one(p, v):
244    """
[caeb06d]245    Randomize a single parameter.
[ec7e360]246    """
[caeb06d]247    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
[ec7e360]248        return v
249    else:
250        return np.random.uniform(*parameter_range(p, v))
[cd3dba0]251
[4f2478e]252
[ec7e360]253def randomize_pars(pars, seed=None):
[caeb06d]254    """
255    Generate random values for all of the parameters.
256
257    Valid ranges for the random number generator are guessed from the name of
258    the parameter; this will not account for constraints such as cap radius
259    greater than cylinder radius in the capped_cylinder model, so
260    :func:`constrain_pars` needs to be called afterward..
261    """
[4f2478e]262    with push_seed(seed):
263        # Note: the sort guarantees order `of calls to random number generator
264        pars = dict((p, _randomize_one(p, v))
265                    for p, v in sorted(pars.items()))
[ec7e360]266    return pars
[cd3dba0]267
268def constrain_pars(model_definition, pars):
[9a66e65]269    """
270    Restrict parameters to valid values.
[caeb06d]271
272    This includes model specific code for models such as capped_cylinder
273    which need to support within model constraints (cap radius more than
274    cylinder radius in this case).
[9a66e65]275    """
[cd3dba0]276    name = model_definition.name
[216a9e1]277    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
[caeb06d]278        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
[b514adf]279    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
[caeb06d]280        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
[b514adf]281
282    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
283    if name == 'guinier':
284        #q_max = 0.2  # mid q maximum
285        q_max = 1.0  # high q maximum
286        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
[caeb06d]287        pars['rg'] = min(pars['rg'], rg_max)
[cd3dba0]288
[82c299f]289    if name == 'rpa':
290        # Make sure phi sums to 1.0
291        if pars['case_num'] < 2:
292            pars['Phia'] = 0.
293            pars['Phib'] = 0.
294        elif pars['case_num'] < 5:
295            pars['Phia'] = 0.
296        total = sum(pars['Phi'+c] for c in 'abcd')
297        for c in 'abcd':
298            pars['Phi'+c] /= total
299
[87985ca]300def parlist(pars):
[caeb06d]301    """
302    Format the parameter list for printing.
303    """
[a4a7308]304    active = None
305    fields = {}
306    lines = []
307    for k, v in sorted(pars.items()):
308        parts = k.split('_pd')
309        #print(k, active, parts)
310        if len(parts) == 1:
311            if active: lines.append(_format_par(active, **fields))
312            active = k
313            fields = {'value': v}
314        else:
315            assert parts[0] == active
316            if parts[1]:
317                fields[parts[1][1:]] = v
318            else:
319                fields['pd'] = v
320    if active: lines.append(_format_par(active, **fields))
321    return "\n".join(lines)
322
323    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
324
325def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
326    line = "%s: %g"%(name, value)
327    if pd != 0.  and n != 0:
328        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
329                % (pd, n, nsigma, nsigma, type)
330    return line
[87985ca]331
332def suppress_pd(pars):
333    """
334    Suppress theta_pd for now until the normalization is resolved.
335
336    May also suppress complete polydispersity of the model to test
337    models more quickly.
338    """
[f4f3919]339    pars = pars.copy()
[87985ca]340    for p in pars:
[8b25ee1]341        if p.endswith("_pd_n"): pars[p] = 0
[f4f3919]342    return pars
[87985ca]343
[ec7e360]344def eval_sasview(model_definition, data):
[caeb06d]345    """
346    Return a model calculator using the SasView fitting engine.
347    """
[dc056b9]348    # importing sas here so that the error message will be that sas failed to
349    # import rather than the more obscure smear_selection not imported error
[2bebe2b]350    import sas
[346bc88]351    from sas.models.qsmearing import smear_selection
[ec7e360]352
353    # convert model parameters from sasmodel form to sasview form
354    #print("old",sorted(pars.items()))
[9cfcac8]355    modelname, _ = revert_model(model_definition, {})
[caeb06d]356    #print("new",sorted(_pars.items()))
[ec7e360]357    sas = __import__('sas.models.'+modelname)
[caeb06d]358    ModelClass = getattr(getattr(sas.models, modelname, None), modelname, None)
[ec7e360]359    if ModelClass is None:
360        raise ValueError("could not find model %r in sas.models"%modelname)
361    model = ModelClass()
[346bc88]362    smearer = smear_selection(data, model=model)
[216a9e1]363
[ec7e360]364    if hasattr(data, 'qx_data'):
365        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
366        index = ((~data.mask) & (~np.isnan(data.data))
367                 & (q >= data.qmin) & (q <= data.qmax))
368        if smearer is not None:
369            smearer.model = model  # because smear_selection has a bug
370            smearer.accuracy = data.accuracy
371            smearer.set_index(index)
372            theory = lambda: smearer.get_value()
373        else:
[d15a908]374            theory = lambda: model.evalDistribution([data.qx_data[index],
375                                                     data.qy_data[index]])
[ec7e360]376    elif smearer is not None:
377        theory = lambda: smearer(model.evalDistribution(data.x))
378    else:
379        theory = lambda: model.evalDistribution(data.x)
380
381    def calculator(**pars):
[caeb06d]382        """
383        Sasview calculator for model.
384        """
[ec7e360]385        # paying for parameter conversion each time to keep life simple, if not fast
386        _, pars = revert_model(model_definition, pars)
[caeb06d]387        for k, v in pars.items():
[ec7e360]388            parts = k.split('.')  # polydispersity components
389            if len(parts) == 2:
390                model.dispersion[parts[0]][parts[1]] = v
391            else:
392                model.setParam(k, v)
393        return theory()
394
395    calculator.engine = "sasview"
396    return calculator
397
398DTYPE_MAP = {
399    'half': '16',
400    'fast': 'fast',
401    'single': '32',
402    'double': '64',
403    'quad': '128',
404    'f16': '16',
405    'f32': '32',
406    'f64': '64',
407    'longdouble': '128',
408}
409def eval_opencl(model_definition, data, dtype='single', cutoff=0.):
[caeb06d]410    """
411    Return a model calculator using the OpenCL calculation engine.
412    """
[216a9e1]413    try:
[ec7e360]414        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[9404dd3]415    except Exception as exc:
416        print(exc)
417        print("... trying again with single precision")
[ec7e360]418        dtype = 'single'
419        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[7cf2cfd]420    calculator = DirectModel(data, model, cutoff=cutoff)
[ec7e360]421    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
422    return calculator
[216a9e1]423
[ec7e360]424def eval_ctypes(model_definition, data, dtype='double', cutoff=0.):
[9cfcac8]425    """
426    Return a model calculator using the DLL calculation engine.
427    """
[caeb06d]428    if dtype == 'quad':
[ec7e360]429        dtype = 'longdouble'
[aa4946b]430    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[7cf2cfd]431    calculator = DirectModel(data, model, cutoff=cutoff)
[ec7e360]432    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
433    return calculator
434
435def time_calculation(calculator, pars, Nevals=1):
[caeb06d]436    """
437    Compute the average calculation time over N evaluations.
438
439    An additional call is generated without polydispersity in order to
440    initialize the calculation engine, and make the average more stable.
441    """
[ec7e360]442    # initialize the code so time is more accurate
[f4f3919]443    value = calculator(**suppress_pd(pars))
[216a9e1]444    toc = tic()
[ec7e360]445    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
[7cf2cfd]446        value = calculator(**pars)
[216a9e1]447    average_time = toc()*1000./Nevals
448    return value, average_time
449
[ec7e360]450def make_data(opts):
[caeb06d]451    """
452    Generate an empty dataset, used with the model to set Q points
453    and resolution.
454
455    *opts* contains the options, with 'qmax', 'nq', 'res',
456    'accuracy', 'is2d' and 'view' parsed from the command line.
457    """
[ec7e360]458    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
459    if opts['is2d']:
460        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
461        data.accuracy = opts['accuracy']
[87985ca]462        set_beam_stop(data, 0.004)
463        index = ~data.mask
[216a9e1]464    else:
[ec7e360]465        if opts['view'] == 'log':
[b89f519]466            qmax = math.log10(qmax)
[ec7e360]467            q = np.logspace(qmax-3, qmax, nq)
[b89f519]468        else:
[ec7e360]469            q = np.linspace(0.001*qmax, qmax, nq)
470        data = empty_data1D(q, resolution=res)
[216a9e1]471        index = slice(None, None)
472    return data, index
473
[ec7e360]474def make_engine(model_definition, data, dtype, cutoff):
[caeb06d]475    """
476    Generate the appropriate calculation engine for the given datatype.
477
478    Datatypes with '!' appended are evaluated using external C DLLs rather
479    than OpenCL.
480    """
[ec7e360]481    if dtype == 'sasview':
482        return eval_sasview(model_definition, data)
483    elif dtype.endswith('!'):
484        return eval_ctypes(model_definition, data, dtype=dtype[:-1],
485                           cutoff=cutoff)
486    else:
487        return eval_opencl(model_definition, data, dtype=dtype,
488                           cutoff=cutoff)
[87985ca]489
[013adb7]490def compare(opts, limits=None):
[caeb06d]491    """
492    Preform a comparison using options from the command line.
493
494    *limits* are the limits on the values to use, either to set the y-axis
495    for 1D or to set the colormap scale for 2D.  If None, then they are
496    inferred from the data and returned. When exploring using Bumps,
497    the limits are set when the model is initially called, and maintained
498    as the values are adjusted, making it easier to see the effects of the
499    parameters.
500    """
[9cfcac8]501    Nbase, Ncomp = opts['n1'], opts['n2']
[ec7e360]502    pars = opts['pars']
503    data = opts['data']
[87985ca]504
[4b41184]505    # Base calculation
[ec7e360]506    if Nbase > 0:
507        base = opts['engines'][0]
[319ab14]508        try:
[ec7e360]509            base_value, base_time = time_calculation(base, pars, Nbase)
[d15a908]510            print("%s t=%.1f ms, intensity=%.0f"
511                  % (base.engine, base_time, sum(base_value)))
[319ab14]512        except ImportError:
513            traceback.print_exc()
[1ec7efa]514            Nbase = 0
[4b41184]515
516    # Comparison calculation
[ec7e360]517    if Ncomp > 0:
518        comp = opts['engines'][1]
[7cf2cfd]519        try:
[ec7e360]520            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
[d15a908]521            print("%s t=%.1f ms, intensity=%.0f"
522                  % (comp.engine, comp_time, sum(comp_value)))
[7cf2cfd]523        except ImportError:
[5753e4e]524            traceback.print_exc()
[4b41184]525            Ncomp = 0
[87985ca]526
527    # Compare, but only if computing both forms
[4b41184]528    if Nbase > 0 and Ncomp > 0:
[ec7e360]529        resid = (base_value - comp_value)
530        relerr = resid/comp_value
[d15a908]531        _print_stats("|%s-%s|"
532                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
[caeb06d]533                     resid)
[d15a908]534        _print_stats("|(%s-%s)/%s|"
535                     % (base.engine, comp.engine, comp.engine),
[caeb06d]536                     relerr)
[87985ca]537
538    # Plot if requested
[ec7e360]539    if not opts['plot'] and not opts['explore']: return
540    view = opts['view']
[1726b21]541    import matplotlib.pyplot as plt
[013adb7]542    if limits is None:
543        vmin, vmax = np.Inf, -np.Inf
544        if Nbase > 0:
545            vmin = min(vmin, min(base_value))
546            vmax = max(vmax, max(base_value))
547        if Ncomp > 0:
548            vmin = min(vmin, min(comp_value))
549            vmax = max(vmax, max(comp_value))
550        limits = vmin, vmax
551
[4b41184]552    if Nbase > 0:
[ec7e360]553        if Ncomp > 0: plt.subplot(131)
[841753c]554        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
[ec7e360]555        plt.title("%s t=%.1f ms"%(base.engine, base_time))
556        #cbar_title = "log I"
557    if Ncomp > 0:
558        if Nbase > 0: plt.subplot(132)
[841753c]559        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
[caeb06d]560        plt.title("%s t=%.1f ms"%(comp.engine, comp_time))
[7cf2cfd]561        #cbar_title = "log I"
[4b41184]562    if Ncomp > 0 and Nbase > 0:
[87985ca]563        plt.subplot(133)
[d5e650d]564        if not opts['rel_err']:
[caeb06d]565            err, errstr, errview = resid, "abs err", "linear"
[29f5536]566        else:
[caeb06d]567            err, errstr, errview = abs(relerr), "rel err", "log"
[4b41184]568        #err,errstr = base/comp,"ratio"
[841753c]569        plot_theory(data, None, resid=err, view=errview, use_data=False)
[d5e650d]570        if view == 'linear':
571            plt.xscale('linear')
[346bc88]572        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
[7cf2cfd]573        #cbar_title = errstr if errview=="linear" else "log "+errstr
574    #if is2D:
575    #    h = plt.colorbar()
576    #    h.ax.set_title(cbar_title)
[ba69383]577
[4b41184]578    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
[ba69383]579        plt.figure()
[346bc88]580        v = relerr
[caeb06d]581        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
582        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
583        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
584                   % (base.engine, comp.engine, comp.engine))
[ba69383]585        plt.ylabel('P(err)')
[ec7e360]586        plt.title('Distribution of relative error between calculation engines')
[ba69383]587
[ec7e360]588    if not opts['explore']:
589        plt.show()
[8a20be5]590
[013adb7]591    return limits
592
[0763009]593def _print_stats(label, err):
594    sorted_err = np.sort(abs(err))
595    p50 = int((len(err)-1)*0.50)
596    p98 = int((len(err)-1)*0.98)
597    data = [
598        "max:%.3e"%sorted_err[-1],
599        "median:%.3e"%sorted_err[p50],
600        "98%%:%.3e"%sorted_err[p98],
601        "rms:%.3e"%np.sqrt(np.mean(err**2)),
602        "zero-offset:%+.3e"%np.mean(err),
603        ]
[caeb06d]604    print(label+"  "+"  ".join(data))
[0763009]605
606
607
[87985ca]608# ===========================================================================
609#
[216a9e1]610NAME_OPTIONS = set([
[5d316e9]611    'plot', 'noplot',
[ec7e360]612    'half', 'fast', 'single', 'double',
613    'single!', 'double!', 'quad!', 'sasview',
[5d316e9]614    'lowq', 'midq', 'highq', 'exq',
615    '2d', '1d',
616    'preset', 'random',
617    'poly', 'mono',
618    'nopars', 'pars',
619    'rel', 'abs',
[b89f519]620    'linear', 'log', 'q4',
[5d316e9]621    'hist', 'nohist',
[ec7e360]622    'edit',
[216a9e1]623    ])
624VALUE_OPTIONS = [
625    # Note: random is both a name option and a value option
[ec7e360]626    'cutoff', 'random', 'nq', 'res', 'accuracy',
[87985ca]627    ]
628
[7cf2cfd]629def columnize(L, indent="", width=79):
[caeb06d]630    """
[1d4017a]631    Format a list of strings into columns.
632
633    Returns a string with carriage returns ready for printing.
[caeb06d]634    """
[7cf2cfd]635    column_width = max(len(w) for w in L) + 1
636    num_columns = (width - len(indent)) // column_width
637    num_rows = len(L) // num_columns
638    L = L + [""] * (num_rows*num_columns - len(L))
639    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
640    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
641             for row in zip(*columns)]
642    output = indent + ("\n"+indent).join(lines)
643    return output
644
645
[cd3dba0]646def get_demo_pars(model_definition):
[caeb06d]647    """
648    Extract demo parameters from the model definition.
649    """
[cd3dba0]650    info = generate.make_info(model_definition)
[ec7e360]651    # Get the default values for the parameters
[9cfcac8]652    pars = dict((p[0], p[2]) for p in info['parameters'])
[ec7e360]653
654    # Fill in default values for the polydispersity parameters
655    for p in info['parameters']:
656        if p[4] in ('volume', 'orientation'):
657            pars[p[0]+'_pd'] = 0.0
658            pars[p[0]+'_pd_n'] = 0
659            pars[p[0]+'_pd_nsigma'] = 3.0
660            pars[p[0]+'_pd_type'] = "gaussian"
661
662    # Plug in values given in demo
[cd3dba0]663    pars.update(info['demo'])
[373d1b6]664    return pars
665
[ec7e360]666def parse_opts():
[caeb06d]667    """
668    Parse command line options.
669    """
[fc0fcd0]670    MODELS = core.list_models()
[caeb06d]671    flags = [arg for arg in sys.argv[1:]
672             if arg.startswith('-')]
673    values = [arg for arg in sys.argv[1:]
674              if not arg.startswith('-') and '=' in arg]
675    args = [arg for arg in sys.argv[1:]
676            if not arg.startswith('-') and '=' not in arg]
[d547f16]677    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]678    if len(args) == 0:
[7cf2cfd]679        print(USAGE)
[caeb06d]680        print("\nAvailable models:")
[7cf2cfd]681        print(columnize(MODELS, indent="  "))
[87985ca]682        sys.exit(1)
[fc0fcd0]683
684    name = args[0]
685    try:
686        model_definition = core.load_model_definition(name)
687    except ImportError, exc:
688        print(str(exc))
689        print("Use one of:\n    " + models)
[87985ca]690        sys.exit(1)
[319ab14]691    if len(args) > 3:
[9cfcac8]692        print("expected parameters: model N1 N2")
[87985ca]693
[ec7e360]694    invalid = [o[1:] for o in flags
[216a9e1]695               if o[1:] not in NAME_OPTIONS
[d15a908]696               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]697    if invalid:
[9404dd3]698        print("Invalid options: %s"%(", ".join(invalid)))
[87985ca]699        sys.exit(1)
700
[ec7e360]701
[d15a908]702    # pylint: disable=bad-whitespace
[ec7e360]703    # Interpret the flags
704    opts = {
705        'plot'      : True,
706        'view'      : 'log',
707        'is2d'      : False,
708        'qmax'      : 0.05,
709        'nq'        : 128,
710        'res'       : 0.0,
711        'accuracy'  : 'Low',
712        'cutoff'    : 1e-5,
713        'seed'      : -1,  # default to preset
714        'mono'      : False,
715        'show_pars' : False,
716        'show_hist' : False,
717        'rel_err'   : True,
718        'explore'   : False,
719    }
720    engines = []
721    for arg in flags:
722        if arg == '-noplot':    opts['plot'] = False
723        elif arg == '-plot':    opts['plot'] = True
724        elif arg == '-linear':  opts['view'] = 'linear'
725        elif arg == '-log':     opts['view'] = 'log'
726        elif arg == '-q4':      opts['view'] = 'q4'
727        elif arg == '-1d':      opts['is2d'] = False
728        elif arg == '-2d':      opts['is2d'] = True
729        elif arg == '-exq':     opts['qmax'] = 10.0
730        elif arg == '-highq':   opts['qmax'] = 1.0
731        elif arg == '-midq':    opts['qmax'] = 0.2
[ce0b154]732        elif arg == '-lowq':    opts['qmax'] = 0.05
[ec7e360]733        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
734        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
735        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
736        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
737        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
738        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
739        elif arg == '-preset':  opts['seed'] = -1
740        elif arg == '-mono':    opts['mono'] = True
741        elif arg == '-poly':    opts['mono'] = False
742        elif arg == '-pars':    opts['show_pars'] = True
743        elif arg == '-nopars':  opts['show_pars'] = False
744        elif arg == '-hist':    opts['show_hist'] = True
745        elif arg == '-nohist':  opts['show_hist'] = False
746        elif arg == '-rel':     opts['rel_err'] = True
747        elif arg == '-abs':     opts['rel_err'] = False
748        elif arg == '-half':    engines.append(arg[1:])
749        elif arg == '-fast':    engines.append(arg[1:])
750        elif arg == '-single':  engines.append(arg[1:])
751        elif arg == '-double':  engines.append(arg[1:])
752        elif arg == '-single!': engines.append(arg[1:])
753        elif arg == '-double!': engines.append(arg[1:])
754        elif arg == '-quad!':   engines.append(arg[1:])
755        elif arg == '-sasview': engines.append(arg[1:])
756        elif arg == '-edit':    opts['explore'] = True
[d15a908]757    # pylint: enable=bad-whitespace
[ec7e360]758
759    if len(engines) == 0:
[9cfcac8]760        engines.extend(['single', 'sasview'])
[ec7e360]761    elif len(engines) == 1:
762        if engines[0][0] != 'sasview':
763            engines.append('sasview')
764        else:
765            engines.append('single')
766    elif len(engines) > 2:
767        del engines[2:]
768
[9cfcac8]769    n1 = int(args[1]) if len(args) > 1 else 1
770    n2 = int(args[2]) if len(args) > 2 else 1
[87985ca]771
[ec7e360]772    # Get demo parameters from model definition, or use default parameters
773    # if model does not define demo parameters
774    pars = get_demo_pars(model_definition)
[87985ca]775
776    # Fill in parameters given on the command line
[ec7e360]777    presets = {}
778    for arg in values:
[d15a908]779        k, v = arg.split('=', 1)
[87985ca]780        if k not in pars:
[ec7e360]781            # extract base name without polydispersity info
[87985ca]782            s = set(p.split('_pd')[0] for p in pars)
[d15a908]783            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
[87985ca]784            sys.exit(1)
[ec7e360]785        presets[k] = float(v) if not k.endswith('type') else v
786
787    # randomize parameters
788    #pars.update(set_pars)  # set value before random to control range
789    if opts['seed'] > -1:
790        pars = randomize_pars(pars, seed=opts['seed'])
791        print("Randomize using -random=%i"%opts['seed'])
[8b25ee1]792    if opts['mono']:
793        pars = suppress_pd(pars)
[ec7e360]794    pars.update(presets)  # set value after random to control value
795    constrain_pars(model_definition, pars)
796    constrain_new_to_old(model_definition, pars)
797    if opts['show_pars']:
[a4a7308]798        print(str(parlist(pars)))
[ec7e360]799
800    # Create the computational engines
[d15a908]801    data, _ = make_data(opts)
[9cfcac8]802    if n1:
[ec7e360]803        base = make_engine(model_definition, data, engines[0], opts['cutoff'])
804    else:
805        base = None
[9cfcac8]806    if n2:
[ec7e360]807        comp = make_engine(model_definition, data, engines[1], opts['cutoff'])
808    else:
809        comp = None
810
[d15a908]811    # pylint: disable=bad-whitespace
[ec7e360]812    # Remember it all
813    opts.update({
814        'name'      : name,
815        'def'       : model_definition,
[9cfcac8]816        'n1'        : n1,
817        'n2'        : n2,
[ec7e360]818        'presets'   : presets,
819        'pars'      : pars,
820        'data'      : data,
821        'engines'   : [base, comp],
822    })
[d15a908]823    # pylint: enable=bad-whitespace
[ec7e360]824
825    return opts
826
827def explore(opts):
[d15a908]828    """
829    Explore the model using the Bumps GUI.
830    """
[ec7e360]831    import wx
832    from bumps.names import FitProblem
833    from bumps.gui.app_frame import AppFrame
834
835    problem = FitProblem(Explore(opts))
[d15a908]836    is_mac = "cocoa" in wx.version()
[ec7e360]837    app = wx.App()
838    frame = AppFrame(parent=None, title="explore")
[d15a908]839    if not is_mac: frame.Show()
[ec7e360]840    frame.panel.set_model(model=problem)
841    frame.panel.Layout()
842    frame.panel.aui.Split(0, wx.TOP)
[d15a908]843    if is_mac: frame.Show()
[ec7e360]844    app.MainLoop()
845
846class Explore(object):
847    """
[d15a908]848    Bumps wrapper for a SAS model comparison.
849
850    The resulting object can be used as a Bumps fit problem so that
851    parameters can be adjusted in the GUI, with plots updated on the fly.
[ec7e360]852    """
853    def __init__(self, opts):
854        from bumps.cli import config_matplotlib
[608e31e]855        from . import bumps_model
[ec7e360]856        config_matplotlib()
857        self.opts = opts
858        info = generate.make_info(opts['def'])
859        pars, pd_types = bumps_model.create_parameters(info, **opts['pars'])
860        if not opts['is2d']:
861            active = [base + ext
862                      for base in info['partype']['pd-1d']
[608e31e]863                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
[ec7e360]864            active.extend(info['partype']['fixed-1d'])
865            for k in active:
866                v = pars[k]
867                v.range(*parameter_range(k, v.value))
868        else:
[013adb7]869            for k, v in pars.items():
[ec7e360]870                v.range(*parameter_range(k, v.value))
871
872        self.pars = pars
873        self.pd_types = pd_types
[013adb7]874        self.limits = None
[ec7e360]875
876    def numpoints(self):
877        """
[608e31e]878        Return the number of points.
[ec7e360]879        """
880        return len(self.pars) + 1  # so dof is 1
881
882    def parameters(self):
883        """
[608e31e]884        Return a dictionary of parameters.
[ec7e360]885        """
886        return self.pars
887
888    def nllf(self):
[608e31e]889        """
890        Return cost.
891        """
[d15a908]892        # pylint: disable=no-self-use
[ec7e360]893        return 0.  # No nllf
894
895    def plot(self, view='log'):
896        """
897        Plot the data and residuals.
898        """
[608e31e]899        pars = dict((k, v.value) for k, v in self.pars.items())
[ec7e360]900        pars.update(self.pd_types)
901        self.opts['pars'] = pars
[013adb7]902        limits = compare(self.opts, limits=self.limits)
903        if self.limits is None:
904            vmin, vmax = limits
905            vmax = 1.3*vmax
906            vmin = vmax*1e-7
907            self.limits = vmin, vmax
[87985ca]908
909
[d15a908]910def main():
911    """
912    Main program.
913    """
914    opts = parse_opts()
915    if opts['explore']:
916        explore(opts)
917    else:
918        compare(opts)
919
[8a20be5]920if __name__ == "__main__":
[87985ca]921    main()
Note: See TracBrowser for help on using the repository browser.