source: sasmodels/sasmodels/compare.py @ 6e5c0b7

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 6e5c0b7 was 6e5c0b7, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

Merge branch 'master' into ticket-890

  • Property mode set to 100755
File size: 44.4 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except Exception:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56
57USAGE = """
58usage: compare.py model N1 N2 [options...] [key=val]
59
60Compare the speed and value for a model between the SasView original and the
61sasmodels rewrite.
62
63model or model1,model2 are the names of the models to compare (see below).
64N1 is the number of times to run sasmodels (default=1).
65N2 is the number times to run sasview (default=1).
66
67Options (* for default):
68
69    -plot*/-noplot plots or suppress the plot of the model
70    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
71    -nq=128 sets the number of Q points in the data set
72    -zero indicates that q=0 should be included
73    -1d*/-2d computes 1d or 2d data
74    -preset*/-random[=seed] preset or random parameters
75    -mono/-poly* force monodisperse/polydisperse
76    -magnetic/-nonmagnetic* suppress magnetism
77    -cutoff=1e-5* cutoff value for including a point in polydispersity
78    -pars/-nopars* prints the parameter set or not
79    -abs/-rel* plot relative or absolute error
80    -linear/-log*/-q4 intensity scaling
81    -hist/-nohist* plot histogram of relative error
82    -res=0 sets the resolution width dQ/Q if calculating with resolution
83    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
84    -edit starts the parameter explorer
85    -default/-demo* use demo vs default parameters
86    -help/-html shows the model docs instead of running the model
87    -title="note" adds note to the plot title, after the model name
88    -data="path" uses q, dq from the data file
89
90Any two calculation engines can be selected for comparison:
91
92    -single/-double/-half/-fast sets an OpenCL calculation engine
93    -single!/-double!/-quad! sets an OpenMP calculation engine
94    -sasview sets the sasview calculation engine
95
96The default is -single -double.  Note that the interpretation of quad
97precision depends on architecture, and may vary from 64-bit to 128-bit,
98with 80-bit floats being common (1e-19 precision).
99
100Key=value pairs allow you to set specific values for the model parameters.
101Key=value1,value2 to compare different values of the same parameter.
102value can be an expression including other parameters
103"""
104
105# Update docs with command line usage string.   This is separate from the usual
106# doc string so that we can display it at run time if there is an error.
107# lin
108__doc__ = (__doc__  # pylint: disable=redefined-builtin
109           + """
110Program description
111-------------------
112
113"""
114           + USAGE)
115
116kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
117
118# list of math functions for use in evaluating parameters
119MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
120
121# CRUFT python 2.6
122if not hasattr(datetime.timedelta, 'total_seconds'):
123    def delay(dt):
124        """Return number date-time delta as number seconds"""
125        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
126else:
127    def delay(dt):
128        """Return number date-time delta as number seconds"""
129        return dt.total_seconds()
130
131
132class push_seed(object):
133    """
134    Set the seed value for the random number generator.
135
136    When used in a with statement, the random number generator state is
137    restored after the with statement is complete.
138
139    :Parameters:
140
141    *seed* : int or array_like, optional
142        Seed for RandomState
143
144    :Example:
145
146    Seed can be used directly to set the seed::
147
148        >>> from numpy.random import randint
149        >>> push_seed(24)
150        <...push_seed object at...>
151        >>> print(randint(0,1000000,3))
152        [242082    899 211136]
153
154    Seed can also be used in a with statement, which sets the random
155    number generator state for the enclosed computations and restores
156    it to the previous state on completion::
157
158        >>> with push_seed(24):
159        ...    print(randint(0,1000000,3))
160        [242082    899 211136]
161
162    Using nested contexts, we can demonstrate that state is indeed
163    restored after the block completes::
164
165        >>> with push_seed(24):
166        ...    print(randint(0,1000000))
167        ...    with push_seed(24):
168        ...        print(randint(0,1000000,3))
169        ...    print(randint(0,1000000))
170        242082
171        [242082    899 211136]
172        899
173
174    The restore step is protected against exceptions in the block::
175
176        >>> with push_seed(24):
177        ...    print(randint(0,1000000))
178        ...    try:
179        ...        with push_seed(24):
180        ...            print(randint(0,1000000,3))
181        ...            raise Exception()
182        ...    except Exception:
183        ...        print("Exception raised")
184        ...    print(randint(0,1000000))
185        242082
186        [242082    899 211136]
187        Exception raised
188        899
189    """
190    def __init__(self, seed=None):
191        # type: (Optional[int]) -> None
192        self._state = np.random.get_state()
193        np.random.seed(seed)
194
195    def __enter__(self):
196        # type: () -> None
197        pass
198
199    def __exit__(self, exc_type, exc_value, traceback):
200        # type: (Any, BaseException, Any) -> None
201        # TODO: better typing for __exit__ method
202        np.random.set_state(self._state)
203
204def tic():
205    # type: () -> Callable[[], float]
206    """
207    Timer function.
208
209    Use "toc=tic()" to start the clock and "toc()" to measure
210    a time interval.
211    """
212    then = datetime.datetime.now()
213    return lambda: delay(datetime.datetime.now() - then)
214
215
216def set_beam_stop(data, radius, outer=None):
217    # type: (Data, float, float) -> None
218    """
219    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
220
221    Note: this function does not require sasview
222    """
223    if hasattr(data, 'qx_data'):
224        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
225        data.mask = (q < radius)
226        if outer is not None:
227            data.mask |= (q >= outer)
228    else:
229        data.mask = (data.x < radius)
230        if outer is not None:
231            data.mask |= (data.x >= outer)
232
233
234def parameter_range(p, v):
235    # type: (str, float) -> Tuple[float, float]
236    """
237    Choose a parameter range based on parameter name and initial value.
238    """
239    # process the polydispersity options
240    if p.endswith('_pd_n'):
241        return 0., 100.
242    elif p.endswith('_pd_nsigma'):
243        return 0., 5.
244    elif p.endswith('_pd_type'):
245        raise ValueError("Cannot return a range for a string value")
246    elif any(s in p for s in ('theta', 'phi', 'psi')):
247        # orientation in [-180,180], orientation pd in [0,45]
248        if p.endswith('_pd'):
249            return 0., 45.
250        else:
251            return -180., 180.
252    elif p.endswith('_pd'):
253        return 0., 1.
254    elif 'sld' in p:
255        return -0.5, 10.
256    elif p == 'background':
257        return 0., 10.
258    elif p == 'scale':
259        return 0., 1.e3
260    elif v < 0.:
261        return 2.*v, -2.*v
262    else:
263        return 0., (2.*v if v > 0. else 1.)
264
265
266def _randomize_one(model_info, p, v):
267    # type: (ModelInfo, str, float) -> float
268    # type: (ModelInfo, str, str) -> str
269    """
270    Randomize a single parameter.
271    """
272    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
273        return v
274
275    # Find the parameter definition
276    for par in model_info.parameters.call_parameters:
277        if par.name == p:
278            break
279    else:
280        raise ValueError("unknown parameter %r"%p)
281
282    if len(par.limits) > 2:  # choice list
283        return np.random.randint(len(par.limits))
284
285    limits = par.limits
286    if not np.isfinite(limits).all():
287        low, high = parameter_range(p, v)
288        limits = (max(limits[0], low), min(limits[1], high))
289
290    return np.random.uniform(*limits)
291
292
293def randomize_pars(model_info, pars, seed=None):
294    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
295    """
296    Generate random values for all of the parameters.
297
298    Valid ranges for the random number generator are guessed from the name of
299    the parameter; this will not account for constraints such as cap radius
300    greater than cylinder radius in the capped_cylinder model, so
301    :func:`constrain_pars` needs to be called afterward..
302    """
303    with push_seed(seed):
304        # Note: the sort guarantees order `of calls to random number generator
305        random_pars = dict((p, _randomize_one(model_info, p, v))
306                           for p, v in sorted(pars.items()))
307    return random_pars
308
309def constrain_pars(model_info, pars):
310    # type: (ModelInfo, ParameterSet) -> None
311    """
312    Restrict parameters to valid values.
313
314    This includes model specific code for models such as capped_cylinder
315    which need to support within model constraints (cap radius more than
316    cylinder radius in this case).
317
318    Warning: this updates the *pars* dictionary in place.
319    """
320    name = model_info.id
321    # if it is a product model, then just look at the form factor since
322    # none of the structure factors need any constraints.
323    if '*' in name:
324        name = name.split('*')[0]
325
326    # Suppress magnetism for python models (not yet implemented)
327    if callable(model_info.Iq):
328        pars.update(suppress_magnetism(pars))
329
330    if name == 'barbell':
331        if pars['radius_bell'] < pars['radius']:
332            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
333
334    elif name == 'capped_cylinder':
335        if pars['radius_cap'] < pars['radius']:
336            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
337
338    elif name == 'guinier':
339        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
340        #q_max = 0.2  # mid q maximum
341        q_max = 1.0  # high q maximum
342        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
343        pars['rg'] = min(pars['rg'], rg_max)
344
345    elif name == 'pearl_necklace':
346        if pars['radius'] < pars['thick_string']:
347            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
348        pass
349
350    elif name == 'rpa':
351        # Make sure phi sums to 1.0
352        if pars['case_num'] < 2:
353            pars['Phi1'] = 0.
354            pars['Phi2'] = 0.
355        elif pars['case_num'] < 5:
356            pars['Phi1'] = 0.
357        total = sum(pars['Phi'+c] for c in '1234')
358        for c in '1234':
359            pars['Phi'+c] /= total
360
361def parlist(model_info, pars, is2d):
362    # type: (ModelInfo, ParameterSet, bool) -> str
363    """
364    Format the parameter list for printing.
365    """
366    lines = []
367    parameters = model_info.parameters
368    magnetic = False
369    for p in parameters.user_parameters(pars, is2d):
370        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
371            continue
372        if p.id.startswith('up:') and not magnetic:
373            continue
374        fields = dict(
375            value=pars.get(p.id, p.default),
376            pd=pars.get(p.id+"_pd", 0.),
377            n=int(pars.get(p.id+"_pd_n", 0)),
378            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
379            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
380            relative_pd=p.relative_pd,
381            M0=pars.get('M0:'+p.id, 0.),
382            mphi=pars.get('mphi:'+p.id, 0.),
383            mtheta=pars.get('mtheta:'+p.id, 0.),
384        )
385        lines.append(_format_par(p.name, **fields))
386        magnetic = magnetic or fields['M0'] != 0.
387    return "\n".join(lines)
388
389    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
390
391def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
392                relative_pd=False, M0=0., mphi=0., mtheta=0.):
393    # type: (str, float, float, int, float, str) -> str
394    line = "%s: %g"%(name, value)
395    if pd != 0.  and n != 0:
396        if relative_pd:
397            pd *= value
398        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
399                % (pd, n, nsigma, nsigma, pdtype)
400    if M0 != 0.:
401        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
402    return line
403
404def suppress_pd(pars):
405    # type: (ParameterSet) -> ParameterSet
406    """
407    Suppress theta_pd for now until the normalization is resolved.
408
409    May also suppress complete polydispersity of the model to test
410    models more quickly.
411    """
412    pars = pars.copy()
413    for p in pars:
414        if p.endswith("_pd_n"): pars[p] = 0
415    return pars
416
417def suppress_magnetism(pars):
418    # type: (ParameterSet) -> ParameterSet
419    """
420    Suppress theta_pd for now until the normalization is resolved.
421
422    May also suppress complete polydispersity of the model to test
423    models more quickly.
424    """
425    pars = pars.copy()
426    for p in pars:
427        if p.startswith("M0:"): pars[p] = 0
428    return pars
429
430def eval_sasview(model_info, data):
431    # type: (Modelinfo, Data) -> Calculator
432    """
433    Return a model calculator using the pre-4.0 SasView models.
434    """
435    # importing sas here so that the error message will be that sas failed to
436    # import rather than the more obscure smear_selection not imported error
437    import sas
438    import sas.models
439    from sas.models.qsmearing import smear_selection
440    from sas.models.MultiplicationModel import MultiplicationModel
441    from sas.models.dispersion_models import models as dispersers
442
443    def get_model_class(name):
444        # type: (str) -> "sas.models.BaseComponent"
445        #print("new",sorted(_pars.items()))
446        __import__('sas.models.' + name)
447        ModelClass = getattr(getattr(sas.models, name, None), name, None)
448        if ModelClass is None:
449            raise ValueError("could not find model %r in sas.models"%name)
450        return ModelClass
451
452    # WARNING: ugly hack when handling model!
453    # Sasview models with multiplicity need to be created with the target
454    # multiplicity, so we cannot create the target model ahead of time for
455    # for multiplicity models.  Instead we store the model in a list and
456    # update the first element of that list with the new multiplicity model
457    # every time we evaluate.
458
459    # grab the sasview model, or create it if it is a product model
460    if model_info.composition:
461        composition_type, parts = model_info.composition
462        if composition_type == 'product':
463            P, S = [get_model_class(revert_name(p))() for p in parts]
464            model = [MultiplicationModel(P, S)]
465        else:
466            raise ValueError("sasview mixture models not supported by compare")
467    else:
468        old_name = revert_name(model_info)
469        if old_name is None:
470            raise ValueError("model %r does not exist in old sasview"
471                            % model_info.id)
472        ModelClass = get_model_class(old_name)
473        model = [ModelClass()]
474    model[0].disperser_handles = {}
475
476    # build a smearer with which to call the model, if necessary
477    smearer = smear_selection(data, model=model)
478    if hasattr(data, 'qx_data'):
479        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
480        index = ((~data.mask) & (~np.isnan(data.data))
481                 & (q >= data.qmin) & (q <= data.qmax))
482        if smearer is not None:
483            smearer.model = model  # because smear_selection has a bug
484            smearer.accuracy = data.accuracy
485            smearer.set_index(index)
486            def _call_smearer():
487                smearer.model = model[0]
488                return smearer.get_value()
489            theory = _call_smearer
490        else:
491            theory = lambda: model[0].evalDistribution([data.qx_data[index],
492                                                        data.qy_data[index]])
493    elif smearer is not None:
494        theory = lambda: smearer(model[0].evalDistribution(data.x))
495    else:
496        theory = lambda: model[0].evalDistribution(data.x)
497
498    def calculator(**pars):
499        # type: (float, ...) -> np.ndarray
500        """
501        Sasview calculator for model.
502        """
503        oldpars = revert_pars(model_info, pars)
504        # For multiplicity models, create a model with the correct multiplicity
505        control = oldpars.pop("CONTROL", None)
506        if control is not None:
507            # sphericalSLD has one fewer multiplicity.  This update should
508            # happen in revert_pars, but it hasn't been called yet.
509            model[0] = ModelClass(control)
510        # paying for parameter conversion each time to keep life simple, if not fast
511        for k, v in oldpars.items():
512            if k.endswith('.type'):
513                par = k[:-5]
514                if v == 'gaussian': continue
515                cls = dispersers[v if v != 'rectangle' else 'rectangula']
516                handle = cls()
517                model[0].disperser_handles[par] = handle
518                try:
519                    model[0].set_dispersion(par, handle)
520                except Exception:
521                    exception.annotate_exception("while setting %s to %r"
522                                                 %(par, v))
523                    raise
524
525
526        #print("sasview pars",oldpars)
527        for k, v in oldpars.items():
528            name_attr = k.split('.')  # polydispersity components
529            if len(name_attr) == 2:
530                par, disp_par = name_attr
531                model[0].dispersion[par][disp_par] = v
532            else:
533                model[0].setParam(k, v)
534        return theory()
535
536    calculator.engine = "sasview"
537    return calculator
538
539DTYPE_MAP = {
540    'half': '16',
541    'fast': 'fast',
542    'single': '32',
543    'double': '64',
544    'quad': '128',
545    'f16': '16',
546    'f32': '32',
547    'f64': '64',
548    'longdouble': '128',
549}
550def eval_opencl(model_info, data, dtype='single', cutoff=0.):
551    # type: (ModelInfo, Data, str, float) -> Calculator
552    """
553    Return a model calculator using the OpenCL calculation engine.
554    """
555    if not core.HAVE_OPENCL:
556        raise RuntimeError("OpenCL not available")
557    model = core.build_model(model_info, dtype=dtype, platform="ocl")
558    calculator = DirectModel(data, model, cutoff=cutoff)
559    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
560    return calculator
561
562def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
563    # type: (ModelInfo, Data, str, float) -> Calculator
564    """
565    Return a model calculator using the DLL calculation engine.
566    """
567    model = core.build_model(model_info, dtype=dtype, platform="dll")
568    calculator = DirectModel(data, model, cutoff=cutoff)
569    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
570    return calculator
571
572def time_calculation(calculator, pars, evals=1):
573    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
574    """
575    Compute the average calculation time over N evaluations.
576
577    An additional call is generated without polydispersity in order to
578    initialize the calculation engine, and make the average more stable.
579    """
580    # initialize the code so time is more accurate
581    if evals > 1:
582        calculator(**suppress_pd(pars))
583    toc = tic()
584    # make sure there is at least one eval
585    value = calculator(**pars)
586    for _ in range(evals-1):
587        value = calculator(**pars)
588    average_time = toc()*1000. / evals
589    #print("I(q)",value)
590    return value, average_time
591
592def make_data(opts):
593    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
594    """
595    Generate an empty dataset, used with the model to set Q points
596    and resolution.
597
598    *opts* contains the options, with 'qmax', 'nq', 'res',
599    'accuracy', 'is2d' and 'view' parsed from the command line.
600    """
601    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
602    if opts['is2d']:
603        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
604        data = empty_data2D(q, resolution=res)
605        data.accuracy = opts['accuracy']
606        set_beam_stop(data, 0.0004)
607        index = ~data.mask
608    else:
609        if opts['view'] == 'log' and not opts['zero']:
610            qmax = math.log10(qmax)
611            q = np.logspace(qmax-3, qmax, nq)
612        else:
613            q = np.linspace(0.001*qmax, qmax, nq)
614        if opts['zero']:
615            q = np.hstack((0, q))
616        data = empty_data1D(q, resolution=res)
617        index = slice(None, None)
618    return data, index
619
620def make_engine(model_info, data, dtype, cutoff):
621    # type: (ModelInfo, Data, str, float) -> Calculator
622    """
623    Generate the appropriate calculation engine for the given datatype.
624
625    Datatypes with '!' appended are evaluated using external C DLLs rather
626    than OpenCL.
627    """
628    if dtype == 'sasview':
629        return eval_sasview(model_info, data)
630    elif dtype.endswith('!'):
631        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
632    else:
633        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
634
635def _show_invalid(data, theory):
636    # type: (Data, np.ma.ndarray) -> None
637    """
638    Display a list of the non-finite values in theory.
639    """
640    if not theory.mask.any():
641        return
642
643    if hasattr(data, 'x'):
644        bad = zip(data.x[theory.mask], theory[theory.mask])
645        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
646
647
648def compare(opts, limits=None):
649    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
650    """
651    Preform a comparison using options from the command line.
652
653    *limits* are the limits on the values to use, either to set the y-axis
654    for 1D or to set the colormap scale for 2D.  If None, then they are
655    inferred from the data and returned. When exploring using Bumps,
656    the limits are set when the model is initially called, and maintained
657    as the values are adjusted, making it easier to see the effects of the
658    parameters.
659    """
660    result = run_models(opts, verbose=True)
661    if opts['plot']:  # Note: never called from explore
662        plot_models(opts, result, limits=limits)
663
664def run_models(opts, verbose=False):
665    # type: (Dict[str, Any]) -> Dict[str, Any]
666
667    n_base, n_comp = opts['count']
668    pars, pars2 = opts['pars']
669    data = opts['data']
670
671    # silence the linter
672    base = opts['engines'][0] if n_base else None
673    comp = opts['engines'][1] if n_comp else None
674
675    base_time = comp_time = None
676    base_value = comp_value = resid = relerr = None
677
678    # Base calculation
679    if n_base > 0:
680        try:
681            base_raw, base_time = time_calculation(base, pars, n_base)
682            base_value = np.ma.masked_invalid(base_raw)
683            if verbose:
684                print("%s t=%.2f ms, intensity=%.0f"
685                      % (base.engine, base_time, base_value.sum()))
686            _show_invalid(data, base_value)
687        except ImportError:
688            traceback.print_exc()
689            n_base = 0
690
691    # Comparison calculation
692    if n_comp > 0:
693        try:
694            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
695            comp_value = np.ma.masked_invalid(comp_raw)
696            if verbose:
697                print("%s t=%.2f ms, intensity=%.0f"
698                      % (comp.engine, comp_time, comp_value.sum()))
699            _show_invalid(data, comp_value)
700        except ImportError:
701            traceback.print_exc()
702            n_comp = 0
703
704    # Compare, but only if computing both forms
705    if n_base > 0 and n_comp > 0:
706        resid = (base_value - comp_value)
707        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
708        if verbose:
709            _print_stats("|%s-%s|"
710                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
711                         resid)
712            _print_stats("|(%s-%s)/%s|"
713                         % (base.engine, comp.engine, comp.engine),
714                         relerr)
715
716    return dict(base_value=base_value, comp_value=comp_value,
717                base_time=base_time, comp_time=comp_time,
718                resid=resid, relerr=relerr)
719
720
721def _print_stats(label, err):
722    # type: (str, np.ma.ndarray) -> None
723    # work with trimmed data, not the full set
724    sorted_err = np.sort(abs(err.compressed()))
725    if len(sorted_err) == 0.:
726        print(label + "  no valid values")
727        return
728
729    p50 = int((len(sorted_err)-1)*0.50)
730    p98 = int((len(sorted_err)-1)*0.98)
731    data = [
732        "max:%.3e"%sorted_err[-1],
733        "median:%.3e"%sorted_err[p50],
734        "98%%:%.3e"%sorted_err[p98],
735        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
736        "zero-offset:%+.3e"%np.mean(sorted_err),
737        ]
738    print(label+"  "+"  ".join(data))
739
740
741def plot_models(opts, result, limits=None):
742    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
743    base_value, comp_value= result['base_value'], result['comp_value']
744    base_time, comp_time = result['base_time'], result['comp_time']
745    resid, relerr = result['resid'], result['relerr']
746
747    have_base, have_comp = (base_value is not None), (comp_value is not None)
748    base = opts['engines'][0] if have_base else None
749    comp = opts['engines'][1] if have_comp else None
750    data = opts['data']
751    use_data = have_base ^ have_comp
752
753    # Plot if requested
754    view = opts['view']
755    import matplotlib.pyplot as plt
756    if limits is None:
757        vmin, vmax = np.Inf, -np.Inf
758        if have_base:
759            vmin = min(vmin, base_value.min())
760            vmax = max(vmax, base_value.max())
761        if have_comp:
762            vmin = min(vmin, comp_value.min())
763            vmax = max(vmax, comp_value.max())
764        limits = vmin, vmax
765
766    if have_base:
767        if have_comp: plt.subplot(131)
768        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
769        plt.title("%s t=%.2f ms"%(base.engine, base_time))
770        #cbar_title = "log I"
771    if have_comp:
772        if have_base: plt.subplot(132)
773        if not opts['is2d'] and have_base:
774            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
775        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
776        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
777        #cbar_title = "log I"
778    if have_base and have_comp:
779        plt.subplot(133)
780        if not opts['rel_err']:
781            err, errstr, errview = resid, "abs err", "linear"
782        else:
783            err, errstr, errview = abs(relerr), "rel err", "log"
784        if 0:  # 95% cutoff
785            sorted = np.sort(err.flatten())
786            cutoff = sorted[int(sorted.size*0.95)]
787            err[err>cutoff] = cutoff
788        #err,errstr = base/comp,"ratio"
789        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
790        if view == 'linear':
791            plt.xscale('linear')
792        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
793        #cbar_title = errstr if errview=="linear" else "log "+errstr
794    #if is2D:
795    #    h = plt.colorbar()
796    #    h.ax.set_title(cbar_title)
797    fig = plt.gcf()
798    extra_title = ' '+opts['title'] if opts['title'] else ''
799    fig.suptitle(":".join(opts['name']) + extra_title)
800
801    if have_base and have_comp and opts['show_hist']:
802        plt.figure()
803        v = relerr
804        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
805        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
806        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
807                   % (base.engine, comp.engine, comp.engine))
808        plt.ylabel('P(err)')
809        plt.title('Distribution of relative error between calculation engines')
810
811    if not opts['explore']:
812        plt.show()
813
814    return limits
815
816
817
818
819# ===========================================================================
820#
821NAME_OPTIONS = set([
822    'plot', 'noplot',
823    'half', 'fast', 'single', 'double',
824    'single!', 'double!', 'quad!', 'sasview',
825    'lowq', 'midq', 'highq', 'exq', 'zero',
826    '2d', '1d',
827    'preset', 'random',
828    'poly', 'mono',
829    'magnetic', 'nonmagnetic',
830    'nopars', 'pars',
831    'rel', 'abs',
832    'linear', 'log', 'q4',
833    'hist', 'nohist',
834    'edit', 'html', 'help',
835    'demo', 'default',
836    ])
837VALUE_OPTIONS = [
838    # Note: random is both a name option and a value option
839    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title', 'data',
840    ]
841
842def columnize(items, indent="", width=79):
843    # type: (List[str], str, int) -> str
844    """
845    Format a list of strings into columns.
846
847    Returns a string with carriage returns ready for printing.
848    """
849    column_width = max(len(w) for w in items) + 1
850    num_columns = (width - len(indent)) // column_width
851    num_rows = len(items) // num_columns
852    items = items + [""] * (num_rows * num_columns - len(items))
853    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
854    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
855             for row in zip(*columns)]
856    output = indent + ("\n"+indent).join(lines)
857    return output
858
859
860def get_pars(model_info, use_demo=False):
861    # type: (ModelInfo, bool) -> ParameterSet
862    """
863    Extract demo parameters from the model definition.
864    """
865    # Get the default values for the parameters
866    pars = {}
867    for p in model_info.parameters.call_parameters:
868        parts = [('', p.default)]
869        if p.polydisperse:
870            parts.append(('_pd', 0.0))
871            parts.append(('_pd_n', 0))
872            parts.append(('_pd_nsigma', 3.0))
873            parts.append(('_pd_type', "gaussian"))
874        for ext, val in parts:
875            if p.length > 1:
876                dict(("%s%d%s" % (p.id, k, ext), val)
877                     for k in range(1, p.length+1))
878            else:
879                pars[p.id + ext] = val
880
881    # Plug in values given in demo
882    if use_demo:
883        pars.update(model_info.demo)
884    return pars
885
886INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
887def isnumber(str):
888    match = FLOAT_RE.match(str)
889    isfloat = (match and not str[match.end():])
890    return isfloat or INTEGER_RE.match(str)
891
892# For distinguishing pairs of models for comparison
893# key-value pair separator =
894# shell characters  | & ; <> $ % ' " \ # `
895# model and parameter names _
896# parameter expressions - + * / . ( )
897# path characters including tilde expansion and windows drive ~ / :
898# not sure about brackets [] {}
899# maybe one of the following @ ? ^ ! ,
900MODEL_SPLIT = ','
901def parse_opts(argv):
902    # type: (List[str]) -> Dict[str, Any]
903    """
904    Parse command line options.
905    """
906    MODELS = core.list_models()
907    flags = [arg for arg in argv
908             if arg.startswith('-')]
909    values = [arg for arg in argv
910              if not arg.startswith('-') and '=' in arg]
911    positional_args = [arg for arg in argv
912            if not arg.startswith('-') and '=' not in arg]
913    models = "\n    ".join("%-15s"%v for v in MODELS)
914    if len(positional_args) == 0:
915        print(USAGE)
916        print("\nAvailable models:")
917        print(columnize(MODELS, indent="  "))
918        return None
919    if len(positional_args) > 3:
920        print("expected parameters: model N1 N2")
921
922    invalid = [o[1:] for o in flags
923               if o[1:] not in NAME_OPTIONS
924               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
925    if invalid:
926        print("Invalid options: %s"%(", ".join(invalid)))
927        return None
928
929    name = positional_args[0]
930    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
931    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
932
933    # pylint: disable=bad-whitespace
934    # Interpret the flags
935    opts = {
936        'plot'      : True,
937        'view'      : 'log',
938        'is2d'      : False,
939        'qmax'      : 0.05,
940        'nq'        : 128,
941        'res'       : 0.0,
942        'accuracy'  : 'Low',
943        'cutoff'    : 0.0,
944        'seed'      : -1,  # default to preset
945        'mono'      : False,
946        # Default to magnetic a magnetic moment is set on the command line
947        'magnetic'  : False,
948        'show_pars' : False,
949        'show_hist' : False,
950        'rel_err'   : True,
951        'explore'   : False,
952        'use_demo'  : True,
953        'zero'      : False,
954        'html'      : False,
955        'title'     : None,
956        'data'      : None,
957    }
958    engines = []
959    for arg in flags:
960        if arg == '-noplot':    opts['plot'] = False
961        elif arg == '-plot':    opts['plot'] = True
962        elif arg == '-linear':  opts['view'] = 'linear'
963        elif arg == '-log':     opts['view'] = 'log'
964        elif arg == '-q4':      opts['view'] = 'q4'
965        elif arg == '-1d':      opts['is2d'] = False
966        elif arg == '-2d':      opts['is2d'] = True
967        elif arg == '-exq':     opts['qmax'] = 10.0
968        elif arg == '-highq':   opts['qmax'] = 1.0
969        elif arg == '-midq':    opts['qmax'] = 0.2
970        elif arg == '-lowq':    opts['qmax'] = 0.05
971        elif arg == '-zero':    opts['zero'] = True
972        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
973        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
974        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
975        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
976        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
977        elif arg.startswith('-title='):    opts['title'] = arg[7:]
978        elif arg.startswith('-data='):     opts['data'] = arg[6:]
979        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
980        elif arg == '-preset':  opts['seed'] = -1
981        elif arg == '-mono':    opts['mono'] = True
982        elif arg == '-poly':    opts['mono'] = False
983        elif arg == '-magnetic':       opts['magnetic'] = True
984        elif arg == '-nonmagnetic':    opts['magnetic'] = False
985        elif arg == '-pars':    opts['show_pars'] = True
986        elif arg == '-nopars':  opts['show_pars'] = False
987        elif arg == '-hist':    opts['show_hist'] = True
988        elif arg == '-nohist':  opts['show_hist'] = False
989        elif arg == '-rel':     opts['rel_err'] = True
990        elif arg == '-abs':     opts['rel_err'] = False
991        elif arg == '-half':    engines.append(arg[1:])
992        elif arg == '-fast':    engines.append(arg[1:])
993        elif arg == '-single':  engines.append(arg[1:])
994        elif arg == '-double':  engines.append(arg[1:])
995        elif arg == '-single!': engines.append(arg[1:])
996        elif arg == '-double!': engines.append(arg[1:])
997        elif arg == '-quad!':   engines.append(arg[1:])
998        elif arg == '-sasview': engines.append(arg[1:])
999        elif arg == '-edit':    opts['explore'] = True
1000        elif arg == '-demo':    opts['use_demo'] = True
1001        elif arg == '-default':    opts['use_demo'] = False
1002        elif arg == '-html':    opts['html'] = True
1003        elif arg == '-help':    opts['html'] = True
1004    # pylint: enable=bad-whitespace
1005
1006    if MODEL_SPLIT in name:
1007        name, name2 = name.split(MODEL_SPLIT, 2)
1008    else:
1009        name2 = name
1010    try:
1011        model_info = core.load_model_info(name)
1012        model_info2 = core.load_model_info(name2) if name2 != name else model_info
1013    except ImportError as exc:
1014        print(str(exc))
1015        print("Could not find model; use one of:\n    " + models)
1016        return None
1017
1018    # Get demo parameters from model definition, or use default parameters
1019    # if model does not define demo parameters
1020    pars = get_pars(model_info, opts['use_demo'])
1021    pars2 = get_pars(model_info2, opts['use_demo'])
1022    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1023    # randomize parameters
1024    #pars.update(set_pars)  # set value before random to control range
1025    if opts['seed'] > -1:
1026        pars = randomize_pars(model_info, pars, seed=opts['seed'])
1027        if model_info != model_info2:
1028            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
1029            # Share values for parameters with the same name
1030            for k, v in pars.items():
1031                if k in pars2:
1032                    pars2[k] = v
1033        else:
1034            pars2 = pars.copy()
1035        constrain_pars(model_info, pars)
1036        constrain_pars(model_info2, pars2)
1037        print("Randomize using -random=%i"%opts['seed'])
1038    if opts['mono']:
1039        pars = suppress_pd(pars)
1040        pars2 = suppress_pd(pars2)
1041    if not opts['magnetic']:
1042        pars = suppress_magnetism(pars)
1043        pars2 = suppress_magnetism(pars2)
1044
1045    # Fill in parameters given on the command line
1046    presets = {}
1047    presets2 = {}
1048    for arg in values:
1049        k, v = arg.split('=', 1)
1050        if k not in pars and k not in pars2:
1051            # extract base name without polydispersity info
1052            s = set(p.split('_pd')[0] for p in pars)
1053            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1054            return None
1055        v1, v2 = v.split(MODEL_SPLIT, 2) if MODEL_SPLIT in v else (v,v)
1056        if v1 and k in pars:
1057            presets[k] = float(v1) if isnumber(v1) else v1
1058        if v2 and k in pars2:
1059            presets2[k] = float(v2) if isnumber(v2) else v2
1060
1061    # If pd given on the command line, default pd_n to 35
1062    for k, v in list(presets.items()):
1063        if k.endswith('_pd'):
1064            presets.setdefault(k+'_n', 35.)
1065    for k, v in list(presets2.items()):
1066        if k.endswith('_pd'):
1067            presets2.setdefault(k+'_n', 35.)
1068
1069    # Evaluate preset parameter expressions
1070    context = MATH.copy()
1071    context['np'] = np
1072    context.update(pars)
1073    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
1074    for k, v in presets.items():
1075        if not isinstance(v, float) and not k.endswith('_type'):
1076            presets[k] = eval(v, context)
1077    context.update(presets)
1078    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
1079    for k, v in presets2.items():
1080        if not isinstance(v, float) and not k.endswith('_type'):
1081            presets2[k] = eval(v, context)
1082
1083    # update parameters with presets
1084    pars.update(presets)  # set value after random to control value
1085    pars2.update(presets2)  # set value after random to control value
1086    #import pprint; pprint.pprint(model_info)
1087
1088    same_model = name == name2 and pars == pars
1089    if len(engines) == 0:
1090        if same_model:
1091            engines.extend(['single', 'double'])
1092        else:
1093            engines.extend(['single', 'single'])
1094    elif len(engines) == 1:
1095        if not same_model:
1096            engines.append(engines[0])
1097        elif engines[0] == 'double':
1098            engines.append('single')
1099        else:
1100            engines.append('double')
1101    elif len(engines) > 2:
1102        del engines[2:]
1103
1104    use_sasview = any(engine == 'sasview' and count > 0
1105                      for engine, count in zip(engines, [n1, n2]))
1106    if use_sasview:
1107        constrain_new_to_old(model_info, pars)
1108        constrain_new_to_old(model_info2, pars2)
1109
1110    if opts['show_pars']:
1111        if not same_model:
1112            print("==== %s ====="%model_info.name)
1113            print(str(parlist(model_info, pars, opts['is2d'])))
1114            print("==== %s ====="%model_info2.name)
1115            print(str(parlist(model_info2, pars2, opts['is2d'])))
1116        else:
1117            print(str(parlist(model_info, pars, opts['is2d'])))
1118
1119    # Create the computational engines
1120    if opts['data'] is not None:
1121        data = load_data(os.path.expanduser(opts['data']))
1122    else:
1123        data, _ = make_data(opts)
1124    if n1:
1125        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1126    else:
1127        base = None
1128    if n2:
1129        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1130    else:
1131        comp = None
1132
1133    # pylint: disable=bad-whitespace
1134    # Remember it all
1135    opts.update({
1136        'data'      : data,
1137        'name'      : [name, name2],
1138        'def'       : [model_info, model_info2],
1139        'count'     : [n1, n2],
1140        'presets'   : [presets, presets2],
1141        'pars'      : [pars, pars2],
1142        'engines'   : [base, comp],
1143    })
1144    # pylint: enable=bad-whitespace
1145
1146    return opts
1147
1148def show_docs(opts):
1149    # type: (Dict[str, Any]) -> None
1150    """
1151    show html docs for the model
1152    """
1153    import os
1154    from .generate import make_html
1155    from . import rst2html
1156
1157    info = opts['def'][0]
1158    html = make_html(info)
1159    path = os.path.dirname(info.filename)
1160    url = "file://"+path.replace("\\","/")[2:]+"/"
1161    rst2html.view_html_qtapp(html, url)
1162
1163def explore(opts):
1164    # type: (Dict[str, Any]) -> None
1165    """
1166    explore the model using the bumps gui.
1167    """
1168    import wx  # type: ignore
1169    from bumps.names import FitProblem  # type: ignore
1170    from bumps.gui.app_frame import AppFrame  # type: ignore
1171    from bumps.gui import signal
1172
1173    is_mac = "cocoa" in wx.version()
1174    # Create an app if not running embedded
1175    app = wx.App() if wx.GetApp() is None else None
1176    model = Explore(opts)
1177    problem = FitProblem(model)
1178    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1179    if not is_mac: frame.Show()
1180    frame.panel.set_model(model=problem)
1181    frame.panel.Layout()
1182    frame.panel.aui.Split(0, wx.TOP)
1183    def reset_parameters(event):
1184        model.revert_values()
1185        signal.update_parameters(problem)
1186    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1187    if is_mac: frame.Show()
1188    # If running withing an app, start the main loop
1189    if app: app.MainLoop()
1190
1191class Explore(object):
1192    """
1193    Bumps wrapper for a SAS model comparison.
1194
1195    The resulting object can be used as a Bumps fit problem so that
1196    parameters can be adjusted in the GUI, with plots updated on the fly.
1197    """
1198    def __init__(self, opts):
1199        # type: (Dict[str, Any]) -> None
1200        from bumps.cli import config_matplotlib  # type: ignore
1201        from . import bumps_model
1202        config_matplotlib()
1203        self.opts = opts
1204        p1, p2 = opts['pars']
1205        m1, m2 = opts['def']
1206        self.fix_p2 = m1 != m2 or p1 != p2
1207        model_info = m1
1208        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1209        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1210        if not opts['is2d']:
1211            for p in model_info.parameters.user_parameters({}, is2d=False):
1212                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1213                    k = p.name+ext
1214                    v = pars.get(k, None)
1215                    if v is not None:
1216                        v.range(*parameter_range(k, v.value))
1217        else:
1218            for k, v in pars.items():
1219                v.range(*parameter_range(k, v.value))
1220
1221        self.pars = pars
1222        self.starting_values = dict((k, v.value) for k, v in pars.items())
1223        self.pd_types = pd_types
1224        self.limits = None
1225
1226    def revert_values(self):
1227        for k, v in self.starting_values.items():
1228            self.pars[k].value = v
1229
1230    def model_update(self):
1231        pass
1232
1233    def numpoints(self):
1234        # type: () -> int
1235        """
1236        Return the number of points.
1237        """
1238        return len(self.pars) + 1  # so dof is 1
1239
1240    def parameters(self):
1241        # type: () -> Any   # Dict/List hierarchy of parameters
1242        """
1243        Return a dictionary of parameters.
1244        """
1245        return self.pars
1246
1247    def nllf(self):
1248        # type: () -> float
1249        """
1250        Return cost.
1251        """
1252        # pylint: disable=no-self-use
1253        return 0.  # No nllf
1254
1255    def plot(self, view='log'):
1256        # type: (str) -> None
1257        """
1258        Plot the data and residuals.
1259        """
1260        pars = dict((k, v.value) for k, v in self.pars.items())
1261        pars.update(self.pd_types)
1262        self.opts['pars'][0] = pars
1263        if not self.fix_p2:
1264            self.opts['pars'][1] = pars
1265        result = run_models(self.opts)
1266        limits = plot_models(self.opts, result, limits=self.limits)
1267        if self.limits is None:
1268            vmin, vmax = limits
1269            self.limits = vmax*1e-7, 1.3*vmax
1270            import pylab; pylab.clf()
1271            plot_models(self.opts, result, limits=self.limits)
1272
1273
1274def main(*argv):
1275    # type: (*str) -> None
1276    """
1277    Main program.
1278    """
1279    opts = parse_opts(argv)
1280    if opts is not None:
1281        if opts['html']:
1282            show_docs(opts)
1283        elif opts['explore']:
1284            explore(opts)
1285        else:
1286            compare(opts)
1287
1288if __name__ == "__main__":
1289    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.