source: sasmodels/sasmodels/compare.py @ 6869ceb

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 6869ceb was 6869ceb, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

code cleanup

  • Property mode set to 100755
File size: 29.6 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np
37
38from . import core
39from . import kerneldll
40from . import generate
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_model, constrain_new_to_old
44
45USAGE = """
46usage: compare.py model N1 N2 [options...] [key=val]
47
48Compare the speed and value for a model between the SasView original and the
49sasmodels rewrite.
50
51model is the name of the model to compare (see below).
52N1 is the number of times to run sasmodels (default=1).
53N2 is the number times to run sasview (default=1).
54
55Options (* for default):
56
57    -plot*/-noplot plots or suppress the plot of the model
58    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
59    -nq=128 sets the number of Q points in the data set
60    -1d*/-2d computes 1d or 2d data
61    -preset*/-random[=seed] preset or random parameters
62    -mono/-poly* force monodisperse/polydisperse
63    -cutoff=1e-5* cutoff value for including a point in polydispersity
64    -pars/-nopars* prints the parameter set or not
65    -abs/-rel* plot relative or absolute error
66    -linear/-log*/-q4 intensity scaling
67    -hist/-nohist* plot histogram of relative error
68    -res=0 sets the resolution width dQ/Q if calculating with resolution
69    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
70    -edit starts the parameter explorer
71
72Any two calculation engines can be selected for comparison:
73
74    -single/-double/-half/-fast sets an OpenCL calculation engine
75    -single!/-double!/-quad! sets an OpenMP calculation engine
76    -sasview sets the sasview calculation engine
77
78The default is -single -sasview.  Note that the interpretation of quad
79precision depends on architecture, and may vary from 64-bit to 128-bit,
80with 80-bit floats being common (1e-19 precision).
81
82Key=value pairs allow you to set specific values for the model parameters.
83"""
84
85# Update docs with command line usage string.   This is separate from the usual
86# doc string so that we can display it at run time if there is an error.
87# lin
88__doc__ = (__doc__  # pylint: disable=redefined-builtin
89           + """
90Program description
91-------------------
92
93"""
94           + USAGE)
95
96kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
97
98MODELS = core.list_models()
99
100# CRUFT python 2.6
101if not hasattr(datetime.timedelta, 'total_seconds'):
102    def delay(dt):
103        """Return number date-time delta as number seconds"""
104        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
105else:
106    def delay(dt):
107        """Return number date-time delta as number seconds"""
108        return dt.total_seconds()
109
110
111class push_seed(object):
112    """
113    Set the seed value for the random number generator.
114
115    When used in a with statement, the random number generator state is
116    restored after the with statement is complete.
117
118    :Parameters:
119
120    *seed* : int or array_like, optional
121        Seed for RandomState
122
123    :Example:
124
125    Seed can be used directly to set the seed::
126
127        >>> from numpy.random import randint
128        >>> push_seed(24)
129        <...push_seed object at...>
130        >>> print(randint(0,1000000,3))
131        [242082    899 211136]
132
133    Seed can also be used in a with statement, which sets the random
134    number generator state for the enclosed computations and restores
135    it to the previous state on completion::
136
137        >>> with push_seed(24):
138        ...    print(randint(0,1000000,3))
139        [242082    899 211136]
140
141    Using nested contexts, we can demonstrate that state is indeed
142    restored after the block completes::
143
144        >>> with push_seed(24):
145        ...    print(randint(0,1000000))
146        ...    with push_seed(24):
147        ...        print(randint(0,1000000,3))
148        ...    print(randint(0,1000000))
149        242082
150        [242082    899 211136]
151        899
152
153    The restore step is protected against exceptions in the block::
154
155        >>> with push_seed(24):
156        ...    print(randint(0,1000000))
157        ...    try:
158        ...        with push_seed(24):
159        ...            print(randint(0,1000000,3))
160        ...            raise Exception()
161        ...    except:
162        ...        print("Exception raised")
163        ...    print(randint(0,1000000))
164        242082
165        [242082    899 211136]
166        Exception raised
167        899
168    """
169    def __init__(self, seed=None):
170        self._state = np.random.get_state()
171        np.random.seed(seed)
172
173    def __enter__(self):
174        return None
175
176    def __exit__(self, *args):
177        np.random.set_state(self._state)
178
179def tic():
180    """
181    Timer function.
182
183    Use "toc=tic()" to start the clock and "toc()" to measure
184    a time interval.
185    """
186    then = datetime.datetime.now()
187    return lambda: delay(datetime.datetime.now() - then)
188
189
190def set_beam_stop(data, radius, outer=None):
191    """
192    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
193
194    Note: this function does not use the sasview package
195    """
196    if hasattr(data, 'qx_data'):
197        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
198        data.mask = (q < radius)
199        if outer is not None:
200            data.mask |= (q >= outer)
201    else:
202        data.mask = (data.x < radius)
203        if outer is not None:
204            data.mask |= (data.x >= outer)
205
206
207def parameter_range(p, v):
208    """
209    Choose a parameter range based on parameter name and initial value.
210    """
211    if p.endswith('_pd_n'):
212        return [0, 100]
213    elif p.endswith('_pd_nsigma'):
214        return [0, 5]
215    elif p.endswith('_pd_type'):
216        return v
217    elif any(s in p for s in ('theta', 'phi', 'psi')):
218        # orientation in [-180,180], orientation pd in [0,45]
219        if p.endswith('_pd'):
220            return [0, 45]
221        else:
222            return [-180, 180]
223    elif 'sld' in p:
224        return [-0.5, 10]
225    elif p.endswith('_pd'):
226        return [0, 1]
227    elif p == 'background':
228        return [0, 10]
229    elif p == 'scale':
230        return [0, 1e3]
231    elif p == 'case_num':
232        # RPA hack
233        return [0, 10]
234    elif v < 0:
235        # Kxy parameters in rpa model can be negative
236        return [2*v, -2*v]
237    else:
238        return [0, (2*v if v > 0 else 1)]
239
240
241def _randomize_one(p, v):
242    """
243    Randomize a single parameter.
244    """
245    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
246        return v
247    else:
248        return np.random.uniform(*parameter_range(p, v))
249
250
251def randomize_pars(pars, seed=None):
252    """
253    Generate random values for all of the parameters.
254
255    Valid ranges for the random number generator are guessed from the name of
256    the parameter; this will not account for constraints such as cap radius
257    greater than cylinder radius in the capped_cylinder model, so
258    :func:`constrain_pars` needs to be called afterward..
259    """
260    with push_seed(seed):
261        # Note: the sort guarantees order `of calls to random number generator
262        pars = dict((p, _randomize_one(p, v))
263                    for p, v in sorted(pars.items()))
264    return pars
265
266def constrain_pars(model_definition, pars):
267    """
268    Restrict parameters to valid values.
269
270    This includes model specific code for models such as capped_cylinder
271    which need to support within model constraints (cap radius more than
272    cylinder radius in this case).
273    """
274    name = model_definition.name
275    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
276        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
277    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
278        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
279
280    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
281    if name == 'guinier':
282        #q_max = 0.2  # mid q maximum
283        q_max = 1.0  # high q maximum
284        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
285        pars['rg'] = min(pars['rg'], rg_max)
286
287    if name == 'rpa':
288        # Make sure phi sums to 1.0
289        if pars['case_num'] < 2:
290            pars['Phia'] = 0.
291            pars['Phib'] = 0.
292        elif pars['case_num'] < 5:
293            pars['Phia'] = 0.
294        total = sum(pars['Phi'+c] for c in 'abcd')
295        for c in 'abcd':
296            pars['Phi'+c] /= total
297
298def parlist(pars):
299    """
300    Format the parameter list for printing.
301    """
302    active = None
303    fields = {}
304    lines = []
305    for k, v in sorted(pars.items()):
306        parts = k.split('_pd')
307        #print(k, active, parts)
308        if len(parts) == 1:
309            if active: lines.append(_format_par(active, **fields))
310            active = k
311            fields = {'value': v}
312        else:
313            assert parts[0] == active
314            if parts[1]:
315                fields[parts[1][1:]] = v
316            else:
317                fields['pd'] = v
318    if active: lines.append(_format_par(active, **fields))
319    return "\n".join(lines)
320
321    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
322
323def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
324    line = "%s: %g"%(name, value)
325    if pd != 0.  and n != 0:
326        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
327                % (pd, n, nsigma, nsigma, type)
328    return line
329
330def suppress_pd(pars):
331    """
332    Suppress theta_pd for now until the normalization is resolved.
333
334    May also suppress complete polydispersity of the model to test
335    models more quickly.
336    """
337    pars = pars.copy()
338    for p in pars:
339        if p.endswith("_pd_n"): pars[p] = 0
340    return pars
341
342def eval_sasview(model_definition, data):
343    """
344    Return a model calculator using the SasView fitting engine.
345    """
346    # importing sas here so that the error message will be that sas failed to
347    # import rather than the more obscure smear_selection not imported error
348    import sas
349    from sas.models.qsmearing import smear_selection
350
351    # convert model parameters from sasmodel form to sasview form
352    #print("old",sorted(pars.items()))
353    modelname, _ = revert_model(model_definition, {})
354    #print("new",sorted(_pars.items()))
355    sas = __import__('sas.models.'+modelname)
356    ModelClass = getattr(getattr(sas.models, modelname, None), modelname, None)
357    if ModelClass is None:
358        raise ValueError("could not find model %r in sas.models"%modelname)
359    model = ModelClass()
360    smearer = smear_selection(data, model=model)
361
362    if hasattr(data, 'qx_data'):
363        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
364        index = ((~data.mask) & (~np.isnan(data.data))
365                 & (q >= data.qmin) & (q <= data.qmax))
366        if smearer is not None:
367            smearer.model = model  # because smear_selection has a bug
368            smearer.accuracy = data.accuracy
369            smearer.set_index(index)
370            theory = lambda: smearer.get_value()
371        else:
372            theory = lambda: model.evalDistribution([data.qx_data[index],
373                                                     data.qy_data[index]])
374    elif smearer is not None:
375        theory = lambda: smearer(model.evalDistribution(data.x))
376    else:
377        theory = lambda: model.evalDistribution(data.x)
378
379    def calculator(**pars):
380        """
381        Sasview calculator for model.
382        """
383        # paying for parameter conversion each time to keep life simple, if not fast
384        _, pars = revert_model(model_definition, pars)
385        for k, v in pars.items():
386            parts = k.split('.')  # polydispersity components
387            if len(parts) == 2:
388                model.dispersion[parts[0]][parts[1]] = v
389            else:
390                model.setParam(k, v)
391        return theory()
392
393    calculator.engine = "sasview"
394    return calculator
395
396DTYPE_MAP = {
397    'half': '16',
398    'fast': 'fast',
399    'single': '32',
400    'double': '64',
401    'quad': '128',
402    'f16': '16',
403    'f32': '32',
404    'f64': '64',
405    'longdouble': '128',
406}
407def eval_opencl(model_definition, data, dtype='single', cutoff=0.):
408    """
409    Return a model calculator using the OpenCL calculation engine.
410    """
411    try:
412        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
413    except Exception as exc:
414        print(exc)
415        print("... trying again with single precision")
416        dtype = 'single'
417        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
418    calculator = DirectModel(data, model, cutoff=cutoff)
419    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
420    return calculator
421
422def eval_ctypes(model_definition, data, dtype='double', cutoff=0.):
423    """
424    Return a model calculator using the DLL calculation engine.
425    """
426    if dtype == 'quad':
427        dtype = 'longdouble'
428    model = core.load_model(model_definition, dtype=dtype, platform="dll")
429    calculator = DirectModel(data, model, cutoff=cutoff)
430    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
431    return calculator
432
433def time_calculation(calculator, pars, Nevals=1):
434    """
435    Compute the average calculation time over N evaluations.
436
437    An additional call is generated without polydispersity in order to
438    initialize the calculation engine, and make the average more stable.
439    """
440    # initialize the code so time is more accurate
441    value = calculator(**suppress_pd(pars))
442    toc = tic()
443    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
444        value = calculator(**pars)
445    average_time = toc()*1000./Nevals
446    return value, average_time
447
448def make_data(opts):
449    """
450    Generate an empty dataset, used with the model to set Q points
451    and resolution.
452
453    *opts* contains the options, with 'qmax', 'nq', 'res',
454    'accuracy', 'is2d' and 'view' parsed from the command line.
455    """
456    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
457    if opts['is2d']:
458        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
459        data.accuracy = opts['accuracy']
460        set_beam_stop(data, 0.004)
461        index = ~data.mask
462    else:
463        if opts['view'] == 'log':
464            qmax = math.log10(qmax)
465            q = np.logspace(qmax-3, qmax, nq)
466        else:
467            q = np.linspace(0.001*qmax, qmax, nq)
468        data = empty_data1D(q, resolution=res)
469        index = slice(None, None)
470    return data, index
471
472def make_engine(model_definition, data, dtype, cutoff):
473    """
474    Generate the appropriate calculation engine for the given datatype.
475
476    Datatypes with '!' appended are evaluated using external C DLLs rather
477    than OpenCL.
478    """
479    if dtype == 'sasview':
480        return eval_sasview(model_definition, data)
481    elif dtype.endswith('!'):
482        return eval_ctypes(model_definition, data, dtype=dtype[:-1],
483                           cutoff=cutoff)
484    else:
485        return eval_opencl(model_definition, data, dtype=dtype,
486                           cutoff=cutoff)
487
488def compare(opts, limits=None):
489    """
490    Preform a comparison using options from the command line.
491
492    *limits* are the limits on the values to use, either to set the y-axis
493    for 1D or to set the colormap scale for 2D.  If None, then they are
494    inferred from the data and returned. When exploring using Bumps,
495    the limits are set when the model is initially called, and maintained
496    as the values are adjusted, making it easier to see the effects of the
497    parameters.
498    """
499    Nbase, Ncomp = opts['n1'], opts['n2']
500    pars = opts['pars']
501    data = opts['data']
502
503    # Base calculation
504    if Nbase > 0:
505        base = opts['engines'][0]
506        try:
507            base_value, base_time = time_calculation(base, pars, Nbase)
508            print("%s t=%.1f ms, intensity=%.0f"
509                  % (base.engine, base_time, sum(base_value)))
510        except ImportError:
511            traceback.print_exc()
512            Nbase = 0
513
514    # Comparison calculation
515    if Ncomp > 0:
516        comp = opts['engines'][1]
517        try:
518            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
519            print("%s t=%.1f ms, intensity=%.0f"
520                  % (comp.engine, comp_time, sum(comp_value)))
521        except ImportError:
522            traceback.print_exc()
523            Ncomp = 0
524
525    # Compare, but only if computing both forms
526    if Nbase > 0 and Ncomp > 0:
527        resid = (base_value - comp_value)
528        relerr = resid/comp_value
529        _print_stats("|%s-%s|"
530                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
531                     resid)
532        _print_stats("|(%s-%s)/%s|"
533                     % (base.engine, comp.engine, comp.engine),
534                     relerr)
535
536    # Plot if requested
537    if not opts['plot'] and not opts['explore']: return
538    view = opts['view']
539    import matplotlib.pyplot as plt
540    if limits is None:
541        vmin, vmax = np.Inf, -np.Inf
542        if Nbase > 0:
543            vmin = min(vmin, min(base_value))
544            vmax = max(vmax, max(base_value))
545        if Ncomp > 0:
546            vmin = min(vmin, min(comp_value))
547            vmax = max(vmax, max(comp_value))
548        limits = vmin, vmax
549
550    if Nbase > 0:
551        if Ncomp > 0: plt.subplot(131)
552        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
553        plt.title("%s t=%.1f ms"%(base.engine, base_time))
554        #cbar_title = "log I"
555    if Ncomp > 0:
556        if Nbase > 0: plt.subplot(132)
557        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
558        plt.title("%s t=%.1f ms"%(comp.engine, comp_time))
559        #cbar_title = "log I"
560    if Ncomp > 0 and Nbase > 0:
561        plt.subplot(133)
562        if not opts['rel_err']:
563            err, errstr, errview = resid, "abs err", "linear"
564        else:
565            err, errstr, errview = abs(relerr), "rel err", "log"
566        #err,errstr = base/comp,"ratio"
567        plot_theory(data, None, resid=err, view=errview, use_data=False)
568        if view == 'linear':
569            plt.xscale('linear')
570        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
571        #cbar_title = errstr if errview=="linear" else "log "+errstr
572    #if is2D:
573    #    h = plt.colorbar()
574    #    h.ax.set_title(cbar_title)
575
576    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
577        plt.figure()
578        v = relerr
579        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
580        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
581        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
582                   % (base.engine, comp.engine, comp.engine))
583        plt.ylabel('P(err)')
584        plt.title('Distribution of relative error between calculation engines')
585
586    if not opts['explore']:
587        plt.show()
588
589    return limits
590
591def _print_stats(label, err):
592    sorted_err = np.sort(abs(err))
593    p50 = int((len(err)-1)*0.50)
594    p98 = int((len(err)-1)*0.98)
595    data = [
596        "max:%.3e"%sorted_err[-1],
597        "median:%.3e"%sorted_err[p50],
598        "98%%:%.3e"%sorted_err[p98],
599        "rms:%.3e"%np.sqrt(np.mean(err**2)),
600        "zero-offset:%+.3e"%np.mean(err),
601        ]
602    print(label+"  "+"  ".join(data))
603
604
605
606# ===========================================================================
607#
608NAME_OPTIONS = set([
609    'plot', 'noplot',
610    'half', 'fast', 'single', 'double',
611    'single!', 'double!', 'quad!', 'sasview',
612    'lowq', 'midq', 'highq', 'exq',
613    '2d', '1d',
614    'preset', 'random',
615    'poly', 'mono',
616    'nopars', 'pars',
617    'rel', 'abs',
618    'linear', 'log', 'q4',
619    'hist', 'nohist',
620    'edit',
621    ])
622VALUE_OPTIONS = [
623    # Note: random is both a name option and a value option
624    'cutoff', 'random', 'nq', 'res', 'accuracy',
625    ]
626
627def columnize(L, indent="", width=79):
628    """
629    Format a list of strings into columns.
630
631    Returns a string with carriage returns ready for printing.
632    """
633    column_width = max(len(w) for w in L) + 1
634    num_columns = (width - len(indent)) // column_width
635    num_rows = len(L) // num_columns
636    L = L + [""] * (num_rows*num_columns - len(L))
637    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
638    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
639             for row in zip(*columns)]
640    output = indent + ("\n"+indent).join(lines)
641    return output
642
643
644def get_demo_pars(model_definition):
645    """
646    Extract demo parameters from the model definition.
647    """
648    info = generate.make_info(model_definition)
649    # Get the default values for the parameters
650    pars = dict((p[0], p[2]) for p in info['parameters'])
651
652    # Fill in default values for the polydispersity parameters
653    for p in info['parameters']:
654        if p[4] in ('volume', 'orientation'):
655            pars[p[0]+'_pd'] = 0.0
656            pars[p[0]+'_pd_n'] = 0
657            pars[p[0]+'_pd_nsigma'] = 3.0
658            pars[p[0]+'_pd_type'] = "gaussian"
659
660    # Plug in values given in demo
661    pars.update(info['demo'])
662    return pars
663
664def parse_opts():
665    """
666    Parse command line options.
667    """
668    MODELS = core.list_models()
669    flags = [arg for arg in sys.argv[1:]
670             if arg.startswith('-')]
671    values = [arg for arg in sys.argv[1:]
672              if not arg.startswith('-') and '=' in arg]
673    args = [arg for arg in sys.argv[1:]
674            if not arg.startswith('-') and '=' not in arg]
675    models = "\n    ".join("%-15s"%v for v in MODELS)
676    if len(args) == 0:
677        print(USAGE)
678        print("\nAvailable models:")
679        print(columnize(MODELS, indent="  "))
680        sys.exit(1)
681
682    name = args[0]
683    try:
684        model_definition = core.load_model_definition(name)
685    except ImportError, exc:
686        print(str(exc))
687        print("Use one of:\n    " + models)
688        sys.exit(1)
689    if len(args) > 3:
690        print("expected parameters: model N1 N2")
691
692    invalid = [o[1:] for o in flags
693               if o[1:] not in NAME_OPTIONS
694               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
695    if invalid:
696        print("Invalid options: %s"%(", ".join(invalid)))
697        sys.exit(1)
698
699
700    # pylint: disable=bad-whitespace
701    # Interpret the flags
702    opts = {
703        'plot'      : True,
704        'view'      : 'log',
705        'is2d'      : False,
706        'qmax'      : 0.05,
707        'nq'        : 128,
708        'res'       : 0.0,
709        'accuracy'  : 'Low',
710        'cutoff'    : 1e-5,
711        'seed'      : -1,  # default to preset
712        'mono'      : False,
713        'show_pars' : False,
714        'show_hist' : False,
715        'rel_err'   : True,
716        'explore'   : False,
717    }
718    engines = []
719    for arg in flags:
720        if arg == '-noplot':    opts['plot'] = False
721        elif arg == '-plot':    opts['plot'] = True
722        elif arg == '-linear':  opts['view'] = 'linear'
723        elif arg == '-log':     opts['view'] = 'log'
724        elif arg == '-q4':      opts['view'] = 'q4'
725        elif arg == '-1d':      opts['is2d'] = False
726        elif arg == '-2d':      opts['is2d'] = True
727        elif arg == '-exq':     opts['qmax'] = 10.0
728        elif arg == '-highq':   opts['qmax'] = 1.0
729        elif arg == '-midq':    opts['qmax'] = 0.2
730        elif arg == '-lowq':    opts['qmax'] = 0.05
731        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
732        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
733        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
734        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
735        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
736        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
737        elif arg == '-preset':  opts['seed'] = -1
738        elif arg == '-mono':    opts['mono'] = True
739        elif arg == '-poly':    opts['mono'] = False
740        elif arg == '-pars':    opts['show_pars'] = True
741        elif arg == '-nopars':  opts['show_pars'] = False
742        elif arg == '-hist':    opts['show_hist'] = True
743        elif arg == '-nohist':  opts['show_hist'] = False
744        elif arg == '-rel':     opts['rel_err'] = True
745        elif arg == '-abs':     opts['rel_err'] = False
746        elif arg == '-half':    engines.append(arg[1:])
747        elif arg == '-fast':    engines.append(arg[1:])
748        elif arg == '-single':  engines.append(arg[1:])
749        elif arg == '-double':  engines.append(arg[1:])
750        elif arg == '-single!': engines.append(arg[1:])
751        elif arg == '-double!': engines.append(arg[1:])
752        elif arg == '-quad!':   engines.append(arg[1:])
753        elif arg == '-sasview': engines.append(arg[1:])
754        elif arg == '-edit':    opts['explore'] = True
755    # pylint: enable=bad-whitespace
756
757    if len(engines) == 0:
758        engines.extend(['single', 'sasview'])
759    elif len(engines) == 1:
760        if engines[0][0] != 'sasview':
761            engines.append('sasview')
762        else:
763            engines.append('single')
764    elif len(engines) > 2:
765        del engines[2:]
766
767    n1 = int(args[1]) if len(args) > 1 else 1
768    n2 = int(args[2]) if len(args) > 2 else 1
769
770    # Get demo parameters from model definition, or use default parameters
771    # if model does not define demo parameters
772    pars = get_demo_pars(model_definition)
773
774    # Fill in parameters given on the command line
775    presets = {}
776    for arg in values:
777        k, v = arg.split('=', 1)
778        if k not in pars:
779            # extract base name without polydispersity info
780            s = set(p.split('_pd')[0] for p in pars)
781            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
782            sys.exit(1)
783        presets[k] = float(v) if not k.endswith('type') else v
784
785    # randomize parameters
786    #pars.update(set_pars)  # set value before random to control range
787    if opts['seed'] > -1:
788        pars = randomize_pars(pars, seed=opts['seed'])
789        print("Randomize using -random=%i"%opts['seed'])
790    if opts['mono']:
791        pars = suppress_pd(pars)
792    pars.update(presets)  # set value after random to control value
793    constrain_pars(model_definition, pars)
794    constrain_new_to_old(model_definition, pars)
795    if opts['show_pars']:
796        print(str(parlist(pars)))
797
798    # Create the computational engines
799    data, _ = make_data(opts)
800    if n1:
801        base = make_engine(model_definition, data, engines[0], opts['cutoff'])
802    else:
803        base = None
804    if n2:
805        comp = make_engine(model_definition, data, engines[1], opts['cutoff'])
806    else:
807        comp = None
808
809    # pylint: disable=bad-whitespace
810    # Remember it all
811    opts.update({
812        'name'      : name,
813        'def'       : model_definition,
814        'n1'        : n1,
815        'n2'        : n2,
816        'presets'   : presets,
817        'pars'      : pars,
818        'data'      : data,
819        'engines'   : [base, comp],
820    })
821    # pylint: enable=bad-whitespace
822
823    return opts
824
825def explore(opts):
826    """
827    Explore the model using the Bumps GUI.
828    """
829    import wx
830    from bumps.names import FitProblem
831    from bumps.gui.app_frame import AppFrame
832
833    problem = FitProblem(Explore(opts))
834    is_mac = "cocoa" in wx.version()
835    app = wx.App()
836    frame = AppFrame(parent=None, title="explore")
837    if not is_mac: frame.Show()
838    frame.panel.set_model(model=problem)
839    frame.panel.Layout()
840    frame.panel.aui.Split(0, wx.TOP)
841    if is_mac: frame.Show()
842    app.MainLoop()
843
844class Explore(object):
845    """
846    Bumps wrapper for a SAS model comparison.
847
848    The resulting object can be used as a Bumps fit problem so that
849    parameters can be adjusted in the GUI, with plots updated on the fly.
850    """
851    def __init__(self, opts):
852        from bumps.cli import config_matplotlib
853        from . import bumps_model
854        config_matplotlib()
855        self.opts = opts
856        info = generate.make_info(opts['def'])
857        pars, pd_types = bumps_model.create_parameters(info, **opts['pars'])
858        if not opts['is2d']:
859            active = [base + ext
860                      for base in info['partype']['pd-1d']
861                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
862            active.extend(info['partype']['fixed-1d'])
863            for k in active:
864                v = pars[k]
865                v.range(*parameter_range(k, v.value))
866        else:
867            for k, v in pars.items():
868                v.range(*parameter_range(k, v.value))
869
870        self.pars = pars
871        self.pd_types = pd_types
872        self.limits = None
873
874    def numpoints(self):
875        """
876        Return the number of points.
877        """
878        return len(self.pars) + 1  # so dof is 1
879
880    def parameters(self):
881        """
882        Return a dictionary of parameters.
883        """
884        return self.pars
885
886    def nllf(self):
887        """
888        Return cost.
889        """
890        # pylint: disable=no-self-use
891        return 0.  # No nllf
892
893    def plot(self, view='log'):
894        """
895        Plot the data and residuals.
896        """
897        pars = dict((k, v.value) for k, v in self.pars.items())
898        pars.update(self.pd_types)
899        self.opts['pars'] = pars
900        limits = compare(self.opts, limits=self.limits)
901        if self.limits is None:
902            vmin, vmax = limits
903            vmax = 1.3*vmax
904            vmin = vmax*1e-7
905            self.limits = vmin, vmax
906
907
908def main():
909    """
910    Main program.
911    """
912    opts = parse_opts()
913    if opts['explore']:
914        explore(opts)
915    else:
916        compare(opts)
917
918if __name__ == "__main__":
919    main()
Note: See TracBrowser for help on using the repository browser.