source: sasmodels/sasmodels/compare.py @ a0d75ce

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a0d75ce was a0d75ce, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

add title option to sascomp to provide more details on the graph

  • Property mode set to 100755
File size: 36.7 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np  # type: ignore
37
38from . import core
39from . import kerneldll
40from . import weights
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_name, revert_pars, constrain_new_to_old
44
45try:
46    from typing import Optional, Dict, Any, Callable, Tuple
47except:
48    pass
49else:
50    from .modelinfo import ModelInfo, Parameter, ParameterSet
51    from .data import Data
52    Calculator = Callable[[float], np.ndarray]
53
54USAGE = """
55usage: compare.py model N1 N2 [options...] [key=val]
56
57Compare the speed and value for a model between the SasView original and the
58sasmodels rewrite.
59
60model is the name of the model to compare (see below).
61N1 is the number of times to run sasmodels (default=1).
62N2 is the number times to run sasview (default=1).
63
64Options (* for default):
65
66    -plot*/-noplot plots or suppress the plot of the model
67    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
68    -nq=128 sets the number of Q points in the data set
69    -zero indicates that q=0 should be included
70    -1d*/-2d computes 1d or 2d data
71    -preset*/-random[=seed] preset or random parameters
72    -mono/-poly* force monodisperse/polydisperse
73    -cutoff=1e-5* cutoff value for including a point in polydispersity
74    -pars/-nopars* prints the parameter set or not
75    -abs/-rel* plot relative or absolute error
76    -linear/-log*/-q4 intensity scaling
77    -hist/-nohist* plot histogram of relative error
78    -res=0 sets the resolution width dQ/Q if calculating with resolution
79    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
80    -edit starts the parameter explorer
81    -default/-demo* use demo vs default parameters
82    -html shows the model docs instead of running the model
83    -title="graph title" adds a title to the plot
84
85Any two calculation engines can be selected for comparison:
86
87    -single/-double/-half/-fast sets an OpenCL calculation engine
88    -single!/-double!/-quad! sets an OpenMP calculation engine
89    -sasview sets the sasview calculation engine
90
91The default is -single -sasview.  Note that the interpretation of quad
92precision depends on architecture, and may vary from 64-bit to 128-bit,
93with 80-bit floats being common (1e-19 precision).
94
95Key=value pairs allow you to set specific values for the model parameters.
96"""
97
98# Update docs with command line usage string.   This is separate from the usual
99# doc string so that we can display it at run time if there is an error.
100# lin
101__doc__ = (__doc__  # pylint: disable=redefined-builtin
102           + """
103Program description
104-------------------
105
106"""
107           + USAGE)
108
109kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
110
111# CRUFT python 2.6
112if not hasattr(datetime.timedelta, 'total_seconds'):
113    def delay(dt):
114        """Return number date-time delta as number seconds"""
115        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
116else:
117    def delay(dt):
118        """Return number date-time delta as number seconds"""
119        return dt.total_seconds()
120
121
122class push_seed(object):
123    """
124    Set the seed value for the random number generator.
125
126    When used in a with statement, the random number generator state is
127    restored after the with statement is complete.
128
129    :Parameters:
130
131    *seed* : int or array_like, optional
132        Seed for RandomState
133
134    :Example:
135
136    Seed can be used directly to set the seed::
137
138        >>> from numpy.random import randint
139        >>> push_seed(24)
140        <...push_seed object at...>
141        >>> print(randint(0,1000000,3))
142        [242082    899 211136]
143
144    Seed can also be used in a with statement, which sets the random
145    number generator state for the enclosed computations and restores
146    it to the previous state on completion::
147
148        >>> with push_seed(24):
149        ...    print(randint(0,1000000,3))
150        [242082    899 211136]
151
152    Using nested contexts, we can demonstrate that state is indeed
153    restored after the block completes::
154
155        >>> with push_seed(24):
156        ...    print(randint(0,1000000))
157        ...    with push_seed(24):
158        ...        print(randint(0,1000000,3))
159        ...    print(randint(0,1000000))
160        242082
161        [242082    899 211136]
162        899
163
164    The restore step is protected against exceptions in the block::
165
166        >>> with push_seed(24):
167        ...    print(randint(0,1000000))
168        ...    try:
169        ...        with push_seed(24):
170        ...            print(randint(0,1000000,3))
171        ...            raise Exception()
172        ...    except Exception:
173        ...        print("Exception raised")
174        ...    print(randint(0,1000000))
175        242082
176        [242082    899 211136]
177        Exception raised
178        899
179    """
180    def __init__(self, seed=None):
181        # type: (Optional[int]) -> None
182        self._state = np.random.get_state()
183        np.random.seed(seed)
184
185    def __enter__(self):
186        # type: () -> None
187        pass
188
189    def __exit__(self, exc_type, exc_value, traceback):
190        # type: (Any, BaseException, Any) -> None
191        # TODO: better typing for __exit__ method
192        np.random.set_state(self._state)
193
194def tic():
195    # type: () -> Callable[[], float]
196    """
197    Timer function.
198
199    Use "toc=tic()" to start the clock and "toc()" to measure
200    a time interval.
201    """
202    then = datetime.datetime.now()
203    return lambda: delay(datetime.datetime.now() - then)
204
205
206def set_beam_stop(data, radius, outer=None):
207    # type: (Data, float, float) -> None
208    """
209    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
210
211    Note: this function does not require sasview
212    """
213    if hasattr(data, 'qx_data'):
214        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
215        data.mask = (q < radius)
216        if outer is not None:
217            data.mask |= (q >= outer)
218    else:
219        data.mask = (data.x < radius)
220        if outer is not None:
221            data.mask |= (data.x >= outer)
222
223
224def parameter_range(p, v):
225    # type: (str, float) -> Tuple[float, float]
226    """
227    Choose a parameter range based on parameter name and initial value.
228    """
229    # process the polydispersity options
230    if p.endswith('_pd_n'):
231        return 0., 100.
232    elif p.endswith('_pd_nsigma'):
233        return 0., 5.
234    elif p.endswith('_pd_type'):
235        raise ValueError("Cannot return a range for a string value")
236    elif any(s in p for s in ('theta', 'phi', 'psi')):
237        # orientation in [-180,180], orientation pd in [0,45]
238        if p.endswith('_pd'):
239            return 0., 45.
240        else:
241            return -180., 180.
242    elif p.endswith('_pd'):
243        return 0., 1.
244    elif 'sld' in p:
245        return -0.5, 10.
246    elif p == 'background':
247        return 0., 10.
248    elif p == 'scale':
249        return 0., 1.e3
250    elif v < 0.:
251        return 2.*v, -2.*v
252    else:
253        return 0., (2.*v if v > 0. else 1.)
254
255
256def _randomize_one(model_info, p, v):
257    # type: (ModelInfo, str, float) -> float
258    # type: (ModelInfo, str, str) -> str
259    """
260    Randomize a single parameter.
261    """
262    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
263        return v
264
265    # Find the parameter definition
266    for par in model_info.parameters.call_parameters:
267        if par.name == p:
268            break
269    else:
270        raise ValueError("unknown parameter %r"%p)
271
272    if len(par.limits) > 2:  # choice list
273        return np.random.randint(len(par.limits))
274
275    limits = par.limits
276    if not np.isfinite(limits).all():
277        low, high = parameter_range(p, v)
278        limits = (max(limits[0], low), min(limits[1], high))
279
280    return np.random.uniform(*limits)
281
282
283def randomize_pars(model_info, pars, seed=None):
284    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
285    """
286    Generate random values for all of the parameters.
287
288    Valid ranges for the random number generator are guessed from the name of
289    the parameter; this will not account for constraints such as cap radius
290    greater than cylinder radius in the capped_cylinder model, so
291    :func:`constrain_pars` needs to be called afterward..
292    """
293    with push_seed(seed):
294        # Note: the sort guarantees order `of calls to random number generator
295        random_pars = dict((p, _randomize_one(model_info, p, v))
296                           for p, v in sorted(pars.items()))
297    return random_pars
298
299def constrain_pars(model_info, pars):
300    # type: (ModelInfo, ParameterSet) -> None
301    """
302    Restrict parameters to valid values.
303
304    This includes model specific code for models such as capped_cylinder
305    which need to support within model constraints (cap radius more than
306    cylinder radius in this case).
307
308    Warning: this updates the *pars* dictionary in place.
309    """
310    name = model_info.id
311    # if it is a product model, then just look at the form factor since
312    # none of the structure factors need any constraints.
313    if '*' in name:
314        name = name.split('*')[0]
315
316    if name == 'capped_cylinder' and pars['radius_cap'] < pars['radius']:
317        pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
318    if name == 'barbell' and pars['radius_bell'] < pars['radius']:
319        pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
320
321    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
322    if name == 'guinier':
323        #q_max = 0.2  # mid q maximum
324        q_max = 1.0  # high q maximum
325        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
326        pars['rg'] = min(pars['rg'], rg_max)
327
328    if name == 'rpa':
329        # Make sure phi sums to 1.0
330        if pars['case_num'] < 2:
331            pars['Phi1'] = 0.
332            pars['Phi2'] = 0.
333        elif pars['case_num'] < 5:
334            pars['Phi1'] = 0.
335        total = sum(pars['Phi'+c] for c in '1234')
336        for c in '1234':
337            pars['Phi'+c] /= total
338
339def parlist(model_info, pars, is2d):
340    # type: (ModelInfo, ParameterSet, bool) -> str
341    """
342    Format the parameter list for printing.
343    """
344    lines = []
345    parameters = model_info.parameters
346    for p in parameters.user_parameters(pars, is2d):
347        fields = dict(
348            value=pars.get(p.id, p.default),
349            pd=pars.get(p.id+"_pd", 0.),
350            n=int(pars.get(p.id+"_pd_n", 0)),
351            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
352            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
353            relative_pd=p.relative_pd,
354        )
355        lines.append(_format_par(p.name, **fields))
356    return "\n".join(lines)
357
358    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
359
360def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
361                relative_pd=False):
362    # type: (str, float, float, int, float, str) -> str
363    line = "%s: %g"%(name, value)
364    if pd != 0.  and n != 0:
365        if relative_pd:
366            pd *= value
367        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
368                % (pd, n, nsigma, nsigma, pdtype)
369    return line
370
371def suppress_pd(pars):
372    # type: (ParameterSet) -> ParameterSet
373    """
374    Suppress theta_pd for now until the normalization is resolved.
375
376    May also suppress complete polydispersity of the model to test
377    models more quickly.
378    """
379    pars = pars.copy()
380    for p in pars:
381        if p.endswith("_pd_n"): pars[p] = 0
382    return pars
383
384def eval_sasview(model_info, data):
385    # type: (Modelinfo, Data) -> Calculator
386    """
387    Return a model calculator using the pre-4.0 SasView models.
388    """
389    # importing sas here so that the error message will be that sas failed to
390    # import rather than the more obscure smear_selection not imported error
391    import sas
392    import sas.models
393    from sas.models.qsmearing import smear_selection
394    from sas.models.MultiplicationModel import MultiplicationModel
395    from sas.models.dispersion_models import models as dispersers
396
397    def get_model_class(name):
398        # type: (str) -> "sas.models.BaseComponent"
399        #print("new",sorted(_pars.items()))
400        __import__('sas.models.' + name)
401        ModelClass = getattr(getattr(sas.models, name, None), name, None)
402        if ModelClass is None:
403            raise ValueError("could not find model %r in sas.models"%name)
404        return ModelClass
405
406    # WARNING: ugly hack when handling model!
407    # Sasview models with multiplicity need to be created with the target
408    # multiplicity, so we cannot create the target model ahead of time for
409    # for multiplicity models.  Instead we store the model in a list and
410    # update the first element of that list with the new multiplicity model
411    # every time we evaluate.
412
413    # grab the sasview model, or create it if it is a product model
414    if model_info.composition:
415        composition_type, parts = model_info.composition
416        if composition_type == 'product':
417            P, S = [get_model_class(revert_name(p))() for p in parts]
418            model = [MultiplicationModel(P, S)]
419        else:
420            raise ValueError("sasview mixture models not supported by compare")
421    else:
422        old_name = revert_name(model_info)
423        if old_name is None:
424            raise ValueError("model %r does not exist in old sasview"
425                            % model_info.id)
426        ModelClass = get_model_class(old_name)
427        model = [ModelClass()]
428    model[0].disperser_handles = {}
429
430    # build a smearer with which to call the model, if necessary
431    smearer = smear_selection(data, model=model)
432    if hasattr(data, 'qx_data'):
433        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
434        index = ((~data.mask) & (~np.isnan(data.data))
435                 & (q >= data.qmin) & (q <= data.qmax))
436        if smearer is not None:
437            smearer.model = model  # because smear_selection has a bug
438            smearer.accuracy = data.accuracy
439            smearer.set_index(index)
440            def _call_smearer():
441                smearer.model = model[0]
442                return smearer.get_value()
443            theory = _call_smearer
444        else:
445            theory = lambda: model[0].evalDistribution([data.qx_data[index],
446                                                        data.qy_data[index]])
447    elif smearer is not None:
448        theory = lambda: smearer(model[0].evalDistribution(data.x))
449    else:
450        theory = lambda: model[0].evalDistribution(data.x)
451
452    def calculator(**pars):
453        # type: (float, ...) -> np.ndarray
454        """
455        Sasview calculator for model.
456        """
457        oldpars = revert_pars(model_info, pars)
458        # For multiplicity models, create a model with the correct multiplicity
459        control = oldpars.pop("CONTROL", None)
460        if control is not None:
461            # sphericalSLD has one fewer multiplicity.  This update should
462            # happen in revert_pars, but it hasn't been called yet.
463            model[0] = ModelClass(control)
464        # paying for parameter conversion each time to keep life simple, if not fast
465        for k, v in oldpars.items():
466            if k.endswith('.type'):
467                par = k[:-5]
468                cls = dispersers[v if v != 'rectangle' else 'rectangula']
469                handle = cls()
470                model[0].disperser_handles[par] = handle
471                model[0].set_dispersion(par, handle)
472
473        #print("sasview pars",oldpars)
474        for k, v in oldpars.items():
475            name_attr = k.split('.')  # polydispersity components
476            if len(name_attr) == 2:
477                par, disp_par = name_attr
478                model[0].dispersion[par][disp_par] = v
479            else:
480                model[0].setParam(k, v)
481        return theory()
482
483    calculator.engine = "sasview"
484    return calculator
485
486DTYPE_MAP = {
487    'half': '16',
488    'fast': 'fast',
489    'single': '32',
490    'double': '64',
491    'quad': '128',
492    'f16': '16',
493    'f32': '32',
494    'f64': '64',
495    'longdouble': '128',
496}
497def eval_opencl(model_info, data, dtype='single', cutoff=0.):
498    # type: (ModelInfo, Data, str, float) -> Calculator
499    """
500    Return a model calculator using the OpenCL calculation engine.
501    """
502    if not core.HAVE_OPENCL:
503        raise RuntimeError("OpenCL not available")
504    model = core.build_model(model_info, dtype=dtype, platform="ocl")
505    calculator = DirectModel(data, model, cutoff=cutoff)
506    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
507    return calculator
508
509def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
510    # type: (ModelInfo, Data, str, float) -> Calculator
511    """
512    Return a model calculator using the DLL calculation engine.
513    """
514    model = core.build_model(model_info, dtype=dtype, platform="dll")
515    calculator = DirectModel(data, model, cutoff=cutoff)
516    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
517    return calculator
518
519def time_calculation(calculator, pars, evals=1):
520    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
521    """
522    Compute the average calculation time over N evaluations.
523
524    An additional call is generated without polydispersity in order to
525    initialize the calculation engine, and make the average more stable.
526    """
527    # initialize the code so time is more accurate
528    if evals > 1:
529        calculator(**suppress_pd(pars))
530    toc = tic()
531    # make sure there is at least one eval
532    value = calculator(**pars)
533    for _ in range(evals-1):
534        value = calculator(**pars)
535    average_time = toc()*1000. / evals
536    #print("I(q)",value)
537    return value, average_time
538
539def make_data(opts):
540    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
541    """
542    Generate an empty dataset, used with the model to set Q points
543    and resolution.
544
545    *opts* contains the options, with 'qmax', 'nq', 'res',
546    'accuracy', 'is2d' and 'view' parsed from the command line.
547    """
548    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
549    if opts['is2d']:
550        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
551        data = empty_data2D(q, resolution=res)
552        data.accuracy = opts['accuracy']
553        set_beam_stop(data, 0.0004)
554        index = ~data.mask
555    else:
556        if opts['view'] == 'log' and not opts['zero']:
557            qmax = math.log10(qmax)
558            q = np.logspace(qmax-3, qmax, nq)
559        else:
560            q = np.linspace(0.001*qmax, qmax, nq)
561        if opts['zero']:
562            q = np.hstack((0, q))
563        data = empty_data1D(q, resolution=res)
564        index = slice(None, None)
565    return data, index
566
567def make_engine(model_info, data, dtype, cutoff):
568    # type: (ModelInfo, Data, str, float) -> Calculator
569    """
570    Generate the appropriate calculation engine for the given datatype.
571
572    Datatypes with '!' appended are evaluated using external C DLLs rather
573    than OpenCL.
574    """
575    if dtype == 'sasview':
576        return eval_sasview(model_info, data)
577    elif dtype.endswith('!'):
578        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
579    else:
580        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
581
582def _show_invalid(data, theory):
583    # type: (Data, np.ma.ndarray) -> None
584    """
585    Display a list of the non-finite values in theory.
586    """
587    if not theory.mask.any():
588        return
589
590    if hasattr(data, 'x'):
591        bad = zip(data.x[theory.mask], theory[theory.mask])
592        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
593
594
595def compare(opts, limits=None):
596    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
597    """
598    Preform a comparison using options from the command line.
599
600    *limits* are the limits on the values to use, either to set the y-axis
601    for 1D or to set the colormap scale for 2D.  If None, then they are
602    inferred from the data and returned. When exploring using Bumps,
603    the limits are set when the model is initially called, and maintained
604    as the values are adjusted, making it easier to see the effects of the
605    parameters.
606    """
607    n_base, n_comp = opts['n1'], opts['n2']
608    pars = opts['pars']
609    data = opts['data']
610
611    # silence the linter
612    base = opts['engines'][0] if n_base else None
613    comp = opts['engines'][1] if n_comp else None
614    base_time = comp_time = None
615    base_value = comp_value = resid = relerr = None
616
617    # Base calculation
618    if n_base > 0:
619        try:
620            base_raw, base_time = time_calculation(base, pars, n_base)
621            base_value = np.ma.masked_invalid(base_raw)
622            print("%s t=%.2f ms, intensity=%.0f"
623                  % (base.engine, base_time, base_value.sum()))
624            _show_invalid(data, base_value)
625        except ImportError:
626            traceback.print_exc()
627            n_base = 0
628
629    # Comparison calculation
630    if n_comp > 0:
631        try:
632            comp_raw, comp_time = time_calculation(comp, pars, n_comp)
633            comp_value = np.ma.masked_invalid(comp_raw)
634            print("%s t=%.2f ms, intensity=%.0f"
635                  % (comp.engine, comp_time, comp_value.sum()))
636            _show_invalid(data, comp_value)
637        except ImportError:
638            traceback.print_exc()
639            n_comp = 0
640
641    # Compare, but only if computing both forms
642    if n_base > 0 and n_comp > 0:
643        resid = (base_value - comp_value)
644        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
645        _print_stats("|%s-%s|"
646                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
647                     resid)
648        _print_stats("|(%s-%s)/%s|"
649                     % (base.engine, comp.engine, comp.engine),
650                     relerr)
651
652    # Plot if requested
653    if not opts['plot'] and not opts['explore']: return
654    view = opts['view']
655    import matplotlib.pyplot as plt
656    if limits is None:
657        vmin, vmax = np.Inf, -np.Inf
658        if n_base > 0:
659            vmin = min(vmin, base_value.min())
660            vmax = max(vmax, base_value.max())
661        if n_comp > 0:
662            vmin = min(vmin, comp_value.min())
663            vmax = max(vmax, comp_value.max())
664        limits = vmin, vmax
665
666    if n_base > 0:
667        if n_comp > 0: plt.subplot(131)
668        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
669        plt.title("%s t=%.2f ms"%(base.engine, base_time))
670        #cbar_title = "log I"
671    if n_comp > 0:
672        if n_base > 0: plt.subplot(132)
673        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
674        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
675        #cbar_title = "log I"
676    if n_comp > 0 and n_base > 0:
677        plt.subplot(133)
678        if not opts['rel_err']:
679            err, errstr, errview = resid, "abs err", "linear"
680        else:
681            err, errstr, errview = abs(relerr), "rel err", "log"
682        #err,errstr = base/comp,"ratio"
683        plot_theory(data, None, resid=err, view=errview, use_data=False)
684        if view == 'linear':
685            plt.xscale('linear')
686        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
687        #cbar_title = errstr if errview=="linear" else "log "+errstr
688    #if is2D:
689    #    h = plt.colorbar()
690    #    h.ax.set_title(cbar_title)
691    fig = plt.gcf()
692    extra_title = ' '+opts['title'] if opts['title'] else ''
693    fig.suptitle(opts['name'] + extra_title)
694
695    if n_comp > 0 and n_base > 0 and '-hist' in opts:
696        plt.figure()
697        v = relerr
698        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
699        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
700        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
701                   % (base.engine, comp.engine, comp.engine))
702        plt.ylabel('P(err)')
703        plt.title('Distribution of relative error between calculation engines')
704
705    if not opts['explore']:
706        plt.show()
707
708    return limits
709
710def _print_stats(label, err):
711    # type: (str, np.ma.ndarray) -> None
712    # work with trimmed data, not the full set
713    sorted_err = np.sort(abs(err.compressed()))
714    p50 = int((len(sorted_err)-1)*0.50)
715    p98 = int((len(sorted_err)-1)*0.98)
716    data = [
717        "max:%.3e"%sorted_err[-1],
718        "median:%.3e"%sorted_err[p50],
719        "98%%:%.3e"%sorted_err[p98],
720        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
721        "zero-offset:%+.3e"%np.mean(sorted_err),
722        ]
723    print(label+"  "+"  ".join(data))
724
725
726
727# ===========================================================================
728#
729NAME_OPTIONS = set([
730    'plot', 'noplot',
731    'half', 'fast', 'single', 'double',
732    'single!', 'double!', 'quad!', 'sasview',
733    'lowq', 'midq', 'highq', 'exq', 'zero',
734    '2d', '1d',
735    'preset', 'random',
736    'poly', 'mono',
737    'nopars', 'pars',
738    'rel', 'abs',
739    'linear', 'log', 'q4',
740    'hist', 'nohist',
741    'edit', 'html',
742    'demo', 'default',
743    ])
744VALUE_OPTIONS = [
745    # Note: random is both a name option and a value option
746    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title',
747    ]
748
749def columnize(items, indent="", width=79):
750    # type: (List[str], str, int) -> str
751    """
752    Format a list of strings into columns.
753
754    Returns a string with carriage returns ready for printing.
755    """
756    column_width = max(len(w) for w in items) + 1
757    num_columns = (width - len(indent)) // column_width
758    num_rows = len(items) // num_columns
759    items = items + [""] * (num_rows * num_columns - len(items))
760    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
761    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
762             for row in zip(*columns)]
763    output = indent + ("\n"+indent).join(lines)
764    return output
765
766
767def get_pars(model_info, use_demo=False):
768    # type: (ModelInfo, bool) -> ParameterSet
769    """
770    Extract demo parameters from the model definition.
771    """
772    # Get the default values for the parameters
773    pars = {}
774    for p in model_info.parameters.call_parameters:
775        parts = [('', p.default)]
776        if p.polydisperse:
777            parts.append(('_pd', 0.0))
778            parts.append(('_pd_n', 0))
779            parts.append(('_pd_nsigma', 3.0))
780            parts.append(('_pd_type', "gaussian"))
781        for ext, val in parts:
782            if p.length > 1:
783                dict(("%s%d%s" % (p.id, k, ext), val)
784                     for k in range(1, p.length+1))
785            else:
786                pars[p.id + ext] = val
787
788    # Plug in values given in demo
789    if use_demo:
790        pars.update(model_info.demo)
791    return pars
792
793
794def parse_opts(argv):
795    # type: (List[str]) -> Dict[str, Any]
796    """
797    Parse command line options.
798    """
799    MODELS = core.list_models()
800    flags = [arg for arg in argv
801             if arg.startswith('-')]
802    values = [arg for arg in argv
803              if not arg.startswith('-') and '=' in arg]
804    positional_args = [arg for arg in argv
805            if not arg.startswith('-') and '=' not in arg]
806    models = "\n    ".join("%-15s"%v for v in MODELS)
807    if len(positional_args) == 0:
808        print(USAGE)
809        print("\nAvailable models:")
810        print(columnize(MODELS, indent="  "))
811        return None
812    if len(positional_args) > 3:
813        print("expected parameters: model N1 N2")
814
815    name = positional_args[0]
816    try:
817        model_info = core.load_model_info(name)
818    except ImportError as exc:
819        print(str(exc))
820        print("Could not find model; use one of:\n    " + models)
821        return None
822
823    invalid = [o[1:] for o in flags
824               if o[1:] not in NAME_OPTIONS
825               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
826    if invalid:
827        print("Invalid options: %s"%(", ".join(invalid)))
828        return None
829
830
831    # pylint: disable=bad-whitespace
832    # Interpret the flags
833    opts = {
834        'plot'      : True,
835        'view'      : 'log',
836        'is2d'      : False,
837        'qmax'      : 0.05,
838        'nq'        : 128,
839        'res'       : 0.0,
840        'accuracy'  : 'Low',
841        'cutoff'    : 0.0,
842        'seed'      : -1,  # default to preset
843        'mono'      : False,
844        'show_pars' : False,
845        'show_hist' : False,
846        'rel_err'   : True,
847        'explore'   : False,
848        'use_demo'  : True,
849        'zero'      : False,
850        'html'      : False,
851        'title'     : None,
852    }
853    engines = []
854    for arg in flags:
855        if arg == '-noplot':    opts['plot'] = False
856        elif arg == '-plot':    opts['plot'] = True
857        elif arg == '-linear':  opts['view'] = 'linear'
858        elif arg == '-log':     opts['view'] = 'log'
859        elif arg == '-q4':      opts['view'] = 'q4'
860        elif arg == '-1d':      opts['is2d'] = False
861        elif arg == '-2d':      opts['is2d'] = True
862        elif arg == '-exq':     opts['qmax'] = 10.0
863        elif arg == '-highq':   opts['qmax'] = 1.0
864        elif arg == '-midq':    opts['qmax'] = 0.2
865        elif arg == '-lowq':    opts['qmax'] = 0.05
866        elif arg == '-zero':    opts['zero'] = True
867        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
868        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
869        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
870        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
871        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
872        elif arg.startswith('-title'):     opts['title'] = arg[7:]
873        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
874        elif arg == '-preset':  opts['seed'] = -1
875        elif arg == '-mono':    opts['mono'] = True
876        elif arg == '-poly':    opts['mono'] = False
877        elif arg == '-pars':    opts['show_pars'] = True
878        elif arg == '-nopars':  opts['show_pars'] = False
879        elif arg == '-hist':    opts['show_hist'] = True
880        elif arg == '-nohist':  opts['show_hist'] = False
881        elif arg == '-rel':     opts['rel_err'] = True
882        elif arg == '-abs':     opts['rel_err'] = False
883        elif arg == '-half':    engines.append(arg[1:])
884        elif arg == '-fast':    engines.append(arg[1:])
885        elif arg == '-single':  engines.append(arg[1:])
886        elif arg == '-double':  engines.append(arg[1:])
887        elif arg == '-single!': engines.append(arg[1:])
888        elif arg == '-double!': engines.append(arg[1:])
889        elif arg == '-quad!':   engines.append(arg[1:])
890        elif arg == '-sasview': engines.append(arg[1:])
891        elif arg == '-edit':    opts['explore'] = True
892        elif arg == '-demo':    opts['use_demo'] = True
893        elif arg == '-default':    opts['use_demo'] = False
894        elif arg == '-html':    opts['html'] = True
895    # pylint: enable=bad-whitespace
896
897    if len(engines) == 0:
898        engines.extend(['single', 'double'])
899    elif len(engines) == 1:
900        if engines[0][0] == 'double':
901            engines.append('single')
902        else:
903            engines.append('double')
904    elif len(engines) > 2:
905        del engines[2:]
906
907    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
908    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
909    use_sasview = any(engine == 'sasview' and count > 0
910                      for engine, count in zip(engines, [n1, n2]))
911
912    # Get demo parameters from model definition, or use default parameters
913    # if model does not define demo parameters
914    pars = get_pars(model_info, opts['use_demo'])
915
916
917    # Fill in parameters given on the command line
918    presets = {}
919    for arg in values:
920        k, v = arg.split('=', 1)
921        if k not in pars:
922            # extract base name without polydispersity info
923            s = set(p.split('_pd')[0] for p in pars)
924            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
925            return None
926        presets[k] = float(v) if not k.endswith('type') else v
927
928    # randomize parameters
929    #pars.update(set_pars)  # set value before random to control range
930    if opts['seed'] > -1:
931        pars = randomize_pars(model_info, pars, seed=opts['seed'])
932        print("Randomize using -random=%i"%opts['seed'])
933    if opts['mono']:
934        pars = suppress_pd(pars)
935    pars.update(presets)  # set value after random to control value
936    #import pprint; pprint.pprint(model_info)
937    constrain_pars(model_info, pars)
938    if use_sasview:
939        constrain_new_to_old(model_info, pars)
940    if opts['show_pars']:
941        print(str(parlist(model_info, pars, opts['is2d'])))
942
943    # Create the computational engines
944    data, _ = make_data(opts)
945    if n1:
946        base = make_engine(model_info, data, engines[0], opts['cutoff'])
947    else:
948        base = None
949    if n2:
950        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
951    else:
952        comp = None
953
954    # pylint: disable=bad-whitespace
955    # Remember it all
956    opts.update({
957        'name'      : name,
958        'def'       : model_info,
959        'n1'        : n1,
960        'n2'        : n2,
961        'presets'   : presets,
962        'pars'      : pars,
963        'data'      : data,
964        'engines'   : [base, comp],
965    })
966    # pylint: enable=bad-whitespace
967
968    return opts
969
970def show_docs(opts):
971    # type: (Dict[str, Any]) -> None
972    """
973    show html docs for the model
974    """
975    import wx  # type: ignore
976    from .generate import view_html_from_info
977    app = wx.App() if wx.GetApp() is None else None
978    view_html_from_info(opts['def'])
979    if app: app.MainLoop()
980
981
982def explore(opts):
983    # type: (Dict[str, Any]) -> None
984    """
985    explore the model using the bumps gui.
986    """
987    import wx  # type: ignore
988    from bumps.names import FitProblem  # type: ignore
989    from bumps.gui.app_frame import AppFrame  # type: ignore
990
991    is_mac = "cocoa" in wx.version()
992    # Create an app if not running embedded
993    app = wx.App() if wx.GetApp() is None else None
994    problem = FitProblem(Explore(opts))
995    frame = AppFrame(parent=None, title="explore", size=(1000,700))
996    if not is_mac: frame.Show()
997    frame.panel.set_model(model=problem)
998    frame.panel.Layout()
999    frame.panel.aui.Split(0, wx.TOP)
1000    if is_mac: frame.Show()
1001    # If running withing an app, start the main loop
1002    if app: app.MainLoop()
1003
1004class Explore(object):
1005    """
1006    Bumps wrapper for a SAS model comparison.
1007
1008    The resulting object can be used as a Bumps fit problem so that
1009    parameters can be adjusted in the GUI, with plots updated on the fly.
1010    """
1011    def __init__(self, opts):
1012        # type: (Dict[str, Any]) -> None
1013        from bumps.cli import config_matplotlib  # type: ignore
1014        from . import bumps_model
1015        config_matplotlib()
1016        self.opts = opts
1017        model_info = opts['def']
1018        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
1019        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1020        if not opts['is2d']:
1021            for p in model_info.parameters.user_parameters(is2d=False):
1022                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1023                    k = p.name+ext
1024                    v = pars.get(k, None)
1025                    if v is not None:
1026                        v.range(*parameter_range(k, v.value))
1027        else:
1028            for k, v in pars.items():
1029                v.range(*parameter_range(k, v.value))
1030
1031        self.pars = pars
1032        self.pd_types = pd_types
1033        self.limits = None
1034
1035    def numpoints(self):
1036        # type: () -> int
1037        """
1038        Return the number of points.
1039        """
1040        return len(self.pars) + 1  # so dof is 1
1041
1042    def parameters(self):
1043        # type: () -> Any   # Dict/List hierarchy of parameters
1044        """
1045        Return a dictionary of parameters.
1046        """
1047        return self.pars
1048
1049    def nllf(self):
1050        # type: () -> float
1051        """
1052        Return cost.
1053        """
1054        # pylint: disable=no-self-use
1055        return 0.  # No nllf
1056
1057    def plot(self, view='log'):
1058        # type: (str) -> None
1059        """
1060        Plot the data and residuals.
1061        """
1062        pars = dict((k, v.value) for k, v in self.pars.items())
1063        pars.update(self.pd_types)
1064        self.opts['pars'] = pars
1065        limits = compare(self.opts, limits=self.limits)
1066        if self.limits is None:
1067            vmin, vmax = limits
1068            self.limits = vmax*1e-7, 1.3*vmax
1069
1070
1071def main(*argv):
1072    # type: (*str) -> None
1073    """
1074    Main program.
1075    """
1076    opts = parse_opts(argv)
1077    if opts is not None:
1078        if opts['html']:
1079            show_docs(opts)
1080        elif opts['explore']:
1081            explore(opts)
1082        else:
1083            compare(opts)
1084
1085if __name__ == "__main__":
1086    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.