source: sasmodels/sasmodels/compare.py @ f3bd37f

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since f3bd37f was f3bd37f, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

fix compare_many

  • Property mode set to 100755
File size: 34.0 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np  # type: ignore
37
38from . import core
39from . import kerneldll
40from .data import plot_theory, empty_data1D, empty_data2D
41from .direct_model import DirectModel
42from .convert import revert_name, revert_pars, constrain_new_to_old
43
44try:
45    from typing import Optional, Dict, Any, Callable, Tuple
46except:
47    pass
48else:
49    from .modelinfo import ModelInfo, Parameter, ParameterSet
50    from .data import Data
51    Calculator = Callable[[float], np.ndarray]
52
53USAGE = """
54usage: compare.py model N1 N2 [options...] [key=val]
55
56Compare the speed and value for a model between the SasView original and the
57sasmodels rewrite.
58
59model is the name of the model to compare (see below).
60N1 is the number of times to run sasmodels (default=1).
61N2 is the number times to run sasview (default=1).
62
63Options (* for default):
64
65    -plot*/-noplot plots or suppress the plot of the model
66    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
67    -nq=128 sets the number of Q points in the data set
68    -zero indicates that q=0 should be included
69    -1d*/-2d computes 1d or 2d data
70    -preset*/-random[=seed] preset or random parameters
71    -mono/-poly* force monodisperse/polydisperse
72    -cutoff=1e-5* cutoff value for including a point in polydispersity
73    -pars/-nopars* prints the parameter set or not
74    -abs/-rel* plot relative or absolute error
75    -linear/-log*/-q4 intensity scaling
76    -hist/-nohist* plot histogram of relative error
77    -res=0 sets the resolution width dQ/Q if calculating with resolution
78    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
79    -edit starts the parameter explorer
80    -default/-demo* use demo vs default parameters
81
82Any two calculation engines can be selected for comparison:
83
84    -single/-double/-half/-fast sets an OpenCL calculation engine
85    -single!/-double!/-quad! sets an OpenMP calculation engine
86    -sasview sets the sasview calculation engine
87
88The default is -single -sasview.  Note that the interpretation of quad
89precision depends on architecture, and may vary from 64-bit to 128-bit,
90with 80-bit floats being common (1e-19 precision).
91
92Key=value pairs allow you to set specific values for the model parameters.
93"""
94
95# Update docs with command line usage string.   This is separate from the usual
96# doc string so that we can display it at run time if there is an error.
97# lin
98__doc__ = (__doc__  # pylint: disable=redefined-builtin
99           + """
100Program description
101-------------------
102
103"""
104           + USAGE)
105
106kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
107
108# CRUFT python 2.6
109if not hasattr(datetime.timedelta, 'total_seconds'):
110    def delay(dt):
111        """Return number date-time delta as number seconds"""
112        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
113else:
114    def delay(dt):
115        """Return number date-time delta as number seconds"""
116        return dt.total_seconds()
117
118
119class push_seed(object):
120    """
121    Set the seed value for the random number generator.
122
123    When used in a with statement, the random number generator state is
124    restored after the with statement is complete.
125
126    :Parameters:
127
128    *seed* : int or array_like, optional
129        Seed for RandomState
130
131    :Example:
132
133    Seed can be used directly to set the seed::
134
135        >>> from numpy.random import randint
136        >>> push_seed(24)
137        <...push_seed object at...>
138        >>> print(randint(0,1000000,3))
139        [242082    899 211136]
140
141    Seed can also be used in a with statement, which sets the random
142    number generator state for the enclosed computations and restores
143    it to the previous state on completion::
144
145        >>> with push_seed(24):
146        ...    print(randint(0,1000000,3))
147        [242082    899 211136]
148
149    Using nested contexts, we can demonstrate that state is indeed
150    restored after the block completes::
151
152        >>> with push_seed(24):
153        ...    print(randint(0,1000000))
154        ...    with push_seed(24):
155        ...        print(randint(0,1000000,3))
156        ...    print(randint(0,1000000))
157        242082
158        [242082    899 211136]
159        899
160
161    The restore step is protected against exceptions in the block::
162
163        >>> with push_seed(24):
164        ...    print(randint(0,1000000))
165        ...    try:
166        ...        with push_seed(24):
167        ...            print(randint(0,1000000,3))
168        ...            raise Exception()
169        ...    except Exception:
170        ...        print("Exception raised")
171        ...    print(randint(0,1000000))
172        242082
173        [242082    899 211136]
174        Exception raised
175        899
176    """
177    def __init__(self, seed=None):
178        # type: (Optional[int]) -> None
179        self._state = np.random.get_state()
180        np.random.seed(seed)
181
182    def __enter__(self):
183        # type: () -> None
184        pass
185
186    def __exit__(self, type, value, traceback):
187        # type: (Any, BaseException, Any) -> None
188        # TODO: better typing for __exit__ method
189        np.random.set_state(self._state)
190
191def tic():
192    # type: () -> Callable[[], float]
193    """
194    Timer function.
195
196    Use "toc=tic()" to start the clock and "toc()" to measure
197    a time interval.
198    """
199    then = datetime.datetime.now()
200    return lambda: delay(datetime.datetime.now() - then)
201
202
203def set_beam_stop(data, radius, outer=None):
204    # type: (Data, float, float) -> None
205    """
206    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
207
208    Note: this function does not require sasview
209    """
210    if hasattr(data, 'qx_data'):
211        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
212        data.mask = (q < radius)
213        if outer is not None:
214            data.mask |= (q >= outer)
215    else:
216        data.mask = (data.x < radius)
217        if outer is not None:
218            data.mask |= (data.x >= outer)
219
220
221def parameter_range(p, v):
222    # type: (str, float) -> Tuple[float, float]
223    """
224    Choose a parameter range based on parameter name and initial value.
225    """
226    # process the polydispersity options
227    if p.endswith('_pd_n'):
228        return 0., 100.
229    elif p.endswith('_pd_nsigma'):
230        return 0., 5.
231    elif p.endswith('_pd_type'):
232        raise ValueError("Cannot return a range for a string value")
233    elif any(s in p for s in ('theta', 'phi', 'psi')):
234        # orientation in [-180,180], orientation pd in [0,45]
235        if p.endswith('_pd'):
236            return 0., 45.
237        else:
238            return -180., 180.
239    elif p.endswith('_pd'):
240        return 0., 1.
241    elif 'sld' in p:
242        return -0.5, 10.
243    elif p == 'background':
244        return 0., 10.
245    elif p == 'scale':
246        return 0., 1.e3
247    elif v < 0.:
248        return 2.*v, -2.*v
249    else:
250        return 0., (2.*v if v > 0. else 1.)
251
252
253def _randomize_one(model_info, p, v):
254    # type: (ModelInfo, str, float) -> float
255    # type: (ModelInfo, str, str) -> str
256    """
257    Randomize a single parameter.
258    """
259    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
260        return v
261
262    # Find the parameter definition
263    for par in model_info.parameters.call_parameters:
264        if par.name == p:
265            break
266    else:
267        raise ValueError("unknown parameter %r"%p)
268
269    if len(par.limits) > 2:  # choice list
270        return np.random.randint(len(par.limits))
271
272    limits = par.limits
273    if not np.isfinite(limits).all():
274        low, high = parameter_range(p, v)
275        limits = (max(limits[0], low), min(limits[1], high))
276
277    return np.random.uniform(*limits)
278
279
280def randomize_pars(model_info, pars, seed=None):
281    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
282    """
283    Generate random values for all of the parameters.
284
285    Valid ranges for the random number generator are guessed from the name of
286    the parameter; this will not account for constraints such as cap radius
287    greater than cylinder radius in the capped_cylinder model, so
288    :func:`constrain_pars` needs to be called afterward..
289    """
290    with push_seed(seed):
291        # Note: the sort guarantees order `of calls to random number generator
292        random_pars = dict((p, _randomize_one(model_info, p, v))
293                           for p, v in sorted(pars.items()))
294    return random_pars
295
296def constrain_pars(model_info, pars):
297    # type: (ModelInfo, ParameterSet) -> None
298    """
299    Restrict parameters to valid values.
300
301    This includes model specific code for models such as capped_cylinder
302    which need to support within model constraints (cap radius more than
303    cylinder radius in this case).
304
305    Warning: this updates the *pars* dictionary in place.
306    """
307    name = model_info.id
308    # if it is a product model, then just look at the form factor since
309    # none of the structure factors need any constraints.
310    if '*' in name:
311        name = name.split('*')[0]
312
313    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
314        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
315    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
316        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
317
318    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
319    if name == 'guinier':
320        #q_max = 0.2  # mid q maximum
321        q_max = 1.0  # high q maximum
322        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
323        pars['rg'] = min(pars['rg'], rg_max)
324
325    if name == 'rpa':
326        # Make sure phi sums to 1.0
327        if pars['case_num'] < 2:
328            pars['Phi1'] = 0.
329            pars['Phi2'] = 0.
330        elif pars['case_num'] < 5:
331            pars['Phi1'] = 0.
332        total = sum(pars['Phi'+c] for c in '1234')
333        for c in '1234':
334            pars['Phi'+c] /= total
335
336def parlist(model_info, pars, is2d):
337    # type: (ModelInfo, ParameterSet, bool) -> str
338    """
339    Format the parameter list for printing.
340    """
341    lines = []
342    parameters = model_info.parameters
343    for p in parameters.user_parameters(pars, is2d):
344        fields = dict(
345            value=pars.get(p.id, p.default),
346            pd=pars.get(p.id+"_pd", 0.),
347            n=int(pars.get(p.id+"_pd_n", 0)),
348            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
349            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
350        )
351        lines.append(_format_par(p.name, **fields))
352    return "\n".join(lines)
353
354    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
355
356def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian'):
357    # type: (str, float, float, int, float, str) -> str
358    line = "%s: %g"%(name, value)
359    if pd != 0.  and n != 0:
360        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
361                % (pd, n, nsigma, nsigma, pdtype)
362    return line
363
364def suppress_pd(pars):
365    # type: (ParameterSet) -> ParameterSet
366    """
367    Suppress theta_pd for now until the normalization is resolved.
368
369    May also suppress complete polydispersity of the model to test
370    models more quickly.
371    """
372    pars = pars.copy()
373    for p in pars:
374        if p.endswith("_pd_n"): pars[p] = 0
375    return pars
376
377def eval_sasview(model_info, data):
378    # type: (Modelinfo, Data) -> Calculator
379    """
380    Return a model calculator using the pre-4.0 SasView models.
381    """
382    # importing sas here so that the error message will be that sas failed to
383    # import rather than the more obscure smear_selection not imported error
384    import sas
385    import sas.models
386    from sas.models.qsmearing import smear_selection
387    from sas.models.MultiplicationModel import MultiplicationModel
388
389    def get_model(name):
390        # type: (str) -> "sas.models.BaseComponent"
391        #print("new",sorted(_pars.items()))
392        __import__('sas.models.' + name)
393        ModelClass = getattr(getattr(sas.models, name, None), name, None)
394        if ModelClass is None:
395            raise ValueError("could not find model %r in sas.models"%name)
396        return ModelClass()
397
398    # grab the sasview model, or create it if it is a product model
399    if model_info.composition:
400        composition_type, parts = model_info.composition
401        if composition_type == 'product':
402            P, S = [get_model(revert_name(p)) for p in parts]
403            model = MultiplicationModel(P, S)
404        else:
405            raise ValueError("sasview mixture models not supported by compare")
406    else:
407        old_name = revert_name(model_info)
408        if old_name is None:
409            raise ValueError("model %r does not exist in old sasview"
410                            % model_info.id)
411        model = get_model(old_name)
412
413    # build a smearer with which to call the model, if necessary
414    smearer = smear_selection(data, model=model)
415    if hasattr(data, 'qx_data'):
416        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
417        index = ((~data.mask) & (~np.isnan(data.data))
418                 & (q >= data.qmin) & (q <= data.qmax))
419        if smearer is not None:
420            smearer.model = model  # because smear_selection has a bug
421            smearer.accuracy = data.accuracy
422            smearer.set_index(index)
423            theory = lambda: smearer.get_value()
424        else:
425            theory = lambda: model.evalDistribution([data.qx_data[index],
426                                                     data.qy_data[index]])
427    elif smearer is not None:
428        theory = lambda: smearer(model.evalDistribution(data.x))
429    else:
430        theory = lambda: model.evalDistribution(data.x)
431
432    def calculator(**pars):
433        # type: (float, ...) -> np.ndarray
434        """
435        Sasview calculator for model.
436        """
437        # paying for parameter conversion each time to keep life simple, if not fast
438        pars = revert_pars(model_info, pars)
439        for k, v in pars.items():
440            name_attr = k.split('.')  # polydispersity components
441            if len(name_attr) == 2:
442                model.dispersion[name_attr[0]][name_attr[1]] = v
443            else:
444                model.setParam(k, v)
445        return theory()
446
447    calculator.engine = "sasview"
448    return calculator
449
450DTYPE_MAP = {
451    'half': '16',
452    'fast': 'fast',
453    'single': '32',
454    'double': '64',
455    'quad': '128',
456    'f16': '16',
457    'f32': '32',
458    'f64': '64',
459    'longdouble': '128',
460}
461def eval_opencl(model_info, data, dtype='single', cutoff=0.):
462    # type: (ModelInfo, Data, str, float) -> Calculator
463    """
464    Return a model calculator using the OpenCL calculation engine.
465    """
466    if not core.HAVE_OPENCL:
467        raise RuntimeError("OpenCL not available")
468    model = core.build_model(model_info, dtype=dtype, platform="ocl")
469    calculator = DirectModel(data, model, cutoff=cutoff)
470    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
471    return calculator
472
473def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
474    # type: (ModelInfo, Data, str, float) -> Calculator
475    """
476    Return a model calculator using the DLL calculation engine.
477    """
478    model = core.build_model(model_info, dtype=dtype, platform="dll")
479    calculator = DirectModel(data, model, cutoff=cutoff)
480    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
481    return calculator
482
483def time_calculation(calculator, pars, Nevals=1):
484    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
485    """
486    Compute the average calculation time over N evaluations.
487
488    An additional call is generated without polydispersity in order to
489    initialize the calculation engine, and make the average more stable.
490    """
491    # initialize the code so time is more accurate
492    if Nevals > 1:
493        calculator(**suppress_pd(pars))
494    toc = tic()
495    # make sure there is at least one eval
496    value = calculator(**pars)
497    for _ in range(Nevals-1):
498        value = calculator(**pars)
499    average_time = toc()*1000./Nevals
500    #print("I(q)",value)
501    return value, average_time
502
503def make_data(opts):
504    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
505    """
506    Generate an empty dataset, used with the model to set Q points
507    and resolution.
508
509    *opts* contains the options, with 'qmax', 'nq', 'res',
510    'accuracy', 'is2d' and 'view' parsed from the command line.
511    """
512    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
513    if opts['is2d']:
514        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
515        data = empty_data2D(q, resolution=res)
516        data.accuracy = opts['accuracy']
517        set_beam_stop(data, 0.0004)
518        index = ~data.mask
519    else:
520        if opts['view'] == 'log' and not opts['zero']:
521            qmax = math.log10(qmax)
522            q = np.logspace(qmax-3, qmax, nq)
523        else:
524            q = np.linspace(0.001*qmax, qmax, nq)
525        if opts['zero']:
526            q = np.hstack((0, q))
527        data = empty_data1D(q, resolution=res)
528        index = slice(None, None)
529    return data, index
530
531def make_engine(model_info, data, dtype, cutoff):
532    # type: (ModelInfo, Data, str, float) -> Calculator
533    """
534    Generate the appropriate calculation engine for the given datatype.
535
536    Datatypes with '!' appended are evaluated using external C DLLs rather
537    than OpenCL.
538    """
539    if dtype == 'sasview':
540        return eval_sasview(model_info, data)
541    elif dtype.endswith('!'):
542        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
543    else:
544        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
545
546def _show_invalid(data, theory):
547    # type: (Data, np.ma.ndarray) -> None
548    """
549    Display a list of the non-finite values in theory.
550    """
551    if not theory.mask.any():
552        return
553
554    if hasattr(data, 'x'):
555        bad = zip(data.x[theory.mask], theory[theory.mask])
556        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
557
558
559def compare(opts, limits=None):
560    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
561    """
562    Preform a comparison using options from the command line.
563
564    *limits* are the limits on the values to use, either to set the y-axis
565    for 1D or to set the colormap scale for 2D.  If None, then they are
566    inferred from the data and returned. When exploring using Bumps,
567    the limits are set when the model is initially called, and maintained
568    as the values are adjusted, making it easier to see the effects of the
569    parameters.
570    """
571    Nbase, Ncomp = opts['n1'], opts['n2']
572    pars = opts['pars']
573    data = opts['data']
574
575    # silence the linter
576    base = opts['engines'][0] if Nbase else None
577    comp = opts['engines'][1] if Ncomp else None
578    base_time = comp_time = None
579    base_value = comp_value = resid = relerr = None
580
581    # Base calculation
582    if Nbase > 0:
583        try:
584            base_raw, base_time = time_calculation(base, pars, Nbase)
585            base_value = np.ma.masked_invalid(base_raw)
586            print("%s t=%.2f ms, intensity=%.0f"
587                  % (base.engine, base_time, base_value.sum()))
588            _show_invalid(data, base_value)
589        except ImportError:
590            traceback.print_exc()
591            Nbase = 0
592
593    # Comparison calculation
594    if Ncomp > 0:
595        try:
596            comp_raw, comp_time = time_calculation(comp, pars, Ncomp)
597            comp_value = np.ma.masked_invalid(comp_raw)
598            print("%s t=%.2f ms, intensity=%.0f"
599                  % (comp.engine, comp_time, comp_value.sum()))
600            _show_invalid(data, comp_value)
601        except ImportError:
602            traceback.print_exc()
603            Ncomp = 0
604
605    # Compare, but only if computing both forms
606    if Nbase > 0 and Ncomp > 0:
607        resid = (base_value - comp_value)
608        relerr = resid/np.where(comp_value!=0., abs(comp_value), 1.0)
609        _print_stats("|%s-%s|"
610                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
611                     resid)
612        _print_stats("|(%s-%s)/%s|"
613                     % (base.engine, comp.engine, comp.engine),
614                     relerr)
615
616    # Plot if requested
617    if not opts['plot'] and not opts['explore']: return
618    view = opts['view']
619    import matplotlib.pyplot as plt
620    if limits is None:
621        vmin, vmax = np.Inf, -np.Inf
622        if Nbase > 0:
623            vmin = min(vmin, base_value.min())
624            vmax = max(vmax, base_value.max())
625        if Ncomp > 0:
626            vmin = min(vmin, comp_value.min())
627            vmax = max(vmax, comp_value.max())
628        limits = vmin, vmax
629
630    if Nbase > 0:
631        if Ncomp > 0: plt.subplot(131)
632        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
633        plt.title("%s t=%.2f ms"%(base.engine, base_time))
634        #cbar_title = "log I"
635    if Ncomp > 0:
636        if Nbase > 0: plt.subplot(132)
637        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
638        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
639        #cbar_title = "log I"
640    if Ncomp > 0 and Nbase > 0:
641        plt.subplot(133)
642        if not opts['rel_err']:
643            err, errstr, errview = resid, "abs err", "linear"
644        else:
645            err, errstr, errview = abs(relerr), "rel err", "log"
646        #err,errstr = base/comp,"ratio"
647        plot_theory(data, None, resid=err, view=errview, use_data=False)
648        if view == 'linear':
649            plt.xscale('linear')
650        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
651        #cbar_title = errstr if errview=="linear" else "log "+errstr
652    #if is2D:
653    #    h = plt.colorbar()
654    #    h.ax.set_title(cbar_title)
655
656    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
657        plt.figure()
658        v = relerr
659        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
660        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
661        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
662                   % (base.engine, comp.engine, comp.engine))
663        plt.ylabel('P(err)')
664        plt.title('Distribution of relative error between calculation engines')
665
666    if not opts['explore']:
667        plt.show()
668
669    return limits
670
671def _print_stats(label, err):
672    # type: (str, np.ma.ndarray) -> None
673    # work with trimmed data, not the full set
674    sorted_err = np.sort(abs(err.compressed()))
675    p50 = int((len(sorted_err)-1)*0.50)
676    p98 = int((len(sorted_err)-1)*0.98)
677    data = [
678        "max:%.3e"%sorted_err[-1],
679        "median:%.3e"%sorted_err[p50],
680        "98%%:%.3e"%sorted_err[p98],
681        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
682        "zero-offset:%+.3e"%np.mean(sorted_err),
683        ]
684    print(label+"  "+"  ".join(data))
685
686
687
688# ===========================================================================
689#
690NAME_OPTIONS = set([
691    'plot', 'noplot',
692    'half', 'fast', 'single', 'double',
693    'single!', 'double!', 'quad!', 'sasview',
694    'lowq', 'midq', 'highq', 'exq', 'zero',
695    '2d', '1d',
696    'preset', 'random',
697    'poly', 'mono',
698    'nopars', 'pars',
699    'rel', 'abs',
700    'linear', 'log', 'q4',
701    'hist', 'nohist',
702    'edit',
703    'demo', 'default',
704    ])
705VALUE_OPTIONS = [
706    # Note: random is both a name option and a value option
707    'cutoff', 'random', 'nq', 'res', 'accuracy',
708    ]
709
710def columnize(L, indent="", width=79):
711    # type: (List[str], str, int) -> str
712    """
713    Format a list of strings into columns.
714
715    Returns a string with carriage returns ready for printing.
716    """
717    column_width = max(len(w) for w in L) + 1
718    num_columns = (width - len(indent)) // column_width
719    num_rows = len(L) // num_columns
720    L = L + [""] * (num_rows*num_columns - len(L))
721    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
722    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
723             for row in zip(*columns)]
724    output = indent + ("\n"+indent).join(lines)
725    return output
726
727
728def get_pars(model_info, use_demo=False):
729    # type: (ModelInfo, bool) -> ParameterSet
730    """
731    Extract demo parameters from the model definition.
732    """
733    # Get the default values for the parameters
734    pars = {}
735    for p in model_info.parameters.call_parameters:
736        parts = [('', p.default)]
737        if p.polydisperse:
738            parts.append(('_pd', 0.0))
739            parts.append(('_pd_n', 0))
740            parts.append(('_pd_nsigma', 3.0))
741            parts.append(('_pd_type', "gaussian"))
742        for ext, val in parts:
743            if p.length > 1:
744                dict(("%s%d%s"%(p.id,k,ext), val) for k in range(p.length))
745            else:
746                pars[p.id+ext] = val
747
748    # Plug in values given in demo
749    if use_demo:
750        pars.update(model_info.demo)
751    return pars
752
753
754def parse_opts():
755    # type: () -> Dict[str, Any]
756    """
757    Parse command line options.
758    """
759    MODELS = core.list_models()
760    flags = [arg for arg in sys.argv[1:]
761             if arg.startswith('-')]
762    values = [arg for arg in sys.argv[1:]
763              if not arg.startswith('-') and '=' in arg]
764    args = [arg for arg in sys.argv[1:]
765            if not arg.startswith('-') and '=' not in arg]
766    models = "\n    ".join("%-15s"%v for v in MODELS)
767    if len(args) == 0:
768        print(USAGE)
769        print("\nAvailable models:")
770        print(columnize(MODELS, indent="  "))
771        sys.exit(1)
772    if len(args) > 3:
773        print("expected parameters: model N1 N2")
774
775    name = args[0]
776    try:
777        model_info = core.load_model_info(name)
778    except ImportError as exc:
779        print(str(exc))
780        print("Could not find model; use one of:\n    " + models)
781        sys.exit(1)
782
783    invalid = [o[1:] for o in flags
784               if o[1:] not in NAME_OPTIONS
785               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
786    if invalid:
787        print("Invalid options: %s"%(", ".join(invalid)))
788        sys.exit(1)
789
790
791    # pylint: disable=bad-whitespace
792    # Interpret the flags
793    opts = {
794        'plot'      : True,
795        'view'      : 'log',
796        'is2d'      : False,
797        'qmax'      : 0.05,
798        'nq'        : 128,
799        'res'       : 0.0,
800        'accuracy'  : 'Low',
801        'cutoff'    : 0.0,
802        'seed'      : -1,  # default to preset
803        'mono'      : False,
804        'show_pars' : False,
805        'show_hist' : False,
806        'rel_err'   : True,
807        'explore'   : False,
808        'use_demo'  : True,
809        'zero'      : False,
810    }
811    engines = []
812    for arg in flags:
813        if arg == '-noplot':    opts['plot'] = False
814        elif arg == '-plot':    opts['plot'] = True
815        elif arg == '-linear':  opts['view'] = 'linear'
816        elif arg == '-log':     opts['view'] = 'log'
817        elif arg == '-q4':      opts['view'] = 'q4'
818        elif arg == '-1d':      opts['is2d'] = False
819        elif arg == '-2d':      opts['is2d'] = True
820        elif arg == '-exq':     opts['qmax'] = 10.0
821        elif arg == '-highq':   opts['qmax'] = 1.0
822        elif arg == '-midq':    opts['qmax'] = 0.2
823        elif arg == '-lowq':    opts['qmax'] = 0.05
824        elif arg == '-zero':    opts['zero'] = True
825        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
826        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
827        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
828        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
829        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
830        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
831        elif arg == '-preset':  opts['seed'] = -1
832        elif arg == '-mono':    opts['mono'] = True
833        elif arg == '-poly':    opts['mono'] = False
834        elif arg == '-pars':    opts['show_pars'] = True
835        elif arg == '-nopars':  opts['show_pars'] = False
836        elif arg == '-hist':    opts['show_hist'] = True
837        elif arg == '-nohist':  opts['show_hist'] = False
838        elif arg == '-rel':     opts['rel_err'] = True
839        elif arg == '-abs':     opts['rel_err'] = False
840        elif arg == '-half':    engines.append(arg[1:])
841        elif arg == '-fast':    engines.append(arg[1:])
842        elif arg == '-single':  engines.append(arg[1:])
843        elif arg == '-double':  engines.append(arg[1:])
844        elif arg == '-single!': engines.append(arg[1:])
845        elif arg == '-double!': engines.append(arg[1:])
846        elif arg == '-quad!':   engines.append(arg[1:])
847        elif arg == '-sasview': engines.append(arg[1:])
848        elif arg == '-edit':    opts['explore'] = True
849        elif arg == '-demo':    opts['use_demo'] = True
850        elif arg == '-default':    opts['use_demo'] = False
851    # pylint: enable=bad-whitespace
852
853    if len(engines) == 0:
854        engines.extend(['single', 'double'])
855    elif len(engines) == 1:
856        if engines[0][0] == 'double':
857            engines.append('single')
858        else:
859            engines.append('double')
860    elif len(engines) > 2:
861        del engines[2:]
862
863    n1 = int(args[1]) if len(args) > 1 else 1
864    n2 = int(args[2]) if len(args) > 2 else 1
865    use_sasview = any(engine=='sasview' and count>0
866                      for engine, count in zip(engines, [n1, n2]))
867
868    # Get demo parameters from model definition, or use default parameters
869    # if model does not define demo parameters
870    pars = get_pars(model_info, opts['use_demo'])
871
872
873    # Fill in parameters given on the command line
874    presets = {}
875    for arg in values:
876        k, v = arg.split('=', 1)
877        if k not in pars:
878            # extract base name without polydispersity info
879            s = set(p.split('_pd')[0] for p in pars)
880            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
881            sys.exit(1)
882        presets[k] = float(v) if not k.endswith('type') else v
883
884    # randomize parameters
885    #pars.update(set_pars)  # set value before random to control range
886    if opts['seed'] > -1:
887        pars = randomize_pars(model_info, pars, seed=opts['seed'])
888        print("Randomize using -random=%i"%opts['seed'])
889    if opts['mono']:
890        pars = suppress_pd(pars)
891    pars.update(presets)  # set value after random to control value
892    #import pprint; pprint.pprint(model_info)
893    constrain_pars(model_info, pars)
894    if use_sasview:
895        constrain_new_to_old(model_info, pars)
896    if opts['show_pars']:
897        print(str(parlist(model_info, pars, opts['is2d'])))
898
899    # Create the computational engines
900    data, _ = make_data(opts)
901    if n1:
902        base = make_engine(model_info, data, engines[0], opts['cutoff'])
903    else:
904        base = None
905    if n2:
906        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
907    else:
908        comp = None
909
910    # pylint: disable=bad-whitespace
911    # Remember it all
912    opts.update({
913        'name'      : name,
914        'def'       : model_info,
915        'n1'        : n1,
916        'n2'        : n2,
917        'presets'   : presets,
918        'pars'      : pars,
919        'data'      : data,
920        'engines'   : [base, comp],
921    })
922    # pylint: enable=bad-whitespace
923
924    return opts
925
926def explore(opts):
927    # type: (Dict[str, Any]) -> None
928    """
929    Explore the model using the Bumps GUI.
930    """
931    import wx  # type: ignore
932    from bumps.names import FitProblem  # type: ignore
933    from bumps.gui.app_frame import AppFrame  # type: ignore
934
935    problem = FitProblem(Explore(opts))
936    is_mac = "cocoa" in wx.version()
937    app = wx.App()
938    frame = AppFrame(parent=None, title="explore")
939    if not is_mac: frame.Show()
940    frame.panel.set_model(model=problem)
941    frame.panel.Layout()
942    frame.panel.aui.Split(0, wx.TOP)
943    if is_mac: frame.Show()
944    app.MainLoop()
945
946class Explore(object):
947    """
948    Bumps wrapper for a SAS model comparison.
949
950    The resulting object can be used as a Bumps fit problem so that
951    parameters can be adjusted in the GUI, with plots updated on the fly.
952    """
953    def __init__(self, opts):
954        # type: (Dict[str, Any]) -> None
955        from bumps.cli import config_matplotlib  # type: ignore
956        from . import bumps_model
957        config_matplotlib()
958        self.opts = opts
959        model_info = opts['def']
960        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
961        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
962        if not opts['is2d']:
963            for p in model_info.parameters.user_parameters(is2d=False):
964                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
965                    k = p.name+ext
966                    v = pars.get(k, None)
967                    if v is not None:
968                        v.range(*parameter_range(k, v.value))
969        else:
970            for k, v in pars.items():
971                v.range(*parameter_range(k, v.value))
972
973        self.pars = pars
974        self.pd_types = pd_types
975        self.limits = None
976
977    def numpoints(self):
978        # type: () -> int
979        """
980        Return the number of points.
981        """
982        return len(self.pars) + 1  # so dof is 1
983
984    def parameters(self):
985        # type: () -> Any   # Dict/List hierarchy of parameters
986        """
987        Return a dictionary of parameters.
988        """
989        return self.pars
990
991    def nllf(self):
992        # type: () -> float
993        """
994        Return cost.
995        """
996        # pylint: disable=no-self-use
997        return 0.  # No nllf
998
999    def plot(self, view='log'):
1000        # type: (str) -> None
1001        """
1002        Plot the data and residuals.
1003        """
1004        pars = dict((k, v.value) for k, v in self.pars.items())
1005        pars.update(self.pd_types)
1006        self.opts['pars'] = pars
1007        limits = compare(self.opts, limits=self.limits)
1008        if self.limits is None:
1009            vmin, vmax = limits
1010            self.limits = vmax*1e-7, 1.3*vmax
1011
1012
1013def main():
1014    # type: () -> None
1015    """
1016    Main program.
1017    """
1018    opts = parse_opts()
1019    if opts['explore']:
1020        explore(opts)
1021    else:
1022        compare(opts)
1023
1024if __name__ == "__main__":
1025    main()
Note: See TracBrowser for help on using the repository browser.