source: sasmodels/sasmodels/compare.py @ 1d4017a

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 1d4017a was 1d4017a, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

add command for list all parameters in all models

  • Property mode set to 100755
File size: 29.0 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31USAGE = """
32usage: compare.py model N1 N2 [options...] [key=val]
33
34Compare the speed and value for a model between the SasView original and the
35sasmodels rewrite.
36
37model is the name of the model to compare (see below).
38N1 is the number of times to run sasmodels (default=1).
39N2 is the number times to run sasview (default=1).
40
41Options (* for default):
42
43    -plot*/-noplot plots or suppress the plot of the model
44    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
45    -nq=128 sets the number of Q points in the data set
46    -1d*/-2d computes 1d or 2d data
47    -preset*/-random[=seed] preset or random parameters
48    -mono/-poly* force monodisperse/polydisperse
49    -cutoff=1e-5* cutoff value for including a point in polydispersity
50    -pars/-nopars* prints the parameter set or not
51    -abs/-rel* plot relative or absolute error
52    -linear/-log*/-q4 intensity scaling
53    -hist/-nohist* plot histogram of relative error
54    -res=0 sets the resolution width dQ/Q if calculating with resolution
55    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
56    -edit starts the parameter explorer
57
58Any two calculation engines can be selected for comparison:
59
60    -single/-double/-half/-fast sets an OpenCL calculation engine
61    -single!/-double!/-quad! sets an OpenMP calculation engine
62    -sasview sets the sasview calculation engine
63
64The default is -single -sasview.  Note that the interpretation of quad
65precision depends on architecture, and may vary from 64-bit to 128-bit,
66with 80-bit floats being common (1e-19 precision).
67
68Key=value pairs allow you to set specific values for the model parameters.
69"""
70
71# Update docs with command line usage string.   This is separate from the usual
72# doc string so that we can display it at run time if there is an error.
73# lin
74__doc__ = (__doc__  # pylint: disable=redefined-builtin
75           + """
76Program description
77-------------------
78
79"""
80           + USAGE)
81
82
83
84import sys
85import math
86from os.path import basename, dirname, join as joinpath
87import glob
88import datetime
89import traceback
90
91import numpy as np
92
93ROOT = dirname(__file__)
94sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
95
96
97from . import core
98from . import kerneldll
99from . import generate
100from .data import plot_theory, empty_data1D, empty_data2D
101from .direct_model import DirectModel
102from .convert import revert_model, constrain_new_to_old
103kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
104
105# List of available models
106MODELS = [basename(f)[:-3]
107          for f in sorted(glob.glob(joinpath(ROOT, "models", "[a-zA-Z]*.py")))]
108
109# CRUFT python 2.6
110if not hasattr(datetime.timedelta, 'total_seconds'):
111    def delay(dt):
112        """Return number date-time delta as number seconds"""
113        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
114else:
115    def delay(dt):
116        """Return number date-time delta as number seconds"""
117        return dt.total_seconds()
118
119
120class push_seed(object):
121    """
122    Set the seed value for the random number generator.
123
124    When used in a with statement, the random number generator state is
125    restored after the with statement is complete.
126
127    :Parameters:
128
129    *seed* : int or array_like, optional
130        Seed for RandomState
131
132    :Example:
133
134    Seed can be used directly to set the seed::
135
136        >>> from numpy.random import randint
137        >>> push_seed(24)
138        <...push_seed object at...>
139        >>> print(randint(0,1000000,3))
140        [242082    899 211136]
141
142    Seed can also be used in a with statement, which sets the random
143    number generator state for the enclosed computations and restores
144    it to the previous state on completion::
145
146        >>> with push_seed(24):
147        ...    print(randint(0,1000000,3))
148        [242082    899 211136]
149
150    Using nested contexts, we can demonstrate that state is indeed
151    restored after the block completes::
152
153        >>> with push_seed(24):
154        ...    print(randint(0,1000000))
155        ...    with push_seed(24):
156        ...        print(randint(0,1000000,3))
157        ...    print(randint(0,1000000))
158        242082
159        [242082    899 211136]
160        899
161
162    The restore step is protected against exceptions in the block::
163
164        >>> with push_seed(24):
165        ...    print(randint(0,1000000))
166        ...    try:
167        ...        with push_seed(24):
168        ...            print(randint(0,1000000,3))
169        ...            raise Exception()
170        ...    except:
171        ...        print("Exception raised")
172        ...    print(randint(0,1000000))
173        242082
174        [242082    899 211136]
175        Exception raised
176        899
177    """
178    def __init__(self, seed=None):
179        self._state = np.random.get_state()
180        np.random.seed(seed)
181
182    def __enter__(self):
183        return None
184
185    def __exit__(self, *args):
186        np.random.set_state(self._state)
187
188def tic():
189    """
190    Timer function.
191
192    Use "toc=tic()" to start the clock and "toc()" to measure
193    a time interval.
194    """
195    then = datetime.datetime.now()
196    return lambda: delay(datetime.datetime.now() - then)
197
198
199def set_beam_stop(data, radius, outer=None):
200    """
201    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
202
203    Note: this function does not use the sasview package
204    """
205    if hasattr(data, 'qx_data'):
206        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
207        data.mask = (q < radius)
208        if outer is not None:
209            data.mask |= (q >= outer)
210    else:
211        data.mask = (data.x < radius)
212        if outer is not None:
213            data.mask |= (data.x >= outer)
214
215
216def parameter_range(p, v):
217    """
218    Choose a parameter range based on parameter name and initial value.
219    """
220    if p.endswith('_pd_n'):
221        return [0, 100]
222    elif p.endswith('_pd_nsigma'):
223        return [0, 5]
224    elif p.endswith('_pd_type'):
225        return v
226    elif any(s in p for s in ('theta', 'phi', 'psi')):
227        # orientation in [-180,180], orientation pd in [0,45]
228        if p.endswith('_pd'):
229            return [0, 45]
230        else:
231            return [-180, 180]
232    elif 'sld' in p:
233        return [-0.5, 10]
234    elif p.endswith('_pd'):
235        return [0, 1]
236    elif p == 'background':
237        return [0, 10]
238    elif p == 'scale':
239        return [0, 1e3]
240    elif p == 'case_num':
241        # RPA hack
242        return [0, 10]
243    elif v < 0:
244        # Kxy parameters in rpa model can be negative
245        return [2*v, -2*v]
246    else:
247        return [0, (2*v if v > 0 else 1)]
248
249
250def _randomize_one(p, v):
251    """
252    Randomize a single parameter.
253    """
254    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
255        return v
256    else:
257        return np.random.uniform(*parameter_range(p, v))
258
259
260def randomize_pars(pars, seed=None):
261    """
262    Generate random values for all of the parameters.
263
264    Valid ranges for the random number generator are guessed from the name of
265    the parameter; this will not account for constraints such as cap radius
266    greater than cylinder radius in the capped_cylinder model, so
267    :func:`constrain_pars` needs to be called afterward..
268    """
269    with push_seed(seed):
270        # Note: the sort guarantees order `of calls to random number generator
271        pars = dict((p, _randomize_one(p, v))
272                    for p, v in sorted(pars.items()))
273    return pars
274
275def constrain_pars(model_definition, pars):
276    """
277    Restrict parameters to valid values.
278
279    This includes model specific code for models such as capped_cylinder
280    which need to support within model constraints (cap radius more than
281    cylinder radius in this case).
282    """
283    name = model_definition.name
284    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
285        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
286    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
287        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
288
289    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
290    if name == 'guinier':
291        #q_max = 0.2  # mid q maximum
292        q_max = 1.0  # high q maximum
293        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
294        pars['rg'] = min(pars['rg'], rg_max)
295
296    if name == 'rpa':
297        # Make sure phi sums to 1.0
298        if pars['case_num'] < 2:
299            pars['Phia'] = 0.
300            pars['Phib'] = 0.
301        elif pars['case_num'] < 5:
302            pars['Phia'] = 0.
303        total = sum(pars['Phi'+c] for c in 'abcd')
304        for c in 'abcd':
305            pars['Phi'+c] /= total
306
307def parlist(pars):
308    """
309    Format the parameter list for printing.
310    """
311    return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
312
313def suppress_pd(pars):
314    """
315    Suppress theta_pd for now until the normalization is resolved.
316
317    May also suppress complete polydispersity of the model to test
318    models more quickly.
319    """
320    pars = pars.copy()
321    for p in pars:
322        if p.endswith("_pd_n"): pars[p] = 0
323    return pars
324
325def eval_sasview(model_definition, data):
326    """
327    Return a model calculator using the SasView fitting engine.
328    """
329    # importing sas here so that the error message will be that sas failed to
330    # import rather than the more obscure smear_selection not imported error
331    import sas
332    from sas.models.qsmearing import smear_selection
333
334    # convert model parameters from sasmodel form to sasview form
335    #print("old",sorted(pars.items()))
336    modelname, _ = revert_model(model_definition, {})
337    #print("new",sorted(_pars.items()))
338    sas = __import__('sas.models.'+modelname)
339    ModelClass = getattr(getattr(sas.models, modelname, None), modelname, None)
340    if ModelClass is None:
341        raise ValueError("could not find model %r in sas.models"%modelname)
342    model = ModelClass()
343    smearer = smear_selection(data, model=model)
344
345    if hasattr(data, 'qx_data'):
346        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
347        index = ((~data.mask) & (~np.isnan(data.data))
348                 & (q >= data.qmin) & (q <= data.qmax))
349        if smearer is not None:
350            smearer.model = model  # because smear_selection has a bug
351            smearer.accuracy = data.accuracy
352            smearer.set_index(index)
353            theory = lambda: smearer.get_value()
354        else:
355            theory = lambda: model.evalDistribution([data.qx_data[index],
356                                                     data.qy_data[index]])
357    elif smearer is not None:
358        theory = lambda: smearer(model.evalDistribution(data.x))
359    else:
360        theory = lambda: model.evalDistribution(data.x)
361
362    def calculator(**pars):
363        """
364        Sasview calculator for model.
365        """
366        # paying for parameter conversion each time to keep life simple, if not fast
367        _, pars = revert_model(model_definition, pars)
368        for k, v in pars.items():
369            parts = k.split('.')  # polydispersity components
370            if len(parts) == 2:
371                model.dispersion[parts[0]][parts[1]] = v
372            else:
373                model.setParam(k, v)
374        return theory()
375
376    calculator.engine = "sasview"
377    return calculator
378
379DTYPE_MAP = {
380    'half': '16',
381    'fast': 'fast',
382    'single': '32',
383    'double': '64',
384    'quad': '128',
385    'f16': '16',
386    'f32': '32',
387    'f64': '64',
388    'longdouble': '128',
389}
390def eval_opencl(model_definition, data, dtype='single', cutoff=0.):
391    """
392    Return a model calculator using the OpenCL calculation engine.
393    """
394    try:
395        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
396    except Exception as exc:
397        print(exc)
398        print("... trying again with single precision")
399        dtype = 'single'
400        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
401    calculator = DirectModel(data, model, cutoff=cutoff)
402    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
403    return calculator
404
405def eval_ctypes(model_definition, data, dtype='double', cutoff=0.):
406    """
407    Return a model calculator using the DLL calculation engine.
408    """
409    if dtype == 'quad':
410        dtype = 'longdouble'
411    model = core.load_model(model_definition, dtype=dtype, platform="dll")
412    calculator = DirectModel(data, model, cutoff=cutoff)
413    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
414    return calculator
415
416def time_calculation(calculator, pars, Nevals=1):
417    """
418    Compute the average calculation time over N evaluations.
419
420    An additional call is generated without polydispersity in order to
421    initialize the calculation engine, and make the average more stable.
422    """
423    # initialize the code so time is more accurate
424    value = calculator(**suppress_pd(pars))
425    toc = tic()
426    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
427        value = calculator(**pars)
428    average_time = toc()*1000./Nevals
429    return value, average_time
430
431def make_data(opts):
432    """
433    Generate an empty dataset, used with the model to set Q points
434    and resolution.
435
436    *opts* contains the options, with 'qmax', 'nq', 'res',
437    'accuracy', 'is2d' and 'view' parsed from the command line.
438    """
439    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
440    if opts['is2d']:
441        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
442        data.accuracy = opts['accuracy']
443        set_beam_stop(data, 0.004)
444        index = ~data.mask
445    else:
446        if opts['view'] == 'log':
447            qmax = math.log10(qmax)
448            q = np.logspace(qmax-3, qmax, nq)
449        else:
450            q = np.linspace(0.001*qmax, qmax, nq)
451        data = empty_data1D(q, resolution=res)
452        index = slice(None, None)
453    return data, index
454
455def make_engine(model_definition, data, dtype, cutoff):
456    """
457    Generate the appropriate calculation engine for the given datatype.
458
459    Datatypes with '!' appended are evaluated using external C DLLs rather
460    than OpenCL.
461    """
462    if dtype == 'sasview':
463        return eval_sasview(model_definition, data)
464    elif dtype.endswith('!'):
465        return eval_ctypes(model_definition, data, dtype=dtype[:-1],
466                           cutoff=cutoff)
467    else:
468        return eval_opencl(model_definition, data, dtype=dtype,
469                           cutoff=cutoff)
470
471def compare(opts, limits=None):
472    """
473    Preform a comparison using options from the command line.
474
475    *limits* are the limits on the values to use, either to set the y-axis
476    for 1D or to set the colormap scale for 2D.  If None, then they are
477    inferred from the data and returned. When exploring using Bumps,
478    the limits are set when the model is initially called, and maintained
479    as the values are adjusted, making it easier to see the effects of the
480    parameters.
481    """
482    Nbase, Ncomp = opts['n1'], opts['n2']
483    pars = opts['pars']
484    data = opts['data']
485
486    # Base calculation
487    if Nbase > 0:
488        base = opts['engines'][0]
489        try:
490            base_value, base_time = time_calculation(base, pars, Nbase)
491            print("%s t=%.1f ms, intensity=%.0f"
492                  % (base.engine, base_time, sum(base_value)))
493        except ImportError:
494            traceback.print_exc()
495            Nbase = 0
496
497    # Comparison calculation
498    if Ncomp > 0:
499        comp = opts['engines'][1]
500        try:
501            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
502            print("%s t=%.1f ms, intensity=%.0f"
503                  % (comp.engine, comp_time, sum(comp_value)))
504        except ImportError:
505            traceback.print_exc()
506            Ncomp = 0
507
508    # Compare, but only if computing both forms
509    if Nbase > 0 and Ncomp > 0:
510        resid = (base_value - comp_value)
511        relerr = resid/comp_value
512        _print_stats("|%s-%s|"
513                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
514                     resid)
515        _print_stats("|(%s-%s)/%s|"
516                     % (base.engine, comp.engine, comp.engine),
517                     relerr)
518
519    # Plot if requested
520    if not opts['plot'] and not opts['explore']: return
521    view = opts['view']
522    import matplotlib.pyplot as plt
523    if limits is None:
524        vmin, vmax = np.Inf, -np.Inf
525        if Nbase > 0:
526            vmin = min(vmin, min(base_value))
527            vmax = max(vmax, max(base_value))
528        if Ncomp > 0:
529            vmin = min(vmin, min(comp_value))
530            vmax = max(vmax, max(comp_value))
531        limits = vmin, vmax
532
533    if Nbase > 0:
534        if Ncomp > 0: plt.subplot(131)
535        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
536        plt.title("%s t=%.1f ms"%(base.engine, base_time))
537        #cbar_title = "log I"
538    if Ncomp > 0:
539        if Nbase > 0: plt.subplot(132)
540        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
541        plt.title("%s t=%.1f ms"%(comp.engine, comp_time))
542        #cbar_title = "log I"
543    if Ncomp > 0 and Nbase > 0:
544        plt.subplot(133)
545        if '-abs' in opts:
546            err, errstr, errview = resid, "abs err", "linear"
547        else:
548            err, errstr, errview = abs(relerr), "rel err", "log"
549        #err,errstr = base/comp,"ratio"
550        plot_theory(data, None, resid=err, view=errview, use_data=False)
551        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
552        #cbar_title = errstr if errview=="linear" else "log "+errstr
553    #if is2D:
554    #    h = plt.colorbar()
555    #    h.ax.set_title(cbar_title)
556
557    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
558        plt.figure()
559        v = relerr
560        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
561        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
562        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
563                   % (base.engine, comp.engine, comp.engine))
564        plt.ylabel('P(err)')
565        plt.title('Distribution of relative error between calculation engines')
566
567    if not opts['explore']:
568        plt.show()
569
570    return limits
571
572def _print_stats(label, err):
573    sorted_err = np.sort(abs(err))
574    p50 = int((len(err)-1)*0.50)
575    p98 = int((len(err)-1)*0.98)
576    data = [
577        "max:%.3e"%sorted_err[-1],
578        "median:%.3e"%sorted_err[p50],
579        "98%%:%.3e"%sorted_err[p98],
580        "rms:%.3e"%np.sqrt(np.mean(err**2)),
581        "zero-offset:%+.3e"%np.mean(err),
582        ]
583    print(label+"  "+"  ".join(data))
584
585
586
587# ===========================================================================
588#
589NAME_OPTIONS = set([
590    'plot', 'noplot',
591    'half', 'fast', 'single', 'double',
592    'single!', 'double!', 'quad!', 'sasview',
593    'lowq', 'midq', 'highq', 'exq',
594    '2d', '1d',
595    'preset', 'random',
596    'poly', 'mono',
597    'nopars', 'pars',
598    'rel', 'abs',
599    'linear', 'log', 'q4',
600    'hist', 'nohist',
601    'edit',
602    ])
603VALUE_OPTIONS = [
604    # Note: random is both a name option and a value option
605    'cutoff', 'random', 'nq', 'res', 'accuracy',
606    ]
607
608def columnize(L, indent="", width=79):
609    """
610    Format a list of strings into columns.
611
612    Returns a string with carriage returns ready for printing.
613    """
614    column_width = max(len(w) for w in L) + 1
615    num_columns = (width - len(indent)) // column_width
616    num_rows = len(L) // num_columns
617    L = L + [""] * (num_rows*num_columns - len(L))
618    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
619    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
620             for row in zip(*columns)]
621    output = indent + ("\n"+indent).join(lines)
622    return output
623
624
625def get_demo_pars(model_definition):
626    """
627    Extract demo parameters from the model definition.
628    """
629    info = generate.make_info(model_definition)
630    # Get the default values for the parameters
631    pars = dict((p[0], p[2]) for p in info['parameters'])
632
633    # Fill in default values for the polydispersity parameters
634    for p in info['parameters']:
635        if p[4] in ('volume', 'orientation'):
636            pars[p[0]+'_pd'] = 0.0
637            pars[p[0]+'_pd_n'] = 0
638            pars[p[0]+'_pd_nsigma'] = 3.0
639            pars[p[0]+'_pd_type'] = "gaussian"
640
641    # Plug in values given in demo
642    pars.update(info['demo'])
643    return pars
644
645def parse_opts():
646    """
647    Parse command line options.
648    """
649    flags = [arg for arg in sys.argv[1:]
650             if arg.startswith('-')]
651    values = [arg for arg in sys.argv[1:]
652              if not arg.startswith('-') and '=' in arg]
653    args = [arg for arg in sys.argv[1:]
654            if not arg.startswith('-') and '=' not in arg]
655    models = "\n    ".join("%-15s"%v for v in MODELS)
656    if len(args) == 0:
657        print(USAGE)
658        print("\nAvailable models:")
659        print(columnize(MODELS, indent="  "))
660        sys.exit(1)
661    if args[0] not in MODELS:
662        print("Model %r not available. Use one of:\n    %s"%(args[0], models))
663        sys.exit(1)
664    if len(args) > 3:
665        print("expected parameters: model N1 N2")
666
667    invalid = [o[1:] for o in flags
668               if o[1:] not in NAME_OPTIONS
669               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
670    if invalid:
671        print("Invalid options: %s"%(", ".join(invalid)))
672        sys.exit(1)
673
674
675    # pylint: disable=bad-whitespace
676    # Interpret the flags
677    opts = {
678        'plot'      : True,
679        'view'      : 'log',
680        'is2d'      : False,
681        'qmax'      : 0.05,
682        'nq'        : 128,
683        'res'       : 0.0,
684        'accuracy'  : 'Low',
685        'cutoff'    : 1e-5,
686        'seed'      : -1,  # default to preset
687        'mono'      : False,
688        'show_pars' : False,
689        'show_hist' : False,
690        'rel_err'   : True,
691        'explore'   : False,
692    }
693    engines = []
694    for arg in flags:
695        if arg == '-noplot':    opts['plot'] = False
696        elif arg == '-plot':    opts['plot'] = True
697        elif arg == '-linear':  opts['view'] = 'linear'
698        elif arg == '-log':     opts['view'] = 'log'
699        elif arg == '-q4':      opts['view'] = 'q4'
700        elif arg == '-1d':      opts['is2d'] = False
701        elif arg == '-2d':      opts['is2d'] = True
702        elif arg == '-exq':     opts['qmax'] = 10.0
703        elif arg == '-highq':   opts['qmax'] = 1.0
704        elif arg == '-midq':    opts['qmax'] = 0.2
705        elif arg == '-loq':     opts['qmax'] = 0.05
706        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
707        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
708        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
709        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
710        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
711        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
712        elif arg == '-preset':  opts['seed'] = -1
713        elif arg == '-mono':    opts['mono'] = True
714        elif arg == '-poly':    opts['mono'] = False
715        elif arg == '-pars':    opts['show_pars'] = True
716        elif arg == '-nopars':  opts['show_pars'] = False
717        elif arg == '-hist':    opts['show_hist'] = True
718        elif arg == '-nohist':  opts['show_hist'] = False
719        elif arg == '-rel':     opts['rel_err'] = True
720        elif arg == '-abs':     opts['rel_err'] = False
721        elif arg == '-half':    engines.append(arg[1:])
722        elif arg == '-fast':    engines.append(arg[1:])
723        elif arg == '-single':  engines.append(arg[1:])
724        elif arg == '-double':  engines.append(arg[1:])
725        elif arg == '-single!': engines.append(arg[1:])
726        elif arg == '-double!': engines.append(arg[1:])
727        elif arg == '-quad!':   engines.append(arg[1:])
728        elif arg == '-sasview': engines.append(arg[1:])
729        elif arg == '-edit':    opts['explore'] = True
730    # pylint: enable=bad-whitespace
731
732    if len(engines) == 0:
733        engines.extend(['single', 'sasview'])
734    elif len(engines) == 1:
735        if engines[0][0] != 'sasview':
736            engines.append('sasview')
737        else:
738            engines.append('single')
739    elif len(engines) > 2:
740        del engines[2:]
741
742    name = args[0]
743    model_definition = core.load_model_definition(name)
744
745    n1 = int(args[1]) if len(args) > 1 else 1
746    n2 = int(args[2]) if len(args) > 2 else 1
747
748    # Get demo parameters from model definition, or use default parameters
749    # if model does not define demo parameters
750    pars = get_demo_pars(model_definition)
751
752    # Fill in parameters given on the command line
753    presets = {}
754    for arg in values:
755        k, v = arg.split('=', 1)
756        if k not in pars:
757            # extract base name without polydispersity info
758            s = set(p.split('_pd')[0] for p in pars)
759            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
760            sys.exit(1)
761        presets[k] = float(v) if not k.endswith('type') else v
762
763    # randomize parameters
764    #pars.update(set_pars)  # set value before random to control range
765    if opts['seed'] > -1:
766        pars = randomize_pars(pars, seed=opts['seed'])
767        print("Randomize using -random=%i"%opts['seed'])
768    if opts['mono']:
769        pars = suppress_pd(pars)
770    pars.update(presets)  # set value after random to control value
771    constrain_pars(model_definition, pars)
772    constrain_new_to_old(model_definition, pars)
773    if opts['show_pars']:
774        print("pars " + str(parlist(pars)))
775
776    # Create the computational engines
777    data, _ = make_data(opts)
778    if n1:
779        base = make_engine(model_definition, data, engines[0], opts['cutoff'])
780    else:
781        base = None
782    if n2:
783        comp = make_engine(model_definition, data, engines[1], opts['cutoff'])
784    else:
785        comp = None
786
787    # pylint: disable=bad-whitespace
788    # Remember it all
789    opts.update({
790        'name'      : name,
791        'def'       : model_definition,
792        'n1'        : n1,
793        'n2'        : n2,
794        'presets'   : presets,
795        'pars'      : pars,
796        'data'      : data,
797        'engines'   : [base, comp],
798    })
799    # pylint: enable=bad-whitespace
800
801    return opts
802
803def explore(opts):
804    """
805    Explore the model using the Bumps GUI.
806    """
807    import wx
808    from bumps.names import FitProblem
809    from bumps.gui.app_frame import AppFrame
810
811    problem = FitProblem(Explore(opts))
812    is_mac = "cocoa" in wx.version()
813    app = wx.App()
814    frame = AppFrame(parent=None, title="explore")
815    if not is_mac: frame.Show()
816    frame.panel.set_model(model=problem)
817    frame.panel.Layout()
818    frame.panel.aui.Split(0, wx.TOP)
819    if is_mac: frame.Show()
820    app.MainLoop()
821
822class Explore(object):
823    """
824    Bumps wrapper for a SAS model comparison.
825
826    The resulting object can be used as a Bumps fit problem so that
827    parameters can be adjusted in the GUI, with plots updated on the fly.
828    """
829    def __init__(self, opts):
830        from bumps.cli import config_matplotlib
831        from . import bumps_model
832        config_matplotlib()
833        self.opts = opts
834        info = generate.make_info(opts['def'])
835        pars, pd_types = bumps_model.create_parameters(info, **opts['pars'])
836        if not opts['is2d']:
837            active = [base + ext
838                      for base in info['partype']['pd-1d']
839                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
840            active.extend(info['partype']['fixed-1d'])
841            for k in active:
842                v = pars[k]
843                v.range(*parameter_range(k, v.value))
844        else:
845            for k, v in pars.items():
846                v.range(*parameter_range(k, v.value))
847
848        self.pars = pars
849        self.pd_types = pd_types
850        self.limits = None
851
852    def numpoints(self):
853        """
854        Return the number of points.
855        """
856        return len(self.pars) + 1  # so dof is 1
857
858    def parameters(self):
859        """
860        Return a dictionary of parameters.
861        """
862        return self.pars
863
864    def nllf(self):
865        """
866        Return cost.
867        """
868        # pylint: disable=no-self-use
869        return 0.  # No nllf
870
871    def plot(self, view='log'):
872        """
873        Plot the data and residuals.
874        """
875        pars = dict((k, v.value) for k, v in self.pars.items())
876        pars.update(self.pd_types)
877        self.opts['pars'] = pars
878        limits = compare(self.opts, limits=self.limits)
879        if self.limits is None:
880            vmin, vmax = limits
881            vmax = 1.3*vmax
882            vmin = vmax*1e-7
883            self.limits = vmin, vmax
884
885
886def main():
887    """
888    Main program.
889    """
890    opts = parse_opts()
891    if opts['explore']:
892        explore(opts)
893    else:
894        compare(opts)
895
896if __name__ == "__main__":
897    main()
Note: See TracBrowser for help on using the repository browser.