source: sasmodels/sasmodels/compare.py @ 251f54b

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 251f54b was 158cee4, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

sascomp: apply constraints on random parameters before evaluating expressions

  • Property mode set to 100755
File size: 41.4 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35import re
36
37import numpy as np  # type: ignore
38
39from . import core
40from . import kerneldll
41from . import exception
42from .data import plot_theory, empty_data1D, empty_data2D
43from .direct_model import DirectModel
44from .convert import revert_name, revert_pars, constrain_new_to_old
45from .generate import FLOAT_RE
46
47try:
48    from typing import Optional, Dict, Any, Callable, Tuple
49except Exception:
50    pass
51else:
52    from .modelinfo import ModelInfo, Parameter, ParameterSet
53    from .data import Data
54    Calculator = Callable[[float], np.ndarray]
55
56USAGE = """
57usage: compare.py model N1 N2 [options...] [key=val]
58
59Compare the speed and value for a model between the SasView original and the
60sasmodels rewrite.
61
62model is the name of the model to compare (see below).
63N1 is the number of times to run sasmodels (default=1).
64N2 is the number times to run sasview (default=1).
65
66Options (* for default):
67
68    -plot*/-noplot plots or suppress the plot of the model
69    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
70    -nq=128 sets the number of Q points in the data set
71    -zero indicates that q=0 should be included
72    -1d*/-2d computes 1d or 2d data
73    -preset*/-random[=seed] preset or random parameters
74    -mono/-poly* force monodisperse/polydisperse
75    -magnetic/-nonmagnetic* suppress magnetism
76    -cutoff=1e-5* cutoff value for including a point in polydispersity
77    -pars/-nopars* prints the parameter set or not
78    -abs/-rel* plot relative or absolute error
79    -linear/-log*/-q4 intensity scaling
80    -hist/-nohist* plot histogram of relative error
81    -res=0 sets the resolution width dQ/Q if calculating with resolution
82    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
83    -edit starts the parameter explorer
84    -default/-demo* use demo vs default parameters
85    -html shows the model docs instead of running the model
86    -title="note" adds note to the plot title, after the model name
87
88Any two calculation engines can be selected for comparison:
89
90    -single/-double/-half/-fast sets an OpenCL calculation engine
91    -single!/-double!/-quad! sets an OpenMP calculation engine
92    -sasview sets the sasview calculation engine
93
94The default is -single -sasview.  Note that the interpretation of quad
95precision depends on architecture, and may vary from 64-bit to 128-bit,
96with 80-bit floats being common (1e-19 precision).
97
98Key=value pairs allow you to set specific values for the model parameters.
99"""
100
101# Update docs with command line usage string.   This is separate from the usual
102# doc string so that we can display it at run time if there is an error.
103# lin
104__doc__ = (__doc__  # pylint: disable=redefined-builtin
105           + """
106Program description
107-------------------
108
109"""
110           + USAGE)
111
112kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
113
114# list of math functions for use in evaluating parameters
115MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
116
117# CRUFT python 2.6
118if not hasattr(datetime.timedelta, 'total_seconds'):
119    def delay(dt):
120        """Return number date-time delta as number seconds"""
121        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
122else:
123    def delay(dt):
124        """Return number date-time delta as number seconds"""
125        return dt.total_seconds()
126
127
128class push_seed(object):
129    """
130    Set the seed value for the random number generator.
131
132    When used in a with statement, the random number generator state is
133    restored after the with statement is complete.
134
135    :Parameters:
136
137    *seed* : int or array_like, optional
138        Seed for RandomState
139
140    :Example:
141
142    Seed can be used directly to set the seed::
143
144        >>> from numpy.random import randint
145        >>> push_seed(24)
146        <...push_seed object at...>
147        >>> print(randint(0,1000000,3))
148        [242082    899 211136]
149
150    Seed can also be used in a with statement, which sets the random
151    number generator state for the enclosed computations and restores
152    it to the previous state on completion::
153
154        >>> with push_seed(24):
155        ...    print(randint(0,1000000,3))
156        [242082    899 211136]
157
158    Using nested contexts, we can demonstrate that state is indeed
159    restored after the block completes::
160
161        >>> with push_seed(24):
162        ...    print(randint(0,1000000))
163        ...    with push_seed(24):
164        ...        print(randint(0,1000000,3))
165        ...    print(randint(0,1000000))
166        242082
167        [242082    899 211136]
168        899
169
170    The restore step is protected against exceptions in the block::
171
172        >>> with push_seed(24):
173        ...    print(randint(0,1000000))
174        ...    try:
175        ...        with push_seed(24):
176        ...            print(randint(0,1000000,3))
177        ...            raise Exception()
178        ...    except Exception:
179        ...        print("Exception raised")
180        ...    print(randint(0,1000000))
181        242082
182        [242082    899 211136]
183        Exception raised
184        899
185    """
186    def __init__(self, seed=None):
187        # type: (Optional[int]) -> None
188        self._state = np.random.get_state()
189        np.random.seed(seed)
190
191    def __enter__(self):
192        # type: () -> None
193        pass
194
195    def __exit__(self, exc_type, exc_value, traceback):
196        # type: (Any, BaseException, Any) -> None
197        # TODO: better typing for __exit__ method
198        np.random.set_state(self._state)
199
200def tic():
201    # type: () -> Callable[[], float]
202    """
203    Timer function.
204
205    Use "toc=tic()" to start the clock and "toc()" to measure
206    a time interval.
207    """
208    then = datetime.datetime.now()
209    return lambda: delay(datetime.datetime.now() - then)
210
211
212def set_beam_stop(data, radius, outer=None):
213    # type: (Data, float, float) -> None
214    """
215    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
216
217    Note: this function does not require sasview
218    """
219    if hasattr(data, 'qx_data'):
220        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
221        data.mask = (q < radius)
222        if outer is not None:
223            data.mask |= (q >= outer)
224    else:
225        data.mask = (data.x < radius)
226        if outer is not None:
227            data.mask |= (data.x >= outer)
228
229
230def parameter_range(p, v):
231    # type: (str, float) -> Tuple[float, float]
232    """
233    Choose a parameter range based on parameter name and initial value.
234    """
235    # process the polydispersity options
236    if p.endswith('_pd_n'):
237        return 0., 100.
238    elif p.endswith('_pd_nsigma'):
239        return 0., 5.
240    elif p.endswith('_pd_type'):
241        raise ValueError("Cannot return a range for a string value")
242    elif any(s in p for s in ('theta', 'phi', 'psi')):
243        # orientation in [-180,180], orientation pd in [0,45]
244        if p.endswith('_pd'):
245            return 0., 45.
246        else:
247            return -180., 180.
248    elif p.endswith('_pd'):
249        return 0., 1.
250    elif 'sld' in p:
251        return -0.5, 10.
252    elif p == 'background':
253        return 0., 10.
254    elif p == 'scale':
255        return 0., 1.e3
256    elif v < 0.:
257        return 2.*v, -2.*v
258    else:
259        return 0., (2.*v if v > 0. else 1.)
260
261
262def _randomize_one(model_info, p, v):
263    # type: (ModelInfo, str, float) -> float
264    # type: (ModelInfo, str, str) -> str
265    """
266    Randomize a single parameter.
267    """
268    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
269        return v
270
271    # Find the parameter definition
272    for par in model_info.parameters.call_parameters:
273        if par.name == p:
274            break
275    else:
276        raise ValueError("unknown parameter %r"%p)
277
278    if len(par.limits) > 2:  # choice list
279        return np.random.randint(len(par.limits))
280
281    limits = par.limits
282    if not np.isfinite(limits).all():
283        low, high = parameter_range(p, v)
284        limits = (max(limits[0], low), min(limits[1], high))
285
286    return np.random.uniform(*limits)
287
288
289def randomize_pars(model_info, pars, seed=None):
290    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
291    """
292    Generate random values for all of the parameters.
293
294    Valid ranges for the random number generator are guessed from the name of
295    the parameter; this will not account for constraints such as cap radius
296    greater than cylinder radius in the capped_cylinder model, so
297    :func:`constrain_pars` needs to be called afterward..
298    """
299    with push_seed(seed):
300        # Note: the sort guarantees order `of calls to random number generator
301        random_pars = dict((p, _randomize_one(model_info, p, v))
302                           for p, v in sorted(pars.items()))
303    return random_pars
304
305def constrain_pars(model_info, pars):
306    # type: (ModelInfo, ParameterSet) -> None
307    """
308    Restrict parameters to valid values.
309
310    This includes model specific code for models such as capped_cylinder
311    which need to support within model constraints (cap radius more than
312    cylinder radius in this case).
313
314    Warning: this updates the *pars* dictionary in place.
315    """
316    name = model_info.id
317    # if it is a product model, then just look at the form factor since
318    # none of the structure factors need any constraints.
319    if '*' in name:
320        name = name.split('*')[0]
321
322    if name == 'barbell':
323        if pars['radius_bell'] < pars['radius']:
324            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
325
326    elif name == 'capped_cylinder':
327        if pars['radius_cap'] < pars['radius']:
328            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
329
330    elif name == 'guinier':
331        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
332        #q_max = 0.2  # mid q maximum
333        q_max = 1.0  # high q maximum
334        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
335        pars['rg'] = min(pars['rg'], rg_max)
336
337    elif name == 'rpa':
338        # Make sure phi sums to 1.0
339        if pars['case_num'] < 2:
340            pars['Phi1'] = 0.
341            pars['Phi2'] = 0.
342        elif pars['case_num'] < 5:
343            pars['Phi1'] = 0.
344        total = sum(pars['Phi'+c] for c in '1234')
345        for c in '1234':
346            pars['Phi'+c] /= total
347
348    elif name == 'stacked_disks':
349        pars['n_stacking'] = math.ceil(pars['n_stacking'])
350
351def parlist(model_info, pars, is2d):
352    # type: (ModelInfo, ParameterSet, bool) -> str
353    """
354    Format the parameter list for printing.
355    """
356    lines = []
357    parameters = model_info.parameters
358    magnetic = False
359    for p in parameters.user_parameters(pars, is2d):
360        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
361            continue
362        if p.id.startswith('up:') and not magnetic:
363            continue
364        fields = dict(
365            value=pars.get(p.id, p.default),
366            pd=pars.get(p.id+"_pd", 0.),
367            n=int(pars.get(p.id+"_pd_n", 0)),
368            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
369            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
370            relative_pd=p.relative_pd,
371            M0=pars.get('M0:'+p.id, 0.),
372            mphi=pars.get('mphi:'+p.id, 0.),
373            mtheta=pars.get('mtheta:'+p.id, 0.),
374        )
375        lines.append(_format_par(p.name, **fields))
376        magnetic = magnetic or fields['M0'] != 0.
377    return "\n".join(lines)
378
379    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
380
381def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
382                relative_pd=False, M0=0., mphi=0., mtheta=0.):
383    # type: (str, float, float, int, float, str) -> str
384    line = "%s: %g"%(name, value)
385    if pd != 0.  and n != 0:
386        if relative_pd:
387            pd *= value
388        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
389                % (pd, n, nsigma, nsigma, pdtype)
390    if M0 != 0.:
391        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
392    return line
393
394def suppress_pd(pars):
395    # type: (ParameterSet) -> ParameterSet
396    """
397    Suppress theta_pd for now until the normalization is resolved.
398
399    May also suppress complete polydispersity of the model to test
400    models more quickly.
401    """
402    pars = pars.copy()
403    for p in pars:
404        if p.endswith("_pd_n"): pars[p] = 0
405    return pars
406
407def suppress_magnetism(pars):
408    # type: (ParameterSet) -> ParameterSet
409    """
410    Suppress theta_pd for now until the normalization is resolved.
411
412    May also suppress complete polydispersity of the model to test
413    models more quickly.
414    """
415    pars = pars.copy()
416    for p in pars:
417        if p.startswith("M0:"): pars[p] = 0
418    return pars
419
420def eval_sasview(model_info, data):
421    # type: (Modelinfo, Data) -> Calculator
422    """
423    Return a model calculator using the pre-4.0 SasView models.
424    """
425    # importing sas here so that the error message will be that sas failed to
426    # import rather than the more obscure smear_selection not imported error
427    import sas
428    import sas.models
429    from sas.models.qsmearing import smear_selection
430    from sas.models.MultiplicationModel import MultiplicationModel
431    from sas.models.dispersion_models import models as dispersers
432
433    def get_model_class(name):
434        # type: (str) -> "sas.models.BaseComponent"
435        #print("new",sorted(_pars.items()))
436        __import__('sas.models.' + name)
437        ModelClass = getattr(getattr(sas.models, name, None), name, None)
438        if ModelClass is None:
439            raise ValueError("could not find model %r in sas.models"%name)
440        return ModelClass
441
442    # WARNING: ugly hack when handling model!
443    # Sasview models with multiplicity need to be created with the target
444    # multiplicity, so we cannot create the target model ahead of time for
445    # for multiplicity models.  Instead we store the model in a list and
446    # update the first element of that list with the new multiplicity model
447    # every time we evaluate.
448
449    # grab the sasview model, or create it if it is a product model
450    if model_info.composition:
451        composition_type, parts = model_info.composition
452        if composition_type == 'product':
453            P, S = [get_model_class(revert_name(p))() for p in parts]
454            model = [MultiplicationModel(P, S)]
455        else:
456            raise ValueError("sasview mixture models not supported by compare")
457    else:
458        old_name = revert_name(model_info)
459        if old_name is None:
460            raise ValueError("model %r does not exist in old sasview"
461                            % model_info.id)
462        ModelClass = get_model_class(old_name)
463        model = [ModelClass()]
464    model[0].disperser_handles = {}
465
466    # build a smearer with which to call the model, if necessary
467    smearer = smear_selection(data, model=model)
468    if hasattr(data, 'qx_data'):
469        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
470        index = ((~data.mask) & (~np.isnan(data.data))
471                 & (q >= data.qmin) & (q <= data.qmax))
472        if smearer is not None:
473            smearer.model = model  # because smear_selection has a bug
474            smearer.accuracy = data.accuracy
475            smearer.set_index(index)
476            def _call_smearer():
477                smearer.model = model[0]
478                return smearer.get_value()
479            theory = _call_smearer
480        else:
481            theory = lambda: model[0].evalDistribution([data.qx_data[index],
482                                                        data.qy_data[index]])
483    elif smearer is not None:
484        theory = lambda: smearer(model[0].evalDistribution(data.x))
485    else:
486        theory = lambda: model[0].evalDistribution(data.x)
487
488    def calculator(**pars):
489        # type: (float, ...) -> np.ndarray
490        """
491        Sasview calculator for model.
492        """
493        oldpars = revert_pars(model_info, pars)
494        # For multiplicity models, create a model with the correct multiplicity
495        control = oldpars.pop("CONTROL", None)
496        if control is not None:
497            # sphericalSLD has one fewer multiplicity.  This update should
498            # happen in revert_pars, but it hasn't been called yet.
499            model[0] = ModelClass(control)
500        # paying for parameter conversion each time to keep life simple, if not fast
501        for k, v in oldpars.items():
502            if k.endswith('.type'):
503                par = k[:-5]
504                if v == 'gaussian': continue
505                cls = dispersers[v if v != 'rectangle' else 'rectangula']
506                handle = cls()
507                model[0].disperser_handles[par] = handle
508                try:
509                    model[0].set_dispersion(par, handle)
510                except Exception:
511                    exception.annotate_exception("while setting %s to %r"
512                                                 %(par, v))
513                    raise
514
515
516        #print("sasview pars",oldpars)
517        for k, v in oldpars.items():
518            name_attr = k.split('.')  # polydispersity components
519            if len(name_attr) == 2:
520                par, disp_par = name_attr
521                model[0].dispersion[par][disp_par] = v
522            else:
523                model[0].setParam(k, v)
524        return theory()
525
526    calculator.engine = "sasview"
527    return calculator
528
529DTYPE_MAP = {
530    'half': '16',
531    'fast': 'fast',
532    'single': '32',
533    'double': '64',
534    'quad': '128',
535    'f16': '16',
536    'f32': '32',
537    'f64': '64',
538    'longdouble': '128',
539}
540def eval_opencl(model_info, data, dtype='single', cutoff=0.):
541    # type: (ModelInfo, Data, str, float) -> Calculator
542    """
543    Return a model calculator using the OpenCL calculation engine.
544    """
545    if not core.HAVE_OPENCL:
546        raise RuntimeError("OpenCL not available")
547    model = core.build_model(model_info, dtype=dtype, platform="ocl")
548    calculator = DirectModel(data, model, cutoff=cutoff)
549    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
550    return calculator
551
552def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
553    # type: (ModelInfo, Data, str, float) -> Calculator
554    """
555    Return a model calculator using the DLL calculation engine.
556    """
557    model = core.build_model(model_info, dtype=dtype, platform="dll")
558    calculator = DirectModel(data, model, cutoff=cutoff)
559    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
560    return calculator
561
562def time_calculation(calculator, pars, evals=1):
563    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
564    """
565    Compute the average calculation time over N evaluations.
566
567    An additional call is generated without polydispersity in order to
568    initialize the calculation engine, and make the average more stable.
569    """
570    # initialize the code so time is more accurate
571    if evals > 1:
572        calculator(**suppress_pd(pars))
573    toc = tic()
574    # make sure there is at least one eval
575    value = calculator(**pars)
576    for _ in range(evals-1):
577        value = calculator(**pars)
578    average_time = toc()*1000. / evals
579    #print("I(q)",value)
580    return value, average_time
581
582def make_data(opts):
583    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
584    """
585    Generate an empty dataset, used with the model to set Q points
586    and resolution.
587
588    *opts* contains the options, with 'qmax', 'nq', 'res',
589    'accuracy', 'is2d' and 'view' parsed from the command line.
590    """
591    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
592    if opts['is2d']:
593        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
594        data = empty_data2D(q, resolution=res)
595        data.accuracy = opts['accuracy']
596        set_beam_stop(data, 0.0004)
597        index = ~data.mask
598    else:
599        if opts['view'] == 'log' and not opts['zero']:
600            qmax = math.log10(qmax)
601            q = np.logspace(qmax-3, qmax, nq)
602        else:
603            q = np.linspace(0.001*qmax, qmax, nq)
604        if opts['zero']:
605            q = np.hstack((0, q))
606        data = empty_data1D(q, resolution=res)
607        index = slice(None, None)
608    return data, index
609
610def make_engine(model_info, data, dtype, cutoff):
611    # type: (ModelInfo, Data, str, float) -> Calculator
612    """
613    Generate the appropriate calculation engine for the given datatype.
614
615    Datatypes with '!' appended are evaluated using external C DLLs rather
616    than OpenCL.
617    """
618    if dtype == 'sasview':
619        return eval_sasview(model_info, data)
620    elif dtype.endswith('!'):
621        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
622    else:
623        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
624
625def _show_invalid(data, theory):
626    # type: (Data, np.ma.ndarray) -> None
627    """
628    Display a list of the non-finite values in theory.
629    """
630    if not theory.mask.any():
631        return
632
633    if hasattr(data, 'x'):
634        bad = zip(data.x[theory.mask], theory[theory.mask])
635        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
636
637
638def compare(opts, limits=None):
639    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
640    """
641    Preform a comparison using options from the command line.
642
643    *limits* are the limits on the values to use, either to set the y-axis
644    for 1D or to set the colormap scale for 2D.  If None, then they are
645    inferred from the data and returned. When exploring using Bumps,
646    the limits are set when the model is initially called, and maintained
647    as the values are adjusted, making it easier to see the effects of the
648    parameters.
649    """
650    n_base, n_comp = opts['count']
651    pars, pars2 = opts['pars']
652    data = opts['data']
653
654    # silence the linter
655    base = opts['engines'][0] if n_base else None
656    comp = opts['engines'][1] if n_comp else None
657    base_time = comp_time = None
658    base_value = comp_value = resid = relerr = None
659
660    # Base calculation
661    if n_base > 0:
662        try:
663            base_raw, base_time = time_calculation(base, pars, n_base)
664            base_value = np.ma.masked_invalid(base_raw)
665            print("%s t=%.2f ms, intensity=%.0f"
666                  % (base.engine, base_time, base_value.sum()))
667            _show_invalid(data, base_value)
668        except ImportError:
669            traceback.print_exc()
670            n_base = 0
671
672    # Comparison calculation
673    if n_comp > 0:
674        try:
675            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
676            comp_value = np.ma.masked_invalid(comp_raw)
677            print("%s t=%.2f ms, intensity=%.0f"
678                  % (comp.engine, comp_time, comp_value.sum()))
679            _show_invalid(data, comp_value)
680        except ImportError:
681            traceback.print_exc()
682            n_comp = 0
683
684    # Compare, but only if computing both forms
685    if n_base > 0 and n_comp > 0:
686        resid = (base_value - comp_value)
687        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
688        _print_stats("|%s-%s|"
689                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
690                     resid)
691        _print_stats("|(%s-%s)/%s|"
692                     % (base.engine, comp.engine, comp.engine),
693                     relerr)
694
695    # Plot if requested
696    if not opts['plot'] and not opts['explore']: return
697    view = opts['view']
698    import matplotlib.pyplot as plt
699    if limits is None:
700        vmin, vmax = np.Inf, -np.Inf
701        if n_base > 0:
702            vmin = min(vmin, base_value.min())
703            vmax = max(vmax, base_value.max())
704        if n_comp > 0:
705            vmin = min(vmin, comp_value.min())
706            vmax = max(vmax, comp_value.max())
707        limits = vmin, vmax
708
709    if n_base > 0:
710        if n_comp > 0: plt.subplot(131)
711        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
712        plt.title("%s t=%.2f ms"%(base.engine, base_time))
713        #cbar_title = "log I"
714    if n_comp > 0:
715        if n_base > 0: plt.subplot(132)
716        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
717        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
718        #cbar_title = "log I"
719    if n_comp > 0 and n_base > 0:
720        if not opts['is2d']:
721            plot_theory(data, base_value, view=view, use_data=False, limits=limits)
722        plt.subplot(133)
723        if not opts['rel_err']:
724            err, errstr, errview = resid, "abs err", "linear"
725        else:
726            err, errstr, errview = abs(relerr), "rel err", "log"
727        if 0:  # 95% cutoff
728            sorted = np.sort(err.flatten())
729            cutoff = sorted[int(sorted.size*0.95)]
730            err[err>cutoff] = cutoff
731        #err,errstr = base/comp,"ratio"
732        plot_theory(data, None, resid=err, view=errview, use_data=False)
733        if view == 'linear':
734            plt.xscale('linear')
735        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
736        #cbar_title = errstr if errview=="linear" else "log "+errstr
737    #if is2D:
738    #    h = plt.colorbar()
739    #    h.ax.set_title(cbar_title)
740    fig = plt.gcf()
741    extra_title = ' '+opts['title'] if opts['title'] else ''
742    fig.suptitle(":".join(opts['name']) + extra_title)
743
744    if n_comp > 0 and n_base > 0 and opts['show_hist']:
745        plt.figure()
746        v = relerr
747        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
748        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
749        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
750                   % (base.engine, comp.engine, comp.engine))
751        plt.ylabel('P(err)')
752        plt.title('Distribution of relative error between calculation engines')
753
754    if not opts['explore']:
755        plt.show()
756
757    return limits
758
759def _print_stats(label, err):
760    # type: (str, np.ma.ndarray) -> None
761    # work with trimmed data, not the full set
762    sorted_err = np.sort(abs(err.compressed()))
763    p50 = int((len(sorted_err)-1)*0.50)
764    p98 = int((len(sorted_err)-1)*0.98)
765    data = [
766        "max:%.3e"%sorted_err[-1],
767        "median:%.3e"%sorted_err[p50],
768        "98%%:%.3e"%sorted_err[p98],
769        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
770        "zero-offset:%+.3e"%np.mean(sorted_err),
771        ]
772    print(label+"  "+"  ".join(data))
773
774
775
776# ===========================================================================
777#
778NAME_OPTIONS = set([
779    'plot', 'noplot',
780    'half', 'fast', 'single', 'double',
781    'single!', 'double!', 'quad!', 'sasview',
782    'lowq', 'midq', 'highq', 'exq', 'zero',
783    '2d', '1d',
784    'preset', 'random',
785    'poly', 'mono',
786    'magnetic', 'nonmagnetic',
787    'nopars', 'pars',
788    'rel', 'abs',
789    'linear', 'log', 'q4',
790    'hist', 'nohist',
791    'edit', 'html',
792    'demo', 'default',
793    ])
794VALUE_OPTIONS = [
795    # Note: random is both a name option and a value option
796    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title',
797    ]
798
799def columnize(items, indent="", width=79):
800    # type: (List[str], str, int) -> str
801    """
802    Format a list of strings into columns.
803
804    Returns a string with carriage returns ready for printing.
805    """
806    column_width = max(len(w) for w in items) + 1
807    num_columns = (width - len(indent)) // column_width
808    num_rows = len(items) // num_columns
809    items = items + [""] * (num_rows * num_columns - len(items))
810    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
811    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
812             for row in zip(*columns)]
813    output = indent + ("\n"+indent).join(lines)
814    return output
815
816
817def get_pars(model_info, use_demo=False):
818    # type: (ModelInfo, bool) -> ParameterSet
819    """
820    Extract demo parameters from the model definition.
821    """
822    # Get the default values for the parameters
823    pars = {}
824    for p in model_info.parameters.call_parameters:
825        parts = [('', p.default)]
826        if p.polydisperse:
827            parts.append(('_pd', 0.0))
828            parts.append(('_pd_n', 0))
829            parts.append(('_pd_nsigma', 3.0))
830            parts.append(('_pd_type', "gaussian"))
831        for ext, val in parts:
832            if p.length > 1:
833                dict(("%s%d%s" % (p.id, k, ext), val)
834                     for k in range(1, p.length+1))
835            else:
836                pars[p.id + ext] = val
837
838    # Plug in values given in demo
839    if use_demo:
840        pars.update(model_info.demo)
841    return pars
842
843INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
844def isnumber(str):
845    match = FLOAT_RE.match(str)
846    isfloat = (match and not str[match.end():])
847    return isfloat or INTEGER_RE.match(str)
848
849def parse_opts(argv):
850    # type: (List[str]) -> Dict[str, Any]
851    """
852    Parse command line options.
853    """
854    MODELS = core.list_models()
855    flags = [arg for arg in argv
856             if arg.startswith('-')]
857    values = [arg for arg in argv
858              if not arg.startswith('-') and '=' in arg]
859    positional_args = [arg for arg in argv
860            if not arg.startswith('-') and '=' not in arg]
861    models = "\n    ".join("%-15s"%v for v in MODELS)
862    if len(positional_args) == 0:
863        print(USAGE)
864        print("\nAvailable models:")
865        print(columnize(MODELS, indent="  "))
866        return None
867    if len(positional_args) > 3:
868        print("expected parameters: model N1 N2")
869
870    invalid = [o[1:] for o in flags
871               if o[1:] not in NAME_OPTIONS
872               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
873    if invalid:
874        print("Invalid options: %s"%(", ".join(invalid)))
875        return None
876
877    name = positional_args[0]
878    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
879    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
880
881    # pylint: disable=bad-whitespace
882    # Interpret the flags
883    opts = {
884        'plot'      : True,
885        'view'      : 'log',
886        'is2d'      : False,
887        'qmax'      : 0.05,
888        'nq'        : 128,
889        'res'       : 0.0,
890        'accuracy'  : 'Low',
891        'cutoff'    : 0.0,
892        'seed'      : -1,  # default to preset
893        'mono'      : False,
894        # Default to magnetic a magnetic moment is set on the command line
895        'magnetic'  : False,
896        'show_pars' : False,
897        'show_hist' : False,
898        'rel_err'   : True,
899        'explore'   : False,
900        'use_demo'  : True,
901        'zero'      : False,
902        'html'      : False,
903        'title'     : None,
904    }
905    engines = []
906    for arg in flags:
907        if arg == '-noplot':    opts['plot'] = False
908        elif arg == '-plot':    opts['plot'] = True
909        elif arg == '-linear':  opts['view'] = 'linear'
910        elif arg == '-log':     opts['view'] = 'log'
911        elif arg == '-q4':      opts['view'] = 'q4'
912        elif arg == '-1d':      opts['is2d'] = False
913        elif arg == '-2d':      opts['is2d'] = True
914        elif arg == '-exq':     opts['qmax'] = 10.0
915        elif arg == '-highq':   opts['qmax'] = 1.0
916        elif arg == '-midq':    opts['qmax'] = 0.2
917        elif arg == '-lowq':    opts['qmax'] = 0.05
918        elif arg == '-zero':    opts['zero'] = True
919        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
920        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
921        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
922        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
923        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
924        elif arg.startswith('-title'):     opts['title'] = arg[7:]
925        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
926        elif arg == '-preset':  opts['seed'] = -1
927        elif arg == '-mono':    opts['mono'] = True
928        elif arg == '-poly':    opts['mono'] = False
929        elif arg == '-magnetic':       opts['magnetic'] = True
930        elif arg == '-nonmagnetic':    opts['magnetic'] = False
931        elif arg == '-pars':    opts['show_pars'] = True
932        elif arg == '-nopars':  opts['show_pars'] = False
933        elif arg == '-hist':    opts['show_hist'] = True
934        elif arg == '-nohist':  opts['show_hist'] = False
935        elif arg == '-rel':     opts['rel_err'] = True
936        elif arg == '-abs':     opts['rel_err'] = False
937        elif arg == '-half':    engines.append(arg[1:])
938        elif arg == '-fast':    engines.append(arg[1:])
939        elif arg == '-single':  engines.append(arg[1:])
940        elif arg == '-double':  engines.append(arg[1:])
941        elif arg == '-single!': engines.append(arg[1:])
942        elif arg == '-double!': engines.append(arg[1:])
943        elif arg == '-quad!':   engines.append(arg[1:])
944        elif arg == '-sasview': engines.append(arg[1:])
945        elif arg == '-edit':    opts['explore'] = True
946        elif arg == '-demo':    opts['use_demo'] = True
947        elif arg == '-default':    opts['use_demo'] = False
948        elif arg == '-html':    opts['html'] = True
949    # pylint: enable=bad-whitespace
950
951    if ':' in name:
952        name, name2 = name.split(':',2)
953    else:
954        name2 = name
955    try:
956        model_info = core.load_model_info(name)
957        model_info2 = core.load_model_info(name2) if name2 != name else model_info
958    except ImportError as exc:
959        print(str(exc))
960        print("Could not find model; use one of:\n    " + models)
961        return None
962
963    # Get demo parameters from model definition, or use default parameters
964    # if model does not define demo parameters
965    pars = get_pars(model_info, opts['use_demo'])
966    pars2 = get_pars(model_info2, opts['use_demo'])
967    pars2.update((k, v) for k, v in pars.items() if k in pars2)
968    # randomize parameters
969    #pars.update(set_pars)  # set value before random to control range
970    if opts['seed'] > -1:
971        pars = randomize_pars(model_info, pars, seed=opts['seed'])
972        if model_info != model_info2:
973            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
974            # Share values for parameters with the same name
975            for k, v in pars.items():
976                if k in pars2:
977                    pars2[k] = v
978        else:
979            pars2 = pars.copy()
980        constrain_pars(model_info, pars)
981        constrain_pars(model_info2, pars2)
982        print("Randomize using -random=%i"%opts['seed'])
983    if opts['mono']:
984        pars = suppress_pd(pars)
985        pars2 = suppress_pd(pars2)
986    if not opts['magnetic']:
987        pars = suppress_magnetism(pars)
988        pars2 = suppress_magnetism(pars2)
989
990    # Fill in parameters given on the command line
991    presets = {}
992    presets2 = {}
993    for arg in values:
994        k, v = arg.split('=', 1)
995        if k not in pars and k not in pars2:
996            # extract base name without polydispersity info
997            s = set(p.split('_pd')[0] for p in pars)
998            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
999            return None
1000        v1, v2 = v.split(':',2) if ':' in v else (v,v)
1001        if v1 and k in pars:
1002            presets[k] = float(v1) if isnumber(v1) else v1
1003        if v2 and k in pars2:
1004            presets2[k] = float(v2) if isnumber(v2) else v2
1005
1006    # If pd given on the command line, default pd_n to 35
1007    for k, v in list(presets.items()):
1008        if k.endswith('_pd'):
1009            presets.setdefault(k+'_n', 35.)
1010    for k, v in list(presets2.items()):
1011        if k.endswith('_pd'):
1012            presets2.setdefault(k+'_n', 35.)
1013
1014    # Evaluate preset parameter expressions
1015    context = MATH.copy()
1016    context.update(pars)
1017    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
1018    for k, v in presets.items():
1019        if not isinstance(v, float) and not k.endswith('_type'):
1020            presets[k] = eval(v, context)
1021    context.update(presets)
1022    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
1023    for k, v in presets2.items():
1024        if not isinstance(v, float) and not k.endswith('_type'):
1025            presets2[k] = eval(v, context)
1026
1027    # update parameters with presets
1028    pars.update(presets)  # set value after random to control value
1029    pars2.update(presets2)  # set value after random to control value
1030    #import pprint; pprint.pprint(model_info)
1031
1032    same_model = name == name2 and pars == pars
1033    if len(engines) == 0:
1034        if same_model:
1035            engines.extend(['single', 'double'])
1036        else:
1037            engines.extend(['single', 'single'])
1038    elif len(engines) == 1:
1039        if not same_model:
1040            engines.append(engines[0])
1041        elif engines[0] == 'double':
1042            engines.append('single')
1043        else:
1044            engines.append('double')
1045    elif len(engines) > 2:
1046        del engines[2:]
1047
1048    use_sasview = any(engine == 'sasview' and count > 0
1049                      for engine, count in zip(engines, [n1, n2]))
1050    if use_sasview:
1051        constrain_new_to_old(model_info, pars)
1052        constrain_new_to_old(model_info2, pars2)
1053
1054    if opts['show_pars']:
1055        if not same_model:
1056            print("==== %s ====="%model_info.name)
1057            print(str(parlist(model_info, pars, opts['is2d'])))
1058            print("==== %s ====="%model_info2.name)
1059            print(str(parlist(model_info2, pars2, opts['is2d'])))
1060        else:
1061            print(str(parlist(model_info, pars, opts['is2d'])))
1062
1063    # Create the computational engines
1064    data, _ = make_data(opts)
1065    if n1:
1066        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1067    else:
1068        base = None
1069    if n2:
1070        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1071    else:
1072        comp = None
1073
1074    # pylint: disable=bad-whitespace
1075    # Remember it all
1076    opts.update({
1077        'data'      : data,
1078        'name'      : [name, name2],
1079        'def'       : [model_info, model_info2],
1080        'count'     : [n1, n2],
1081        'presets'   : [presets, presets2],
1082        'pars'      : [pars, pars2],
1083        'engines'   : [base, comp],
1084    })
1085    # pylint: enable=bad-whitespace
1086
1087    return opts
1088
1089def show_docs(opts):
1090    # type: (Dict[str, Any]) -> None
1091    """
1092    show html docs for the model
1093    """
1094    import wx  # type: ignore
1095    from .generate import view_html_from_info
1096    app = wx.App() if wx.GetApp() is None else None
1097    view_html_from_info(opts['def'][0])
1098    if app: app.MainLoop()
1099
1100
1101def explore(opts):
1102    # type: (Dict[str, Any]) -> None
1103    """
1104    explore the model using the bumps gui.
1105    """
1106    import wx  # type: ignore
1107    from bumps.names import FitProblem  # type: ignore
1108    from bumps.gui.app_frame import AppFrame  # type: ignore
1109
1110    is_mac = "cocoa" in wx.version()
1111    # Create an app if not running embedded
1112    app = wx.App() if wx.GetApp() is None else None
1113    problem = FitProblem(Explore(opts))
1114    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1115    if not is_mac: frame.Show()
1116    frame.panel.set_model(model=problem)
1117    frame.panel.Layout()
1118    frame.panel.aui.Split(0, wx.TOP)
1119    if is_mac: frame.Show()
1120    # If running withing an app, start the main loop
1121    if app: app.MainLoop()
1122
1123class Explore(object):
1124    """
1125    Bumps wrapper for a SAS model comparison.
1126
1127    The resulting object can be used as a Bumps fit problem so that
1128    parameters can be adjusted in the GUI, with plots updated on the fly.
1129    """
1130    def __init__(self, opts):
1131        # type: (Dict[str, Any]) -> None
1132        from bumps.cli import config_matplotlib  # type: ignore
1133        from . import bumps_model
1134        config_matplotlib()
1135        self.opts = opts
1136        model_info = opts['def'][0]
1137        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'][0])
1138        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1139        if not opts['is2d']:
1140            for p in model_info.parameters.user_parameters(is2d=False):
1141                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1142                    k = p.name+ext
1143                    v = pars.get(k, None)
1144                    if v is not None:
1145                        v.range(*parameter_range(k, v.value))
1146        else:
1147            for k, v in pars.items():
1148                v.range(*parameter_range(k, v.value))
1149
1150        self.pars = pars
1151        self.pd_types = pd_types
1152        self.limits = None
1153
1154    def numpoints(self):
1155        # type: () -> int
1156        """
1157        Return the number of points.
1158        """
1159        return len(self.pars) + 1  # so dof is 1
1160
1161    def parameters(self):
1162        # type: () -> Any   # Dict/List hierarchy of parameters
1163        """
1164        Return a dictionary of parameters.
1165        """
1166        return self.pars
1167
1168    def nllf(self):
1169        # type: () -> float
1170        """
1171        Return cost.
1172        """
1173        # pylint: disable=no-self-use
1174        return 0.  # No nllf
1175
1176    def plot(self, view='log'):
1177        # type: (str) -> None
1178        """
1179        Plot the data and residuals.
1180        """
1181        pars = dict((k, v.value) for k, v in self.pars.items())
1182        pars.update(self.pd_types)
1183        self.opts['pars'][0] = pars
1184        self.opts['pars'][1] = pars
1185        limits = compare(self.opts, limits=self.limits)
1186        if self.limits is None:
1187            vmin, vmax = limits
1188            self.limits = vmax*1e-7, 1.3*vmax
1189
1190
1191def main(*argv):
1192    # type: (*str) -> None
1193    """
1194    Main program.
1195    """
1196    opts = parse_opts(argv)
1197    if opts is not None:
1198        if opts['html']:
1199            show_docs(opts)
1200        elif opts['explore']:
1201            explore(opts)
1202        else:
1203            compare(opts)
1204
1205if __name__ == "__main__":
1206    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.