source: sasmodels/sasmodels/compare.py @ f67f26c

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since f67f26c was f67f26c, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

multilayer vesicle: shell thickness is not an sld

  • Property mode set to 100755
File size: 34.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np  # type: ignore
37
38from . import core
39from . import kerneldll
40from .data import plot_theory, empty_data1D, empty_data2D
41from .direct_model import DirectModel
42from .convert import revert_name, revert_pars, constrain_new_to_old
43
44try:
45    from typing import Optional, Dict, Any, Callable, Tuple
46except:
47    pass
48else:
49    from .modelinfo import ModelInfo, Parameter, ParameterSet
50    from .data import Data
51    Calculator = Callable[[float], np.ndarray]
52
53USAGE = """
54usage: compare.py model N1 N2 [options...] [key=val]
55
56Compare the speed and value for a model between the SasView original and the
57sasmodels rewrite.
58
59model is the name of the model to compare (see below).
60N1 is the number of times to run sasmodels (default=1).
61N2 is the number times to run sasview (default=1).
62
63Options (* for default):
64
65    -plot*/-noplot plots or suppress the plot of the model
66    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
67    -nq=128 sets the number of Q points in the data set
68    -zero indicates that q=0 should be included
69    -1d*/-2d computes 1d or 2d data
70    -preset*/-random[=seed] preset or random parameters
71    -mono/-poly* force monodisperse/polydisperse
72    -cutoff=1e-5* cutoff value for including a point in polydispersity
73    -pars/-nopars* prints the parameter set or not
74    -abs/-rel* plot relative or absolute error
75    -linear/-log*/-q4 intensity scaling
76    -hist/-nohist* plot histogram of relative error
77    -res=0 sets the resolution width dQ/Q if calculating with resolution
78    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
79    -edit starts the parameter explorer
80    -default/-demo* use demo vs default parameters
81
82Any two calculation engines can be selected for comparison:
83
84    -single/-double/-half/-fast sets an OpenCL calculation engine
85    -single!/-double!/-quad! sets an OpenMP calculation engine
86    -sasview sets the sasview calculation engine
87
88The default is -single -sasview.  Note that the interpretation of quad
89precision depends on architecture, and may vary from 64-bit to 128-bit,
90with 80-bit floats being common (1e-19 precision).
91
92Key=value pairs allow you to set specific values for the model parameters.
93"""
94
95# Update docs with command line usage string.   This is separate from the usual
96# doc string so that we can display it at run time if there is an error.
97# lin
98__doc__ = (__doc__  # pylint: disable=redefined-builtin
99           + """
100Program description
101-------------------
102
103"""
104           + USAGE)
105
106kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
107
108# CRUFT python 2.6
109if not hasattr(datetime.timedelta, 'total_seconds'):
110    def delay(dt):
111        """Return number date-time delta as number seconds"""
112        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
113else:
114    def delay(dt):
115        """Return number date-time delta as number seconds"""
116        return dt.total_seconds()
117
118
119class push_seed(object):
120    """
121    Set the seed value for the random number generator.
122
123    When used in a with statement, the random number generator state is
124    restored after the with statement is complete.
125
126    :Parameters:
127
128    *seed* : int or array_like, optional
129        Seed for RandomState
130
131    :Example:
132
133    Seed can be used directly to set the seed::
134
135        >>> from numpy.random import randint
136        >>> push_seed(24)
137        <...push_seed object at...>
138        >>> print(randint(0,1000000,3))
139        [242082    899 211136]
140
141    Seed can also be used in a with statement, which sets the random
142    number generator state for the enclosed computations and restores
143    it to the previous state on completion::
144
145        >>> with push_seed(24):
146        ...    print(randint(0,1000000,3))
147        [242082    899 211136]
148
149    Using nested contexts, we can demonstrate that state is indeed
150    restored after the block completes::
151
152        >>> with push_seed(24):
153        ...    print(randint(0,1000000))
154        ...    with push_seed(24):
155        ...        print(randint(0,1000000,3))
156        ...    print(randint(0,1000000))
157        242082
158        [242082    899 211136]
159        899
160
161    The restore step is protected against exceptions in the block::
162
163        >>> with push_seed(24):
164        ...    print(randint(0,1000000))
165        ...    try:
166        ...        with push_seed(24):
167        ...            print(randint(0,1000000,3))
168        ...            raise Exception()
169        ...    except Exception:
170        ...        print("Exception raised")
171        ...    print(randint(0,1000000))
172        242082
173        [242082    899 211136]
174        Exception raised
175        899
176    """
177    def __init__(self, seed=None):
178        # type: (Optional[int]) -> None
179        self._state = np.random.get_state()
180        np.random.seed(seed)
181
182    def __enter__(self):
183        # type: () -> None
184        pass
185
186    def __exit__(self, type, value, traceback):
187        # type: (Any, BaseException, Any) -> None
188        # TODO: better typing for __exit__ method
189        np.random.set_state(self._state)
190
191def tic():
192    # type: () -> Callable[[], float]
193    """
194    Timer function.
195
196    Use "toc=tic()" to start the clock and "toc()" to measure
197    a time interval.
198    """
199    then = datetime.datetime.now()
200    return lambda: delay(datetime.datetime.now() - then)
201
202
203def set_beam_stop(data, radius, outer=None):
204    # type: (Data, float, float) -> None
205    """
206    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
207
208    Note: this function does not require sasview
209    """
210    if hasattr(data, 'qx_data'):
211        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
212        data.mask = (q < radius)
213        if outer is not None:
214            data.mask |= (q >= outer)
215    else:
216        data.mask = (data.x < radius)
217        if outer is not None:
218            data.mask |= (data.x >= outer)
219
220
221def parameter_range(p, v):
222    # type: (str, float) -> Tuple[float, float]
223    """
224    Choose a parameter range based on parameter name and initial value.
225    """
226    # process the polydispersity options
227    if p.endswith('_pd_n'):
228        return 0., 100.
229    elif p.endswith('_pd_nsigma'):
230        return 0., 5.
231    elif p.endswith('_pd_type'):
232        raise ValueError("Cannot return a range for a string value")
233    elif any(s in p for s in ('theta', 'phi', 'psi')):
234        # orientation in [-180,180], orientation pd in [0,45]
235        if p.endswith('_pd'):
236            return 0., 45.
237        else:
238            return -180., 180.
239    elif p.endswith('_pd'):
240        return 0., 1.
241    elif 'sld' in p:
242        return -0.5, 10.
243    elif p == 'background':
244        return 0., 10.
245    elif p == 'scale':
246        return 0., 1.e3
247    elif v < 0.:
248        return 2.*v, -2.*v
249    else:
250        return 0., (2.*v if v > 0. else 1.)
251
252
253def _randomize_one(model_info, p, v):
254    # type: (ModelInfo, str, float) -> float
255    # type: (ModelInfo, str, str) -> str
256    """
257    Randomize a single parameter.
258    """
259    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
260        return v
261
262    # Find the parameter definition
263    for par in model_info.parameters.call_parameters:
264        if par.name == p:
265            break
266    else:
267        raise ValueError("unknown parameter %r"%p)
268
269    if len(par.limits) > 2:  # choice list
270        return np.random.randint(len(par.limits))
271
272    limits = par.limits
273    if not np.isfinite(limits).all():
274        low, high = parameter_range(p, v)
275        limits = (max(limits[0], low), min(limits[1], high))
276
277    return np.random.uniform(*limits)
278
279
280def randomize_pars(model_info, pars, seed=None):
281    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
282    """
283    Generate random values for all of the parameters.
284
285    Valid ranges for the random number generator are guessed from the name of
286    the parameter; this will not account for constraints such as cap radius
287    greater than cylinder radius in the capped_cylinder model, so
288    :func:`constrain_pars` needs to be called afterward..
289    """
290    with push_seed(seed):
291        # Note: the sort guarantees order `of calls to random number generator
292        random_pars = dict((p, _randomize_one(model_info, p, v))
293                           for p, v in sorted(pars.items()))
294    return random_pars
295
296def constrain_pars(model_info, pars):
297    # type: (ModelInfo, ParameterSet) -> None
298    """
299    Restrict parameters to valid values.
300
301    This includes model specific code for models such as capped_cylinder
302    which need to support within model constraints (cap radius more than
303    cylinder radius in this case).
304
305    Warning: this updates the *pars* dictionary in place.
306    """
307    name = model_info.id
308    # if it is a product model, then just look at the form factor since
309    # none of the structure factors need any constraints.
310    if '*' in name:
311        name = name.split('*')[0]
312
313    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
314        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
315    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
316        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
317
318    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
319    if name == 'guinier':
320        #q_max = 0.2  # mid q maximum
321        q_max = 1.0  # high q maximum
322        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
323        pars['rg'] = min(pars['rg'], rg_max)
324
325    if name == 'rpa':
326        # Make sure phi sums to 1.0
327        if pars['case_num'] < 2:
328            pars['Phi1'] = 0.
329            pars['Phi2'] = 0.
330        elif pars['case_num'] < 5:
331            pars['Phi1'] = 0.
332        total = sum(pars['Phi'+c] for c in '1234')
333        for c in '1234':
334            pars['Phi'+c] /= total
335
336def parlist(model_info, pars, is2d):
337    # type: (ModelInfo, ParameterSet, bool) -> str
338    """
339    Format the parameter list for printing.
340    """
341    lines = []
342    parameters = model_info.parameters
343    for p in parameters.user_parameters(pars, is2d):
344        fields = dict(
345            value=pars.get(p.id, p.default),
346            pd=pars.get(p.id+"_pd", 0.),
347            n=int(pars.get(p.id+"_pd_n", 0)),
348            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
349            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
350        )
351        lines.append(_format_par(p.name, **fields))
352    return "\n".join(lines)
353
354    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
355
356def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian'):
357    # type: (str, float, float, int, float, str) -> str
358    line = "%s: %g"%(name, value)
359    if pd != 0.  and n != 0:
360        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
361                % (pd, n, nsigma, nsigma, pdtype)
362    return line
363
364def suppress_pd(pars):
365    # type: (ParameterSet) -> ParameterSet
366    """
367    Suppress theta_pd for now until the normalization is resolved.
368
369    May also suppress complete polydispersity of the model to test
370    models more quickly.
371    """
372    pars = pars.copy()
373    for p in pars:
374        if p.endswith("_pd_n"): pars[p] = 0
375    return pars
376
377def eval_sasview(model_info, data):
378    # type: (Modelinfo, Data) -> Calculator
379    """
380    Return a model calculator using the pre-4.0 SasView models.
381    """
382    # importing sas here so that the error message will be that sas failed to
383    # import rather than the more obscure smear_selection not imported error
384    import sas
385    import sas.models
386    from sas.models.qsmearing import smear_selection
387    from sas.models.MultiplicationModel import MultiplicationModel
388
389    def get_model_class(name):
390        # type: (str) -> "sas.models.BaseComponent"
391        #print("new",sorted(_pars.items()))
392        __import__('sas.models.' + name)
393        ModelClass = getattr(getattr(sas.models, name, None), name, None)
394        if ModelClass is None:
395            raise ValueError("could not find model %r in sas.models"%name)
396        return ModelClass
397
398    # WARNING: ugly hack when handling model!
399    # Sasview models with multiplicity need to be created with the target
400    # multiplicity, so we cannot create the target model ahead of time for
401    # for multiplicity models.  Instead we store the model in a list and
402    # update the first element of that list with the new multiplicity model
403    # every time we evaluate.
404
405    # grab the sasview model, or create it if it is a product model
406    if model_info.composition:
407        composition_type, parts = model_info.composition
408        if composition_type == 'product':
409            P, S = [get_model_class(revert_name(p))() for p in parts]
410            model = [MultiplicationModel(P, S)]
411        else:
412            raise ValueError("sasview mixture models not supported by compare")
413    else:
414        old_name = revert_name(model_info)
415        if old_name is None:
416            raise ValueError("model %r does not exist in old sasview"
417                            % model_info.id)
418        ModelClass = get_model_class(old_name)
419        model = [ModelClass()]
420
421    # build a smearer with which to call the model, if necessary
422    smearer = smear_selection(data, model=model)
423    if hasattr(data, 'qx_data'):
424        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
425        index = ((~data.mask) & (~np.isnan(data.data))
426                 & (q >= data.qmin) & (q <= data.qmax))
427        if smearer is not None:
428            smearer.model = model  # because smear_selection has a bug
429            smearer.accuracy = data.accuracy
430            smearer.set_index(index)
431            def _call_smearer():
432                smearer.model = model[0]
433                return smearer.get_value()
434            theory = lambda: _call_smearer()
435        else:
436            theory = lambda: model[0].evalDistribution([data.qx_data[index],
437                                                        data.qy_data[index]])
438    elif smearer is not None:
439        theory = lambda: smearer(model[0].evalDistribution(data.x))
440    else:
441        theory = lambda: model[0].evalDistribution(data.x)
442
443    def calculator(**pars):
444        # type: (float, ...) -> np.ndarray
445        """
446        Sasview calculator for model.
447        """
448        # For multiplicity models, recreate the model the first time the
449        if model_info.control:
450            model[0] = ModelClass(int(pars[model_info.control]))
451        # paying for parameter conversion each time to keep life simple, if not fast
452        oldpars = revert_pars(model_info, pars)
453        #print("sasview pars",oldpars)
454        for k, v in oldpars.items():
455            name_attr = k.split('.')  # polydispersity components
456            if len(name_attr) == 2:
457                model[0].dispersion[name_attr[0]][name_attr[1]] = v
458            else:
459                model[0].setParam(k, v)
460        return theory()
461
462    calculator.engine = "sasview"
463    return calculator
464
465DTYPE_MAP = {
466    'half': '16',
467    'fast': 'fast',
468    'single': '32',
469    'double': '64',
470    'quad': '128',
471    'f16': '16',
472    'f32': '32',
473    'f64': '64',
474    'longdouble': '128',
475}
476def eval_opencl(model_info, data, dtype='single', cutoff=0.):
477    # type: (ModelInfo, Data, str, float) -> Calculator
478    """
479    Return a model calculator using the OpenCL calculation engine.
480    """
481    if not core.HAVE_OPENCL:
482        raise RuntimeError("OpenCL not available")
483    model = core.build_model(model_info, dtype=dtype, platform="ocl")
484    calculator = DirectModel(data, model, cutoff=cutoff)
485    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
486    return calculator
487
488def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
489    # type: (ModelInfo, Data, str, float) -> Calculator
490    """
491    Return a model calculator using the DLL calculation engine.
492    """
493    model = core.build_model(model_info, dtype=dtype, platform="dll")
494    calculator = DirectModel(data, model, cutoff=cutoff)
495    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
496    return calculator
497
498def time_calculation(calculator, pars, Nevals=1):
499    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
500    """
501    Compute the average calculation time over N evaluations.
502
503    An additional call is generated without polydispersity in order to
504    initialize the calculation engine, and make the average more stable.
505    """
506    # initialize the code so time is more accurate
507    if Nevals > 1:
508        calculator(**suppress_pd(pars))
509    toc = tic()
510    # make sure there is at least one eval
511    value = calculator(**pars)
512    for _ in range(Nevals-1):
513        value = calculator(**pars)
514    average_time = toc()*1000./Nevals
515    #print("I(q)",value)
516    return value, average_time
517
518def make_data(opts):
519    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
520    """
521    Generate an empty dataset, used with the model to set Q points
522    and resolution.
523
524    *opts* contains the options, with 'qmax', 'nq', 'res',
525    'accuracy', 'is2d' and 'view' parsed from the command line.
526    """
527    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
528    if opts['is2d']:
529        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
530        data = empty_data2D(q, resolution=res)
531        data.accuracy = opts['accuracy']
532        set_beam_stop(data, 0.0004)
533        index = ~data.mask
534    else:
535        if opts['view'] == 'log' and not opts['zero']:
536            qmax = math.log10(qmax)
537            q = np.logspace(qmax-3, qmax, nq)
538        else:
539            q = np.linspace(0.001*qmax, qmax, nq)
540        if opts['zero']:
541            q = np.hstack((0, q))
542        data = empty_data1D(q, resolution=res)
543        index = slice(None, None)
544    return data, index
545
546def make_engine(model_info, data, dtype, cutoff):
547    # type: (ModelInfo, Data, str, float) -> Calculator
548    """
549    Generate the appropriate calculation engine for the given datatype.
550
551    Datatypes with '!' appended are evaluated using external C DLLs rather
552    than OpenCL.
553    """
554    if dtype == 'sasview':
555        return eval_sasview(model_info, data)
556    elif dtype.endswith('!'):
557        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
558    else:
559        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
560
561def _show_invalid(data, theory):
562    # type: (Data, np.ma.ndarray) -> None
563    """
564    Display a list of the non-finite values in theory.
565    """
566    if not theory.mask.any():
567        return
568
569    if hasattr(data, 'x'):
570        bad = zip(data.x[theory.mask], theory[theory.mask])
571        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
572
573
574def compare(opts, limits=None):
575    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
576    """
577    Preform a comparison using options from the command line.
578
579    *limits* are the limits on the values to use, either to set the y-axis
580    for 1D or to set the colormap scale for 2D.  If None, then they are
581    inferred from the data and returned. When exploring using Bumps,
582    the limits are set when the model is initially called, and maintained
583    as the values are adjusted, making it easier to see the effects of the
584    parameters.
585    """
586    Nbase, Ncomp = opts['n1'], opts['n2']
587    pars = opts['pars']
588    data = opts['data']
589
590    # silence the linter
591    base = opts['engines'][0] if Nbase else None
592    comp = opts['engines'][1] if Ncomp else None
593    base_time = comp_time = None
594    base_value = comp_value = resid = relerr = None
595
596    # Base calculation
597    if Nbase > 0:
598        try:
599            base_raw, base_time = time_calculation(base, pars, Nbase)
600            base_value = np.ma.masked_invalid(base_raw)
601            print("%s t=%.2f ms, intensity=%.0f"
602                  % (base.engine, base_time, base_value.sum()))
603            _show_invalid(data, base_value)
604        except ImportError:
605            traceback.print_exc()
606            Nbase = 0
607
608    # Comparison calculation
609    if Ncomp > 0:
610        try:
611            comp_raw, comp_time = time_calculation(comp, pars, Ncomp)
612            comp_value = np.ma.masked_invalid(comp_raw)
613            print("%s t=%.2f ms, intensity=%.0f"
614                  % (comp.engine, comp_time, comp_value.sum()))
615            _show_invalid(data, comp_value)
616        except ImportError:
617            traceback.print_exc()
618            Ncomp = 0
619
620    # Compare, but only if computing both forms
621    if Nbase > 0 and Ncomp > 0:
622        resid = (base_value - comp_value)
623        relerr = resid/np.where(comp_value!=0., abs(comp_value), 1.0)
624        _print_stats("|%s-%s|"
625                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
626                     resid)
627        _print_stats("|(%s-%s)/%s|"
628                     % (base.engine, comp.engine, comp.engine),
629                     relerr)
630
631    # Plot if requested
632    if not opts['plot'] and not opts['explore']: return
633    view = opts['view']
634    import matplotlib.pyplot as plt
635    if limits is None:
636        vmin, vmax = np.Inf, -np.Inf
637        if Nbase > 0:
638            vmin = min(vmin, base_value.min())
639            vmax = max(vmax, base_value.max())
640        if Ncomp > 0:
641            vmin = min(vmin, comp_value.min())
642            vmax = max(vmax, comp_value.max())
643        limits = vmin, vmax
644
645    if Nbase > 0:
646        if Ncomp > 0: plt.subplot(131)
647        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
648        plt.title("%s t=%.2f ms"%(base.engine, base_time))
649        #cbar_title = "log I"
650    if Ncomp > 0:
651        if Nbase > 0: plt.subplot(132)
652        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
653        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
654        #cbar_title = "log I"
655    if Ncomp > 0 and Nbase > 0:
656        plt.subplot(133)
657        if not opts['rel_err']:
658            err, errstr, errview = resid, "abs err", "linear"
659        else:
660            err, errstr, errview = abs(relerr), "rel err", "log"
661        #err,errstr = base/comp,"ratio"
662        plot_theory(data, None, resid=err, view=errview, use_data=False)
663        if view == 'linear':
664            plt.xscale('linear')
665        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
666        #cbar_title = errstr if errview=="linear" else "log "+errstr
667    #if is2D:
668    #    h = plt.colorbar()
669    #    h.ax.set_title(cbar_title)
670    fig = plt.gcf()
671    fig.suptitle(opts['name'])
672
673    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
674        plt.figure()
675        v = relerr
676        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
677        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
678        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
679                   % (base.engine, comp.engine, comp.engine))
680        plt.ylabel('P(err)')
681        plt.title('Distribution of relative error between calculation engines')
682
683    if not opts['explore']:
684        plt.show()
685
686    return limits
687
688def _print_stats(label, err):
689    # type: (str, np.ma.ndarray) -> None
690    # work with trimmed data, not the full set
691    sorted_err = np.sort(abs(err.compressed()))
692    p50 = int((len(sorted_err)-1)*0.50)
693    p98 = int((len(sorted_err)-1)*0.98)
694    data = [
695        "max:%.3e"%sorted_err[-1],
696        "median:%.3e"%sorted_err[p50],
697        "98%%:%.3e"%sorted_err[p98],
698        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
699        "zero-offset:%+.3e"%np.mean(sorted_err),
700        ]
701    print(label+"  "+"  ".join(data))
702
703
704
705# ===========================================================================
706#
707NAME_OPTIONS = set([
708    'plot', 'noplot',
709    'half', 'fast', 'single', 'double',
710    'single!', 'double!', 'quad!', 'sasview',
711    'lowq', 'midq', 'highq', 'exq', 'zero',
712    '2d', '1d',
713    'preset', 'random',
714    'poly', 'mono',
715    'nopars', 'pars',
716    'rel', 'abs',
717    'linear', 'log', 'q4',
718    'hist', 'nohist',
719    'edit',
720    'demo', 'default',
721    ])
722VALUE_OPTIONS = [
723    # Note: random is both a name option and a value option
724    'cutoff', 'random', 'nq', 'res', 'accuracy',
725    ]
726
727def columnize(L, indent="", width=79):
728    # type: (List[str], str, int) -> str
729    """
730    Format a list of strings into columns.
731
732    Returns a string with carriage returns ready for printing.
733    """
734    column_width = max(len(w) for w in L) + 1
735    num_columns = (width - len(indent)) // column_width
736    num_rows = len(L) // num_columns
737    L = L + [""] * (num_rows*num_columns - len(L))
738    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
739    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
740             for row in zip(*columns)]
741    output = indent + ("\n"+indent).join(lines)
742    return output
743
744
745def get_pars(model_info, use_demo=False):
746    # type: (ModelInfo, bool) -> ParameterSet
747    """
748    Extract demo parameters from the model definition.
749    """
750    # Get the default values for the parameters
751    pars = {}
752    for p in model_info.parameters.call_parameters:
753        parts = [('', p.default)]
754        if p.polydisperse:
755            parts.append(('_pd', 0.0))
756            parts.append(('_pd_n', 0))
757            parts.append(('_pd_nsigma', 3.0))
758            parts.append(('_pd_type', "gaussian"))
759        for ext, val in parts:
760            if p.length > 1:
761                dict(("%s%d%s"%(p.id,k,ext), val) for k in range(1, p.length+1))
762            else:
763                pars[p.id+ext] = val
764
765    # Plug in values given in demo
766    if use_demo:
767        pars.update(model_info.demo)
768    return pars
769
770
771def parse_opts():
772    # type: () -> Dict[str, Any]
773    """
774    Parse command line options.
775    """
776    MODELS = core.list_models()
777    flags = [arg for arg in sys.argv[1:]
778             if arg.startswith('-')]
779    values = [arg for arg in sys.argv[1:]
780              if not arg.startswith('-') and '=' in arg]
781    args = [arg for arg in sys.argv[1:]
782            if not arg.startswith('-') and '=' not in arg]
783    models = "\n    ".join("%-15s"%v for v in MODELS)
784    if len(args) == 0:
785        print(USAGE)
786        print("\nAvailable models:")
787        print(columnize(MODELS, indent="  "))
788        sys.exit(1)
789    if len(args) > 3:
790        print("expected parameters: model N1 N2")
791
792    name = args[0]
793    try:
794        model_info = core.load_model_info(name)
795    except ImportError as exc:
796        print(str(exc))
797        print("Could not find model; use one of:\n    " + models)
798        sys.exit(1)
799
800    invalid = [o[1:] for o in flags
801               if o[1:] not in NAME_OPTIONS
802               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
803    if invalid:
804        print("Invalid options: %s"%(", ".join(invalid)))
805        sys.exit(1)
806
807
808    # pylint: disable=bad-whitespace
809    # Interpret the flags
810    opts = {
811        'plot'      : True,
812        'view'      : 'log',
813        'is2d'      : False,
814        'qmax'      : 0.05,
815        'nq'        : 128,
816        'res'       : 0.0,
817        'accuracy'  : 'Low',
818        'cutoff'    : 0.0,
819        'seed'      : -1,  # default to preset
820        'mono'      : False,
821        'show_pars' : False,
822        'show_hist' : False,
823        'rel_err'   : True,
824        'explore'   : False,
825        'use_demo'  : True,
826        'zero'      : False,
827    }
828    engines = []
829    for arg in flags:
830        if arg == '-noplot':    opts['plot'] = False
831        elif arg == '-plot':    opts['plot'] = True
832        elif arg == '-linear':  opts['view'] = 'linear'
833        elif arg == '-log':     opts['view'] = 'log'
834        elif arg == '-q4':      opts['view'] = 'q4'
835        elif arg == '-1d':      opts['is2d'] = False
836        elif arg == '-2d':      opts['is2d'] = True
837        elif arg == '-exq':     opts['qmax'] = 10.0
838        elif arg == '-highq':   opts['qmax'] = 1.0
839        elif arg == '-midq':    opts['qmax'] = 0.2
840        elif arg == '-lowq':    opts['qmax'] = 0.05
841        elif arg == '-zero':    opts['zero'] = True
842        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
843        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
844        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
845        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
846        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
847        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
848        elif arg == '-preset':  opts['seed'] = -1
849        elif arg == '-mono':    opts['mono'] = True
850        elif arg == '-poly':    opts['mono'] = False
851        elif arg == '-pars':    opts['show_pars'] = True
852        elif arg == '-nopars':  opts['show_pars'] = False
853        elif arg == '-hist':    opts['show_hist'] = True
854        elif arg == '-nohist':  opts['show_hist'] = False
855        elif arg == '-rel':     opts['rel_err'] = True
856        elif arg == '-abs':     opts['rel_err'] = False
857        elif arg == '-half':    engines.append(arg[1:])
858        elif arg == '-fast':    engines.append(arg[1:])
859        elif arg == '-single':  engines.append(arg[1:])
860        elif arg == '-double':  engines.append(arg[1:])
861        elif arg == '-single!': engines.append(arg[1:])
862        elif arg == '-double!': engines.append(arg[1:])
863        elif arg == '-quad!':   engines.append(arg[1:])
864        elif arg == '-sasview': engines.append(arg[1:])
865        elif arg == '-edit':    opts['explore'] = True
866        elif arg == '-demo':    opts['use_demo'] = True
867        elif arg == '-default':    opts['use_demo'] = False
868    # pylint: enable=bad-whitespace
869
870    if len(engines) == 0:
871        engines.extend(['single', 'double'])
872    elif len(engines) == 1:
873        if engines[0][0] == 'double':
874            engines.append('single')
875        else:
876            engines.append('double')
877    elif len(engines) > 2:
878        del engines[2:]
879
880    n1 = int(args[1]) if len(args) > 1 else 1
881    n2 = int(args[2]) if len(args) > 2 else 1
882    use_sasview = any(engine=='sasview' and count>0
883                      for engine, count in zip(engines, [n1, n2]))
884
885    # Get demo parameters from model definition, or use default parameters
886    # if model does not define demo parameters
887    pars = get_pars(model_info, opts['use_demo'])
888
889
890    # Fill in parameters given on the command line
891    presets = {}
892    for arg in values:
893        k, v = arg.split('=', 1)
894        if k not in pars:
895            # extract base name without polydispersity info
896            s = set(p.split('_pd')[0] for p in pars)
897            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
898            sys.exit(1)
899        presets[k] = float(v) if not k.endswith('type') else v
900
901    # randomize parameters
902    #pars.update(set_pars)  # set value before random to control range
903    if opts['seed'] > -1:
904        pars = randomize_pars(model_info, pars, seed=opts['seed'])
905        print("Randomize using -random=%i"%opts['seed'])
906    if opts['mono']:
907        pars = suppress_pd(pars)
908    pars.update(presets)  # set value after random to control value
909    #import pprint; pprint.pprint(model_info)
910    constrain_pars(model_info, pars)
911    if use_sasview:
912        constrain_new_to_old(model_info, pars)
913    if opts['show_pars']:
914        print(str(parlist(model_info, pars, opts['is2d'])))
915
916    # Create the computational engines
917    data, _ = make_data(opts)
918    if n1:
919        base = make_engine(model_info, data, engines[0], opts['cutoff'])
920    else:
921        base = None
922    if n2:
923        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
924    else:
925        comp = None
926
927    # pylint: disable=bad-whitespace
928    # Remember it all
929    opts.update({
930        'name'      : name,
931        'def'       : model_info,
932        'n1'        : n1,
933        'n2'        : n2,
934        'presets'   : presets,
935        'pars'      : pars,
936        'data'      : data,
937        'engines'   : [base, comp],
938    })
939    # pylint: enable=bad-whitespace
940
941    return opts
942
943def explore(opts):
944    # type: (Dict[str, Any]) -> None
945    """
946    Explore the model using the Bumps GUI.
947    """
948    import wx  # type: ignore
949    from bumps.names import FitProblem  # type: ignore
950    from bumps.gui.app_frame import AppFrame  # type: ignore
951
952    problem = FitProblem(Explore(opts))
953    is_mac = "cocoa" in wx.version()
954    app = wx.App()
955    frame = AppFrame(parent=None, title="explore")
956    if not is_mac: frame.Show()
957    frame.panel.set_model(model=problem)
958    frame.panel.Layout()
959    frame.panel.aui.Split(0, wx.TOP)
960    if is_mac: frame.Show()
961    app.MainLoop()
962
963class Explore(object):
964    """
965    Bumps wrapper for a SAS model comparison.
966
967    The resulting object can be used as a Bumps fit problem so that
968    parameters can be adjusted in the GUI, with plots updated on the fly.
969    """
970    def __init__(self, opts):
971        # type: (Dict[str, Any]) -> None
972        from bumps.cli import config_matplotlib  # type: ignore
973        from . import bumps_model
974        config_matplotlib()
975        self.opts = opts
976        model_info = opts['def']
977        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
978        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
979        if not opts['is2d']:
980            for p in model_info.parameters.user_parameters(is2d=False):
981                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
982                    k = p.name+ext
983                    v = pars.get(k, None)
984                    if v is not None:
985                        v.range(*parameter_range(k, v.value))
986        else:
987            for k, v in pars.items():
988                v.range(*parameter_range(k, v.value))
989
990        self.pars = pars
991        self.pd_types = pd_types
992        self.limits = None
993
994    def numpoints(self):
995        # type: () -> int
996        """
997        Return the number of points.
998        """
999        return len(self.pars) + 1  # so dof is 1
1000
1001    def parameters(self):
1002        # type: () -> Any   # Dict/List hierarchy of parameters
1003        """
1004        Return a dictionary of parameters.
1005        """
1006        return self.pars
1007
1008    def nllf(self):
1009        # type: () -> float
1010        """
1011        Return cost.
1012        """
1013        # pylint: disable=no-self-use
1014        return 0.  # No nllf
1015
1016    def plot(self, view='log'):
1017        # type: (str) -> None
1018        """
1019        Plot the data and residuals.
1020        """
1021        pars = dict((k, v.value) for k, v in self.pars.items())
1022        pars.update(self.pd_types)
1023        self.opts['pars'] = pars
1024        limits = compare(self.opts, limits=self.limits)
1025        if self.limits is None:
1026            vmin, vmax = limits
1027            self.limits = vmax*1e-7, 1.3*vmax
1028
1029
1030def main():
1031    # type: () -> None
1032    """
1033    Main program.
1034    """
1035    opts = parse_opts()
1036    if opts['explore']:
1037        explore(opts)
1038    else:
1039        compare(opts)
1040
1041if __name__ == "__main__":
1042    main()
Note: See TracBrowser for help on using the repository browser.