source: sasmodels/sasmodels/compare.py @ b6f10d8

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since b6f10d8 was b6f10d8, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

sascomp: default to pd_n=35 if pd given on command line

  • Property mode set to 100755
File size: 41.0 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35import re
36
37import numpy as np  # type: ignore
38
39from . import core
40from . import kerneldll
41from . import exception
42from .data import plot_theory, empty_data1D, empty_data2D
43from .direct_model import DirectModel
44from .convert import revert_name, revert_pars, constrain_new_to_old
45from .generate import FLOAT_RE
46
47try:
48    from typing import Optional, Dict, Any, Callable, Tuple
49except Exception:
50    pass
51else:
52    from .modelinfo import ModelInfo, Parameter, ParameterSet
53    from .data import Data
54    Calculator = Callable[[float], np.ndarray]
55
56USAGE = """
57usage: compare.py model N1 N2 [options...] [key=val]
58
59Compare the speed and value for a model between the SasView original and the
60sasmodels rewrite.
61
62model is the name of the model to compare (see below).
63N1 is the number of times to run sasmodels (default=1).
64N2 is the number times to run sasview (default=1).
65
66Options (* for default):
67
68    -plot*/-noplot plots or suppress the plot of the model
69    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
70    -nq=128 sets the number of Q points in the data set
71    -zero indicates that q=0 should be included
72    -1d*/-2d computes 1d or 2d data
73    -preset*/-random[=seed] preset or random parameters
74    -mono/-poly* force monodisperse/polydisperse
75    -magnetic/-nonmagnetic* suppress magnetism
76    -cutoff=1e-5* cutoff value for including a point in polydispersity
77    -pars/-nopars* prints the parameter set or not
78    -abs/-rel* plot relative or absolute error
79    -linear/-log*/-q4 intensity scaling
80    -hist/-nohist* plot histogram of relative error
81    -res=0 sets the resolution width dQ/Q if calculating with resolution
82    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
83    -edit starts the parameter explorer
84    -default/-demo* use demo vs default parameters
85    -html shows the model docs instead of running the model
86    -title="note" adds note to the plot title, after the model name
87
88Any two calculation engines can be selected for comparison:
89
90    -single/-double/-half/-fast sets an OpenCL calculation engine
91    -single!/-double!/-quad! sets an OpenMP calculation engine
92    -sasview sets the sasview calculation engine
93
94The default is -single -sasview.  Note that the interpretation of quad
95precision depends on architecture, and may vary from 64-bit to 128-bit,
96with 80-bit floats being common (1e-19 precision).
97
98Key=value pairs allow you to set specific values for the model parameters.
99"""
100
101# Update docs with command line usage string.   This is separate from the usual
102# doc string so that we can display it at run time if there is an error.
103# lin
104__doc__ = (__doc__  # pylint: disable=redefined-builtin
105           + """
106Program description
107-------------------
108
109"""
110           + USAGE)
111
112kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
113
114# list of math functions for use in evaluating parameters
115MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
116
117# CRUFT python 2.6
118if not hasattr(datetime.timedelta, 'total_seconds'):
119    def delay(dt):
120        """Return number date-time delta as number seconds"""
121        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
122else:
123    def delay(dt):
124        """Return number date-time delta as number seconds"""
125        return dt.total_seconds()
126
127
128class push_seed(object):
129    """
130    Set the seed value for the random number generator.
131
132    When used in a with statement, the random number generator state is
133    restored after the with statement is complete.
134
135    :Parameters:
136
137    *seed* : int or array_like, optional
138        Seed for RandomState
139
140    :Example:
141
142    Seed can be used directly to set the seed::
143
144        >>> from numpy.random import randint
145        >>> push_seed(24)
146        <...push_seed object at...>
147        >>> print(randint(0,1000000,3))
148        [242082    899 211136]
149
150    Seed can also be used in a with statement, which sets the random
151    number generator state for the enclosed computations and restores
152    it to the previous state on completion::
153
154        >>> with push_seed(24):
155        ...    print(randint(0,1000000,3))
156        [242082    899 211136]
157
158    Using nested contexts, we can demonstrate that state is indeed
159    restored after the block completes::
160
161        >>> with push_seed(24):
162        ...    print(randint(0,1000000))
163        ...    with push_seed(24):
164        ...        print(randint(0,1000000,3))
165        ...    print(randint(0,1000000))
166        242082
167        [242082    899 211136]
168        899
169
170    The restore step is protected against exceptions in the block::
171
172        >>> with push_seed(24):
173        ...    print(randint(0,1000000))
174        ...    try:
175        ...        with push_seed(24):
176        ...            print(randint(0,1000000,3))
177        ...            raise Exception()
178        ...    except Exception:
179        ...        print("Exception raised")
180        ...    print(randint(0,1000000))
181        242082
182        [242082    899 211136]
183        Exception raised
184        899
185    """
186    def __init__(self, seed=None):
187        # type: (Optional[int]) -> None
188        self._state = np.random.get_state()
189        np.random.seed(seed)
190
191    def __enter__(self):
192        # type: () -> None
193        pass
194
195    def __exit__(self, exc_type, exc_value, traceback):
196        # type: (Any, BaseException, Any) -> None
197        # TODO: better typing for __exit__ method
198        np.random.set_state(self._state)
199
200def tic():
201    # type: () -> Callable[[], float]
202    """
203    Timer function.
204
205    Use "toc=tic()" to start the clock and "toc()" to measure
206    a time interval.
207    """
208    then = datetime.datetime.now()
209    return lambda: delay(datetime.datetime.now() - then)
210
211
212def set_beam_stop(data, radius, outer=None):
213    # type: (Data, float, float) -> None
214    """
215    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
216
217    Note: this function does not require sasview
218    """
219    if hasattr(data, 'qx_data'):
220        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
221        data.mask = (q < radius)
222        if outer is not None:
223            data.mask |= (q >= outer)
224    else:
225        data.mask = (data.x < radius)
226        if outer is not None:
227            data.mask |= (data.x >= outer)
228
229
230def parameter_range(p, v):
231    # type: (str, float) -> Tuple[float, float]
232    """
233    Choose a parameter range based on parameter name and initial value.
234    """
235    # process the polydispersity options
236    if p.endswith('_pd_n'):
237        return 0., 100.
238    elif p.endswith('_pd_nsigma'):
239        return 0., 5.
240    elif p.endswith('_pd_type'):
241        raise ValueError("Cannot return a range for a string value")
242    elif any(s in p for s in ('theta', 'phi', 'psi')):
243        # orientation in [-180,180], orientation pd in [0,45]
244        if p.endswith('_pd'):
245            return 0., 45.
246        else:
247            return -180., 180.
248    elif p.endswith('_pd'):
249        return 0., 1.
250    elif 'sld' in p:
251        return -0.5, 10.
252    elif p == 'background':
253        return 0., 10.
254    elif p == 'scale':
255        return 0., 1.e3
256    elif v < 0.:
257        return 2.*v, -2.*v
258    else:
259        return 0., (2.*v if v > 0. else 1.)
260
261
262def _randomize_one(model_info, p, v):
263    # type: (ModelInfo, str, float) -> float
264    # type: (ModelInfo, str, str) -> str
265    """
266    Randomize a single parameter.
267    """
268    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
269        return v
270
271    # Find the parameter definition
272    for par in model_info.parameters.call_parameters:
273        if par.name == p:
274            break
275    else:
276        raise ValueError("unknown parameter %r"%p)
277
278    if len(par.limits) > 2:  # choice list
279        return np.random.randint(len(par.limits))
280
281    limits = par.limits
282    if not np.isfinite(limits).all():
283        low, high = parameter_range(p, v)
284        limits = (max(limits[0], low), min(limits[1], high))
285
286    return np.random.uniform(*limits)
287
288
289def randomize_pars(model_info, pars, seed=None):
290    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
291    """
292    Generate random values for all of the parameters.
293
294    Valid ranges for the random number generator are guessed from the name of
295    the parameter; this will not account for constraints such as cap radius
296    greater than cylinder radius in the capped_cylinder model, so
297    :func:`constrain_pars` needs to be called afterward..
298    """
299    with push_seed(seed):
300        # Note: the sort guarantees order `of calls to random number generator
301        random_pars = dict((p, _randomize_one(model_info, p, v))
302                           for p, v in sorted(pars.items()))
303    return random_pars
304
305def constrain_pars(model_info, pars):
306    # type: (ModelInfo, ParameterSet) -> None
307    """
308    Restrict parameters to valid values.
309
310    This includes model specific code for models such as capped_cylinder
311    which need to support within model constraints (cap radius more than
312    cylinder radius in this case).
313
314    Warning: this updates the *pars* dictionary in place.
315    """
316    name = model_info.id
317    # if it is a product model, then just look at the form factor since
318    # none of the structure factors need any constraints.
319    if '*' in name:
320        name = name.split('*')[0]
321
322    if name == 'capped_cylinder' and pars['radius_cap'] < pars['radius']:
323        pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
324    if name == 'barbell' and pars['radius_bell'] < pars['radius']:
325        pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
326
327    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
328    if name == 'guinier':
329        #q_max = 0.2  # mid q maximum
330        q_max = 1.0  # high q maximum
331        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
332        pars['rg'] = min(pars['rg'], rg_max)
333
334    if name == 'rpa':
335        # Make sure phi sums to 1.0
336        if pars['case_num'] < 2:
337            pars['Phi1'] = 0.
338            pars['Phi2'] = 0.
339        elif pars['case_num'] < 5:
340            pars['Phi1'] = 0.
341        total = sum(pars['Phi'+c] for c in '1234')
342        for c in '1234':
343            pars['Phi'+c] /= total
344
345def parlist(model_info, pars, is2d):
346    # type: (ModelInfo, ParameterSet, bool) -> str
347    """
348    Format the parameter list for printing.
349    """
350    lines = []
351    parameters = model_info.parameters
352    magnetic = False
353    for p in parameters.user_parameters(pars, is2d):
354        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
355            continue
356        if p.id.startswith('up:') and not magnetic:
357            continue
358        fields = dict(
359            value=pars.get(p.id, p.default),
360            pd=pars.get(p.id+"_pd", 0.),
361            n=int(pars.get(p.id+"_pd_n", 0)),
362            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
363            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
364            relative_pd=p.relative_pd,
365            M0=pars.get('M0:'+p.id, 0.),
366            mphi=pars.get('mphi:'+p.id, 0.),
367            mtheta=pars.get('mtheta:'+p.id, 0.),
368        )
369        lines.append(_format_par(p.name, **fields))
370        magnetic = magnetic or fields['M0'] != 0.
371    return "\n".join(lines)
372
373    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
374
375def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
376                relative_pd=False, M0=0., mphi=0., mtheta=0.):
377    # type: (str, float, float, int, float, str) -> str
378    line = "%s: %g"%(name, value)
379    if pd != 0.  and n != 0:
380        if relative_pd:
381            pd *= value
382        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
383                % (pd, n, nsigma, nsigma, pdtype)
384    if M0 != 0.:
385        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
386    return line
387
388def suppress_pd(pars):
389    # type: (ParameterSet) -> ParameterSet
390    """
391    Suppress theta_pd for now until the normalization is resolved.
392
393    May also suppress complete polydispersity of the model to test
394    models more quickly.
395    """
396    pars = pars.copy()
397    for p in pars:
398        if p.endswith("_pd_n"): pars[p] = 0
399    return pars
400
401def suppress_magnetism(pars):
402    # type: (ParameterSet) -> ParameterSet
403    """
404    Suppress theta_pd for now until the normalization is resolved.
405
406    May also suppress complete polydispersity of the model to test
407    models more quickly.
408    """
409    pars = pars.copy()
410    for p in pars:
411        if p.startswith("M0:"): pars[p] = 0
412    return pars
413
414def eval_sasview(model_info, data):
415    # type: (Modelinfo, Data) -> Calculator
416    """
417    Return a model calculator using the pre-4.0 SasView models.
418    """
419    # importing sas here so that the error message will be that sas failed to
420    # import rather than the more obscure smear_selection not imported error
421    import sas
422    import sas.models
423    from sas.models.qsmearing import smear_selection
424    from sas.models.MultiplicationModel import MultiplicationModel
425    from sas.models.dispersion_models import models as dispersers
426
427    def get_model_class(name):
428        # type: (str) -> "sas.models.BaseComponent"
429        #print("new",sorted(_pars.items()))
430        __import__('sas.models.' + name)
431        ModelClass = getattr(getattr(sas.models, name, None), name, None)
432        if ModelClass is None:
433            raise ValueError("could not find model %r in sas.models"%name)
434        return ModelClass
435
436    # WARNING: ugly hack when handling model!
437    # Sasview models with multiplicity need to be created with the target
438    # multiplicity, so we cannot create the target model ahead of time for
439    # for multiplicity models.  Instead we store the model in a list and
440    # update the first element of that list with the new multiplicity model
441    # every time we evaluate.
442
443    # grab the sasview model, or create it if it is a product model
444    if model_info.composition:
445        composition_type, parts = model_info.composition
446        if composition_type == 'product':
447            P, S = [get_model_class(revert_name(p))() for p in parts]
448            model = [MultiplicationModel(P, S)]
449        else:
450            raise ValueError("sasview mixture models not supported by compare")
451    else:
452        old_name = revert_name(model_info)
453        if old_name is None:
454            raise ValueError("model %r does not exist in old sasview"
455                            % model_info.id)
456        ModelClass = get_model_class(old_name)
457        model = [ModelClass()]
458    model[0].disperser_handles = {}
459
460    # build a smearer with which to call the model, if necessary
461    smearer = smear_selection(data, model=model)
462    if hasattr(data, 'qx_data'):
463        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
464        index = ((~data.mask) & (~np.isnan(data.data))
465                 & (q >= data.qmin) & (q <= data.qmax))
466        if smearer is not None:
467            smearer.model = model  # because smear_selection has a bug
468            smearer.accuracy = data.accuracy
469            smearer.set_index(index)
470            def _call_smearer():
471                smearer.model = model[0]
472                return smearer.get_value()
473            theory = _call_smearer
474        else:
475            theory = lambda: model[0].evalDistribution([data.qx_data[index],
476                                                        data.qy_data[index]])
477    elif smearer is not None:
478        theory = lambda: smearer(model[0].evalDistribution(data.x))
479    else:
480        theory = lambda: model[0].evalDistribution(data.x)
481
482    def calculator(**pars):
483        # type: (float, ...) -> np.ndarray
484        """
485        Sasview calculator for model.
486        """
487        oldpars = revert_pars(model_info, pars)
488        # For multiplicity models, create a model with the correct multiplicity
489        control = oldpars.pop("CONTROL", None)
490        if control is not None:
491            # sphericalSLD has one fewer multiplicity.  This update should
492            # happen in revert_pars, but it hasn't been called yet.
493            model[0] = ModelClass(control)
494        # paying for parameter conversion each time to keep life simple, if not fast
495        for k, v in oldpars.items():
496            if k.endswith('.type'):
497                par = k[:-5]
498                if v == 'gaussian': continue
499                cls = dispersers[v if v != 'rectangle' else 'rectangula']
500                handle = cls()
501                model[0].disperser_handles[par] = handle
502                try:
503                    model[0].set_dispersion(par, handle)
504                except Exception:
505                    exception.annotate_exception("while setting %s to %r"
506                                                 %(par, v))
507                    raise
508
509
510        #print("sasview pars",oldpars)
511        for k, v in oldpars.items():
512            name_attr = k.split('.')  # polydispersity components
513            if len(name_attr) == 2:
514                par, disp_par = name_attr
515                model[0].dispersion[par][disp_par] = v
516            else:
517                model[0].setParam(k, v)
518        return theory()
519
520    calculator.engine = "sasview"
521    return calculator
522
523DTYPE_MAP = {
524    'half': '16',
525    'fast': 'fast',
526    'single': '32',
527    'double': '64',
528    'quad': '128',
529    'f16': '16',
530    'f32': '32',
531    'f64': '64',
532    'longdouble': '128',
533}
534def eval_opencl(model_info, data, dtype='single', cutoff=0.):
535    # type: (ModelInfo, Data, str, float) -> Calculator
536    """
537    Return a model calculator using the OpenCL calculation engine.
538    """
539    if not core.HAVE_OPENCL:
540        raise RuntimeError("OpenCL not available")
541    model = core.build_model(model_info, dtype=dtype, platform="ocl")
542    calculator = DirectModel(data, model, cutoff=cutoff)
543    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
544    return calculator
545
546def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
547    # type: (ModelInfo, Data, str, float) -> Calculator
548    """
549    Return a model calculator using the DLL calculation engine.
550    """
551    model = core.build_model(model_info, dtype=dtype, platform="dll")
552    calculator = DirectModel(data, model, cutoff=cutoff)
553    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
554    return calculator
555
556def time_calculation(calculator, pars, evals=1):
557    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
558    """
559    Compute the average calculation time over N evaluations.
560
561    An additional call is generated without polydispersity in order to
562    initialize the calculation engine, and make the average more stable.
563    """
564    # initialize the code so time is more accurate
565    if evals > 1:
566        calculator(**suppress_pd(pars))
567    toc = tic()
568    # make sure there is at least one eval
569    value = calculator(**pars)
570    for _ in range(evals-1):
571        value = calculator(**pars)
572    average_time = toc()*1000. / evals
573    #print("I(q)",value)
574    return value, average_time
575
576def make_data(opts):
577    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
578    """
579    Generate an empty dataset, used with the model to set Q points
580    and resolution.
581
582    *opts* contains the options, with 'qmax', 'nq', 'res',
583    'accuracy', 'is2d' and 'view' parsed from the command line.
584    """
585    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
586    if opts['is2d']:
587        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
588        data = empty_data2D(q, resolution=res)
589        data.accuracy = opts['accuracy']
590        set_beam_stop(data, 0.0004)
591        index = ~data.mask
592    else:
593        if opts['view'] == 'log' and not opts['zero']:
594            qmax = math.log10(qmax)
595            q = np.logspace(qmax-3, qmax, nq)
596        else:
597            q = np.linspace(0.001*qmax, qmax, nq)
598        if opts['zero']:
599            q = np.hstack((0, q))
600        data = empty_data1D(q, resolution=res)
601        index = slice(None, None)
602    return data, index
603
604def make_engine(model_info, data, dtype, cutoff):
605    # type: (ModelInfo, Data, str, float) -> Calculator
606    """
607    Generate the appropriate calculation engine for the given datatype.
608
609    Datatypes with '!' appended are evaluated using external C DLLs rather
610    than OpenCL.
611    """
612    if dtype == 'sasview':
613        return eval_sasview(model_info, data)
614    elif dtype.endswith('!'):
615        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
616    else:
617        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
618
619def _show_invalid(data, theory):
620    # type: (Data, np.ma.ndarray) -> None
621    """
622    Display a list of the non-finite values in theory.
623    """
624    if not theory.mask.any():
625        return
626
627    if hasattr(data, 'x'):
628        bad = zip(data.x[theory.mask], theory[theory.mask])
629        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
630
631
632def compare(opts, limits=None):
633    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
634    """
635    Preform a comparison using options from the command line.
636
637    *limits* are the limits on the values to use, either to set the y-axis
638    for 1D or to set the colormap scale for 2D.  If None, then they are
639    inferred from the data and returned. When exploring using Bumps,
640    the limits are set when the model is initially called, and maintained
641    as the values are adjusted, making it easier to see the effects of the
642    parameters.
643    """
644    n_base, n_comp = opts['count']
645    pars, pars2 = opts['pars']
646    data = opts['data']
647
648    # silence the linter
649    base = opts['engines'][0] if n_base else None
650    comp = opts['engines'][1] if n_comp else None
651    base_time = comp_time = None
652    base_value = comp_value = resid = relerr = None
653
654    # Base calculation
655    if n_base > 0:
656        try:
657            base_raw, base_time = time_calculation(base, pars, n_base)
658            base_value = np.ma.masked_invalid(base_raw)
659            print("%s t=%.2f ms, intensity=%.0f"
660                  % (base.engine, base_time, base_value.sum()))
661            _show_invalid(data, base_value)
662        except ImportError:
663            traceback.print_exc()
664            n_base = 0
665
666    # Comparison calculation
667    if n_comp > 0:
668        try:
669            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
670            comp_value = np.ma.masked_invalid(comp_raw)
671            print("%s t=%.2f ms, intensity=%.0f"
672                  % (comp.engine, comp_time, comp_value.sum()))
673            _show_invalid(data, comp_value)
674        except ImportError:
675            traceback.print_exc()
676            n_comp = 0
677
678    # Compare, but only if computing both forms
679    if n_base > 0 and n_comp > 0:
680        resid = (base_value - comp_value)
681        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
682        _print_stats("|%s-%s|"
683                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
684                     resid)
685        _print_stats("|(%s-%s)/%s|"
686                     % (base.engine, comp.engine, comp.engine),
687                     relerr)
688
689    # Plot if requested
690    if not opts['plot'] and not opts['explore']: return
691    view = opts['view']
692    import matplotlib.pyplot as plt
693    if limits is None:
694        vmin, vmax = np.Inf, -np.Inf
695        if n_base > 0:
696            vmin = min(vmin, base_value.min())
697            vmax = max(vmax, base_value.max())
698        if n_comp > 0:
699            vmin = min(vmin, comp_value.min())
700            vmax = max(vmax, comp_value.max())
701        limits = vmin, vmax
702
703    if n_base > 0:
704        if n_comp > 0: plt.subplot(131)
705        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
706        plt.title("%s t=%.2f ms"%(base.engine, base_time))
707        #cbar_title = "log I"
708    if n_comp > 0:
709        if n_base > 0: plt.subplot(132)
710        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
711        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
712        #cbar_title = "log I"
713    if n_comp > 0 and n_base > 0:
714        if not opts['is2d']:
715            plot_theory(data, base_value, view=view, use_data=False, limits=limits)
716        plt.subplot(133)
717        if not opts['rel_err']:
718            err, errstr, errview = resid, "abs err", "linear"
719        else:
720            err, errstr, errview = abs(relerr), "rel err", "log"
721        #sorted = np.sort(err.flatten())
722        #cutoff = sorted[int(sorted.size*0.95)]
723        #err[err>cutoff] = cutoff
724        #err,errstr = base/comp,"ratio"
725        plot_theory(data, None, resid=err, view=errview, use_data=False)
726        if view == 'linear':
727            plt.xscale('linear')
728        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
729        #cbar_title = errstr if errview=="linear" else "log "+errstr
730    #if is2D:
731    #    h = plt.colorbar()
732    #    h.ax.set_title(cbar_title)
733    fig = plt.gcf()
734    extra_title = ' '+opts['title'] if opts['title'] else ''
735    fig.suptitle(":".join(opts['name']) + extra_title)
736
737    if n_comp > 0 and n_base > 0 and opts['show_hist']:
738        plt.figure()
739        v = relerr
740        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
741        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
742        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
743                   % (base.engine, comp.engine, comp.engine))
744        plt.ylabel('P(err)')
745        plt.title('Distribution of relative error between calculation engines')
746
747    if not opts['explore']:
748        plt.show()
749
750    return limits
751
752def _print_stats(label, err):
753    # type: (str, np.ma.ndarray) -> None
754    # work with trimmed data, not the full set
755    sorted_err = np.sort(abs(err.compressed()))
756    p50 = int((len(sorted_err)-1)*0.50)
757    p98 = int((len(sorted_err)-1)*0.98)
758    data = [
759        "max:%.3e"%sorted_err[-1],
760        "median:%.3e"%sorted_err[p50],
761        "98%%:%.3e"%sorted_err[p98],
762        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
763        "zero-offset:%+.3e"%np.mean(sorted_err),
764        ]
765    print(label+"  "+"  ".join(data))
766
767
768
769# ===========================================================================
770#
771NAME_OPTIONS = set([
772    'plot', 'noplot',
773    'half', 'fast', 'single', 'double',
774    'single!', 'double!', 'quad!', 'sasview',
775    'lowq', 'midq', 'highq', 'exq', 'zero',
776    '2d', '1d',
777    'preset', 'random',
778    'poly', 'mono',
779    'magnetic', 'nonmagnetic',
780    'nopars', 'pars',
781    'rel', 'abs',
782    'linear', 'log', 'q4',
783    'hist', 'nohist',
784    'edit', 'html',
785    'demo', 'default',
786    ])
787VALUE_OPTIONS = [
788    # Note: random is both a name option and a value option
789    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title',
790    ]
791
792def columnize(items, indent="", width=79):
793    # type: (List[str], str, int) -> str
794    """
795    Format a list of strings into columns.
796
797    Returns a string with carriage returns ready for printing.
798    """
799    column_width = max(len(w) for w in items) + 1
800    num_columns = (width - len(indent)) // column_width
801    num_rows = len(items) // num_columns
802    items = items + [""] * (num_rows * num_columns - len(items))
803    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
804    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
805             for row in zip(*columns)]
806    output = indent + ("\n"+indent).join(lines)
807    return output
808
809
810def get_pars(model_info, use_demo=False):
811    # type: (ModelInfo, bool) -> ParameterSet
812    """
813    Extract demo parameters from the model definition.
814    """
815    # Get the default values for the parameters
816    pars = {}
817    for p in model_info.parameters.call_parameters:
818        parts = [('', p.default)]
819        if p.polydisperse:
820            parts.append(('_pd', 0.0))
821            parts.append(('_pd_n', 0))
822            parts.append(('_pd_nsigma', 3.0))
823            parts.append(('_pd_type', "gaussian"))
824        for ext, val in parts:
825            if p.length > 1:
826                dict(("%s%d%s" % (p.id, k, ext), val)
827                     for k in range(1, p.length+1))
828            else:
829                pars[p.id + ext] = val
830
831    # Plug in values given in demo
832    if use_demo:
833        pars.update(model_info.demo)
834    return pars
835
836INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
837def isnumber(str):
838    match = FLOAT_RE.match(str)
839    isfloat = (match and not str[match.end():])
840    return isfloat or INTEGER_RE.match(str)
841
842def parse_opts(argv):
843    # type: (List[str]) -> Dict[str, Any]
844    """
845    Parse command line options.
846    """
847    MODELS = core.list_models()
848    flags = [arg for arg in argv
849             if arg.startswith('-')]
850    values = [arg for arg in argv
851              if not arg.startswith('-') and '=' in arg]
852    positional_args = [arg for arg in argv
853            if not arg.startswith('-') and '=' not in arg]
854    models = "\n    ".join("%-15s"%v for v in MODELS)
855    if len(positional_args) == 0:
856        print(USAGE)
857        print("\nAvailable models:")
858        print(columnize(MODELS, indent="  "))
859        return None
860    if len(positional_args) > 3:
861        print("expected parameters: model N1 N2")
862
863    invalid = [o[1:] for o in flags
864               if o[1:] not in NAME_OPTIONS
865               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
866    if invalid:
867        print("Invalid options: %s"%(", ".join(invalid)))
868        return None
869
870    name = positional_args[0]
871    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
872    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
873
874    # pylint: disable=bad-whitespace
875    # Interpret the flags
876    opts = {
877        'plot'      : True,
878        'view'      : 'log',
879        'is2d'      : False,
880        'qmax'      : 0.05,
881        'nq'        : 128,
882        'res'       : 0.0,
883        'accuracy'  : 'Low',
884        'cutoff'    : 0.0,
885        'seed'      : -1,  # default to preset
886        'mono'      : False,
887        # Default to magnetic a magnetic moment is set on the command line
888        'magnetic'  : False,
889        'show_pars' : False,
890        'show_hist' : False,
891        'rel_err'   : True,
892        'explore'   : False,
893        'use_demo'  : True,
894        'zero'      : False,
895        'html'      : False,
896        'title'     : None,
897    }
898    engines = []
899    for arg in flags:
900        if arg == '-noplot':    opts['plot'] = False
901        elif arg == '-plot':    opts['plot'] = True
902        elif arg == '-linear':  opts['view'] = 'linear'
903        elif arg == '-log':     opts['view'] = 'log'
904        elif arg == '-q4':      opts['view'] = 'q4'
905        elif arg == '-1d':      opts['is2d'] = False
906        elif arg == '-2d':      opts['is2d'] = True
907        elif arg == '-exq':     opts['qmax'] = 10.0
908        elif arg == '-highq':   opts['qmax'] = 1.0
909        elif arg == '-midq':    opts['qmax'] = 0.2
910        elif arg == '-lowq':    opts['qmax'] = 0.05
911        elif arg == '-zero':    opts['zero'] = True
912        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
913        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
914        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
915        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
916        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
917        elif arg.startswith('-title'):     opts['title'] = arg[7:]
918        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
919        elif arg == '-preset':  opts['seed'] = -1
920        elif arg == '-mono':    opts['mono'] = True
921        elif arg == '-poly':    opts['mono'] = False
922        elif arg == '-magnetic':       opts['magnetic'] = True
923        elif arg == '-nonmagnetic':    opts['magnetic'] = False
924        elif arg == '-pars':    opts['show_pars'] = True
925        elif arg == '-nopars':  opts['show_pars'] = False
926        elif arg == '-hist':    opts['show_hist'] = True
927        elif arg == '-nohist':  opts['show_hist'] = False
928        elif arg == '-rel':     opts['rel_err'] = True
929        elif arg == '-abs':     opts['rel_err'] = False
930        elif arg == '-half':    engines.append(arg[1:])
931        elif arg == '-fast':    engines.append(arg[1:])
932        elif arg == '-single':  engines.append(arg[1:])
933        elif arg == '-double':  engines.append(arg[1:])
934        elif arg == '-single!': engines.append(arg[1:])
935        elif arg == '-double!': engines.append(arg[1:])
936        elif arg == '-quad!':   engines.append(arg[1:])
937        elif arg == '-sasview': engines.append(arg[1:])
938        elif arg == '-edit':    opts['explore'] = True
939        elif arg == '-demo':    opts['use_demo'] = True
940        elif arg == '-default':    opts['use_demo'] = False
941        elif arg == '-html':    opts['html'] = True
942    # pylint: enable=bad-whitespace
943
944    if ':' in name:
945        name, name2 = name.split(':',2)
946    else:
947        name2 = name
948    try:
949        model_info = core.load_model_info(name)
950        model_info2 = core.load_model_info(name2) if name2 != name else model_info
951    except ImportError as exc:
952        print(str(exc))
953        print("Could not find model; use one of:\n    " + models)
954        return None
955
956    # Get demo parameters from model definition, or use default parameters
957    # if model does not define demo parameters
958    pars = get_pars(model_info, opts['use_demo'])
959    pars2 = get_pars(model_info2, opts['use_demo'])
960    pars2.update((k, v) for k, v in pars.items() if k in pars2)
961    # randomize parameters
962    #pars.update(set_pars)  # set value before random to control range
963    if opts['seed'] > -1:
964        pars = randomize_pars(model_info, pars, seed=opts['seed'])
965        if model_info != model_info2:
966            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
967        else:
968            pars2 = pars.copy()
969        print("Randomize using -random=%i"%opts['seed'])
970    if opts['mono']:
971        pars = suppress_pd(pars)
972        pars2 = suppress_pd(pars2)
973    if not opts['magnetic']:
974        pars = suppress_magnetism(pars)
975        pars2 = suppress_magnetism(pars2)
976
977    # Fill in parameters given on the command line
978    presets = {}
979    presets2 = {}
980    for arg in values:
981        k, v = arg.split('=', 1)
982        if k not in pars and k not in pars2:
983            # extract base name without polydispersity info
984            s = set(p.split('_pd')[0] for p in pars)
985            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
986            return None
987        v1, v2 = v.split(':',2) if ':' in v else (v,v)
988        if v1 and k in pars:
989            presets[k] = float(v1) if isnumber(v1) else v1
990        if v2 and k in pars2:
991            presets2[k] = float(v2) if isnumber(v2) else v2
992
993    # If pd given on the command line, default pd_n to 35
994    for k, v in list(presets.items()):
995        if k.endswith('_pd'):
996            presets.setdefault(k+'_n', 35.)
997    for k, v in list(presets2.items()):
998        if k.endswith('_pd'):
999            presets2.setdefault(k+'_n', 35.)
1000
1001    # Evaluate preset parameter expressions
1002    context = MATH.copy()
1003    context.update(pars)
1004    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
1005    for k, v in presets.items():
1006        if not isinstance(v, float) and not k.endswith('_type'):
1007            presets[k] = eval(v, context)
1008    context.update(presets)
1009    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
1010    for k, v in presets2.items():
1011        if not isinstance(v, float) and not k.endswith('_type'):
1012            presets2[k] = eval(v, context)
1013
1014    # update parameters with presets
1015    pars.update(presets)  # set value after random to control value
1016    pars2.update(presets2)  # set value after random to control value
1017    #import pprint; pprint.pprint(model_info)
1018    constrain_pars(model_info, pars)
1019    constrain_pars(model_info2, pars2)
1020
1021    same_model = name == name2 and pars == pars
1022    if len(engines) == 0:
1023        if same_model:
1024            engines.extend(['single', 'double'])
1025        else:
1026            engines.extend(['single', 'single'])
1027    elif len(engines) == 1:
1028        if not same_model:
1029            engines.append(engines[0])
1030        elif engines[0] == 'double':
1031            engines.append('single')
1032        else:
1033            engines.append('double')
1034    elif len(engines) > 2:
1035        del engines[2:]
1036
1037    use_sasview = any(engine == 'sasview' and count > 0
1038                      for engine, count in zip(engines, [n1, n2]))
1039    if use_sasview:
1040        constrain_new_to_old(model_info, pars)
1041        constrain_new_to_old(model_info2, pars2)
1042
1043    if opts['show_pars']:
1044        if not same_model:
1045            print("==== %s ====="%model_info.name)
1046            print(str(parlist(model_info, pars, opts['is2d'])))
1047            print("==== %s ====="%model_info2.name)
1048            print(str(parlist(model_info2, pars2, opts['is2d'])))
1049        else:
1050            print(str(parlist(model_info, pars, opts['is2d'])))
1051
1052    # Create the computational engines
1053    data, _ = make_data(opts)
1054    if n1:
1055        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1056    else:
1057        base = None
1058    if n2:
1059        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1060    else:
1061        comp = None
1062
1063    # pylint: disable=bad-whitespace
1064    # Remember it all
1065    opts.update({
1066        'data'      : data,
1067        'name'      : [name, name2],
1068        'def'       : [model_info, model_info2],
1069        'count'     : [n1, n2],
1070        'presets'   : [presets, presets2],
1071        'pars'      : [pars, pars2],
1072        'engines'   : [base, comp],
1073    })
1074    # pylint: enable=bad-whitespace
1075
1076    return opts
1077
1078def show_docs(opts):
1079    # type: (Dict[str, Any]) -> None
1080    """
1081    show html docs for the model
1082    """
1083    import wx  # type: ignore
1084    from .generate import view_html_from_info
1085    app = wx.App() if wx.GetApp() is None else None
1086    view_html_from_info(opts['def'][0])
1087    if app: app.MainLoop()
1088
1089
1090def explore(opts):
1091    # type: (Dict[str, Any]) -> None
1092    """
1093    explore the model using the bumps gui.
1094    """
1095    import wx  # type: ignore
1096    from bumps.names import FitProblem  # type: ignore
1097    from bumps.gui.app_frame import AppFrame  # type: ignore
1098
1099    is_mac = "cocoa" in wx.version()
1100    # Create an app if not running embedded
1101    app = wx.App() if wx.GetApp() is None else None
1102    problem = FitProblem(Explore(opts))
1103    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1104    if not is_mac: frame.Show()
1105    frame.panel.set_model(model=problem)
1106    frame.panel.Layout()
1107    frame.panel.aui.Split(0, wx.TOP)
1108    if is_mac: frame.Show()
1109    # If running withing an app, start the main loop
1110    if app: app.MainLoop()
1111
1112class Explore(object):
1113    """
1114    Bumps wrapper for a SAS model comparison.
1115
1116    The resulting object can be used as a Bumps fit problem so that
1117    parameters can be adjusted in the GUI, with plots updated on the fly.
1118    """
1119    def __init__(self, opts):
1120        # type: (Dict[str, Any]) -> None
1121        from bumps.cli import config_matplotlib  # type: ignore
1122        from . import bumps_model
1123        config_matplotlib()
1124        self.opts = opts
1125        model_info = opts['def'][0]
1126        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'][0])
1127        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1128        if not opts['is2d']:
1129            for p in model_info.parameters.user_parameters(is2d=False):
1130                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1131                    k = p.name+ext
1132                    v = pars.get(k, None)
1133                    if v is not None:
1134                        v.range(*parameter_range(k, v.value))
1135        else:
1136            for k, v in pars.items():
1137                v.range(*parameter_range(k, v.value))
1138
1139        self.pars = pars
1140        self.pd_types = pd_types
1141        self.limits = None
1142
1143    def numpoints(self):
1144        # type: () -> int
1145        """
1146        Return the number of points.
1147        """
1148        return len(self.pars) + 1  # so dof is 1
1149
1150    def parameters(self):
1151        # type: () -> Any   # Dict/List hierarchy of parameters
1152        """
1153        Return a dictionary of parameters.
1154        """
1155        return self.pars
1156
1157    def nllf(self):
1158        # type: () -> float
1159        """
1160        Return cost.
1161        """
1162        # pylint: disable=no-self-use
1163        return 0.  # No nllf
1164
1165    def plot(self, view='log'):
1166        # type: (str) -> None
1167        """
1168        Plot the data and residuals.
1169        """
1170        pars = dict((k, v.value) for k, v in self.pars.items())
1171        pars.update(self.pd_types)
1172        self.opts['pars'][0] = pars
1173        self.opts['pars'][1] = pars
1174        limits = compare(self.opts, limits=self.limits)
1175        if self.limits is None:
1176            vmin, vmax = limits
1177            self.limits = vmax*1e-7, 1.3*vmax
1178
1179
1180def main(*argv):
1181    # type: (*str) -> None
1182    """
1183    Main program.
1184    """
1185    opts = parse_opts(argv)
1186    if opts is not None:
1187        if opts['html']:
1188            show_docs(opts)
1189        elif opts['explore']:
1190            explore(opts)
1191        else:
1192            compare(opts)
1193
1194if __name__ == "__main__":
1195    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.