source: sasmodels/sasmodels/compare.py @ 3b681fa

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 3b681fa was 234c532, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

add -html option to sascomp to show the model help

  • Property mode set to 100755
File size: 36.5 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np  # type: ignore
37
38from . import core
39from . import kerneldll
40from . import weights
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_name, revert_pars, constrain_new_to_old
44
45try:
46    from typing import Optional, Dict, Any, Callable, Tuple
47except:
48    pass
49else:
50    from .modelinfo import ModelInfo, Parameter, ParameterSet
51    from .data import Data
52    Calculator = Callable[[float], np.ndarray]
53
54USAGE = """
55usage: compare.py model N1 N2 [options...] [key=val]
56
57Compare the speed and value for a model between the SasView original and the
58sasmodels rewrite.
59
60model is the name of the model to compare (see below).
61N1 is the number of times to run sasmodels (default=1).
62N2 is the number times to run sasview (default=1).
63
64Options (* for default):
65
66    -plot*/-noplot plots or suppress the plot of the model
67    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
68    -nq=128 sets the number of Q points in the data set
69    -zero indicates that q=0 should be included
70    -1d*/-2d computes 1d or 2d data
71    -preset*/-random[=seed] preset or random parameters
72    -mono/-poly* force monodisperse/polydisperse
73    -cutoff=1e-5* cutoff value for including a point in polydispersity
74    -pars/-nopars* prints the parameter set or not
75    -abs/-rel* plot relative or absolute error
76    -linear/-log*/-q4 intensity scaling
77    -hist/-nohist* plot histogram of relative error
78    -res=0 sets the resolution width dQ/Q if calculating with resolution
79    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
80    -edit starts the parameter explorer
81    -default/-demo* use demo vs default parameters
82    -html shows the model docs instead of running the model
83
84Any two calculation engines can be selected for comparison:
85
86    -single/-double/-half/-fast sets an OpenCL calculation engine
87    -single!/-double!/-quad! sets an OpenMP calculation engine
88    -sasview sets the sasview calculation engine
89
90The default is -single -sasview.  Note that the interpretation of quad
91precision depends on architecture, and may vary from 64-bit to 128-bit,
92with 80-bit floats being common (1e-19 precision).
93
94Key=value pairs allow you to set specific values for the model parameters.
95"""
96
97# Update docs with command line usage string.   This is separate from the usual
98# doc string so that we can display it at run time if there is an error.
99# lin
100__doc__ = (__doc__  # pylint: disable=redefined-builtin
101           + """
102Program description
103-------------------
104
105"""
106           + USAGE)
107
108kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
109
110# CRUFT python 2.6
111if not hasattr(datetime.timedelta, 'total_seconds'):
112    def delay(dt):
113        """Return number date-time delta as number seconds"""
114        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
115else:
116    def delay(dt):
117        """Return number date-time delta as number seconds"""
118        return dt.total_seconds()
119
120
121class push_seed(object):
122    """
123    Set the seed value for the random number generator.
124
125    When used in a with statement, the random number generator state is
126    restored after the with statement is complete.
127
128    :Parameters:
129
130    *seed* : int or array_like, optional
131        Seed for RandomState
132
133    :Example:
134
135    Seed can be used directly to set the seed::
136
137        >>> from numpy.random import randint
138        >>> push_seed(24)
139        <...push_seed object at...>
140        >>> print(randint(0,1000000,3))
141        [242082    899 211136]
142
143    Seed can also be used in a with statement, which sets the random
144    number generator state for the enclosed computations and restores
145    it to the previous state on completion::
146
147        >>> with push_seed(24):
148        ...    print(randint(0,1000000,3))
149        [242082    899 211136]
150
151    Using nested contexts, we can demonstrate that state is indeed
152    restored after the block completes::
153
154        >>> with push_seed(24):
155        ...    print(randint(0,1000000))
156        ...    with push_seed(24):
157        ...        print(randint(0,1000000,3))
158        ...    print(randint(0,1000000))
159        242082
160        [242082    899 211136]
161        899
162
163    The restore step is protected against exceptions in the block::
164
165        >>> with push_seed(24):
166        ...    print(randint(0,1000000))
167        ...    try:
168        ...        with push_seed(24):
169        ...            print(randint(0,1000000,3))
170        ...            raise Exception()
171        ...    except Exception:
172        ...        print("Exception raised")
173        ...    print(randint(0,1000000))
174        242082
175        [242082    899 211136]
176        Exception raised
177        899
178    """
179    def __init__(self, seed=None):
180        # type: (Optional[int]) -> None
181        self._state = np.random.get_state()
182        np.random.seed(seed)
183
184    def __enter__(self):
185        # type: () -> None
186        pass
187
188    def __exit__(self, exc_type, exc_value, traceback):
189        # type: (Any, BaseException, Any) -> None
190        # TODO: better typing for __exit__ method
191        np.random.set_state(self._state)
192
193def tic():
194    # type: () -> Callable[[], float]
195    """
196    Timer function.
197
198    Use "toc=tic()" to start the clock and "toc()" to measure
199    a time interval.
200    """
201    then = datetime.datetime.now()
202    return lambda: delay(datetime.datetime.now() - then)
203
204
205def set_beam_stop(data, radius, outer=None):
206    # type: (Data, float, float) -> None
207    """
208    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
209
210    Note: this function does not require sasview
211    """
212    if hasattr(data, 'qx_data'):
213        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
214        data.mask = (q < radius)
215        if outer is not None:
216            data.mask |= (q >= outer)
217    else:
218        data.mask = (data.x < radius)
219        if outer is not None:
220            data.mask |= (data.x >= outer)
221
222
223def parameter_range(p, v):
224    # type: (str, float) -> Tuple[float, float]
225    """
226    Choose a parameter range based on parameter name and initial value.
227    """
228    # process the polydispersity options
229    if p.endswith('_pd_n'):
230        return 0., 100.
231    elif p.endswith('_pd_nsigma'):
232        return 0., 5.
233    elif p.endswith('_pd_type'):
234        raise ValueError("Cannot return a range for a string value")
235    elif any(s in p for s in ('theta', 'phi', 'psi')):
236        # orientation in [-180,180], orientation pd in [0,45]
237        if p.endswith('_pd'):
238            return 0., 45.
239        else:
240            return -180., 180.
241    elif p.endswith('_pd'):
242        return 0., 1.
243    elif 'sld' in p:
244        return -0.5, 10.
245    elif p == 'background':
246        return 0., 10.
247    elif p == 'scale':
248        return 0., 1.e3
249    elif v < 0.:
250        return 2.*v, -2.*v
251    else:
252        return 0., (2.*v if v > 0. else 1.)
253
254
255def _randomize_one(model_info, p, v):
256    # type: (ModelInfo, str, float) -> float
257    # type: (ModelInfo, str, str) -> str
258    """
259    Randomize a single parameter.
260    """
261    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
262        return v
263
264    # Find the parameter definition
265    for par in model_info.parameters.call_parameters:
266        if par.name == p:
267            break
268    else:
269        raise ValueError("unknown parameter %r"%p)
270
271    if len(par.limits) > 2:  # choice list
272        return np.random.randint(len(par.limits))
273
274    limits = par.limits
275    if not np.isfinite(limits).all():
276        low, high = parameter_range(p, v)
277        limits = (max(limits[0], low), min(limits[1], high))
278
279    return np.random.uniform(*limits)
280
281
282def randomize_pars(model_info, pars, seed=None):
283    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
284    """
285    Generate random values for all of the parameters.
286
287    Valid ranges for the random number generator are guessed from the name of
288    the parameter; this will not account for constraints such as cap radius
289    greater than cylinder radius in the capped_cylinder model, so
290    :func:`constrain_pars` needs to be called afterward..
291    """
292    with push_seed(seed):
293        # Note: the sort guarantees order `of calls to random number generator
294        random_pars = dict((p, _randomize_one(model_info, p, v))
295                           for p, v in sorted(pars.items()))
296    return random_pars
297
298def constrain_pars(model_info, pars):
299    # type: (ModelInfo, ParameterSet) -> None
300    """
301    Restrict parameters to valid values.
302
303    This includes model specific code for models such as capped_cylinder
304    which need to support within model constraints (cap radius more than
305    cylinder radius in this case).
306
307    Warning: this updates the *pars* dictionary in place.
308    """
309    name = model_info.id
310    # if it is a product model, then just look at the form factor since
311    # none of the structure factors need any constraints.
312    if '*' in name:
313        name = name.split('*')[0]
314
315    if name == 'capped_cylinder' and pars['radius_cap'] < pars['radius']:
316        pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
317    if name == 'barbell' and pars['radius_bell'] < pars['radius']:
318        pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
319
320    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
321    if name == 'guinier':
322        #q_max = 0.2  # mid q maximum
323        q_max = 1.0  # high q maximum
324        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
325        pars['rg'] = min(pars['rg'], rg_max)
326
327    if name == 'rpa':
328        # Make sure phi sums to 1.0
329        if pars['case_num'] < 2:
330            pars['Phi1'] = 0.
331            pars['Phi2'] = 0.
332        elif pars['case_num'] < 5:
333            pars['Phi1'] = 0.
334        total = sum(pars['Phi'+c] for c in '1234')
335        for c in '1234':
336            pars['Phi'+c] /= total
337
338def parlist(model_info, pars, is2d):
339    # type: (ModelInfo, ParameterSet, bool) -> str
340    """
341    Format the parameter list for printing.
342    """
343    lines = []
344    parameters = model_info.parameters
345    for p in parameters.user_parameters(pars, is2d):
346        fields = dict(
347            value=pars.get(p.id, p.default),
348            pd=pars.get(p.id+"_pd", 0.),
349            n=int(pars.get(p.id+"_pd_n", 0)),
350            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
351            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
352            relative_pd=p.relative_pd,
353        )
354        lines.append(_format_par(p.name, **fields))
355    return "\n".join(lines)
356
357    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
358
359def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
360                relative_pd=False):
361    # type: (str, float, float, int, float, str) -> str
362    line = "%s: %g"%(name, value)
363    if pd != 0.  and n != 0:
364        if relative_pd:
365            pd *= value
366        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
367                % (pd, n, nsigma, nsigma, pdtype)
368    return line
369
370def suppress_pd(pars):
371    # type: (ParameterSet) -> ParameterSet
372    """
373    Suppress theta_pd for now until the normalization is resolved.
374
375    May also suppress complete polydispersity of the model to test
376    models more quickly.
377    """
378    pars = pars.copy()
379    for p in pars:
380        if p.endswith("_pd_n"): pars[p] = 0
381    return pars
382
383def eval_sasview(model_info, data):
384    # type: (Modelinfo, Data) -> Calculator
385    """
386    Return a model calculator using the pre-4.0 SasView models.
387    """
388    # importing sas here so that the error message will be that sas failed to
389    # import rather than the more obscure smear_selection not imported error
390    import sas
391    import sas.models
392    from sas.models.qsmearing import smear_selection
393    from sas.models.MultiplicationModel import MultiplicationModel
394    from sas.models.dispersion_models import models as dispersers
395
396    def get_model_class(name):
397        # type: (str) -> "sas.models.BaseComponent"
398        #print("new",sorted(_pars.items()))
399        __import__('sas.models.' + name)
400        ModelClass = getattr(getattr(sas.models, name, None), name, None)
401        if ModelClass is None:
402            raise ValueError("could not find model %r in sas.models"%name)
403        return ModelClass
404
405    # WARNING: ugly hack when handling model!
406    # Sasview models with multiplicity need to be created with the target
407    # multiplicity, so we cannot create the target model ahead of time for
408    # for multiplicity models.  Instead we store the model in a list and
409    # update the first element of that list with the new multiplicity model
410    # every time we evaluate.
411
412    # grab the sasview model, or create it if it is a product model
413    if model_info.composition:
414        composition_type, parts = model_info.composition
415        if composition_type == 'product':
416            P, S = [get_model_class(revert_name(p))() for p in parts]
417            model = [MultiplicationModel(P, S)]
418        else:
419            raise ValueError("sasview mixture models not supported by compare")
420    else:
421        old_name = revert_name(model_info)
422        if old_name is None:
423            raise ValueError("model %r does not exist in old sasview"
424                            % model_info.id)
425        ModelClass = get_model_class(old_name)
426        model = [ModelClass()]
427    model[0].disperser_handles = {}
428
429    # build a smearer with which to call the model, if necessary
430    smearer = smear_selection(data, model=model)
431    if hasattr(data, 'qx_data'):
432        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
433        index = ((~data.mask) & (~np.isnan(data.data))
434                 & (q >= data.qmin) & (q <= data.qmax))
435        if smearer is not None:
436            smearer.model = model  # because smear_selection has a bug
437            smearer.accuracy = data.accuracy
438            smearer.set_index(index)
439            def _call_smearer():
440                smearer.model = model[0]
441                return smearer.get_value()
442            theory = _call_smearer
443        else:
444            theory = lambda: model[0].evalDistribution([data.qx_data[index],
445                                                        data.qy_data[index]])
446    elif smearer is not None:
447        theory = lambda: smearer(model[0].evalDistribution(data.x))
448    else:
449        theory = lambda: model[0].evalDistribution(data.x)
450
451    def calculator(**pars):
452        # type: (float, ...) -> np.ndarray
453        """
454        Sasview calculator for model.
455        """
456        oldpars = revert_pars(model_info, pars)
457        # For multiplicity models, create a model with the correct multiplicity
458        control = oldpars.pop("CONTROL", None)
459        if control is not None:
460            # sphericalSLD has one fewer multiplicity.  This update should
461            # happen in revert_pars, but it hasn't been called yet.
462            model[0] = ModelClass(control)
463        # paying for parameter conversion each time to keep life simple, if not fast
464        for k, v in oldpars.items():
465            if k.endswith('.type'):
466                par = k[:-5]
467                cls = dispersers[v if v != 'rectangle' else 'rectangula']
468                handle = cls()
469                model[0].disperser_handles[par] = handle
470                model[0].set_dispersion(par, handle)
471
472        #print("sasview pars",oldpars)
473        for k, v in oldpars.items():
474            name_attr = k.split('.')  # polydispersity components
475            if len(name_attr) == 2:
476                par, disp_par = name_attr
477                model[0].dispersion[par][disp_par] = v
478            else:
479                model[0].setParam(k, v)
480        return theory()
481
482    calculator.engine = "sasview"
483    return calculator
484
485DTYPE_MAP = {
486    'half': '16',
487    'fast': 'fast',
488    'single': '32',
489    'double': '64',
490    'quad': '128',
491    'f16': '16',
492    'f32': '32',
493    'f64': '64',
494    'longdouble': '128',
495}
496def eval_opencl(model_info, data, dtype='single', cutoff=0.):
497    # type: (ModelInfo, Data, str, float) -> Calculator
498    """
499    Return a model calculator using the OpenCL calculation engine.
500    """
501    if not core.HAVE_OPENCL:
502        raise RuntimeError("OpenCL not available")
503    model = core.build_model(model_info, dtype=dtype, platform="ocl")
504    calculator = DirectModel(data, model, cutoff=cutoff)
505    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
506    return calculator
507
508def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
509    # type: (ModelInfo, Data, str, float) -> Calculator
510    """
511    Return a model calculator using the DLL calculation engine.
512    """
513    model = core.build_model(model_info, dtype=dtype, platform="dll")
514    calculator = DirectModel(data, model, cutoff=cutoff)
515    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
516    return calculator
517
518def time_calculation(calculator, pars, evals=1):
519    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
520    """
521    Compute the average calculation time over N evaluations.
522
523    An additional call is generated without polydispersity in order to
524    initialize the calculation engine, and make the average more stable.
525    """
526    # initialize the code so time is more accurate
527    if evals > 1:
528        calculator(**suppress_pd(pars))
529    toc = tic()
530    # make sure there is at least one eval
531    value = calculator(**pars)
532    for _ in range(evals-1):
533        value = calculator(**pars)
534    average_time = toc()*1000. / evals
535    #print("I(q)",value)
536    return value, average_time
537
538def make_data(opts):
539    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
540    """
541    Generate an empty dataset, used with the model to set Q points
542    and resolution.
543
544    *opts* contains the options, with 'qmax', 'nq', 'res',
545    'accuracy', 'is2d' and 'view' parsed from the command line.
546    """
547    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
548    if opts['is2d']:
549        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
550        data = empty_data2D(q, resolution=res)
551        data.accuracy = opts['accuracy']
552        set_beam_stop(data, 0.0004)
553        index = ~data.mask
554    else:
555        if opts['view'] == 'log' and not opts['zero']:
556            qmax = math.log10(qmax)
557            q = np.logspace(qmax-3, qmax, nq)
558        else:
559            q = np.linspace(0.001*qmax, qmax, nq)
560        if opts['zero']:
561            q = np.hstack((0, q))
562        data = empty_data1D(q, resolution=res)
563        index = slice(None, None)
564    return data, index
565
566def make_engine(model_info, data, dtype, cutoff):
567    # type: (ModelInfo, Data, str, float) -> Calculator
568    """
569    Generate the appropriate calculation engine for the given datatype.
570
571    Datatypes with '!' appended are evaluated using external C DLLs rather
572    than OpenCL.
573    """
574    if dtype == 'sasview':
575        return eval_sasview(model_info, data)
576    elif dtype.endswith('!'):
577        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
578    else:
579        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
580
581def _show_invalid(data, theory):
582    # type: (Data, np.ma.ndarray) -> None
583    """
584    Display a list of the non-finite values in theory.
585    """
586    if not theory.mask.any():
587        return
588
589    if hasattr(data, 'x'):
590        bad = zip(data.x[theory.mask], theory[theory.mask])
591        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
592
593
594def compare(opts, limits=None):
595    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
596    """
597    Preform a comparison using options from the command line.
598
599    *limits* are the limits on the values to use, either to set the y-axis
600    for 1D or to set the colormap scale for 2D.  If None, then they are
601    inferred from the data and returned. When exploring using Bumps,
602    the limits are set when the model is initially called, and maintained
603    as the values are adjusted, making it easier to see the effects of the
604    parameters.
605    """
606    n_base, n_comp = opts['n1'], opts['n2']
607    pars = opts['pars']
608    data = opts['data']
609
610    # silence the linter
611    base = opts['engines'][0] if n_base else None
612    comp = opts['engines'][1] if n_comp else None
613    base_time = comp_time = None
614    base_value = comp_value = resid = relerr = None
615
616    # Base calculation
617    if n_base > 0:
618        try:
619            base_raw, base_time = time_calculation(base, pars, n_base)
620            base_value = np.ma.masked_invalid(base_raw)
621            print("%s t=%.2f ms, intensity=%.0f"
622                  % (base.engine, base_time, base_value.sum()))
623            _show_invalid(data, base_value)
624        except ImportError:
625            traceback.print_exc()
626            n_base = 0
627
628    # Comparison calculation
629    if n_comp > 0:
630        try:
631            comp_raw, comp_time = time_calculation(comp, pars, n_comp)
632            comp_value = np.ma.masked_invalid(comp_raw)
633            print("%s t=%.2f ms, intensity=%.0f"
634                  % (comp.engine, comp_time, comp_value.sum()))
635            _show_invalid(data, comp_value)
636        except ImportError:
637            traceback.print_exc()
638            n_comp = 0
639
640    # Compare, but only if computing both forms
641    if n_base > 0 and n_comp > 0:
642        resid = (base_value - comp_value)
643        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
644        _print_stats("|%s-%s|"
645                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
646                     resid)
647        _print_stats("|(%s-%s)/%s|"
648                     % (base.engine, comp.engine, comp.engine),
649                     relerr)
650
651    # Plot if requested
652    if not opts['plot'] and not opts['explore']: return
653    view = opts['view']
654    import matplotlib.pyplot as plt
655    if limits is None:
656        vmin, vmax = np.Inf, -np.Inf
657        if n_base > 0:
658            vmin = min(vmin, base_value.min())
659            vmax = max(vmax, base_value.max())
660        if n_comp > 0:
661            vmin = min(vmin, comp_value.min())
662            vmax = max(vmax, comp_value.max())
663        limits = vmin, vmax
664
665    if n_base > 0:
666        if n_comp > 0: plt.subplot(131)
667        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
668        plt.title("%s t=%.2f ms"%(base.engine, base_time))
669        #cbar_title = "log I"
670    if n_comp > 0:
671        if n_base > 0: plt.subplot(132)
672        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
673        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
674        #cbar_title = "log I"
675    if n_comp > 0 and n_base > 0:
676        plt.subplot(133)
677        if not opts['rel_err']:
678            err, errstr, errview = resid, "abs err", "linear"
679        else:
680            err, errstr, errview = abs(relerr), "rel err", "log"
681        #err,errstr = base/comp,"ratio"
682        plot_theory(data, None, resid=err, view=errview, use_data=False)
683        if view == 'linear':
684            plt.xscale('linear')
685        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
686        #cbar_title = errstr if errview=="linear" else "log "+errstr
687    #if is2D:
688    #    h = plt.colorbar()
689    #    h.ax.set_title(cbar_title)
690    fig = plt.gcf()
691    fig.suptitle(opts['name'])
692
693    if n_comp > 0 and n_base > 0 and '-hist' in opts:
694        plt.figure()
695        v = relerr
696        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
697        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
698        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
699                   % (base.engine, comp.engine, comp.engine))
700        plt.ylabel('P(err)')
701        plt.title('Distribution of relative error between calculation engines')
702
703    if not opts['explore']:
704        plt.show()
705
706    return limits
707
708def _print_stats(label, err):
709    # type: (str, np.ma.ndarray) -> None
710    # work with trimmed data, not the full set
711    sorted_err = np.sort(abs(err.compressed()))
712    p50 = int((len(sorted_err)-1)*0.50)
713    p98 = int((len(sorted_err)-1)*0.98)
714    data = [
715        "max:%.3e"%sorted_err[-1],
716        "median:%.3e"%sorted_err[p50],
717        "98%%:%.3e"%sorted_err[p98],
718        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
719        "zero-offset:%+.3e"%np.mean(sorted_err),
720        ]
721    print(label+"  "+"  ".join(data))
722
723
724
725# ===========================================================================
726#
727NAME_OPTIONS = set([
728    'plot', 'noplot',
729    'half', 'fast', 'single', 'double',
730    'single!', 'double!', 'quad!', 'sasview',
731    'lowq', 'midq', 'highq', 'exq', 'zero',
732    '2d', '1d',
733    'preset', 'random',
734    'poly', 'mono',
735    'nopars', 'pars',
736    'rel', 'abs',
737    'linear', 'log', 'q4',
738    'hist', 'nohist',
739    'edit', 'html',
740    'demo', 'default',
741    ])
742VALUE_OPTIONS = [
743    # Note: random is both a name option and a value option
744    'cutoff', 'random', 'nq', 'res', 'accuracy',
745    ]
746
747def columnize(items, indent="", width=79):
748    # type: (List[str], str, int) -> str
749    """
750    Format a list of strings into columns.
751
752    Returns a string with carriage returns ready for printing.
753    """
754    column_width = max(len(w) for w in items) + 1
755    num_columns = (width - len(indent)) // column_width
756    num_rows = len(items) // num_columns
757    items = items + [""] * (num_rows * num_columns - len(items))
758    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
759    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
760             for row in zip(*columns)]
761    output = indent + ("\n"+indent).join(lines)
762    return output
763
764
765def get_pars(model_info, use_demo=False):
766    # type: (ModelInfo, bool) -> ParameterSet
767    """
768    Extract demo parameters from the model definition.
769    """
770    # Get the default values for the parameters
771    pars = {}
772    for p in model_info.parameters.call_parameters:
773        parts = [('', p.default)]
774        if p.polydisperse:
775            parts.append(('_pd', 0.0))
776            parts.append(('_pd_n', 0))
777            parts.append(('_pd_nsigma', 3.0))
778            parts.append(('_pd_type', "gaussian"))
779        for ext, val in parts:
780            if p.length > 1:
781                dict(("%s%d%s" % (p.id, k, ext), val)
782                     for k in range(1, p.length+1))
783            else:
784                pars[p.id + ext] = val
785
786    # Plug in values given in demo
787    if use_demo:
788        pars.update(model_info.demo)
789    return pars
790
791
792def parse_opts(argv):
793    # type: (List[str]) -> Dict[str, Any]
794    """
795    Parse command line options.
796    """
797    MODELS = core.list_models()
798    flags = [arg for arg in argv
799             if arg.startswith('-')]
800    values = [arg for arg in argv
801              if not arg.startswith('-') and '=' in arg]
802    positional_args = [arg for arg in argv
803            if not arg.startswith('-') and '=' not in arg]
804    models = "\n    ".join("%-15s"%v for v in MODELS)
805    if len(positional_args) == 0:
806        print(USAGE)
807        print("\nAvailable models:")
808        print(columnize(MODELS, indent="  "))
809        return None
810    if len(positional_args) > 3:
811        print("expected parameters: model N1 N2")
812
813    name = positional_args[0]
814    try:
815        model_info = core.load_model_info(name)
816    except ImportError as exc:
817        print(str(exc))
818        print("Could not find model; use one of:\n    " + models)
819        return None
820
821    invalid = [o[1:] for o in flags
822               if o[1:] not in NAME_OPTIONS
823               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
824    if invalid:
825        print("Invalid options: %s"%(", ".join(invalid)))
826        return None
827
828
829    # pylint: disable=bad-whitespace
830    # Interpret the flags
831    opts = {
832        'plot'      : True,
833        'view'      : 'log',
834        'is2d'      : False,
835        'qmax'      : 0.05,
836        'nq'        : 128,
837        'res'       : 0.0,
838        'accuracy'  : 'Low',
839        'cutoff'    : 0.0,
840        'seed'      : -1,  # default to preset
841        'mono'      : False,
842        'show_pars' : False,
843        'show_hist' : False,
844        'rel_err'   : True,
845        'explore'   : False,
846        'use_demo'  : True,
847        'zero'      : False,
848        'html'      : False,
849    }
850    engines = []
851    for arg in flags:
852        if arg == '-noplot':    opts['plot'] = False
853        elif arg == '-plot':    opts['plot'] = True
854        elif arg == '-linear':  opts['view'] = 'linear'
855        elif arg == '-log':     opts['view'] = 'log'
856        elif arg == '-q4':      opts['view'] = 'q4'
857        elif arg == '-1d':      opts['is2d'] = False
858        elif arg == '-2d':      opts['is2d'] = True
859        elif arg == '-exq':     opts['qmax'] = 10.0
860        elif arg == '-highq':   opts['qmax'] = 1.0
861        elif arg == '-midq':    opts['qmax'] = 0.2
862        elif arg == '-lowq':    opts['qmax'] = 0.05
863        elif arg == '-zero':    opts['zero'] = True
864        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
865        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
866        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
867        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
868        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
869        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
870        elif arg == '-preset':  opts['seed'] = -1
871        elif arg == '-mono':    opts['mono'] = True
872        elif arg == '-poly':    opts['mono'] = False
873        elif arg == '-pars':    opts['show_pars'] = True
874        elif arg == '-nopars':  opts['show_pars'] = False
875        elif arg == '-hist':    opts['show_hist'] = True
876        elif arg == '-nohist':  opts['show_hist'] = False
877        elif arg == '-rel':     opts['rel_err'] = True
878        elif arg == '-abs':     opts['rel_err'] = False
879        elif arg == '-half':    engines.append(arg[1:])
880        elif arg == '-fast':    engines.append(arg[1:])
881        elif arg == '-single':  engines.append(arg[1:])
882        elif arg == '-double':  engines.append(arg[1:])
883        elif arg == '-single!': engines.append(arg[1:])
884        elif arg == '-double!': engines.append(arg[1:])
885        elif arg == '-quad!':   engines.append(arg[1:])
886        elif arg == '-sasview': engines.append(arg[1:])
887        elif arg == '-edit':    opts['explore'] = True
888        elif arg == '-demo':    opts['use_demo'] = True
889        elif arg == '-default':    opts['use_demo'] = False
890        elif arg == '-html':    opts['html'] = True
891    # pylint: enable=bad-whitespace
892
893    if len(engines) == 0:
894        engines.extend(['single', 'double'])
895    elif len(engines) == 1:
896        if engines[0][0] == 'double':
897            engines.append('single')
898        else:
899            engines.append('double')
900    elif len(engines) > 2:
901        del engines[2:]
902
903    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
904    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
905    use_sasview = any(engine == 'sasview' and count > 0
906                      for engine, count in zip(engines, [n1, n2]))
907
908    # Get demo parameters from model definition, or use default parameters
909    # if model does not define demo parameters
910    pars = get_pars(model_info, opts['use_demo'])
911
912
913    # Fill in parameters given on the command line
914    presets = {}
915    for arg in values:
916        k, v = arg.split('=', 1)
917        if k not in pars:
918            # extract base name without polydispersity info
919            s = set(p.split('_pd')[0] for p in pars)
920            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
921            return None
922        presets[k] = float(v) if not k.endswith('type') else v
923
924    # randomize parameters
925    #pars.update(set_pars)  # set value before random to control range
926    if opts['seed'] > -1:
927        pars = randomize_pars(model_info, pars, seed=opts['seed'])
928        print("Randomize using -random=%i"%opts['seed'])
929    if opts['mono']:
930        pars = suppress_pd(pars)
931    pars.update(presets)  # set value after random to control value
932    #import pprint; pprint.pprint(model_info)
933    constrain_pars(model_info, pars)
934    if use_sasview:
935        constrain_new_to_old(model_info, pars)
936    if opts['show_pars']:
937        print(str(parlist(model_info, pars, opts['is2d'])))
938
939    # Create the computational engines
940    data, _ = make_data(opts)
941    if n1:
942        base = make_engine(model_info, data, engines[0], opts['cutoff'])
943    else:
944        base = None
945    if n2:
946        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
947    else:
948        comp = None
949
950    # pylint: disable=bad-whitespace
951    # Remember it all
952    opts.update({
953        'name'      : name,
954        'def'       : model_info,
955        'n1'        : n1,
956        'n2'        : n2,
957        'presets'   : presets,
958        'pars'      : pars,
959        'data'      : data,
960        'engines'   : [base, comp],
961    })
962    # pylint: enable=bad-whitespace
963
964    return opts
965
966def show_docs(opts):
967    # type: (Dict[str, Any]) -> None
968    """
969    show html docs for the model
970    """
971    import wx  # type: ignore
972    from .generate import view_html_from_info
973    app = wx.App() if wx.GetApp() is None else None
974    view_html_from_info(opts['def'])
975    if app: app.MainLoop()
976
977
978def explore(opts):
979    # type: (Dict[str, Any]) -> None
980    """
981    explore the model using the bumps gui.
982    """
983    import wx  # type: ignore
984    from bumps.names import FitProblem  # type: ignore
985    from bumps.gui.app_frame import AppFrame  # type: ignore
986
987    is_mac = "cocoa" in wx.version()
988    # Create an app if not running embedded
989    app = wx.App() if wx.GetApp() is None else None
990    problem = FitProblem(Explore(opts))
991    frame = AppFrame(parent=None, title="explore", size=(1000,700))
992    if not is_mac: frame.Show()
993    frame.panel.set_model(model=problem)
994    frame.panel.Layout()
995    frame.panel.aui.Split(0, wx.TOP)
996    if is_mac: frame.Show()
997    # If running withing an app, start the main loop
998    if app: app.MainLoop()
999
1000class Explore(object):
1001    """
1002    Bumps wrapper for a SAS model comparison.
1003
1004    The resulting object can be used as a Bumps fit problem so that
1005    parameters can be adjusted in the GUI, with plots updated on the fly.
1006    """
1007    def __init__(self, opts):
1008        # type: (Dict[str, Any]) -> None
1009        from bumps.cli import config_matplotlib  # type: ignore
1010        from . import bumps_model
1011        config_matplotlib()
1012        self.opts = opts
1013        model_info = opts['def']
1014        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
1015        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1016        if not opts['is2d']:
1017            for p in model_info.parameters.user_parameters(is2d=False):
1018                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1019                    k = p.name+ext
1020                    v = pars.get(k, None)
1021                    if v is not None:
1022                        v.range(*parameter_range(k, v.value))
1023        else:
1024            for k, v in pars.items():
1025                v.range(*parameter_range(k, v.value))
1026
1027        self.pars = pars
1028        self.pd_types = pd_types
1029        self.limits = None
1030
1031    def numpoints(self):
1032        # type: () -> int
1033        """
1034        Return the number of points.
1035        """
1036        return len(self.pars) + 1  # so dof is 1
1037
1038    def parameters(self):
1039        # type: () -> Any   # Dict/List hierarchy of parameters
1040        """
1041        Return a dictionary of parameters.
1042        """
1043        return self.pars
1044
1045    def nllf(self):
1046        # type: () -> float
1047        """
1048        Return cost.
1049        """
1050        # pylint: disable=no-self-use
1051        return 0.  # No nllf
1052
1053    def plot(self, view='log'):
1054        # type: (str) -> None
1055        """
1056        Plot the data and residuals.
1057        """
1058        pars = dict((k, v.value) for k, v in self.pars.items())
1059        pars.update(self.pd_types)
1060        self.opts['pars'] = pars
1061        limits = compare(self.opts, limits=self.limits)
1062        if self.limits is None:
1063            vmin, vmax = limits
1064            self.limits = vmax*1e-7, 1.3*vmax
1065
1066
1067def main(*argv):
1068    # type: (*str) -> None
1069    """
1070    Main program.
1071    """
1072    opts = parse_opts(argv)
1073    if opts is not None:
1074        if opts['html']:
1075            show_docs(opts)
1076        elif opts['explore']:
1077            explore(opts)
1078        else:
1079            compare(opts)
1080
1081if __name__ == "__main__":
1082    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.