source: sasmodels/sasmodels/compare.py @ ff1fff5

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ff1fff5 was ff1fff5, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

allow comparison of different models, such as 'sascomp sphere:ellipsoid radius_polar=:radius radius_equatorial=:radius -mono'

  • Property mode set to 100755
File size: 38.9 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35import re
36
37import numpy as np  # type: ignore
38
39from . import core
40from . import kerneldll
41from . import exception
42from .data import plot_theory, empty_data1D, empty_data2D
43from .direct_model import DirectModel
44from .convert import revert_name, revert_pars, constrain_new_to_old
45from .generate import FLOAT_RE
46
47try:
48    from typing import Optional, Dict, Any, Callable, Tuple
49except Exception:
50    pass
51else:
52    from .modelinfo import ModelInfo, Parameter, ParameterSet
53    from .data import Data
54    Calculator = Callable[[float], np.ndarray]
55
56USAGE = """
57usage: compare.py model N1 N2 [options...] [key=val]
58
59Compare the speed and value for a model between the SasView original and the
60sasmodels rewrite.
61
62model is the name of the model to compare (see below).
63N1 is the number of times to run sasmodels (default=1).
64N2 is the number times to run sasview (default=1).
65
66Options (* for default):
67
68    -plot*/-noplot plots or suppress the plot of the model
69    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
70    -nq=128 sets the number of Q points in the data set
71    -zero indicates that q=0 should be included
72    -1d*/-2d computes 1d or 2d data
73    -preset*/-random[=seed] preset or random parameters
74    -mono/-poly* force monodisperse/polydisperse
75    -cutoff=1e-5* cutoff value for including a point in polydispersity
76    -pars/-nopars* prints the parameter set or not
77    -abs/-rel* plot relative or absolute error
78    -linear/-log*/-q4 intensity scaling
79    -hist/-nohist* plot histogram of relative error
80    -res=0 sets the resolution width dQ/Q if calculating with resolution
81    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
82    -edit starts the parameter explorer
83    -default/-demo* use demo vs default parameters
84    -html shows the model docs instead of running the model
85    -title="note" adds note to the plot title, after the model name
86
87Any two calculation engines can be selected for comparison:
88
89    -single/-double/-half/-fast sets an OpenCL calculation engine
90    -single!/-double!/-quad! sets an OpenMP calculation engine
91    -sasview sets the sasview calculation engine
92
93The default is -single -sasview.  Note that the interpretation of quad
94precision depends on architecture, and may vary from 64-bit to 128-bit,
95with 80-bit floats being common (1e-19 precision).
96
97Key=value pairs allow you to set specific values for the model parameters.
98"""
99
100# Update docs with command line usage string.   This is separate from the usual
101# doc string so that we can display it at run time if there is an error.
102# lin
103__doc__ = (__doc__  # pylint: disable=redefined-builtin
104           + """
105Program description
106-------------------
107
108"""
109           + USAGE)
110
111kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
112
113# CRUFT python 2.6
114if not hasattr(datetime.timedelta, 'total_seconds'):
115    def delay(dt):
116        """Return number date-time delta as number seconds"""
117        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
118else:
119    def delay(dt):
120        """Return number date-time delta as number seconds"""
121        return dt.total_seconds()
122
123
124class push_seed(object):
125    """
126    Set the seed value for the random number generator.
127
128    When used in a with statement, the random number generator state is
129    restored after the with statement is complete.
130
131    :Parameters:
132
133    *seed* : int or array_like, optional
134        Seed for RandomState
135
136    :Example:
137
138    Seed can be used directly to set the seed::
139
140        >>> from numpy.random import randint
141        >>> push_seed(24)
142        <...push_seed object at...>
143        >>> print(randint(0,1000000,3))
144        [242082    899 211136]
145
146    Seed can also be used in a with statement, which sets the random
147    number generator state for the enclosed computations and restores
148    it to the previous state on completion::
149
150        >>> with push_seed(24):
151        ...    print(randint(0,1000000,3))
152        [242082    899 211136]
153
154    Using nested contexts, we can demonstrate that state is indeed
155    restored after the block completes::
156
157        >>> with push_seed(24):
158        ...    print(randint(0,1000000))
159        ...    with push_seed(24):
160        ...        print(randint(0,1000000,3))
161        ...    print(randint(0,1000000))
162        242082
163        [242082    899 211136]
164        899
165
166    The restore step is protected against exceptions in the block::
167
168        >>> with push_seed(24):
169        ...    print(randint(0,1000000))
170        ...    try:
171        ...        with push_seed(24):
172        ...            print(randint(0,1000000,3))
173        ...            raise Exception()
174        ...    except Exception:
175        ...        print("Exception raised")
176        ...    print(randint(0,1000000))
177        242082
178        [242082    899 211136]
179        Exception raised
180        899
181    """
182    def __init__(self, seed=None):
183        # type: (Optional[int]) -> None
184        self._state = np.random.get_state()
185        np.random.seed(seed)
186
187    def __enter__(self):
188        # type: () -> None
189        pass
190
191    def __exit__(self, exc_type, exc_value, traceback):
192        # type: (Any, BaseException, Any) -> None
193        # TODO: better typing for __exit__ method
194        np.random.set_state(self._state)
195
196def tic():
197    # type: () -> Callable[[], float]
198    """
199    Timer function.
200
201    Use "toc=tic()" to start the clock and "toc()" to measure
202    a time interval.
203    """
204    then = datetime.datetime.now()
205    return lambda: delay(datetime.datetime.now() - then)
206
207
208def set_beam_stop(data, radius, outer=None):
209    # type: (Data, float, float) -> None
210    """
211    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
212
213    Note: this function does not require sasview
214    """
215    if hasattr(data, 'qx_data'):
216        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
217        data.mask = (q < radius)
218        if outer is not None:
219            data.mask |= (q >= outer)
220    else:
221        data.mask = (data.x < radius)
222        if outer is not None:
223            data.mask |= (data.x >= outer)
224
225
226def parameter_range(p, v):
227    # type: (str, float) -> Tuple[float, float]
228    """
229    Choose a parameter range based on parameter name and initial value.
230    """
231    # process the polydispersity options
232    if p.endswith('_pd_n'):
233        return 0., 100.
234    elif p.endswith('_pd_nsigma'):
235        return 0., 5.
236    elif p.endswith('_pd_type'):
237        raise ValueError("Cannot return a range for a string value")
238    elif any(s in p for s in ('theta', 'phi', 'psi')):
239        # orientation in [-180,180], orientation pd in [0,45]
240        if p.endswith('_pd'):
241            return 0., 45.
242        else:
243            return -180., 180.
244    elif p.endswith('_pd'):
245        return 0., 1.
246    elif 'sld' in p:
247        return -0.5, 10.
248    elif p == 'background':
249        return 0., 10.
250    elif p == 'scale':
251        return 0., 1.e3
252    elif v < 0.:
253        return 2.*v, -2.*v
254    else:
255        return 0., (2.*v if v > 0. else 1.)
256
257
258def _randomize_one(model_info, p, v):
259    # type: (ModelInfo, str, float) -> float
260    # type: (ModelInfo, str, str) -> str
261    """
262    Randomize a single parameter.
263    """
264    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
265        return v
266
267    # Find the parameter definition
268    for par in model_info.parameters.call_parameters:
269        if par.name == p:
270            break
271    else:
272        raise ValueError("unknown parameter %r"%p)
273
274    if len(par.limits) > 2:  # choice list
275        return np.random.randint(len(par.limits))
276
277    limits = par.limits
278    if not np.isfinite(limits).all():
279        low, high = parameter_range(p, v)
280        limits = (max(limits[0], low), min(limits[1], high))
281
282    return np.random.uniform(*limits)
283
284
285def randomize_pars(model_info, pars, seed=None):
286    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
287    """
288    Generate random values for all of the parameters.
289
290    Valid ranges for the random number generator are guessed from the name of
291    the parameter; this will not account for constraints such as cap radius
292    greater than cylinder radius in the capped_cylinder model, so
293    :func:`constrain_pars` needs to be called afterward..
294    """
295    with push_seed(seed):
296        # Note: the sort guarantees order `of calls to random number generator
297        random_pars = dict((p, _randomize_one(model_info, p, v))
298                           for p, v in sorted(pars.items()))
299    return random_pars
300
301def constrain_pars(model_info, pars):
302    # type: (ModelInfo, ParameterSet) -> None
303    """
304    Restrict parameters to valid values.
305
306    This includes model specific code for models such as capped_cylinder
307    which need to support within model constraints (cap radius more than
308    cylinder radius in this case).
309
310    Warning: this updates the *pars* dictionary in place.
311    """
312    name = model_info.id
313    # if it is a product model, then just look at the form factor since
314    # none of the structure factors need any constraints.
315    if '*' in name:
316        name = name.split('*')[0]
317
318    if name == 'capped_cylinder' and pars['radius_cap'] < pars['radius']:
319        pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
320    if name == 'barbell' and pars['radius_bell'] < pars['radius']:
321        pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
322
323    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
324    if name == 'guinier':
325        #q_max = 0.2  # mid q maximum
326        q_max = 1.0  # high q maximum
327        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
328        pars['rg'] = min(pars['rg'], rg_max)
329
330    if name == 'rpa':
331        # Make sure phi sums to 1.0
332        if pars['case_num'] < 2:
333            pars['Phi1'] = 0.
334            pars['Phi2'] = 0.
335        elif pars['case_num'] < 5:
336            pars['Phi1'] = 0.
337        total = sum(pars['Phi'+c] for c in '1234')
338        for c in '1234':
339            pars['Phi'+c] /= total
340
341def parlist(model_info, pars, is2d):
342    # type: (ModelInfo, ParameterSet, bool) -> str
343    """
344    Format the parameter list for printing.
345    """
346    lines = []
347    parameters = model_info.parameters
348    for p in parameters.user_parameters(pars, is2d):
349        fields = dict(
350            value=pars.get(p.id, p.default),
351            pd=pars.get(p.id+"_pd", 0.),
352            n=int(pars.get(p.id+"_pd_n", 0)),
353            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
354            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
355            relative_pd=p.relative_pd,
356        )
357        lines.append(_format_par(p.name, **fields))
358    return "\n".join(lines)
359
360    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
361
362def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
363                relative_pd=False):
364    # type: (str, float, float, int, float, str) -> str
365    line = "%s: %g"%(name, value)
366    if pd != 0.  and n != 0:
367        if relative_pd:
368            pd *= value
369        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
370                % (pd, n, nsigma, nsigma, pdtype)
371    return line
372
373def suppress_pd(pars):
374    # type: (ParameterSet) -> ParameterSet
375    """
376    Suppress theta_pd for now until the normalization is resolved.
377
378    May also suppress complete polydispersity of the model to test
379    models more quickly.
380    """
381    pars = pars.copy()
382    for p in pars:
383        if p.endswith("_pd_n"): pars[p] = 0
384    return pars
385
386def eval_sasview(model_info, data):
387    # type: (Modelinfo, Data) -> Calculator
388    """
389    Return a model calculator using the pre-4.0 SasView models.
390    """
391    # importing sas here so that the error message will be that sas failed to
392    # import rather than the more obscure smear_selection not imported error
393    import sas
394    import sas.models
395    from sas.models.qsmearing import smear_selection
396    from sas.models.MultiplicationModel import MultiplicationModel
397    from sas.models.dispersion_models import models as dispersers
398
399    def get_model_class(name):
400        # type: (str) -> "sas.models.BaseComponent"
401        #print("new",sorted(_pars.items()))
402        __import__('sas.models.' + name)
403        ModelClass = getattr(getattr(sas.models, name, None), name, None)
404        if ModelClass is None:
405            raise ValueError("could not find model %r in sas.models"%name)
406        return ModelClass
407
408    # WARNING: ugly hack when handling model!
409    # Sasview models with multiplicity need to be created with the target
410    # multiplicity, so we cannot create the target model ahead of time for
411    # for multiplicity models.  Instead we store the model in a list and
412    # update the first element of that list with the new multiplicity model
413    # every time we evaluate.
414
415    # grab the sasview model, or create it if it is a product model
416    if model_info.composition:
417        composition_type, parts = model_info.composition
418        if composition_type == 'product':
419            P, S = [get_model_class(revert_name(p))() for p in parts]
420            model = [MultiplicationModel(P, S)]
421        else:
422            raise ValueError("sasview mixture models not supported by compare")
423    else:
424        old_name = revert_name(model_info)
425        if old_name is None:
426            raise ValueError("model %r does not exist in old sasview"
427                            % model_info.id)
428        ModelClass = get_model_class(old_name)
429        model = [ModelClass()]
430    model[0].disperser_handles = {}
431
432    # build a smearer with which to call the model, if necessary
433    smearer = smear_selection(data, model=model)
434    if hasattr(data, 'qx_data'):
435        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
436        index = ((~data.mask) & (~np.isnan(data.data))
437                 & (q >= data.qmin) & (q <= data.qmax))
438        if smearer is not None:
439            smearer.model = model  # because smear_selection has a bug
440            smearer.accuracy = data.accuracy
441            smearer.set_index(index)
442            def _call_smearer():
443                smearer.model = model[0]
444                return smearer.get_value()
445            theory = _call_smearer
446        else:
447            theory = lambda: model[0].evalDistribution([data.qx_data[index],
448                                                        data.qy_data[index]])
449    elif smearer is not None:
450        theory = lambda: smearer(model[0].evalDistribution(data.x))
451    else:
452        theory = lambda: model[0].evalDistribution(data.x)
453
454    def calculator(**pars):
455        # type: (float, ...) -> np.ndarray
456        """
457        Sasview calculator for model.
458        """
459        oldpars = revert_pars(model_info, pars)
460        # For multiplicity models, create a model with the correct multiplicity
461        control = oldpars.pop("CONTROL", None)
462        if control is not None:
463            # sphericalSLD has one fewer multiplicity.  This update should
464            # happen in revert_pars, but it hasn't been called yet.
465            model[0] = ModelClass(control)
466        # paying for parameter conversion each time to keep life simple, if not fast
467        for k, v in oldpars.items():
468            if k.endswith('.type'):
469                par = k[:-5]
470                if v == 'gaussian': continue
471                cls = dispersers[v if v != 'rectangle' else 'rectangula']
472                handle = cls()
473                model[0].disperser_handles[par] = handle
474                try:
475                    model[0].set_dispersion(par, handle)
476                except Exception:
477                    exception.annotate_exception("while setting %s to %r"
478                                                 %(par, v))
479                    raise
480
481
482        #print("sasview pars",oldpars)
483        for k, v in oldpars.items():
484            name_attr = k.split('.')  # polydispersity components
485            if len(name_attr) == 2:
486                par, disp_par = name_attr
487                model[0].dispersion[par][disp_par] = v
488            else:
489                model[0].setParam(k, v)
490        return theory()
491
492    calculator.engine = "sasview"
493    return calculator
494
495DTYPE_MAP = {
496    'half': '16',
497    'fast': 'fast',
498    'single': '32',
499    'double': '64',
500    'quad': '128',
501    'f16': '16',
502    'f32': '32',
503    'f64': '64',
504    'longdouble': '128',
505}
506def eval_opencl(model_info, data, dtype='single', cutoff=0.):
507    # type: (ModelInfo, Data, str, float) -> Calculator
508    """
509    Return a model calculator using the OpenCL calculation engine.
510    """
511    if not core.HAVE_OPENCL:
512        raise RuntimeError("OpenCL not available")
513    model = core.build_model(model_info, dtype=dtype, platform="ocl")
514    calculator = DirectModel(data, model, cutoff=cutoff)
515    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
516    return calculator
517
518def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
519    # type: (ModelInfo, Data, str, float) -> Calculator
520    """
521    Return a model calculator using the DLL calculation engine.
522    """
523    model = core.build_model(model_info, dtype=dtype, platform="dll")
524    calculator = DirectModel(data, model, cutoff=cutoff)
525    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
526    return calculator
527
528def time_calculation(calculator, pars, evals=1):
529    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
530    """
531    Compute the average calculation time over N evaluations.
532
533    An additional call is generated without polydispersity in order to
534    initialize the calculation engine, and make the average more stable.
535    """
536    # initialize the code so time is more accurate
537    if evals > 1:
538        calculator(**suppress_pd(pars))
539    toc = tic()
540    # make sure there is at least one eval
541    value = calculator(**pars)
542    for _ in range(evals-1):
543        value = calculator(**pars)
544    average_time = toc()*1000. / evals
545    #print("I(q)",value)
546    return value, average_time
547
548def make_data(opts):
549    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
550    """
551    Generate an empty dataset, used with the model to set Q points
552    and resolution.
553
554    *opts* contains the options, with 'qmax', 'nq', 'res',
555    'accuracy', 'is2d' and 'view' parsed from the command line.
556    """
557    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
558    if opts['is2d']:
559        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
560        data = empty_data2D(q, resolution=res)
561        data.accuracy = opts['accuracy']
562        set_beam_stop(data, 0.0004)
563        index = ~data.mask
564    else:
565        if opts['view'] == 'log' and not opts['zero']:
566            qmax = math.log10(qmax)
567            q = np.logspace(qmax-3, qmax, nq)
568        else:
569            q = np.linspace(0.001*qmax, qmax, nq)
570        if opts['zero']:
571            q = np.hstack((0, q))
572        data = empty_data1D(q, resolution=res)
573        index = slice(None, None)
574    return data, index
575
576def make_engine(model_info, data, dtype, cutoff):
577    # type: (ModelInfo, Data, str, float) -> Calculator
578    """
579    Generate the appropriate calculation engine for the given datatype.
580
581    Datatypes with '!' appended are evaluated using external C DLLs rather
582    than OpenCL.
583    """
584    if dtype == 'sasview':
585        return eval_sasview(model_info, data)
586    elif dtype.endswith('!'):
587        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
588    else:
589        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
590
591def _show_invalid(data, theory):
592    # type: (Data, np.ma.ndarray) -> None
593    """
594    Display a list of the non-finite values in theory.
595    """
596    if not theory.mask.any():
597        return
598
599    if hasattr(data, 'x'):
600        bad = zip(data.x[theory.mask], theory[theory.mask])
601        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
602
603
604def compare(opts, limits=None):
605    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
606    """
607    Preform a comparison using options from the command line.
608
609    *limits* are the limits on the values to use, either to set the y-axis
610    for 1D or to set the colormap scale for 2D.  If None, then they are
611    inferred from the data and returned. When exploring using Bumps,
612    the limits are set when the model is initially called, and maintained
613    as the values are adjusted, making it easier to see the effects of the
614    parameters.
615    """
616    n_base, n_comp = opts['count']
617    pars, pars2 = opts['pars']
618    data = opts['data']
619
620    # silence the linter
621    base = opts['engines'][0] if n_base else None
622    comp = opts['engines'][1] if n_comp else None
623    base_time = comp_time = None
624    base_value = comp_value = resid = relerr = None
625
626    # Base calculation
627    if n_base > 0:
628        try:
629            base_raw, base_time = time_calculation(base, pars, n_base)
630            base_value = np.ma.masked_invalid(base_raw)
631            print("%s t=%.2f ms, intensity=%.0f"
632                  % (base.engine, base_time, base_value.sum()))
633            _show_invalid(data, base_value)
634        except ImportError:
635            traceback.print_exc()
636            n_base = 0
637
638    # Comparison calculation
639    if n_comp > 0:
640        try:
641            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
642            comp_value = np.ma.masked_invalid(comp_raw)
643            print("%s t=%.2f ms, intensity=%.0f"
644                  % (comp.engine, comp_time, comp_value.sum()))
645            _show_invalid(data, comp_value)
646        except ImportError:
647            traceback.print_exc()
648            n_comp = 0
649
650    # Compare, but only if computing both forms
651    if n_base > 0 and n_comp > 0:
652        resid = (base_value - comp_value)
653        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
654        _print_stats("|%s-%s|"
655                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
656                     resid)
657        _print_stats("|(%s-%s)/%s|"
658                     % (base.engine, comp.engine, comp.engine),
659                     relerr)
660
661    # Plot if requested
662    if not opts['plot'] and not opts['explore']: return
663    view = opts['view']
664    import matplotlib.pyplot as plt
665    if limits is None:
666        vmin, vmax = np.Inf, -np.Inf
667        if n_base > 0:
668            vmin = min(vmin, base_value.min())
669            vmax = max(vmax, base_value.max())
670        if n_comp > 0:
671            vmin = min(vmin, comp_value.min())
672            vmax = max(vmax, comp_value.max())
673        limits = vmin, vmax
674
675    if n_base > 0:
676        if n_comp > 0: plt.subplot(131)
677        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
678        plt.title("%s t=%.2f ms"%(base.engine, base_time))
679        #cbar_title = "log I"
680    if n_comp > 0:
681        if n_base > 0: plt.subplot(132)
682        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
683        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
684        #cbar_title = "log I"
685    if n_comp > 0 and n_base > 0:
686        plt.subplot(133)
687        if not opts['rel_err']:
688            err, errstr, errview = resid, "abs err", "linear"
689        else:
690            err, errstr, errview = abs(relerr), "rel err", "log"
691        #sorted = np.sort(err.flatten())
692        #cutoff = sorted[int(sorted.size*0.95)]
693        #err[err>cutoff] = cutoff
694        #err,errstr = base/comp,"ratio"
695        plot_theory(data, None, resid=err, view=errview, use_data=False)
696        if view == 'linear':
697            plt.xscale('linear')
698        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
699        #cbar_title = errstr if errview=="linear" else "log "+errstr
700    #if is2D:
701    #    h = plt.colorbar()
702    #    h.ax.set_title(cbar_title)
703    fig = plt.gcf()
704    extra_title = ' '+opts['title'] if opts['title'] else ''
705    fig.suptitle(":".join(opts['name']) + extra_title)
706
707    if n_comp > 0 and n_base > 0 and opts['show_hist']:
708        plt.figure()
709        v = relerr
710        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
711        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
712        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
713                   % (base.engine, comp.engine, comp.engine))
714        plt.ylabel('P(err)')
715        plt.title('Distribution of relative error between calculation engines')
716
717    if not opts['explore']:
718        plt.show()
719
720    return limits
721
722def _print_stats(label, err):
723    # type: (str, np.ma.ndarray) -> None
724    # work with trimmed data, not the full set
725    sorted_err = np.sort(abs(err.compressed()))
726    p50 = int((len(sorted_err)-1)*0.50)
727    p98 = int((len(sorted_err)-1)*0.98)
728    data = [
729        "max:%.3e"%sorted_err[-1],
730        "median:%.3e"%sorted_err[p50],
731        "98%%:%.3e"%sorted_err[p98],
732        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
733        "zero-offset:%+.3e"%np.mean(sorted_err),
734        ]
735    print(label+"  "+"  ".join(data))
736
737
738
739# ===========================================================================
740#
741NAME_OPTIONS = set([
742    'plot', 'noplot',
743    'half', 'fast', 'single', 'double',
744    'single!', 'double!', 'quad!', 'sasview',
745    'lowq', 'midq', 'highq', 'exq', 'zero',
746    '2d', '1d',
747    'preset', 'random',
748    'poly', 'mono',
749    'nopars', 'pars',
750    'rel', 'abs',
751    'linear', 'log', 'q4',
752    'hist', 'nohist',
753    'edit', 'html',
754    'demo', 'default',
755    ])
756VALUE_OPTIONS = [
757    # Note: random is both a name option and a value option
758    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title',
759    ]
760
761def columnize(items, indent="", width=79):
762    # type: (List[str], str, int) -> str
763    """
764    Format a list of strings into columns.
765
766    Returns a string with carriage returns ready for printing.
767    """
768    column_width = max(len(w) for w in items) + 1
769    num_columns = (width - len(indent)) // column_width
770    num_rows = len(items) // num_columns
771    items = items + [""] * (num_rows * num_columns - len(items))
772    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
773    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
774             for row in zip(*columns)]
775    output = indent + ("\n"+indent).join(lines)
776    return output
777
778
779def get_pars(model_info, use_demo=False):
780    # type: (ModelInfo, bool) -> ParameterSet
781    """
782    Extract demo parameters from the model definition.
783    """
784    # Get the default values for the parameters
785    pars = {}
786    for p in model_info.parameters.call_parameters:
787        parts = [('', p.default)]
788        if p.polydisperse:
789            parts.append(('_pd', 0.0))
790            parts.append(('_pd_n', 0))
791            parts.append(('_pd_nsigma', 3.0))
792            parts.append(('_pd_type', "gaussian"))
793        for ext, val in parts:
794            if p.length > 1:
795                dict(("%s%d%s" % (p.id, k, ext), val)
796                     for k in range(1, p.length+1))
797            else:
798                pars[p.id + ext] = val
799
800    # Plug in values given in demo
801    if use_demo:
802        pars.update(model_info.demo)
803    return pars
804
805INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
806def isnumber(str):
807    match = FLOAT_RE.match(str)
808    isfloat = (match and not str[match.end():])
809    return isfloat or INTEGER_RE.match(str)
810
811def parse_opts(argv):
812    # type: (List[str]) -> Dict[str, Any]
813    """
814    Parse command line options.
815    """
816    MODELS = core.list_models()
817    flags = [arg for arg in argv
818             if arg.startswith('-')]
819    values = [arg for arg in argv
820              if not arg.startswith('-') and '=' in arg]
821    positional_args = [arg for arg in argv
822            if not arg.startswith('-') and '=' not in arg]
823    models = "\n    ".join("%-15s"%v for v in MODELS)
824    if len(positional_args) == 0:
825        print(USAGE)
826        print("\nAvailable models:")
827        print(columnize(MODELS, indent="  "))
828        return None
829    if len(positional_args) > 3:
830        print("expected parameters: model N1 N2")
831
832    invalid = [o[1:] for o in flags
833               if o[1:] not in NAME_OPTIONS
834               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
835    if invalid:
836        print("Invalid options: %s"%(", ".join(invalid)))
837        return None
838
839    name = positional_args[0]
840    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
841    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
842
843    # pylint: disable=bad-whitespace
844    # Interpret the flags
845    opts = {
846        'plot'      : True,
847        'view'      : 'log',
848        'is2d'      : False,
849        'qmax'      : 0.05,
850        'nq'        : 128,
851        'res'       : 0.0,
852        'accuracy'  : 'Low',
853        'cutoff'    : 0.0,
854        'seed'      : -1,  # default to preset
855        'mono'      : False,
856        'show_pars' : False,
857        'show_hist' : False,
858        'rel_err'   : True,
859        'explore'   : False,
860        'use_demo'  : True,
861        'zero'      : False,
862        'html'      : False,
863        'title'     : None,
864    }
865    engines = []
866    for arg in flags:
867        if arg == '-noplot':    opts['plot'] = False
868        elif arg == '-plot':    opts['plot'] = True
869        elif arg == '-linear':  opts['view'] = 'linear'
870        elif arg == '-log':     opts['view'] = 'log'
871        elif arg == '-q4':      opts['view'] = 'q4'
872        elif arg == '-1d':      opts['is2d'] = False
873        elif arg == '-2d':      opts['is2d'] = True
874        elif arg == '-exq':     opts['qmax'] = 10.0
875        elif arg == '-highq':   opts['qmax'] = 1.0
876        elif arg == '-midq':    opts['qmax'] = 0.2
877        elif arg == '-lowq':    opts['qmax'] = 0.05
878        elif arg == '-zero':    opts['zero'] = True
879        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
880        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
881        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
882        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
883        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
884        elif arg.startswith('-title'):     opts['title'] = arg[7:]
885        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
886        elif arg == '-preset':  opts['seed'] = -1
887        elif arg == '-mono':    opts['mono'] = True
888        elif arg == '-poly':    opts['mono'] = False
889        elif arg == '-pars':    opts['show_pars'] = True
890        elif arg == '-nopars':  opts['show_pars'] = False
891        elif arg == '-hist':    opts['show_hist'] = True
892        elif arg == '-nohist':  opts['show_hist'] = False
893        elif arg == '-rel':     opts['rel_err'] = True
894        elif arg == '-abs':     opts['rel_err'] = False
895        elif arg == '-half':    engines.append(arg[1:])
896        elif arg == '-fast':    engines.append(arg[1:])
897        elif arg == '-single':  engines.append(arg[1:])
898        elif arg == '-double':  engines.append(arg[1:])
899        elif arg == '-single!': engines.append(arg[1:])
900        elif arg == '-double!': engines.append(arg[1:])
901        elif arg == '-quad!':   engines.append(arg[1:])
902        elif arg == '-sasview': engines.append(arg[1:])
903        elif arg == '-edit':    opts['explore'] = True
904        elif arg == '-demo':    opts['use_demo'] = True
905        elif arg == '-default':    opts['use_demo'] = False
906        elif arg == '-html':    opts['html'] = True
907    # pylint: enable=bad-whitespace
908
909    if ':' in name:
910        name, name2 = name.split(':',2)
911    else:
912        name2 = name
913    try:
914        model_info = core.load_model_info(name)
915        model_info2 = core.load_model_info(name2) if name2 != name else model_info
916    except ImportError as exc:
917        print(str(exc))
918        print("Could not find model; use one of:\n    " + models)
919        return None
920
921    # Get demo parameters from model definition, or use default parameters
922    # if model does not define demo parameters
923    pars = get_pars(model_info, opts['use_demo'])
924    pars2 = get_pars(model_info2, opts['use_demo'])
925    # randomize parameters
926    #pars.update(set_pars)  # set value before random to control range
927    if opts['seed'] > -1:
928        pars = randomize_pars(model_info, pars, seed=opts['seed'])
929        if model_info != model_info2:
930            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
931        else:
932            pars2 = pars.copy()
933        print("Randomize using -random=%i"%opts['seed'])
934    if opts['mono']:
935        pars = suppress_pd(pars)
936        pars2 = suppress_pd(pars2)
937
938    # Fill in parameters given on the command line
939    presets = {}
940    presets2 = {}
941    for arg in values:
942        k, v = arg.split('=', 1)
943        if k not in pars and k not in pars2:
944            # extract base name without polydispersity info
945            s = set(p.split('_pd')[0] for p in pars)
946            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
947            return None
948        v1, v2 = v.split(':',2) if ':' in v else (v,v)
949        if v1 and k in pars:
950            presets[k] = float(v1) if isnumber(v1) else v1
951        if v2 and k in pars2:
952            presets2[k] = float(v2) if isnumber(v2) else v2
953
954    # Evaluate preset parameter expressions
955    context = pars.copy()
956    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
957    for k, v in presets.items():
958        if not isinstance(v, float) and not k.endswith('_type'):
959            presets[k] = eval(v, context)
960    context = pars.copy()
961    context.update(presets)
962    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
963    for k, v in presets2.items():
964        if not isinstance(v, float) and not k.endswith('_type'):
965            presets2[k] = eval(v, context)
966
967    # update parameters with presets
968    pars.update(presets)  # set value after random to control value
969    pars2.update(presets2)  # set value after random to control value
970    #import pprint; pprint.pprint(model_info)
971    constrain_pars(model_info, pars)
972    constrain_pars(model_info2, pars2)
973
974    same_model = name == name2 and pars == pars
975    if len(engines) == 0:
976        if same_model:
977            engines.extend(['single', 'double'])
978        else:
979            engines.extend(['single', 'single'])
980    elif len(engines) == 1:
981        if not same_model:
982            engines.append(engines[0])
983        elif engines[0] == 'double':
984            engines.append('single')
985        else:
986            engines.append('double')
987    elif len(engines) > 2:
988        del engines[2:]
989
990    use_sasview = any(engine == 'sasview' and count > 0
991                      for engine, count in zip(engines, [n1, n2]))
992    if use_sasview:
993        constrain_new_to_old(model_info, pars)
994        constrain_new_to_old(model_info2, pars2)
995
996    if opts['show_pars']:
997        print(str(parlist(model_info, pars, opts['is2d'])))
998
999    # Create the computational engines
1000    data, _ = make_data(opts)
1001    if n1:
1002        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1003    else:
1004        base = None
1005    if n2:
1006        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1007    else:
1008        comp = None
1009
1010    # pylint: disable=bad-whitespace
1011    # Remember it all
1012    opts.update({
1013        'data'      : data,
1014        'name'      : [name, name2],
1015        'def'       : [model_info, model_info2],
1016        'count'     : [n1, n2],
1017        'presets'   : [presets, presets2],
1018        'pars'      : [pars, pars2],
1019        'engines'   : [base, comp],
1020    })
1021    # pylint: enable=bad-whitespace
1022
1023    return opts
1024
1025def show_docs(opts):
1026    # type: (Dict[str, Any]) -> None
1027    """
1028    show html docs for the model
1029    """
1030    import wx  # type: ignore
1031    from .generate import view_html_from_info
1032    app = wx.App() if wx.GetApp() is None else None
1033    view_html_from_info(opts['def'][0])
1034    if app: app.MainLoop()
1035
1036
1037def explore(opts):
1038    # type: (Dict[str, Any]) -> None
1039    """
1040    explore the model using the bumps gui.
1041    """
1042    import wx  # type: ignore
1043    from bumps.names import FitProblem  # type: ignore
1044    from bumps.gui.app_frame import AppFrame  # type: ignore
1045
1046    is_mac = "cocoa" in wx.version()
1047    # Create an app if not running embedded
1048    app = wx.App() if wx.GetApp() is None else None
1049    problem = FitProblem(Explore(opts))
1050    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1051    if not is_mac: frame.Show()
1052    frame.panel.set_model(model=problem)
1053    frame.panel.Layout()
1054    frame.panel.aui.Split(0, wx.TOP)
1055    if is_mac: frame.Show()
1056    # If running withing an app, start the main loop
1057    if app: app.MainLoop()
1058
1059class Explore(object):
1060    """
1061    Bumps wrapper for a SAS model comparison.
1062
1063    The resulting object can be used as a Bumps fit problem so that
1064    parameters can be adjusted in the GUI, with plots updated on the fly.
1065    """
1066    def __init__(self, opts):
1067        # type: (Dict[str, Any]) -> None
1068        from bumps.cli import config_matplotlib  # type: ignore
1069        from . import bumps_model
1070        config_matplotlib()
1071        self.opts = opts
1072        model_info = opts['def'][0]
1073        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'][0])
1074        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1075        if not opts['is2d']:
1076            for p in model_info.parameters.user_parameters(is2d=False):
1077                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1078                    k = p.name+ext
1079                    v = pars.get(k, None)
1080                    if v is not None:
1081                        v.range(*parameter_range(k, v.value))
1082        else:
1083            for k, v in pars.items():
1084                v.range(*parameter_range(k, v.value))
1085
1086        self.pars = pars
1087        self.pd_types = pd_types
1088        self.limits = None
1089
1090    def numpoints(self):
1091        # type: () -> int
1092        """
1093        Return the number of points.
1094        """
1095        return len(self.pars) + 1  # so dof is 1
1096
1097    def parameters(self):
1098        # type: () -> Any   # Dict/List hierarchy of parameters
1099        """
1100        Return a dictionary of parameters.
1101        """
1102        return self.pars
1103
1104    def nllf(self):
1105        # type: () -> float
1106        """
1107        Return cost.
1108        """
1109        # pylint: disable=no-self-use
1110        return 0.  # No nllf
1111
1112    def plot(self, view='log'):
1113        # type: (str) -> None
1114        """
1115        Plot the data and residuals.
1116        """
1117        pars = dict((k, v.value) for k, v in self.pars.items())
1118        pars.update(self.pd_types)
1119        self.opts['pars'][0] = pars
1120        self.opts['pars'][1] = pars
1121        limits = compare(self.opts, limits=self.limits)
1122        if self.limits is None:
1123            vmin, vmax = limits
1124            self.limits = vmax*1e-7, 1.3*vmax
1125
1126
1127def main(*argv):
1128    # type: (*str) -> None
1129    """
1130    Main program.
1131    """
1132    opts = parse_opts(argv)
1133    if opts is not None:
1134        if opts['html']:
1135            show_docs(opts)
1136        elif opts['explore']:
1137            explore(opts)
1138        else:
1139            compare(opts)
1140
1141if __name__ == "__main__":
1142    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.