source: sasmodels/sasmodels/compare.py @ 248561a

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 248561a was 248561a, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

allow math functions such as sqrt and atan in parameter expressions

  • Property mode set to 100755
File size: 39.4 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35import re
36
37import numpy as np  # type: ignore
38
39from . import core
40from . import kerneldll
41from . import exception
42from .data import plot_theory, empty_data1D, empty_data2D
43from .direct_model import DirectModel
44from .convert import revert_name, revert_pars, constrain_new_to_old
45from .generate import FLOAT_RE
46
47try:
48    from typing import Optional, Dict, Any, Callable, Tuple
49except Exception:
50    pass
51else:
52    from .modelinfo import ModelInfo, Parameter, ParameterSet
53    from .data import Data
54    Calculator = Callable[[float], np.ndarray]
55
56USAGE = """
57usage: compare.py model N1 N2 [options...] [key=val]
58
59Compare the speed and value for a model between the SasView original and the
60sasmodels rewrite.
61
62model is the name of the model to compare (see below).
63N1 is the number of times to run sasmodels (default=1).
64N2 is the number times to run sasview (default=1).
65
66Options (* for default):
67
68    -plot*/-noplot plots or suppress the plot of the model
69    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
70    -nq=128 sets the number of Q points in the data set
71    -zero indicates that q=0 should be included
72    -1d*/-2d computes 1d or 2d data
73    -preset*/-random[=seed] preset or random parameters
74    -mono/-poly* force monodisperse/polydisperse
75    -cutoff=1e-5* cutoff value for including a point in polydispersity
76    -pars/-nopars* prints the parameter set or not
77    -abs/-rel* plot relative or absolute error
78    -linear/-log*/-q4 intensity scaling
79    -hist/-nohist* plot histogram of relative error
80    -res=0 sets the resolution width dQ/Q if calculating with resolution
81    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
82    -edit starts the parameter explorer
83    -default/-demo* use demo vs default parameters
84    -html shows the model docs instead of running the model
85    -title="note" adds note to the plot title, after the model name
86
87Any two calculation engines can be selected for comparison:
88
89    -single/-double/-half/-fast sets an OpenCL calculation engine
90    -single!/-double!/-quad! sets an OpenMP calculation engine
91    -sasview sets the sasview calculation engine
92
93The default is -single -sasview.  Note that the interpretation of quad
94precision depends on architecture, and may vary from 64-bit to 128-bit,
95with 80-bit floats being common (1e-19 precision).
96
97Key=value pairs allow you to set specific values for the model parameters.
98"""
99
100# Update docs with command line usage string.   This is separate from the usual
101# doc string so that we can display it at run time if there is an error.
102# lin
103__doc__ = (__doc__  # pylint: disable=redefined-builtin
104           + """
105Program description
106-------------------
107
108"""
109           + USAGE)
110
111kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
112
113# list of math functions for use in evaluating parameters
114MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
115
116# CRUFT python 2.6
117if not hasattr(datetime.timedelta, 'total_seconds'):
118    def delay(dt):
119        """Return number date-time delta as number seconds"""
120        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
121else:
122    def delay(dt):
123        """Return number date-time delta as number seconds"""
124        return dt.total_seconds()
125
126
127class push_seed(object):
128    """
129    Set the seed value for the random number generator.
130
131    When used in a with statement, the random number generator state is
132    restored after the with statement is complete.
133
134    :Parameters:
135
136    *seed* : int or array_like, optional
137        Seed for RandomState
138
139    :Example:
140
141    Seed can be used directly to set the seed::
142
143        >>> from numpy.random import randint
144        >>> push_seed(24)
145        <...push_seed object at...>
146        >>> print(randint(0,1000000,3))
147        [242082    899 211136]
148
149    Seed can also be used in a with statement, which sets the random
150    number generator state for the enclosed computations and restores
151    it to the previous state on completion::
152
153        >>> with push_seed(24):
154        ...    print(randint(0,1000000,3))
155        [242082    899 211136]
156
157    Using nested contexts, we can demonstrate that state is indeed
158    restored after the block completes::
159
160        >>> with push_seed(24):
161        ...    print(randint(0,1000000))
162        ...    with push_seed(24):
163        ...        print(randint(0,1000000,3))
164        ...    print(randint(0,1000000))
165        242082
166        [242082    899 211136]
167        899
168
169    The restore step is protected against exceptions in the block::
170
171        >>> with push_seed(24):
172        ...    print(randint(0,1000000))
173        ...    try:
174        ...        with push_seed(24):
175        ...            print(randint(0,1000000,3))
176        ...            raise Exception()
177        ...    except Exception:
178        ...        print("Exception raised")
179        ...    print(randint(0,1000000))
180        242082
181        [242082    899 211136]
182        Exception raised
183        899
184    """
185    def __init__(self, seed=None):
186        # type: (Optional[int]) -> None
187        self._state = np.random.get_state()
188        np.random.seed(seed)
189
190    def __enter__(self):
191        # type: () -> None
192        pass
193
194    def __exit__(self, exc_type, exc_value, traceback):
195        # type: (Any, BaseException, Any) -> None
196        # TODO: better typing for __exit__ method
197        np.random.set_state(self._state)
198
199def tic():
200    # type: () -> Callable[[], float]
201    """
202    Timer function.
203
204    Use "toc=tic()" to start the clock and "toc()" to measure
205    a time interval.
206    """
207    then = datetime.datetime.now()
208    return lambda: delay(datetime.datetime.now() - then)
209
210
211def set_beam_stop(data, radius, outer=None):
212    # type: (Data, float, float) -> None
213    """
214    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
215
216    Note: this function does not require sasview
217    """
218    if hasattr(data, 'qx_data'):
219        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
220        data.mask = (q < radius)
221        if outer is not None:
222            data.mask |= (q >= outer)
223    else:
224        data.mask = (data.x < radius)
225        if outer is not None:
226            data.mask |= (data.x >= outer)
227
228
229def parameter_range(p, v):
230    # type: (str, float) -> Tuple[float, float]
231    """
232    Choose a parameter range based on parameter name and initial value.
233    """
234    # process the polydispersity options
235    if p.endswith('_pd_n'):
236        return 0., 100.
237    elif p.endswith('_pd_nsigma'):
238        return 0., 5.
239    elif p.endswith('_pd_type'):
240        raise ValueError("Cannot return a range for a string value")
241    elif any(s in p for s in ('theta', 'phi', 'psi')):
242        # orientation in [-180,180], orientation pd in [0,45]
243        if p.endswith('_pd'):
244            return 0., 45.
245        else:
246            return -180., 180.
247    elif p.endswith('_pd'):
248        return 0., 1.
249    elif 'sld' in p:
250        return -0.5, 10.
251    elif p == 'background':
252        return 0., 10.
253    elif p == 'scale':
254        return 0., 1.e3
255    elif v < 0.:
256        return 2.*v, -2.*v
257    else:
258        return 0., (2.*v if v > 0. else 1.)
259
260
261def _randomize_one(model_info, p, v):
262    # type: (ModelInfo, str, float) -> float
263    # type: (ModelInfo, str, str) -> str
264    """
265    Randomize a single parameter.
266    """
267    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
268        return v
269
270    # Find the parameter definition
271    for par in model_info.parameters.call_parameters:
272        if par.name == p:
273            break
274    else:
275        raise ValueError("unknown parameter %r"%p)
276
277    if len(par.limits) > 2:  # choice list
278        return np.random.randint(len(par.limits))
279
280    limits = par.limits
281    if not np.isfinite(limits).all():
282        low, high = parameter_range(p, v)
283        limits = (max(limits[0], low), min(limits[1], high))
284
285    return np.random.uniform(*limits)
286
287
288def randomize_pars(model_info, pars, seed=None):
289    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
290    """
291    Generate random values for all of the parameters.
292
293    Valid ranges for the random number generator are guessed from the name of
294    the parameter; this will not account for constraints such as cap radius
295    greater than cylinder radius in the capped_cylinder model, so
296    :func:`constrain_pars` needs to be called afterward..
297    """
298    with push_seed(seed):
299        # Note: the sort guarantees order `of calls to random number generator
300        random_pars = dict((p, _randomize_one(model_info, p, v))
301                           for p, v in sorted(pars.items()))
302    return random_pars
303
304def constrain_pars(model_info, pars):
305    # type: (ModelInfo, ParameterSet) -> None
306    """
307    Restrict parameters to valid values.
308
309    This includes model specific code for models such as capped_cylinder
310    which need to support within model constraints (cap radius more than
311    cylinder radius in this case).
312
313    Warning: this updates the *pars* dictionary in place.
314    """
315    name = model_info.id
316    # if it is a product model, then just look at the form factor since
317    # none of the structure factors need any constraints.
318    if '*' in name:
319        name = name.split('*')[0]
320
321    if name == 'capped_cylinder' and pars['radius_cap'] < pars['radius']:
322        pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
323    if name == 'barbell' and pars['radius_bell'] < pars['radius']:
324        pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
325
326    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
327    if name == 'guinier':
328        #q_max = 0.2  # mid q maximum
329        q_max = 1.0  # high q maximum
330        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
331        pars['rg'] = min(pars['rg'], rg_max)
332
333    if name == 'rpa':
334        # Make sure phi sums to 1.0
335        if pars['case_num'] < 2:
336            pars['Phi1'] = 0.
337            pars['Phi2'] = 0.
338        elif pars['case_num'] < 5:
339            pars['Phi1'] = 0.
340        total = sum(pars['Phi'+c] for c in '1234')
341        for c in '1234':
342            pars['Phi'+c] /= total
343
344def parlist(model_info, pars, is2d):
345    # type: (ModelInfo, ParameterSet, bool) -> str
346    """
347    Format the parameter list for printing.
348    """
349    lines = []
350    parameters = model_info.parameters
351    for p in parameters.user_parameters(pars, is2d):
352        fields = dict(
353            value=pars.get(p.id, p.default),
354            pd=pars.get(p.id+"_pd", 0.),
355            n=int(pars.get(p.id+"_pd_n", 0)),
356            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
357            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
358            relative_pd=p.relative_pd,
359        )
360        lines.append(_format_par(p.name, **fields))
361    return "\n".join(lines)
362
363    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
364
365def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
366                relative_pd=False):
367    # type: (str, float, float, int, float, str) -> str
368    line = "%s: %g"%(name, value)
369    if pd != 0.  and n != 0:
370        if relative_pd:
371            pd *= value
372        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
373                % (pd, n, nsigma, nsigma, pdtype)
374    return line
375
376def suppress_pd(pars):
377    # type: (ParameterSet) -> ParameterSet
378    """
379    Suppress theta_pd for now until the normalization is resolved.
380
381    May also suppress complete polydispersity of the model to test
382    models more quickly.
383    """
384    pars = pars.copy()
385    for p in pars:
386        if p.endswith("_pd_n"): pars[p] = 0
387    return pars
388
389def eval_sasview(model_info, data):
390    # type: (Modelinfo, Data) -> Calculator
391    """
392    Return a model calculator using the pre-4.0 SasView models.
393    """
394    # importing sas here so that the error message will be that sas failed to
395    # import rather than the more obscure smear_selection not imported error
396    import sas
397    import sas.models
398    from sas.models.qsmearing import smear_selection
399    from sas.models.MultiplicationModel import MultiplicationModel
400    from sas.models.dispersion_models import models as dispersers
401
402    def get_model_class(name):
403        # type: (str) -> "sas.models.BaseComponent"
404        #print("new",sorted(_pars.items()))
405        __import__('sas.models.' + name)
406        ModelClass = getattr(getattr(sas.models, name, None), name, None)
407        if ModelClass is None:
408            raise ValueError("could not find model %r in sas.models"%name)
409        return ModelClass
410
411    # WARNING: ugly hack when handling model!
412    # Sasview models with multiplicity need to be created with the target
413    # multiplicity, so we cannot create the target model ahead of time for
414    # for multiplicity models.  Instead we store the model in a list and
415    # update the first element of that list with the new multiplicity model
416    # every time we evaluate.
417
418    # grab the sasview model, or create it if it is a product model
419    if model_info.composition:
420        composition_type, parts = model_info.composition
421        if composition_type == 'product':
422            P, S = [get_model_class(revert_name(p))() for p in parts]
423            model = [MultiplicationModel(P, S)]
424        else:
425            raise ValueError("sasview mixture models not supported by compare")
426    else:
427        old_name = revert_name(model_info)
428        if old_name is None:
429            raise ValueError("model %r does not exist in old sasview"
430                            % model_info.id)
431        ModelClass = get_model_class(old_name)
432        model = [ModelClass()]
433    model[0].disperser_handles = {}
434
435    # build a smearer with which to call the model, if necessary
436    smearer = smear_selection(data, model=model)
437    if hasattr(data, 'qx_data'):
438        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
439        index = ((~data.mask) & (~np.isnan(data.data))
440                 & (q >= data.qmin) & (q <= data.qmax))
441        if smearer is not None:
442            smearer.model = model  # because smear_selection has a bug
443            smearer.accuracy = data.accuracy
444            smearer.set_index(index)
445            def _call_smearer():
446                smearer.model = model[0]
447                return smearer.get_value()
448            theory = _call_smearer
449        else:
450            theory = lambda: model[0].evalDistribution([data.qx_data[index],
451                                                        data.qy_data[index]])
452    elif smearer is not None:
453        theory = lambda: smearer(model[0].evalDistribution(data.x))
454    else:
455        theory = lambda: model[0].evalDistribution(data.x)
456
457    def calculator(**pars):
458        # type: (float, ...) -> np.ndarray
459        """
460        Sasview calculator for model.
461        """
462        oldpars = revert_pars(model_info, pars)
463        # For multiplicity models, create a model with the correct multiplicity
464        control = oldpars.pop("CONTROL", None)
465        if control is not None:
466            # sphericalSLD has one fewer multiplicity.  This update should
467            # happen in revert_pars, but it hasn't been called yet.
468            model[0] = ModelClass(control)
469        # paying for parameter conversion each time to keep life simple, if not fast
470        for k, v in oldpars.items():
471            if k.endswith('.type'):
472                par = k[:-5]
473                if v == 'gaussian': continue
474                cls = dispersers[v if v != 'rectangle' else 'rectangula']
475                handle = cls()
476                model[0].disperser_handles[par] = handle
477                try:
478                    model[0].set_dispersion(par, handle)
479                except Exception:
480                    exception.annotate_exception("while setting %s to %r"
481                                                 %(par, v))
482                    raise
483
484
485        #print("sasview pars",oldpars)
486        for k, v in oldpars.items():
487            name_attr = k.split('.')  # polydispersity components
488            if len(name_attr) == 2:
489                par, disp_par = name_attr
490                model[0].dispersion[par][disp_par] = v
491            else:
492                model[0].setParam(k, v)
493        return theory()
494
495    calculator.engine = "sasview"
496    return calculator
497
498DTYPE_MAP = {
499    'half': '16',
500    'fast': 'fast',
501    'single': '32',
502    'double': '64',
503    'quad': '128',
504    'f16': '16',
505    'f32': '32',
506    'f64': '64',
507    'longdouble': '128',
508}
509def eval_opencl(model_info, data, dtype='single', cutoff=0.):
510    # type: (ModelInfo, Data, str, float) -> Calculator
511    """
512    Return a model calculator using the OpenCL calculation engine.
513    """
514    if not core.HAVE_OPENCL:
515        raise RuntimeError("OpenCL not available")
516    model = core.build_model(model_info, dtype=dtype, platform="ocl")
517    calculator = DirectModel(data, model, cutoff=cutoff)
518    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
519    return calculator
520
521def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
522    # type: (ModelInfo, Data, str, float) -> Calculator
523    """
524    Return a model calculator using the DLL calculation engine.
525    """
526    model = core.build_model(model_info, dtype=dtype, platform="dll")
527    calculator = DirectModel(data, model, cutoff=cutoff)
528    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
529    return calculator
530
531def time_calculation(calculator, pars, evals=1):
532    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
533    """
534    Compute the average calculation time over N evaluations.
535
536    An additional call is generated without polydispersity in order to
537    initialize the calculation engine, and make the average more stable.
538    """
539    # initialize the code so time is more accurate
540    if evals > 1:
541        calculator(**suppress_pd(pars))
542    toc = tic()
543    # make sure there is at least one eval
544    value = calculator(**pars)
545    for _ in range(evals-1):
546        value = calculator(**pars)
547    average_time = toc()*1000. / evals
548    #print("I(q)",value)
549    return value, average_time
550
551def make_data(opts):
552    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
553    """
554    Generate an empty dataset, used with the model to set Q points
555    and resolution.
556
557    *opts* contains the options, with 'qmax', 'nq', 'res',
558    'accuracy', 'is2d' and 'view' parsed from the command line.
559    """
560    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
561    if opts['is2d']:
562        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
563        data = empty_data2D(q, resolution=res)
564        data.accuracy = opts['accuracy']
565        set_beam_stop(data, 0.0004)
566        index = ~data.mask
567    else:
568        if opts['view'] == 'log' and not opts['zero']:
569            qmax = math.log10(qmax)
570            q = np.logspace(qmax-3, qmax, nq)
571        else:
572            q = np.linspace(0.001*qmax, qmax, nq)
573        if opts['zero']:
574            q = np.hstack((0, q))
575        data = empty_data1D(q, resolution=res)
576        index = slice(None, None)
577    return data, index
578
579def make_engine(model_info, data, dtype, cutoff):
580    # type: (ModelInfo, Data, str, float) -> Calculator
581    """
582    Generate the appropriate calculation engine for the given datatype.
583
584    Datatypes with '!' appended are evaluated using external C DLLs rather
585    than OpenCL.
586    """
587    if dtype == 'sasview':
588        return eval_sasview(model_info, data)
589    elif dtype.endswith('!'):
590        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
591    else:
592        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
593
594def _show_invalid(data, theory):
595    # type: (Data, np.ma.ndarray) -> None
596    """
597    Display a list of the non-finite values in theory.
598    """
599    if not theory.mask.any():
600        return
601
602    if hasattr(data, 'x'):
603        bad = zip(data.x[theory.mask], theory[theory.mask])
604        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
605
606
607def compare(opts, limits=None):
608    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
609    """
610    Preform a comparison using options from the command line.
611
612    *limits* are the limits on the values to use, either to set the y-axis
613    for 1D or to set the colormap scale for 2D.  If None, then they are
614    inferred from the data and returned. When exploring using Bumps,
615    the limits are set when the model is initially called, and maintained
616    as the values are adjusted, making it easier to see the effects of the
617    parameters.
618    """
619    n_base, n_comp = opts['count']
620    pars, pars2 = opts['pars']
621    data = opts['data']
622
623    # silence the linter
624    base = opts['engines'][0] if n_base else None
625    comp = opts['engines'][1] if n_comp else None
626    base_time = comp_time = None
627    base_value = comp_value = resid = relerr = None
628
629    # Base calculation
630    if n_base > 0:
631        try:
632            base_raw, base_time = time_calculation(base, pars, n_base)
633            base_value = np.ma.masked_invalid(base_raw)
634            print("%s t=%.2f ms, intensity=%.0f"
635                  % (base.engine, base_time, base_value.sum()))
636            _show_invalid(data, base_value)
637        except ImportError:
638            traceback.print_exc()
639            n_base = 0
640
641    # Comparison calculation
642    if n_comp > 0:
643        try:
644            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
645            comp_value = np.ma.masked_invalid(comp_raw)
646            print("%s t=%.2f ms, intensity=%.0f"
647                  % (comp.engine, comp_time, comp_value.sum()))
648            _show_invalid(data, comp_value)
649        except ImportError:
650            traceback.print_exc()
651            n_comp = 0
652
653    # Compare, but only if computing both forms
654    if n_base > 0 and n_comp > 0:
655        resid = (base_value - comp_value)
656        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
657        _print_stats("|%s-%s|"
658                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
659                     resid)
660        _print_stats("|(%s-%s)/%s|"
661                     % (base.engine, comp.engine, comp.engine),
662                     relerr)
663
664    # Plot if requested
665    if not opts['plot'] and not opts['explore']: return
666    view = opts['view']
667    import matplotlib.pyplot as plt
668    if limits is None:
669        vmin, vmax = np.Inf, -np.Inf
670        if n_base > 0:
671            vmin = min(vmin, base_value.min())
672            vmax = max(vmax, base_value.max())
673        if n_comp > 0:
674            vmin = min(vmin, comp_value.min())
675            vmax = max(vmax, comp_value.max())
676        limits = vmin, vmax
677
678    if n_base > 0:
679        if n_comp > 0: plt.subplot(131)
680        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
681        plt.title("%s t=%.2f ms"%(base.engine, base_time))
682        #cbar_title = "log I"
683    if n_comp > 0:
684        if n_base > 0: plt.subplot(132)
685        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
686        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
687        #cbar_title = "log I"
688    if n_comp > 0 and n_base > 0:
689        plt.subplot(133)
690        if not opts['rel_err']:
691            err, errstr, errview = resid, "abs err", "linear"
692        else:
693            err, errstr, errview = abs(relerr), "rel err", "log"
694        #sorted = np.sort(err.flatten())
695        #cutoff = sorted[int(sorted.size*0.95)]
696        #err[err>cutoff] = cutoff
697        #err,errstr = base/comp,"ratio"
698        plot_theory(data, None, resid=err, view=errview, use_data=False)
699        if view == 'linear':
700            plt.xscale('linear')
701        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
702        #cbar_title = errstr if errview=="linear" else "log "+errstr
703    #if is2D:
704    #    h = plt.colorbar()
705    #    h.ax.set_title(cbar_title)
706    fig = plt.gcf()
707    extra_title = ' '+opts['title'] if opts['title'] else ''
708    fig.suptitle(":".join(opts['name']) + extra_title)
709
710    if n_comp > 0 and n_base > 0 and opts['show_hist']:
711        plt.figure()
712        v = relerr
713        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
714        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
715        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
716                   % (base.engine, comp.engine, comp.engine))
717        plt.ylabel('P(err)')
718        plt.title('Distribution of relative error between calculation engines')
719
720    if not opts['explore']:
721        plt.show()
722
723    return limits
724
725def _print_stats(label, err):
726    # type: (str, np.ma.ndarray) -> None
727    # work with trimmed data, not the full set
728    sorted_err = np.sort(abs(err.compressed()))
729    p50 = int((len(sorted_err)-1)*0.50)
730    p98 = int((len(sorted_err)-1)*0.98)
731    data = [
732        "max:%.3e"%sorted_err[-1],
733        "median:%.3e"%sorted_err[p50],
734        "98%%:%.3e"%sorted_err[p98],
735        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
736        "zero-offset:%+.3e"%np.mean(sorted_err),
737        ]
738    print(label+"  "+"  ".join(data))
739
740
741
742# ===========================================================================
743#
744NAME_OPTIONS = set([
745    'plot', 'noplot',
746    'half', 'fast', 'single', 'double',
747    'single!', 'double!', 'quad!', 'sasview',
748    'lowq', 'midq', 'highq', 'exq', 'zero',
749    '2d', '1d',
750    'preset', 'random',
751    'poly', 'mono',
752    'nopars', 'pars',
753    'rel', 'abs',
754    'linear', 'log', 'q4',
755    'hist', 'nohist',
756    'edit', 'html',
757    'demo', 'default',
758    ])
759VALUE_OPTIONS = [
760    # Note: random is both a name option and a value option
761    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title',
762    ]
763
764def columnize(items, indent="", width=79):
765    # type: (List[str], str, int) -> str
766    """
767    Format a list of strings into columns.
768
769    Returns a string with carriage returns ready for printing.
770    """
771    column_width = max(len(w) for w in items) + 1
772    num_columns = (width - len(indent)) // column_width
773    num_rows = len(items) // num_columns
774    items = items + [""] * (num_rows * num_columns - len(items))
775    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
776    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
777             for row in zip(*columns)]
778    output = indent + ("\n"+indent).join(lines)
779    return output
780
781
782def get_pars(model_info, use_demo=False):
783    # type: (ModelInfo, bool) -> ParameterSet
784    """
785    Extract demo parameters from the model definition.
786    """
787    # Get the default values for the parameters
788    pars = {}
789    for p in model_info.parameters.call_parameters:
790        parts = [('', p.default)]
791        if p.polydisperse:
792            parts.append(('_pd', 0.0))
793            parts.append(('_pd_n', 0))
794            parts.append(('_pd_nsigma', 3.0))
795            parts.append(('_pd_type', "gaussian"))
796        for ext, val in parts:
797            if p.length > 1:
798                dict(("%s%d%s" % (p.id, k, ext), val)
799                     for k in range(1, p.length+1))
800            else:
801                pars[p.id + ext] = val
802
803    # Plug in values given in demo
804    if use_demo:
805        pars.update(model_info.demo)
806    return pars
807
808INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
809def isnumber(str):
810    match = FLOAT_RE.match(str)
811    isfloat = (match and not str[match.end():])
812    return isfloat or INTEGER_RE.match(str)
813
814def parse_opts(argv):
815    # type: (List[str]) -> Dict[str, Any]
816    """
817    Parse command line options.
818    """
819    MODELS = core.list_models()
820    flags = [arg for arg in argv
821             if arg.startswith('-')]
822    values = [arg for arg in argv
823              if not arg.startswith('-') and '=' in arg]
824    positional_args = [arg for arg in argv
825            if not arg.startswith('-') and '=' not in arg]
826    models = "\n    ".join("%-15s"%v for v in MODELS)
827    if len(positional_args) == 0:
828        print(USAGE)
829        print("\nAvailable models:")
830        print(columnize(MODELS, indent="  "))
831        return None
832    if len(positional_args) > 3:
833        print("expected parameters: model N1 N2")
834
835    invalid = [o[1:] for o in flags
836               if o[1:] not in NAME_OPTIONS
837               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
838    if invalid:
839        print("Invalid options: %s"%(", ".join(invalid)))
840        return None
841
842    name = positional_args[0]
843    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
844    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
845
846    # pylint: disable=bad-whitespace
847    # Interpret the flags
848    opts = {
849        'plot'      : True,
850        'view'      : 'log',
851        'is2d'      : False,
852        'qmax'      : 0.05,
853        'nq'        : 128,
854        'res'       : 0.0,
855        'accuracy'  : 'Low',
856        'cutoff'    : 0.0,
857        'seed'      : -1,  # default to preset
858        'mono'      : False,
859        'show_pars' : False,
860        'show_hist' : False,
861        'rel_err'   : True,
862        'explore'   : False,
863        'use_demo'  : True,
864        'zero'      : False,
865        'html'      : False,
866        'title'     : None,
867    }
868    engines = []
869    for arg in flags:
870        if arg == '-noplot':    opts['plot'] = False
871        elif arg == '-plot':    opts['plot'] = True
872        elif arg == '-linear':  opts['view'] = 'linear'
873        elif arg == '-log':     opts['view'] = 'log'
874        elif arg == '-q4':      opts['view'] = 'q4'
875        elif arg == '-1d':      opts['is2d'] = False
876        elif arg == '-2d':      opts['is2d'] = True
877        elif arg == '-exq':     opts['qmax'] = 10.0
878        elif arg == '-highq':   opts['qmax'] = 1.0
879        elif arg == '-midq':    opts['qmax'] = 0.2
880        elif arg == '-lowq':    opts['qmax'] = 0.05
881        elif arg == '-zero':    opts['zero'] = True
882        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
883        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
884        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
885        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
886        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
887        elif arg.startswith('-title'):     opts['title'] = arg[7:]
888        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
889        elif arg == '-preset':  opts['seed'] = -1
890        elif arg == '-mono':    opts['mono'] = True
891        elif arg == '-poly':    opts['mono'] = False
892        elif arg == '-pars':    opts['show_pars'] = True
893        elif arg == '-nopars':  opts['show_pars'] = False
894        elif arg == '-hist':    opts['show_hist'] = True
895        elif arg == '-nohist':  opts['show_hist'] = False
896        elif arg == '-rel':     opts['rel_err'] = True
897        elif arg == '-abs':     opts['rel_err'] = False
898        elif arg == '-half':    engines.append(arg[1:])
899        elif arg == '-fast':    engines.append(arg[1:])
900        elif arg == '-single':  engines.append(arg[1:])
901        elif arg == '-double':  engines.append(arg[1:])
902        elif arg == '-single!': engines.append(arg[1:])
903        elif arg == '-double!': engines.append(arg[1:])
904        elif arg == '-quad!':   engines.append(arg[1:])
905        elif arg == '-sasview': engines.append(arg[1:])
906        elif arg == '-edit':    opts['explore'] = True
907        elif arg == '-demo':    opts['use_demo'] = True
908        elif arg == '-default':    opts['use_demo'] = False
909        elif arg == '-html':    opts['html'] = True
910    # pylint: enable=bad-whitespace
911
912    if ':' in name:
913        name, name2 = name.split(':',2)
914    else:
915        name2 = name
916    try:
917        model_info = core.load_model_info(name)
918        model_info2 = core.load_model_info(name2) if name2 != name else model_info
919    except ImportError as exc:
920        print(str(exc))
921        print("Could not find model; use one of:\n    " + models)
922        return None
923
924    # Get demo parameters from model definition, or use default parameters
925    # if model does not define demo parameters
926    pars = get_pars(model_info, opts['use_demo'])
927    pars2 = get_pars(model_info2, opts['use_demo'])
928    pars2.update((k, v) for k, v in pars.items() if k in pars2)
929    # randomize parameters
930    #pars.update(set_pars)  # set value before random to control range
931    if opts['seed'] > -1:
932        pars = randomize_pars(model_info, pars, seed=opts['seed'])
933        if model_info != model_info2:
934            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
935        else:
936            pars2 = pars.copy()
937        print("Randomize using -random=%i"%opts['seed'])
938    if opts['mono']:
939        pars = suppress_pd(pars)
940        pars2 = suppress_pd(pars2)
941
942    # Fill in parameters given on the command line
943    presets = {}
944    presets2 = {}
945    for arg in values:
946        k, v = arg.split('=', 1)
947        if k not in pars and k not in pars2:
948            # extract base name without polydispersity info
949            s = set(p.split('_pd')[0] for p in pars)
950            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
951            return None
952        v1, v2 = v.split(':',2) if ':' in v else (v,v)
953        if v1 and k in pars:
954            presets[k] = float(v1) if isnumber(v1) else v1
955        if v2 and k in pars2:
956            presets2[k] = float(v2) if isnumber(v2) else v2
957
958    # Evaluate preset parameter expressions
959    context = MATH.copy()
960    context.update(pars)
961    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
962    for k, v in presets.items():
963        if not isinstance(v, float) and not k.endswith('_type'):
964            presets[k] = eval(v, context)
965    context.update(presets)
966    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
967    for k, v in presets2.items():
968        if not isinstance(v, float) and not k.endswith('_type'):
969            presets2[k] = eval(v, context)
970
971    # update parameters with presets
972    pars.update(presets)  # set value after random to control value
973    pars2.update(presets2)  # set value after random to control value
974    #import pprint; pprint.pprint(model_info)
975    constrain_pars(model_info, pars)
976    constrain_pars(model_info2, pars2)
977
978    same_model = name == name2 and pars == pars
979    if len(engines) == 0:
980        if same_model:
981            engines.extend(['single', 'double'])
982        else:
983            engines.extend(['single', 'single'])
984    elif len(engines) == 1:
985        if not same_model:
986            engines.append(engines[0])
987        elif engines[0] == 'double':
988            engines.append('single')
989        else:
990            engines.append('double')
991    elif len(engines) > 2:
992        del engines[2:]
993
994    use_sasview = any(engine == 'sasview' and count > 0
995                      for engine, count in zip(engines, [n1, n2]))
996    if use_sasview:
997        constrain_new_to_old(model_info, pars)
998        constrain_new_to_old(model_info2, pars2)
999
1000    if opts['show_pars']:
1001        if not same_model:
1002            print("==== %s ====="%model_info.name)
1003            print(str(parlist(model_info, pars, opts['is2d'])))
1004            print("==== %s ====="%model_info2.name)
1005            print(str(parlist(model_info2, pars2, opts['is2d'])))
1006        else:
1007            print(str(parlist(model_info, pars, opts['is2d'])))
1008
1009    # Create the computational engines
1010    data, _ = make_data(opts)
1011    if n1:
1012        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1013    else:
1014        base = None
1015    if n2:
1016        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1017    else:
1018        comp = None
1019
1020    # pylint: disable=bad-whitespace
1021    # Remember it all
1022    opts.update({
1023        'data'      : data,
1024        'name'      : [name, name2],
1025        'def'       : [model_info, model_info2],
1026        'count'     : [n1, n2],
1027        'presets'   : [presets, presets2],
1028        'pars'      : [pars, pars2],
1029        'engines'   : [base, comp],
1030    })
1031    # pylint: enable=bad-whitespace
1032
1033    return opts
1034
1035def show_docs(opts):
1036    # type: (Dict[str, Any]) -> None
1037    """
1038    show html docs for the model
1039    """
1040    import wx  # type: ignore
1041    from .generate import view_html_from_info
1042    app = wx.App() if wx.GetApp() is None else None
1043    view_html_from_info(opts['def'][0])
1044    if app: app.MainLoop()
1045
1046
1047def explore(opts):
1048    # type: (Dict[str, Any]) -> None
1049    """
1050    explore the model using the bumps gui.
1051    """
1052    import wx  # type: ignore
1053    from bumps.names import FitProblem  # type: ignore
1054    from bumps.gui.app_frame import AppFrame  # type: ignore
1055
1056    is_mac = "cocoa" in wx.version()
1057    # Create an app if not running embedded
1058    app = wx.App() if wx.GetApp() is None else None
1059    problem = FitProblem(Explore(opts))
1060    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1061    if not is_mac: frame.Show()
1062    frame.panel.set_model(model=problem)
1063    frame.panel.Layout()
1064    frame.panel.aui.Split(0, wx.TOP)
1065    if is_mac: frame.Show()
1066    # If running withing an app, start the main loop
1067    if app: app.MainLoop()
1068
1069class Explore(object):
1070    """
1071    Bumps wrapper for a SAS model comparison.
1072
1073    The resulting object can be used as a Bumps fit problem so that
1074    parameters can be adjusted in the GUI, with plots updated on the fly.
1075    """
1076    def __init__(self, opts):
1077        # type: (Dict[str, Any]) -> None
1078        from bumps.cli import config_matplotlib  # type: ignore
1079        from . import bumps_model
1080        config_matplotlib()
1081        self.opts = opts
1082        model_info = opts['def'][0]
1083        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'][0])
1084        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1085        if not opts['is2d']:
1086            for p in model_info.parameters.user_parameters(is2d=False):
1087                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1088                    k = p.name+ext
1089                    v = pars.get(k, None)
1090                    if v is not None:
1091                        v.range(*parameter_range(k, v.value))
1092        else:
1093            for k, v in pars.items():
1094                v.range(*parameter_range(k, v.value))
1095
1096        self.pars = pars
1097        self.pd_types = pd_types
1098        self.limits = None
1099
1100    def numpoints(self):
1101        # type: () -> int
1102        """
1103        Return the number of points.
1104        """
1105        return len(self.pars) + 1  # so dof is 1
1106
1107    def parameters(self):
1108        # type: () -> Any   # Dict/List hierarchy of parameters
1109        """
1110        Return a dictionary of parameters.
1111        """
1112        return self.pars
1113
1114    def nllf(self):
1115        # type: () -> float
1116        """
1117        Return cost.
1118        """
1119        # pylint: disable=no-self-use
1120        return 0.  # No nllf
1121
1122    def plot(self, view='log'):
1123        # type: (str) -> None
1124        """
1125        Plot the data and residuals.
1126        """
1127        pars = dict((k, v.value) for k, v in self.pars.items())
1128        pars.update(self.pd_types)
1129        self.opts['pars'][0] = pars
1130        self.opts['pars'][1] = pars
1131        limits = compare(self.opts, limits=self.limits)
1132        if self.limits is None:
1133            vmin, vmax = limits
1134            self.limits = vmax*1e-7, 1.3*vmax
1135
1136
1137def main(*argv):
1138    # type: (*str) -> None
1139    """
1140    Main program.
1141    """
1142    opts = parse_opts(argv)
1143    if opts is not None:
1144        if opts['html']:
1145            show_docs(opts)
1146        elif opts['explore']:
1147            explore(opts)
1148        else:
1149            compare(opts)
1150
1151if __name__ == "__main__":
1152    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.