source: sasmodels/sasmodels/compare.py @ d2bb604

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since d2bb604 was d2bb604, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

fix models so all dll tests pass

  • Property mode set to 100755
File size: 30.5 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np
37
38from . import core
39from . import kerneldll
40from .data import plot_theory, empty_data1D, empty_data2D
41from .direct_model import DirectModel
42from .convert import revert_name, revert_pars, constrain_new_to_old
43
44USAGE = """
45usage: compare.py model N1 N2 [options...] [key=val]
46
47Compare the speed and value for a model between the SasView original and the
48sasmodels rewrite.
49
50model is the name of the model to compare (see below).
51N1 is the number of times to run sasmodels (default=1).
52N2 is the number times to run sasview (default=1).
53
54Options (* for default):
55
56    -plot*/-noplot plots or suppress the plot of the model
57    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
58    -nq=128 sets the number of Q points in the data set
59    -1d*/-2d computes 1d or 2d data
60    -preset*/-random[=seed] preset or random parameters
61    -mono/-poly* force monodisperse/polydisperse
62    -cutoff=1e-5* cutoff value for including a point in polydispersity
63    -pars/-nopars* prints the parameter set or not
64    -abs/-rel* plot relative or absolute error
65    -linear/-log*/-q4 intensity scaling
66    -hist/-nohist* plot histogram of relative error
67    -res=0 sets the resolution width dQ/Q if calculating with resolution
68    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
69    -edit starts the parameter explorer
70    -default/-demo* use demo vs default parameters
71
72Any two calculation engines can be selected for comparison:
73
74    -single/-double/-half/-fast sets an OpenCL calculation engine
75    -single!/-double!/-quad! sets an OpenMP calculation engine
76    -sasview sets the sasview calculation engine
77
78The default is -single -sasview.  Note that the interpretation of quad
79precision depends on architecture, and may vary from 64-bit to 128-bit,
80with 80-bit floats being common (1e-19 precision).
81
82Key=value pairs allow you to set specific values for the model parameters.
83"""
84
85# Update docs with command line usage string.   This is separate from the usual
86# doc string so that we can display it at run time if there is an error.
87# lin
88__doc__ = (__doc__  # pylint: disable=redefined-builtin
89           + """
90Program description
91-------------------
92
93"""
94           + USAGE)
95
96kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
97
98MODELS = core.list_models()
99
100# CRUFT python 2.6
101if not hasattr(datetime.timedelta, 'total_seconds'):
102    def delay(dt):
103        """Return number date-time delta as number seconds"""
104        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
105else:
106    def delay(dt):
107        """Return number date-time delta as number seconds"""
108        return dt.total_seconds()
109
110
111class push_seed(object):
112    """
113    Set the seed value for the random number generator.
114
115    When used in a with statement, the random number generator state is
116    restored after the with statement is complete.
117
118    :Parameters:
119
120    *seed* : int or array_like, optional
121        Seed for RandomState
122
123    :Example:
124
125    Seed can be used directly to set the seed::
126
127        >>> from numpy.random import randint
128        >>> push_seed(24)
129        <...push_seed object at...>
130        >>> print(randint(0,1000000,3))
131        [242082    899 211136]
132
133    Seed can also be used in a with statement, which sets the random
134    number generator state for the enclosed computations and restores
135    it to the previous state on completion::
136
137        >>> with push_seed(24):
138        ...    print(randint(0,1000000,3))
139        [242082    899 211136]
140
141    Using nested contexts, we can demonstrate that state is indeed
142    restored after the block completes::
143
144        >>> with push_seed(24):
145        ...    print(randint(0,1000000))
146        ...    with push_seed(24):
147        ...        print(randint(0,1000000,3))
148        ...    print(randint(0,1000000))
149        242082
150        [242082    899 211136]
151        899
152
153    The restore step is protected against exceptions in the block::
154
155        >>> with push_seed(24):
156        ...    print(randint(0,1000000))
157        ...    try:
158        ...        with push_seed(24):
159        ...            print(randint(0,1000000,3))
160        ...            raise Exception()
161        ...    except:
162        ...        print("Exception raised")
163        ...    print(randint(0,1000000))
164        242082
165        [242082    899 211136]
166        Exception raised
167        899
168    """
169    def __init__(self, seed=None):
170        self._state = np.random.get_state()
171        np.random.seed(seed)
172
173    def __enter__(self):
174        return None
175
176    def __exit__(self, *args):
177        np.random.set_state(self._state)
178
179def tic():
180    """
181    Timer function.
182
183    Use "toc=tic()" to start the clock and "toc()" to measure
184    a time interval.
185    """
186    then = datetime.datetime.now()
187    return lambda: delay(datetime.datetime.now() - then)
188
189
190def set_beam_stop(data, radius, outer=None):
191    """
192    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
193
194    Note: this function does not use the sasview package
195    """
196    if hasattr(data, 'qx_data'):
197        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
198        data.mask = (q < radius)
199        if outer is not None:
200            data.mask |= (q >= outer)
201    else:
202        data.mask = (data.x < radius)
203        if outer is not None:
204            data.mask |= (data.x >= outer)
205
206
207def parameter_range(p, v):
208    """
209    Choose a parameter range based on parameter name and initial value.
210    """
211    if p.endswith('_pd_n'):
212        return [0, 100]
213    elif p.endswith('_pd_nsigma'):
214        return [0, 5]
215    elif p.endswith('_pd_type'):
216        return v
217    elif any(s in p for s in ('theta', 'phi', 'psi')):
218        # orientation in [-180,180], orientation pd in [0,45]
219        if p.endswith('_pd'):
220            return [0, 45]
221        else:
222            return [-180, 180]
223    elif 'sld' in p:
224        return [-0.5, 10]
225    elif p.endswith('_pd'):
226        return [0, 1]
227    elif p == 'background':
228        return [0, 10]
229    elif p == 'scale':
230        return [0, 1e3]
231    elif p == 'case_num':
232        # RPA hack
233        return [0, 10]
234    elif v < 0:
235        # Kxy parameters in rpa model can be negative
236        return [2*v, -2*v]
237    else:
238        return [0, (2*v if v > 0 else 1)]
239
240
241def _randomize_one(p, v):
242    """
243    Randomize a single parameter.
244    """
245    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
246        return v
247    else:
248        return np.random.uniform(*parameter_range(p, v))
249
250
251def randomize_pars(pars, seed=None):
252    """
253    Generate random values for all of the parameters.
254
255    Valid ranges for the random number generator are guessed from the name of
256    the parameter; this will not account for constraints such as cap radius
257    greater than cylinder radius in the capped_cylinder model, so
258    :func:`constrain_pars` needs to be called afterward..
259    """
260    with push_seed(seed):
261        # Note: the sort guarantees order `of calls to random number generator
262        pars = dict((p, _randomize_one(p, v))
263                    for p, v in sorted(pars.items()))
264    return pars
265
266def constrain_pars(model_info, pars):
267    """
268    Restrict parameters to valid values.
269
270    This includes model specific code for models such as capped_cylinder
271    which need to support within model constraints (cap radius more than
272    cylinder radius in this case).
273    """
274    name = model_info['id']
275    # if it is a product model, then just look at the form factor since
276    # none of the structure factors need any constraints.
277    if '*' in name:
278        name = name.split('*')[0]
279
280    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
281        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
282    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
283        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
284
285    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
286    if name == 'guinier':
287        #q_max = 0.2  # mid q maximum
288        q_max = 1.0  # high q maximum
289        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
290        pars['rg'] = min(pars['rg'], rg_max)
291
292    if name == 'rpa':
293        # Make sure phi sums to 1.0
294        if pars['case_num'] < 2:
295            pars['Phia'] = 0.
296            pars['Phib'] = 0.
297        elif pars['case_num'] < 5:
298            pars['Phia'] = 0.
299        total = sum(pars['Phi'+c] for c in 'abcd')
300        for c in 'abcd':
301            pars['Phi'+c] /= total
302
303def parlist(model_info, pars, is2d):
304    """
305    Format the parameter list for printing.
306    """
307    lines = []
308    parameters = model_info['parameters']
309    for p in parameters.user_parameters(pars, is2d):
310        fields = dict(
311            value=pars.get(p.id, p.default),
312            pd=pars.get(p.id+"_pd", 0.),
313            n=int(pars.get(p.id+"_pd_n", 0)),
314            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
315            type=pars.get(p.id+"_pd_type", 'gaussian'))
316        lines.append(_format_par(p.name, **fields))
317    return "\n".join(lines)
318
319    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
320
321def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
322    line = "%s: %g"%(name, value)
323    if pd != 0.  and n != 0:
324        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
325                % (pd, n, nsigma, nsigma, type)
326    return line
327
328def suppress_pd(pars):
329    """
330    Suppress theta_pd for now until the normalization is resolved.
331
332    May also suppress complete polydispersity of the model to test
333    models more quickly.
334    """
335    pars = pars.copy()
336    for p in pars:
337        if p.endswith("_pd_n"): pars[p] = 0
338    return pars
339
340def eval_sasview(model_info, data):
341    """
342    Return a model calculator using the pre-4.0 SasView models.
343    """
344    # importing sas here so that the error message will be that sas failed to
345    # import rather than the more obscure smear_selection not imported error
346    import sas
347    from sas.models.qsmearing import smear_selection
348
349    def get_model(name):
350        #print("new",sorted(_pars.items()))
351        sas = __import__('sas.models.' + name)
352        ModelClass = getattr(getattr(sas.models, name, None), name, None)
353        if ModelClass is None:
354            raise ValueError("could not find model %r in sas.models"%name)
355        return ModelClass()
356
357    # grab the sasview model, or create it if it is a product model
358    if model_info['composition']:
359        composition_type, parts = model_info['composition']
360        if composition_type == 'product':
361            from sas.models.MultiplicationModel import MultiplicationModel
362            P, S = [get_model(revert_name(p)) for p in parts]
363            model = MultiplicationModel(P, S)
364        else:
365            raise ValueError("sasview mixture models not supported by compare")
366    else:
367        model = get_model(revert_name(model_info))
368
369    # build a smearer with which to call the model, if necessary
370    smearer = smear_selection(data, model=model)
371    if hasattr(data, 'qx_data'):
372        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
373        index = ((~data.mask) & (~np.isnan(data.data))
374                 & (q >= data.qmin) & (q <= data.qmax))
375        if smearer is not None:
376            smearer.model = model  # because smear_selection has a bug
377            smearer.accuracy = data.accuracy
378            smearer.set_index(index)
379            theory = lambda: smearer.get_value()
380        else:
381            theory = lambda: model.evalDistribution([data.qx_data[index],
382                                                     data.qy_data[index]])
383    elif smearer is not None:
384        theory = lambda: smearer(model.evalDistribution(data.x))
385    else:
386        theory = lambda: model.evalDistribution(data.x)
387
388    def calculator(**pars):
389        """
390        Sasview calculator for model.
391        """
392        # paying for parameter conversion each time to keep life simple, if not fast
393        pars = revert_pars(model_info, pars)
394        for k, v in pars.items():
395            parts = k.split('.')  # polydispersity components
396            if len(parts) == 2:
397                model.dispersion[parts[0]][parts[1]] = v
398            else:
399                model.setParam(k, v)
400        return theory()
401
402    calculator.engine = "sasview"
403    return calculator
404
405DTYPE_MAP = {
406    'half': '16',
407    'fast': 'fast',
408    'single': '32',
409    'double': '64',
410    'quad': '128',
411    'f16': '16',
412    'f32': '32',
413    'f64': '64',
414    'longdouble': '128',
415}
416def eval_opencl(model_info, data, dtype='single', cutoff=0.):
417    """
418    Return a model calculator using the OpenCL calculation engine.
419    """
420    try:
421        model = core.build_model(model_info, dtype=dtype, platform="ocl")
422    except Exception as exc:
423        print(exc)
424        print("... trying again with single precision")
425        model = core.build_model(model_info, dtype='single', platform="ocl")
426    calculator = DirectModel(data, model, cutoff=cutoff)
427    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
428    return calculator
429
430def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
431    """
432    Return a model calculator using the DLL calculation engine.
433    """
434    if dtype == 'quad':
435        dtype = 'longdouble'
436    model = core.build_model(model_info, dtype=dtype, platform="dll")
437    calculator = DirectModel(data, model, cutoff=cutoff)
438    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
439    return calculator
440
441def time_calculation(calculator, pars, Nevals=1):
442    """
443    Compute the average calculation time over N evaluations.
444
445    An additional call is generated without polydispersity in order to
446    initialize the calculation engine, and make the average more stable.
447    """
448    # initialize the code so time is more accurate
449    if Nevals > 1:
450        value = calculator(**suppress_pd(pars))
451    toc = tic()
452    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
453        value = calculator(**pars)
454    average_time = toc()*1000./Nevals
455    return value, average_time
456
457def make_data(opts):
458    """
459    Generate an empty dataset, used with the model to set Q points
460    and resolution.
461
462    *opts* contains the options, with 'qmax', 'nq', 'res',
463    'accuracy', 'is2d' and 'view' parsed from the command line.
464    """
465    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
466    if opts['is2d']:
467        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
468        data.accuracy = opts['accuracy']
469        set_beam_stop(data, 0.0004)
470        index = ~data.mask
471    else:
472        if opts['view'] == 'log':
473            qmax = math.log10(qmax)
474            q = np.logspace(qmax-3, qmax, nq)
475        else:
476            q = np.linspace(0.001*qmax, qmax, nq)
477        data = empty_data1D(q, resolution=res)
478        index = slice(None, None)
479    return data, index
480
481def make_engine(model_info, data, dtype, cutoff):
482    """
483    Generate the appropriate calculation engine for the given datatype.
484
485    Datatypes with '!' appended are evaluated using external C DLLs rather
486    than OpenCL.
487    """
488    if dtype == 'sasview':
489        return eval_sasview(model_info, data)
490    elif dtype.endswith('!'):
491        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
492    else:
493        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
494
495def compare(opts, limits=None):
496    """
497    Preform a comparison using options from the command line.
498
499    *limits* are the limits on the values to use, either to set the y-axis
500    for 1D or to set the colormap scale for 2D.  If None, then they are
501    inferred from the data and returned. When exploring using Bumps,
502    the limits are set when the model is initially called, and maintained
503    as the values are adjusted, making it easier to see the effects of the
504    parameters.
505    """
506    Nbase, Ncomp = opts['n1'], opts['n2']
507    pars = opts['pars']
508    data = opts['data']
509
510    # Base calculation
511    if Nbase > 0:
512        base = opts['engines'][0]
513        try:
514            base_value, base_time = time_calculation(base, pars, Nbase)
515            print("%s t=%.2f ms, intensity=%.0f"
516                  % (base.engine, base_time, sum(base_value)))
517        except ImportError:
518            traceback.print_exc()
519            Nbase = 0
520
521    # Comparison calculation
522    if Ncomp > 0:
523        comp = opts['engines'][1]
524        try:
525            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
526            print("%s t=%.2f ms, intensity=%.0f"
527                  % (comp.engine, comp_time, sum(comp_value)))
528        except ImportError:
529            traceback.print_exc()
530            Ncomp = 0
531
532    # Compare, but only if computing both forms
533    if Nbase > 0 and Ncomp > 0:
534        resid = (base_value - comp_value)
535        relerr = resid/comp_value
536        _print_stats("|%s-%s|"
537                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
538                     resid)
539        _print_stats("|(%s-%s)/%s|"
540                     % (base.engine, comp.engine, comp.engine),
541                     relerr)
542
543    # Plot if requested
544    if not opts['plot'] and not opts['explore']: return
545    view = opts['view']
546    import matplotlib.pyplot as plt
547    if limits is None:
548        vmin, vmax = np.Inf, -np.Inf
549        if Nbase > 0:
550            vmin = min(vmin, min(base_value))
551            vmax = max(vmax, max(base_value))
552        if Ncomp > 0:
553            vmin = min(vmin, min(comp_value))
554            vmax = max(vmax, max(comp_value))
555        limits = vmin, vmax
556
557    if Nbase > 0:
558        if Ncomp > 0: plt.subplot(131)
559        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
560        plt.title("%s t=%.2f ms"%(base.engine, base_time))
561        #cbar_title = "log I"
562    if Ncomp > 0:
563        if Nbase > 0: plt.subplot(132)
564        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
565        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
566        #cbar_title = "log I"
567    if Ncomp > 0 and Nbase > 0:
568        plt.subplot(133)
569        if not opts['rel_err']:
570            err, errstr, errview = resid, "abs err", "linear"
571        else:
572            err, errstr, errview = abs(relerr), "rel err", "log"
573        #err,errstr = base/comp,"ratio"
574        plot_theory(data, None, resid=err, view=errview, use_data=False)
575        if view == 'linear':
576            plt.xscale('linear')
577        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
578        #cbar_title = errstr if errview=="linear" else "log "+errstr
579    #if is2D:
580    #    h = plt.colorbar()
581    #    h.ax.set_title(cbar_title)
582
583    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
584        plt.figure()
585        v = relerr
586        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
587        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
588        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
589                   % (base.engine, comp.engine, comp.engine))
590        plt.ylabel('P(err)')
591        plt.title('Distribution of relative error between calculation engines')
592
593    if not opts['explore']:
594        plt.show()
595
596    return limits
597
598def _print_stats(label, err):
599    sorted_err = np.sort(abs(err))
600    p50 = int((len(err)-1)*0.50)
601    p98 = int((len(err)-1)*0.98)
602    data = [
603        "max:%.3e"%sorted_err[-1],
604        "median:%.3e"%sorted_err[p50],
605        "98%%:%.3e"%sorted_err[p98],
606        "rms:%.3e"%np.sqrt(np.mean(err**2)),
607        "zero-offset:%+.3e"%np.mean(err),
608        ]
609    print(label+"  "+"  ".join(data))
610
611
612
613# ===========================================================================
614#
615NAME_OPTIONS = set([
616    'plot', 'noplot',
617    'half', 'fast', 'single', 'double',
618    'single!', 'double!', 'quad!', 'sasview',
619    'lowq', 'midq', 'highq', 'exq',
620    '2d', '1d',
621    'preset', 'random',
622    'poly', 'mono',
623    'nopars', 'pars',
624    'rel', 'abs',
625    'linear', 'log', 'q4',
626    'hist', 'nohist',
627    'edit',
628    'demo', 'default',
629    ])
630VALUE_OPTIONS = [
631    # Note: random is both a name option and a value option
632    'cutoff', 'random', 'nq', 'res', 'accuracy',
633    ]
634
635def columnize(L, indent="", width=79):
636    """
637    Format a list of strings into columns.
638
639    Returns a string with carriage returns ready for printing.
640    """
641    column_width = max(len(w) for w in L) + 1
642    num_columns = (width - len(indent)) // column_width
643    num_rows = len(L) // num_columns
644    L = L + [""] * (num_rows*num_columns - len(L))
645    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
646    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
647             for row in zip(*columns)]
648    output = indent + ("\n"+indent).join(lines)
649    return output
650
651
652def get_pars(model_info, use_demo=False):
653    """
654    Extract demo parameters from the model definition.
655    """
656    # Get the default values for the parameters
657    pars = {}
658    for p in model_info['parameters'].call_parameters:
659        parts = [('', p.default)]
660        if p.polydisperse:
661            parts.append(('_pd', 0.0))
662            parts.append(('_pd_n', 0))
663            parts.append(('_pd_nsigma', 3.0))
664            parts.append(('_pd_type', "gaussian"))
665        for ext, val in parts:
666            if p.length > 1:
667                dict(("%s%d%s"%(p.id,k,ext), val) for k in range(p.length))
668            else:
669                pars[p.id+ext] = val
670
671    # Plug in values given in demo
672    if use_demo:
673        pars.update(model_info['demo'])
674    return pars
675
676
677def parse_opts():
678    """
679    Parse command line options.
680    """
681    MODELS = core.list_models()
682    flags = [arg for arg in sys.argv[1:]
683             if arg.startswith('-')]
684    values = [arg for arg in sys.argv[1:]
685              if not arg.startswith('-') and '=' in arg]
686    args = [arg for arg in sys.argv[1:]
687            if not arg.startswith('-') and '=' not in arg]
688    models = "\n    ".join("%-15s"%v for v in MODELS)
689    if len(args) == 0:
690        print(USAGE)
691        print("\nAvailable models:")
692        print(columnize(MODELS, indent="  "))
693        sys.exit(1)
694    if len(args) > 3:
695        print("expected parameters: model N1 N2")
696
697    name = args[0]
698    try:
699        model_info = core.load_model_info(name)
700    except ImportError, exc:
701        print(str(exc))
702        print("Could not find model; use one of:\n    " + models)
703        sys.exit(1)
704
705    invalid = [o[1:] for o in flags
706               if o[1:] not in NAME_OPTIONS
707               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
708    if invalid:
709        print("Invalid options: %s"%(", ".join(invalid)))
710        sys.exit(1)
711
712
713    # pylint: disable=bad-whitespace
714    # Interpret the flags
715    opts = {
716        'plot'      : True,
717        'view'      : 'log',
718        'is2d'      : False,
719        'qmax'      : 0.05,
720        'nq'        : 128,
721        'res'       : 0.0,
722        'accuracy'  : 'Low',
723        'cutoff'    : 0.0,
724        'seed'      : -1,  # default to preset
725        'mono'      : False,
726        'show_pars' : False,
727        'show_hist' : False,
728        'rel_err'   : True,
729        'explore'   : False,
730        'use_demo'  : True,
731    }
732    engines = []
733    for arg in flags:
734        if arg == '-noplot':    opts['plot'] = False
735        elif arg == '-plot':    opts['plot'] = True
736        elif arg == '-linear':  opts['view'] = 'linear'
737        elif arg == '-log':     opts['view'] = 'log'
738        elif arg == '-q4':      opts['view'] = 'q4'
739        elif arg == '-1d':      opts['is2d'] = False
740        elif arg == '-2d':      opts['is2d'] = True
741        elif arg == '-exq':     opts['qmax'] = 10.0
742        elif arg == '-highq':   opts['qmax'] = 1.0
743        elif arg == '-midq':    opts['qmax'] = 0.2
744        elif arg == '-lowq':    opts['qmax'] = 0.05
745        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
746        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
747        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
748        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
749        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
750        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
751        elif arg == '-preset':  opts['seed'] = -1
752        elif arg == '-mono':    opts['mono'] = True
753        elif arg == '-poly':    opts['mono'] = False
754        elif arg == '-pars':    opts['show_pars'] = True
755        elif arg == '-nopars':  opts['show_pars'] = False
756        elif arg == '-hist':    opts['show_hist'] = True
757        elif arg == '-nohist':  opts['show_hist'] = False
758        elif arg == '-rel':     opts['rel_err'] = True
759        elif arg == '-abs':     opts['rel_err'] = False
760        elif arg == '-half':    engines.append(arg[1:])
761        elif arg == '-fast':    engines.append(arg[1:])
762        elif arg == '-single':  engines.append(arg[1:])
763        elif arg == '-double':  engines.append(arg[1:])
764        elif arg == '-single!': engines.append(arg[1:])
765        elif arg == '-double!': engines.append(arg[1:])
766        elif arg == '-quad!':   engines.append(arg[1:])
767        elif arg == '-sasview': engines.append(arg[1:])
768        elif arg == '-edit':    opts['explore'] = True
769        elif arg == '-demo':    opts['use_demo'] = True
770        elif arg == '-default':    opts['use_demo'] = False
771    # pylint: enable=bad-whitespace
772
773    if len(engines) == 0:
774        engines.extend(['single', 'double'])
775    elif len(engines) == 1:
776        if engines[0][0] != 'double':
777            engines.append('double')
778        else:
779            engines.append('single')
780    elif len(engines) > 2:
781        del engines[2:]
782
783    n1 = int(args[1]) if len(args) > 1 else 1
784    n2 = int(args[2]) if len(args) > 2 else 1
785    use_sasview = any(engine=='sasview' and count>0
786                      for engine, count in zip(engines, [n1, n2]))
787
788    # Get demo parameters from model definition, or use default parameters
789    # if model does not define demo parameters
790    pars = get_pars(model_info, opts['use_demo'])
791
792
793    # Fill in parameters given on the command line
794    presets = {}
795    for arg in values:
796        k, v = arg.split('=', 1)
797        if k not in pars:
798            # extract base name without polydispersity info
799            s = set(p.split('_pd')[0] for p in pars)
800            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
801            sys.exit(1)
802        presets[k] = float(v) if not k.endswith('type') else v
803
804    # randomize parameters
805    #pars.update(set_pars)  # set value before random to control range
806    if opts['seed'] > -1:
807        pars = randomize_pars(pars, seed=opts['seed'])
808        print("Randomize using -random=%i"%opts['seed'])
809    if opts['mono']:
810        pars = suppress_pd(pars)
811    pars.update(presets)  # set value after random to control value
812    #import pprint; pprint.pprint(model_info)
813    constrain_pars(model_info, pars)
814    if use_sasview:
815        constrain_new_to_old(model_info, pars)
816    if opts['show_pars']:
817        print(str(parlist(model_info, pars, opts['is2d'])))
818
819    # Create the computational engines
820    data, _ = make_data(opts)
821    if n1:
822        base = make_engine(model_info, data, engines[0], opts['cutoff'])
823    else:
824        base = None
825    if n2:
826        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
827    else:
828        comp = None
829
830    # pylint: disable=bad-whitespace
831    # Remember it all
832    opts.update({
833        'name'      : name,
834        'def'       : model_info,
835        'n1'        : n1,
836        'n2'        : n2,
837        'presets'   : presets,
838        'pars'      : pars,
839        'data'      : data,
840        'engines'   : [base, comp],
841    })
842    # pylint: enable=bad-whitespace
843
844    return opts
845
846def explore(opts):
847    """
848    Explore the model using the Bumps GUI.
849    """
850    import wx
851    from bumps.names import FitProblem
852    from bumps.gui.app_frame import AppFrame
853
854    problem = FitProblem(Explore(opts))
855    is_mac = "cocoa" in wx.version()
856    app = wx.App()
857    frame = AppFrame(parent=None, title="explore")
858    if not is_mac: frame.Show()
859    frame.panel.set_model(model=problem)
860    frame.panel.Layout()
861    frame.panel.aui.Split(0, wx.TOP)
862    if is_mac: frame.Show()
863    app.MainLoop()
864
865class Explore(object):
866    """
867    Bumps wrapper for a SAS model comparison.
868
869    The resulting object can be used as a Bumps fit problem so that
870    parameters can be adjusted in the GUI, with plots updated on the fly.
871    """
872    def __init__(self, opts):
873        from bumps.cli import config_matplotlib
874        from . import bumps_model
875        config_matplotlib()
876        self.opts = opts
877        model_info = opts['def']
878        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
879        if not opts['is2d']:
880            for p in model_info['parameters'].type['1d']:
881                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
882                    k = p.name+ext
883                    v = pars.get(k, None)
884                    if v is not None:
885                        v.range(*parameter_range(k, v.value))
886        else:
887            for k, v in pars.items():
888                v.range(*parameter_range(k, v.value))
889
890        self.pars = pars
891        self.pd_types = pd_types
892        self.limits = None
893
894    def numpoints(self):
895        """
896        Return the number of points.
897        """
898        return len(self.pars) + 1  # so dof is 1
899
900    def parameters(self):
901        """
902        Return a dictionary of parameters.
903        """
904        return self.pars
905
906    def nllf(self):
907        """
908        Return cost.
909        """
910        # pylint: disable=no-self-use
911        return 0.  # No nllf
912
913    def plot(self, view='log'):
914        """
915        Plot the data and residuals.
916        """
917        pars = dict((k, v.value) for k, v in self.pars.items())
918        pars.update(self.pd_types)
919        self.opts['pars'] = pars
920        limits = compare(self.opts, limits=self.limits)
921        if self.limits is None:
922            vmin, vmax = limits
923            vmax = 1.3*vmax
924            vmin = vmax*1e-7
925            self.limits = vmin, vmax
926
927
928def main():
929    """
930    Main program.
931    """
932    opts = parse_opts()
933    if opts['explore']:
934        explore(opts)
935    else:
936        compare(opts)
937
938if __name__ == "__main__":
939    main()
Note: See TracBrowser for help on using the repository browser.