source: sasmodels/sasmodels/compare.py @ 050c2c8

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 050c2c8 was 050c2c8, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

sasview compare with alternate polydispersity functions

  • Property mode set to 100755
File size: 35.7 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np  # type: ignore
37
38from . import core
39from . import kerneldll
40from . import weights
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_name, revert_pars, constrain_new_to_old
44
45try:
46    from typing import Optional, Dict, Any, Callable, Tuple
47except:
48    pass
49else:
50    from .modelinfo import ModelInfo, Parameter, ParameterSet
51    from .data import Data
52    Calculator = Callable[[float], np.ndarray]
53
54USAGE = """
55usage: compare.py model N1 N2 [options...] [key=val]
56
57Compare the speed and value for a model between the SasView original and the
58sasmodels rewrite.
59
60model is the name of the model to compare (see below).
61N1 is the number of times to run sasmodels (default=1).
62N2 is the number times to run sasview (default=1).
63
64Options (* for default):
65
66    -plot*/-noplot plots or suppress the plot of the model
67    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
68    -nq=128 sets the number of Q points in the data set
69    -zero indicates that q=0 should be included
70    -1d*/-2d computes 1d or 2d data
71    -preset*/-random[=seed] preset or random parameters
72    -mono/-poly* force monodisperse/polydisperse
73    -cutoff=1e-5* cutoff value for including a point in polydispersity
74    -pars/-nopars* prints the parameter set or not
75    -abs/-rel* plot relative or absolute error
76    -linear/-log*/-q4 intensity scaling
77    -hist/-nohist* plot histogram of relative error
78    -res=0 sets the resolution width dQ/Q if calculating with resolution
79    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
80    -edit starts the parameter explorer
81    -default/-demo* use demo vs default parameters
82
83Any two calculation engines can be selected for comparison:
84
85    -single/-double/-half/-fast sets an OpenCL calculation engine
86    -single!/-double!/-quad! sets an OpenMP calculation engine
87    -sasview sets the sasview calculation engine
88
89The default is -single -sasview.  Note that the interpretation of quad
90precision depends on architecture, and may vary from 64-bit to 128-bit,
91with 80-bit floats being common (1e-19 precision).
92
93Key=value pairs allow you to set specific values for the model parameters.
94"""
95
96# Update docs with command line usage string.   This is separate from the usual
97# doc string so that we can display it at run time if there is an error.
98# lin
99__doc__ = (__doc__  # pylint: disable=redefined-builtin
100           + """
101Program description
102-------------------
103
104"""
105           + USAGE)
106
107kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
108
109# CRUFT python 2.6
110if not hasattr(datetime.timedelta, 'total_seconds'):
111    def delay(dt):
112        """Return number date-time delta as number seconds"""
113        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
114else:
115    def delay(dt):
116        """Return number date-time delta as number seconds"""
117        return dt.total_seconds()
118
119
120class push_seed(object):
121    """
122    Set the seed value for the random number generator.
123
124    When used in a with statement, the random number generator state is
125    restored after the with statement is complete.
126
127    :Parameters:
128
129    *seed* : int or array_like, optional
130        Seed for RandomState
131
132    :Example:
133
134    Seed can be used directly to set the seed::
135
136        >>> from numpy.random import randint
137        >>> push_seed(24)
138        <...push_seed object at...>
139        >>> print(randint(0,1000000,3))
140        [242082    899 211136]
141
142    Seed can also be used in a with statement, which sets the random
143    number generator state for the enclosed computations and restores
144    it to the previous state on completion::
145
146        >>> with push_seed(24):
147        ...    print(randint(0,1000000,3))
148        [242082    899 211136]
149
150    Using nested contexts, we can demonstrate that state is indeed
151    restored after the block completes::
152
153        >>> with push_seed(24):
154        ...    print(randint(0,1000000))
155        ...    with push_seed(24):
156        ...        print(randint(0,1000000,3))
157        ...    print(randint(0,1000000))
158        242082
159        [242082    899 211136]
160        899
161
162    The restore step is protected against exceptions in the block::
163
164        >>> with push_seed(24):
165        ...    print(randint(0,1000000))
166        ...    try:
167        ...        with push_seed(24):
168        ...            print(randint(0,1000000,3))
169        ...            raise Exception()
170        ...    except Exception:
171        ...        print("Exception raised")
172        ...    print(randint(0,1000000))
173        242082
174        [242082    899 211136]
175        Exception raised
176        899
177    """
178    def __init__(self, seed=None):
179        # type: (Optional[int]) -> None
180        self._state = np.random.get_state()
181        np.random.seed(seed)
182
183    def __enter__(self):
184        # type: () -> None
185        pass
186
187    def __exit__(self, exc_type, exc_value, traceback):
188        # type: (Any, BaseException, Any) -> None
189        # TODO: better typing for __exit__ method
190        np.random.set_state(self._state)
191
192def tic():
193    # type: () -> Callable[[], float]
194    """
195    Timer function.
196
197    Use "toc=tic()" to start the clock and "toc()" to measure
198    a time interval.
199    """
200    then = datetime.datetime.now()
201    return lambda: delay(datetime.datetime.now() - then)
202
203
204def set_beam_stop(data, radius, outer=None):
205    # type: (Data, float, float) -> None
206    """
207    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
208
209    Note: this function does not require sasview
210    """
211    if hasattr(data, 'qx_data'):
212        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
213        data.mask = (q < radius)
214        if outer is not None:
215            data.mask |= (q >= outer)
216    else:
217        data.mask = (data.x < radius)
218        if outer is not None:
219            data.mask |= (data.x >= outer)
220
221
222def parameter_range(p, v):
223    # type: (str, float) -> Tuple[float, float]
224    """
225    Choose a parameter range based on parameter name and initial value.
226    """
227    # process the polydispersity options
228    if p.endswith('_pd_n'):
229        return 0., 100.
230    elif p.endswith('_pd_nsigma'):
231        return 0., 5.
232    elif p.endswith('_pd_type'):
233        raise ValueError("Cannot return a range for a string value")
234    elif any(s in p for s in ('theta', 'phi', 'psi')):
235        # orientation in [-180,180], orientation pd in [0,45]
236        if p.endswith('_pd'):
237            return 0., 45.
238        else:
239            return -180., 180.
240    elif p.endswith('_pd'):
241        return 0., 1.
242    elif 'sld' in p:
243        return -0.5, 10.
244    elif p == 'background':
245        return 0., 10.
246    elif p == 'scale':
247        return 0., 1.e3
248    elif v < 0.:
249        return 2.*v, -2.*v
250    else:
251        return 0., (2.*v if v > 0. else 1.)
252
253
254def _randomize_one(model_info, p, v):
255    # type: (ModelInfo, str, float) -> float
256    # type: (ModelInfo, str, str) -> str
257    """
258    Randomize a single parameter.
259    """
260    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
261        return v
262
263    # Find the parameter definition
264    for par in model_info.parameters.call_parameters:
265        if par.name == p:
266            break
267    else:
268        raise ValueError("unknown parameter %r"%p)
269
270    if len(par.limits) > 2:  # choice list
271        return np.random.randint(len(par.limits))
272
273    limits = par.limits
274    if not np.isfinite(limits).all():
275        low, high = parameter_range(p, v)
276        limits = (max(limits[0], low), min(limits[1], high))
277
278    return np.random.uniform(*limits)
279
280
281def randomize_pars(model_info, pars, seed=None):
282    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
283    """
284    Generate random values for all of the parameters.
285
286    Valid ranges for the random number generator are guessed from the name of
287    the parameter; this will not account for constraints such as cap radius
288    greater than cylinder radius in the capped_cylinder model, so
289    :func:`constrain_pars` needs to be called afterward..
290    """
291    with push_seed(seed):
292        # Note: the sort guarantees order `of calls to random number generator
293        random_pars = dict((p, _randomize_one(model_info, p, v))
294                           for p, v in sorted(pars.items()))
295    return random_pars
296
297def constrain_pars(model_info, pars):
298    # type: (ModelInfo, ParameterSet) -> None
299    """
300    Restrict parameters to valid values.
301
302    This includes model specific code for models such as capped_cylinder
303    which need to support within model constraints (cap radius more than
304    cylinder radius in this case).
305
306    Warning: this updates the *pars* dictionary in place.
307    """
308    name = model_info.id
309    # if it is a product model, then just look at the form factor since
310    # none of the structure factors need any constraints.
311    if '*' in name:
312        name = name.split('*')[0]
313
314    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
315        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
316    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
317        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
318
319    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
320    if name == 'guinier':
321        #q_max = 0.2  # mid q maximum
322        q_max = 1.0  # high q maximum
323        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
324        pars['rg'] = min(pars['rg'], rg_max)
325
326    if name == 'rpa':
327        # Make sure phi sums to 1.0
328        if pars['case_num'] < 2:
329            pars['Phi1'] = 0.
330            pars['Phi2'] = 0.
331        elif pars['case_num'] < 5:
332            pars['Phi1'] = 0.
333        total = sum(pars['Phi'+c] for c in '1234')
334        for c in '1234':
335            pars['Phi'+c] /= total
336
337def parlist(model_info, pars, is2d):
338    # type: (ModelInfo, ParameterSet, bool) -> str
339    """
340    Format the parameter list for printing.
341    """
342    lines = []
343    parameters = model_info.parameters
344    for p in parameters.user_parameters(pars, is2d):
345        fields = dict(
346            value=pars.get(p.id, p.default),
347            pd=pars.get(p.id+"_pd", 0.),
348            n=int(pars.get(p.id+"_pd_n", 0)),
349            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
350            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
351            relative_pd=p.relative_pd,
352        )
353        lines.append(_format_par(p.name, **fields))
354    return "\n".join(lines)
355
356    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
357
358def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
359                relative_pd=False):
360    # type: (str, float, float, int, float, str) -> str
361    line = "%s: %g"%(name, value)
362    if pd != 0.  and n != 0:
363        if relative_pd:
364            pd *= value
365        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
366                % (pd, n, nsigma, nsigma, pdtype)
367    return line
368
369def suppress_pd(pars):
370    # type: (ParameterSet) -> ParameterSet
371    """
372    Suppress theta_pd for now until the normalization is resolved.
373
374    May also suppress complete polydispersity of the model to test
375    models more quickly.
376    """
377    pars = pars.copy()
378    for p in pars:
379        if p.endswith("_pd_n"): pars[p] = 0
380    return pars
381
382def eval_sasview(model_info, data):
383    # type: (Modelinfo, Data) -> Calculator
384    """
385    Return a model calculator using the pre-4.0 SasView models.
386    """
387    # importing sas here so that the error message will be that sas failed to
388    # import rather than the more obscure smear_selection not imported error
389    import sas
390    import sas.models
391    from sas.models.qsmearing import smear_selection
392    from sas.models.MultiplicationModel import MultiplicationModel
393    from sas.models.dispersion_models import models as dispersers
394
395    def get_model_class(name):
396        # type: (str) -> "sas.models.BaseComponent"
397        #print("new",sorted(_pars.items()))
398        __import__('sas.models.' + name)
399        ModelClass = getattr(getattr(sas.models, name, None), name, None)
400        if ModelClass is None:
401            raise ValueError("could not find model %r in sas.models"%name)
402        return ModelClass
403
404    # WARNING: ugly hack when handling model!
405    # Sasview models with multiplicity need to be created with the target
406    # multiplicity, so we cannot create the target model ahead of time for
407    # for multiplicity models.  Instead we store the model in a list and
408    # update the first element of that list with the new multiplicity model
409    # every time we evaluate.
410
411    # grab the sasview model, or create it if it is a product model
412    if model_info.composition:
413        composition_type, parts = model_info.composition
414        if composition_type == 'product':
415            P, S = [get_model_class(revert_name(p))() for p in parts]
416            model = [MultiplicationModel(P, S)]
417        else:
418            raise ValueError("sasview mixture models not supported by compare")
419    else:
420        old_name = revert_name(model_info)
421        if old_name is None:
422            raise ValueError("model %r does not exist in old sasview"
423                            % model_info.id)
424        ModelClass = get_model_class(old_name)
425        model = [ModelClass()]
426    model[0].disperser_handles = {}
427
428    # build a smearer with which to call the model, if necessary
429    smearer = smear_selection(data, model=model)
430    if hasattr(data, 'qx_data'):
431        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
432        index = ((~data.mask) & (~np.isnan(data.data))
433                 & (q >= data.qmin) & (q <= data.qmax))
434        if smearer is not None:
435            smearer.model = model  # because smear_selection has a bug
436            smearer.accuracy = data.accuracy
437            smearer.set_index(index)
438            def _call_smearer():
439                smearer.model = model[0]
440                return smearer.get_value()
441            theory = _call_smearer
442        else:
443            theory = lambda: model[0].evalDistribution([data.qx_data[index],
444                                                        data.qy_data[index]])
445    elif smearer is not None:
446        theory = lambda: smearer(model[0].evalDistribution(data.x))
447    else:
448        theory = lambda: model[0].evalDistribution(data.x)
449
450    def calculator(**pars):
451        # type: (float, ...) -> np.ndarray
452        """
453        Sasview calculator for model.
454        """
455        oldpars = revert_pars(model_info, pars)
456        # For multiplicity models, create a model with the correct multiplicity
457        control = oldpars.pop("CONTROL", None)
458        if control is not None:
459            # sphericalSLD has one fewer multiplicity.  This update should
460            # happen in revert_pars, but it hasn't been called yet.
461            model[0] = ModelClass(control)
462        # paying for parameter conversion each time to keep life simple, if not fast
463        for k, v in oldpars.items():
464            if k.endswith('.type'):
465                par = k[:-5]
466                cls = dispersers[v if v != 'rectangle' else 'rectangula']
467                handle = cls()
468                model[0].disperser_handles[par] = handle
469                model[0].set_dispersion(par, handle)
470
471        #print("sasview pars",oldpars)
472        for k, v in oldpars.items():
473            name_attr = k.split('.')  # polydispersity components
474            if len(name_attr) == 2:
475                par, disp_par = name_attr
476                model[0].dispersion[par][disp_par] = v
477            else:
478                model[0].setParam(k, v)
479        return theory()
480
481    calculator.engine = "sasview"
482    return calculator
483
484DTYPE_MAP = {
485    'half': '16',
486    'fast': 'fast',
487    'single': '32',
488    'double': '64',
489    'quad': '128',
490    'f16': '16',
491    'f32': '32',
492    'f64': '64',
493    'longdouble': '128',
494}
495def eval_opencl(model_info, data, dtype='single', cutoff=0.):
496    # type: (ModelInfo, Data, str, float) -> Calculator
497    """
498    Return a model calculator using the OpenCL calculation engine.
499    """
500    if not core.HAVE_OPENCL:
501        raise RuntimeError("OpenCL not available")
502    model = core.build_model(model_info, dtype=dtype, platform="ocl")
503    calculator = DirectModel(data, model, cutoff=cutoff)
504    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
505    return calculator
506
507def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
508    # type: (ModelInfo, Data, str, float) -> Calculator
509    """
510    Return a model calculator using the DLL calculation engine.
511    """
512    model = core.build_model(model_info, dtype=dtype, platform="dll")
513    calculator = DirectModel(data, model, cutoff=cutoff)
514    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
515    return calculator
516
517def time_calculation(calculator, pars, evals=1):
518    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
519    """
520    Compute the average calculation time over N evaluations.
521
522    An additional call is generated without polydispersity in order to
523    initialize the calculation engine, and make the average more stable.
524    """
525    # initialize the code so time is more accurate
526    if evals > 1:
527        calculator(**suppress_pd(pars))
528    toc = tic()
529    # make sure there is at least one eval
530    value = calculator(**pars)
531    for _ in range(evals-1):
532        value = calculator(**pars)
533    average_time = toc()*1000. / evals
534    #print("I(q)",value)
535    return value, average_time
536
537def make_data(opts):
538    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
539    """
540    Generate an empty dataset, used with the model to set Q points
541    and resolution.
542
543    *opts* contains the options, with 'qmax', 'nq', 'res',
544    'accuracy', 'is2d' and 'view' parsed from the command line.
545    """
546    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
547    if opts['is2d']:
548        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
549        data = empty_data2D(q, resolution=res)
550        data.accuracy = opts['accuracy']
551        set_beam_stop(data, 0.0004)
552        index = ~data.mask
553    else:
554        if opts['view'] == 'log' and not opts['zero']:
555            qmax = math.log10(qmax)
556            q = np.logspace(qmax-3, qmax, nq)
557        else:
558            q = np.linspace(0.001*qmax, qmax, nq)
559        if opts['zero']:
560            q = np.hstack((0, q))
561        data = empty_data1D(q, resolution=res)
562        index = slice(None, None)
563    return data, index
564
565def make_engine(model_info, data, dtype, cutoff):
566    # type: (ModelInfo, Data, str, float) -> Calculator
567    """
568    Generate the appropriate calculation engine for the given datatype.
569
570    Datatypes with '!' appended are evaluated using external C DLLs rather
571    than OpenCL.
572    """
573    if dtype == 'sasview':
574        return eval_sasview(model_info, data)
575    elif dtype.endswith('!'):
576        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
577    else:
578        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
579
580def _show_invalid(data, theory):
581    # type: (Data, np.ma.ndarray) -> None
582    """
583    Display a list of the non-finite values in theory.
584    """
585    if not theory.mask.any():
586        return
587
588    if hasattr(data, 'x'):
589        bad = zip(data.x[theory.mask], theory[theory.mask])
590        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
591
592
593def compare(opts, limits=None):
594    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
595    """
596    Preform a comparison using options from the command line.
597
598    *limits* are the limits on the values to use, either to set the y-axis
599    for 1D or to set the colormap scale for 2D.  If None, then they are
600    inferred from the data and returned. When exploring using Bumps,
601    the limits are set when the model is initially called, and maintained
602    as the values are adjusted, making it easier to see the effects of the
603    parameters.
604    """
605    n_base, n_comp = opts['n1'], opts['n2']
606    pars = opts['pars']
607    data = opts['data']
608
609    # silence the linter
610    base = opts['engines'][0] if n_base else None
611    comp = opts['engines'][1] if n_comp else None
612    base_time = comp_time = None
613    base_value = comp_value = resid = relerr = None
614
615    # Base calculation
616    if n_base > 0:
617        try:
618            base_raw, base_time = time_calculation(base, pars, n_base)
619            base_value = np.ma.masked_invalid(base_raw)
620            print("%s t=%.2f ms, intensity=%.0f"
621                  % (base.engine, base_time, base_value.sum()))
622            _show_invalid(data, base_value)
623        except ImportError:
624            traceback.print_exc()
625            n_base = 0
626
627    # Comparison calculation
628    if n_comp > 0:
629        try:
630            comp_raw, comp_time = time_calculation(comp, pars, n_comp)
631            comp_value = np.ma.masked_invalid(comp_raw)
632            print("%s t=%.2f ms, intensity=%.0f"
633                  % (comp.engine, comp_time, comp_value.sum()))
634            _show_invalid(data, comp_value)
635        except ImportError:
636            traceback.print_exc()
637            n_comp = 0
638
639    # Compare, but only if computing both forms
640    if n_base > 0 and n_comp > 0:
641        resid = (base_value - comp_value)
642        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
643        _print_stats("|%s-%s|"
644                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
645                     resid)
646        _print_stats("|(%s-%s)/%s|"
647                     % (base.engine, comp.engine, comp.engine),
648                     relerr)
649
650    # Plot if requested
651    if not opts['plot'] and not opts['explore']: return
652    view = opts['view']
653    import matplotlib.pyplot as plt
654    if limits is None:
655        vmin, vmax = np.Inf, -np.Inf
656        if n_base > 0:
657            vmin = min(vmin, base_value.min())
658            vmax = max(vmax, base_value.max())
659        if n_comp > 0:
660            vmin = min(vmin, comp_value.min())
661            vmax = max(vmax, comp_value.max())
662        limits = vmin, vmax
663
664    if n_base > 0:
665        if n_comp > 0: plt.subplot(131)
666        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
667        plt.title("%s t=%.2f ms"%(base.engine, base_time))
668        #cbar_title = "log I"
669    if n_comp > 0:
670        if n_base > 0: plt.subplot(132)
671        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
672        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
673        #cbar_title = "log I"
674    if n_comp > 0 and n_base > 0:
675        plt.subplot(133)
676        if not opts['rel_err']:
677            err, errstr, errview = resid, "abs err", "linear"
678        else:
679            err, errstr, errview = abs(relerr), "rel err", "log"
680        #err,errstr = base/comp,"ratio"
681        plot_theory(data, None, resid=err, view=errview, use_data=False)
682        if view == 'linear':
683            plt.xscale('linear')
684        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
685        #cbar_title = errstr if errview=="linear" else "log "+errstr
686    #if is2D:
687    #    h = plt.colorbar()
688    #    h.ax.set_title(cbar_title)
689    fig = plt.gcf()
690    fig.suptitle(opts['name'])
691
692    if n_comp > 0 and n_base > 0 and '-hist' in opts:
693        plt.figure()
694        v = relerr
695        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
696        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
697        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
698                   % (base.engine, comp.engine, comp.engine))
699        plt.ylabel('P(err)')
700        plt.title('Distribution of relative error between calculation engines')
701
702    if not opts['explore']:
703        plt.show()
704
705    return limits
706
707def _print_stats(label, err):
708    # type: (str, np.ma.ndarray) -> None
709    # work with trimmed data, not the full set
710    sorted_err = np.sort(abs(err.compressed()))
711    p50 = int((len(sorted_err)-1)*0.50)
712    p98 = int((len(sorted_err)-1)*0.98)
713    data = [
714        "max:%.3e"%sorted_err[-1],
715        "median:%.3e"%sorted_err[p50],
716        "98%%:%.3e"%sorted_err[p98],
717        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
718        "zero-offset:%+.3e"%np.mean(sorted_err),
719        ]
720    print(label+"  "+"  ".join(data))
721
722
723
724# ===========================================================================
725#
726NAME_OPTIONS = set([
727    'plot', 'noplot',
728    'half', 'fast', 'single', 'double',
729    'single!', 'double!', 'quad!', 'sasview',
730    'lowq', 'midq', 'highq', 'exq', 'zero',
731    '2d', '1d',
732    'preset', 'random',
733    'poly', 'mono',
734    'nopars', 'pars',
735    'rel', 'abs',
736    'linear', 'log', 'q4',
737    'hist', 'nohist',
738    'edit',
739    'demo', 'default',
740    ])
741VALUE_OPTIONS = [
742    # Note: random is both a name option and a value option
743    'cutoff', 'random', 'nq', 'res', 'accuracy',
744    ]
745
746def columnize(items, indent="", width=79):
747    # type: (List[str], str, int) -> str
748    """
749    Format a list of strings into columns.
750
751    Returns a string with carriage returns ready for printing.
752    """
753    column_width = max(len(w) for w in items) + 1
754    num_columns = (width - len(indent)) // column_width
755    num_rows = len(items) // num_columns
756    items = items + [""] * (num_rows * num_columns - len(items))
757    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
758    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
759             for row in zip(*columns)]
760    output = indent + ("\n"+indent).join(lines)
761    return output
762
763
764def get_pars(model_info, use_demo=False):
765    # type: (ModelInfo, bool) -> ParameterSet
766    """
767    Extract demo parameters from the model definition.
768    """
769    # Get the default values for the parameters
770    pars = {}
771    for p in model_info.parameters.call_parameters:
772        parts = [('', p.default)]
773        if p.polydisperse:
774            parts.append(('_pd', 0.0))
775            parts.append(('_pd_n', 0))
776            parts.append(('_pd_nsigma', 3.0))
777            parts.append(('_pd_type', "gaussian"))
778        for ext, val in parts:
779            if p.length > 1:
780                dict(("%s%d%s" % (p.id, k, ext), val)
781                     for k in range(1, p.length+1))
782            else:
783                pars[p.id + ext] = val
784
785    # Plug in values given in demo
786    if use_demo:
787        pars.update(model_info.demo)
788    return pars
789
790
791def parse_opts():
792    # type: () -> Dict[str, Any]
793    """
794    Parse command line options.
795    """
796    MODELS = core.list_models()
797    flags = [arg for arg in sys.argv[1:]
798             if arg.startswith('-')]
799    values = [arg for arg in sys.argv[1:]
800              if not arg.startswith('-') and '=' in arg]
801    args = [arg for arg in sys.argv[1:]
802            if not arg.startswith('-') and '=' not in arg]
803    models = "\n    ".join("%-15s"%v for v in MODELS)
804    if len(args) == 0:
805        print(USAGE)
806        print("\nAvailable models:")
807        print(columnize(MODELS, indent="  "))
808        sys.exit(1)
809    if len(args) > 3:
810        print("expected parameters: model N1 N2")
811
812    name = args[0]
813    try:
814        model_info = core.load_model_info(name)
815    except ImportError as exc:
816        print(str(exc))
817        print("Could not find model; use one of:\n    " + models)
818        sys.exit(1)
819
820    invalid = [o[1:] for o in flags
821               if o[1:] not in NAME_OPTIONS
822               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
823    if invalid:
824        print("Invalid options: %s"%(", ".join(invalid)))
825        sys.exit(1)
826
827
828    # pylint: disable=bad-whitespace
829    # Interpret the flags
830    opts = {
831        'plot'      : True,
832        'view'      : 'log',
833        'is2d'      : False,
834        'qmax'      : 0.05,
835        'nq'        : 128,
836        'res'       : 0.0,
837        'accuracy'  : 'Low',
838        'cutoff'    : 0.0,
839        'seed'      : -1,  # default to preset
840        'mono'      : False,
841        'show_pars' : False,
842        'show_hist' : False,
843        'rel_err'   : True,
844        'explore'   : False,
845        'use_demo'  : True,
846        'zero'      : False,
847    }
848    engines = []
849    for arg in flags:
850        if arg == '-noplot':    opts['plot'] = False
851        elif arg == '-plot':    opts['plot'] = True
852        elif arg == '-linear':  opts['view'] = 'linear'
853        elif arg == '-log':     opts['view'] = 'log'
854        elif arg == '-q4':      opts['view'] = 'q4'
855        elif arg == '-1d':      opts['is2d'] = False
856        elif arg == '-2d':      opts['is2d'] = True
857        elif arg == '-exq':     opts['qmax'] = 10.0
858        elif arg == '-highq':   opts['qmax'] = 1.0
859        elif arg == '-midq':    opts['qmax'] = 0.2
860        elif arg == '-lowq':    opts['qmax'] = 0.05
861        elif arg == '-zero':    opts['zero'] = True
862        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
863        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
864        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
865        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
866        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
867        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
868        elif arg == '-preset':  opts['seed'] = -1
869        elif arg == '-mono':    opts['mono'] = True
870        elif arg == '-poly':    opts['mono'] = False
871        elif arg == '-pars':    opts['show_pars'] = True
872        elif arg == '-nopars':  opts['show_pars'] = False
873        elif arg == '-hist':    opts['show_hist'] = True
874        elif arg == '-nohist':  opts['show_hist'] = False
875        elif arg == '-rel':     opts['rel_err'] = True
876        elif arg == '-abs':     opts['rel_err'] = False
877        elif arg == '-half':    engines.append(arg[1:])
878        elif arg == '-fast':    engines.append(arg[1:])
879        elif arg == '-single':  engines.append(arg[1:])
880        elif arg == '-double':  engines.append(arg[1:])
881        elif arg == '-single!': engines.append(arg[1:])
882        elif arg == '-double!': engines.append(arg[1:])
883        elif arg == '-quad!':   engines.append(arg[1:])
884        elif arg == '-sasview': engines.append(arg[1:])
885        elif arg == '-edit':    opts['explore'] = True
886        elif arg == '-demo':    opts['use_demo'] = True
887        elif arg == '-default':    opts['use_demo'] = False
888    # pylint: enable=bad-whitespace
889
890    if len(engines) == 0:
891        engines.extend(['single', 'double'])
892    elif len(engines) == 1:
893        if engines[0][0] == 'double':
894            engines.append('single')
895        else:
896            engines.append('double')
897    elif len(engines) > 2:
898        del engines[2:]
899
900    n1 = int(args[1]) if len(args) > 1 else 1
901    n2 = int(args[2]) if len(args) > 2 else 1
902    use_sasview = any(engine == 'sasview' and count > 0
903                      for engine, count in zip(engines, [n1, n2]))
904
905    # Get demo parameters from model definition, or use default parameters
906    # if model does not define demo parameters
907    pars = get_pars(model_info, opts['use_demo'])
908
909
910    # Fill in parameters given on the command line
911    presets = {}
912    for arg in values:
913        k, v = arg.split('=', 1)
914        if k not in pars:
915            # extract base name without polydispersity info
916            s = set(p.split('_pd')[0] for p in pars)
917            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
918            sys.exit(1)
919        presets[k] = float(v) if not k.endswith('type') else v
920
921    # randomize parameters
922    #pars.update(set_pars)  # set value before random to control range
923    if opts['seed'] > -1:
924        pars = randomize_pars(model_info, pars, seed=opts['seed'])
925        print("Randomize using -random=%i"%opts['seed'])
926    if opts['mono']:
927        pars = suppress_pd(pars)
928    pars.update(presets)  # set value after random to control value
929    #import pprint; pprint.pprint(model_info)
930    constrain_pars(model_info, pars)
931    if use_sasview:
932        constrain_new_to_old(model_info, pars)
933    if opts['show_pars']:
934        print(str(parlist(model_info, pars, opts['is2d'])))
935
936    # Create the computational engines
937    data, _ = make_data(opts)
938    if n1:
939        base = make_engine(model_info, data, engines[0], opts['cutoff'])
940    else:
941        base = None
942    if n2:
943        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
944    else:
945        comp = None
946
947    # pylint: disable=bad-whitespace
948    # Remember it all
949    opts.update({
950        'name'      : name,
951        'def'       : model_info,
952        'n1'        : n1,
953        'n2'        : n2,
954        'presets'   : presets,
955        'pars'      : pars,
956        'data'      : data,
957        'engines'   : [base, comp],
958    })
959    # pylint: enable=bad-whitespace
960
961    return opts
962
963def explore(opts):
964    # type: (Dict[str, Any]) -> None
965    """
966    Explore the model using the Bumps GUI.
967    """
968    import wx  # type: ignore
969    from bumps.names import FitProblem  # type: ignore
970    from bumps.gui.app_frame import AppFrame  # type: ignore
971
972    problem = FitProblem(Explore(opts))
973    is_mac = "cocoa" in wx.version()
974    app = wx.App()
975    frame = AppFrame(parent=None, title="explore")
976    if not is_mac: frame.Show()
977    frame.panel.set_model(model=problem)
978    frame.panel.Layout()
979    frame.panel.aui.Split(0, wx.TOP)
980    if is_mac: frame.Show()
981    app.MainLoop()
982
983class Explore(object):
984    """
985    Bumps wrapper for a SAS model comparison.
986
987    The resulting object can be used as a Bumps fit problem so that
988    parameters can be adjusted in the GUI, with plots updated on the fly.
989    """
990    def __init__(self, opts):
991        # type: (Dict[str, Any]) -> None
992        from bumps.cli import config_matplotlib  # type: ignore
993        from . import bumps_model
994        config_matplotlib()
995        self.opts = opts
996        model_info = opts['def']
997        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
998        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
999        if not opts['is2d']:
1000            for p in model_info.parameters.user_parameters(is2d=False):
1001                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1002                    k = p.name+ext
1003                    v = pars.get(k, None)
1004                    if v is not None:
1005                        v.range(*parameter_range(k, v.value))
1006        else:
1007            for k, v in pars.items():
1008                v.range(*parameter_range(k, v.value))
1009
1010        self.pars = pars
1011        self.pd_types = pd_types
1012        self.limits = None
1013
1014    def numpoints(self):
1015        # type: () -> int
1016        """
1017        Return the number of points.
1018        """
1019        return len(self.pars) + 1  # so dof is 1
1020
1021    def parameters(self):
1022        # type: () -> Any   # Dict/List hierarchy of parameters
1023        """
1024        Return a dictionary of parameters.
1025        """
1026        return self.pars
1027
1028    def nllf(self):
1029        # type: () -> float
1030        """
1031        Return cost.
1032        """
1033        # pylint: disable=no-self-use
1034        return 0.  # No nllf
1035
1036    def plot(self, view='log'):
1037        # type: (str) -> None
1038        """
1039        Plot the data and residuals.
1040        """
1041        pars = dict((k, v.value) for k, v in self.pars.items())
1042        pars.update(self.pd_types)
1043        self.opts['pars'] = pars
1044        limits = compare(self.opts, limits=self.limits)
1045        if self.limits is None:
1046            vmin, vmax = limits
1047            self.limits = vmax*1e-7, 1.3*vmax
1048
1049
1050def main():
1051    # type: () -> None
1052    """
1053    Main program.
1054    """
1055    opts = parse_opts()
1056    if opts['explore']:
1057        explore(opts)
1058    else:
1059        compare(opts)
1060
1061if __name__ == "__main__":
1062    main()
Note: See TracBrowser for help on using the repository browser.