source: sasmodels/sasmodels/compare.py @ a4a7308

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a4a7308 was a4a7308, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

nicer formatting for parameters

  • Property mode set to 100755
File size: 29.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33from os.path import basename, dirname, join as joinpath
34import glob
35import datetime
36import traceback
37
38import numpy as np
39
40from . import core
41from . import kerneldll
42from . import generate
43from .data import plot_theory, empty_data1D, empty_data2D
44from .direct_model import DirectModel
45from .convert import revert_model, constrain_new_to_old
46
47USAGE = """
48usage: compare.py model N1 N2 [options...] [key=val]
49
50Compare the speed and value for a model between the SasView original and the
51sasmodels rewrite.
52
53model is the name of the model to compare (see below).
54N1 is the number of times to run sasmodels (default=1).
55N2 is the number times to run sasview (default=1).
56
57Options (* for default):
58
59    -plot*/-noplot plots or suppress the plot of the model
60    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
61    -nq=128 sets the number of Q points in the data set
62    -1d*/-2d computes 1d or 2d data
63    -preset*/-random[=seed] preset or random parameters
64    -mono/-poly* force monodisperse/polydisperse
65    -cutoff=1e-5* cutoff value for including a point in polydispersity
66    -pars/-nopars* prints the parameter set or not
67    -abs/-rel* plot relative or absolute error
68    -linear/-log*/-q4 intensity scaling
69    -hist/-nohist* plot histogram of relative error
70    -res=0 sets the resolution width dQ/Q if calculating with resolution
71    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
72    -edit starts the parameter explorer
73
74Any two calculation engines can be selected for comparison:
75
76    -single/-double/-half/-fast sets an OpenCL calculation engine
77    -single!/-double!/-quad! sets an OpenMP calculation engine
78    -sasview sets the sasview calculation engine
79
80The default is -single -sasview.  Note that the interpretation of quad
81precision depends on architecture, and may vary from 64-bit to 128-bit,
82with 80-bit floats being common (1e-19 precision).
83
84Key=value pairs allow you to set specific values for the model parameters.
85"""
86
87# Update docs with command line usage string.   This is separate from the usual
88# doc string so that we can display it at run time if there is an error.
89# lin
90__doc__ = (__doc__  # pylint: disable=redefined-builtin
91           + """
92Program description
93-------------------
94
95"""
96           + USAGE)
97
98kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
99
100# List of available models
101ROOT = dirname(__file__)
102MODELS = [basename(f)[:-3]
103          for f in sorted(glob.glob(joinpath(ROOT, "models", "[a-zA-Z]*.py")))]
104
105# CRUFT python 2.6
106if not hasattr(datetime.timedelta, 'total_seconds'):
107    def delay(dt):
108        """Return number date-time delta as number seconds"""
109        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
110else:
111    def delay(dt):
112        """Return number date-time delta as number seconds"""
113        return dt.total_seconds()
114
115
116class push_seed(object):
117    """
118    Set the seed value for the random number generator.
119
120    When used in a with statement, the random number generator state is
121    restored after the with statement is complete.
122
123    :Parameters:
124
125    *seed* : int or array_like, optional
126        Seed for RandomState
127
128    :Example:
129
130    Seed can be used directly to set the seed::
131
132        >>> from numpy.random import randint
133        >>> push_seed(24)
134        <...push_seed object at...>
135        >>> print(randint(0,1000000,3))
136        [242082    899 211136]
137
138    Seed can also be used in a with statement, which sets the random
139    number generator state for the enclosed computations and restores
140    it to the previous state on completion::
141
142        >>> with push_seed(24):
143        ...    print(randint(0,1000000,3))
144        [242082    899 211136]
145
146    Using nested contexts, we can demonstrate that state is indeed
147    restored after the block completes::
148
149        >>> with push_seed(24):
150        ...    print(randint(0,1000000))
151        ...    with push_seed(24):
152        ...        print(randint(0,1000000,3))
153        ...    print(randint(0,1000000))
154        242082
155        [242082    899 211136]
156        899
157
158    The restore step is protected against exceptions in the block::
159
160        >>> with push_seed(24):
161        ...    print(randint(0,1000000))
162        ...    try:
163        ...        with push_seed(24):
164        ...            print(randint(0,1000000,3))
165        ...            raise Exception()
166        ...    except:
167        ...        print("Exception raised")
168        ...    print(randint(0,1000000))
169        242082
170        [242082    899 211136]
171        Exception raised
172        899
173    """
174    def __init__(self, seed=None):
175        self._state = np.random.get_state()
176        np.random.seed(seed)
177
178    def __enter__(self):
179        return None
180
181    def __exit__(self, *args):
182        np.random.set_state(self._state)
183
184def tic():
185    """
186    Timer function.
187
188    Use "toc=tic()" to start the clock and "toc()" to measure
189    a time interval.
190    """
191    then = datetime.datetime.now()
192    return lambda: delay(datetime.datetime.now() - then)
193
194
195def set_beam_stop(data, radius, outer=None):
196    """
197    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
198
199    Note: this function does not use the sasview package
200    """
201    if hasattr(data, 'qx_data'):
202        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
203        data.mask = (q < radius)
204        if outer is not None:
205            data.mask |= (q >= outer)
206    else:
207        data.mask = (data.x < radius)
208        if outer is not None:
209            data.mask |= (data.x >= outer)
210
211
212def parameter_range(p, v):
213    """
214    Choose a parameter range based on parameter name and initial value.
215    """
216    if p.endswith('_pd_n'):
217        return [0, 100]
218    elif p.endswith('_pd_nsigma'):
219        return [0, 5]
220    elif p.endswith('_pd_type'):
221        return v
222    elif any(s in p for s in ('theta', 'phi', 'psi')):
223        # orientation in [-180,180], orientation pd in [0,45]
224        if p.endswith('_pd'):
225            return [0, 45]
226        else:
227            return [-180, 180]
228    elif 'sld' in p:
229        return [-0.5, 10]
230    elif p.endswith('_pd'):
231        return [0, 1]
232    elif p == 'background':
233        return [0, 10]
234    elif p == 'scale':
235        return [0, 1e3]
236    elif p == 'case_num':
237        # RPA hack
238        return [0, 10]
239    elif v < 0:
240        # Kxy parameters in rpa model can be negative
241        return [2*v, -2*v]
242    else:
243        return [0, (2*v if v > 0 else 1)]
244
245
246def _randomize_one(p, v):
247    """
248    Randomize a single parameter.
249    """
250    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
251        return v
252    else:
253        return np.random.uniform(*parameter_range(p, v))
254
255
256def randomize_pars(pars, seed=None):
257    """
258    Generate random values for all of the parameters.
259
260    Valid ranges for the random number generator are guessed from the name of
261    the parameter; this will not account for constraints such as cap radius
262    greater than cylinder radius in the capped_cylinder model, so
263    :func:`constrain_pars` needs to be called afterward..
264    """
265    with push_seed(seed):
266        # Note: the sort guarantees order `of calls to random number generator
267        pars = dict((p, _randomize_one(p, v))
268                    for p, v in sorted(pars.items()))
269    return pars
270
271def constrain_pars(model_definition, pars):
272    """
273    Restrict parameters to valid values.
274
275    This includes model specific code for models such as capped_cylinder
276    which need to support within model constraints (cap radius more than
277    cylinder radius in this case).
278    """
279    name = model_definition.name
280    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
281        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
282    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
283        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
284
285    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
286    if name == 'guinier':
287        #q_max = 0.2  # mid q maximum
288        q_max = 1.0  # high q maximum
289        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
290        pars['rg'] = min(pars['rg'], rg_max)
291
292    if name == 'rpa':
293        # Make sure phi sums to 1.0
294        if pars['case_num'] < 2:
295            pars['Phia'] = 0.
296            pars['Phib'] = 0.
297        elif pars['case_num'] < 5:
298            pars['Phia'] = 0.
299        total = sum(pars['Phi'+c] for c in 'abcd')
300        for c in 'abcd':
301            pars['Phi'+c] /= total
302
303def parlist(pars):
304    """
305    Format the parameter list for printing.
306    """
307    active = None
308    fields = {}
309    lines = []
310    for k, v in sorted(pars.items()):
311        parts = k.split('_pd')
312        #print(k, active, parts)
313        if len(parts) == 1:
314            if active: lines.append(_format_par(active, **fields))
315            active = k
316            fields = {'value': v}
317        else:
318            assert parts[0] == active
319            if parts[1]:
320                fields[parts[1][1:]] = v
321            else:
322                fields['pd'] = v
323    if active: lines.append(_format_par(active, **fields))
324    return "\n".join(lines)
325
326    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
327
328def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
329    line = "%s: %g"%(name, value)
330    if pd != 0.  and n != 0:
331        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
332                % (pd, n, nsigma, nsigma, type)
333    return line
334
335def suppress_pd(pars):
336    """
337    Suppress theta_pd for now until the normalization is resolved.
338
339    May also suppress complete polydispersity of the model to test
340    models more quickly.
341    """
342    pars = pars.copy()
343    for p in pars:
344        if p.endswith("_pd_n"): pars[p] = 0
345    return pars
346
347def eval_sasview(model_definition, data):
348    """
349    Return a model calculator using the SasView fitting engine.
350    """
351    # importing sas here so that the error message will be that sas failed to
352    # import rather than the more obscure smear_selection not imported error
353    import sas
354    from sas.models.qsmearing import smear_selection
355
356    # convert model parameters from sasmodel form to sasview form
357    #print("old",sorted(pars.items()))
358    modelname, _ = revert_model(model_definition, {})
359    #print("new",sorted(_pars.items()))
360    sas = __import__('sas.models.'+modelname)
361    ModelClass = getattr(getattr(sas.models, modelname, None), modelname, None)
362    if ModelClass is None:
363        raise ValueError("could not find model %r in sas.models"%modelname)
364    model = ModelClass()
365    smearer = smear_selection(data, model=model)
366
367    if hasattr(data, 'qx_data'):
368        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
369        index = ((~data.mask) & (~np.isnan(data.data))
370                 & (q >= data.qmin) & (q <= data.qmax))
371        if smearer is not None:
372            smearer.model = model  # because smear_selection has a bug
373            smearer.accuracy = data.accuracy
374            smearer.set_index(index)
375            theory = lambda: smearer.get_value()
376        else:
377            theory = lambda: model.evalDistribution([data.qx_data[index],
378                                                     data.qy_data[index]])
379    elif smearer is not None:
380        theory = lambda: smearer(model.evalDistribution(data.x))
381    else:
382        theory = lambda: model.evalDistribution(data.x)
383
384    def calculator(**pars):
385        """
386        Sasview calculator for model.
387        """
388        # paying for parameter conversion each time to keep life simple, if not fast
389        _, pars = revert_model(model_definition, pars)
390        for k, v in pars.items():
391            parts = k.split('.')  # polydispersity components
392            if len(parts) == 2:
393                model.dispersion[parts[0]][parts[1]] = v
394            else:
395                model.setParam(k, v)
396        return theory()
397
398    calculator.engine = "sasview"
399    return calculator
400
401DTYPE_MAP = {
402    'half': '16',
403    'fast': 'fast',
404    'single': '32',
405    'double': '64',
406    'quad': '128',
407    'f16': '16',
408    'f32': '32',
409    'f64': '64',
410    'longdouble': '128',
411}
412def eval_opencl(model_definition, data, dtype='single', cutoff=0.):
413    """
414    Return a model calculator using the OpenCL calculation engine.
415    """
416    try:
417        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
418    except Exception as exc:
419        print(exc)
420        print("... trying again with single precision")
421        dtype = 'single'
422        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
423    calculator = DirectModel(data, model, cutoff=cutoff)
424    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
425    return calculator
426
427def eval_ctypes(model_definition, data, dtype='double', cutoff=0.):
428    """
429    Return a model calculator using the DLL calculation engine.
430    """
431    if dtype == 'quad':
432        dtype = 'longdouble'
433    model = core.load_model(model_definition, dtype=dtype, platform="dll")
434    calculator = DirectModel(data, model, cutoff=cutoff)
435    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
436    return calculator
437
438def time_calculation(calculator, pars, Nevals=1):
439    """
440    Compute the average calculation time over N evaluations.
441
442    An additional call is generated without polydispersity in order to
443    initialize the calculation engine, and make the average more stable.
444    """
445    # initialize the code so time is more accurate
446    value = calculator(**suppress_pd(pars))
447    toc = tic()
448    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
449        value = calculator(**pars)
450    average_time = toc()*1000./Nevals
451    return value, average_time
452
453def make_data(opts):
454    """
455    Generate an empty dataset, used with the model to set Q points
456    and resolution.
457
458    *opts* contains the options, with 'qmax', 'nq', 'res',
459    'accuracy', 'is2d' and 'view' parsed from the command line.
460    """
461    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
462    if opts['is2d']:
463        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
464        data.accuracy = opts['accuracy']
465        set_beam_stop(data, 0.004)
466        index = ~data.mask
467    else:
468        if opts['view'] == 'log':
469            qmax = math.log10(qmax)
470            q = np.logspace(qmax-3, qmax, nq)
471        else:
472            q = np.linspace(0.001*qmax, qmax, nq)
473        data = empty_data1D(q, resolution=res)
474        index = slice(None, None)
475    return data, index
476
477def make_engine(model_definition, data, dtype, cutoff):
478    """
479    Generate the appropriate calculation engine for the given datatype.
480
481    Datatypes with '!' appended are evaluated using external C DLLs rather
482    than OpenCL.
483    """
484    if dtype == 'sasview':
485        return eval_sasview(model_definition, data)
486    elif dtype.endswith('!'):
487        return eval_ctypes(model_definition, data, dtype=dtype[:-1],
488                           cutoff=cutoff)
489    else:
490        return eval_opencl(model_definition, data, dtype=dtype,
491                           cutoff=cutoff)
492
493def compare(opts, limits=None):
494    """
495    Preform a comparison using options from the command line.
496
497    *limits* are the limits on the values to use, either to set the y-axis
498    for 1D or to set the colormap scale for 2D.  If None, then they are
499    inferred from the data and returned. When exploring using Bumps,
500    the limits are set when the model is initially called, and maintained
501    as the values are adjusted, making it easier to see the effects of the
502    parameters.
503    """
504    Nbase, Ncomp = opts['n1'], opts['n2']
505    pars = opts['pars']
506    data = opts['data']
507
508    # Base calculation
509    if Nbase > 0:
510        base = opts['engines'][0]
511        try:
512            base_value, base_time = time_calculation(base, pars, Nbase)
513            print("%s t=%.1f ms, intensity=%.0f"
514                  % (base.engine, base_time, sum(base_value)))
515        except ImportError:
516            traceback.print_exc()
517            Nbase = 0
518
519    # Comparison calculation
520    if Ncomp > 0:
521        comp = opts['engines'][1]
522        try:
523            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
524            print("%s t=%.1f ms, intensity=%.0f"
525                  % (comp.engine, comp_time, sum(comp_value)))
526        except ImportError:
527            traceback.print_exc()
528            Ncomp = 0
529
530    # Compare, but only if computing both forms
531    if Nbase > 0 and Ncomp > 0:
532        resid = (base_value - comp_value)
533        relerr = resid/comp_value
534        _print_stats("|%s-%s|"
535                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
536                     resid)
537        _print_stats("|(%s-%s)/%s|"
538                     % (base.engine, comp.engine, comp.engine),
539                     relerr)
540
541    # Plot if requested
542    if not opts['plot'] and not opts['explore']: return
543    view = opts['view']
544    import matplotlib.pyplot as plt
545    if limits is None:
546        vmin, vmax = np.Inf, -np.Inf
547        if Nbase > 0:
548            vmin = min(vmin, min(base_value))
549            vmax = max(vmax, max(base_value))
550        if Ncomp > 0:
551            vmin = min(vmin, min(comp_value))
552            vmax = max(vmax, max(comp_value))
553        limits = vmin, vmax
554
555    if Nbase > 0:
556        if Ncomp > 0: plt.subplot(131)
557        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
558        plt.title("%s t=%.1f ms"%(base.engine, base_time))
559        #cbar_title = "log I"
560    if Ncomp > 0:
561        if Nbase > 0: plt.subplot(132)
562        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
563        plt.title("%s t=%.1f ms"%(comp.engine, comp_time))
564        #cbar_title = "log I"
565    if Ncomp > 0 and Nbase > 0:
566        plt.subplot(133)
567        if not opts['rel_err']:
568            err, errstr, errview = resid, "abs err", "linear"
569        else:
570            err, errstr, errview = abs(relerr), "rel err", "log"
571        #err,errstr = base/comp,"ratio"
572        plot_theory(data, None, resid=err, view=errview, use_data=False)
573        if view == 'linear':
574            plt.xscale('linear')
575        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
576        #cbar_title = errstr if errview=="linear" else "log "+errstr
577    #if is2D:
578    #    h = plt.colorbar()
579    #    h.ax.set_title(cbar_title)
580
581    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
582        plt.figure()
583        v = relerr
584        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
585        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
586        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
587                   % (base.engine, comp.engine, comp.engine))
588        plt.ylabel('P(err)')
589        plt.title('Distribution of relative error between calculation engines')
590
591    if not opts['explore']:
592        plt.show()
593
594    return limits
595
596def _print_stats(label, err):
597    sorted_err = np.sort(abs(err))
598    p50 = int((len(err)-1)*0.50)
599    p98 = int((len(err)-1)*0.98)
600    data = [
601        "max:%.3e"%sorted_err[-1],
602        "median:%.3e"%sorted_err[p50],
603        "98%%:%.3e"%sorted_err[p98],
604        "rms:%.3e"%np.sqrt(np.mean(err**2)),
605        "zero-offset:%+.3e"%np.mean(err),
606        ]
607    print(label+"  "+"  ".join(data))
608
609
610
611# ===========================================================================
612#
613NAME_OPTIONS = set([
614    'plot', 'noplot',
615    'half', 'fast', 'single', 'double',
616    'single!', 'double!', 'quad!', 'sasview',
617    'lowq', 'midq', 'highq', 'exq',
618    '2d', '1d',
619    'preset', 'random',
620    'poly', 'mono',
621    'nopars', 'pars',
622    'rel', 'abs',
623    'linear', 'log', 'q4',
624    'hist', 'nohist',
625    'edit',
626    ])
627VALUE_OPTIONS = [
628    # Note: random is both a name option and a value option
629    'cutoff', 'random', 'nq', 'res', 'accuracy',
630    ]
631
632def columnize(L, indent="", width=79):
633    """
634    Format a list of strings into columns.
635
636    Returns a string with carriage returns ready for printing.
637    """
638    column_width = max(len(w) for w in L) + 1
639    num_columns = (width - len(indent)) // column_width
640    num_rows = len(L) // num_columns
641    L = L + [""] * (num_rows*num_columns - len(L))
642    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
643    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
644             for row in zip(*columns)]
645    output = indent + ("\n"+indent).join(lines)
646    return output
647
648
649def get_demo_pars(model_definition):
650    """
651    Extract demo parameters from the model definition.
652    """
653    info = generate.make_info(model_definition)
654    # Get the default values for the parameters
655    pars = dict((p[0], p[2]) for p in info['parameters'])
656
657    # Fill in default values for the polydispersity parameters
658    for p in info['parameters']:
659        if p[4] in ('volume', 'orientation'):
660            pars[p[0]+'_pd'] = 0.0
661            pars[p[0]+'_pd_n'] = 0
662            pars[p[0]+'_pd_nsigma'] = 3.0
663            pars[p[0]+'_pd_type'] = "gaussian"
664
665    # Plug in values given in demo
666    pars.update(info['demo'])
667    return pars
668
669def parse_opts():
670    """
671    Parse command line options.
672    """
673    flags = [arg for arg in sys.argv[1:]
674             if arg.startswith('-')]
675    values = [arg for arg in sys.argv[1:]
676              if not arg.startswith('-') and '=' in arg]
677    args = [arg for arg in sys.argv[1:]
678            if not arg.startswith('-') and '=' not in arg]
679    models = "\n    ".join("%-15s"%v for v in MODELS)
680    if len(args) == 0:
681        print(USAGE)
682        print("\nAvailable models:")
683        print(columnize(MODELS, indent="  "))
684        sys.exit(1)
685    if args[0] not in MODELS:
686        print("Model %r not available. Use one of:\n    %s"%(args[0], models))
687        sys.exit(1)
688    if len(args) > 3:
689        print("expected parameters: model N1 N2")
690
691    invalid = [o[1:] for o in flags
692               if o[1:] not in NAME_OPTIONS
693               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
694    if invalid:
695        print("Invalid options: %s"%(", ".join(invalid)))
696        sys.exit(1)
697
698
699    # pylint: disable=bad-whitespace
700    # Interpret the flags
701    opts = {
702        'plot'      : True,
703        'view'      : 'log',
704        'is2d'      : False,
705        'qmax'      : 0.05,
706        'nq'        : 128,
707        'res'       : 0.0,
708        'accuracy'  : 'Low',
709        'cutoff'    : 1e-5,
710        'seed'      : -1,  # default to preset
711        'mono'      : False,
712        'show_pars' : False,
713        'show_hist' : False,
714        'rel_err'   : True,
715        'explore'   : False,
716    }
717    engines = []
718    for arg in flags:
719        if arg == '-noplot':    opts['plot'] = False
720        elif arg == '-plot':    opts['plot'] = True
721        elif arg == '-linear':  opts['view'] = 'linear'
722        elif arg == '-log':     opts['view'] = 'log'
723        elif arg == '-q4':      opts['view'] = 'q4'
724        elif arg == '-1d':      opts['is2d'] = False
725        elif arg == '-2d':      opts['is2d'] = True
726        elif arg == '-exq':     opts['qmax'] = 10.0
727        elif arg == '-highq':   opts['qmax'] = 1.0
728        elif arg == '-midq':    opts['qmax'] = 0.2
729        elif arg == '-lowq':    opts['qmax'] = 0.05
730        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
731        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
732        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
733        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
734        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
735        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
736        elif arg == '-preset':  opts['seed'] = -1
737        elif arg == '-mono':    opts['mono'] = True
738        elif arg == '-poly':    opts['mono'] = False
739        elif arg == '-pars':    opts['show_pars'] = True
740        elif arg == '-nopars':  opts['show_pars'] = False
741        elif arg == '-hist':    opts['show_hist'] = True
742        elif arg == '-nohist':  opts['show_hist'] = False
743        elif arg == '-rel':     opts['rel_err'] = True
744        elif arg == '-abs':     opts['rel_err'] = False
745        elif arg == '-half':    engines.append(arg[1:])
746        elif arg == '-fast':    engines.append(arg[1:])
747        elif arg == '-single':  engines.append(arg[1:])
748        elif arg == '-double':  engines.append(arg[1:])
749        elif arg == '-single!': engines.append(arg[1:])
750        elif arg == '-double!': engines.append(arg[1:])
751        elif arg == '-quad!':   engines.append(arg[1:])
752        elif arg == '-sasview': engines.append(arg[1:])
753        elif arg == '-edit':    opts['explore'] = True
754    # pylint: enable=bad-whitespace
755
756    if len(engines) == 0:
757        engines.extend(['single', 'sasview'])
758    elif len(engines) == 1:
759        if engines[0][0] != 'sasview':
760            engines.append('sasview')
761        else:
762            engines.append('single')
763    elif len(engines) > 2:
764        del engines[2:]
765
766    name = args[0]
767    model_definition = core.load_model_definition(name)
768
769    n1 = int(args[1]) if len(args) > 1 else 1
770    n2 = int(args[2]) if len(args) > 2 else 1
771
772    # Get demo parameters from model definition, or use default parameters
773    # if model does not define demo parameters
774    pars = get_demo_pars(model_definition)
775
776    # Fill in parameters given on the command line
777    presets = {}
778    for arg in values:
779        k, v = arg.split('=', 1)
780        if k not in pars:
781            # extract base name without polydispersity info
782            s = set(p.split('_pd')[0] for p in pars)
783            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
784            sys.exit(1)
785        presets[k] = float(v) if not k.endswith('type') else v
786
787    # randomize parameters
788    #pars.update(set_pars)  # set value before random to control range
789    if opts['seed'] > -1:
790        pars = randomize_pars(pars, seed=opts['seed'])
791        print("Randomize using -random=%i"%opts['seed'])
792    if opts['mono']:
793        pars = suppress_pd(pars)
794    pars.update(presets)  # set value after random to control value
795    constrain_pars(model_definition, pars)
796    constrain_new_to_old(model_definition, pars)
797    if opts['show_pars']:
798        print(str(parlist(pars)))
799
800    # Create the computational engines
801    data, _ = make_data(opts)
802    if n1:
803        base = make_engine(model_definition, data, engines[0], opts['cutoff'])
804    else:
805        base = None
806    if n2:
807        comp = make_engine(model_definition, data, engines[1], opts['cutoff'])
808    else:
809        comp = None
810
811    # pylint: disable=bad-whitespace
812    # Remember it all
813    opts.update({
814        'name'      : name,
815        'def'       : model_definition,
816        'n1'        : n1,
817        'n2'        : n2,
818        'presets'   : presets,
819        'pars'      : pars,
820        'data'      : data,
821        'engines'   : [base, comp],
822    })
823    # pylint: enable=bad-whitespace
824
825    return opts
826
827def explore(opts):
828    """
829    Explore the model using the Bumps GUI.
830    """
831    import wx
832    from bumps.names import FitProblem
833    from bumps.gui.app_frame import AppFrame
834
835    problem = FitProblem(Explore(opts))
836    is_mac = "cocoa" in wx.version()
837    app = wx.App()
838    frame = AppFrame(parent=None, title="explore")
839    if not is_mac: frame.Show()
840    frame.panel.set_model(model=problem)
841    frame.panel.Layout()
842    frame.panel.aui.Split(0, wx.TOP)
843    if is_mac: frame.Show()
844    app.MainLoop()
845
846class Explore(object):
847    """
848    Bumps wrapper for a SAS model comparison.
849
850    The resulting object can be used as a Bumps fit problem so that
851    parameters can be adjusted in the GUI, with plots updated on the fly.
852    """
853    def __init__(self, opts):
854        from bumps.cli import config_matplotlib
855        from . import bumps_model
856        config_matplotlib()
857        self.opts = opts
858        info = generate.make_info(opts['def'])
859        pars, pd_types = bumps_model.create_parameters(info, **opts['pars'])
860        if not opts['is2d']:
861            active = [base + ext
862                      for base in info['partype']['pd-1d']
863                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
864            active.extend(info['partype']['fixed-1d'])
865            for k in active:
866                v = pars[k]
867                v.range(*parameter_range(k, v.value))
868        else:
869            for k, v in pars.items():
870                v.range(*parameter_range(k, v.value))
871
872        self.pars = pars
873        self.pd_types = pd_types
874        self.limits = None
875
876    def numpoints(self):
877        """
878        Return the number of points.
879        """
880        return len(self.pars) + 1  # so dof is 1
881
882    def parameters(self):
883        """
884        Return a dictionary of parameters.
885        """
886        return self.pars
887
888    def nllf(self):
889        """
890        Return cost.
891        """
892        # pylint: disable=no-self-use
893        return 0.  # No nllf
894
895    def plot(self, view='log'):
896        """
897        Plot the data and residuals.
898        """
899        pars = dict((k, v.value) for k, v in self.pars.items())
900        pars.update(self.pd_types)
901        self.opts['pars'] = pars
902        limits = compare(self.opts, limits=self.limits)
903        if self.limits is None:
904            vmin, vmax = limits
905            vmax = 1.3*vmax
906            vmin = vmax*1e-7
907            self.limits = vmin, vmax
908
909
910def main():
911    """
912    Main program.
913    """
914    opts = parse_opts()
915    if opts['explore']:
916        explore(opts)
917    else:
918        compare(opts)
919
920if __name__ == "__main__":
921    main()
Note: See TracBrowser for help on using the repository browser.