source: sasmodels/sasmodels/compare.py @ 5efe850

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5efe850 was 8749d5b9, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

missing default value for zero

  • Property mode set to 100755
File size: 31.3 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np
37
38from . import core
39from . import kerneldll
40from . import product
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_name, revert_pars, constrain_new_to_old
44
45USAGE = """
46usage: compare.py model N1 N2 [options...] [key=val]
47
48Compare the speed and value for a model between the SasView original and the
49sasmodels rewrite.
50
51model is the name of the model to compare (see below).
52N1 is the number of times to run sasmodels (default=1).
53N2 is the number times to run sasview (default=1).
54
55Options (* for default):
56
57    -plot*/-noplot plots or suppress the plot of the model
58    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
59    -nq=128 sets the number of Q points in the data set
60    -zero indicates that q=0 should be included
61    -1d*/-2d computes 1d or 2d data
62    -preset*/-random[=seed] preset or random parameters
63    -mono/-poly* force monodisperse/polydisperse
64    -cutoff=1e-5* cutoff value for including a point in polydispersity
65    -pars/-nopars* prints the parameter set or not
66    -abs/-rel* plot relative or absolute error
67    -linear/-log*/-q4 intensity scaling
68    -hist/-nohist* plot histogram of relative error
69    -res=0 sets the resolution width dQ/Q if calculating with resolution
70    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
71    -edit starts the parameter explorer
72    -default/-demo* use demo vs default parameters
73
74Any two calculation engines can be selected for comparison:
75
76    -single/-double/-half/-fast sets an OpenCL calculation engine
77    -single!/-double!/-quad! sets an OpenMP calculation engine
78    -sasview sets the sasview calculation engine
79
80The default is -single -sasview.  Note that the interpretation of quad
81precision depends on architecture, and may vary from 64-bit to 128-bit,
82with 80-bit floats being common (1e-19 precision).
83
84Key=value pairs allow you to set specific values for the model parameters.
85"""
86
87# Update docs with command line usage string.   This is separate from the usual
88# doc string so that we can display it at run time if there is an error.
89# lin
90__doc__ = (__doc__  # pylint: disable=redefined-builtin
91           + """
92Program description
93-------------------
94
95"""
96           + USAGE)
97
98kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
99
100MODELS = core.list_models()
101
102# CRUFT python 2.6
103if not hasattr(datetime.timedelta, 'total_seconds'):
104    def delay(dt):
105        """Return number date-time delta as number seconds"""
106        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
107else:
108    def delay(dt):
109        """Return number date-time delta as number seconds"""
110        return dt.total_seconds()
111
112
113class push_seed(object):
114    """
115    Set the seed value for the random number generator.
116
117    When used in a with statement, the random number generator state is
118    restored after the with statement is complete.
119
120    :Parameters:
121
122    *seed* : int or array_like, optional
123        Seed for RandomState
124
125    :Example:
126
127    Seed can be used directly to set the seed::
128
129        >>> from numpy.random import randint
130        >>> push_seed(24)
131        <...push_seed object at...>
132        >>> print(randint(0,1000000,3))
133        [242082    899 211136]
134
135    Seed can also be used in a with statement, which sets the random
136    number generator state for the enclosed computations and restores
137    it to the previous state on completion::
138
139        >>> with push_seed(24):
140        ...    print(randint(0,1000000,3))
141        [242082    899 211136]
142
143    Using nested contexts, we can demonstrate that state is indeed
144    restored after the block completes::
145
146        >>> with push_seed(24):
147        ...    print(randint(0,1000000))
148        ...    with push_seed(24):
149        ...        print(randint(0,1000000,3))
150        ...    print(randint(0,1000000))
151        242082
152        [242082    899 211136]
153        899
154
155    The restore step is protected against exceptions in the block::
156
157        >>> with push_seed(24):
158        ...    print(randint(0,1000000))
159        ...    try:
160        ...        with push_seed(24):
161        ...            print(randint(0,1000000,3))
162        ...            raise Exception()
163        ...    except:
164        ...        print("Exception raised")
165        ...    print(randint(0,1000000))
166        242082
167        [242082    899 211136]
168        Exception raised
169        899
170    """
171    def __init__(self, seed=None):
172        self._state = np.random.get_state()
173        np.random.seed(seed)
174
175    def __enter__(self):
176        return None
177
178    def __exit__(self, *args):
179        np.random.set_state(self._state)
180
181def tic():
182    """
183    Timer function.
184
185    Use "toc=tic()" to start the clock and "toc()" to measure
186    a time interval.
187    """
188    then = datetime.datetime.now()
189    return lambda: delay(datetime.datetime.now() - then)
190
191
192def set_beam_stop(data, radius, outer=None):
193    """
194    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
195
196    Note: this function does not use the sasview package
197    """
198    if hasattr(data, 'qx_data'):
199        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
200        data.mask = (q < radius)
201        if outer is not None:
202            data.mask |= (q >= outer)
203    else:
204        data.mask = (data.x < radius)
205        if outer is not None:
206            data.mask |= (data.x >= outer)
207
208
209def parameter_range(p, v):
210    """
211    Choose a parameter range based on parameter name and initial value.
212    """
213    if p.endswith('_pd_n'):
214        return [0, 100]
215    elif p.endswith('_pd_nsigma'):
216        return [0, 5]
217    elif p.endswith('_pd_type'):
218        return v
219    elif any(s in p for s in ('theta', 'phi', 'psi')):
220        # orientation in [-180,180], orientation pd in [0,45]
221        if p.endswith('_pd'):
222            return [0, 45]
223        else:
224            return [-180, 180]
225    elif 'sld' in p:
226        return [-0.5, 10]
227    elif p.endswith('_pd'):
228        return [0, 1]
229    elif p == 'background':
230        return [0, 10]
231    elif p == 'scale':
232        return [0, 1e3]
233    elif p == 'case_num':
234        # RPA hack
235        return [0, 10]
236    elif v < 0:
237        # Kxy parameters in rpa model can be negative
238        return [2*v, -2*v]
239    else:
240        return [0, (2*v if v > 0 else 1)]
241
242
243def _randomize_one(p, v):
244    """
245    Randomize a single parameter.
246    """
247    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
248        return v
249    else:
250        return np.random.uniform(*parameter_range(p, v))
251
252
253def randomize_pars(pars, seed=None):
254    """
255    Generate random values for all of the parameters.
256
257    Valid ranges for the random number generator are guessed from the name of
258    the parameter; this will not account for constraints such as cap radius
259    greater than cylinder radius in the capped_cylinder model, so
260    :func:`constrain_pars` needs to be called afterward..
261    """
262    with push_seed(seed):
263        # Note: the sort guarantees order `of calls to random number generator
264        pars = dict((p, _randomize_one(p, v))
265                    for p, v in sorted(pars.items()))
266    return pars
267
268def constrain_pars(model_info, pars):
269    """
270    Restrict parameters to valid values.
271
272    This includes model specific code for models such as capped_cylinder
273    which need to support within model constraints (cap radius more than
274    cylinder radius in this case).
275    """
276    name = model_info['id']
277    # if it is a product model, then just look at the form factor since
278    # none of the structure factors need any constraints.
279    if '*' in name:
280        name = name.split('*')[0]
281
282    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
283        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
284    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
285        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
286
287    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
288    if name == 'guinier':
289        #q_max = 0.2  # mid q maximum
290        q_max = 1.0  # high q maximum
291        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
292        pars['rg'] = min(pars['rg'], rg_max)
293
294    if name == 'rpa':
295        # Make sure phi sums to 1.0
296        if pars['case_num'] < 2:
297            pars['Phia'] = 0.
298            pars['Phib'] = 0.
299        elif pars['case_num'] < 5:
300            pars['Phia'] = 0.
301        total = sum(pars['Phi'+c] for c in 'abcd')
302        for c in 'abcd':
303            pars['Phi'+c] /= total
304
305def parlist(model_info, pars, is2d):
306    """
307    Format the parameter list for printing.
308    """
309    if is2d:
310        exclude = lambda n: False
311    else:
312        partype = model_info['partype']
313        par1d = set(partype['fixed-1d']+partype['pd-1d'])
314        exclude = lambda n: n not in par1d
315    lines = []
316    for p in model_info['parameters']:
317        if exclude(p.name): continue
318        fields = dict(
319            value=pars.get(p.name, p.default),
320            pd=pars.get(p.name+"_pd", 0.),
321            n=int(pars.get(p.name+"_pd_n", 0)),
322            nsigma=pars.get(p.name+"_pd_nsgima", 3.),
323            type=pars.get(p.name+"_pd_type", 'gaussian'))
324        lines.append(_format_par(p.name, **fields))
325    return "\n".join(lines)
326
327    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
328
329def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
330    line = "%s: %g"%(name, value)
331    if pd != 0.  and n != 0:
332        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
333                % (pd, n, nsigma, nsigma, type)
334    return line
335
336def suppress_pd(pars):
337    """
338    Suppress theta_pd for now until the normalization is resolved.
339
340    May also suppress complete polydispersity of the model to test
341    models more quickly.
342    """
343    pars = pars.copy()
344    for p in pars:
345        if p.endswith("_pd_n"): pars[p] = 0
346    return pars
347
348def eval_sasview(model_info, data):
349    """
350    Return a model calculator using the pre-4.0 SasView models.
351    """
352    # importing sas here so that the error message will be that sas failed to
353    # import rather than the more obscure smear_selection not imported error
354    import sas
355    from sas.models.qsmearing import smear_selection
356
357    def get_model(name):
358        #print("new",sorted(_pars.items()))
359        sas = __import__('sas.models.' + name)
360        ModelClass = getattr(getattr(sas.models, name, None), name, None)
361        if ModelClass is None:
362            raise ValueError("could not find model %r in sas.models"%name)
363        return ModelClass()
364
365    # grab the sasview model, or create it if it is a product model
366    if model_info['composition']:
367        composition_type, parts = model_info['composition']
368        if composition_type == 'product':
369            from sas.models.MultiplicationModel import MultiplicationModel
370            P, S = [get_model(revert_name(p)) for p in parts]
371            model = MultiplicationModel(P, S)
372        else:
373            raise ValueError("sasview mixture models not supported by compare")
374    else:
375        model = get_model(revert_name(model_info))
376
377    # build a smearer with which to call the model, if necessary
378    smearer = smear_selection(data, model=model)
379    if hasattr(data, 'qx_data'):
380        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
381        index = ((~data.mask) & (~np.isnan(data.data))
382                 & (q >= data.qmin) & (q <= data.qmax))
383        if smearer is not None:
384            smearer.model = model  # because smear_selection has a bug
385            smearer.accuracy = data.accuracy
386            smearer.set_index(index)
387            theory = lambda: smearer.get_value()
388        else:
389            theory = lambda: model.evalDistribution([data.qx_data[index],
390                                                     data.qy_data[index]])
391    elif smearer is not None:
392        theory = lambda: smearer(model.evalDistribution(data.x))
393    else:
394        theory = lambda: model.evalDistribution(data.x)
395
396    def calculator(**pars):
397        """
398        Sasview calculator for model.
399        """
400        # paying for parameter conversion each time to keep life simple, if not fast
401        pars = revert_pars(model_info, pars)
402        for k, v in pars.items():
403            parts = k.split('.')  # polydispersity components
404            if len(parts) == 2:
405                model.dispersion[parts[0]][parts[1]] = v
406            else:
407                model.setParam(k, v)
408        return theory()
409
410    calculator.engine = "sasview"
411    return calculator
412
413DTYPE_MAP = {
414    'half': '16',
415    'fast': 'fast',
416    'single': '32',
417    'double': '64',
418    'quad': '128',
419    'f16': '16',
420    'f32': '32',
421    'f64': '64',
422    'longdouble': '128',
423}
424def eval_opencl(model_info, data, dtype='single', cutoff=0.):
425    """
426    Return a model calculator using the OpenCL calculation engine.
427    """
428    try:
429        model = core.build_model(model_info, dtype=dtype, platform="ocl")
430    except Exception as exc:
431        print(exc)
432        print("... trying again with single precision")
433        model = core.build_model(model_info, dtype='single', platform="ocl")
434    calculator = DirectModel(data, model, cutoff=cutoff)
435    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
436    return calculator
437
438def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
439    """
440    Return a model calculator using the DLL calculation engine.
441    """
442    if dtype == 'quad':
443        dtype = 'longdouble'
444    model = core.build_model(model_info, dtype=dtype, platform="dll")
445    calculator = DirectModel(data, model, cutoff=cutoff)
446    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
447    return calculator
448
449def time_calculation(calculator, pars, Nevals=1):
450    """
451    Compute the average calculation time over N evaluations.
452
453    An additional call is generated without polydispersity in order to
454    initialize the calculation engine, and make the average more stable.
455    """
456    # initialize the code so time is more accurate
457    value = calculator(**suppress_pd(pars))
458    toc = tic()
459    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
460        value = calculator(**pars)
461    average_time = toc()*1000./Nevals
462    return value, average_time
463
464def make_data(opts):
465    """
466    Generate an empty dataset, used with the model to set Q points
467    and resolution.
468
469    *opts* contains the options, with 'qmax', 'nq', 'res',
470    'accuracy', 'is2d' and 'view' parsed from the command line.
471    """
472    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
473    if opts['is2d']:
474        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
475        data.accuracy = opts['accuracy']
476        set_beam_stop(data, 0.0004)
477        index = ~data.mask
478    else:
479        if opts['view'] == 'log' and not opts['zero']:
480            qmax = math.log10(qmax)
481            q = np.logspace(qmax-3, qmax, nq)
482        else:
483            q = np.linspace(0.001*qmax, qmax, nq)
484        if opts['zero']:
485            q = np.hstack((0, q))
486        data = empty_data1D(q, resolution=res)
487        index = slice(None, None)
488    return data, index
489
490def make_engine(model_info, data, dtype, cutoff):
491    """
492    Generate the appropriate calculation engine for the given datatype.
493
494    Datatypes with '!' appended are evaluated using external C DLLs rather
495    than OpenCL.
496    """
497    if dtype == 'sasview':
498        return eval_sasview(model_info, data)
499    elif dtype.endswith('!'):
500        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
501    else:
502        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
503
504def _show_invalid(data, theory):
505    if not theory.mask.any():
506        return
507
508    if hasattr(data, 'x'):
509        bad = zip(data.x[theory.mask], theory[theory.mask])
510        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x,y in bad))
511
512
513def compare(opts, limits=None):
514    """
515    Preform a comparison using options from the command line.
516
517    *limits* are the limits on the values to use, either to set the y-axis
518    for 1D or to set the colormap scale for 2D.  If None, then they are
519    inferred from the data and returned. When exploring using Bumps,
520    the limits are set when the model is initially called, and maintained
521    as the values are adjusted, making it easier to see the effects of the
522    parameters.
523    """
524    Nbase, Ncomp = opts['n1'], opts['n2']
525    pars = opts['pars']
526    data = opts['data']
527
528    # Base calculation
529    if Nbase > 0:
530        base = opts['engines'][0]
531        try:
532            base_value, base_time = time_calculation(base, pars, Nbase)
533            base_value = np.ma.masked_invalid(base_value)
534            print("%s t=%.2f ms, intensity=%.0f"
535                  % (base.engine, base_time, base_value.sum()))
536            _show_invalid(data, base_value)
537        except ImportError:
538            traceback.print_exc()
539            Nbase = 0
540
541    # Comparison calculation
542    if Ncomp > 0:
543        comp = opts['engines'][1]
544        try:
545            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
546            comp_value = np.ma.masked_invalid(comp_value)
547            print("%s t=%.2f ms, intensity=%.0f"
548                  % (comp.engine, comp_time, comp_value.sum()))
549            _show_invalid(data, comp_value)
550        except ImportError:
551            traceback.print_exc()
552            Ncomp = 0
553
554    # Compare, but only if computing both forms
555    if Nbase > 0 and Ncomp > 0:
556        resid = (base_value - comp_value)
557        relerr = resid/comp_value
558        _print_stats("|%s-%s|"
559                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
560                     resid)
561        _print_stats("|(%s-%s)/%s|"
562                     % (base.engine, comp.engine, comp.engine),
563                     relerr)
564
565    # Plot if requested
566    if not opts['plot'] and not opts['explore']: return
567    view = opts['view']
568    import matplotlib.pyplot as plt
569    if limits is None:
570        vmin, vmax = np.Inf, -np.Inf
571        if Nbase > 0:
572            vmin = min(vmin, base_value.min())
573            vmax = max(vmax, base_value.max())
574        if Ncomp > 0:
575            vmin = min(vmin, comp_value.min())
576            vmax = max(vmax, comp_value.max())
577        limits = vmin, vmax
578
579    if Nbase > 0:
580        if Ncomp > 0: plt.subplot(131)
581        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
582        plt.title("%s t=%.2f ms"%(base.engine, base_time))
583        #cbar_title = "log I"
584    if Ncomp > 0:
585        if Nbase > 0: plt.subplot(132)
586        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
587        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
588        #cbar_title = "log I"
589    if Ncomp > 0 and Nbase > 0:
590        plt.subplot(133)
591        if not opts['rel_err']:
592            err, errstr, errview = resid, "abs err", "linear"
593        else:
594            err, errstr, errview = abs(relerr), "rel err", "log"
595        #err,errstr = base/comp,"ratio"
596        plot_theory(data, None, resid=err, view=errview, use_data=False)
597        if view == 'linear':
598            plt.xscale('linear')
599        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
600        #cbar_title = errstr if errview=="linear" else "log "+errstr
601    #if is2D:
602    #    h = plt.colorbar()
603    #    h.ax.set_title(cbar_title)
604
605    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
606        plt.figure()
607        v = relerr
608        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
609        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
610        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
611                   % (base.engine, comp.engine, comp.engine))
612        plt.ylabel('P(err)')
613        plt.title('Distribution of relative error between calculation engines')
614
615    if not opts['explore']:
616        plt.show()
617
618    return limits
619
620def _print_stats(label, err):
621    sorted_err = np.sort(abs(err.compressed()))
622    p50 = int((len(err)-1)*0.50)
623    p98 = int((len(err)-1)*0.98)
624    data = [
625        "max:%.3e"%sorted_err[-1],
626        "median:%.3e"%sorted_err[p50],
627        "98%%:%.3e"%sorted_err[p98],
628        "rms:%.3e"%np.sqrt(np.mean(err**2)),
629        "zero-offset:%+.3e"%np.mean(err),
630        ]
631    print(label+"  "+"  ".join(data))
632
633
634
635# ===========================================================================
636#
637NAME_OPTIONS = set([
638    'plot', 'noplot',
639    'half', 'fast', 'single', 'double',
640    'single!', 'double!', 'quad!', 'sasview',
641    'lowq', 'midq', 'highq', 'exq', 'zero',
642    '2d', '1d',
643    'preset', 'random',
644    'poly', 'mono',
645    'nopars', 'pars',
646    'rel', 'abs',
647    'linear', 'log', 'q4',
648    'hist', 'nohist',
649    'edit',
650    'demo', 'default',
651    ])
652VALUE_OPTIONS = [
653    # Note: random is both a name option and a value option
654    'cutoff', 'random', 'nq', 'res', 'accuracy',
655    ]
656
657def columnize(L, indent="", width=79):
658    """
659    Format a list of strings into columns.
660
661    Returns a string with carriage returns ready for printing.
662    """
663    column_width = max(len(w) for w in L) + 1
664    num_columns = (width - len(indent)) // column_width
665    num_rows = len(L) // num_columns
666    L = L + [""] * (num_rows*num_columns - len(L))
667    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
668    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
669             for row in zip(*columns)]
670    output = indent + ("\n"+indent).join(lines)
671    return output
672
673
674def get_pars(model_info, use_demo=False):
675    """
676    Extract demo parameters from the model definition.
677    """
678    # Get the default values for the parameters
679    pars = dict((p.name, p.default) for p in model_info['parameters'])
680
681    # Fill in default values for the polydispersity parameters
682    for p in model_info['parameters']:
683        if p.type in ('volume', 'orientation'):
684            pars[p.name+'_pd'] = 0.0
685            pars[p.name+'_pd_n'] = 0
686            pars[p.name+'_pd_nsigma'] = 3.0
687            pars[p.name+'_pd_type'] = "gaussian"
688
689    # Plug in values given in demo
690    if use_demo:
691        pars.update(model_info['demo'])
692    return pars
693
694
695def parse_opts():
696    """
697    Parse command line options.
698    """
699    MODELS = core.list_models()
700    flags = [arg for arg in sys.argv[1:]
701             if arg.startswith('-')]
702    values = [arg for arg in sys.argv[1:]
703              if not arg.startswith('-') and '=' in arg]
704    args = [arg for arg in sys.argv[1:]
705            if not arg.startswith('-') and '=' not in arg]
706    models = "\n    ".join("%-15s"%v for v in MODELS)
707    if len(args) == 0:
708        print(USAGE)
709        print("\nAvailable models:")
710        print(columnize(MODELS, indent="  "))
711        sys.exit(1)
712    if len(args) > 3:
713        print("expected parameters: model N1 N2")
714
715    name = args[0]
716    try:
717        model_info = core.load_model_info(name)
718    except ImportError as exc:
719        print(str(exc))
720        print("Could not find model; use one of:\n    " + models)
721        sys.exit(1)
722
723    invalid = [o[1:] for o in flags
724               if o[1:] not in NAME_OPTIONS
725               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
726    if invalid:
727        print("Invalid options: %s"%(", ".join(invalid)))
728        sys.exit(1)
729
730
731    # pylint: disable=bad-whitespace
732    # Interpret the flags
733    opts = {
734        'plot'      : True,
735        'view'      : 'log',
736        'is2d'      : False,
737        'qmax'      : 0.05,
738        'nq'        : 128,
739        'res'       : 0.0,
740        'accuracy'  : 'Low',
741        'cutoff'    : 0.0,
742        'seed'      : -1,  # default to preset
743        'mono'      : False,
744        'show_pars' : False,
745        'show_hist' : False,
746        'rel_err'   : True,
747        'explore'   : False,
748        'use_demo'  : True,
749        'zero'      : False,
750    }
751    engines = []
752    for arg in flags:
753        if arg == '-noplot':    opts['plot'] = False
754        elif arg == '-plot':    opts['plot'] = True
755        elif arg == '-linear':  opts['view'] = 'linear'
756        elif arg == '-log':     opts['view'] = 'log'
757        elif arg == '-q4':      opts['view'] = 'q4'
758        elif arg == '-1d':      opts['is2d'] = False
759        elif arg == '-2d':      opts['is2d'] = True
760        elif arg == '-exq':     opts['qmax'] = 10.0
761        elif arg == '-highq':   opts['qmax'] = 1.0
762        elif arg == '-midq':    opts['qmax'] = 0.2
763        elif arg == '-lowq':    opts['qmax'] = 0.05
764        elif arg == '-zero':    opts['zero'] = True
765        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
766        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
767        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
768        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
769        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
770        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
771        elif arg == '-preset':  opts['seed'] = -1
772        elif arg == '-mono':    opts['mono'] = True
773        elif arg == '-poly':    opts['mono'] = False
774        elif arg == '-pars':    opts['show_pars'] = True
775        elif arg == '-nopars':  opts['show_pars'] = False
776        elif arg == '-hist':    opts['show_hist'] = True
777        elif arg == '-nohist':  opts['show_hist'] = False
778        elif arg == '-rel':     opts['rel_err'] = True
779        elif arg == '-abs':     opts['rel_err'] = False
780        elif arg == '-half':    engines.append(arg[1:])
781        elif arg == '-fast':    engines.append(arg[1:])
782        elif arg == '-single':  engines.append(arg[1:])
783        elif arg == '-double':  engines.append(arg[1:])
784        elif arg == '-single!': engines.append(arg[1:])
785        elif arg == '-double!': engines.append(arg[1:])
786        elif arg == '-quad!':   engines.append(arg[1:])
787        elif arg == '-sasview': engines.append(arg[1:])
788        elif arg == '-edit':    opts['explore'] = True
789        elif arg == '-demo':    opts['use_demo'] = True
790        elif arg == '-default':    opts['use_demo'] = False
791    # pylint: enable=bad-whitespace
792
793    if len(engines) == 0:
794        engines.extend(['single', 'sasview'])
795    elif len(engines) == 1:
796        if engines[0][0] != 'sasview':
797            engines.append('sasview')
798        else:
799            engines.append('single')
800    elif len(engines) > 2:
801        del engines[2:]
802
803    n1 = int(args[1]) if len(args) > 1 else 1
804    n2 = int(args[2]) if len(args) > 2 else 1
805    use_sasview = any(engine=='sasview' and count>0
806                      for engine, count in zip(engines, [n1, n2]))
807
808    # Get demo parameters from model definition, or use default parameters
809    # if model does not define demo parameters
810    pars = get_pars(model_info, opts['use_demo'])
811
812
813    # Fill in parameters given on the command line
814    presets = {}
815    for arg in values:
816        k, v = arg.split('=', 1)
817        if k not in pars:
818            # extract base name without polydispersity info
819            s = set(p.split('_pd')[0] for p in pars)
820            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
821            sys.exit(1)
822        presets[k] = float(v) if not k.endswith('type') else v
823
824    # randomize parameters
825    #pars.update(set_pars)  # set value before random to control range
826    if opts['seed'] > -1:
827        pars = randomize_pars(pars, seed=opts['seed'])
828        print("Randomize using -random=%i"%opts['seed'])
829    if opts['mono']:
830        pars = suppress_pd(pars)
831    pars.update(presets)  # set value after random to control value
832    #import pprint; pprint.pprint(model_info)
833    constrain_pars(model_info, pars)
834    if use_sasview:
835        constrain_new_to_old(model_info, pars)
836    if opts['show_pars']:
837        print(str(parlist(model_info, pars, opts['is2d'])))
838
839    # Create the computational engines
840    data, _ = make_data(opts)
841    if n1:
842        base = make_engine(model_info, data, engines[0], opts['cutoff'])
843    else:
844        base = None
845    if n2:
846        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
847    else:
848        comp = None
849
850    # pylint: disable=bad-whitespace
851    # Remember it all
852    opts.update({
853        'name'      : name,
854        'def'       : model_info,
855        'n1'        : n1,
856        'n2'        : n2,
857        'presets'   : presets,
858        'pars'      : pars,
859        'data'      : data,
860        'engines'   : [base, comp],
861    })
862    # pylint: enable=bad-whitespace
863
864    return opts
865
866def explore(opts):
867    """
868    Explore the model using the Bumps GUI.
869    """
870    import wx
871    from bumps.names import FitProblem
872    from bumps.gui.app_frame import AppFrame
873
874    problem = FitProblem(Explore(opts))
875    is_mac = "cocoa" in wx.version()
876    app = wx.App()
877    frame = AppFrame(parent=None, title="explore")
878    if not is_mac: frame.Show()
879    frame.panel.set_model(model=problem)
880    frame.panel.Layout()
881    frame.panel.aui.Split(0, wx.TOP)
882    if is_mac: frame.Show()
883    app.MainLoop()
884
885class Explore(object):
886    """
887    Bumps wrapper for a SAS model comparison.
888
889    The resulting object can be used as a Bumps fit problem so that
890    parameters can be adjusted in the GUI, with plots updated on the fly.
891    """
892    def __init__(self, opts):
893        from bumps.cli import config_matplotlib
894        from . import bumps_model
895        config_matplotlib()
896        self.opts = opts
897        model_info = opts['def']
898        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
899        if not opts['is2d']:
900            active = [base + ext
901                      for base in model_info['partype']['pd-1d']
902                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
903            active.extend(model_info['partype']['fixed-1d'])
904            for k in active:
905                v = pars[k]
906                v.range(*parameter_range(k, v.value))
907        else:
908            for k, v in pars.items():
909                v.range(*parameter_range(k, v.value))
910
911        self.pars = pars
912        self.pd_types = pd_types
913        self.limits = None
914
915    def numpoints(self):
916        """
917        Return the number of points.
918        """
919        return len(self.pars) + 1  # so dof is 1
920
921    def parameters(self):
922        """
923        Return a dictionary of parameters.
924        """
925        return self.pars
926
927    def nllf(self):
928        """
929        Return cost.
930        """
931        # pylint: disable=no-self-use
932        return 0.  # No nllf
933
934    def plot(self, view='log'):
935        """
936        Plot the data and residuals.
937        """
938        pars = dict((k, v.value) for k, v in self.pars.items())
939        pars.update(self.pd_types)
940        self.opts['pars'] = pars
941        limits = compare(self.opts, limits=self.limits)
942        if self.limits is None:
943            vmin, vmax = limits
944            vmax = 1.3*vmax
945            vmin = vmax*1e-7
946            self.limits = vmin, vmax
947
948
949def main():
950    """
951    Main program.
952    """
953    opts = parse_opts()
954    if opts['explore']:
955        explore(opts)
956    else:
957        compare(opts)
958
959if __name__ == "__main__":
960    main()
Note: See TracBrowser for help on using the repository browser.