source: sasmodels/sasmodels/compare.py @ 8c65a33

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 8c65a33 was 8c65a33, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

change model-model comparison to use , because : is the windows drive separator, and models may be specified with the full path

  • Property mode set to 100755
File size: 43.9 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35import re
36
37import numpy as np  # type: ignore
38
39from . import core
40from . import kerneldll
41from . import exception
42from .data import plot_theory, empty_data1D, empty_data2D
43from .direct_model import DirectModel
44from .convert import revert_name, revert_pars, constrain_new_to_old
45from .generate import FLOAT_RE
46
47try:
48    from typing import Optional, Dict, Any, Callable, Tuple
49except Exception:
50    pass
51else:
52    from .modelinfo import ModelInfo, Parameter, ParameterSet
53    from .data import Data
54    Calculator = Callable[[float], np.ndarray]
55
56USAGE = """
57usage: compare.py model N1 N2 [options...] [key=val]
58
59Compare the speed and value for a model between the SasView original and the
60sasmodels rewrite.
61
62model or model1,model2 are the names of the models to compare (see below).
63N1 is the number of times to run sasmodels (default=1).
64N2 is the number times to run sasview (default=1).
65
66Options (* for default):
67
68    -plot*/-noplot plots or suppress the plot of the model
69    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
70    -nq=128 sets the number of Q points in the data set
71    -zero indicates that q=0 should be included
72    -1d*/-2d computes 1d or 2d data
73    -preset*/-random[=seed] preset or random parameters
74    -mono/-poly* force monodisperse/polydisperse
75    -magnetic/-nonmagnetic* suppress magnetism
76    -cutoff=1e-5* cutoff value for including a point in polydispersity
77    -pars/-nopars* prints the parameter set or not
78    -abs/-rel* plot relative or absolute error
79    -linear/-log*/-q4 intensity scaling
80    -hist/-nohist* plot histogram of relative error
81    -res=0 sets the resolution width dQ/Q if calculating with resolution
82    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
83    -edit starts the parameter explorer
84    -default/-demo* use demo vs default parameters
85    -html shows the model docs instead of running the model
86    -title="note" adds note to the plot title, after the model name
87
88Any two calculation engines can be selected for comparison:
89
90    -single/-double/-half/-fast sets an OpenCL calculation engine
91    -single!/-double!/-quad! sets an OpenMP calculation engine
92    -sasview sets the sasview calculation engine
93
94The default is -single -double.  Note that the interpretation of quad
95precision depends on architecture, and may vary from 64-bit to 128-bit,
96with 80-bit floats being common (1e-19 precision).
97
98Key=value pairs allow you to set specific values for the model parameters.
99Key=value1,value2 to compare different values of the same parameter.
100value can be an expression including other parameters
101"""
102
103# Update docs with command line usage string.   This is separate from the usual
104# doc string so that we can display it at run time if there is an error.
105# lin
106__doc__ = (__doc__  # pylint: disable=redefined-builtin
107           + """
108Program description
109-------------------
110
111"""
112           + USAGE)
113
114kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
115
116# list of math functions for use in evaluating parameters
117MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
118
119# CRUFT python 2.6
120if not hasattr(datetime.timedelta, 'total_seconds'):
121    def delay(dt):
122        """Return number date-time delta as number seconds"""
123        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
124else:
125    def delay(dt):
126        """Return number date-time delta as number seconds"""
127        return dt.total_seconds()
128
129
130class push_seed(object):
131    """
132    Set the seed value for the random number generator.
133
134    When used in a with statement, the random number generator state is
135    restored after the with statement is complete.
136
137    :Parameters:
138
139    *seed* : int or array_like, optional
140        Seed for RandomState
141
142    :Example:
143
144    Seed can be used directly to set the seed::
145
146        >>> from numpy.random import randint
147        >>> push_seed(24)
148        <...push_seed object at...>
149        >>> print(randint(0,1000000,3))
150        [242082    899 211136]
151
152    Seed can also be used in a with statement, which sets the random
153    number generator state for the enclosed computations and restores
154    it to the previous state on completion::
155
156        >>> with push_seed(24):
157        ...    print(randint(0,1000000,3))
158        [242082    899 211136]
159
160    Using nested contexts, we can demonstrate that state is indeed
161    restored after the block completes::
162
163        >>> with push_seed(24):
164        ...    print(randint(0,1000000))
165        ...    with push_seed(24):
166        ...        print(randint(0,1000000,3))
167        ...    print(randint(0,1000000))
168        242082
169        [242082    899 211136]
170        899
171
172    The restore step is protected against exceptions in the block::
173
174        >>> with push_seed(24):
175        ...    print(randint(0,1000000))
176        ...    try:
177        ...        with push_seed(24):
178        ...            print(randint(0,1000000,3))
179        ...            raise Exception()
180        ...    except Exception:
181        ...        print("Exception raised")
182        ...    print(randint(0,1000000))
183        242082
184        [242082    899 211136]
185        Exception raised
186        899
187    """
188    def __init__(self, seed=None):
189        # type: (Optional[int]) -> None
190        self._state = np.random.get_state()
191        np.random.seed(seed)
192
193    def __enter__(self):
194        # type: () -> None
195        pass
196
197    def __exit__(self, exc_type, exc_value, traceback):
198        # type: (Any, BaseException, Any) -> None
199        # TODO: better typing for __exit__ method
200        np.random.set_state(self._state)
201
202def tic():
203    # type: () -> Callable[[], float]
204    """
205    Timer function.
206
207    Use "toc=tic()" to start the clock and "toc()" to measure
208    a time interval.
209    """
210    then = datetime.datetime.now()
211    return lambda: delay(datetime.datetime.now() - then)
212
213
214def set_beam_stop(data, radius, outer=None):
215    # type: (Data, float, float) -> None
216    """
217    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
218
219    Note: this function does not require sasview
220    """
221    if hasattr(data, 'qx_data'):
222        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
223        data.mask = (q < radius)
224        if outer is not None:
225            data.mask |= (q >= outer)
226    else:
227        data.mask = (data.x < radius)
228        if outer is not None:
229            data.mask |= (data.x >= outer)
230
231
232def parameter_range(p, v):
233    # type: (str, float) -> Tuple[float, float]
234    """
235    Choose a parameter range based on parameter name and initial value.
236    """
237    # process the polydispersity options
238    if p.endswith('_pd_n'):
239        return 0., 100.
240    elif p.endswith('_pd_nsigma'):
241        return 0., 5.
242    elif p.endswith('_pd_type'):
243        raise ValueError("Cannot return a range for a string value")
244    elif any(s in p for s in ('theta', 'phi', 'psi')):
245        # orientation in [-180,180], orientation pd in [0,45]
246        if p.endswith('_pd'):
247            return 0., 45.
248        else:
249            return -180., 180.
250    elif p.endswith('_pd'):
251        return 0., 1.
252    elif 'sld' in p:
253        return -0.5, 10.
254    elif p == 'background':
255        return 0., 10.
256    elif p == 'scale':
257        return 0., 1.e3
258    elif v < 0.:
259        return 2.*v, -2.*v
260    else:
261        return 0., (2.*v if v > 0. else 1.)
262
263
264def _randomize_one(model_info, p, v):
265    # type: (ModelInfo, str, float) -> float
266    # type: (ModelInfo, str, str) -> str
267    """
268    Randomize a single parameter.
269    """
270    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
271        return v
272
273    # Find the parameter definition
274    for par in model_info.parameters.call_parameters:
275        if par.name == p:
276            break
277    else:
278        raise ValueError("unknown parameter %r"%p)
279
280    if len(par.limits) > 2:  # choice list
281        return np.random.randint(len(par.limits))
282
283    limits = par.limits
284    if not np.isfinite(limits).all():
285        low, high = parameter_range(p, v)
286        limits = (max(limits[0], low), min(limits[1], high))
287
288    return np.random.uniform(*limits)
289
290
291def randomize_pars(model_info, pars, seed=None):
292    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
293    """
294    Generate random values for all of the parameters.
295
296    Valid ranges for the random number generator are guessed from the name of
297    the parameter; this will not account for constraints such as cap radius
298    greater than cylinder radius in the capped_cylinder model, so
299    :func:`constrain_pars` needs to be called afterward..
300    """
301    with push_seed(seed):
302        # Note: the sort guarantees order `of calls to random number generator
303        random_pars = dict((p, _randomize_one(model_info, p, v))
304                           for p, v in sorted(pars.items()))
305    return random_pars
306
307def constrain_pars(model_info, pars):
308    # type: (ModelInfo, ParameterSet) -> None
309    """
310    Restrict parameters to valid values.
311
312    This includes model specific code for models such as capped_cylinder
313    which need to support within model constraints (cap radius more than
314    cylinder radius in this case).
315
316    Warning: this updates the *pars* dictionary in place.
317    """
318    name = model_info.id
319    # if it is a product model, then just look at the form factor since
320    # none of the structure factors need any constraints.
321    if '*' in name:
322        name = name.split('*')[0]
323
324    if name == 'barbell':
325        if pars['radius_bell'] < pars['radius']:
326            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
327
328    elif name == 'capped_cylinder':
329        if pars['radius_cap'] < pars['radius']:
330            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
331
332    elif name == 'guinier':
333        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
334        #q_max = 0.2  # mid q maximum
335        q_max = 1.0  # high q maximum
336        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
337        pars['rg'] = min(pars['rg'], rg_max)
338
339    elif name == 'pearl_necklace':
340        if pars['radius'] < pars['thick_string']:
341            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
342        pars['num_pearls'] = math.ceil(pars['num_pearls'])
343        pass
344
345    elif name == 'rpa':
346        # Make sure phi sums to 1.0
347        if pars['case_num'] < 2:
348            pars['Phi1'] = 0.
349            pars['Phi2'] = 0.
350        elif pars['case_num'] < 5:
351            pars['Phi1'] = 0.
352        total = sum(pars['Phi'+c] for c in '1234')
353        for c in '1234':
354            pars['Phi'+c] /= total
355
356    elif name == 'stacked_disks':
357        pars['n_stacking'] = math.ceil(pars['n_stacking'])
358
359def parlist(model_info, pars, is2d):
360    # type: (ModelInfo, ParameterSet, bool) -> str
361    """
362    Format the parameter list for printing.
363    """
364    lines = []
365    parameters = model_info.parameters
366    magnetic = False
367    for p in parameters.user_parameters(pars, is2d):
368        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
369            continue
370        if p.id.startswith('up:') and not magnetic:
371            continue
372        fields = dict(
373            value=pars.get(p.id, p.default),
374            pd=pars.get(p.id+"_pd", 0.),
375            n=int(pars.get(p.id+"_pd_n", 0)),
376            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
377            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
378            relative_pd=p.relative_pd,
379            M0=pars.get('M0:'+p.id, 0.),
380            mphi=pars.get('mphi:'+p.id, 0.),
381            mtheta=pars.get('mtheta:'+p.id, 0.),
382        )
383        lines.append(_format_par(p.name, **fields))
384        magnetic = magnetic or fields['M0'] != 0.
385    return "\n".join(lines)
386
387    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
388
389def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
390                relative_pd=False, M0=0., mphi=0., mtheta=0.):
391    # type: (str, float, float, int, float, str) -> str
392    line = "%s: %g"%(name, value)
393    if pd != 0.  and n != 0:
394        if relative_pd:
395            pd *= value
396        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
397                % (pd, n, nsigma, nsigma, pdtype)
398    if M0 != 0.:
399        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
400    return line
401
402def suppress_pd(pars):
403    # type: (ParameterSet) -> ParameterSet
404    """
405    Suppress theta_pd for now until the normalization is resolved.
406
407    May also suppress complete polydispersity of the model to test
408    models more quickly.
409    """
410    pars = pars.copy()
411    for p in pars:
412        if p.endswith("_pd_n"): pars[p] = 0
413    return pars
414
415def suppress_magnetism(pars):
416    # type: (ParameterSet) -> ParameterSet
417    """
418    Suppress theta_pd for now until the normalization is resolved.
419
420    May also suppress complete polydispersity of the model to test
421    models more quickly.
422    """
423    pars = pars.copy()
424    for p in pars:
425        if p.startswith("M0:"): pars[p] = 0
426    return pars
427
428def eval_sasview(model_info, data):
429    # type: (Modelinfo, Data) -> Calculator
430    """
431    Return a model calculator using the pre-4.0 SasView models.
432    """
433    # importing sas here so that the error message will be that sas failed to
434    # import rather than the more obscure smear_selection not imported error
435    import sas
436    import sas.models
437    from sas.models.qsmearing import smear_selection
438    from sas.models.MultiplicationModel import MultiplicationModel
439    from sas.models.dispersion_models import models as dispersers
440
441    def get_model_class(name):
442        # type: (str) -> "sas.models.BaseComponent"
443        #print("new",sorted(_pars.items()))
444        __import__('sas.models.' + name)
445        ModelClass = getattr(getattr(sas.models, name, None), name, None)
446        if ModelClass is None:
447            raise ValueError("could not find model %r in sas.models"%name)
448        return ModelClass
449
450    # WARNING: ugly hack when handling model!
451    # Sasview models with multiplicity need to be created with the target
452    # multiplicity, so we cannot create the target model ahead of time for
453    # for multiplicity models.  Instead we store the model in a list and
454    # update the first element of that list with the new multiplicity model
455    # every time we evaluate.
456
457    # grab the sasview model, or create it if it is a product model
458    if model_info.composition:
459        composition_type, parts = model_info.composition
460        if composition_type == 'product':
461            P, S = [get_model_class(revert_name(p))() for p in parts]
462            model = [MultiplicationModel(P, S)]
463        else:
464            raise ValueError("sasview mixture models not supported by compare")
465    else:
466        old_name = revert_name(model_info)
467        if old_name is None:
468            raise ValueError("model %r does not exist in old sasview"
469                            % model_info.id)
470        ModelClass = get_model_class(old_name)
471        model = [ModelClass()]
472    model[0].disperser_handles = {}
473
474    # build a smearer with which to call the model, if necessary
475    smearer = smear_selection(data, model=model)
476    if hasattr(data, 'qx_data'):
477        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
478        index = ((~data.mask) & (~np.isnan(data.data))
479                 & (q >= data.qmin) & (q <= data.qmax))
480        if smearer is not None:
481            smearer.model = model  # because smear_selection has a bug
482            smearer.accuracy = data.accuracy
483            smearer.set_index(index)
484            def _call_smearer():
485                smearer.model = model[0]
486                return smearer.get_value()
487            theory = _call_smearer
488        else:
489            theory = lambda: model[0].evalDistribution([data.qx_data[index],
490                                                        data.qy_data[index]])
491    elif smearer is not None:
492        theory = lambda: smearer(model[0].evalDistribution(data.x))
493    else:
494        theory = lambda: model[0].evalDistribution(data.x)
495
496    def calculator(**pars):
497        # type: (float, ...) -> np.ndarray
498        """
499        Sasview calculator for model.
500        """
501        oldpars = revert_pars(model_info, pars)
502        # For multiplicity models, create a model with the correct multiplicity
503        control = oldpars.pop("CONTROL", None)
504        if control is not None:
505            # sphericalSLD has one fewer multiplicity.  This update should
506            # happen in revert_pars, but it hasn't been called yet.
507            model[0] = ModelClass(control)
508        # paying for parameter conversion each time to keep life simple, if not fast
509        for k, v in oldpars.items():
510            if k.endswith('.type'):
511                par = k[:-5]
512                if v == 'gaussian': continue
513                cls = dispersers[v if v != 'rectangle' else 'rectangula']
514                handle = cls()
515                model[0].disperser_handles[par] = handle
516                try:
517                    model[0].set_dispersion(par, handle)
518                except Exception:
519                    exception.annotate_exception("while setting %s to %r"
520                                                 %(par, v))
521                    raise
522
523
524        #print("sasview pars",oldpars)
525        for k, v in oldpars.items():
526            name_attr = k.split('.')  # polydispersity components
527            if len(name_attr) == 2:
528                par, disp_par = name_attr
529                model[0].dispersion[par][disp_par] = v
530            else:
531                model[0].setParam(k, v)
532        return theory()
533
534    calculator.engine = "sasview"
535    return calculator
536
537DTYPE_MAP = {
538    'half': '16',
539    'fast': 'fast',
540    'single': '32',
541    'double': '64',
542    'quad': '128',
543    'f16': '16',
544    'f32': '32',
545    'f64': '64',
546    'longdouble': '128',
547}
548def eval_opencl(model_info, data, dtype='single', cutoff=0.):
549    # type: (ModelInfo, Data, str, float) -> Calculator
550    """
551    Return a model calculator using the OpenCL calculation engine.
552    """
553    if not core.HAVE_OPENCL:
554        raise RuntimeError("OpenCL not available")
555    model = core.build_model(model_info, dtype=dtype, platform="ocl")
556    calculator = DirectModel(data, model, cutoff=cutoff)
557    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
558    return calculator
559
560def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
561    # type: (ModelInfo, Data, str, float) -> Calculator
562    """
563    Return a model calculator using the DLL calculation engine.
564    """
565    model = core.build_model(model_info, dtype=dtype, platform="dll")
566    calculator = DirectModel(data, model, cutoff=cutoff)
567    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
568    return calculator
569
570def time_calculation(calculator, pars, evals=1):
571    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
572    """
573    Compute the average calculation time over N evaluations.
574
575    An additional call is generated without polydispersity in order to
576    initialize the calculation engine, and make the average more stable.
577    """
578    # initialize the code so time is more accurate
579    if evals > 1:
580        calculator(**suppress_pd(pars))
581    toc = tic()
582    # make sure there is at least one eval
583    value = calculator(**pars)
584    for _ in range(evals-1):
585        value = calculator(**pars)
586    average_time = toc()*1000. / evals
587    #print("I(q)",value)
588    return value, average_time
589
590def make_data(opts):
591    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
592    """
593    Generate an empty dataset, used with the model to set Q points
594    and resolution.
595
596    *opts* contains the options, with 'qmax', 'nq', 'res',
597    'accuracy', 'is2d' and 'view' parsed from the command line.
598    """
599    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
600    if opts['is2d']:
601        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
602        data = empty_data2D(q, resolution=res)
603        data.accuracy = opts['accuracy']
604        set_beam_stop(data, 0.0004)
605        index = ~data.mask
606    else:
607        if opts['view'] == 'log' and not opts['zero']:
608            qmax = math.log10(qmax)
609            q = np.logspace(qmax-3, qmax, nq)
610        else:
611            q = np.linspace(0.001*qmax, qmax, nq)
612        if opts['zero']:
613            q = np.hstack((0, q))
614        data = empty_data1D(q, resolution=res)
615        index = slice(None, None)
616    return data, index
617
618def make_engine(model_info, data, dtype, cutoff):
619    # type: (ModelInfo, Data, str, float) -> Calculator
620    """
621    Generate the appropriate calculation engine for the given datatype.
622
623    Datatypes with '!' appended are evaluated using external C DLLs rather
624    than OpenCL.
625    """
626    if dtype == 'sasview':
627        return eval_sasview(model_info, data)
628    elif dtype.endswith('!'):
629        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
630    else:
631        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
632
633def _show_invalid(data, theory):
634    # type: (Data, np.ma.ndarray) -> None
635    """
636    Display a list of the non-finite values in theory.
637    """
638    if not theory.mask.any():
639        return
640
641    if hasattr(data, 'x'):
642        bad = zip(data.x[theory.mask], theory[theory.mask])
643        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
644
645
646def compare(opts, limits=None):
647    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
648    """
649    Preform a comparison using options from the command line.
650
651    *limits* are the limits on the values to use, either to set the y-axis
652    for 1D or to set the colormap scale for 2D.  If None, then they are
653    inferred from the data and returned. When exploring using Bumps,
654    the limits are set when the model is initially called, and maintained
655    as the values are adjusted, making it easier to see the effects of the
656    parameters.
657    """
658    result = run_models(opts, verbose=True)
659    if opts['plot']:  # Note: never called from explore
660        plot_models(opts, result, limits=limits)
661
662def run_models(opts, verbose=False):
663    # type: (Dict[str, Any]) -> Dict[str, Any]
664
665    n_base, n_comp = opts['count']
666    pars, pars2 = opts['pars']
667    data = opts['data']
668
669    # silence the linter
670    base = opts['engines'][0] if n_base else None
671    comp = opts['engines'][1] if n_comp else None
672
673    base_time = comp_time = None
674    base_value = comp_value = resid = relerr = None
675
676    # Base calculation
677    if n_base > 0:
678        try:
679            base_raw, base_time = time_calculation(base, pars, n_base)
680            base_value = np.ma.masked_invalid(base_raw)
681            if verbose:
682                print("%s t=%.2f ms, intensity=%.0f"
683                      % (base.engine, base_time, base_value.sum()))
684            _show_invalid(data, base_value)
685        except ImportError:
686            traceback.print_exc()
687            n_base = 0
688
689    # Comparison calculation
690    if n_comp > 0:
691        try:
692            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
693            comp_value = np.ma.masked_invalid(comp_raw)
694            if verbose:
695                print("%s t=%.2f ms, intensity=%.0f"
696                      % (comp.engine, comp_time, comp_value.sum()))
697            _show_invalid(data, comp_value)
698        except ImportError:
699            traceback.print_exc()
700            n_comp = 0
701
702    # Compare, but only if computing both forms
703    if n_base > 0 and n_comp > 0:
704        resid = (base_value - comp_value)
705        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
706        if verbose:
707            _print_stats("|%s-%s|"
708                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
709                         resid)
710            _print_stats("|(%s-%s)/%s|"
711                         % (base.engine, comp.engine, comp.engine),
712                         relerr)
713
714    return dict(base_value=base_value, comp_value=comp_value,
715                base_time=base_time, comp_time=comp_time,
716                resid=resid, relerr=relerr)
717
718
719def _print_stats(label, err):
720    # type: (str, np.ma.ndarray) -> None
721    # work with trimmed data, not the full set
722    sorted_err = np.sort(abs(err.compressed()))
723    if len(sorted_err) == 0.:
724        print(label + "  no valid values")
725        return
726
727    p50 = int((len(sorted_err)-1)*0.50)
728    p98 = int((len(sorted_err)-1)*0.98)
729    data = [
730        "max:%.3e"%sorted_err[-1],
731        "median:%.3e"%sorted_err[p50],
732        "98%%:%.3e"%sorted_err[p98],
733        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
734        "zero-offset:%+.3e"%np.mean(sorted_err),
735        ]
736    print(label+"  "+"  ".join(data))
737
738
739def plot_models(opts, result, limits=None):
740    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
741    base_value, comp_value= result['base_value'], result['comp_value']
742    base_time, comp_time = result['base_time'], result['comp_time']
743    resid, relerr = result['resid'], result['relerr']
744
745    have_base, have_comp = (base_value is not None), (comp_value is not None)
746    base = opts['engines'][0] if have_base else None
747    comp = opts['engines'][1] if have_comp else None
748    data = opts['data']
749
750    # Plot if requested
751    view = opts['view']
752    import matplotlib.pyplot as plt
753    if limits is None:
754        vmin, vmax = np.Inf, -np.Inf
755        if have_base:
756            vmin = min(vmin, base_value.min())
757            vmax = max(vmax, base_value.max())
758        if have_comp:
759            vmin = min(vmin, comp_value.min())
760            vmax = max(vmax, comp_value.max())
761        limits = vmin, vmax
762
763    if have_base:
764        if have_comp: plt.subplot(131)
765        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
766        plt.title("%s t=%.2f ms"%(base.engine, base_time))
767        #cbar_title = "log I"
768    if have_comp:
769        if have_base: plt.subplot(132)
770        if not opts['is2d'] and have_base:
771            plot_theory(data, base_value, view=view, use_data=False, limits=limits)
772        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
773        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
774        #cbar_title = "log I"
775    if have_base and have_comp:
776        plt.subplot(133)
777        if not opts['rel_err']:
778            err, errstr, errview = resid, "abs err", "linear"
779        else:
780            err, errstr, errview = abs(relerr), "rel err", "log"
781        if 0:  # 95% cutoff
782            sorted = np.sort(err.flatten())
783            cutoff = sorted[int(sorted.size*0.95)]
784            err[err>cutoff] = cutoff
785        #err,errstr = base/comp,"ratio"
786        plot_theory(data, None, resid=err, view=errview, use_data=False)
787        if view == 'linear':
788            plt.xscale('linear')
789        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
790        #cbar_title = errstr if errview=="linear" else "log "+errstr
791    #if is2D:
792    #    h = plt.colorbar()
793    #    h.ax.set_title(cbar_title)
794    fig = plt.gcf()
795    extra_title = ' '+opts['title'] if opts['title'] else ''
796    fig.suptitle(":".join(opts['name']) + extra_title)
797
798    if have_base and have_comp and opts['show_hist']:
799        plt.figure()
800        v = relerr
801        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
802        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
803        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
804                   % (base.engine, comp.engine, comp.engine))
805        plt.ylabel('P(err)')
806        plt.title('Distribution of relative error between calculation engines')
807
808    if not opts['explore']:
809        plt.show()
810
811    return limits
812
813
814
815
816# ===========================================================================
817#
818NAME_OPTIONS = set([
819    'plot', 'noplot',
820    'half', 'fast', 'single', 'double',
821    'single!', 'double!', 'quad!', 'sasview',
822    'lowq', 'midq', 'highq', 'exq', 'zero',
823    '2d', '1d',
824    'preset', 'random',
825    'poly', 'mono',
826    'magnetic', 'nonmagnetic',
827    'nopars', 'pars',
828    'rel', 'abs',
829    'linear', 'log', 'q4',
830    'hist', 'nohist',
831    'edit', 'html',
832    'demo', 'default',
833    ])
834VALUE_OPTIONS = [
835    # Note: random is both a name option and a value option
836    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title',
837    ]
838
839def columnize(items, indent="", width=79):
840    # type: (List[str], str, int) -> str
841    """
842    Format a list of strings into columns.
843
844    Returns a string with carriage returns ready for printing.
845    """
846    column_width = max(len(w) for w in items) + 1
847    num_columns = (width - len(indent)) // column_width
848    num_rows = len(items) // num_columns
849    items = items + [""] * (num_rows * num_columns - len(items))
850    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
851    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
852             for row in zip(*columns)]
853    output = indent + ("\n"+indent).join(lines)
854    return output
855
856
857def get_pars(model_info, use_demo=False):
858    # type: (ModelInfo, bool) -> ParameterSet
859    """
860    Extract demo parameters from the model definition.
861    """
862    # Get the default values for the parameters
863    pars = {}
864    for p in model_info.parameters.call_parameters:
865        parts = [('', p.default)]
866        if p.polydisperse:
867            parts.append(('_pd', 0.0))
868            parts.append(('_pd_n', 0))
869            parts.append(('_pd_nsigma', 3.0))
870            parts.append(('_pd_type', "gaussian"))
871        for ext, val in parts:
872            if p.length > 1:
873                dict(("%s%d%s" % (p.id, k, ext), val)
874                     for k in range(1, p.length+1))
875            else:
876                pars[p.id + ext] = val
877
878    # Plug in values given in demo
879    if use_demo:
880        pars.update(model_info.demo)
881    return pars
882
883INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
884def isnumber(str):
885    match = FLOAT_RE.match(str)
886    isfloat = (match and not str[match.end():])
887    return isfloat or INTEGER_RE.match(str)
888
889# For distinguishing pairs of models for comparison
890# key-value pair separator =
891# shell characters  | & ; <> $ % ' " \ # `
892# model and parameter names _
893# parameter expressions - + * / . ( )
894# path characters including tilde expansion and windows drive ~ / :
895# not sure about brackets [] {}
896# maybe one of the following @ ? ^ ! ,
897MODEL_SPLIT = ','
898def parse_opts(argv):
899    # type: (List[str]) -> Dict[str, Any]
900    """
901    Parse command line options.
902    """
903    MODELS = core.list_models()
904    flags = [arg for arg in argv
905             if arg.startswith('-')]
906    values = [arg for arg in argv
907              if not arg.startswith('-') and '=' in arg]
908    positional_args = [arg for arg in argv
909            if not arg.startswith('-') and '=' not in arg]
910    models = "\n    ".join("%-15s"%v for v in MODELS)
911    if len(positional_args) == 0:
912        print(USAGE)
913        print("\nAvailable models:")
914        print(columnize(MODELS, indent="  "))
915        return None
916    if len(positional_args) > 3:
917        print("expected parameters: model N1 N2")
918
919    invalid = [o[1:] for o in flags
920               if o[1:] not in NAME_OPTIONS
921               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
922    if invalid:
923        print("Invalid options: %s"%(", ".join(invalid)))
924        return None
925
926    name = positional_args[0]
927    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
928    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
929
930    # pylint: disable=bad-whitespace
931    # Interpret the flags
932    opts = {
933        'plot'      : True,
934        'view'      : 'log',
935        'is2d'      : False,
936        'qmax'      : 0.05,
937        'nq'        : 128,
938        'res'       : 0.0,
939        'accuracy'  : 'Low',
940        'cutoff'    : 0.0,
941        'seed'      : -1,  # default to preset
942        'mono'      : False,
943        # Default to magnetic a magnetic moment is set on the command line
944        'magnetic'  : False,
945        'show_pars' : False,
946        'show_hist' : False,
947        'rel_err'   : True,
948        'explore'   : False,
949        'use_demo'  : True,
950        'zero'      : False,
951        'html'      : False,
952        'title'     : None,
953    }
954    engines = []
955    for arg in flags:
956        if arg == '-noplot':    opts['plot'] = False
957        elif arg == '-plot':    opts['plot'] = True
958        elif arg == '-linear':  opts['view'] = 'linear'
959        elif arg == '-log':     opts['view'] = 'log'
960        elif arg == '-q4':      opts['view'] = 'q4'
961        elif arg == '-1d':      opts['is2d'] = False
962        elif arg == '-2d':      opts['is2d'] = True
963        elif arg == '-exq':     opts['qmax'] = 10.0
964        elif arg == '-highq':   opts['qmax'] = 1.0
965        elif arg == '-midq':    opts['qmax'] = 0.2
966        elif arg == '-lowq':    opts['qmax'] = 0.05
967        elif arg == '-zero':    opts['zero'] = True
968        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
969        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
970        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
971        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
972        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
973        elif arg.startswith('-title'):     opts['title'] = arg[7:]
974        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
975        elif arg == '-preset':  opts['seed'] = -1
976        elif arg == '-mono':    opts['mono'] = True
977        elif arg == '-poly':    opts['mono'] = False
978        elif arg == '-magnetic':       opts['magnetic'] = True
979        elif arg == '-nonmagnetic':    opts['magnetic'] = False
980        elif arg == '-pars':    opts['show_pars'] = True
981        elif arg == '-nopars':  opts['show_pars'] = False
982        elif arg == '-hist':    opts['show_hist'] = True
983        elif arg == '-nohist':  opts['show_hist'] = False
984        elif arg == '-rel':     opts['rel_err'] = True
985        elif arg == '-abs':     opts['rel_err'] = False
986        elif arg == '-half':    engines.append(arg[1:])
987        elif arg == '-fast':    engines.append(arg[1:])
988        elif arg == '-single':  engines.append(arg[1:])
989        elif arg == '-double':  engines.append(arg[1:])
990        elif arg == '-single!': engines.append(arg[1:])
991        elif arg == '-double!': engines.append(arg[1:])
992        elif arg == '-quad!':   engines.append(arg[1:])
993        elif arg == '-sasview': engines.append(arg[1:])
994        elif arg == '-edit':    opts['explore'] = True
995        elif arg == '-demo':    opts['use_demo'] = True
996        elif arg == '-default':    opts['use_demo'] = False
997        elif arg == '-html':    opts['html'] = True
998    # pylint: enable=bad-whitespace
999
1000    if MODEL_SPLIT in name:
1001        name, name2 = name.split(MODEL_SPLIT, 2)
1002    else:
1003        name2 = name
1004    try:
1005        model_info = core.load_model_info(name)
1006        model_info2 = core.load_model_info(name2) if name2 != name else model_info
1007    except ImportError as exc:
1008        print(str(exc))
1009        print("Could not find model; use one of:\n    " + models)
1010        return None
1011
1012    # Get demo parameters from model definition, or use default parameters
1013    # if model does not define demo parameters
1014    pars = get_pars(model_info, opts['use_demo'])
1015    pars2 = get_pars(model_info2, opts['use_demo'])
1016    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1017    # randomize parameters
1018    #pars.update(set_pars)  # set value before random to control range
1019    if opts['seed'] > -1:
1020        pars = randomize_pars(model_info, pars, seed=opts['seed'])
1021        if model_info != model_info2:
1022            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
1023            # Share values for parameters with the same name
1024            for k, v in pars.items():
1025                if k in pars2:
1026                    pars2[k] = v
1027        else:
1028            pars2 = pars.copy()
1029        constrain_pars(model_info, pars)
1030        constrain_pars(model_info2, pars2)
1031        print("Randomize using -random=%i"%opts['seed'])
1032    if opts['mono']:
1033        pars = suppress_pd(pars)
1034        pars2 = suppress_pd(pars2)
1035    if not opts['magnetic']:
1036        pars = suppress_magnetism(pars)
1037        pars2 = suppress_magnetism(pars2)
1038
1039    # Fill in parameters given on the command line
1040    presets = {}
1041    presets2 = {}
1042    for arg in values:
1043        k, v = arg.split('=', 1)
1044        if k not in pars and k not in pars2:
1045            # extract base name without polydispersity info
1046            s = set(p.split('_pd')[0] for p in pars)
1047            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1048            return None
1049        v1, v2 = v.split(MODEL_SPLIT, 2) if MODEL_SPLIT in v else (v,v)
1050        if v1 and k in pars:
1051            presets[k] = float(v1) if isnumber(v1) else v1
1052        if v2 and k in pars2:
1053            presets2[k] = float(v2) if isnumber(v2) else v2
1054
1055    # If pd given on the command line, default pd_n to 35
1056    for k, v in list(presets.items()):
1057        if k.endswith('_pd'):
1058            presets.setdefault(k+'_n', 35.)
1059    for k, v in list(presets2.items()):
1060        if k.endswith('_pd'):
1061            presets2.setdefault(k+'_n', 35.)
1062
1063    # Evaluate preset parameter expressions
1064    context = MATH.copy()
1065    context.update(pars)
1066    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
1067    for k, v in presets.items():
1068        if not isinstance(v, float) and not k.endswith('_type'):
1069            presets[k] = eval(v, context)
1070    context.update(presets)
1071    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
1072    for k, v in presets2.items():
1073        if not isinstance(v, float) and not k.endswith('_type'):
1074            presets2[k] = eval(v, context)
1075
1076    # update parameters with presets
1077    pars.update(presets)  # set value after random to control value
1078    pars2.update(presets2)  # set value after random to control value
1079    #import pprint; pprint.pprint(model_info)
1080
1081    same_model = name == name2 and pars == pars
1082    if len(engines) == 0:
1083        if same_model:
1084            engines.extend(['single', 'double'])
1085        else:
1086            engines.extend(['single', 'single'])
1087    elif len(engines) == 1:
1088        if not same_model:
1089            engines.append(engines[0])
1090        elif engines[0] == 'double':
1091            engines.append('single')
1092        else:
1093            engines.append('double')
1094    elif len(engines) > 2:
1095        del engines[2:]
1096
1097    use_sasview = any(engine == 'sasview' and count > 0
1098                      for engine, count in zip(engines, [n1, n2]))
1099    if use_sasview:
1100        constrain_new_to_old(model_info, pars)
1101        constrain_new_to_old(model_info2, pars2)
1102
1103    if opts['show_pars']:
1104        if not same_model:
1105            print("==== %s ====="%model_info.name)
1106            print(str(parlist(model_info, pars, opts['is2d'])))
1107            print("==== %s ====="%model_info2.name)
1108            print(str(parlist(model_info2, pars2, opts['is2d'])))
1109        else:
1110            print(str(parlist(model_info, pars, opts['is2d'])))
1111
1112    # Create the computational engines
1113    data, _ = make_data(opts)
1114    if n1:
1115        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1116    else:
1117        base = None
1118    if n2:
1119        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1120    else:
1121        comp = None
1122
1123    # pylint: disable=bad-whitespace
1124    # Remember it all
1125    opts.update({
1126        'data'      : data,
1127        'name'      : [name, name2],
1128        'def'       : [model_info, model_info2],
1129        'count'     : [n1, n2],
1130        'presets'   : [presets, presets2],
1131        'pars'      : [pars, pars2],
1132        'engines'   : [base, comp],
1133    })
1134    # pylint: enable=bad-whitespace
1135
1136    return opts
1137
1138def show_docs(opts):
1139    # type: (Dict[str, Any]) -> None
1140    """
1141    show html docs for the model
1142    """
1143    import wx  # type: ignore
1144    from .generate import view_html_from_info
1145    app = wx.App() if wx.GetApp() is None else None
1146    view_html_from_info(opts['def'][0])
1147    if app: app.MainLoop()
1148
1149
1150def explore(opts):
1151    # type: (Dict[str, Any]) -> None
1152    """
1153    explore the model using the bumps gui.
1154    """
1155    import wx  # type: ignore
1156    from bumps.names import FitProblem  # type: ignore
1157    from bumps.gui.app_frame import AppFrame  # type: ignore
1158    from bumps.gui import signal
1159
1160    is_mac = "cocoa" in wx.version()
1161    # Create an app if not running embedded
1162    app = wx.App() if wx.GetApp() is None else None
1163    model = Explore(opts)
1164    problem = FitProblem(model)
1165    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1166    if not is_mac: frame.Show()
1167    frame.panel.set_model(model=problem)
1168    frame.panel.Layout()
1169    frame.panel.aui.Split(0, wx.TOP)
1170    def reset_parameters(event):
1171        model.revert_values()
1172        signal.update_parameters(problem)
1173    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1174    if is_mac: frame.Show()
1175    # If running withing an app, start the main loop
1176    if app: app.MainLoop()
1177
1178class Explore(object):
1179    """
1180    Bumps wrapper for a SAS model comparison.
1181
1182    The resulting object can be used as a Bumps fit problem so that
1183    parameters can be adjusted in the GUI, with plots updated on the fly.
1184    """
1185    def __init__(self, opts):
1186        # type: (Dict[str, Any]) -> None
1187        from bumps.cli import config_matplotlib  # type: ignore
1188        from . import bumps_model
1189        config_matplotlib()
1190        self.opts = opts
1191        p1, p2 = opts['pars']
1192        m1, m2 = opts['def']
1193        self.fix_p2 = m1 != m2 or p1 != p2
1194        model_info = m1
1195        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1196        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1197        if not opts['is2d']:
1198            for p in model_info.parameters.user_parameters({}, is2d=False):
1199                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1200                    k = p.name+ext
1201                    v = pars.get(k, None)
1202                    if v is not None:
1203                        v.range(*parameter_range(k, v.value))
1204        else:
1205            for k, v in pars.items():
1206                v.range(*parameter_range(k, v.value))
1207
1208        self.pars = pars
1209        self.starting_values = dict((k, v.value) for k, v in pars.items())
1210        self.pd_types = pd_types
1211        self.limits = None
1212
1213    def revert_values(self):
1214        for k, v in self.starting_values.items():
1215            self.pars[k].value = v
1216
1217    def model_update(self):
1218        pass
1219
1220    def numpoints(self):
1221        # type: () -> int
1222        """
1223        Return the number of points.
1224        """
1225        return len(self.pars) + 1  # so dof is 1
1226
1227    def parameters(self):
1228        # type: () -> Any   # Dict/List hierarchy of parameters
1229        """
1230        Return a dictionary of parameters.
1231        """
1232        return self.pars
1233
1234    def nllf(self):
1235        # type: () -> float
1236        """
1237        Return cost.
1238        """
1239        # pylint: disable=no-self-use
1240        return 0.  # No nllf
1241
1242    def plot(self, view='log'):
1243        # type: (str) -> None
1244        """
1245        Plot the data and residuals.
1246        """
1247        pars = dict((k, v.value) for k, v in self.pars.items())
1248        pars.update(self.pd_types)
1249        self.opts['pars'][0] = pars
1250        if not self.fix_p2:
1251            self.opts['pars'][1] = pars
1252        result = run_models(self.opts)
1253        limits = plot_models(self.opts, result, limits=self.limits)
1254        if self.limits is None:
1255            vmin, vmax = limits
1256            self.limits = vmax*1e-7, 1.3*vmax
1257            import pylab; pylab.clf()
1258            plot_models(self.opts, result, limits=self.limits)
1259
1260
1261def main(*argv):
1262    # type: (*str) -> None
1263    """
1264    Main program.
1265    """
1266    opts = parse_opts(argv)
1267    if opts is not None:
1268        if opts['html']:
1269            show_docs(opts)
1270        elif opts['explore']:
1271            explore(opts)
1272        else:
1273            compare(opts)
1274
1275if __name__ == "__main__":
1276    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.