source: sasmodels/sasmodels/compare.py @ 85fe7f8

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 85fe7f8 was 85fe7f8, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

re-enable sasview after forward conversion and linting updates

  • Property mode set to 100755
File size: 43.4 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35import re
36
37import numpy as np  # type: ignore
38
39from . import core
40from . import kerneldll
41from . import exception
42from .data import plot_theory, empty_data1D, empty_data2D
43from .direct_model import DirectModel
44from .convert import revert_name, revert_pars, constrain_new_to_old
45from .generate import FLOAT_RE
46
47try:
48    from typing import Optional, Dict, Any, Callable, Tuple
49except Exception:
50    pass
51else:
52    from .modelinfo import ModelInfo, Parameter, ParameterSet
53    from .data import Data
54    Calculator = Callable[[float], np.ndarray]
55
56USAGE = """
57usage: compare.py model N1 N2 [options...] [key=val]
58
59Compare the speed and value for a model between the SasView original and the
60sasmodels rewrite.
61
62model is the name of the model to compare (see below).
63N1 is the number of times to run sasmodels (default=1).
64N2 is the number times to run sasview (default=1).
65
66Options (* for default):
67
68    -plot*/-noplot plots or suppress the plot of the model
69    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
70    -nq=128 sets the number of Q points in the data set
71    -zero indicates that q=0 should be included
72    -1d*/-2d computes 1d or 2d data
73    -preset*/-random[=seed] preset or random parameters
74    -mono/-poly* force monodisperse/polydisperse
75    -magnetic/-nonmagnetic* suppress magnetism
76    -cutoff=1e-5* cutoff value for including a point in polydispersity
77    -pars/-nopars* prints the parameter set or not
78    -abs/-rel* plot relative or absolute error
79    -linear/-log*/-q4 intensity scaling
80    -hist/-nohist* plot histogram of relative error
81    -res=0 sets the resolution width dQ/Q if calculating with resolution
82    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
83    -edit starts the parameter explorer
84    -default/-demo* use demo vs default parameters
85    -html shows the model docs instead of running the model
86    -title="note" adds note to the plot title, after the model name
87
88Any two calculation engines can be selected for comparison:
89
90    -single/-double/-half/-fast sets an OpenCL calculation engine
91    -single!/-double!/-quad! sets an OpenMP calculation engine
92    -sasview sets the sasview calculation engine
93
94The default is -single -sasview.  Note that the interpretation of quad
95precision depends on architecture, and may vary from 64-bit to 128-bit,
96with 80-bit floats being common (1e-19 precision).
97
98Key=value pairs allow you to set specific values for the model parameters.
99"""
100
101# Update docs with command line usage string.   This is separate from the usual
102# doc string so that we can display it at run time if there is an error.
103# lin
104__doc__ = (__doc__  # pylint: disable=redefined-builtin
105           + """
106Program description
107-------------------
108
109"""
110           + USAGE)
111
112kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
113
114# list of math functions for use in evaluating parameters
115MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
116
117# CRUFT python 2.6
118if not hasattr(datetime.timedelta, 'total_seconds'):
119    def delay(dt):
120        """Return number date-time delta as number seconds"""
121        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
122else:
123    def delay(dt):
124        """Return number date-time delta as number seconds"""
125        return dt.total_seconds()
126
127
128class push_seed(object):
129    """
130    Set the seed value for the random number generator.
131
132    When used in a with statement, the random number generator state is
133    restored after the with statement is complete.
134
135    :Parameters:
136
137    *seed* : int or array_like, optional
138        Seed for RandomState
139
140    :Example:
141
142    Seed can be used directly to set the seed::
143
144        >>> from numpy.random import randint
145        >>> push_seed(24)
146        <...push_seed object at...>
147        >>> print(randint(0,1000000,3))
148        [242082    899 211136]
149
150    Seed can also be used in a with statement, which sets the random
151    number generator state for the enclosed computations and restores
152    it to the previous state on completion::
153
154        >>> with push_seed(24):
155        ...    print(randint(0,1000000,3))
156        [242082    899 211136]
157
158    Using nested contexts, we can demonstrate that state is indeed
159    restored after the block completes::
160
161        >>> with push_seed(24):
162        ...    print(randint(0,1000000))
163        ...    with push_seed(24):
164        ...        print(randint(0,1000000,3))
165        ...    print(randint(0,1000000))
166        242082
167        [242082    899 211136]
168        899
169
170    The restore step is protected against exceptions in the block::
171
172        >>> with push_seed(24):
173        ...    print(randint(0,1000000))
174        ...    try:
175        ...        with push_seed(24):
176        ...            print(randint(0,1000000,3))
177        ...            raise Exception()
178        ...    except Exception:
179        ...        print("Exception raised")
180        ...    print(randint(0,1000000))
181        242082
182        [242082    899 211136]
183        Exception raised
184        899
185    """
186    def __init__(self, seed=None):
187        # type: (Optional[int]) -> None
188        self._state = np.random.get_state()
189        np.random.seed(seed)
190
191    def __enter__(self):
192        # type: () -> None
193        pass
194
195    def __exit__(self, exc_type, exc_value, traceback):
196        # type: (Any, BaseException, Any) -> None
197        # TODO: better typing for __exit__ method
198        np.random.set_state(self._state)
199
200def tic():
201    # type: () -> Callable[[], float]
202    """
203    Timer function.
204
205    Use "toc=tic()" to start the clock and "toc()" to measure
206    a time interval.
207    """
208    then = datetime.datetime.now()
209    return lambda: delay(datetime.datetime.now() - then)
210
211
212def set_beam_stop(data, radius, outer=None):
213    # type: (Data, float, float) -> None
214    """
215    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
216
217    Note: this function does not require sasview
218    """
219    if hasattr(data, 'qx_data'):
220        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
221        data.mask = (q < radius)
222        if outer is not None:
223            data.mask |= (q >= outer)
224    else:
225        data.mask = (data.x < radius)
226        if outer is not None:
227            data.mask |= (data.x >= outer)
228
229
230def parameter_range(p, v):
231    # type: (str, float) -> Tuple[float, float]
232    """
233    Choose a parameter range based on parameter name and initial value.
234    """
235    # process the polydispersity options
236    if p.endswith('_pd_n'):
237        return 0., 100.
238    elif p.endswith('_pd_nsigma'):
239        return 0., 5.
240    elif p.endswith('_pd_type'):
241        raise ValueError("Cannot return a range for a string value")
242    elif any(s in p for s in ('theta', 'phi', 'psi')):
243        # orientation in [-180,180], orientation pd in [0,45]
244        if p.endswith('_pd'):
245            return 0., 45.
246        else:
247            return -180., 180.
248    elif p.endswith('_pd'):
249        return 0., 1.
250    elif 'sld' in p:
251        return -0.5, 10.
252    elif p == 'background':
253        return 0., 10.
254    elif p == 'scale':
255        return 0., 1.e3
256    elif v < 0.:
257        return 2.*v, -2.*v
258    else:
259        return 0., (2.*v if v > 0. else 1.)
260
261
262def _randomize_one(model_info, p, v):
263    # type: (ModelInfo, str, float) -> float
264    # type: (ModelInfo, str, str) -> str
265    """
266    Randomize a single parameter.
267    """
268    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
269        return v
270
271    # Find the parameter definition
272    for par in model_info.parameters.call_parameters:
273        if par.name == p:
274            break
275    else:
276        raise ValueError("unknown parameter %r"%p)
277
278    if len(par.limits) > 2:  # choice list
279        return np.random.randint(len(par.limits))
280
281    limits = par.limits
282    if not np.isfinite(limits).all():
283        low, high = parameter_range(p, v)
284        limits = (max(limits[0], low), min(limits[1], high))
285
286    return np.random.uniform(*limits)
287
288
289def randomize_pars(model_info, pars, seed=None):
290    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
291    """
292    Generate random values for all of the parameters.
293
294    Valid ranges for the random number generator are guessed from the name of
295    the parameter; this will not account for constraints such as cap radius
296    greater than cylinder radius in the capped_cylinder model, so
297    :func:`constrain_pars` needs to be called afterward..
298    """
299    with push_seed(seed):
300        # Note: the sort guarantees order `of calls to random number generator
301        random_pars = dict((p, _randomize_one(model_info, p, v))
302                           for p, v in sorted(pars.items()))
303    return random_pars
304
305def constrain_pars(model_info, pars):
306    # type: (ModelInfo, ParameterSet) -> None
307    """
308    Restrict parameters to valid values.
309
310    This includes model specific code for models such as capped_cylinder
311    which need to support within model constraints (cap radius more than
312    cylinder radius in this case).
313
314    Warning: this updates the *pars* dictionary in place.
315    """
316    name = model_info.id
317    # if it is a product model, then just look at the form factor since
318    # none of the structure factors need any constraints.
319    if '*' in name:
320        name = name.split('*')[0]
321
322    if name == 'barbell':
323        if pars['radius_bell'] < pars['radius']:
324            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
325
326    elif name == 'capped_cylinder':
327        if pars['radius_cap'] < pars['radius']:
328            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
329
330    elif name == 'guinier':
331        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
332        #q_max = 0.2  # mid q maximum
333        q_max = 1.0  # high q maximum
334        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
335        pars['rg'] = min(pars['rg'], rg_max)
336
337    elif name == 'pearl_necklace':
338        if pars['radius'] < pars['thick_string']:
339            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
340        pars['num_pearls'] = math.ceil(pars['num_pearls'])
341        pass
342
343    elif name == 'rpa':
344        # Make sure phi sums to 1.0
345        if pars['case_num'] < 2:
346            pars['Phi1'] = 0.
347            pars['Phi2'] = 0.
348        elif pars['case_num'] < 5:
349            pars['Phi1'] = 0.
350        total = sum(pars['Phi'+c] for c in '1234')
351        for c in '1234':
352            pars['Phi'+c] /= total
353
354    elif name == 'stacked_disks':
355        pars['n_stacking'] = math.ceil(pars['n_stacking'])
356
357def parlist(model_info, pars, is2d):
358    # type: (ModelInfo, ParameterSet, bool) -> str
359    """
360    Format the parameter list for printing.
361    """
362    lines = []
363    parameters = model_info.parameters
364    magnetic = False
365    for p in parameters.user_parameters(pars, is2d):
366        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
367            continue
368        if p.id.startswith('up:') and not magnetic:
369            continue
370        fields = dict(
371            value=pars.get(p.id, p.default),
372            pd=pars.get(p.id+"_pd", 0.),
373            n=int(pars.get(p.id+"_pd_n", 0)),
374            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
375            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
376            relative_pd=p.relative_pd,
377            M0=pars.get('M0:'+p.id, 0.),
378            mphi=pars.get('mphi:'+p.id, 0.),
379            mtheta=pars.get('mtheta:'+p.id, 0.),
380        )
381        lines.append(_format_par(p.name, **fields))
382        magnetic = magnetic or fields['M0'] != 0.
383    return "\n".join(lines)
384
385    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
386
387def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
388                relative_pd=False, M0=0., mphi=0., mtheta=0.):
389    # type: (str, float, float, int, float, str) -> str
390    line = "%s: %g"%(name, value)
391    if pd != 0.  and n != 0:
392        if relative_pd:
393            pd *= value
394        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
395                % (pd, n, nsigma, nsigma, pdtype)
396    if M0 != 0.:
397        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
398    return line
399
400def suppress_pd(pars):
401    # type: (ParameterSet) -> ParameterSet
402    """
403    Suppress theta_pd for now until the normalization is resolved.
404
405    May also suppress complete polydispersity of the model to test
406    models more quickly.
407    """
408    pars = pars.copy()
409    for p in pars:
410        if p.endswith("_pd_n"): pars[p] = 0
411    return pars
412
413def suppress_magnetism(pars):
414    # type: (ParameterSet) -> ParameterSet
415    """
416    Suppress theta_pd for now until the normalization is resolved.
417
418    May also suppress complete polydispersity of the model to test
419    models more quickly.
420    """
421    pars = pars.copy()
422    for p in pars:
423        if p.startswith("M0:"): pars[p] = 0
424    return pars
425
426def eval_sasview(model_info, data):
427    # type: (Modelinfo, Data) -> Calculator
428    """
429    Return a model calculator using the pre-4.0 SasView models.
430    """
431    # importing sas here so that the error message will be that sas failed to
432    # import rather than the more obscure smear_selection not imported error
433    import sas
434    import sas.models
435    from sas.models.qsmearing import smear_selection
436    from sas.models.MultiplicationModel import MultiplicationModel
437    from sas.models.dispersion_models import models as dispersers
438
439    def get_model_class(name):
440        # type: (str) -> "sas.models.BaseComponent"
441        #print("new",sorted(_pars.items()))
442        __import__('sas.models.' + name)
443        ModelClass = getattr(getattr(sas.models, name, None), name, None)
444        if ModelClass is None:
445            raise ValueError("could not find model %r in sas.models"%name)
446        return ModelClass
447
448    # WARNING: ugly hack when handling model!
449    # Sasview models with multiplicity need to be created with the target
450    # multiplicity, so we cannot create the target model ahead of time for
451    # for multiplicity models.  Instead we store the model in a list and
452    # update the first element of that list with the new multiplicity model
453    # every time we evaluate.
454
455    # grab the sasview model, or create it if it is a product model
456    if model_info.composition:
457        composition_type, parts = model_info.composition
458        if composition_type == 'product':
459            P, S = [get_model_class(revert_name(p))() for p in parts]
460            model = [MultiplicationModel(P, S)]
461        else:
462            raise ValueError("sasview mixture models not supported by compare")
463    else:
464        old_name = revert_name(model_info)
465        if old_name is None:
466            raise ValueError("model %r does not exist in old sasview"
467                            % model_info.id)
468        ModelClass = get_model_class(old_name)
469        model = [ModelClass()]
470    model[0].disperser_handles = {}
471
472    # build a smearer with which to call the model, if necessary
473    smearer = smear_selection(data, model=model)
474    if hasattr(data, 'qx_data'):
475        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
476        index = ((~data.mask) & (~np.isnan(data.data))
477                 & (q >= data.qmin) & (q <= data.qmax))
478        if smearer is not None:
479            smearer.model = model  # because smear_selection has a bug
480            smearer.accuracy = data.accuracy
481            smearer.set_index(index)
482            def _call_smearer():
483                smearer.model = model[0]
484                return smearer.get_value()
485            theory = _call_smearer
486        else:
487            theory = lambda: model[0].evalDistribution([data.qx_data[index],
488                                                        data.qy_data[index]])
489    elif smearer is not None:
490        theory = lambda: smearer(model[0].evalDistribution(data.x))
491    else:
492        theory = lambda: model[0].evalDistribution(data.x)
493
494    def calculator(**pars):
495        # type: (float, ...) -> np.ndarray
496        """
497        Sasview calculator for model.
498        """
499        oldpars = revert_pars(model_info, pars)
500        # For multiplicity models, create a model with the correct multiplicity
501        control = oldpars.pop("CONTROL", None)
502        if control is not None:
503            # sphericalSLD has one fewer multiplicity.  This update should
504            # happen in revert_pars, but it hasn't been called yet.
505            model[0] = ModelClass(control)
506        # paying for parameter conversion each time to keep life simple, if not fast
507        for k, v in oldpars.items():
508            if k.endswith('.type'):
509                par = k[:-5]
510                if v == 'gaussian': continue
511                cls = dispersers[v if v != 'rectangle' else 'rectangula']
512                handle = cls()
513                model[0].disperser_handles[par] = handle
514                try:
515                    model[0].set_dispersion(par, handle)
516                except Exception:
517                    exception.annotate_exception("while setting %s to %r"
518                                                 %(par, v))
519                    raise
520
521
522        #print("sasview pars",oldpars)
523        for k, v in oldpars.items():
524            name_attr = k.split('.')  # polydispersity components
525            if len(name_attr) == 2:
526                par, disp_par = name_attr
527                model[0].dispersion[par][disp_par] = v
528            else:
529                model[0].setParam(k, v)
530        return theory()
531
532    calculator.engine = "sasview"
533    return calculator
534
535DTYPE_MAP = {
536    'half': '16',
537    'fast': 'fast',
538    'single': '32',
539    'double': '64',
540    'quad': '128',
541    'f16': '16',
542    'f32': '32',
543    'f64': '64',
544    'longdouble': '128',
545}
546def eval_opencl(model_info, data, dtype='single', cutoff=0.):
547    # type: (ModelInfo, Data, str, float) -> Calculator
548    """
549    Return a model calculator using the OpenCL calculation engine.
550    """
551    if not core.HAVE_OPENCL:
552        raise RuntimeError("OpenCL not available")
553    model = core.build_model(model_info, dtype=dtype, platform="ocl")
554    calculator = DirectModel(data, model, cutoff=cutoff)
555    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
556    return calculator
557
558def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
559    # type: (ModelInfo, Data, str, float) -> Calculator
560    """
561    Return a model calculator using the DLL calculation engine.
562    """
563    model = core.build_model(model_info, dtype=dtype, platform="dll")
564    calculator = DirectModel(data, model, cutoff=cutoff)
565    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
566    return calculator
567
568def time_calculation(calculator, pars, evals=1):
569    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
570    """
571    Compute the average calculation time over N evaluations.
572
573    An additional call is generated without polydispersity in order to
574    initialize the calculation engine, and make the average more stable.
575    """
576    # initialize the code so time is more accurate
577    if evals > 1:
578        calculator(**suppress_pd(pars))
579    toc = tic()
580    # make sure there is at least one eval
581    value = calculator(**pars)
582    for _ in range(evals-1):
583        value = calculator(**pars)
584    average_time = toc()*1000. / evals
585    #print("I(q)",value)
586    return value, average_time
587
588def make_data(opts):
589    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
590    """
591    Generate an empty dataset, used with the model to set Q points
592    and resolution.
593
594    *opts* contains the options, with 'qmax', 'nq', 'res',
595    'accuracy', 'is2d' and 'view' parsed from the command line.
596    """
597    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
598    if opts['is2d']:
599        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
600        data = empty_data2D(q, resolution=res)
601        data.accuracy = opts['accuracy']
602        set_beam_stop(data, 0.0004)
603        index = ~data.mask
604    else:
605        if opts['view'] == 'log' and not opts['zero']:
606            qmax = math.log10(qmax)
607            q = np.logspace(qmax-3, qmax, nq)
608        else:
609            q = np.linspace(0.001*qmax, qmax, nq)
610        if opts['zero']:
611            q = np.hstack((0, q))
612        data = empty_data1D(q, resolution=res)
613        index = slice(None, None)
614    return data, index
615
616def make_engine(model_info, data, dtype, cutoff):
617    # type: (ModelInfo, Data, str, float) -> Calculator
618    """
619    Generate the appropriate calculation engine for the given datatype.
620
621    Datatypes with '!' appended are evaluated using external C DLLs rather
622    than OpenCL.
623    """
624    if dtype == 'sasview':
625        return eval_sasview(model_info, data)
626    elif dtype.endswith('!'):
627        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
628    else:
629        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
630
631def _show_invalid(data, theory):
632    # type: (Data, np.ma.ndarray) -> None
633    """
634    Display a list of the non-finite values in theory.
635    """
636    if not theory.mask.any():
637        return
638
639    if hasattr(data, 'x'):
640        bad = zip(data.x[theory.mask], theory[theory.mask])
641        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
642
643
644def compare(opts, limits=None):
645    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
646    """
647    Preform a comparison using options from the command line.
648
649    *limits* are the limits on the values to use, either to set the y-axis
650    for 1D or to set the colormap scale for 2D.  If None, then they are
651    inferred from the data and returned. When exploring using Bumps,
652    the limits are set when the model is initially called, and maintained
653    as the values are adjusted, making it easier to see the effects of the
654    parameters.
655    """
656    result = run_models(opts, verbose=True)
657    if opts['plot']:  # Note: never called from explore
658        plot_models(opts, result, limits=limits)
659
660def run_models(opts, verbose=False):
661    # type: (Dict[str, Any]) -> Dict[str, Any]
662
663    n_base, n_comp = opts['count']
664    pars, pars2 = opts['pars']
665    data = opts['data']
666
667    # silence the linter
668    base = opts['engines'][0] if n_base else None
669    comp = opts['engines'][1] if n_comp else None
670
671    base_time = comp_time = None
672    base_value = comp_value = resid = relerr = None
673
674    # Base calculation
675    if n_base > 0:
676        try:
677            base_raw, base_time = time_calculation(base, pars, n_base)
678            base_value = np.ma.masked_invalid(base_raw)
679            if verbose:
680                print("%s t=%.2f ms, intensity=%.0f"
681                      % (base.engine, base_time, base_value.sum()))
682            _show_invalid(data, base_value)
683        except ImportError:
684            traceback.print_exc()
685            n_base = 0
686
687    # Comparison calculation
688    if n_comp > 0:
689        try:
690            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
691            comp_value = np.ma.masked_invalid(comp_raw)
692            if verbose:
693                print("%s t=%.2f ms, intensity=%.0f"
694                      % (comp.engine, comp_time, comp_value.sum()))
695            _show_invalid(data, comp_value)
696        except ImportError:
697            traceback.print_exc()
698            n_comp = 0
699
700    # Compare, but only if computing both forms
701    if n_base > 0 and n_comp > 0:
702        resid = (base_value - comp_value)
703        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
704        if verbose:
705            _print_stats("|%s-%s|"
706                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
707                         resid)
708            _print_stats("|(%s-%s)/%s|"
709                         % (base.engine, comp.engine, comp.engine),
710                         relerr)
711
712    return dict(base_value=base_value, comp_value=comp_value,
713                base_time=base_time, comp_time=comp_time,
714                resid=resid, relerr=relerr)
715
716
717def _print_stats(label, err):
718    # type: (str, np.ma.ndarray) -> None
719    # work with trimmed data, not the full set
720    sorted_err = np.sort(abs(err.compressed()))
721    if len(sorted_err) == 0.:
722        print(label + "  no valid values")
723        return
724
725    p50 = int((len(sorted_err)-1)*0.50)
726    p98 = int((len(sorted_err)-1)*0.98)
727    data = [
728        "max:%.3e"%sorted_err[-1],
729        "median:%.3e"%sorted_err[p50],
730        "98%%:%.3e"%sorted_err[p98],
731        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
732        "zero-offset:%+.3e"%np.mean(sorted_err),
733        ]
734    print(label+"  "+"  ".join(data))
735
736
737def plot_models(opts, result, limits=None):
738    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
739    base_value, comp_value= result['base_value'], result['comp_value']
740    base_time, comp_time = result['base_time'], result['comp_time']
741    resid, relerr = result['resid'], result['relerr']
742
743    have_base, have_comp = (base_value is not None), (comp_value is not None)
744    base = opts['engines'][0] if have_base else None
745    comp = opts['engines'][1] if have_comp else None
746    data = opts['data']
747
748    # Plot if requested
749    view = opts['view']
750    import matplotlib.pyplot as plt
751    if limits is None:
752        vmin, vmax = np.Inf, -np.Inf
753        if have_base:
754            vmin = min(vmin, base_value.min())
755            vmax = max(vmax, base_value.max())
756        if have_comp:
757            vmin = min(vmin, comp_value.min())
758            vmax = max(vmax, comp_value.max())
759        limits = vmin, vmax
760
761    if have_base:
762        if have_comp: plt.subplot(131)
763        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
764        plt.title("%s t=%.2f ms"%(base.engine, base_time))
765        #cbar_title = "log I"
766    if have_comp:
767        if have_base: plt.subplot(132)
768        if not opts['is2d'] and have_base:
769            plot_theory(data, base_value, view=view, use_data=False, limits=limits)
770        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
771        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
772        #cbar_title = "log I"
773    if have_base and have_comp:
774        plt.subplot(133)
775        if not opts['rel_err']:
776            err, errstr, errview = resid, "abs err", "linear"
777        else:
778            err, errstr, errview = abs(relerr), "rel err", "log"
779        if 0:  # 95% cutoff
780            sorted = np.sort(err.flatten())
781            cutoff = sorted[int(sorted.size*0.95)]
782            err[err>cutoff] = cutoff
783        #err,errstr = base/comp,"ratio"
784        plot_theory(data, None, resid=err, view=errview, use_data=False)
785        if view == 'linear':
786            plt.xscale('linear')
787        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
788        #cbar_title = errstr if errview=="linear" else "log "+errstr
789    #if is2D:
790    #    h = plt.colorbar()
791    #    h.ax.set_title(cbar_title)
792    fig = plt.gcf()
793    extra_title = ' '+opts['title'] if opts['title'] else ''
794    fig.suptitle(":".join(opts['name']) + extra_title)
795
796    if have_base and have_comp and opts['show_hist']:
797        plt.figure()
798        v = relerr
799        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
800        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
801        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
802                   % (base.engine, comp.engine, comp.engine))
803        plt.ylabel('P(err)')
804        plt.title('Distribution of relative error between calculation engines')
805
806    if not opts['explore']:
807        plt.show()
808
809    return limits
810
811
812
813
814# ===========================================================================
815#
816NAME_OPTIONS = set([
817    'plot', 'noplot',
818    'half', 'fast', 'single', 'double',
819    'single!', 'double!', 'quad!', 'sasview',
820    'lowq', 'midq', 'highq', 'exq', 'zero',
821    '2d', '1d',
822    'preset', 'random',
823    'poly', 'mono',
824    'magnetic', 'nonmagnetic',
825    'nopars', 'pars',
826    'rel', 'abs',
827    'linear', 'log', 'q4',
828    'hist', 'nohist',
829    'edit', 'html',
830    'demo', 'default',
831    ])
832VALUE_OPTIONS = [
833    # Note: random is both a name option and a value option
834    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title',
835    ]
836
837def columnize(items, indent="", width=79):
838    # type: (List[str], str, int) -> str
839    """
840    Format a list of strings into columns.
841
842    Returns a string with carriage returns ready for printing.
843    """
844    column_width = max(len(w) for w in items) + 1
845    num_columns = (width - len(indent)) // column_width
846    num_rows = len(items) // num_columns
847    items = items + [""] * (num_rows * num_columns - len(items))
848    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
849    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
850             for row in zip(*columns)]
851    output = indent + ("\n"+indent).join(lines)
852    return output
853
854
855def get_pars(model_info, use_demo=False):
856    # type: (ModelInfo, bool) -> ParameterSet
857    """
858    Extract demo parameters from the model definition.
859    """
860    # Get the default values for the parameters
861    pars = {}
862    for p in model_info.parameters.call_parameters:
863        parts = [('', p.default)]
864        if p.polydisperse:
865            parts.append(('_pd', 0.0))
866            parts.append(('_pd_n', 0))
867            parts.append(('_pd_nsigma', 3.0))
868            parts.append(('_pd_type', "gaussian"))
869        for ext, val in parts:
870            if p.length > 1:
871                dict(("%s%d%s" % (p.id, k, ext), val)
872                     for k in range(1, p.length+1))
873            else:
874                pars[p.id + ext] = val
875
876    # Plug in values given in demo
877    if use_demo:
878        pars.update(model_info.demo)
879    return pars
880
881INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
882def isnumber(str):
883    match = FLOAT_RE.match(str)
884    isfloat = (match and not str[match.end():])
885    return isfloat or INTEGER_RE.match(str)
886
887def parse_opts(argv):
888    # type: (List[str]) -> Dict[str, Any]
889    """
890    Parse command line options.
891    """
892    MODELS = core.list_models()
893    flags = [arg for arg in argv
894             if arg.startswith('-')]
895    values = [arg for arg in argv
896              if not arg.startswith('-') and '=' in arg]
897    positional_args = [arg for arg in argv
898            if not arg.startswith('-') and '=' not in arg]
899    models = "\n    ".join("%-15s"%v for v in MODELS)
900    if len(positional_args) == 0:
901        print(USAGE)
902        print("\nAvailable models:")
903        print(columnize(MODELS, indent="  "))
904        return None
905    if len(positional_args) > 3:
906        print("expected parameters: model N1 N2")
907
908    invalid = [o[1:] for o in flags
909               if o[1:] not in NAME_OPTIONS
910               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
911    if invalid:
912        print("Invalid options: %s"%(", ".join(invalid)))
913        return None
914
915    name = positional_args[0]
916    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
917    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
918
919    # pylint: disable=bad-whitespace
920    # Interpret the flags
921    opts = {
922        'plot'      : True,
923        'view'      : 'log',
924        'is2d'      : False,
925        'qmax'      : 0.05,
926        'nq'        : 128,
927        'res'       : 0.0,
928        'accuracy'  : 'Low',
929        'cutoff'    : 0.0,
930        'seed'      : -1,  # default to preset
931        'mono'      : False,
932        # Default to magnetic a magnetic moment is set on the command line
933        'magnetic'  : False,
934        'show_pars' : False,
935        'show_hist' : False,
936        'rel_err'   : True,
937        'explore'   : False,
938        'use_demo'  : True,
939        'zero'      : False,
940        'html'      : False,
941        'title'     : None,
942    }
943    engines = []
944    for arg in flags:
945        if arg == '-noplot':    opts['plot'] = False
946        elif arg == '-plot':    opts['plot'] = True
947        elif arg == '-linear':  opts['view'] = 'linear'
948        elif arg == '-log':     opts['view'] = 'log'
949        elif arg == '-q4':      opts['view'] = 'q4'
950        elif arg == '-1d':      opts['is2d'] = False
951        elif arg == '-2d':      opts['is2d'] = True
952        elif arg == '-exq':     opts['qmax'] = 10.0
953        elif arg == '-highq':   opts['qmax'] = 1.0
954        elif arg == '-midq':    opts['qmax'] = 0.2
955        elif arg == '-lowq':    opts['qmax'] = 0.05
956        elif arg == '-zero':    opts['zero'] = True
957        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
958        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
959        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
960        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
961        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
962        elif arg.startswith('-title'):     opts['title'] = arg[7:]
963        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
964        elif arg == '-preset':  opts['seed'] = -1
965        elif arg == '-mono':    opts['mono'] = True
966        elif arg == '-poly':    opts['mono'] = False
967        elif arg == '-magnetic':       opts['magnetic'] = True
968        elif arg == '-nonmagnetic':    opts['magnetic'] = False
969        elif arg == '-pars':    opts['show_pars'] = True
970        elif arg == '-nopars':  opts['show_pars'] = False
971        elif arg == '-hist':    opts['show_hist'] = True
972        elif arg == '-nohist':  opts['show_hist'] = False
973        elif arg == '-rel':     opts['rel_err'] = True
974        elif arg == '-abs':     opts['rel_err'] = False
975        elif arg == '-half':    engines.append(arg[1:])
976        elif arg == '-fast':    engines.append(arg[1:])
977        elif arg == '-single':  engines.append(arg[1:])
978        elif arg == '-double':  engines.append(arg[1:])
979        elif arg == '-single!': engines.append(arg[1:])
980        elif arg == '-double!': engines.append(arg[1:])
981        elif arg == '-quad!':   engines.append(arg[1:])
982        elif arg == '-sasview': engines.append(arg[1:])
983        elif arg == '-edit':    opts['explore'] = True
984        elif arg == '-demo':    opts['use_demo'] = True
985        elif arg == '-default':    opts['use_demo'] = False
986        elif arg == '-html':    opts['html'] = True
987    # pylint: enable=bad-whitespace
988
989    if ':' in name:
990        name, name2 = name.split(':',2)
991    else:
992        name2 = name
993    try:
994        model_info = core.load_model_info(name)
995        model_info2 = core.load_model_info(name2) if name2 != name else model_info
996    except ImportError as exc:
997        print(str(exc))
998        print("Could not find model; use one of:\n    " + models)
999        return None
1000
1001    # Get demo parameters from model definition, or use default parameters
1002    # if model does not define demo parameters
1003    pars = get_pars(model_info, opts['use_demo'])
1004    pars2 = get_pars(model_info2, opts['use_demo'])
1005    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1006    # randomize parameters
1007    #pars.update(set_pars)  # set value before random to control range
1008    if opts['seed'] > -1:
1009        pars = randomize_pars(model_info, pars, seed=opts['seed'])
1010        if model_info != model_info2:
1011            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
1012            # Share values for parameters with the same name
1013            for k, v in pars.items():
1014                if k in pars2:
1015                    pars2[k] = v
1016        else:
1017            pars2 = pars.copy()
1018        constrain_pars(model_info, pars)
1019        constrain_pars(model_info2, pars2)
1020        print("Randomize using -random=%i"%opts['seed'])
1021    if opts['mono']:
1022        pars = suppress_pd(pars)
1023        pars2 = suppress_pd(pars2)
1024    if not opts['magnetic']:
1025        pars = suppress_magnetism(pars)
1026        pars2 = suppress_magnetism(pars2)
1027
1028    # Fill in parameters given on the command line
1029    presets = {}
1030    presets2 = {}
1031    for arg in values:
1032        k, v = arg.split('=', 1)
1033        if k not in pars and k not in pars2:
1034            # extract base name without polydispersity info
1035            s = set(p.split('_pd')[0] for p in pars)
1036            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1037            return None
1038        v1, v2 = v.split(':',2) if ':' in v else (v,v)
1039        if v1 and k in pars:
1040            presets[k] = float(v1) if isnumber(v1) else v1
1041        if v2 and k in pars2:
1042            presets2[k] = float(v2) if isnumber(v2) else v2
1043
1044    # If pd given on the command line, default pd_n to 35
1045    for k, v in list(presets.items()):
1046        if k.endswith('_pd'):
1047            presets.setdefault(k+'_n', 35.)
1048    for k, v in list(presets2.items()):
1049        if k.endswith('_pd'):
1050            presets2.setdefault(k+'_n', 35.)
1051
1052    # Evaluate preset parameter expressions
1053    context = MATH.copy()
1054    context.update(pars)
1055    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
1056    for k, v in presets.items():
1057        if not isinstance(v, float) and not k.endswith('_type'):
1058            presets[k] = eval(v, context)
1059    context.update(presets)
1060    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
1061    for k, v in presets2.items():
1062        if not isinstance(v, float) and not k.endswith('_type'):
1063            presets2[k] = eval(v, context)
1064
1065    # update parameters with presets
1066    pars.update(presets)  # set value after random to control value
1067    pars2.update(presets2)  # set value after random to control value
1068    #import pprint; pprint.pprint(model_info)
1069
1070    same_model = name == name2 and pars == pars
1071    if len(engines) == 0:
1072        if same_model:
1073            engines.extend(['single', 'double'])
1074        else:
1075            engines.extend(['single', 'single'])
1076    elif len(engines) == 1:
1077        if not same_model:
1078            engines.append(engines[0])
1079        elif engines[0] == 'double':
1080            engines.append('single')
1081        else:
1082            engines.append('double')
1083    elif len(engines) > 2:
1084        del engines[2:]
1085
1086    use_sasview = any(engine == 'sasview' and count > 0
1087                      for engine, count in zip(engines, [n1, n2]))
1088    if use_sasview:
1089        constrain_new_to_old(model_info, pars)
1090        constrain_new_to_old(model_info2, pars2)
1091
1092    if opts['show_pars']:
1093        if not same_model:
1094            print("==== %s ====="%model_info.name)
1095            print(str(parlist(model_info, pars, opts['is2d'])))
1096            print("==== %s ====="%model_info2.name)
1097            print(str(parlist(model_info2, pars2, opts['is2d'])))
1098        else:
1099            print(str(parlist(model_info, pars, opts['is2d'])))
1100
1101    # Create the computational engines
1102    data, _ = make_data(opts)
1103    if n1:
1104        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1105    else:
1106        base = None
1107    if n2:
1108        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1109    else:
1110        comp = None
1111
1112    # pylint: disable=bad-whitespace
1113    # Remember it all
1114    opts.update({
1115        'data'      : data,
1116        'name'      : [name, name2],
1117        'def'       : [model_info, model_info2],
1118        'count'     : [n1, n2],
1119        'presets'   : [presets, presets2],
1120        'pars'      : [pars, pars2],
1121        'engines'   : [base, comp],
1122    })
1123    # pylint: enable=bad-whitespace
1124
1125    return opts
1126
1127def show_docs(opts):
1128    # type: (Dict[str, Any]) -> None
1129    """
1130    show html docs for the model
1131    """
1132    import wx  # type: ignore
1133    from .generate import view_html_from_info
1134    app = wx.App() if wx.GetApp() is None else None
1135    view_html_from_info(opts['def'][0])
1136    if app: app.MainLoop()
1137
1138
1139def explore(opts):
1140    # type: (Dict[str, Any]) -> None
1141    """
1142    explore the model using the bumps gui.
1143    """
1144    import wx  # type: ignore
1145    from bumps.names import FitProblem  # type: ignore
1146    from bumps.gui.app_frame import AppFrame  # type: ignore
1147    from bumps.gui import signal
1148
1149    is_mac = "cocoa" in wx.version()
1150    # Create an app if not running embedded
1151    app = wx.App() if wx.GetApp() is None else None
1152    model = Explore(opts)
1153    problem = FitProblem(model)
1154    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1155    if not is_mac: frame.Show()
1156    frame.panel.set_model(model=problem)
1157    frame.panel.Layout()
1158    frame.panel.aui.Split(0, wx.TOP)
1159    def reset_parameters(event):
1160        model.revert_values()
1161        signal.update_parameters(problem)
1162    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1163    if is_mac: frame.Show()
1164    # If running withing an app, start the main loop
1165    if app: app.MainLoop()
1166
1167class Explore(object):
1168    """
1169    Bumps wrapper for a SAS model comparison.
1170
1171    The resulting object can be used as a Bumps fit problem so that
1172    parameters can be adjusted in the GUI, with plots updated on the fly.
1173    """
1174    def __init__(self, opts):
1175        # type: (Dict[str, Any]) -> None
1176        from bumps.cli import config_matplotlib  # type: ignore
1177        from . import bumps_model
1178        config_matplotlib()
1179        self.opts = opts
1180        p1, p2 = opts['pars']
1181        m1, m2 = opts['def']
1182        self.fix_p2 = m1 != m2 or p1 != p2
1183        model_info = m1
1184        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1185        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1186        if not opts['is2d']:
1187            for p in model_info.parameters.user_parameters({}, is2d=False):
1188                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1189                    k = p.name+ext
1190                    v = pars.get(k, None)
1191                    if v is not None:
1192                        v.range(*parameter_range(k, v.value))
1193        else:
1194            for k, v in pars.items():
1195                v.range(*parameter_range(k, v.value))
1196
1197        self.pars = pars
1198        self.starting_values = dict((k, v.value) for k, v in pars.items())
1199        self.pd_types = pd_types
1200        self.limits = None
1201
1202    def revert_values(self):
1203        for k, v in self.starting_values.items():
1204            self.pars[k].value = v
1205
1206    def model_update(self):
1207        pass
1208
1209    def numpoints(self):
1210        # type: () -> int
1211        """
1212        Return the number of points.
1213        """
1214        return len(self.pars) + 1  # so dof is 1
1215
1216    def parameters(self):
1217        # type: () -> Any   # Dict/List hierarchy of parameters
1218        """
1219        Return a dictionary of parameters.
1220        """
1221        return self.pars
1222
1223    def nllf(self):
1224        # type: () -> float
1225        """
1226        Return cost.
1227        """
1228        # pylint: disable=no-self-use
1229        return 0.  # No nllf
1230
1231    def plot(self, view='log'):
1232        # type: (str) -> None
1233        """
1234        Plot the data and residuals.
1235        """
1236        pars = dict((k, v.value) for k, v in self.pars.items())
1237        pars.update(self.pd_types)
1238        self.opts['pars'][0] = pars
1239        if not self.fix_p2:
1240            self.opts['pars'][1] = pars
1241        result = run_models(self.opts)
1242        limits = plot_models(self.opts, result, limits=self.limits)
1243        if self.limits is None:
1244            vmin, vmax = limits
1245            self.limits = vmax*1e-7, 1.3*vmax
1246            import pylab; pylab.clf()
1247            plot_models(self.opts, result, limits=self.limits)
1248
1249
1250def main(*argv):
1251    # type: (*str) -> None
1252    """
1253    Main program.
1254    """
1255    opts = parse_opts(argv)
1256    if opts is not None:
1257        if opts['html']:
1258            show_docs(opts)
1259        elif opts['explore']:
1260            explore(opts)
1261        else:
1262            compare(opts)
1263
1264if __name__ == "__main__":
1265    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.