source: sasmodels/sasmodels/compare.py @ dd7fc12

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since dd7fc12 was dd7fc12, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

fix kerneldll dtype problem; more type hinting

  • Property mode set to 100755
File size: 33.9 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np  # type: ignore
37
38from . import core
39from . import kerneldll
40from .data import plot_theory, empty_data1D, empty_data2D
41from .direct_model import DirectModel
42from .convert import revert_name, revert_pars, constrain_new_to_old
43
44try:
45    from typing import Optional, Dict, Any, Callable, Tuple
46except:
47    pass
48else:
49    from .modelinfo import ModelInfo, Parameter, ParameterSet
50    from .data import Data
51    Calculator = Callable[[float, ...], np.ndarray]
52
53USAGE = """
54usage: compare.py model N1 N2 [options...] [key=val]
55
56Compare the speed and value for a model between the SasView original and the
57sasmodels rewrite.
58
59model is the name of the model to compare (see below).
60N1 is the number of times to run sasmodels (default=1).
61N2 is the number times to run sasview (default=1).
62
63Options (* for default):
64
65    -plot*/-noplot plots or suppress the plot of the model
66    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
67    -nq=128 sets the number of Q points in the data set
68    -zero indicates that q=0 should be included
69    -1d*/-2d computes 1d or 2d data
70    -preset*/-random[=seed] preset or random parameters
71    -mono/-poly* force monodisperse/polydisperse
72    -cutoff=1e-5* cutoff value for including a point in polydispersity
73    -pars/-nopars* prints the parameter set or not
74    -abs/-rel* plot relative or absolute error
75    -linear/-log*/-q4 intensity scaling
76    -hist/-nohist* plot histogram of relative error
77    -res=0 sets the resolution width dQ/Q if calculating with resolution
78    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
79    -edit starts the parameter explorer
80    -default/-demo* use demo vs default parameters
81
82Any two calculation engines can be selected for comparison:
83
84    -single/-double/-half/-fast sets an OpenCL calculation engine
85    -single!/-double!/-quad! sets an OpenMP calculation engine
86    -sasview sets the sasview calculation engine
87
88The default is -single -sasview.  Note that the interpretation of quad
89precision depends on architecture, and may vary from 64-bit to 128-bit,
90with 80-bit floats being common (1e-19 precision).
91
92Key=value pairs allow you to set specific values for the model parameters.
93"""
94
95# Update docs with command line usage string.   This is separate from the usual
96# doc string so that we can display it at run time if there is an error.
97# lin
98__doc__ = (__doc__  # pylint: disable=redefined-builtin
99           + """
100Program description
101-------------------
102
103"""
104           + USAGE)
105
106kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
107
108# CRUFT python 2.6
109if not hasattr(datetime.timedelta, 'total_seconds'):
110    def delay(dt):
111        """Return number date-time delta as number seconds"""
112        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
113else:
114    def delay(dt):
115        """Return number date-time delta as number seconds"""
116        return dt.total_seconds()
117
118
119class push_seed(object):
120    """
121    Set the seed value for the random number generator.
122
123    When used in a with statement, the random number generator state is
124    restored after the with statement is complete.
125
126    :Parameters:
127
128    *seed* : int or array_like, optional
129        Seed for RandomState
130
131    :Example:
132
133    Seed can be used directly to set the seed::
134
135        >>> from numpy.random import randint
136        >>> push_seed(24)
137        <...push_seed object at...>
138        >>> print(randint(0,1000000,3))
139        [242082    899 211136]
140
141    Seed can also be used in a with statement, which sets the random
142    number generator state for the enclosed computations and restores
143    it to the previous state on completion::
144
145        >>> with push_seed(24):
146        ...    print(randint(0,1000000,3))
147        [242082    899 211136]
148
149    Using nested contexts, we can demonstrate that state is indeed
150    restored after the block completes::
151
152        >>> with push_seed(24):
153        ...    print(randint(0,1000000))
154        ...    with push_seed(24):
155        ...        print(randint(0,1000000,3))
156        ...    print(randint(0,1000000))
157        242082
158        [242082    899 211136]
159        899
160
161    The restore step is protected against exceptions in the block::
162
163        >>> with push_seed(24):
164        ...    print(randint(0,1000000))
165        ...    try:
166        ...        with push_seed(24):
167        ...            print(randint(0,1000000,3))
168        ...            raise Exception()
169        ...    except Exception:
170        ...        print("Exception raised")
171        ...    print(randint(0,1000000))
172        242082
173        [242082    899 211136]
174        Exception raised
175        899
176    """
177    def __init__(self, seed=None):
178        # type: (Optional[int]) -> None
179        self._state = np.random.get_state()
180        np.random.seed(seed)
181
182    def __enter__(self):
183        # type: () -> None
184        pass
185
186    def __exit__(self, type, value, traceback):
187        # type: (Any, BaseException, Any) -> None
188        # TODO: better typing for __exit__ method
189        np.random.set_state(self._state)
190
191def tic():
192    # type: () -> Callable[[], float]
193    """
194    Timer function.
195
196    Use "toc=tic()" to start the clock and "toc()" to measure
197    a time interval.
198    """
199    then = datetime.datetime.now()
200    return lambda: delay(datetime.datetime.now() - then)
201
202
203def set_beam_stop(data, radius, outer=None):
204    # type: (Data, float, float) -> None
205    """
206    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
207
208    Note: this function does not require sasview
209    """
210    if hasattr(data, 'qx_data'):
211        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
212        data.mask = (q < radius)
213        if outer is not None:
214            data.mask |= (q >= outer)
215    else:
216        data.mask = (data.x < radius)
217        if outer is not None:
218            data.mask |= (data.x >= outer)
219
220
221def parameter_range(p, v):
222    # type: (str, float) -> Tuple[float, float]
223    """
224    Choose a parameter range based on parameter name and initial value.
225    """
226    # process the polydispersity options
227    if p.endswith('_pd_n'):
228        return 0., 100.
229    elif p.endswith('_pd_nsigma'):
230        return 0., 5.
231    elif p.endswith('_pd_type'):
232        raise ValueError("Cannot return a range for a string value")
233    elif any(s in p for s in ('theta', 'phi', 'psi')):
234        # orientation in [-180,180], orientation pd in [0,45]
235        if p.endswith('_pd'):
236            return 0., 45.
237        else:
238            return -180., 180.
239    elif p.endswith('_pd'):
240        return 0., 1.
241    elif 'sld' in p:
242        return -0.5, 10.
243    elif p == 'background':
244        return 0., 10.
245    elif p == 'scale':
246        return 0., 1.e3
247    elif v < 0.:
248        return 2.*v, -2.*v
249    else:
250        return 0., (2.*v if v > 0. else 1.)
251
252
253def _randomize_one(model_info, p, v):
254    # type: (ModelInfo, str, float) -> float
255    # type: (ModelInfo, str, str) -> str
256    """
257    Randomize a single parameter.
258    """
259    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
260        return v
261
262    # Find the parameter definition
263    for par in model_info.parameters.call_parameters:
264        if par.name == p:
265            break
266    else:
267        raise ValueError("unknown parameter %r"%p)
268
269    if len(par.limits) > 2:  # choice list
270        return np.random.randint(len(par.limits))
271
272    limits = par.limits
273    if not np.isfinite(limits).all():
274        low, high = parameter_range(p, v)
275        limits = (max(limits[0], low), min(limits[1], high))
276
277    return np.random.uniform(*limits)
278
279
280def randomize_pars(model_info, pars, seed=None):
281    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
282    """
283    Generate random values for all of the parameters.
284
285    Valid ranges for the random number generator are guessed from the name of
286    the parameter; this will not account for constraints such as cap radius
287    greater than cylinder radius in the capped_cylinder model, so
288    :func:`constrain_pars` needs to be called afterward..
289    """
290    with push_seed(seed):
291        # Note: the sort guarantees order `of calls to random number generator
292        random_pars = dict((p, _randomize_one(model_info, p, v))
293                           for p, v in sorted(pars.items()))
294    return random_pars
295
296def constrain_pars(model_info, pars):
297    # type: (ModelInfo, ParameterSet) -> None
298    """
299    Restrict parameters to valid values.
300
301    This includes model specific code for models such as capped_cylinder
302    which need to support within model constraints (cap radius more than
303    cylinder radius in this case).
304
305    Warning: this updates the *pars* dictionary in place.
306    """
307    name = model_info.id
308    # if it is a product model, then just look at the form factor since
309    # none of the structure factors need any constraints.
310    if '*' in name:
311        name = name.split('*')[0]
312
313    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
314        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
315    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
316        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
317
318    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
319    if name == 'guinier':
320        #q_max = 0.2  # mid q maximum
321        q_max = 1.0  # high q maximum
322        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
323        pars['rg'] = min(pars['rg'], rg_max)
324
325    if name == 'rpa':
326        # Make sure phi sums to 1.0
327        if pars['case_num'] < 2:
328            pars['Phi1'] = 0.
329            pars['Phi2'] = 0.
330        elif pars['case_num'] < 5:
331            pars['Phi1'] = 0.
332        total = sum(pars['Phi'+c] for c in '1234')
333        for c in '1234':
334            pars['Phi'+c] /= total
335
336def parlist(model_info, pars, is2d):
337    # type: (ModelInfo, ParameterSet, bool) -> str
338    """
339    Format the parameter list for printing.
340    """
341    lines = []
342    parameters = model_info.parameters
343    for p in parameters.user_parameters(pars, is2d):
344        fields = dict(
345            value=pars.get(p.id, p.default),
346            pd=pars.get(p.id+"_pd", 0.),
347            n=int(pars.get(p.id+"_pd_n", 0)),
348            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
349            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
350        )
351        lines.append(_format_par(p.name, **fields))
352    return "\n".join(lines)
353
354    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
355
356def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian'):
357    # type: (str, float, float, int, float, str) -> str
358    line = "%s: %g"%(name, value)
359    if pd != 0.  and n != 0:
360        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
361                % (pd, n, nsigma, nsigma, pdtype)
362    return line
363
364def suppress_pd(pars):
365    # type: (ParameterSet) -> ParameterSet
366    """
367    Suppress theta_pd for now until the normalization is resolved.
368
369    May also suppress complete polydispersity of the model to test
370    models more quickly.
371    """
372    pars = pars.copy()
373    for p in pars:
374        if p.endswith("_pd_n"): pars[p] = 0
375    return pars
376
377def eval_sasview(model_info, data):
378    # type: (Modelinfo, Data) -> Calculator
379    """
380    Return a model calculator using the pre-4.0 SasView models.
381    """
382    # importing sas here so that the error message will be that sas failed to
383    # import rather than the more obscure smear_selection not imported error
384    import sas
385    from sas.models.qsmearing import smear_selection
386    import sas.models
387
388    def get_model(name):
389        # type: (str) -> "sas.models.BaseComponent"
390        #print("new",sorted(_pars.items()))
391        __import__('sas.models.' + name)
392        ModelClass = getattr(getattr(sas.models, name, None), name, None)
393        if ModelClass is None:
394            raise ValueError("could not find model %r in sas.models"%name)
395        return ModelClass()
396
397    # grab the sasview model, or create it if it is a product model
398    if model_info.composition:
399        composition_type, parts = model_info.composition
400        if composition_type == 'product':
401            from sas.models.MultiplicationModel import MultiplicationModel
402            P, S = [get_model(revert_name(p)) for p in parts]
403            model = MultiplicationModel(P, S)
404        else:
405            raise ValueError("sasview mixture models not supported by compare")
406    else:
407        model = get_model(revert_name(model_info))
408
409    # build a smearer with which to call the model, if necessary
410    smearer = smear_selection(data, model=model)
411    if hasattr(data, 'qx_data'):
412        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
413        index = ((~data.mask) & (~np.isnan(data.data))
414                 & (q >= data.qmin) & (q <= data.qmax))
415        if smearer is not None:
416            smearer.model = model  # because smear_selection has a bug
417            smearer.accuracy = data.accuracy
418            smearer.set_index(index)
419            theory = lambda: smearer.get_value()
420        else:
421            theory = lambda: model.evalDistribution([data.qx_data[index],
422                                                     data.qy_data[index]])
423    elif smearer is not None:
424        theory = lambda: smearer(model.evalDistribution(data.x))
425    else:
426        theory = lambda: model.evalDistribution(data.x)
427
428    def calculator(**pars):
429        # type: (float, ...) -> np.ndarray
430        """
431        Sasview calculator for model.
432        """
433        # paying for parameter conversion each time to keep life simple, if not fast
434        pars = revert_pars(model_info, pars)
435        for k, v in pars.items():
436            name_attr = k.split('.')  # polydispersity components
437            if len(name_attr) == 2:
438                model.dispersion[name_attr[0]][name_attr[1]] = v
439            else:
440                model.setParam(k, v)
441        return theory()
442
443    calculator.engine = "sasview"
444    return calculator
445
446DTYPE_MAP = {
447    'half': '16',
448    'fast': 'fast',
449    'single': '32',
450    'double': '64',
451    'quad': '128',
452    'f16': '16',
453    'f32': '32',
454    'f64': '64',
455    'longdouble': '128',
456}
457def eval_opencl(model_info, data, dtype='single', cutoff=0.):
458    # type: (ModelInfo, Data, str, float) -> Calculator
459    """
460    Return a model calculator using the OpenCL calculation engine.
461    """
462    try:
463        model = core.build_model(model_info, dtype=dtype, platform="ocl")
464    except Exception as exc:
465        print(exc)
466        print("... trying again with single precision")
467        model = core.build_model(model_info, dtype='single', platform="ocl")
468    calculator = DirectModel(data, model, cutoff=cutoff)
469    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
470    return calculator
471
472def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
473    # type: (ModelInfo, Data, str, float) -> Calculator
474    """
475    Return a model calculator using the DLL calculation engine.
476    """
477    model = core.build_model(model_info, dtype=dtype, platform="dll")
478    calculator = DirectModel(data, model, cutoff=cutoff)
479    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
480    return calculator
481
482def time_calculation(calculator, pars, Nevals=1):
483    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
484    """
485    Compute the average calculation time over N evaluations.
486
487    An additional call is generated without polydispersity in order to
488    initialize the calculation engine, and make the average more stable.
489    """
490    # initialize the code so time is more accurate
491    if Nevals > 1:
492        calculator(**suppress_pd(pars))
493    toc = tic()
494    # make sure there is at least one eval
495    value = calculator(**pars)
496    for _ in range(Nevals-1):
497        value = calculator(**pars)
498    average_time = toc()*1000./Nevals
499    return value, average_time
500
501def make_data(opts):
502    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
503    """
504    Generate an empty dataset, used with the model to set Q points
505    and resolution.
506
507    *opts* contains the options, with 'qmax', 'nq', 'res',
508    'accuracy', 'is2d' and 'view' parsed from the command line.
509    """
510    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
511    if opts['is2d']:
512        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
513        data = empty_data2D(q, resolution=res)
514        data.accuracy = opts['accuracy']
515        set_beam_stop(data, 0.0004)
516        index = ~data.mask
517    else:
518        if opts['view'] == 'log' and not opts['zero']:
519            qmax = math.log10(qmax)
520            q = np.logspace(qmax-3, qmax, nq)
521        else:
522            q = np.linspace(0.001*qmax, qmax, nq)
523        if opts['zero']:
524            q = np.hstack((0, q))
525        data = empty_data1D(q, resolution=res)
526        index = slice(None, None)
527    return data, index
528
529def make_engine(model_info, data, dtype, cutoff):
530    # type: (ModelInfo, Data, str, float) -> Calculator
531    """
532    Generate the appropriate calculation engine for the given datatype.
533
534    Datatypes with '!' appended are evaluated using external C DLLs rather
535    than OpenCL.
536    """
537    if dtype == 'sasview':
538        return eval_sasview(model_info, data)
539    elif dtype.endswith('!'):
540        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
541    else:
542        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
543
544def _show_invalid(data, theory):
545    # type: (Data, np.ma.ndarray) -> None
546    """
547    Display a list of the non-finite values in theory.
548    """
549    if not theory.mask.any():
550        return
551
552    if hasattr(data, 'x'):
553        bad = zip(data.x[theory.mask], theory[theory.mask])
554        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
555
556
557def compare(opts, limits=None):
558    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
559    """
560    Preform a comparison using options from the command line.
561
562    *limits* are the limits on the values to use, either to set the y-axis
563    for 1D or to set the colormap scale for 2D.  If None, then they are
564    inferred from the data and returned. When exploring using Bumps,
565    the limits are set when the model is initially called, and maintained
566    as the values are adjusted, making it easier to see the effects of the
567    parameters.
568    """
569    Nbase, Ncomp = opts['n1'], opts['n2']
570    pars = opts['pars']
571    data = opts['data']
572
573    # silence the linter
574    base = opts['engines'][0] if Nbase else None
575    comp = opts['engines'][1] if Ncomp else None
576    base_time = comp_time = None
577    base_value = comp_value = resid = relerr = None
578
579    # Base calculation
580    if Nbase > 0:
581        try:
582            base_raw, base_time = time_calculation(base, pars, Nbase)
583            base_value = np.ma.masked_invalid(base_raw)
584            print("%s t=%.2f ms, intensity=%.0f"
585                  % (base.engine, base_time, base_value.sum()))
586            _show_invalid(data, base_value)
587        except ImportError:
588            traceback.print_exc()
589            Nbase = 0
590
591    # Comparison calculation
592    if Ncomp > 0:
593        try:
594            comp_raw, comp_time = time_calculation(comp, pars, Ncomp)
595            comp_value = np.ma.masked_invalid(comp_raw)
596            print("%s t=%.2f ms, intensity=%.0f"
597                  % (comp.engine, comp_time, comp_value.sum()))
598            _show_invalid(data, comp_value)
599        except ImportError:
600            traceback.print_exc()
601            Ncomp = 0
602
603    # Compare, but only if computing both forms
604    if Nbase > 0 and Ncomp > 0:
605        resid = (base_value - comp_value)
606        relerr = resid/comp_value
607        _print_stats("|%s-%s|"
608                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
609                     resid)
610        _print_stats("|(%s-%s)/%s|"
611                     % (base.engine, comp.engine, comp.engine),
612                     relerr)
613
614    # Plot if requested
615    if not opts['plot'] and not opts['explore']: return
616    view = opts['view']
617    import matplotlib.pyplot as plt
618    if limits is None:
619        vmin, vmax = np.Inf, -np.Inf
620        if Nbase > 0:
621            vmin = min(vmin, base_value.min())
622            vmax = max(vmax, base_value.max())
623        if Ncomp > 0:
624            vmin = min(vmin, comp_value.min())
625            vmax = max(vmax, comp_value.max())
626        limits = vmin, vmax
627
628    if Nbase > 0:
629        if Ncomp > 0: plt.subplot(131)
630        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
631        plt.title("%s t=%.2f ms"%(base.engine, base_time))
632        #cbar_title = "log I"
633    if Ncomp > 0:
634        if Nbase > 0: plt.subplot(132)
635        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
636        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
637        #cbar_title = "log I"
638    if Ncomp > 0 and Nbase > 0:
639        plt.subplot(133)
640        if not opts['rel_err']:
641            err, errstr, errview = resid, "abs err", "linear"
642        else:
643            err, errstr, errview = abs(relerr), "rel err", "log"
644        #err,errstr = base/comp,"ratio"
645        plot_theory(data, None, resid=err, view=errview, use_data=False)
646        if view == 'linear':
647            plt.xscale('linear')
648        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
649        #cbar_title = errstr if errview=="linear" else "log "+errstr
650    #if is2D:
651    #    h = plt.colorbar()
652    #    h.ax.set_title(cbar_title)
653
654    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
655        plt.figure()
656        v = relerr
657        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
658        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
659        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
660                   % (base.engine, comp.engine, comp.engine))
661        plt.ylabel('P(err)')
662        plt.title('Distribution of relative error between calculation engines')
663
664    if not opts['explore']:
665        plt.show()
666
667    return limits
668
669def _print_stats(label, err):
670    # type: (str, np.ma.ndarray) -> None
671    # work with trimmed data, not the full set
672    sorted_err = np.sort(abs(err.compressed()))
673    p50 = int((len(sorted_err)-1)*0.50)
674    p98 = int((len(sorted_err)-1)*0.98)
675    data = [
676        "max:%.3e"%sorted_err[-1],
677        "median:%.3e"%sorted_err[p50],
678        "98%%:%.3e"%sorted_err[p98],
679        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
680        "zero-offset:%+.3e"%np.mean(sorted_err),
681        ]
682    print(label+"  "+"  ".join(data))
683
684
685
686# ===========================================================================
687#
688NAME_OPTIONS = set([
689    'plot', 'noplot',
690    'half', 'fast', 'single', 'double',
691    'single!', 'double!', 'quad!', 'sasview',
692    'lowq', 'midq', 'highq', 'exq', 'zero',
693    '2d', '1d',
694    'preset', 'random',
695    'poly', 'mono',
696    'nopars', 'pars',
697    'rel', 'abs',
698    'linear', 'log', 'q4',
699    'hist', 'nohist',
700    'edit',
701    'demo', 'default',
702    ])
703VALUE_OPTIONS = [
704    # Note: random is both a name option and a value option
705    'cutoff', 'random', 'nq', 'res', 'accuracy',
706    ]
707
708def columnize(L, indent="", width=79):
709    # type: (List[str], str, int) -> str
710    """
711    Format a list of strings into columns.
712
713    Returns a string with carriage returns ready for printing.
714    """
715    column_width = max(len(w) for w in L) + 1
716    num_columns = (width - len(indent)) // column_width
717    num_rows = len(L) // num_columns
718    L = L + [""] * (num_rows*num_columns - len(L))
719    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
720    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
721             for row in zip(*columns)]
722    output = indent + ("\n"+indent).join(lines)
723    return output
724
725
726def get_pars(model_info, use_demo=False):
727    # type: (ModelInfo, bool) -> ParameterSet
728    """
729    Extract demo parameters from the model definition.
730    """
731    # Get the default values for the parameters
732    pars = {}
733    for p in model_info.parameters.call_parameters:
734        parts = [('', p.default)]
735        if p.polydisperse:
736            parts.append(('_pd', 0.0))
737            parts.append(('_pd_n', 0))
738            parts.append(('_pd_nsigma', 3.0))
739            parts.append(('_pd_type', "gaussian"))
740        for ext, val in parts:
741            if p.length > 1:
742                dict(("%s%d%s"%(p.id,k,ext), val) for k in range(p.length))
743            else:
744                pars[p.id+ext] = val
745
746    # Plug in values given in demo
747    if use_demo:
748        pars.update(model_info.demo)
749    return pars
750
751
752def parse_opts():
753    # type: () -> Dict[str, Any]
754    """
755    Parse command line options.
756    """
757    MODELS = core.list_models()
758    flags = [arg for arg in sys.argv[1:]
759             if arg.startswith('-')]
760    values = [arg for arg in sys.argv[1:]
761              if not arg.startswith('-') and '=' in arg]
762    args = [arg for arg in sys.argv[1:]
763            if not arg.startswith('-') and '=' not in arg]
764    models = "\n    ".join("%-15s"%v for v in MODELS)
765    if len(args) == 0:
766        print(USAGE)
767        print("\nAvailable models:")
768        print(columnize(MODELS, indent="  "))
769        sys.exit(1)
770    if len(args) > 3:
771        print("expected parameters: model N1 N2")
772
773    name = args[0]
774    try:
775        model_info = core.load_model_info(name)
776    except ImportError as exc:
777        print(str(exc))
778        print("Could not find model; use one of:\n    " + models)
779        sys.exit(1)
780
781    invalid = [o[1:] for o in flags
782               if o[1:] not in NAME_OPTIONS
783               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
784    if invalid:
785        print("Invalid options: %s"%(", ".join(invalid)))
786        sys.exit(1)
787
788
789    # pylint: disable=bad-whitespace
790    # Interpret the flags
791    opts = {
792        'plot'      : True,
793        'view'      : 'log',
794        'is2d'      : False,
795        'qmax'      : 0.05,
796        'nq'        : 128,
797        'res'       : 0.0,
798        'accuracy'  : 'Low',
799        'cutoff'    : 0.0,
800        'seed'      : -1,  # default to preset
801        'mono'      : False,
802        'show_pars' : False,
803        'show_hist' : False,
804        'rel_err'   : True,
805        'explore'   : False,
806        'use_demo'  : True,
807        'zero'      : False,
808    }
809    engines = []
810    for arg in flags:
811        if arg == '-noplot':    opts['plot'] = False
812        elif arg == '-plot':    opts['plot'] = True
813        elif arg == '-linear':  opts['view'] = 'linear'
814        elif arg == '-log':     opts['view'] = 'log'
815        elif arg == '-q4':      opts['view'] = 'q4'
816        elif arg == '-1d':      opts['is2d'] = False
817        elif arg == '-2d':      opts['is2d'] = True
818        elif arg == '-exq':     opts['qmax'] = 10.0
819        elif arg == '-highq':   opts['qmax'] = 1.0
820        elif arg == '-midq':    opts['qmax'] = 0.2
821        elif arg == '-lowq':    opts['qmax'] = 0.05
822        elif arg == '-zero':    opts['zero'] = True
823        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
824        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
825        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
826        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
827        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
828        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
829        elif arg == '-preset':  opts['seed'] = -1
830        elif arg == '-mono':    opts['mono'] = True
831        elif arg == '-poly':    opts['mono'] = False
832        elif arg == '-pars':    opts['show_pars'] = True
833        elif arg == '-nopars':  opts['show_pars'] = False
834        elif arg == '-hist':    opts['show_hist'] = True
835        elif arg == '-nohist':  opts['show_hist'] = False
836        elif arg == '-rel':     opts['rel_err'] = True
837        elif arg == '-abs':     opts['rel_err'] = False
838        elif arg == '-half':    engines.append(arg[1:])
839        elif arg == '-fast':    engines.append(arg[1:])
840        elif arg == '-single':  engines.append(arg[1:])
841        elif arg == '-double':  engines.append(arg[1:])
842        elif arg == '-single!': engines.append(arg[1:])
843        elif arg == '-double!': engines.append(arg[1:])
844        elif arg == '-quad!':   engines.append(arg[1:])
845        elif arg == '-sasview': engines.append(arg[1:])
846        elif arg == '-edit':    opts['explore'] = True
847        elif arg == '-demo':    opts['use_demo'] = True
848        elif arg == '-default':    opts['use_demo'] = False
849    # pylint: enable=bad-whitespace
850
851    if len(engines) == 0:
852        engines.extend(['single', 'double'])
853    elif len(engines) == 1:
854        if engines[0][0] != 'double':
855            engines.append('double')
856        else:
857            engines.append('single')
858    elif len(engines) > 2:
859        del engines[2:]
860
861    n1 = int(args[1]) if len(args) > 1 else 1
862    n2 = int(args[2]) if len(args) > 2 else 1
863    use_sasview = any(engine=='sasview' and count>0
864                      for engine, count in zip(engines, [n1, n2]))
865
866    # Get demo parameters from model definition, or use default parameters
867    # if model does not define demo parameters
868    pars = get_pars(model_info, opts['use_demo'])
869
870
871    # Fill in parameters given on the command line
872    presets = {}
873    for arg in values:
874        k, v = arg.split('=', 1)
875        if k not in pars:
876            # extract base name without polydispersity info
877            s = set(p.split('_pd')[0] for p in pars)
878            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
879            sys.exit(1)
880        presets[k] = float(v) if not k.endswith('type') else v
881
882    # randomize parameters
883    #pars.update(set_pars)  # set value before random to control range
884    if opts['seed'] > -1:
885        pars = randomize_pars(model_info, pars, seed=opts['seed'])
886        print("Randomize using -random=%i"%opts['seed'])
887    if opts['mono']:
888        pars = suppress_pd(pars)
889    pars.update(presets)  # set value after random to control value
890    #import pprint; pprint.pprint(model_info)
891    constrain_pars(model_info, pars)
892    if use_sasview:
893        constrain_new_to_old(model_info, pars)
894    if opts['show_pars']:
895        print(str(parlist(model_info, pars, opts['is2d'])))
896
897    # Create the computational engines
898    data, _ = make_data(opts)
899    if n1:
900        base = make_engine(model_info, data, engines[0], opts['cutoff'])
901    else:
902        base = None
903    if n2:
904        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
905    else:
906        comp = None
907
908    # pylint: disable=bad-whitespace
909    # Remember it all
910    opts.update({
911        'name'      : name,
912        'def'       : model_info,
913        'n1'        : n1,
914        'n2'        : n2,
915        'presets'   : presets,
916        'pars'      : pars,
917        'data'      : data,
918        'engines'   : [base, comp],
919    })
920    # pylint: enable=bad-whitespace
921
922    return opts
923
924def explore(opts):
925    # type: (Dict[str, Any]) -> None
926    """
927    Explore the model using the Bumps GUI.
928    """
929    import wx  # type: ignore
930    from bumps.names import FitProblem  # type: ignore
931    from bumps.gui.app_frame import AppFrame  # type: ignore
932
933    problem = FitProblem(Explore(opts))
934    is_mac = "cocoa" in wx.version()
935    app = wx.App()
936    frame = AppFrame(parent=None, title="explore")
937    if not is_mac: frame.Show()
938    frame.panel.set_model(model=problem)
939    frame.panel.Layout()
940    frame.panel.aui.Split(0, wx.TOP)
941    if is_mac: frame.Show()
942    app.MainLoop()
943
944class Explore(object):
945    """
946    Bumps wrapper for a SAS model comparison.
947
948    The resulting object can be used as a Bumps fit problem so that
949    parameters can be adjusted in the GUI, with plots updated on the fly.
950    """
951    def __init__(self, opts):
952        # type: (Dict[str, Any]) -> None
953        from bumps.cli import config_matplotlib  # type: ignore
954        from . import bumps_model
955        config_matplotlib()
956        self.opts = opts
957        model_info = opts['def']
958        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
959        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
960        if not opts['is2d']:
961            for p in model_info.parameters.user_parameters(is2d=False):
962                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
963                    k = p.name+ext
964                    v = pars.get(k, None)
965                    if v is not None:
966                        v.range(*parameter_range(k, v.value))
967        else:
968            for k, v in pars.items():
969                v.range(*parameter_range(k, v.value))
970
971        self.pars = pars
972        self.pd_types = pd_types
973        self.limits = None
974
975    def numpoints(self):
976        # type: () -> int
977        """
978        Return the number of points.
979        """
980        return len(self.pars) + 1  # so dof is 1
981
982    def parameters(self):
983        # type: () -> Any   # Dict/List hierarchy of parameters
984        """
985        Return a dictionary of parameters.
986        """
987        return self.pars
988
989    def nllf(self):
990        # type: () -> float
991        """
992        Return cost.
993        """
994        # pylint: disable=no-self-use
995        return 0.  # No nllf
996
997    def plot(self, view='log'):
998        # type: (str) -> None
999        """
1000        Plot the data and residuals.
1001        """
1002        pars = dict((k, v.value) for k, v in self.pars.items())
1003        pars.update(self.pd_types)
1004        self.opts['pars'] = pars
1005        limits = compare(self.opts, limits=self.limits)
1006        if self.limits is None:
1007            vmin, vmax = limits
1008            self.limits = vmax*1e-7, 1.3*vmax
1009
1010
1011def main():
1012    # type: () -> None
1013    """
1014    Main program.
1015    """
1016    opts = parse_opts()
1017    if opts['explore']:
1018        explore(opts)
1019    else:
1020        compare(opts)
1021
1022if __name__ == "__main__":
1023    main()
Note: See TracBrowser for help on using the repository browser.