source: sasmodels/sasmodels/compare.py @ 3e8ea5d

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 3e8ea5d was 3e8ea5d, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

constrain pearls to be bigger than string

  • Property mode set to 100755
File size: 41.6 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35import re
36
37import numpy as np  # type: ignore
38
39from . import core
40from . import kerneldll
41from . import exception
42from .data import plot_theory, empty_data1D, empty_data2D
43from .direct_model import DirectModel
44from .convert import revert_name, revert_pars, constrain_new_to_old
45from .generate import FLOAT_RE
46
47try:
48    from typing import Optional, Dict, Any, Callable, Tuple
49except Exception:
50    pass
51else:
52    from .modelinfo import ModelInfo, Parameter, ParameterSet
53    from .data import Data
54    Calculator = Callable[[float], np.ndarray]
55
56USAGE = """
57usage: compare.py model N1 N2 [options...] [key=val]
58
59Compare the speed and value for a model between the SasView original and the
60sasmodels rewrite.
61
62model is the name of the model to compare (see below).
63N1 is the number of times to run sasmodels (default=1).
64N2 is the number times to run sasview (default=1).
65
66Options (* for default):
67
68    -plot*/-noplot plots or suppress the plot of the model
69    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
70    -nq=128 sets the number of Q points in the data set
71    -zero indicates that q=0 should be included
72    -1d*/-2d computes 1d or 2d data
73    -preset*/-random[=seed] preset or random parameters
74    -mono/-poly* force monodisperse/polydisperse
75    -magnetic/-nonmagnetic* suppress magnetism
76    -cutoff=1e-5* cutoff value for including a point in polydispersity
77    -pars/-nopars* prints the parameter set or not
78    -abs/-rel* plot relative or absolute error
79    -linear/-log*/-q4 intensity scaling
80    -hist/-nohist* plot histogram of relative error
81    -res=0 sets the resolution width dQ/Q if calculating with resolution
82    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
83    -edit starts the parameter explorer
84    -default/-demo* use demo vs default parameters
85    -html shows the model docs instead of running the model
86    -title="note" adds note to the plot title, after the model name
87
88Any two calculation engines can be selected for comparison:
89
90    -single/-double/-half/-fast sets an OpenCL calculation engine
91    -single!/-double!/-quad! sets an OpenMP calculation engine
92    -sasview sets the sasview calculation engine
93
94The default is -single -sasview.  Note that the interpretation of quad
95precision depends on architecture, and may vary from 64-bit to 128-bit,
96with 80-bit floats being common (1e-19 precision).
97
98Key=value pairs allow you to set specific values for the model parameters.
99"""
100
101# Update docs with command line usage string.   This is separate from the usual
102# doc string so that we can display it at run time if there is an error.
103# lin
104__doc__ = (__doc__  # pylint: disable=redefined-builtin
105           + """
106Program description
107-------------------
108
109"""
110           + USAGE)
111
112kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
113
114# list of math functions for use in evaluating parameters
115MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
116
117# CRUFT python 2.6
118if not hasattr(datetime.timedelta, 'total_seconds'):
119    def delay(dt):
120        """Return number date-time delta as number seconds"""
121        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
122else:
123    def delay(dt):
124        """Return number date-time delta as number seconds"""
125        return dt.total_seconds()
126
127
128class push_seed(object):
129    """
130    Set the seed value for the random number generator.
131
132    When used in a with statement, the random number generator state is
133    restored after the with statement is complete.
134
135    :Parameters:
136
137    *seed* : int or array_like, optional
138        Seed for RandomState
139
140    :Example:
141
142    Seed can be used directly to set the seed::
143
144        >>> from numpy.random import randint
145        >>> push_seed(24)
146        <...push_seed object at...>
147        >>> print(randint(0,1000000,3))
148        [242082    899 211136]
149
150    Seed can also be used in a with statement, which sets the random
151    number generator state for the enclosed computations and restores
152    it to the previous state on completion::
153
154        >>> with push_seed(24):
155        ...    print(randint(0,1000000,3))
156        [242082    899 211136]
157
158    Using nested contexts, we can demonstrate that state is indeed
159    restored after the block completes::
160
161        >>> with push_seed(24):
162        ...    print(randint(0,1000000))
163        ...    with push_seed(24):
164        ...        print(randint(0,1000000,3))
165        ...    print(randint(0,1000000))
166        242082
167        [242082    899 211136]
168        899
169
170    The restore step is protected against exceptions in the block::
171
172        >>> with push_seed(24):
173        ...    print(randint(0,1000000))
174        ...    try:
175        ...        with push_seed(24):
176        ...            print(randint(0,1000000,3))
177        ...            raise Exception()
178        ...    except Exception:
179        ...        print("Exception raised")
180        ...    print(randint(0,1000000))
181        242082
182        [242082    899 211136]
183        Exception raised
184        899
185    """
186    def __init__(self, seed=None):
187        # type: (Optional[int]) -> None
188        self._state = np.random.get_state()
189        np.random.seed(seed)
190
191    def __enter__(self):
192        # type: () -> None
193        pass
194
195    def __exit__(self, exc_type, exc_value, traceback):
196        # type: (Any, BaseException, Any) -> None
197        # TODO: better typing for __exit__ method
198        np.random.set_state(self._state)
199
200def tic():
201    # type: () -> Callable[[], float]
202    """
203    Timer function.
204
205    Use "toc=tic()" to start the clock and "toc()" to measure
206    a time interval.
207    """
208    then = datetime.datetime.now()
209    return lambda: delay(datetime.datetime.now() - then)
210
211
212def set_beam_stop(data, radius, outer=None):
213    # type: (Data, float, float) -> None
214    """
215    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
216
217    Note: this function does not require sasview
218    """
219    if hasattr(data, 'qx_data'):
220        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
221        data.mask = (q < radius)
222        if outer is not None:
223            data.mask |= (q >= outer)
224    else:
225        data.mask = (data.x < radius)
226        if outer is not None:
227            data.mask |= (data.x >= outer)
228
229
230def parameter_range(p, v):
231    # type: (str, float) -> Tuple[float, float]
232    """
233    Choose a parameter range based on parameter name and initial value.
234    """
235    # process the polydispersity options
236    if p.endswith('_pd_n'):
237        return 0., 100.
238    elif p.endswith('_pd_nsigma'):
239        return 0., 5.
240    elif p.endswith('_pd_type'):
241        raise ValueError("Cannot return a range for a string value")
242    elif any(s in p for s in ('theta', 'phi', 'psi')):
243        # orientation in [-180,180], orientation pd in [0,45]
244        if p.endswith('_pd'):
245            return 0., 45.
246        else:
247            return -180., 180.
248    elif p.endswith('_pd'):
249        return 0., 1.
250    elif 'sld' in p:
251        return -0.5, 10.
252    elif p == 'background':
253        return 0., 10.
254    elif p == 'scale':
255        return 0., 1.e3
256    elif v < 0.:
257        return 2.*v, -2.*v
258    else:
259        return 0., (2.*v if v > 0. else 1.)
260
261
262def _randomize_one(model_info, p, v):
263    # type: (ModelInfo, str, float) -> float
264    # type: (ModelInfo, str, str) -> str
265    """
266    Randomize a single parameter.
267    """
268    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
269        return v
270
271    # Find the parameter definition
272    for par in model_info.parameters.call_parameters:
273        if par.name == p:
274            break
275    else:
276        raise ValueError("unknown parameter %r"%p)
277
278    if len(par.limits) > 2:  # choice list
279        return np.random.randint(len(par.limits))
280
281    limits = par.limits
282    if not np.isfinite(limits).all():
283        low, high = parameter_range(p, v)
284        limits = (max(limits[0], low), min(limits[1], high))
285
286    return np.random.uniform(*limits)
287
288
289def randomize_pars(model_info, pars, seed=None):
290    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
291    """
292    Generate random values for all of the parameters.
293
294    Valid ranges for the random number generator are guessed from the name of
295    the parameter; this will not account for constraints such as cap radius
296    greater than cylinder radius in the capped_cylinder model, so
297    :func:`constrain_pars` needs to be called afterward..
298    """
299    with push_seed(seed):
300        # Note: the sort guarantees order `of calls to random number generator
301        random_pars = dict((p, _randomize_one(model_info, p, v))
302                           for p, v in sorted(pars.items()))
303    return random_pars
304
305def constrain_pars(model_info, pars):
306    # type: (ModelInfo, ParameterSet) -> None
307    """
308    Restrict parameters to valid values.
309
310    This includes model specific code for models such as capped_cylinder
311    which need to support within model constraints (cap radius more than
312    cylinder radius in this case).
313
314    Warning: this updates the *pars* dictionary in place.
315    """
316    name = model_info.id
317    # if it is a product model, then just look at the form factor since
318    # none of the structure factors need any constraints.
319    if '*' in name:
320        name = name.split('*')[0]
321
322    if name == 'barbell':
323        if pars['radius_bell'] < pars['radius']:
324            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
325
326    elif name == 'capped_cylinder':
327        if pars['radius_cap'] < pars['radius']:
328            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
329
330    elif name == 'guinier':
331        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
332        #q_max = 0.2  # mid q maximum
333        q_max = 1.0  # high q maximum
334        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
335        pars['rg'] = min(pars['rg'], rg_max)
336
337    elif name == 'pearl_necklace':
338        if pars['radius'] < pars['thick_string']:
339            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
340        pars['num_pearls'] = math.ceil(pars['num_pearls'])
341        pass
342
343    elif name == 'rpa':
344        # Make sure phi sums to 1.0
345        if pars['case_num'] < 2:
346            pars['Phi1'] = 0.
347            pars['Phi2'] = 0.
348        elif pars['case_num'] < 5:
349            pars['Phi1'] = 0.
350        total = sum(pars['Phi'+c] for c in '1234')
351        for c in '1234':
352            pars['Phi'+c] /= total
353
354    elif name == 'stacked_disks':
355        pars['n_stacking'] = math.ceil(pars['n_stacking'])
356
357def parlist(model_info, pars, is2d):
358    # type: (ModelInfo, ParameterSet, bool) -> str
359    """
360    Format the parameter list for printing.
361    """
362    lines = []
363    parameters = model_info.parameters
364    magnetic = False
365    for p in parameters.user_parameters(pars, is2d):
366        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
367            continue
368        if p.id.startswith('up:') and not magnetic:
369            continue
370        fields = dict(
371            value=pars.get(p.id, p.default),
372            pd=pars.get(p.id+"_pd", 0.),
373            n=int(pars.get(p.id+"_pd_n", 0)),
374            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
375            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
376            relative_pd=p.relative_pd,
377            M0=pars.get('M0:'+p.id, 0.),
378            mphi=pars.get('mphi:'+p.id, 0.),
379            mtheta=pars.get('mtheta:'+p.id, 0.),
380        )
381        lines.append(_format_par(p.name, **fields))
382        magnetic = magnetic or fields['M0'] != 0.
383    return "\n".join(lines)
384
385    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
386
387def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
388                relative_pd=False, M0=0., mphi=0., mtheta=0.):
389    # type: (str, float, float, int, float, str) -> str
390    line = "%s: %g"%(name, value)
391    if pd != 0.  and n != 0:
392        if relative_pd:
393            pd *= value
394        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
395                % (pd, n, nsigma, nsigma, pdtype)
396    if M0 != 0.:
397        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
398    return line
399
400def suppress_pd(pars):
401    # type: (ParameterSet) -> ParameterSet
402    """
403    Suppress theta_pd for now until the normalization is resolved.
404
405    May also suppress complete polydispersity of the model to test
406    models more quickly.
407    """
408    pars = pars.copy()
409    for p in pars:
410        if p.endswith("_pd_n"): pars[p] = 0
411    return pars
412
413def suppress_magnetism(pars):
414    # type: (ParameterSet) -> ParameterSet
415    """
416    Suppress theta_pd for now until the normalization is resolved.
417
418    May also suppress complete polydispersity of the model to test
419    models more quickly.
420    """
421    pars = pars.copy()
422    for p in pars:
423        if p.startswith("M0:"): pars[p] = 0
424    return pars
425
426def eval_sasview(model_info, data):
427    # type: (Modelinfo, Data) -> Calculator
428    """
429    Return a model calculator using the pre-4.0 SasView models.
430    """
431    # importing sas here so that the error message will be that sas failed to
432    # import rather than the more obscure smear_selection not imported error
433    import sas
434    import sas.models
435    from sas.models.qsmearing import smear_selection
436    from sas.models.MultiplicationModel import MultiplicationModel
437    from sas.models.dispersion_models import models as dispersers
438
439    def get_model_class(name):
440        # type: (str) -> "sas.models.BaseComponent"
441        #print("new",sorted(_pars.items()))
442        __import__('sas.models.' + name)
443        ModelClass = getattr(getattr(sas.models, name, None), name, None)
444        if ModelClass is None:
445            raise ValueError("could not find model %r in sas.models"%name)
446        return ModelClass
447
448    # WARNING: ugly hack when handling model!
449    # Sasview models with multiplicity need to be created with the target
450    # multiplicity, so we cannot create the target model ahead of time for
451    # for multiplicity models.  Instead we store the model in a list and
452    # update the first element of that list with the new multiplicity model
453    # every time we evaluate.
454
455    # grab the sasview model, or create it if it is a product model
456    if model_info.composition:
457        composition_type, parts = model_info.composition
458        if composition_type == 'product':
459            P, S = [get_model_class(revert_name(p))() for p in parts]
460            model = [MultiplicationModel(P, S)]
461        else:
462            raise ValueError("sasview mixture models not supported by compare")
463    else:
464        old_name = revert_name(model_info)
465        if old_name is None:
466            raise ValueError("model %r does not exist in old sasview"
467                            % model_info.id)
468        ModelClass = get_model_class(old_name)
469        model = [ModelClass()]
470    model[0].disperser_handles = {}
471
472    # build a smearer with which to call the model, if necessary
473    smearer = smear_selection(data, model=model)
474    if hasattr(data, 'qx_data'):
475        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
476        index = ((~data.mask) & (~np.isnan(data.data))
477                 & (q >= data.qmin) & (q <= data.qmax))
478        if smearer is not None:
479            smearer.model = model  # because smear_selection has a bug
480            smearer.accuracy = data.accuracy
481            smearer.set_index(index)
482            def _call_smearer():
483                smearer.model = model[0]
484                return smearer.get_value()
485            theory = _call_smearer
486        else:
487            theory = lambda: model[0].evalDistribution([data.qx_data[index],
488                                                        data.qy_data[index]])
489    elif smearer is not None:
490        theory = lambda: smearer(model[0].evalDistribution(data.x))
491    else:
492        theory = lambda: model[0].evalDistribution(data.x)
493
494    def calculator(**pars):
495        # type: (float, ...) -> np.ndarray
496        """
497        Sasview calculator for model.
498        """
499        oldpars = revert_pars(model_info, pars)
500        # For multiplicity models, create a model with the correct multiplicity
501        control = oldpars.pop("CONTROL", None)
502        if control is not None:
503            # sphericalSLD has one fewer multiplicity.  This update should
504            # happen in revert_pars, but it hasn't been called yet.
505            model[0] = ModelClass(control)
506        # paying for parameter conversion each time to keep life simple, if not fast
507        for k, v in oldpars.items():
508            if k.endswith('.type'):
509                par = k[:-5]
510                if v == 'gaussian': continue
511                cls = dispersers[v if v != 'rectangle' else 'rectangula']
512                handle = cls()
513                model[0].disperser_handles[par] = handle
514                try:
515                    model[0].set_dispersion(par, handle)
516                except Exception:
517                    exception.annotate_exception("while setting %s to %r"
518                                                 %(par, v))
519                    raise
520
521
522        #print("sasview pars",oldpars)
523        for k, v in oldpars.items():
524            name_attr = k.split('.')  # polydispersity components
525            if len(name_attr) == 2:
526                par, disp_par = name_attr
527                model[0].dispersion[par][disp_par] = v
528            else:
529                model[0].setParam(k, v)
530        return theory()
531
532    calculator.engine = "sasview"
533    return calculator
534
535DTYPE_MAP = {
536    'half': '16',
537    'fast': 'fast',
538    'single': '32',
539    'double': '64',
540    'quad': '128',
541    'f16': '16',
542    'f32': '32',
543    'f64': '64',
544    'longdouble': '128',
545}
546def eval_opencl(model_info, data, dtype='single', cutoff=0.):
547    # type: (ModelInfo, Data, str, float) -> Calculator
548    """
549    Return a model calculator using the OpenCL calculation engine.
550    """
551    if not core.HAVE_OPENCL:
552        raise RuntimeError("OpenCL not available")
553    model = core.build_model(model_info, dtype=dtype, platform="ocl")
554    calculator = DirectModel(data, model, cutoff=cutoff)
555    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
556    return calculator
557
558def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
559    # type: (ModelInfo, Data, str, float) -> Calculator
560    """
561    Return a model calculator using the DLL calculation engine.
562    """
563    model = core.build_model(model_info, dtype=dtype, platform="dll")
564    calculator = DirectModel(data, model, cutoff=cutoff)
565    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
566    return calculator
567
568def time_calculation(calculator, pars, evals=1):
569    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
570    """
571    Compute the average calculation time over N evaluations.
572
573    An additional call is generated without polydispersity in order to
574    initialize the calculation engine, and make the average more stable.
575    """
576    # initialize the code so time is more accurate
577    if evals > 1:
578        calculator(**suppress_pd(pars))
579    toc = tic()
580    # make sure there is at least one eval
581    value = calculator(**pars)
582    for _ in range(evals-1):
583        value = calculator(**pars)
584    average_time = toc()*1000. / evals
585    #print("I(q)",value)
586    return value, average_time
587
588def make_data(opts):
589    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
590    """
591    Generate an empty dataset, used with the model to set Q points
592    and resolution.
593
594    *opts* contains the options, with 'qmax', 'nq', 'res',
595    'accuracy', 'is2d' and 'view' parsed from the command line.
596    """
597    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
598    if opts['is2d']:
599        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
600        data = empty_data2D(q, resolution=res)
601        data.accuracy = opts['accuracy']
602        set_beam_stop(data, 0.0004)
603        index = ~data.mask
604    else:
605        if opts['view'] == 'log' and not opts['zero']:
606            qmax = math.log10(qmax)
607            q = np.logspace(qmax-3, qmax, nq)
608        else:
609            q = np.linspace(0.001*qmax, qmax, nq)
610        if opts['zero']:
611            q = np.hstack((0, q))
612        data = empty_data1D(q, resolution=res)
613        index = slice(None, None)
614    return data, index
615
616def make_engine(model_info, data, dtype, cutoff):
617    # type: (ModelInfo, Data, str, float) -> Calculator
618    """
619    Generate the appropriate calculation engine for the given datatype.
620
621    Datatypes with '!' appended are evaluated using external C DLLs rather
622    than OpenCL.
623    """
624    if dtype == 'sasview':
625        return eval_sasview(model_info, data)
626    elif dtype.endswith('!'):
627        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
628    else:
629        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
630
631def _show_invalid(data, theory):
632    # type: (Data, np.ma.ndarray) -> None
633    """
634    Display a list of the non-finite values in theory.
635    """
636    if not theory.mask.any():
637        return
638
639    if hasattr(data, 'x'):
640        bad = zip(data.x[theory.mask], theory[theory.mask])
641        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
642
643
644def compare(opts, limits=None):
645    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
646    """
647    Preform a comparison using options from the command line.
648
649    *limits* are the limits on the values to use, either to set the y-axis
650    for 1D or to set the colormap scale for 2D.  If None, then they are
651    inferred from the data and returned. When exploring using Bumps,
652    the limits are set when the model is initially called, and maintained
653    as the values are adjusted, making it easier to see the effects of the
654    parameters.
655    """
656    n_base, n_comp = opts['count']
657    pars, pars2 = opts['pars']
658    data = opts['data']
659
660    # silence the linter
661    base = opts['engines'][0] if n_base else None
662    comp = opts['engines'][1] if n_comp else None
663    base_time = comp_time = None
664    base_value = comp_value = resid = relerr = None
665
666    # Base calculation
667    if n_base > 0:
668        try:
669            base_raw, base_time = time_calculation(base, pars, n_base)
670            base_value = np.ma.masked_invalid(base_raw)
671            print("%s t=%.2f ms, intensity=%.0f"
672                  % (base.engine, base_time, base_value.sum()))
673            _show_invalid(data, base_value)
674        except ImportError:
675            traceback.print_exc()
676            n_base = 0
677
678    # Comparison calculation
679    if n_comp > 0:
680        try:
681            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
682            comp_value = np.ma.masked_invalid(comp_raw)
683            print("%s t=%.2f ms, intensity=%.0f"
684                  % (comp.engine, comp_time, comp_value.sum()))
685            _show_invalid(data, comp_value)
686        except ImportError:
687            traceback.print_exc()
688            n_comp = 0
689
690    # Compare, but only if computing both forms
691    if n_base > 0 and n_comp > 0:
692        resid = (base_value - comp_value)
693        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
694        _print_stats("|%s-%s|"
695                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
696                     resid)
697        _print_stats("|(%s-%s)/%s|"
698                     % (base.engine, comp.engine, comp.engine),
699                     relerr)
700
701    # Plot if requested
702    if not opts['plot'] and not opts['explore']: return
703    view = opts['view']
704    import matplotlib.pyplot as plt
705    if limits is None:
706        vmin, vmax = np.Inf, -np.Inf
707        if n_base > 0:
708            vmin = min(vmin, base_value.min())
709            vmax = max(vmax, base_value.max())
710        if n_comp > 0:
711            vmin = min(vmin, comp_value.min())
712            vmax = max(vmax, comp_value.max())
713        limits = vmin, vmax
714
715    if n_base > 0:
716        if n_comp > 0: plt.subplot(131)
717        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
718        plt.title("%s t=%.2f ms"%(base.engine, base_time))
719        #cbar_title = "log I"
720    if n_comp > 0:
721        if n_base > 0: plt.subplot(132)
722        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
723        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
724        #cbar_title = "log I"
725    if n_comp > 0 and n_base > 0:
726        if not opts['is2d']:
727            plot_theory(data, base_value, view=view, use_data=False, limits=limits)
728        plt.subplot(133)
729        if not opts['rel_err']:
730            err, errstr, errview = resid, "abs err", "linear"
731        else:
732            err, errstr, errview = abs(relerr), "rel err", "log"
733        if 0:  # 95% cutoff
734            sorted = np.sort(err.flatten())
735            cutoff = sorted[int(sorted.size*0.95)]
736            err[err>cutoff] = cutoff
737        #err,errstr = base/comp,"ratio"
738        plot_theory(data, None, resid=err, view=errview, use_data=False)
739        if view == 'linear':
740            plt.xscale('linear')
741        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
742        #cbar_title = errstr if errview=="linear" else "log "+errstr
743    #if is2D:
744    #    h = plt.colorbar()
745    #    h.ax.set_title(cbar_title)
746    fig = plt.gcf()
747    extra_title = ' '+opts['title'] if opts['title'] else ''
748    fig.suptitle(":".join(opts['name']) + extra_title)
749
750    if n_comp > 0 and n_base > 0 and opts['show_hist']:
751        plt.figure()
752        v = relerr
753        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
754        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
755        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
756                   % (base.engine, comp.engine, comp.engine))
757        plt.ylabel('P(err)')
758        plt.title('Distribution of relative error between calculation engines')
759
760    if not opts['explore']:
761        plt.show()
762
763    return limits
764
765def _print_stats(label, err):
766    # type: (str, np.ma.ndarray) -> None
767    # work with trimmed data, not the full set
768    sorted_err = np.sort(abs(err.compressed()))
769    p50 = int((len(sorted_err)-1)*0.50)
770    p98 = int((len(sorted_err)-1)*0.98)
771    data = [
772        "max:%.3e"%sorted_err[-1],
773        "median:%.3e"%sorted_err[p50],
774        "98%%:%.3e"%sorted_err[p98],
775        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
776        "zero-offset:%+.3e"%np.mean(sorted_err),
777        ]
778    print(label+"  "+"  ".join(data))
779
780
781
782# ===========================================================================
783#
784NAME_OPTIONS = set([
785    'plot', 'noplot',
786    'half', 'fast', 'single', 'double',
787    'single!', 'double!', 'quad!', 'sasview',
788    'lowq', 'midq', 'highq', 'exq', 'zero',
789    '2d', '1d',
790    'preset', 'random',
791    'poly', 'mono',
792    'magnetic', 'nonmagnetic',
793    'nopars', 'pars',
794    'rel', 'abs',
795    'linear', 'log', 'q4',
796    'hist', 'nohist',
797    'edit', 'html',
798    'demo', 'default',
799    ])
800VALUE_OPTIONS = [
801    # Note: random is both a name option and a value option
802    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title',
803    ]
804
805def columnize(items, indent="", width=79):
806    # type: (List[str], str, int) -> str
807    """
808    Format a list of strings into columns.
809
810    Returns a string with carriage returns ready for printing.
811    """
812    column_width = max(len(w) for w in items) + 1
813    num_columns = (width - len(indent)) // column_width
814    num_rows = len(items) // num_columns
815    items = items + [""] * (num_rows * num_columns - len(items))
816    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
817    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
818             for row in zip(*columns)]
819    output = indent + ("\n"+indent).join(lines)
820    return output
821
822
823def get_pars(model_info, use_demo=False):
824    # type: (ModelInfo, bool) -> ParameterSet
825    """
826    Extract demo parameters from the model definition.
827    """
828    # Get the default values for the parameters
829    pars = {}
830    for p in model_info.parameters.call_parameters:
831        parts = [('', p.default)]
832        if p.polydisperse:
833            parts.append(('_pd', 0.0))
834            parts.append(('_pd_n', 0))
835            parts.append(('_pd_nsigma', 3.0))
836            parts.append(('_pd_type', "gaussian"))
837        for ext, val in parts:
838            if p.length > 1:
839                dict(("%s%d%s" % (p.id, k, ext), val)
840                     for k in range(1, p.length+1))
841            else:
842                pars[p.id + ext] = val
843
844    # Plug in values given in demo
845    if use_demo:
846        pars.update(model_info.demo)
847    return pars
848
849INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
850def isnumber(str):
851    match = FLOAT_RE.match(str)
852    isfloat = (match and not str[match.end():])
853    return isfloat or INTEGER_RE.match(str)
854
855def parse_opts(argv):
856    # type: (List[str]) -> Dict[str, Any]
857    """
858    Parse command line options.
859    """
860    MODELS = core.list_models()
861    flags = [arg for arg in argv
862             if arg.startswith('-')]
863    values = [arg for arg in argv
864              if not arg.startswith('-') and '=' in arg]
865    positional_args = [arg for arg in argv
866            if not arg.startswith('-') and '=' not in arg]
867    models = "\n    ".join("%-15s"%v for v in MODELS)
868    if len(positional_args) == 0:
869        print(USAGE)
870        print("\nAvailable models:")
871        print(columnize(MODELS, indent="  "))
872        return None
873    if len(positional_args) > 3:
874        print("expected parameters: model N1 N2")
875
876    invalid = [o[1:] for o in flags
877               if o[1:] not in NAME_OPTIONS
878               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
879    if invalid:
880        print("Invalid options: %s"%(", ".join(invalid)))
881        return None
882
883    name = positional_args[0]
884    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
885    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
886
887    # pylint: disable=bad-whitespace
888    # Interpret the flags
889    opts = {
890        'plot'      : True,
891        'view'      : 'log',
892        'is2d'      : False,
893        'qmax'      : 0.05,
894        'nq'        : 128,
895        'res'       : 0.0,
896        'accuracy'  : 'Low',
897        'cutoff'    : 0.0,
898        'seed'      : -1,  # default to preset
899        'mono'      : False,
900        # Default to magnetic a magnetic moment is set on the command line
901        'magnetic'  : False,
902        'show_pars' : False,
903        'show_hist' : False,
904        'rel_err'   : True,
905        'explore'   : False,
906        'use_demo'  : True,
907        'zero'      : False,
908        'html'      : False,
909        'title'     : None,
910    }
911    engines = []
912    for arg in flags:
913        if arg == '-noplot':    opts['plot'] = False
914        elif arg == '-plot':    opts['plot'] = True
915        elif arg == '-linear':  opts['view'] = 'linear'
916        elif arg == '-log':     opts['view'] = 'log'
917        elif arg == '-q4':      opts['view'] = 'q4'
918        elif arg == '-1d':      opts['is2d'] = False
919        elif arg == '-2d':      opts['is2d'] = True
920        elif arg == '-exq':     opts['qmax'] = 10.0
921        elif arg == '-highq':   opts['qmax'] = 1.0
922        elif arg == '-midq':    opts['qmax'] = 0.2
923        elif arg == '-lowq':    opts['qmax'] = 0.05
924        elif arg == '-zero':    opts['zero'] = True
925        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
926        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
927        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
928        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
929        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
930        elif arg.startswith('-title'):     opts['title'] = arg[7:]
931        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
932        elif arg == '-preset':  opts['seed'] = -1
933        elif arg == '-mono':    opts['mono'] = True
934        elif arg == '-poly':    opts['mono'] = False
935        elif arg == '-magnetic':       opts['magnetic'] = True
936        elif arg == '-nonmagnetic':    opts['magnetic'] = False
937        elif arg == '-pars':    opts['show_pars'] = True
938        elif arg == '-nopars':  opts['show_pars'] = False
939        elif arg == '-hist':    opts['show_hist'] = True
940        elif arg == '-nohist':  opts['show_hist'] = False
941        elif arg == '-rel':     opts['rel_err'] = True
942        elif arg == '-abs':     opts['rel_err'] = False
943        elif arg == '-half':    engines.append(arg[1:])
944        elif arg == '-fast':    engines.append(arg[1:])
945        elif arg == '-single':  engines.append(arg[1:])
946        elif arg == '-double':  engines.append(arg[1:])
947        elif arg == '-single!': engines.append(arg[1:])
948        elif arg == '-double!': engines.append(arg[1:])
949        elif arg == '-quad!':   engines.append(arg[1:])
950        elif arg == '-sasview': engines.append(arg[1:])
951        elif arg == '-edit':    opts['explore'] = True
952        elif arg == '-demo':    opts['use_demo'] = True
953        elif arg == '-default':    opts['use_demo'] = False
954        elif arg == '-html':    opts['html'] = True
955    # pylint: enable=bad-whitespace
956
957    if ':' in name:
958        name, name2 = name.split(':',2)
959    else:
960        name2 = name
961    try:
962        model_info = core.load_model_info(name)
963        model_info2 = core.load_model_info(name2) if name2 != name else model_info
964    except ImportError as exc:
965        print(str(exc))
966        print("Could not find model; use one of:\n    " + models)
967        return None
968
969    # Get demo parameters from model definition, or use default parameters
970    # if model does not define demo parameters
971    pars = get_pars(model_info, opts['use_demo'])
972    pars2 = get_pars(model_info2, opts['use_demo'])
973    pars2.update((k, v) for k, v in pars.items() if k in pars2)
974    # randomize parameters
975    #pars.update(set_pars)  # set value before random to control range
976    if opts['seed'] > -1:
977        pars = randomize_pars(model_info, pars, seed=opts['seed'])
978        if model_info != model_info2:
979            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
980            # Share values for parameters with the same name
981            for k, v in pars.items():
982                if k in pars2:
983                    pars2[k] = v
984        else:
985            pars2 = pars.copy()
986        constrain_pars(model_info, pars)
987        constrain_pars(model_info2, pars2)
988        print("Randomize using -random=%i"%opts['seed'])
989    if opts['mono']:
990        pars = suppress_pd(pars)
991        pars2 = suppress_pd(pars2)
992    if not opts['magnetic']:
993        pars = suppress_magnetism(pars)
994        pars2 = suppress_magnetism(pars2)
995
996    # Fill in parameters given on the command line
997    presets = {}
998    presets2 = {}
999    for arg in values:
1000        k, v = arg.split('=', 1)
1001        if k not in pars and k not in pars2:
1002            # extract base name without polydispersity info
1003            s = set(p.split('_pd')[0] for p in pars)
1004            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1005            return None
1006        v1, v2 = v.split(':',2) if ':' in v else (v,v)
1007        if v1 and k in pars:
1008            presets[k] = float(v1) if isnumber(v1) else v1
1009        if v2 and k in pars2:
1010            presets2[k] = float(v2) if isnumber(v2) else v2
1011
1012    # If pd given on the command line, default pd_n to 35
1013    for k, v in list(presets.items()):
1014        if k.endswith('_pd'):
1015            presets.setdefault(k+'_n', 35.)
1016    for k, v in list(presets2.items()):
1017        if k.endswith('_pd'):
1018            presets2.setdefault(k+'_n', 35.)
1019
1020    # Evaluate preset parameter expressions
1021    context = MATH.copy()
1022    context.update(pars)
1023    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
1024    for k, v in presets.items():
1025        if not isinstance(v, float) and not k.endswith('_type'):
1026            presets[k] = eval(v, context)
1027    context.update(presets)
1028    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
1029    for k, v in presets2.items():
1030        if not isinstance(v, float) and not k.endswith('_type'):
1031            presets2[k] = eval(v, context)
1032
1033    # update parameters with presets
1034    pars.update(presets)  # set value after random to control value
1035    pars2.update(presets2)  # set value after random to control value
1036    #import pprint; pprint.pprint(model_info)
1037
1038    same_model = name == name2 and pars == pars
1039    if len(engines) == 0:
1040        if same_model:
1041            engines.extend(['single', 'double'])
1042        else:
1043            engines.extend(['single', 'single'])
1044    elif len(engines) == 1:
1045        if not same_model:
1046            engines.append(engines[0])
1047        elif engines[0] == 'double':
1048            engines.append('single')
1049        else:
1050            engines.append('double')
1051    elif len(engines) > 2:
1052        del engines[2:]
1053
1054    use_sasview = any(engine == 'sasview' and count > 0
1055                      for engine, count in zip(engines, [n1, n2]))
1056    if use_sasview:
1057        constrain_new_to_old(model_info, pars)
1058        constrain_new_to_old(model_info2, pars2)
1059
1060    if opts['show_pars']:
1061        if not same_model:
1062            print("==== %s ====="%model_info.name)
1063            print(str(parlist(model_info, pars, opts['is2d'])))
1064            print("==== %s ====="%model_info2.name)
1065            print(str(parlist(model_info2, pars2, opts['is2d'])))
1066        else:
1067            print(str(parlist(model_info, pars, opts['is2d'])))
1068
1069    # Create the computational engines
1070    data, _ = make_data(opts)
1071    if n1:
1072        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1073    else:
1074        base = None
1075    if n2:
1076        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1077    else:
1078        comp = None
1079
1080    # pylint: disable=bad-whitespace
1081    # Remember it all
1082    opts.update({
1083        'data'      : data,
1084        'name'      : [name, name2],
1085        'def'       : [model_info, model_info2],
1086        'count'     : [n1, n2],
1087        'presets'   : [presets, presets2],
1088        'pars'      : [pars, pars2],
1089        'engines'   : [base, comp],
1090    })
1091    # pylint: enable=bad-whitespace
1092
1093    return opts
1094
1095def show_docs(opts):
1096    # type: (Dict[str, Any]) -> None
1097    """
1098    show html docs for the model
1099    """
1100    import wx  # type: ignore
1101    from .generate import view_html_from_info
1102    app = wx.App() if wx.GetApp() is None else None
1103    view_html_from_info(opts['def'][0])
1104    if app: app.MainLoop()
1105
1106
1107def explore(opts):
1108    # type: (Dict[str, Any]) -> None
1109    """
1110    explore the model using the bumps gui.
1111    """
1112    import wx  # type: ignore
1113    from bumps.names import FitProblem  # type: ignore
1114    from bumps.gui.app_frame import AppFrame  # type: ignore
1115
1116    is_mac = "cocoa" in wx.version()
1117    # Create an app if not running embedded
1118    app = wx.App() if wx.GetApp() is None else None
1119    problem = FitProblem(Explore(opts))
1120    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1121    if not is_mac: frame.Show()
1122    frame.panel.set_model(model=problem)
1123    frame.panel.Layout()
1124    frame.panel.aui.Split(0, wx.TOP)
1125    if is_mac: frame.Show()
1126    # If running withing an app, start the main loop
1127    if app: app.MainLoop()
1128
1129class Explore(object):
1130    """
1131    Bumps wrapper for a SAS model comparison.
1132
1133    The resulting object can be used as a Bumps fit problem so that
1134    parameters can be adjusted in the GUI, with plots updated on the fly.
1135    """
1136    def __init__(self, opts):
1137        # type: (Dict[str, Any]) -> None
1138        from bumps.cli import config_matplotlib  # type: ignore
1139        from . import bumps_model
1140        config_matplotlib()
1141        self.opts = opts
1142        model_info = opts['def'][0]
1143        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'][0])
1144        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1145        if not opts['is2d']:
1146            for p in model_info.parameters.user_parameters(is2d=False):
1147                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1148                    k = p.name+ext
1149                    v = pars.get(k, None)
1150                    if v is not None:
1151                        v.range(*parameter_range(k, v.value))
1152        else:
1153            for k, v in pars.items():
1154                v.range(*parameter_range(k, v.value))
1155
1156        self.pars = pars
1157        self.pd_types = pd_types
1158        self.limits = None
1159
1160    def numpoints(self):
1161        # type: () -> int
1162        """
1163        Return the number of points.
1164        """
1165        return len(self.pars) + 1  # so dof is 1
1166
1167    def parameters(self):
1168        # type: () -> Any   # Dict/List hierarchy of parameters
1169        """
1170        Return a dictionary of parameters.
1171        """
1172        return self.pars
1173
1174    def nllf(self):
1175        # type: () -> float
1176        """
1177        Return cost.
1178        """
1179        # pylint: disable=no-self-use
1180        return 0.  # No nllf
1181
1182    def plot(self, view='log'):
1183        # type: (str) -> None
1184        """
1185        Plot the data and residuals.
1186        """
1187        pars = dict((k, v.value) for k, v in self.pars.items())
1188        pars.update(self.pd_types)
1189        self.opts['pars'][0] = pars
1190        self.opts['pars'][1] = pars
1191        limits = compare(self.opts, limits=self.limits)
1192        if self.limits is None:
1193            vmin, vmax = limits
1194            self.limits = vmax*1e-7, 1.3*vmax
1195
1196
1197def main(*argv):
1198    # type: (*str) -> None
1199    """
1200    Main program.
1201    """
1202    opts = parse_opts(argv)
1203    if opts is not None:
1204        if opts['html']:
1205            show_docs(opts)
1206        elif opts['explore']:
1207            explore(opts)
1208        else:
1209            compare(opts)
1210
1211if __name__ == "__main__":
1212    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.