source: sasmodels/sasmodels/compare.py @ 0b040de

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 0b040de was 0b040de, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

nicer formatting for magnetic parameters in parameter summary table

  • Property mode set to 100755
File size: 40.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35import re
36
37import numpy as np  # type: ignore
38
39from . import core
40from . import kerneldll
41from . import exception
42from .data import plot_theory, empty_data1D, empty_data2D
43from .direct_model import DirectModel
44from .convert import revert_name, revert_pars, constrain_new_to_old
45from .generate import FLOAT_RE
46
47try:
48    from typing import Optional, Dict, Any, Callable, Tuple
49except Exception:
50    pass
51else:
52    from .modelinfo import ModelInfo, Parameter, ParameterSet
53    from .data import Data
54    Calculator = Callable[[float], np.ndarray]
55
56USAGE = """
57usage: compare.py model N1 N2 [options...] [key=val]
58
59Compare the speed and value for a model between the SasView original and the
60sasmodels rewrite.
61
62model is the name of the model to compare (see below).
63N1 is the number of times to run sasmodels (default=1).
64N2 is the number times to run sasview (default=1).
65
66Options (* for default):
67
68    -plot*/-noplot plots or suppress the plot of the model
69    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
70    -nq=128 sets the number of Q points in the data set
71    -zero indicates that q=0 should be included
72    -1d*/-2d computes 1d or 2d data
73    -preset*/-random[=seed] preset or random parameters
74    -mono/-poly* force monodisperse/polydisperse
75    -magnetic/-nonmagnetic* suppress magnetism
76    -cutoff=1e-5* cutoff value for including a point in polydispersity
77    -pars/-nopars* prints the parameter set or not
78    -abs/-rel* plot relative or absolute error
79    -linear/-log*/-q4 intensity scaling
80    -hist/-nohist* plot histogram of relative error
81    -res=0 sets the resolution width dQ/Q if calculating with resolution
82    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
83    -edit starts the parameter explorer
84    -default/-demo* use demo vs default parameters
85    -html shows the model docs instead of running the model
86    -title="note" adds note to the plot title, after the model name
87
88Any two calculation engines can be selected for comparison:
89
90    -single/-double/-half/-fast sets an OpenCL calculation engine
91    -single!/-double!/-quad! sets an OpenMP calculation engine
92    -sasview sets the sasview calculation engine
93
94The default is -single -sasview.  Note that the interpretation of quad
95precision depends on architecture, and may vary from 64-bit to 128-bit,
96with 80-bit floats being common (1e-19 precision).
97
98Key=value pairs allow you to set specific values for the model parameters.
99"""
100
101# Update docs with command line usage string.   This is separate from the usual
102# doc string so that we can display it at run time if there is an error.
103# lin
104__doc__ = (__doc__  # pylint: disable=redefined-builtin
105           + """
106Program description
107-------------------
108
109"""
110           + USAGE)
111
112kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
113
114# list of math functions for use in evaluating parameters
115MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
116
117# CRUFT python 2.6
118if not hasattr(datetime.timedelta, 'total_seconds'):
119    def delay(dt):
120        """Return number date-time delta as number seconds"""
121        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
122else:
123    def delay(dt):
124        """Return number date-time delta as number seconds"""
125        return dt.total_seconds()
126
127
128class push_seed(object):
129    """
130    Set the seed value for the random number generator.
131
132    When used in a with statement, the random number generator state is
133    restored after the with statement is complete.
134
135    :Parameters:
136
137    *seed* : int or array_like, optional
138        Seed for RandomState
139
140    :Example:
141
142    Seed can be used directly to set the seed::
143
144        >>> from numpy.random import randint
145        >>> push_seed(24)
146        <...push_seed object at...>
147        >>> print(randint(0,1000000,3))
148        [242082    899 211136]
149
150    Seed can also be used in a with statement, which sets the random
151    number generator state for the enclosed computations and restores
152    it to the previous state on completion::
153
154        >>> with push_seed(24):
155        ...    print(randint(0,1000000,3))
156        [242082    899 211136]
157
158    Using nested contexts, we can demonstrate that state is indeed
159    restored after the block completes::
160
161        >>> with push_seed(24):
162        ...    print(randint(0,1000000))
163        ...    with push_seed(24):
164        ...        print(randint(0,1000000,3))
165        ...    print(randint(0,1000000))
166        242082
167        [242082    899 211136]
168        899
169
170    The restore step is protected against exceptions in the block::
171
172        >>> with push_seed(24):
173        ...    print(randint(0,1000000))
174        ...    try:
175        ...        with push_seed(24):
176        ...            print(randint(0,1000000,3))
177        ...            raise Exception()
178        ...    except Exception:
179        ...        print("Exception raised")
180        ...    print(randint(0,1000000))
181        242082
182        [242082    899 211136]
183        Exception raised
184        899
185    """
186    def __init__(self, seed=None):
187        # type: (Optional[int]) -> None
188        self._state = np.random.get_state()
189        np.random.seed(seed)
190
191    def __enter__(self):
192        # type: () -> None
193        pass
194
195    def __exit__(self, exc_type, exc_value, traceback):
196        # type: (Any, BaseException, Any) -> None
197        # TODO: better typing for __exit__ method
198        np.random.set_state(self._state)
199
200def tic():
201    # type: () -> Callable[[], float]
202    """
203    Timer function.
204
205    Use "toc=tic()" to start the clock and "toc()" to measure
206    a time interval.
207    """
208    then = datetime.datetime.now()
209    return lambda: delay(datetime.datetime.now() - then)
210
211
212def set_beam_stop(data, radius, outer=None):
213    # type: (Data, float, float) -> None
214    """
215    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
216
217    Note: this function does not require sasview
218    """
219    if hasattr(data, 'qx_data'):
220        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
221        data.mask = (q < radius)
222        if outer is not None:
223            data.mask |= (q >= outer)
224    else:
225        data.mask = (data.x < radius)
226        if outer is not None:
227            data.mask |= (data.x >= outer)
228
229
230def parameter_range(p, v):
231    # type: (str, float) -> Tuple[float, float]
232    """
233    Choose a parameter range based on parameter name and initial value.
234    """
235    # process the polydispersity options
236    if p.endswith('_pd_n'):
237        return 0., 100.
238    elif p.endswith('_pd_nsigma'):
239        return 0., 5.
240    elif p.endswith('_pd_type'):
241        raise ValueError("Cannot return a range for a string value")
242    elif any(s in p for s in ('theta', 'phi', 'psi')):
243        # orientation in [-180,180], orientation pd in [0,45]
244        if p.endswith('_pd'):
245            return 0., 45.
246        else:
247            return -180., 180.
248    elif p.endswith('_pd'):
249        return 0., 1.
250    elif 'sld' in p:
251        return -0.5, 10.
252    elif p == 'background':
253        return 0., 10.
254    elif p == 'scale':
255        return 0., 1.e3
256    elif v < 0.:
257        return 2.*v, -2.*v
258    else:
259        return 0., (2.*v if v > 0. else 1.)
260
261
262def _randomize_one(model_info, p, v):
263    # type: (ModelInfo, str, float) -> float
264    # type: (ModelInfo, str, str) -> str
265    """
266    Randomize a single parameter.
267    """
268    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
269        return v
270
271    # Find the parameter definition
272    for par in model_info.parameters.call_parameters:
273        if par.name == p:
274            break
275    else:
276        raise ValueError("unknown parameter %r"%p)
277
278    if len(par.limits) > 2:  # choice list
279        return np.random.randint(len(par.limits))
280
281    limits = par.limits
282    if not np.isfinite(limits).all():
283        low, high = parameter_range(p, v)
284        limits = (max(limits[0], low), min(limits[1], high))
285
286    return np.random.uniform(*limits)
287
288
289def randomize_pars(model_info, pars, seed=None):
290    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
291    """
292    Generate random values for all of the parameters.
293
294    Valid ranges for the random number generator are guessed from the name of
295    the parameter; this will not account for constraints such as cap radius
296    greater than cylinder radius in the capped_cylinder model, so
297    :func:`constrain_pars` needs to be called afterward..
298    """
299    with push_seed(seed):
300        # Note: the sort guarantees order `of calls to random number generator
301        random_pars = dict((p, _randomize_one(model_info, p, v))
302                           for p, v in sorted(pars.items()))
303    return random_pars
304
305def constrain_pars(model_info, pars):
306    # type: (ModelInfo, ParameterSet) -> None
307    """
308    Restrict parameters to valid values.
309
310    This includes model specific code for models such as capped_cylinder
311    which need to support within model constraints (cap radius more than
312    cylinder radius in this case).
313
314    Warning: this updates the *pars* dictionary in place.
315    """
316    name = model_info.id
317    # if it is a product model, then just look at the form factor since
318    # none of the structure factors need any constraints.
319    if '*' in name:
320        name = name.split('*')[0]
321
322    if name == 'capped_cylinder' and pars['radius_cap'] < pars['radius']:
323        pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
324    if name == 'barbell' and pars['radius_bell'] < pars['radius']:
325        pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
326
327    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
328    if name == 'guinier':
329        #q_max = 0.2  # mid q maximum
330        q_max = 1.0  # high q maximum
331        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
332        pars['rg'] = min(pars['rg'], rg_max)
333
334    if name == 'rpa':
335        # Make sure phi sums to 1.0
336        if pars['case_num'] < 2:
337            pars['Phi1'] = 0.
338            pars['Phi2'] = 0.
339        elif pars['case_num'] < 5:
340            pars['Phi1'] = 0.
341        total = sum(pars['Phi'+c] for c in '1234')
342        for c in '1234':
343            pars['Phi'+c] /= total
344
345def parlist(model_info, pars, is2d):
346    # type: (ModelInfo, ParameterSet, bool) -> str
347    """
348    Format the parameter list for printing.
349    """
350    lines = []
351    parameters = model_info.parameters
352    magnetic = False
353    for p in parameters.user_parameters(pars, is2d):
354        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
355            continue
356        if p.id.startswith('up:') and not magnetic:
357            continue
358        fields = dict(
359            value=pars.get(p.id, p.default),
360            pd=pars.get(p.id+"_pd", 0.),
361            n=int(pars.get(p.id+"_pd_n", 0)),
362            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
363            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
364            relative_pd=p.relative_pd,
365            M0=pars.get('M0:'+p.id, 0.),
366            mphi=pars.get('mphi:'+p.id, 0.),
367            mtheta=pars.get('mtheta:'+p.id, 0.),
368        )
369        lines.append(_format_par(p.name, **fields))
370        magnetic = magnetic or fields['M0'] != 0.
371    return "\n".join(lines)
372
373    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
374
375def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
376                relative_pd=False, M0=0., mphi=0., mtheta=0.):
377    # type: (str, float, float, int, float, str) -> str
378    line = "%s: %g"%(name, value)
379    if pd != 0.  and n != 0:
380        if relative_pd:
381            pd *= value
382        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
383                % (pd, n, nsigma, nsigma, pdtype)
384    if M0 != 0.:
385        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
386    return line
387
388def suppress_pd(pars):
389    # type: (ParameterSet) -> ParameterSet
390    """
391    Suppress theta_pd for now until the normalization is resolved.
392
393    May also suppress complete polydispersity of the model to test
394    models more quickly.
395    """
396    pars = pars.copy()
397    for p in pars:
398        if p.endswith("_pd_n"): pars[p] = 0
399    return pars
400
401def suppress_magnetism(pars):
402    # type: (ParameterSet) -> ParameterSet
403    """
404    Suppress theta_pd for now until the normalization is resolved.
405
406    May also suppress complete polydispersity of the model to test
407    models more quickly.
408    """
409    pars = pars.copy()
410    for p in pars:
411        if p.startswith("M0:"): pars[p] = 0
412    return pars
413
414def eval_sasview(model_info, data):
415    # type: (Modelinfo, Data) -> Calculator
416    """
417    Return a model calculator using the pre-4.0 SasView models.
418    """
419    # importing sas here so that the error message will be that sas failed to
420    # import rather than the more obscure smear_selection not imported error
421    import sas
422    import sas.models
423    from sas.models.qsmearing import smear_selection
424    from sas.models.MultiplicationModel import MultiplicationModel
425    from sas.models.dispersion_models import models as dispersers
426
427    def get_model_class(name):
428        # type: (str) -> "sas.models.BaseComponent"
429        #print("new",sorted(_pars.items()))
430        __import__('sas.models.' + name)
431        ModelClass = getattr(getattr(sas.models, name, None), name, None)
432        if ModelClass is None:
433            raise ValueError("could not find model %r in sas.models"%name)
434        return ModelClass
435
436    # WARNING: ugly hack when handling model!
437    # Sasview models with multiplicity need to be created with the target
438    # multiplicity, so we cannot create the target model ahead of time for
439    # for multiplicity models.  Instead we store the model in a list and
440    # update the first element of that list with the new multiplicity model
441    # every time we evaluate.
442
443    # grab the sasview model, or create it if it is a product model
444    if model_info.composition:
445        composition_type, parts = model_info.composition
446        if composition_type == 'product':
447            P, S = [get_model_class(revert_name(p))() for p in parts]
448            model = [MultiplicationModel(P, S)]
449        else:
450            raise ValueError("sasview mixture models not supported by compare")
451    else:
452        old_name = revert_name(model_info)
453        if old_name is None:
454            raise ValueError("model %r does not exist in old sasview"
455                            % model_info.id)
456        ModelClass = get_model_class(old_name)
457        model = [ModelClass()]
458    model[0].disperser_handles = {}
459
460    # build a smearer with which to call the model, if necessary
461    smearer = smear_selection(data, model=model)
462    if hasattr(data, 'qx_data'):
463        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
464        index = ((~data.mask) & (~np.isnan(data.data))
465                 & (q >= data.qmin) & (q <= data.qmax))
466        if smearer is not None:
467            smearer.model = model  # because smear_selection has a bug
468            smearer.accuracy = data.accuracy
469            smearer.set_index(index)
470            def _call_smearer():
471                smearer.model = model[0]
472                return smearer.get_value()
473            theory = _call_smearer
474        else:
475            theory = lambda: model[0].evalDistribution([data.qx_data[index],
476                                                        data.qy_data[index]])
477    elif smearer is not None:
478        theory = lambda: smearer(model[0].evalDistribution(data.x))
479    else:
480        theory = lambda: model[0].evalDistribution(data.x)
481
482    def calculator(**pars):
483        # type: (float, ...) -> np.ndarray
484        """
485        Sasview calculator for model.
486        """
487        oldpars = revert_pars(model_info, pars)
488        # For multiplicity models, create a model with the correct multiplicity
489        control = oldpars.pop("CONTROL", None)
490        if control is not None:
491            # sphericalSLD has one fewer multiplicity.  This update should
492            # happen in revert_pars, but it hasn't been called yet.
493            model[0] = ModelClass(control)
494        # paying for parameter conversion each time to keep life simple, if not fast
495        for k, v in oldpars.items():
496            if k.endswith('.type'):
497                par = k[:-5]
498                if v == 'gaussian': continue
499                cls = dispersers[v if v != 'rectangle' else 'rectangula']
500                handle = cls()
501                model[0].disperser_handles[par] = handle
502                try:
503                    model[0].set_dispersion(par, handle)
504                except Exception:
505                    exception.annotate_exception("while setting %s to %r"
506                                                 %(par, v))
507                    raise
508
509
510        #print("sasview pars",oldpars)
511        for k, v in oldpars.items():
512            name_attr = k.split('.')  # polydispersity components
513            if len(name_attr) == 2:
514                par, disp_par = name_attr
515                model[0].dispersion[par][disp_par] = v
516            else:
517                model[0].setParam(k, v)
518        return theory()
519
520    calculator.engine = "sasview"
521    return calculator
522
523DTYPE_MAP = {
524    'half': '16',
525    'fast': 'fast',
526    'single': '32',
527    'double': '64',
528    'quad': '128',
529    'f16': '16',
530    'f32': '32',
531    'f64': '64',
532    'longdouble': '128',
533}
534def eval_opencl(model_info, data, dtype='single', cutoff=0.):
535    # type: (ModelInfo, Data, str, float) -> Calculator
536    """
537    Return a model calculator using the OpenCL calculation engine.
538    """
539    if not core.HAVE_OPENCL:
540        raise RuntimeError("OpenCL not available")
541    model = core.build_model(model_info, dtype=dtype, platform="ocl")
542    calculator = DirectModel(data, model, cutoff=cutoff)
543    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
544    return calculator
545
546def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
547    # type: (ModelInfo, Data, str, float) -> Calculator
548    """
549    Return a model calculator using the DLL calculation engine.
550    """
551    model = core.build_model(model_info, dtype=dtype, platform="dll")
552    calculator = DirectModel(data, model, cutoff=cutoff)
553    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
554    return calculator
555
556def time_calculation(calculator, pars, evals=1):
557    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
558    """
559    Compute the average calculation time over N evaluations.
560
561    An additional call is generated without polydispersity in order to
562    initialize the calculation engine, and make the average more stable.
563    """
564    # initialize the code so time is more accurate
565    if evals > 1:
566        calculator(**suppress_pd(pars))
567    toc = tic()
568    # make sure there is at least one eval
569    value = calculator(**pars)
570    for _ in range(evals-1):
571        value = calculator(**pars)
572    average_time = toc()*1000. / evals
573    #print("I(q)",value)
574    return value, average_time
575
576def make_data(opts):
577    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
578    """
579    Generate an empty dataset, used with the model to set Q points
580    and resolution.
581
582    *opts* contains the options, with 'qmax', 'nq', 'res',
583    'accuracy', 'is2d' and 'view' parsed from the command line.
584    """
585    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
586    if opts['is2d']:
587        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
588        data = empty_data2D(q, resolution=res)
589        data.accuracy = opts['accuracy']
590        set_beam_stop(data, 0.0004)
591        index = ~data.mask
592    else:
593        if opts['view'] == 'log' and not opts['zero']:
594            qmax = math.log10(qmax)
595            q = np.logspace(qmax-3, qmax, nq)
596        else:
597            q = np.linspace(0.001*qmax, qmax, nq)
598        if opts['zero']:
599            q = np.hstack((0, q))
600        data = empty_data1D(q, resolution=res)
601        index = slice(None, None)
602    return data, index
603
604def make_engine(model_info, data, dtype, cutoff):
605    # type: (ModelInfo, Data, str, float) -> Calculator
606    """
607    Generate the appropriate calculation engine for the given datatype.
608
609    Datatypes with '!' appended are evaluated using external C DLLs rather
610    than OpenCL.
611    """
612    if dtype == 'sasview':
613        return eval_sasview(model_info, data)
614    elif dtype.endswith('!'):
615        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
616    else:
617        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
618
619def _show_invalid(data, theory):
620    # type: (Data, np.ma.ndarray) -> None
621    """
622    Display a list of the non-finite values in theory.
623    """
624    if not theory.mask.any():
625        return
626
627    if hasattr(data, 'x'):
628        bad = zip(data.x[theory.mask], theory[theory.mask])
629        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
630
631
632def compare(opts, limits=None):
633    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
634    """
635    Preform a comparison using options from the command line.
636
637    *limits* are the limits on the values to use, either to set the y-axis
638    for 1D or to set the colormap scale for 2D.  If None, then they are
639    inferred from the data and returned. When exploring using Bumps,
640    the limits are set when the model is initially called, and maintained
641    as the values are adjusted, making it easier to see the effects of the
642    parameters.
643    """
644    n_base, n_comp = opts['count']
645    pars, pars2 = opts['pars']
646    data = opts['data']
647
648    # silence the linter
649    base = opts['engines'][0] if n_base else None
650    comp = opts['engines'][1] if n_comp else None
651    base_time = comp_time = None
652    base_value = comp_value = resid = relerr = None
653
654    # Base calculation
655    if n_base > 0:
656        try:
657            base_raw, base_time = time_calculation(base, pars, n_base)
658            base_value = np.ma.masked_invalid(base_raw)
659            print("%s t=%.2f ms, intensity=%.0f"
660                  % (base.engine, base_time, base_value.sum()))
661            _show_invalid(data, base_value)
662        except ImportError:
663            traceback.print_exc()
664            n_base = 0
665
666    # Comparison calculation
667    if n_comp > 0:
668        try:
669            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
670            comp_value = np.ma.masked_invalid(comp_raw)
671            print("%s t=%.2f ms, intensity=%.0f"
672                  % (comp.engine, comp_time, comp_value.sum()))
673            _show_invalid(data, comp_value)
674        except ImportError:
675            traceback.print_exc()
676            n_comp = 0
677
678    # Compare, but only if computing both forms
679    if n_base > 0 and n_comp > 0:
680        resid = (base_value - comp_value)
681        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
682        _print_stats("|%s-%s|"
683                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
684                     resid)
685        _print_stats("|(%s-%s)/%s|"
686                     % (base.engine, comp.engine, comp.engine),
687                     relerr)
688
689    # Plot if requested
690    if not opts['plot'] and not opts['explore']: return
691    view = opts['view']
692    import matplotlib.pyplot as plt
693    if limits is None:
694        vmin, vmax = np.Inf, -np.Inf
695        if n_base > 0:
696            vmin = min(vmin, base_value.min())
697            vmax = max(vmax, base_value.max())
698        if n_comp > 0:
699            vmin = min(vmin, comp_value.min())
700            vmax = max(vmax, comp_value.max())
701        limits = vmin, vmax
702
703    if n_base > 0:
704        if n_comp > 0: plt.subplot(131)
705        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
706        plt.title("%s t=%.2f ms"%(base.engine, base_time))
707        #cbar_title = "log I"
708    if n_comp > 0:
709        if n_base > 0: plt.subplot(132)
710        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
711        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
712        #cbar_title = "log I"
713    if n_comp > 0 and n_base > 0:
714        if not opts['is2d']:
715            plot_theory(data, base_value, view=view, use_data=False, limits=limits)
716        plt.subplot(133)
717        if not opts['rel_err']:
718            err, errstr, errview = resid, "abs err", "linear"
719        else:
720            err, errstr, errview = abs(relerr), "rel err", "log"
721        #sorted = np.sort(err.flatten())
722        #cutoff = sorted[int(sorted.size*0.95)]
723        #err[err>cutoff] = cutoff
724        #err,errstr = base/comp,"ratio"
725        plot_theory(data, None, resid=err, view=errview, use_data=False)
726        if view == 'linear':
727            plt.xscale('linear')
728        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
729        #cbar_title = errstr if errview=="linear" else "log "+errstr
730    #if is2D:
731    #    h = plt.colorbar()
732    #    h.ax.set_title(cbar_title)
733    fig = plt.gcf()
734    extra_title = ' '+opts['title'] if opts['title'] else ''
735    fig.suptitle(":".join(opts['name']) + extra_title)
736
737    if n_comp > 0 and n_base > 0 and opts['show_hist']:
738        plt.figure()
739        v = relerr
740        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
741        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
742        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
743                   % (base.engine, comp.engine, comp.engine))
744        plt.ylabel('P(err)')
745        plt.title('Distribution of relative error between calculation engines')
746
747    if not opts['explore']:
748        plt.show()
749
750    return limits
751
752def _print_stats(label, err):
753    # type: (str, np.ma.ndarray) -> None
754    # work with trimmed data, not the full set
755    sorted_err = np.sort(abs(err.compressed()))
756    p50 = int((len(sorted_err)-1)*0.50)
757    p98 = int((len(sorted_err)-1)*0.98)
758    data = [
759        "max:%.3e"%sorted_err[-1],
760        "median:%.3e"%sorted_err[p50],
761        "98%%:%.3e"%sorted_err[p98],
762        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
763        "zero-offset:%+.3e"%np.mean(sorted_err),
764        ]
765    print(label+"  "+"  ".join(data))
766
767
768
769# ===========================================================================
770#
771NAME_OPTIONS = set([
772    'plot', 'noplot',
773    'half', 'fast', 'single', 'double',
774    'single!', 'double!', 'quad!', 'sasview',
775    'lowq', 'midq', 'highq', 'exq', 'zero',
776    '2d', '1d',
777    'preset', 'random',
778    'poly', 'mono',
779    'magnetic', 'nonmagnetic',
780    'nopars', 'pars',
781    'rel', 'abs',
782    'linear', 'log', 'q4',
783    'hist', 'nohist',
784    'edit', 'html',
785    'demo', 'default',
786    ])
787VALUE_OPTIONS = [
788    # Note: random is both a name option and a value option
789    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title',
790    ]
791
792def columnize(items, indent="", width=79):
793    # type: (List[str], str, int) -> str
794    """
795    Format a list of strings into columns.
796
797    Returns a string with carriage returns ready for printing.
798    """
799    column_width = max(len(w) for w in items) + 1
800    num_columns = (width - len(indent)) // column_width
801    num_rows = len(items) // num_columns
802    items = items + [""] * (num_rows * num_columns - len(items))
803    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
804    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
805             for row in zip(*columns)]
806    output = indent + ("\n"+indent).join(lines)
807    return output
808
809
810def get_pars(model_info, use_demo=False):
811    # type: (ModelInfo, bool) -> ParameterSet
812    """
813    Extract demo parameters from the model definition.
814    """
815    # Get the default values for the parameters
816    pars = {}
817    for p in model_info.parameters.call_parameters:
818        parts = [('', p.default)]
819        if p.polydisperse:
820            parts.append(('_pd', 0.0))
821            parts.append(('_pd_n', 0))
822            parts.append(('_pd_nsigma', 3.0))
823            parts.append(('_pd_type', "gaussian"))
824        for ext, val in parts:
825            if p.length > 1:
826                dict(("%s%d%s" % (p.id, k, ext), val)
827                     for k in range(1, p.length+1))
828            else:
829                pars[p.id + ext] = val
830
831    # Plug in values given in demo
832    if use_demo:
833        pars.update(model_info.demo)
834    return pars
835
836INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
837def isnumber(str):
838    match = FLOAT_RE.match(str)
839    isfloat = (match and not str[match.end():])
840    return isfloat or INTEGER_RE.match(str)
841
842def parse_opts(argv):
843    # type: (List[str]) -> Dict[str, Any]
844    """
845    Parse command line options.
846    """
847    MODELS = core.list_models()
848    flags = [arg for arg in argv
849             if arg.startswith('-')]
850    values = [arg for arg in argv
851              if not arg.startswith('-') and '=' in arg]
852    positional_args = [arg for arg in argv
853            if not arg.startswith('-') and '=' not in arg]
854    models = "\n    ".join("%-15s"%v for v in MODELS)
855    if len(positional_args) == 0:
856        print(USAGE)
857        print("\nAvailable models:")
858        print(columnize(MODELS, indent="  "))
859        return None
860    if len(positional_args) > 3:
861        print("expected parameters: model N1 N2")
862
863    invalid = [o[1:] for o in flags
864               if o[1:] not in NAME_OPTIONS
865               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
866    if invalid:
867        print("Invalid options: %s"%(", ".join(invalid)))
868        return None
869
870    name = positional_args[0]
871    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
872    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
873
874    # pylint: disable=bad-whitespace
875    # Interpret the flags
876    opts = {
877        'plot'      : True,
878        'view'      : 'log',
879        'is2d'      : False,
880        'qmax'      : 0.05,
881        'nq'        : 128,
882        'res'       : 0.0,
883        'accuracy'  : 'Low',
884        'cutoff'    : 0.0,
885        'seed'      : -1,  # default to preset
886        'mono'      : False,
887        # Default to magnetic a magnetic moment is set on the command line
888        'magnetic'  : any(v.startswith('M0:') for v in values),
889        'show_pars' : False,
890        'show_hist' : False,
891        'rel_err'   : True,
892        'explore'   : False,
893        'use_demo'  : True,
894        'zero'      : False,
895        'html'      : False,
896        'title'     : None,
897    }
898    engines = []
899    for arg in flags:
900        if arg == '-noplot':    opts['plot'] = False
901        elif arg == '-plot':    opts['plot'] = True
902        elif arg == '-linear':  opts['view'] = 'linear'
903        elif arg == '-log':     opts['view'] = 'log'
904        elif arg == '-q4':      opts['view'] = 'q4'
905        elif arg == '-1d':      opts['is2d'] = False
906        elif arg == '-2d':      opts['is2d'] = True
907        elif arg == '-exq':     opts['qmax'] = 10.0
908        elif arg == '-highq':   opts['qmax'] = 1.0
909        elif arg == '-midq':    opts['qmax'] = 0.2
910        elif arg == '-lowq':    opts['qmax'] = 0.05
911        elif arg == '-zero':    opts['zero'] = True
912        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
913        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
914        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
915        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
916        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
917        elif arg.startswith('-title'):     opts['title'] = arg[7:]
918        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
919        elif arg == '-preset':  opts['seed'] = -1
920        elif arg == '-mono':    opts['mono'] = True
921        elif arg == '-poly':    opts['mono'] = False
922        elif arg == '-magnetic':       opts['magnetic'] = True
923        elif arg == '-nonmagnetic':    opts['magnetic'] = False
924        elif arg == '-pars':    opts['show_pars'] = True
925        elif arg == '-nopars':  opts['show_pars'] = False
926        elif arg == '-hist':    opts['show_hist'] = True
927        elif arg == '-nohist':  opts['show_hist'] = False
928        elif arg == '-rel':     opts['rel_err'] = True
929        elif arg == '-abs':     opts['rel_err'] = False
930        elif arg == '-half':    engines.append(arg[1:])
931        elif arg == '-fast':    engines.append(arg[1:])
932        elif arg == '-single':  engines.append(arg[1:])
933        elif arg == '-double':  engines.append(arg[1:])
934        elif arg == '-single!': engines.append(arg[1:])
935        elif arg == '-double!': engines.append(arg[1:])
936        elif arg == '-quad!':   engines.append(arg[1:])
937        elif arg == '-sasview': engines.append(arg[1:])
938        elif arg == '-edit':    opts['explore'] = True
939        elif arg == '-demo':    opts['use_demo'] = True
940        elif arg == '-default':    opts['use_demo'] = False
941        elif arg == '-html':    opts['html'] = True
942    # pylint: enable=bad-whitespace
943
944    if ':' in name:
945        name, name2 = name.split(':',2)
946    else:
947        name2 = name
948    try:
949        model_info = core.load_model_info(name)
950        model_info2 = core.load_model_info(name2) if name2 != name else model_info
951    except ImportError as exc:
952        print(str(exc))
953        print("Could not find model; use one of:\n    " + models)
954        return None
955
956    # Get demo parameters from model definition, or use default parameters
957    # if model does not define demo parameters
958    pars = get_pars(model_info, opts['use_demo'])
959    pars2 = get_pars(model_info2, opts['use_demo'])
960    pars2.update((k, v) for k, v in pars.items() if k in pars2)
961    # randomize parameters
962    #pars.update(set_pars)  # set value before random to control range
963    if opts['seed'] > -1:
964        pars = randomize_pars(model_info, pars, seed=opts['seed'])
965        if model_info != model_info2:
966            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
967        else:
968            pars2 = pars.copy()
969        print("Randomize using -random=%i"%opts['seed'])
970    if opts['mono']:
971        pars = suppress_pd(pars)
972        pars2 = suppress_pd(pars2)
973
974    # Fill in parameters given on the command line
975    presets = {}
976    presets2 = {}
977    for arg in values:
978        k, v = arg.split('=', 1)
979        if k not in pars and k not in pars2:
980            # extract base name without polydispersity info
981            s = set(p.split('_pd')[0] for p in pars)
982            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
983            return None
984        v1, v2 = v.split(':',2) if ':' in v else (v,v)
985        if v1 and k in pars:
986            presets[k] = float(v1) if isnumber(v1) else v1
987        if v2 and k in pars2:
988            presets2[k] = float(v2) if isnumber(v2) else v2
989
990    # Evaluate preset parameter expressions
991    context = MATH.copy()
992    context.update(pars)
993    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
994    for k, v in presets.items():
995        if not isinstance(v, float) and not k.endswith('_type'):
996            presets[k] = eval(v, context)
997    context.update(presets)
998    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
999    for k, v in presets2.items():
1000        if not isinstance(v, float) and not k.endswith('_type'):
1001            presets2[k] = eval(v, context)
1002
1003    # update parameters with presets
1004    pars.update(presets)  # set value after random to control value
1005    pars2.update(presets2)  # set value after random to control value
1006    #import pprint; pprint.pprint(model_info)
1007    constrain_pars(model_info, pars)
1008    constrain_pars(model_info2, pars2)
1009    if not opts['magnetic']:
1010        pars = suppress_magnetism(pars)
1011        pars2 = suppress_magnetism(pars2)
1012
1013    same_model = name == name2 and pars == pars
1014    if len(engines) == 0:
1015        if same_model:
1016            engines.extend(['single', 'double'])
1017        else:
1018            engines.extend(['single', 'single'])
1019    elif len(engines) == 1:
1020        if not same_model:
1021            engines.append(engines[0])
1022        elif engines[0] == 'double':
1023            engines.append('single')
1024        else:
1025            engines.append('double')
1026    elif len(engines) > 2:
1027        del engines[2:]
1028
1029    use_sasview = any(engine == 'sasview' and count > 0
1030                      for engine, count in zip(engines, [n1, n2]))
1031    if use_sasview:
1032        constrain_new_to_old(model_info, pars)
1033        constrain_new_to_old(model_info2, pars2)
1034
1035    if opts['show_pars']:
1036        if not same_model:
1037            print("==== %s ====="%model_info.name)
1038            print(str(parlist(model_info, pars, opts['is2d'])))
1039            print("==== %s ====="%model_info2.name)
1040            print(str(parlist(model_info2, pars2, opts['is2d'])))
1041        else:
1042            print(str(parlist(model_info, pars, opts['is2d'])))
1043
1044    # Create the computational engines
1045    data, _ = make_data(opts)
1046    if n1:
1047        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1048    else:
1049        base = None
1050    if n2:
1051        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1052    else:
1053        comp = None
1054
1055    # pylint: disable=bad-whitespace
1056    # Remember it all
1057    opts.update({
1058        'data'      : data,
1059        'name'      : [name, name2],
1060        'def'       : [model_info, model_info2],
1061        'count'     : [n1, n2],
1062        'presets'   : [presets, presets2],
1063        'pars'      : [pars, pars2],
1064        'engines'   : [base, comp],
1065    })
1066    # pylint: enable=bad-whitespace
1067
1068    return opts
1069
1070def show_docs(opts):
1071    # type: (Dict[str, Any]) -> None
1072    """
1073    show html docs for the model
1074    """
1075    import wx  # type: ignore
1076    from .generate import view_html_from_info
1077    app = wx.App() if wx.GetApp() is None else None
1078    view_html_from_info(opts['def'][0])
1079    if app: app.MainLoop()
1080
1081
1082def explore(opts):
1083    # type: (Dict[str, Any]) -> None
1084    """
1085    explore the model using the bumps gui.
1086    """
1087    import wx  # type: ignore
1088    from bumps.names import FitProblem  # type: ignore
1089    from bumps.gui.app_frame import AppFrame  # type: ignore
1090
1091    is_mac = "cocoa" in wx.version()
1092    # Create an app if not running embedded
1093    app = wx.App() if wx.GetApp() is None else None
1094    problem = FitProblem(Explore(opts))
1095    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1096    if not is_mac: frame.Show()
1097    frame.panel.set_model(model=problem)
1098    frame.panel.Layout()
1099    frame.panel.aui.Split(0, wx.TOP)
1100    if is_mac: frame.Show()
1101    # If running withing an app, start the main loop
1102    if app: app.MainLoop()
1103
1104class Explore(object):
1105    """
1106    Bumps wrapper for a SAS model comparison.
1107
1108    The resulting object can be used as a Bumps fit problem so that
1109    parameters can be adjusted in the GUI, with plots updated on the fly.
1110    """
1111    def __init__(self, opts):
1112        # type: (Dict[str, Any]) -> None
1113        from bumps.cli import config_matplotlib  # type: ignore
1114        from . import bumps_model
1115        config_matplotlib()
1116        self.opts = opts
1117        model_info = opts['def'][0]
1118        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'][0])
1119        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1120        if not opts['is2d']:
1121            for p in model_info.parameters.user_parameters(is2d=False):
1122                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1123                    k = p.name+ext
1124                    v = pars.get(k, None)
1125                    if v is not None:
1126                        v.range(*parameter_range(k, v.value))
1127        else:
1128            for k, v in pars.items():
1129                v.range(*parameter_range(k, v.value))
1130
1131        self.pars = pars
1132        self.pd_types = pd_types
1133        self.limits = None
1134
1135    def numpoints(self):
1136        # type: () -> int
1137        """
1138        Return the number of points.
1139        """
1140        return len(self.pars) + 1  # so dof is 1
1141
1142    def parameters(self):
1143        # type: () -> Any   # Dict/List hierarchy of parameters
1144        """
1145        Return a dictionary of parameters.
1146        """
1147        return self.pars
1148
1149    def nllf(self):
1150        # type: () -> float
1151        """
1152        Return cost.
1153        """
1154        # pylint: disable=no-self-use
1155        return 0.  # No nllf
1156
1157    def plot(self, view='log'):
1158        # type: (str) -> None
1159        """
1160        Plot the data and residuals.
1161        """
1162        pars = dict((k, v.value) for k, v in self.pars.items())
1163        pars.update(self.pd_types)
1164        self.opts['pars'][0] = pars
1165        self.opts['pars'][1] = pars
1166        limits = compare(self.opts, limits=self.limits)
1167        if self.limits is None:
1168            vmin, vmax = limits
1169            self.limits = vmax*1e-7, 1.3*vmax
1170
1171
1172def main(*argv):
1173    # type: (*str) -> None
1174    """
1175    Main program.
1176    """
1177    opts = parse_opts(argv)
1178    if opts is not None:
1179        if opts['html']:
1180            show_docs(opts)
1181        elif opts['explore']:
1182            explore(opts)
1183        else:
1184            compare(opts)
1185
1186if __name__ == "__main__":
1187    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.