source: sasmodels/sasmodels/compare.py @ d504bcd

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since d504bcd was d504bcd, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

remove restriction to integer n for random stacked_disk and pearl_necklace models

  • Property mode set to 100755
File size: 43.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35import re
36
37import numpy as np  # type: ignore
38
39from . import core
40from . import kerneldll
41from . import exception
42from .data import plot_theory, empty_data1D, empty_data2D
43from .direct_model import DirectModel
44from .convert import revert_name, revert_pars, constrain_new_to_old
45from .generate import FLOAT_RE
46
47try:
48    from typing import Optional, Dict, Any, Callable, Tuple
49except Exception:
50    pass
51else:
52    from .modelinfo import ModelInfo, Parameter, ParameterSet
53    from .data import Data
54    Calculator = Callable[[float], np.ndarray]
55
56USAGE = """
57usage: compare.py model N1 N2 [options...] [key=val]
58
59Compare the speed and value for a model between the SasView original and the
60sasmodels rewrite.
61
62model or model1,model2 are the names of the models to compare (see below).
63N1 is the number of times to run sasmodels (default=1).
64N2 is the number times to run sasview (default=1).
65
66Options (* for default):
67
68    -plot*/-noplot plots or suppress the plot of the model
69    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
70    -nq=128 sets the number of Q points in the data set
71    -zero indicates that q=0 should be included
72    -1d*/-2d computes 1d or 2d data
73    -preset*/-random[=seed] preset or random parameters
74    -mono/-poly* force monodisperse/polydisperse
75    -magnetic/-nonmagnetic* suppress magnetism
76    -cutoff=1e-5* cutoff value for including a point in polydispersity
77    -pars/-nopars* prints the parameter set or not
78    -abs/-rel* plot relative or absolute error
79    -linear/-log*/-q4 intensity scaling
80    -hist/-nohist* plot histogram of relative error
81    -res=0 sets the resolution width dQ/Q if calculating with resolution
82    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
83    -edit starts the parameter explorer
84    -default/-demo* use demo vs default parameters
85    -html shows the model docs instead of running the model
86    -title="note" adds note to the plot title, after the model name
87
88Any two calculation engines can be selected for comparison:
89
90    -single/-double/-half/-fast sets an OpenCL calculation engine
91    -single!/-double!/-quad! sets an OpenMP calculation engine
92    -sasview sets the sasview calculation engine
93
94The default is -single -double.  Note that the interpretation of quad
95precision depends on architecture, and may vary from 64-bit to 128-bit,
96with 80-bit floats being common (1e-19 precision).
97
98Key=value pairs allow you to set specific values for the model parameters.
99Key=value1,value2 to compare different values of the same parameter.
100value can be an expression including other parameters
101"""
102
103# Update docs with command line usage string.   This is separate from the usual
104# doc string so that we can display it at run time if there is an error.
105# lin
106__doc__ = (__doc__  # pylint: disable=redefined-builtin
107           + """
108Program description
109-------------------
110
111"""
112           + USAGE)
113
114kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
115
116# list of math functions for use in evaluating parameters
117MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
118
119# CRUFT python 2.6
120if not hasattr(datetime.timedelta, 'total_seconds'):
121    def delay(dt):
122        """Return number date-time delta as number seconds"""
123        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
124else:
125    def delay(dt):
126        """Return number date-time delta as number seconds"""
127        return dt.total_seconds()
128
129
130class push_seed(object):
131    """
132    Set the seed value for the random number generator.
133
134    When used in a with statement, the random number generator state is
135    restored after the with statement is complete.
136
137    :Parameters:
138
139    *seed* : int or array_like, optional
140        Seed for RandomState
141
142    :Example:
143
144    Seed can be used directly to set the seed::
145
146        >>> from numpy.random import randint
147        >>> push_seed(24)
148        <...push_seed object at...>
149        >>> print(randint(0,1000000,3))
150        [242082    899 211136]
151
152    Seed can also be used in a with statement, which sets the random
153    number generator state for the enclosed computations and restores
154    it to the previous state on completion::
155
156        >>> with push_seed(24):
157        ...    print(randint(0,1000000,3))
158        [242082    899 211136]
159
160    Using nested contexts, we can demonstrate that state is indeed
161    restored after the block completes::
162
163        >>> with push_seed(24):
164        ...    print(randint(0,1000000))
165        ...    with push_seed(24):
166        ...        print(randint(0,1000000,3))
167        ...    print(randint(0,1000000))
168        242082
169        [242082    899 211136]
170        899
171
172    The restore step is protected against exceptions in the block::
173
174        >>> with push_seed(24):
175        ...    print(randint(0,1000000))
176        ...    try:
177        ...        with push_seed(24):
178        ...            print(randint(0,1000000,3))
179        ...            raise Exception()
180        ...    except Exception:
181        ...        print("Exception raised")
182        ...    print(randint(0,1000000))
183        242082
184        [242082    899 211136]
185        Exception raised
186        899
187    """
188    def __init__(self, seed=None):
189        # type: (Optional[int]) -> None
190        self._state = np.random.get_state()
191        np.random.seed(seed)
192
193    def __enter__(self):
194        # type: () -> None
195        pass
196
197    def __exit__(self, exc_type, exc_value, traceback):
198        # type: (Any, BaseException, Any) -> None
199        # TODO: better typing for __exit__ method
200        np.random.set_state(self._state)
201
202def tic():
203    # type: () -> Callable[[], float]
204    """
205    Timer function.
206
207    Use "toc=tic()" to start the clock and "toc()" to measure
208    a time interval.
209    """
210    then = datetime.datetime.now()
211    return lambda: delay(datetime.datetime.now() - then)
212
213
214def set_beam_stop(data, radius, outer=None):
215    # type: (Data, float, float) -> None
216    """
217    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
218
219    Note: this function does not require sasview
220    """
221    if hasattr(data, 'qx_data'):
222        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
223        data.mask = (q < radius)
224        if outer is not None:
225            data.mask |= (q >= outer)
226    else:
227        data.mask = (data.x < radius)
228        if outer is not None:
229            data.mask |= (data.x >= outer)
230
231
232def parameter_range(p, v):
233    # type: (str, float) -> Tuple[float, float]
234    """
235    Choose a parameter range based on parameter name and initial value.
236    """
237    # process the polydispersity options
238    if p.endswith('_pd_n'):
239        return 0., 100.
240    elif p.endswith('_pd_nsigma'):
241        return 0., 5.
242    elif p.endswith('_pd_type'):
243        raise ValueError("Cannot return a range for a string value")
244    elif any(s in p for s in ('theta', 'phi', 'psi')):
245        # orientation in [-180,180], orientation pd in [0,45]
246        if p.endswith('_pd'):
247            return 0., 45.
248        else:
249            return -180., 180.
250    elif p.endswith('_pd'):
251        return 0., 1.
252    elif 'sld' in p:
253        return -0.5, 10.
254    elif p == 'background':
255        return 0., 10.
256    elif p == 'scale':
257        return 0., 1.e3
258    elif v < 0.:
259        return 2.*v, -2.*v
260    else:
261        return 0., (2.*v if v > 0. else 1.)
262
263
264def _randomize_one(model_info, p, v):
265    # type: (ModelInfo, str, float) -> float
266    # type: (ModelInfo, str, str) -> str
267    """
268    Randomize a single parameter.
269    """
270    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
271        return v
272
273    # Find the parameter definition
274    for par in model_info.parameters.call_parameters:
275        if par.name == p:
276            break
277    else:
278        raise ValueError("unknown parameter %r"%p)
279
280    if len(par.limits) > 2:  # choice list
281        return np.random.randint(len(par.limits))
282
283    limits = par.limits
284    if not np.isfinite(limits).all():
285        low, high = parameter_range(p, v)
286        limits = (max(limits[0], low), min(limits[1], high))
287
288    return np.random.uniform(*limits)
289
290
291def randomize_pars(model_info, pars, seed=None):
292    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
293    """
294    Generate random values for all of the parameters.
295
296    Valid ranges for the random number generator are guessed from the name of
297    the parameter; this will not account for constraints such as cap radius
298    greater than cylinder radius in the capped_cylinder model, so
299    :func:`constrain_pars` needs to be called afterward..
300    """
301    with push_seed(seed):
302        # Note: the sort guarantees order `of calls to random number generator
303        random_pars = dict((p, _randomize_one(model_info, p, v))
304                           for p, v in sorted(pars.items()))
305    return random_pars
306
307def constrain_pars(model_info, pars):
308    # type: (ModelInfo, ParameterSet) -> None
309    """
310    Restrict parameters to valid values.
311
312    This includes model specific code for models such as capped_cylinder
313    which need to support within model constraints (cap radius more than
314    cylinder radius in this case).
315
316    Warning: this updates the *pars* dictionary in place.
317    """
318    name = model_info.id
319    # if it is a product model, then just look at the form factor since
320    # none of the structure factors need any constraints.
321    if '*' in name:
322        name = name.split('*')[0]
323
324    if name == 'barbell':
325        if pars['radius_bell'] < pars['radius']:
326            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
327
328    elif name == 'capped_cylinder':
329        if pars['radius_cap'] < pars['radius']:
330            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
331
332    elif name == 'guinier':
333        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
334        #q_max = 0.2  # mid q maximum
335        q_max = 1.0  # high q maximum
336        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
337        pars['rg'] = min(pars['rg'], rg_max)
338
339    elif name == 'pearl_necklace':
340        if pars['radius'] < pars['thick_string']:
341            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
342        pass
343
344    elif name == 'rpa':
345        # Make sure phi sums to 1.0
346        if pars['case_num'] < 2:
347            pars['Phi1'] = 0.
348            pars['Phi2'] = 0.
349        elif pars['case_num'] < 5:
350            pars['Phi1'] = 0.
351        total = sum(pars['Phi'+c] for c in '1234')
352        for c in '1234':
353            pars['Phi'+c] /= total
354
355def parlist(model_info, pars, is2d):
356    # type: (ModelInfo, ParameterSet, bool) -> str
357    """
358    Format the parameter list for printing.
359    """
360    lines = []
361    parameters = model_info.parameters
362    magnetic = False
363    for p in parameters.user_parameters(pars, is2d):
364        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
365            continue
366        if p.id.startswith('up:') and not magnetic:
367            continue
368        fields = dict(
369            value=pars.get(p.id, p.default),
370            pd=pars.get(p.id+"_pd", 0.),
371            n=int(pars.get(p.id+"_pd_n", 0)),
372            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
373            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
374            relative_pd=p.relative_pd,
375            M0=pars.get('M0:'+p.id, 0.),
376            mphi=pars.get('mphi:'+p.id, 0.),
377            mtheta=pars.get('mtheta:'+p.id, 0.),
378        )
379        lines.append(_format_par(p.name, **fields))
380        magnetic = magnetic or fields['M0'] != 0.
381    return "\n".join(lines)
382
383    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
384
385def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
386                relative_pd=False, M0=0., mphi=0., mtheta=0.):
387    # type: (str, float, float, int, float, str) -> str
388    line = "%s: %g"%(name, value)
389    if pd != 0.  and n != 0:
390        if relative_pd:
391            pd *= value
392        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
393                % (pd, n, nsigma, nsigma, pdtype)
394    if M0 != 0.:
395        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
396    return line
397
398def suppress_pd(pars):
399    # type: (ParameterSet) -> ParameterSet
400    """
401    Suppress theta_pd for now until the normalization is resolved.
402
403    May also suppress complete polydispersity of the model to test
404    models more quickly.
405    """
406    pars = pars.copy()
407    for p in pars:
408        if p.endswith("_pd_n"): pars[p] = 0
409    return pars
410
411def suppress_magnetism(pars):
412    # type: (ParameterSet) -> ParameterSet
413    """
414    Suppress theta_pd for now until the normalization is resolved.
415
416    May also suppress complete polydispersity of the model to test
417    models more quickly.
418    """
419    pars = pars.copy()
420    for p in pars:
421        if p.startswith("M0:"): pars[p] = 0
422    return pars
423
424def eval_sasview(model_info, data):
425    # type: (Modelinfo, Data) -> Calculator
426    """
427    Return a model calculator using the pre-4.0 SasView models.
428    """
429    # importing sas here so that the error message will be that sas failed to
430    # import rather than the more obscure smear_selection not imported error
431    import sas
432    import sas.models
433    from sas.models.qsmearing import smear_selection
434    from sas.models.MultiplicationModel import MultiplicationModel
435    from sas.models.dispersion_models import models as dispersers
436
437    def get_model_class(name):
438        # type: (str) -> "sas.models.BaseComponent"
439        #print("new",sorted(_pars.items()))
440        __import__('sas.models.' + name)
441        ModelClass = getattr(getattr(sas.models, name, None), name, None)
442        if ModelClass is None:
443            raise ValueError("could not find model %r in sas.models"%name)
444        return ModelClass
445
446    # WARNING: ugly hack when handling model!
447    # Sasview models with multiplicity need to be created with the target
448    # multiplicity, so we cannot create the target model ahead of time for
449    # for multiplicity models.  Instead we store the model in a list and
450    # update the first element of that list with the new multiplicity model
451    # every time we evaluate.
452
453    # grab the sasview model, or create it if it is a product model
454    if model_info.composition:
455        composition_type, parts = model_info.composition
456        if composition_type == 'product':
457            P, S = [get_model_class(revert_name(p))() for p in parts]
458            model = [MultiplicationModel(P, S)]
459        else:
460            raise ValueError("sasview mixture models not supported by compare")
461    else:
462        old_name = revert_name(model_info)
463        if old_name is None:
464            raise ValueError("model %r does not exist in old sasview"
465                            % model_info.id)
466        ModelClass = get_model_class(old_name)
467        model = [ModelClass()]
468    model[0].disperser_handles = {}
469
470    # build a smearer with which to call the model, if necessary
471    smearer = smear_selection(data, model=model)
472    if hasattr(data, 'qx_data'):
473        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
474        index = ((~data.mask) & (~np.isnan(data.data))
475                 & (q >= data.qmin) & (q <= data.qmax))
476        if smearer is not None:
477            smearer.model = model  # because smear_selection has a bug
478            smearer.accuracy = data.accuracy
479            smearer.set_index(index)
480            def _call_smearer():
481                smearer.model = model[0]
482                return smearer.get_value()
483            theory = _call_smearer
484        else:
485            theory = lambda: model[0].evalDistribution([data.qx_data[index],
486                                                        data.qy_data[index]])
487    elif smearer is not None:
488        theory = lambda: smearer(model[0].evalDistribution(data.x))
489    else:
490        theory = lambda: model[0].evalDistribution(data.x)
491
492    def calculator(**pars):
493        # type: (float, ...) -> np.ndarray
494        """
495        Sasview calculator for model.
496        """
497        oldpars = revert_pars(model_info, pars)
498        # For multiplicity models, create a model with the correct multiplicity
499        control = oldpars.pop("CONTROL", None)
500        if control is not None:
501            # sphericalSLD has one fewer multiplicity.  This update should
502            # happen in revert_pars, but it hasn't been called yet.
503            model[0] = ModelClass(control)
504        # paying for parameter conversion each time to keep life simple, if not fast
505        for k, v in oldpars.items():
506            if k.endswith('.type'):
507                par = k[:-5]
508                if v == 'gaussian': continue
509                cls = dispersers[v if v != 'rectangle' else 'rectangula']
510                handle = cls()
511                model[0].disperser_handles[par] = handle
512                try:
513                    model[0].set_dispersion(par, handle)
514                except Exception:
515                    exception.annotate_exception("while setting %s to %r"
516                                                 %(par, v))
517                    raise
518
519
520        #print("sasview pars",oldpars)
521        for k, v in oldpars.items():
522            name_attr = k.split('.')  # polydispersity components
523            if len(name_attr) == 2:
524                par, disp_par = name_attr
525                model[0].dispersion[par][disp_par] = v
526            else:
527                model[0].setParam(k, v)
528        return theory()
529
530    calculator.engine = "sasview"
531    return calculator
532
533DTYPE_MAP = {
534    'half': '16',
535    'fast': 'fast',
536    'single': '32',
537    'double': '64',
538    'quad': '128',
539    'f16': '16',
540    'f32': '32',
541    'f64': '64',
542    'longdouble': '128',
543}
544def eval_opencl(model_info, data, dtype='single', cutoff=0.):
545    # type: (ModelInfo, Data, str, float) -> Calculator
546    """
547    Return a model calculator using the OpenCL calculation engine.
548    """
549    if not core.HAVE_OPENCL:
550        raise RuntimeError("OpenCL not available")
551    model = core.build_model(model_info, dtype=dtype, platform="ocl")
552    calculator = DirectModel(data, model, cutoff=cutoff)
553    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
554    return calculator
555
556def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
557    # type: (ModelInfo, Data, str, float) -> Calculator
558    """
559    Return a model calculator using the DLL calculation engine.
560    """
561    model = core.build_model(model_info, dtype=dtype, platform="dll")
562    calculator = DirectModel(data, model, cutoff=cutoff)
563    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
564    return calculator
565
566def time_calculation(calculator, pars, evals=1):
567    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
568    """
569    Compute the average calculation time over N evaluations.
570
571    An additional call is generated without polydispersity in order to
572    initialize the calculation engine, and make the average more stable.
573    """
574    # initialize the code so time is more accurate
575    if evals > 1:
576        calculator(**suppress_pd(pars))
577    toc = tic()
578    # make sure there is at least one eval
579    value = calculator(**pars)
580    for _ in range(evals-1):
581        value = calculator(**pars)
582    average_time = toc()*1000. / evals
583    #print("I(q)",value)
584    return value, average_time
585
586def make_data(opts):
587    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
588    """
589    Generate an empty dataset, used with the model to set Q points
590    and resolution.
591
592    *opts* contains the options, with 'qmax', 'nq', 'res',
593    'accuracy', 'is2d' and 'view' parsed from the command line.
594    """
595    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
596    if opts['is2d']:
597        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
598        data = empty_data2D(q, resolution=res)
599        data.accuracy = opts['accuracy']
600        set_beam_stop(data, 0.0004)
601        index = ~data.mask
602    else:
603        if opts['view'] == 'log' and not opts['zero']:
604            qmax = math.log10(qmax)
605            q = np.logspace(qmax-3, qmax, nq)
606        else:
607            q = np.linspace(0.001*qmax, qmax, nq)
608        if opts['zero']:
609            q = np.hstack((0, q))
610        data = empty_data1D(q, resolution=res)
611        index = slice(None, None)
612    return data, index
613
614def make_engine(model_info, data, dtype, cutoff):
615    # type: (ModelInfo, Data, str, float) -> Calculator
616    """
617    Generate the appropriate calculation engine for the given datatype.
618
619    Datatypes with '!' appended are evaluated using external C DLLs rather
620    than OpenCL.
621    """
622    if dtype == 'sasview':
623        return eval_sasview(model_info, data)
624    elif dtype.endswith('!'):
625        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
626    else:
627        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
628
629def _show_invalid(data, theory):
630    # type: (Data, np.ma.ndarray) -> None
631    """
632    Display a list of the non-finite values in theory.
633    """
634    if not theory.mask.any():
635        return
636
637    if hasattr(data, 'x'):
638        bad = zip(data.x[theory.mask], theory[theory.mask])
639        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
640
641
642def compare(opts, limits=None):
643    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
644    """
645    Preform a comparison using options from the command line.
646
647    *limits* are the limits on the values to use, either to set the y-axis
648    for 1D or to set the colormap scale for 2D.  If None, then they are
649    inferred from the data and returned. When exploring using Bumps,
650    the limits are set when the model is initially called, and maintained
651    as the values are adjusted, making it easier to see the effects of the
652    parameters.
653    """
654    result = run_models(opts, verbose=True)
655    if opts['plot']:  # Note: never called from explore
656        plot_models(opts, result, limits=limits)
657
658def run_models(opts, verbose=False):
659    # type: (Dict[str, Any]) -> Dict[str, Any]
660
661    n_base, n_comp = opts['count']
662    pars, pars2 = opts['pars']
663    data = opts['data']
664
665    # silence the linter
666    base = opts['engines'][0] if n_base else None
667    comp = opts['engines'][1] if n_comp else None
668
669    base_time = comp_time = None
670    base_value = comp_value = resid = relerr = None
671
672    # Base calculation
673    if n_base > 0:
674        try:
675            base_raw, base_time = time_calculation(base, pars, n_base)
676            base_value = np.ma.masked_invalid(base_raw)
677            if verbose:
678                print("%s t=%.2f ms, intensity=%.0f"
679                      % (base.engine, base_time, base_value.sum()))
680            _show_invalid(data, base_value)
681        except ImportError:
682            traceback.print_exc()
683            n_base = 0
684
685    # Comparison calculation
686    if n_comp > 0:
687        try:
688            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
689            comp_value = np.ma.masked_invalid(comp_raw)
690            if verbose:
691                print("%s t=%.2f ms, intensity=%.0f"
692                      % (comp.engine, comp_time, comp_value.sum()))
693            _show_invalid(data, comp_value)
694        except ImportError:
695            traceback.print_exc()
696            n_comp = 0
697
698    # Compare, but only if computing both forms
699    if n_base > 0 and n_comp > 0:
700        resid = (base_value - comp_value)
701        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
702        if verbose:
703            _print_stats("|%s-%s|"
704                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
705                         resid)
706            _print_stats("|(%s-%s)/%s|"
707                         % (base.engine, comp.engine, comp.engine),
708                         relerr)
709
710    return dict(base_value=base_value, comp_value=comp_value,
711                base_time=base_time, comp_time=comp_time,
712                resid=resid, relerr=relerr)
713
714
715def _print_stats(label, err):
716    # type: (str, np.ma.ndarray) -> None
717    # work with trimmed data, not the full set
718    sorted_err = np.sort(abs(err.compressed()))
719    if len(sorted_err) == 0.:
720        print(label + "  no valid values")
721        return
722
723    p50 = int((len(sorted_err)-1)*0.50)
724    p98 = int((len(sorted_err)-1)*0.98)
725    data = [
726        "max:%.3e"%sorted_err[-1],
727        "median:%.3e"%sorted_err[p50],
728        "98%%:%.3e"%sorted_err[p98],
729        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
730        "zero-offset:%+.3e"%np.mean(sorted_err),
731        ]
732    print(label+"  "+"  ".join(data))
733
734
735def plot_models(opts, result, limits=None):
736    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
737    base_value, comp_value= result['base_value'], result['comp_value']
738    base_time, comp_time = result['base_time'], result['comp_time']
739    resid, relerr = result['resid'], result['relerr']
740
741    have_base, have_comp = (base_value is not None), (comp_value is not None)
742    base = opts['engines'][0] if have_base else None
743    comp = opts['engines'][1] if have_comp else None
744    data = opts['data']
745
746    # Plot if requested
747    view = opts['view']
748    import matplotlib.pyplot as plt
749    if limits is None:
750        vmin, vmax = np.Inf, -np.Inf
751        if have_base:
752            vmin = min(vmin, base_value.min())
753            vmax = max(vmax, base_value.max())
754        if have_comp:
755            vmin = min(vmin, comp_value.min())
756            vmax = max(vmax, comp_value.max())
757        limits = vmin, vmax
758
759    if have_base:
760        if have_comp: plt.subplot(131)
761        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
762        plt.title("%s t=%.2f ms"%(base.engine, base_time))
763        #cbar_title = "log I"
764    if have_comp:
765        if have_base: plt.subplot(132)
766        if not opts['is2d'] and have_base:
767            plot_theory(data, base_value, view=view, use_data=False, limits=limits)
768        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
769        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
770        #cbar_title = "log I"
771    if have_base and have_comp:
772        plt.subplot(133)
773        if not opts['rel_err']:
774            err, errstr, errview = resid, "abs err", "linear"
775        else:
776            err, errstr, errview = abs(relerr), "rel err", "log"
777        if 0:  # 95% cutoff
778            sorted = np.sort(err.flatten())
779            cutoff = sorted[int(sorted.size*0.95)]
780            err[err>cutoff] = cutoff
781        #err,errstr = base/comp,"ratio"
782        plot_theory(data, None, resid=err, view=errview, use_data=False)
783        if view == 'linear':
784            plt.xscale('linear')
785        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
786        #cbar_title = errstr if errview=="linear" else "log "+errstr
787    #if is2D:
788    #    h = plt.colorbar()
789    #    h.ax.set_title(cbar_title)
790    fig = plt.gcf()
791    extra_title = ' '+opts['title'] if opts['title'] else ''
792    fig.suptitle(":".join(opts['name']) + extra_title)
793
794    if have_base and have_comp and opts['show_hist']:
795        plt.figure()
796        v = relerr
797        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
798        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
799        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
800                   % (base.engine, comp.engine, comp.engine))
801        plt.ylabel('P(err)')
802        plt.title('Distribution of relative error between calculation engines')
803
804    if not opts['explore']:
805        plt.show()
806
807    return limits
808
809
810
811
812# ===========================================================================
813#
814NAME_OPTIONS = set([
815    'plot', 'noplot',
816    'half', 'fast', 'single', 'double',
817    'single!', 'double!', 'quad!', 'sasview',
818    'lowq', 'midq', 'highq', 'exq', 'zero',
819    '2d', '1d',
820    'preset', 'random',
821    'poly', 'mono',
822    'magnetic', 'nonmagnetic',
823    'nopars', 'pars',
824    'rel', 'abs',
825    'linear', 'log', 'q4',
826    'hist', 'nohist',
827    'edit', 'html',
828    'demo', 'default',
829    ])
830VALUE_OPTIONS = [
831    # Note: random is both a name option and a value option
832    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title',
833    ]
834
835def columnize(items, indent="", width=79):
836    # type: (List[str], str, int) -> str
837    """
838    Format a list of strings into columns.
839
840    Returns a string with carriage returns ready for printing.
841    """
842    column_width = max(len(w) for w in items) + 1
843    num_columns = (width - len(indent)) // column_width
844    num_rows = len(items) // num_columns
845    items = items + [""] * (num_rows * num_columns - len(items))
846    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
847    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
848             for row in zip(*columns)]
849    output = indent + ("\n"+indent).join(lines)
850    return output
851
852
853def get_pars(model_info, use_demo=False):
854    # type: (ModelInfo, bool) -> ParameterSet
855    """
856    Extract demo parameters from the model definition.
857    """
858    # Get the default values for the parameters
859    pars = {}
860    for p in model_info.parameters.call_parameters:
861        parts = [('', p.default)]
862        if p.polydisperse:
863            parts.append(('_pd', 0.0))
864            parts.append(('_pd_n', 0))
865            parts.append(('_pd_nsigma', 3.0))
866            parts.append(('_pd_type', "gaussian"))
867        for ext, val in parts:
868            if p.length > 1:
869                dict(("%s%d%s" % (p.id, k, ext), val)
870                     for k in range(1, p.length+1))
871            else:
872                pars[p.id + ext] = val
873
874    # Plug in values given in demo
875    if use_demo:
876        pars.update(model_info.demo)
877    return pars
878
879INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
880def isnumber(str):
881    match = FLOAT_RE.match(str)
882    isfloat = (match and not str[match.end():])
883    return isfloat or INTEGER_RE.match(str)
884
885# For distinguishing pairs of models for comparison
886# key-value pair separator =
887# shell characters  | & ; <> $ % ' " \ # `
888# model and parameter names _
889# parameter expressions - + * / . ( )
890# path characters including tilde expansion and windows drive ~ / :
891# not sure about brackets [] {}
892# maybe one of the following @ ? ^ ! ,
893MODEL_SPLIT = ','
894def parse_opts(argv):
895    # type: (List[str]) -> Dict[str, Any]
896    """
897    Parse command line options.
898    """
899    MODELS = core.list_models()
900    flags = [arg for arg in argv
901             if arg.startswith('-')]
902    values = [arg for arg in argv
903              if not arg.startswith('-') and '=' in arg]
904    positional_args = [arg for arg in argv
905            if not arg.startswith('-') and '=' not in arg]
906    models = "\n    ".join("%-15s"%v for v in MODELS)
907    if len(positional_args) == 0:
908        print(USAGE)
909        print("\nAvailable models:")
910        print(columnize(MODELS, indent="  "))
911        return None
912    if len(positional_args) > 3:
913        print("expected parameters: model N1 N2")
914
915    invalid = [o[1:] for o in flags
916               if o[1:] not in NAME_OPTIONS
917               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
918    if invalid:
919        print("Invalid options: %s"%(", ".join(invalid)))
920        return None
921
922    name = positional_args[0]
923    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
924    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
925
926    # pylint: disable=bad-whitespace
927    # Interpret the flags
928    opts = {
929        'plot'      : True,
930        'view'      : 'log',
931        'is2d'      : False,
932        'qmax'      : 0.05,
933        'nq'        : 128,
934        'res'       : 0.0,
935        'accuracy'  : 'Low',
936        'cutoff'    : 0.0,
937        'seed'      : -1,  # default to preset
938        'mono'      : False,
939        # Default to magnetic a magnetic moment is set on the command line
940        'magnetic'  : False,
941        'show_pars' : False,
942        'show_hist' : False,
943        'rel_err'   : True,
944        'explore'   : False,
945        'use_demo'  : True,
946        'zero'      : False,
947        'html'      : False,
948        'title'     : None,
949    }
950    engines = []
951    for arg in flags:
952        if arg == '-noplot':    opts['plot'] = False
953        elif arg == '-plot':    opts['plot'] = True
954        elif arg == '-linear':  opts['view'] = 'linear'
955        elif arg == '-log':     opts['view'] = 'log'
956        elif arg == '-q4':      opts['view'] = 'q4'
957        elif arg == '-1d':      opts['is2d'] = False
958        elif arg == '-2d':      opts['is2d'] = True
959        elif arg == '-exq':     opts['qmax'] = 10.0
960        elif arg == '-highq':   opts['qmax'] = 1.0
961        elif arg == '-midq':    opts['qmax'] = 0.2
962        elif arg == '-lowq':    opts['qmax'] = 0.05
963        elif arg == '-zero':    opts['zero'] = True
964        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
965        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
966        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
967        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
968        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
969        elif arg.startswith('-title'):     opts['title'] = arg[7:]
970        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
971        elif arg == '-preset':  opts['seed'] = -1
972        elif arg == '-mono':    opts['mono'] = True
973        elif arg == '-poly':    opts['mono'] = False
974        elif arg == '-magnetic':       opts['magnetic'] = True
975        elif arg == '-nonmagnetic':    opts['magnetic'] = False
976        elif arg == '-pars':    opts['show_pars'] = True
977        elif arg == '-nopars':  opts['show_pars'] = False
978        elif arg == '-hist':    opts['show_hist'] = True
979        elif arg == '-nohist':  opts['show_hist'] = False
980        elif arg == '-rel':     opts['rel_err'] = True
981        elif arg == '-abs':     opts['rel_err'] = False
982        elif arg == '-half':    engines.append(arg[1:])
983        elif arg == '-fast':    engines.append(arg[1:])
984        elif arg == '-single':  engines.append(arg[1:])
985        elif arg == '-double':  engines.append(arg[1:])
986        elif arg == '-single!': engines.append(arg[1:])
987        elif arg == '-double!': engines.append(arg[1:])
988        elif arg == '-quad!':   engines.append(arg[1:])
989        elif arg == '-sasview': engines.append(arg[1:])
990        elif arg == '-edit':    opts['explore'] = True
991        elif arg == '-demo':    opts['use_demo'] = True
992        elif arg == '-default':    opts['use_demo'] = False
993        elif arg == '-html':    opts['html'] = True
994    # pylint: enable=bad-whitespace
995
996    if MODEL_SPLIT in name:
997        name, name2 = name.split(MODEL_SPLIT, 2)
998    else:
999        name2 = name
1000    try:
1001        model_info = core.load_model_info(name)
1002        model_info2 = core.load_model_info(name2) if name2 != name else model_info
1003    except ImportError as exc:
1004        print(str(exc))
1005        print("Could not find model; use one of:\n    " + models)
1006        return None
1007
1008    # Get demo parameters from model definition, or use default parameters
1009    # if model does not define demo parameters
1010    pars = get_pars(model_info, opts['use_demo'])
1011    pars2 = get_pars(model_info2, opts['use_demo'])
1012    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1013    # randomize parameters
1014    #pars.update(set_pars)  # set value before random to control range
1015    if opts['seed'] > -1:
1016        pars = randomize_pars(model_info, pars, seed=opts['seed'])
1017        if model_info != model_info2:
1018            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
1019            # Share values for parameters with the same name
1020            for k, v in pars.items():
1021                if k in pars2:
1022                    pars2[k] = v
1023        else:
1024            pars2 = pars.copy()
1025        constrain_pars(model_info, pars)
1026        constrain_pars(model_info2, pars2)
1027        print("Randomize using -random=%i"%opts['seed'])
1028    if opts['mono']:
1029        pars = suppress_pd(pars)
1030        pars2 = suppress_pd(pars2)
1031    if not opts['magnetic']:
1032        pars = suppress_magnetism(pars)
1033        pars2 = suppress_magnetism(pars2)
1034
1035    # Fill in parameters given on the command line
1036    presets = {}
1037    presets2 = {}
1038    for arg in values:
1039        k, v = arg.split('=', 1)
1040        if k not in pars and k not in pars2:
1041            # extract base name without polydispersity info
1042            s = set(p.split('_pd')[0] for p in pars)
1043            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1044            return None
1045        v1, v2 = v.split(MODEL_SPLIT, 2) if MODEL_SPLIT in v else (v,v)
1046        if v1 and k in pars:
1047            presets[k] = float(v1) if isnumber(v1) else v1
1048        if v2 and k in pars2:
1049            presets2[k] = float(v2) if isnumber(v2) else v2
1050
1051    # If pd given on the command line, default pd_n to 35
1052    for k, v in list(presets.items()):
1053        if k.endswith('_pd'):
1054            presets.setdefault(k+'_n', 35.)
1055    for k, v in list(presets2.items()):
1056        if k.endswith('_pd'):
1057            presets2.setdefault(k+'_n', 35.)
1058
1059    # Evaluate preset parameter expressions
1060    context = MATH.copy()
1061    context['np'] = np
1062    context.update(pars)
1063    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
1064    for k, v in presets.items():
1065        if not isinstance(v, float) and not k.endswith('_type'):
1066            presets[k] = eval(v, context)
1067    context.update(presets)
1068    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
1069    for k, v in presets2.items():
1070        if not isinstance(v, float) and not k.endswith('_type'):
1071            presets2[k] = eval(v, context)
1072
1073    # update parameters with presets
1074    pars.update(presets)  # set value after random to control value
1075    pars2.update(presets2)  # set value after random to control value
1076    #import pprint; pprint.pprint(model_info)
1077
1078    same_model = name == name2 and pars == pars
1079    if len(engines) == 0:
1080        if same_model:
1081            engines.extend(['single', 'double'])
1082        else:
1083            engines.extend(['single', 'single'])
1084    elif len(engines) == 1:
1085        if not same_model:
1086            engines.append(engines[0])
1087        elif engines[0] == 'double':
1088            engines.append('single')
1089        else:
1090            engines.append('double')
1091    elif len(engines) > 2:
1092        del engines[2:]
1093
1094    use_sasview = any(engine == 'sasview' and count > 0
1095                      for engine, count in zip(engines, [n1, n2]))
1096    if use_sasview:
1097        constrain_new_to_old(model_info, pars)
1098        constrain_new_to_old(model_info2, pars2)
1099
1100    if opts['show_pars']:
1101        if not same_model:
1102            print("==== %s ====="%model_info.name)
1103            print(str(parlist(model_info, pars, opts['is2d'])))
1104            print("==== %s ====="%model_info2.name)
1105            print(str(parlist(model_info2, pars2, opts['is2d'])))
1106        else:
1107            print(str(parlist(model_info, pars, opts['is2d'])))
1108
1109    # Create the computational engines
1110    data, _ = make_data(opts)
1111    if n1:
1112        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1113    else:
1114        base = None
1115    if n2:
1116        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1117    else:
1118        comp = None
1119
1120    # pylint: disable=bad-whitespace
1121    # Remember it all
1122    opts.update({
1123        'data'      : data,
1124        'name'      : [name, name2],
1125        'def'       : [model_info, model_info2],
1126        'count'     : [n1, n2],
1127        'presets'   : [presets, presets2],
1128        'pars'      : [pars, pars2],
1129        'engines'   : [base, comp],
1130    })
1131    # pylint: enable=bad-whitespace
1132
1133    return opts
1134
1135def show_docs(opts):
1136    # type: (Dict[str, Any]) -> None
1137    """
1138    show html docs for the model
1139    """
1140    import wx  # type: ignore
1141    from .generate import view_html_from_info
1142    app = wx.App() if wx.GetApp() is None else None
1143    view_html_from_info(opts['def'][0])
1144    if app: app.MainLoop()
1145
1146
1147def explore(opts):
1148    # type: (Dict[str, Any]) -> None
1149    """
1150    explore the model using the bumps gui.
1151    """
1152    import wx  # type: ignore
1153    from bumps.names import FitProblem  # type: ignore
1154    from bumps.gui.app_frame import AppFrame  # type: ignore
1155    from bumps.gui import signal
1156
1157    is_mac = "cocoa" in wx.version()
1158    # Create an app if not running embedded
1159    app = wx.App() if wx.GetApp() is None else None
1160    model = Explore(opts)
1161    problem = FitProblem(model)
1162    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1163    if not is_mac: frame.Show()
1164    frame.panel.set_model(model=problem)
1165    frame.panel.Layout()
1166    frame.panel.aui.Split(0, wx.TOP)
1167    def reset_parameters(event):
1168        model.revert_values()
1169        signal.update_parameters(problem)
1170    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1171    if is_mac: frame.Show()
1172    # If running withing an app, start the main loop
1173    if app: app.MainLoop()
1174
1175class Explore(object):
1176    """
1177    Bumps wrapper for a SAS model comparison.
1178
1179    The resulting object can be used as a Bumps fit problem so that
1180    parameters can be adjusted in the GUI, with plots updated on the fly.
1181    """
1182    def __init__(self, opts):
1183        # type: (Dict[str, Any]) -> None
1184        from bumps.cli import config_matplotlib  # type: ignore
1185        from . import bumps_model
1186        config_matplotlib()
1187        self.opts = opts
1188        p1, p2 = opts['pars']
1189        m1, m2 = opts['def']
1190        self.fix_p2 = m1 != m2 or p1 != p2
1191        model_info = m1
1192        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1193        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1194        if not opts['is2d']:
1195            for p in model_info.parameters.user_parameters({}, is2d=False):
1196                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1197                    k = p.name+ext
1198                    v = pars.get(k, None)
1199                    if v is not None:
1200                        v.range(*parameter_range(k, v.value))
1201        else:
1202            for k, v in pars.items():
1203                v.range(*parameter_range(k, v.value))
1204
1205        self.pars = pars
1206        self.starting_values = dict((k, v.value) for k, v in pars.items())
1207        self.pd_types = pd_types
1208        self.limits = None
1209
1210    def revert_values(self):
1211        for k, v in self.starting_values.items():
1212            self.pars[k].value = v
1213
1214    def model_update(self):
1215        pass
1216
1217    def numpoints(self):
1218        # type: () -> int
1219        """
1220        Return the number of points.
1221        """
1222        return len(self.pars) + 1  # so dof is 1
1223
1224    def parameters(self):
1225        # type: () -> Any   # Dict/List hierarchy of parameters
1226        """
1227        Return a dictionary of parameters.
1228        """
1229        return self.pars
1230
1231    def nllf(self):
1232        # type: () -> float
1233        """
1234        Return cost.
1235        """
1236        # pylint: disable=no-self-use
1237        return 0.  # No nllf
1238
1239    def plot(self, view='log'):
1240        # type: (str) -> None
1241        """
1242        Plot the data and residuals.
1243        """
1244        pars = dict((k, v.value) for k, v in self.pars.items())
1245        pars.update(self.pd_types)
1246        self.opts['pars'][0] = pars
1247        if not self.fix_p2:
1248            self.opts['pars'][1] = pars
1249        result = run_models(self.opts)
1250        limits = plot_models(self.opts, result, limits=self.limits)
1251        if self.limits is None:
1252            vmin, vmax = limits
1253            self.limits = vmax*1e-7, 1.3*vmax
1254            import pylab; pylab.clf()
1255            plot_models(self.opts, result, limits=self.limits)
1256
1257
1258def main(*argv):
1259    # type: (*str) -> None
1260    """
1261    Main program.
1262    """
1263    opts = parse_opts(argv)
1264    if opts is not None:
1265        if opts['html']:
1266            show_docs(opts)
1267        elif opts['explore']:
1268            explore(opts)
1269        else:
1270            compare(opts)
1271
1272if __name__ == "__main__":
1273    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.