source: sasmodels/sasmodels/compare.py @ 6831fa0

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 6831fa0 was 6831fa0, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

stacked_disks: enable 2D modeling for stacked disks

  • Property mode set to 100755
File size: 36.9 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np  # type: ignore
37
38from . import core
39from . import kerneldll
40from . import exception
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_name, revert_pars, constrain_new_to_old
44
45try:
46    from typing import Optional, Dict, Any, Callable, Tuple
47except Exception:
48    pass
49else:
50    from .modelinfo import ModelInfo, Parameter, ParameterSet
51    from .data import Data
52    Calculator = Callable[[float], np.ndarray]
53
54USAGE = """
55usage: compare.py model N1 N2 [options...] [key=val]
56
57Compare the speed and value for a model between the SasView original and the
58sasmodels rewrite.
59
60model is the name of the model to compare (see below).
61N1 is the number of times to run sasmodels (default=1).
62N2 is the number times to run sasview (default=1).
63
64Options (* for default):
65
66    -plot*/-noplot plots or suppress the plot of the model
67    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
68    -nq=128 sets the number of Q points in the data set
69    -zero indicates that q=0 should be included
70    -1d*/-2d computes 1d or 2d data
71    -preset*/-random[=seed] preset or random parameters
72    -mono/-poly* force monodisperse/polydisperse
73    -cutoff=1e-5* cutoff value for including a point in polydispersity
74    -pars/-nopars* prints the parameter set or not
75    -abs/-rel* plot relative or absolute error
76    -linear/-log*/-q4 intensity scaling
77    -hist/-nohist* plot histogram of relative error
78    -res=0 sets the resolution width dQ/Q if calculating with resolution
79    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
80    -edit starts the parameter explorer
81    -default/-demo* use demo vs default parameters
82    -html shows the model docs instead of running the model
83    -title="note" adds note to the plot title, after the model name
84
85Any two calculation engines can be selected for comparison:
86
87    -single/-double/-half/-fast sets an OpenCL calculation engine
88    -single!/-double!/-quad! sets an OpenMP calculation engine
89    -sasview sets the sasview calculation engine
90
91The default is -single -sasview.  Note that the interpretation of quad
92precision depends on architecture, and may vary from 64-bit to 128-bit,
93with 80-bit floats being common (1e-19 precision).
94
95Key=value pairs allow you to set specific values for the model parameters.
96"""
97
98# Update docs with command line usage string.   This is separate from the usual
99# doc string so that we can display it at run time if there is an error.
100# lin
101__doc__ = (__doc__  # pylint: disable=redefined-builtin
102           + """
103Program description
104-------------------
105
106"""
107           + USAGE)
108
109kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
110
111# CRUFT python 2.6
112if not hasattr(datetime.timedelta, 'total_seconds'):
113    def delay(dt):
114        """Return number date-time delta as number seconds"""
115        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
116else:
117    def delay(dt):
118        """Return number date-time delta as number seconds"""
119        return dt.total_seconds()
120
121
122class push_seed(object):
123    """
124    Set the seed value for the random number generator.
125
126    When used in a with statement, the random number generator state is
127    restored after the with statement is complete.
128
129    :Parameters:
130
131    *seed* : int or array_like, optional
132        Seed for RandomState
133
134    :Example:
135
136    Seed can be used directly to set the seed::
137
138        >>> from numpy.random import randint
139        >>> push_seed(24)
140        <...push_seed object at...>
141        >>> print(randint(0,1000000,3))
142        [242082    899 211136]
143
144    Seed can also be used in a with statement, which sets the random
145    number generator state for the enclosed computations and restores
146    it to the previous state on completion::
147
148        >>> with push_seed(24):
149        ...    print(randint(0,1000000,3))
150        [242082    899 211136]
151
152    Using nested contexts, we can demonstrate that state is indeed
153    restored after the block completes::
154
155        >>> with push_seed(24):
156        ...    print(randint(0,1000000))
157        ...    with push_seed(24):
158        ...        print(randint(0,1000000,3))
159        ...    print(randint(0,1000000))
160        242082
161        [242082    899 211136]
162        899
163
164    The restore step is protected against exceptions in the block::
165
166        >>> with push_seed(24):
167        ...    print(randint(0,1000000))
168        ...    try:
169        ...        with push_seed(24):
170        ...            print(randint(0,1000000,3))
171        ...            raise Exception()
172        ...    except Exception:
173        ...        print("Exception raised")
174        ...    print(randint(0,1000000))
175        242082
176        [242082    899 211136]
177        Exception raised
178        899
179    """
180    def __init__(self, seed=None):
181        # type: (Optional[int]) -> None
182        self._state = np.random.get_state()
183        np.random.seed(seed)
184
185    def __enter__(self):
186        # type: () -> None
187        pass
188
189    def __exit__(self, exc_type, exc_value, traceback):
190        # type: (Any, BaseException, Any) -> None
191        # TODO: better typing for __exit__ method
192        np.random.set_state(self._state)
193
194def tic():
195    # type: () -> Callable[[], float]
196    """
197    Timer function.
198
199    Use "toc=tic()" to start the clock and "toc()" to measure
200    a time interval.
201    """
202    then = datetime.datetime.now()
203    return lambda: delay(datetime.datetime.now() - then)
204
205
206def set_beam_stop(data, radius, outer=None):
207    # type: (Data, float, float) -> None
208    """
209    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
210
211    Note: this function does not require sasview
212    """
213    if hasattr(data, 'qx_data'):
214        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
215        data.mask = (q < radius)
216        if outer is not None:
217            data.mask |= (q >= outer)
218    else:
219        data.mask = (data.x < radius)
220        if outer is not None:
221            data.mask |= (data.x >= outer)
222
223
224def parameter_range(p, v):
225    # type: (str, float) -> Tuple[float, float]
226    """
227    Choose a parameter range based on parameter name and initial value.
228    """
229    # process the polydispersity options
230    if p.endswith('_pd_n'):
231        return 0., 100.
232    elif p.endswith('_pd_nsigma'):
233        return 0., 5.
234    elif p.endswith('_pd_type'):
235        raise ValueError("Cannot return a range for a string value")
236    elif any(s in p for s in ('theta', 'phi', 'psi')):
237        # orientation in [-180,180], orientation pd in [0,45]
238        if p.endswith('_pd'):
239            return 0., 45.
240        else:
241            return -180., 180.
242    elif p.endswith('_pd'):
243        return 0., 1.
244    elif 'sld' in p:
245        return -0.5, 10.
246    elif p == 'background':
247        return 0., 10.
248    elif p == 'scale':
249        return 0., 1.e3
250    elif v < 0.:
251        return 2.*v, -2.*v
252    else:
253        return 0., (2.*v if v > 0. else 1.)
254
255
256def _randomize_one(model_info, p, v):
257    # type: (ModelInfo, str, float) -> float
258    # type: (ModelInfo, str, str) -> str
259    """
260    Randomize a single parameter.
261    """
262    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
263        return v
264
265    # Find the parameter definition
266    for par in model_info.parameters.call_parameters:
267        if par.name == p:
268            break
269    else:
270        raise ValueError("unknown parameter %r"%p)
271
272    if len(par.limits) > 2:  # choice list
273        return np.random.randint(len(par.limits))
274
275    limits = par.limits
276    if not np.isfinite(limits).all():
277        low, high = parameter_range(p, v)
278        limits = (max(limits[0], low), min(limits[1], high))
279
280    return np.random.uniform(*limits)
281
282
283def randomize_pars(model_info, pars, seed=None):
284    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
285    """
286    Generate random values for all of the parameters.
287
288    Valid ranges for the random number generator are guessed from the name of
289    the parameter; this will not account for constraints such as cap radius
290    greater than cylinder radius in the capped_cylinder model, so
291    :func:`constrain_pars` needs to be called afterward..
292    """
293    with push_seed(seed):
294        # Note: the sort guarantees order `of calls to random number generator
295        random_pars = dict((p, _randomize_one(model_info, p, v))
296                           for p, v in sorted(pars.items()))
297    return random_pars
298
299def constrain_pars(model_info, pars):
300    # type: (ModelInfo, ParameterSet) -> None
301    """
302    Restrict parameters to valid values.
303
304    This includes model specific code for models such as capped_cylinder
305    which need to support within model constraints (cap radius more than
306    cylinder radius in this case).
307
308    Warning: this updates the *pars* dictionary in place.
309    """
310    name = model_info.id
311    # if it is a product model, then just look at the form factor since
312    # none of the structure factors need any constraints.
313    if '*' in name:
314        name = name.split('*')[0]
315
316    if name == 'capped_cylinder' and pars['radius_cap'] < pars['radius']:
317        pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
318    if name == 'barbell' and pars['radius_bell'] < pars['radius']:
319        pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
320
321    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
322    if name == 'guinier':
323        #q_max = 0.2  # mid q maximum
324        q_max = 1.0  # high q maximum
325        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
326        pars['rg'] = min(pars['rg'], rg_max)
327
328    if name == 'rpa':
329        # Make sure phi sums to 1.0
330        if pars['case_num'] < 2:
331            pars['Phi1'] = 0.
332            pars['Phi2'] = 0.
333        elif pars['case_num'] < 5:
334            pars['Phi1'] = 0.
335        total = sum(pars['Phi'+c] for c in '1234')
336        for c in '1234':
337            pars['Phi'+c] /= total
338
339def parlist(model_info, pars, is2d):
340    # type: (ModelInfo, ParameterSet, bool) -> str
341    """
342    Format the parameter list for printing.
343    """
344    lines = []
345    parameters = model_info.parameters
346    for p in parameters.user_parameters(pars, is2d):
347        fields = dict(
348            value=pars.get(p.id, p.default),
349            pd=pars.get(p.id+"_pd", 0.),
350            n=int(pars.get(p.id+"_pd_n", 0)),
351            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
352            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
353            relative_pd=p.relative_pd,
354        )
355        lines.append(_format_par(p.name, **fields))
356    return "\n".join(lines)
357
358    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
359
360def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
361                relative_pd=False):
362    # type: (str, float, float, int, float, str) -> str
363    line = "%s: %g"%(name, value)
364    if pd != 0.  and n != 0:
365        if relative_pd:
366            pd *= value
367        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
368                % (pd, n, nsigma, nsigma, pdtype)
369    return line
370
371def suppress_pd(pars):
372    # type: (ParameterSet) -> ParameterSet
373    """
374    Suppress theta_pd for now until the normalization is resolved.
375
376    May also suppress complete polydispersity of the model to test
377    models more quickly.
378    """
379    pars = pars.copy()
380    for p in pars:
381        if p.endswith("_pd_n"): pars[p] = 0
382    return pars
383
384def eval_sasview(model_info, data):
385    # type: (Modelinfo, Data) -> Calculator
386    """
387    Return a model calculator using the pre-4.0 SasView models.
388    """
389    # importing sas here so that the error message will be that sas failed to
390    # import rather than the more obscure smear_selection not imported error
391    import sas
392    import sas.models
393    from sas.models.qsmearing import smear_selection
394    from sas.models.MultiplicationModel import MultiplicationModel
395    from sas.models.dispersion_models import models as dispersers
396
397    def get_model_class(name):
398        # type: (str) -> "sas.models.BaseComponent"
399        #print("new",sorted(_pars.items()))
400        __import__('sas.models.' + name)
401        ModelClass = getattr(getattr(sas.models, name, None), name, None)
402        if ModelClass is None:
403            raise ValueError("could not find model %r in sas.models"%name)
404        return ModelClass
405
406    # WARNING: ugly hack when handling model!
407    # Sasview models with multiplicity need to be created with the target
408    # multiplicity, so we cannot create the target model ahead of time for
409    # for multiplicity models.  Instead we store the model in a list and
410    # update the first element of that list with the new multiplicity model
411    # every time we evaluate.
412
413    # grab the sasview model, or create it if it is a product model
414    if model_info.composition:
415        composition_type, parts = model_info.composition
416        if composition_type == 'product':
417            P, S = [get_model_class(revert_name(p))() for p in parts]
418            model = [MultiplicationModel(P, S)]
419        else:
420            raise ValueError("sasview mixture models not supported by compare")
421    else:
422        old_name = revert_name(model_info)
423        if old_name is None:
424            raise ValueError("model %r does not exist in old sasview"
425                            % model_info.id)
426        ModelClass = get_model_class(old_name)
427        model = [ModelClass()]
428    model[0].disperser_handles = {}
429
430    # build a smearer with which to call the model, if necessary
431    smearer = smear_selection(data, model=model)
432    if hasattr(data, 'qx_data'):
433        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
434        index = ((~data.mask) & (~np.isnan(data.data))
435                 & (q >= data.qmin) & (q <= data.qmax))
436        if smearer is not None:
437            smearer.model = model  # because smear_selection has a bug
438            smearer.accuracy = data.accuracy
439            smearer.set_index(index)
440            def _call_smearer():
441                smearer.model = model[0]
442                return smearer.get_value()
443            theory = _call_smearer
444        else:
445            theory = lambda: model[0].evalDistribution([data.qx_data[index],
446                                                        data.qy_data[index]])
447    elif smearer is not None:
448        theory = lambda: smearer(model[0].evalDistribution(data.x))
449    else:
450        theory = lambda: model[0].evalDistribution(data.x)
451
452    def calculator(**pars):
453        # type: (float, ...) -> np.ndarray
454        """
455        Sasview calculator for model.
456        """
457        oldpars = revert_pars(model_info, pars)
458        # For multiplicity models, create a model with the correct multiplicity
459        control = oldpars.pop("CONTROL", None)
460        if control is not None:
461            # sphericalSLD has one fewer multiplicity.  This update should
462            # happen in revert_pars, but it hasn't been called yet.
463            model[0] = ModelClass(control)
464        # paying for parameter conversion each time to keep life simple, if not fast
465        for k, v in oldpars.items():
466            if k.endswith('.type'):
467                par = k[:-5]
468                if v == 'gaussian': continue
469                cls = dispersers[v if v != 'rectangle' else 'rectangula']
470                handle = cls()
471                model[0].disperser_handles[par] = handle
472                try:
473                    model[0].set_dispersion(par, handle)
474                except Exception:
475                    exception.annotate_exception("while setting %s to %r"
476                                                 %(par, v))
477                    raise
478
479
480        #print("sasview pars",oldpars)
481        for k, v in oldpars.items():
482            name_attr = k.split('.')  # polydispersity components
483            if len(name_attr) == 2:
484                par, disp_par = name_attr
485                model[0].dispersion[par][disp_par] = v
486            else:
487                model[0].setParam(k, v)
488        return theory()
489
490    calculator.engine = "sasview"
491    return calculator
492
493DTYPE_MAP = {
494    'half': '16',
495    'fast': 'fast',
496    'single': '32',
497    'double': '64',
498    'quad': '128',
499    'f16': '16',
500    'f32': '32',
501    'f64': '64',
502    'longdouble': '128',
503}
504def eval_opencl(model_info, data, dtype='single', cutoff=0.):
505    # type: (ModelInfo, Data, str, float) -> Calculator
506    """
507    Return a model calculator using the OpenCL calculation engine.
508    """
509    if not core.HAVE_OPENCL:
510        raise RuntimeError("OpenCL not available")
511    model = core.build_model(model_info, dtype=dtype, platform="ocl")
512    calculator = DirectModel(data, model, cutoff=cutoff)
513    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
514    return calculator
515
516def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
517    # type: (ModelInfo, Data, str, float) -> Calculator
518    """
519    Return a model calculator using the DLL calculation engine.
520    """
521    model = core.build_model(model_info, dtype=dtype, platform="dll")
522    calculator = DirectModel(data, model, cutoff=cutoff)
523    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
524    return calculator
525
526def time_calculation(calculator, pars, evals=1):
527    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
528    """
529    Compute the average calculation time over N evaluations.
530
531    An additional call is generated without polydispersity in order to
532    initialize the calculation engine, and make the average more stable.
533    """
534    # initialize the code so time is more accurate
535    if evals > 1:
536        calculator(**suppress_pd(pars))
537    toc = tic()
538    # make sure there is at least one eval
539    value = calculator(**pars)
540    for _ in range(evals-1):
541        value = calculator(**pars)
542    average_time = toc()*1000. / evals
543    #print("I(q)",value)
544    return value, average_time
545
546def make_data(opts):
547    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
548    """
549    Generate an empty dataset, used with the model to set Q points
550    and resolution.
551
552    *opts* contains the options, with 'qmax', 'nq', 'res',
553    'accuracy', 'is2d' and 'view' parsed from the command line.
554    """
555    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
556    if opts['is2d']:
557        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
558        data = empty_data2D(q, resolution=res)
559        data.accuracy = opts['accuracy']
560        set_beam_stop(data, 0.0004)
561        index = ~data.mask
562    else:
563        if opts['view'] == 'log' and not opts['zero']:
564            qmax = math.log10(qmax)
565            q = np.logspace(qmax-3, qmax, nq)
566        else:
567            q = np.linspace(0.001*qmax, qmax, nq)
568        if opts['zero']:
569            q = np.hstack((0, q))
570        data = empty_data1D(q, resolution=res)
571        index = slice(None, None)
572    return data, index
573
574def make_engine(model_info, data, dtype, cutoff):
575    # type: (ModelInfo, Data, str, float) -> Calculator
576    """
577    Generate the appropriate calculation engine for the given datatype.
578
579    Datatypes with '!' appended are evaluated using external C DLLs rather
580    than OpenCL.
581    """
582    if dtype == 'sasview':
583        return eval_sasview(model_info, data)
584    elif dtype.endswith('!'):
585        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
586    else:
587        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
588
589def _show_invalid(data, theory):
590    # type: (Data, np.ma.ndarray) -> None
591    """
592    Display a list of the non-finite values in theory.
593    """
594    if not theory.mask.any():
595        return
596
597    if hasattr(data, 'x'):
598        bad = zip(data.x[theory.mask], theory[theory.mask])
599        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
600
601
602def compare(opts, limits=None):
603    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
604    """
605    Preform a comparison using options from the command line.
606
607    *limits* are the limits on the values to use, either to set the y-axis
608    for 1D or to set the colormap scale for 2D.  If None, then they are
609    inferred from the data and returned. When exploring using Bumps,
610    the limits are set when the model is initially called, and maintained
611    as the values are adjusted, making it easier to see the effects of the
612    parameters.
613    """
614    n_base, n_comp = opts['n1'], opts['n2']
615    pars = opts['pars']
616    data = opts['data']
617
618    # silence the linter
619    base = opts['engines'][0] if n_base else None
620    comp = opts['engines'][1] if n_comp else None
621    base_time = comp_time = None
622    base_value = comp_value = resid = relerr = None
623
624    # Base calculation
625    if n_base > 0:
626        try:
627            base_raw, base_time = time_calculation(base, pars, n_base)
628            base_value = np.ma.masked_invalid(base_raw)
629            print("%s t=%.2f ms, intensity=%.0f"
630                  % (base.engine, base_time, base_value.sum()))
631            _show_invalid(data, base_value)
632        except ImportError:
633            traceback.print_exc()
634            n_base = 0
635
636    # Comparison calculation
637    if n_comp > 0:
638        try:
639            comp_raw, comp_time = time_calculation(comp, pars, n_comp)
640            comp_value = np.ma.masked_invalid(comp_raw)
641            print("%s t=%.2f ms, intensity=%.0f"
642                  % (comp.engine, comp_time, comp_value.sum()))
643            _show_invalid(data, comp_value)
644        except ImportError:
645            traceback.print_exc()
646            n_comp = 0
647
648    # Compare, but only if computing both forms
649    if n_base > 0 and n_comp > 0:
650        resid = (base_value - comp_value)
651        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
652        _print_stats("|%s-%s|"
653                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
654                     resid)
655        _print_stats("|(%s-%s)/%s|"
656                     % (base.engine, comp.engine, comp.engine),
657                     relerr)
658
659    # Plot if requested
660    if not opts['plot'] and not opts['explore']: return
661    view = opts['view']
662    import matplotlib.pyplot as plt
663    if limits is None:
664        vmin, vmax = np.Inf, -np.Inf
665        if n_base > 0:
666            vmin = min(vmin, base_value.min())
667            vmax = max(vmax, base_value.max())
668        if n_comp > 0:
669            vmin = min(vmin, comp_value.min())
670            vmax = max(vmax, comp_value.max())
671        limits = vmin, vmax
672
673    if n_base > 0:
674        if n_comp > 0: plt.subplot(131)
675        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
676        plt.title("%s t=%.2f ms"%(base.engine, base_time))
677        #cbar_title = "log I"
678    if n_comp > 0:
679        if n_base > 0: plt.subplot(132)
680        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
681        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
682        #cbar_title = "log I"
683    if n_comp > 0 and n_base > 0:
684        plt.subplot(133)
685        if not opts['rel_err']:
686            err, errstr, errview = resid, "abs err", "linear"
687        else:
688            err, errstr, errview = abs(relerr), "rel err", "log"
689        #sorted = np.sort(err.flatten())
690        #cutoff = sorted[int(sorted.size*0.95)]
691        #err[err>cutoff] = cutoff
692        #err,errstr = base/comp,"ratio"
693        plot_theory(data, None, resid=err, view=errview, use_data=False)
694        if view == 'linear':
695            plt.xscale('linear')
696        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
697        #cbar_title = errstr if errview=="linear" else "log "+errstr
698    #if is2D:
699    #    h = plt.colorbar()
700    #    h.ax.set_title(cbar_title)
701    fig = plt.gcf()
702    fig.suptitle(opts['name'])
703
704    if n_comp > 0 and n_base > 0 and opts['show_hist']:
705        plt.figure()
706        v = relerr
707        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
708        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
709        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
710                   % (base.engine, comp.engine, comp.engine))
711        plt.ylabel('P(err)')
712        plt.title('Distribution of relative error between calculation engines')
713
714    if not opts['explore']:
715        plt.show()
716
717    return limits
718
719def _print_stats(label, err):
720    # type: (str, np.ma.ndarray) -> None
721    # work with trimmed data, not the full set
722    sorted_err = np.sort(abs(err.compressed()))
723    p50 = int((len(sorted_err)-1)*0.50)
724    p98 = int((len(sorted_err)-1)*0.98)
725    data = [
726        "max:%.3e"%sorted_err[-1],
727        "median:%.3e"%sorted_err[p50],
728        "98%%:%.3e"%sorted_err[p98],
729        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
730        "zero-offset:%+.3e"%np.mean(sorted_err),
731        ]
732    print(label+"  "+"  ".join(data))
733
734
735
736# ===========================================================================
737#
738NAME_OPTIONS = set([
739    'plot', 'noplot',
740    'half', 'fast', 'single', 'double',
741    'single!', 'double!', 'quad!', 'sasview',
742    'lowq', 'midq', 'highq', 'exq', 'zero',
743    '2d', '1d',
744    'preset', 'random',
745    'poly', 'mono',
746    'nopars', 'pars',
747    'rel', 'abs',
748    'linear', 'log', 'q4',
749    'hist', 'nohist',
750    'edit', 'html',
751    'demo', 'default',
752    ])
753VALUE_OPTIONS = [
754    # Note: random is both a name option and a value option
755    'cutoff', 'random', 'nq', 'res', 'accuracy',
756    ]
757
758def columnize(items, indent="", width=79):
759    # type: (List[str], str, int) -> str
760    """
761    Format a list of strings into columns.
762
763    Returns a string with carriage returns ready for printing.
764    """
765    column_width = max(len(w) for w in items) + 1
766    num_columns = (width - len(indent)) // column_width
767    num_rows = len(items) // num_columns
768    items = items + [""] * (num_rows * num_columns - len(items))
769    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
770    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
771             for row in zip(*columns)]
772    output = indent + ("\n"+indent).join(lines)
773    return output
774
775
776def get_pars(model_info, use_demo=False):
777    # type: (ModelInfo, bool) -> ParameterSet
778    """
779    Extract demo parameters from the model definition.
780    """
781    # Get the default values for the parameters
782    pars = {}
783    for p in model_info.parameters.call_parameters:
784        parts = [('', p.default)]
785        if p.polydisperse:
786            parts.append(('_pd', 0.0))
787            parts.append(('_pd_n', 0))
788            parts.append(('_pd_nsigma', 3.0))
789            parts.append(('_pd_type', "gaussian"))
790        for ext, val in parts:
791            if p.length > 1:
792                dict(("%s%d%s" % (p.id, k, ext), val)
793                     for k in range(1, p.length+1))
794            else:
795                pars[p.id + ext] = val
796
797    # Plug in values given in demo
798    if use_demo:
799        pars.update(model_info.demo)
800    return pars
801
802
803def parse_opts(argv):
804    # type: (List[str]) -> Dict[str, Any]
805    """
806    Parse command line options.
807    """
808    MODELS = core.list_models()
809    flags = [arg for arg in argv
810             if arg.startswith('-')]
811    values = [arg for arg in argv
812              if not arg.startswith('-') and '=' in arg]
813    positional_args = [arg for arg in argv
814            if not arg.startswith('-') and '=' not in arg]
815    models = "\n    ".join("%-15s"%v for v in MODELS)
816    if len(positional_args) == 0:
817        print(USAGE)
818        print("\nAvailable models:")
819        print(columnize(MODELS, indent="  "))
820        return None
821    if len(positional_args) > 3:
822        print("expected parameters: model N1 N2")
823
824    name = positional_args[0]
825    try:
826        model_info = core.load_model_info(name)
827    except ImportError as exc:
828        print(str(exc))
829        print("Could not find model; use one of:\n    " + models)
830        return None
831
832    invalid = [o[1:] for o in flags
833               if o[1:] not in NAME_OPTIONS
834               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
835    if invalid:
836        print("Invalid options: %s"%(", ".join(invalid)))
837        return None
838
839
840    # pylint: disable=bad-whitespace
841    # Interpret the flags
842    opts = {
843        'plot'      : True,
844        'view'      : 'log',
845        'is2d'      : False,
846        'qmax'      : 0.05,
847        'nq'        : 128,
848        'res'       : 0.0,
849        'accuracy'  : 'Low',
850        'cutoff'    : 0.0,
851        'seed'      : -1,  # default to preset
852        'mono'      : False,
853        'show_pars' : False,
854        'show_hist' : False,
855        'rel_err'   : True,
856        'explore'   : False,
857        'use_demo'  : True,
858        'zero'      : False,
859        'html'      : False,
860    }
861    engines = []
862    for arg in flags:
863        if arg == '-noplot':    opts['plot'] = False
864        elif arg == '-plot':    opts['plot'] = True
865        elif arg == '-linear':  opts['view'] = 'linear'
866        elif arg == '-log':     opts['view'] = 'log'
867        elif arg == '-q4':      opts['view'] = 'q4'
868        elif arg == '-1d':      opts['is2d'] = False
869        elif arg == '-2d':      opts['is2d'] = True
870        elif arg == '-exq':     opts['qmax'] = 10.0
871        elif arg == '-highq':   opts['qmax'] = 1.0
872        elif arg == '-midq':    opts['qmax'] = 0.2
873        elif arg == '-lowq':    opts['qmax'] = 0.05
874        elif arg == '-zero':    opts['zero'] = True
875        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
876        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
877        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
878        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
879        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
880        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
881        elif arg == '-preset':  opts['seed'] = -1
882        elif arg == '-mono':    opts['mono'] = True
883        elif arg == '-poly':    opts['mono'] = False
884        elif arg == '-pars':    opts['show_pars'] = True
885        elif arg == '-nopars':  opts['show_pars'] = False
886        elif arg == '-hist':    opts['show_hist'] = True
887        elif arg == '-nohist':  opts['show_hist'] = False
888        elif arg == '-rel':     opts['rel_err'] = True
889        elif arg == '-abs':     opts['rel_err'] = False
890        elif arg == '-half':    engines.append(arg[1:])
891        elif arg == '-fast':    engines.append(arg[1:])
892        elif arg == '-single':  engines.append(arg[1:])
893        elif arg == '-double':  engines.append(arg[1:])
894        elif arg == '-single!': engines.append(arg[1:])
895        elif arg == '-double!': engines.append(arg[1:])
896        elif arg == '-quad!':   engines.append(arg[1:])
897        elif arg == '-sasview': engines.append(arg[1:])
898        elif arg == '-edit':    opts['explore'] = True
899        elif arg == '-demo':    opts['use_demo'] = True
900        elif arg == '-default':    opts['use_demo'] = False
901        elif arg == '-html':    opts['html'] = True
902    # pylint: enable=bad-whitespace
903
904    if len(engines) == 0:
905        engines.extend(['single', 'double'])
906    elif len(engines) == 1:
907        if engines[0][0] == 'double':
908            engines.append('single')
909        else:
910            engines.append('double')
911    elif len(engines) > 2:
912        del engines[2:]
913
914    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
915    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
916    use_sasview = any(engine == 'sasview' and count > 0
917                      for engine, count in zip(engines, [n1, n2]))
918
919    # Get demo parameters from model definition, or use default parameters
920    # if model does not define demo parameters
921    pars = get_pars(model_info, opts['use_demo'])
922
923
924    # Fill in parameters given on the command line
925    presets = {}
926    for arg in values:
927        k, v = arg.split('=', 1)
928        if k not in pars:
929            # extract base name without polydispersity info
930            s = set(p.split('_pd')[0] for p in pars)
931            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
932            return None
933        presets[k] = float(v) if not k.endswith('type') else v
934
935    # randomize parameters
936    #pars.update(set_pars)  # set value before random to control range
937    if opts['seed'] > -1:
938        pars = randomize_pars(model_info, pars, seed=opts['seed'])
939        print("Randomize using -random=%i"%opts['seed'])
940    if opts['mono']:
941        pars = suppress_pd(pars)
942    pars.update(presets)  # set value after random to control value
943    #import pprint; pprint.pprint(model_info)
944    constrain_pars(model_info, pars)
945    if use_sasview:
946        constrain_new_to_old(model_info, pars)
947    if opts['show_pars']:
948        print(str(parlist(model_info, pars, opts['is2d'])))
949
950    # Create the computational engines
951    data, _ = make_data(opts)
952    if n1:
953        base = make_engine(model_info, data, engines[0], opts['cutoff'])
954    else:
955        base = None
956    if n2:
957        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
958    else:
959        comp = None
960
961    # pylint: disable=bad-whitespace
962    # Remember it all
963    opts.update({
964        'name'      : name,
965        'def'       : model_info,
966        'n1'        : n1,
967        'n2'        : n2,
968        'presets'   : presets,
969        'pars'      : pars,
970        'data'      : data,
971        'engines'   : [base, comp],
972    })
973    # pylint: enable=bad-whitespace
974
975    return opts
976
977def show_docs(opts):
978    # type: (Dict[str, Any]) -> None
979    """
980    show html docs for the model
981    """
982    import wx  # type: ignore
983    from .generate import view_html_from_info
984    app = wx.App() if wx.GetApp() is None else None
985    view_html_from_info(opts['def'])
986    if app: app.MainLoop()
987
988
989def explore(opts):
990    # type: (Dict[str, Any]) -> None
991    """
992    explore the model using the bumps gui.
993    """
994    import wx  # type: ignore
995    from bumps.names import FitProblem  # type: ignore
996    from bumps.gui.app_frame import AppFrame  # type: ignore
997
998    is_mac = "cocoa" in wx.version()
999    # Create an app if not running embedded
1000    app = wx.App() if wx.GetApp() is None else None
1001    problem = FitProblem(Explore(opts))
1002    frame = AppFrame(parent=None, title="explore", size=(1000,700))
1003    if not is_mac: frame.Show()
1004    frame.panel.set_model(model=problem)
1005    frame.panel.Layout()
1006    frame.panel.aui.Split(0, wx.TOP)
1007    if is_mac: frame.Show()
1008    # If running withing an app, start the main loop
1009    if app: app.MainLoop()
1010
1011class Explore(object):
1012    """
1013    Bumps wrapper for a SAS model comparison.
1014
1015    The resulting object can be used as a Bumps fit problem so that
1016    parameters can be adjusted in the GUI, with plots updated on the fly.
1017    """
1018    def __init__(self, opts):
1019        # type: (Dict[str, Any]) -> None
1020        from bumps.cli import config_matplotlib  # type: ignore
1021        from . import bumps_model
1022        config_matplotlib()
1023        self.opts = opts
1024        model_info = opts['def']
1025        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
1026        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1027        if not opts['is2d']:
1028            for p in model_info.parameters.user_parameters(is2d=False):
1029                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1030                    k = p.name+ext
1031                    v = pars.get(k, None)
1032                    if v is not None:
1033                        v.range(*parameter_range(k, v.value))
1034        else:
1035            for k, v in pars.items():
1036                v.range(*parameter_range(k, v.value))
1037
1038        self.pars = pars
1039        self.pd_types = pd_types
1040        self.limits = None
1041
1042    def numpoints(self):
1043        # type: () -> int
1044        """
1045        Return the number of points.
1046        """
1047        return len(self.pars) + 1  # so dof is 1
1048
1049    def parameters(self):
1050        # type: () -> Any   # Dict/List hierarchy of parameters
1051        """
1052        Return a dictionary of parameters.
1053        """
1054        return self.pars
1055
1056    def nllf(self):
1057        # type: () -> float
1058        """
1059        Return cost.
1060        """
1061        # pylint: disable=no-self-use
1062        return 0.  # No nllf
1063
1064    def plot(self, view='log'):
1065        # type: (str) -> None
1066        """
1067        Plot the data and residuals.
1068        """
1069        pars = dict((k, v.value) for k, v in self.pars.items())
1070        pars.update(self.pd_types)
1071        self.opts['pars'] = pars
1072        limits = compare(self.opts, limits=self.limits)
1073        if self.limits is None:
1074            vmin, vmax = limits
1075            self.limits = vmax*1e-7, 1.3*vmax
1076
1077
1078def main(*argv):
1079    # type: (*str) -> None
1080    """
1081    Main program.
1082    """
1083    opts = parse_opts(argv)
1084    if opts is not None:
1085        if opts['html']:
1086            show_docs(opts)
1087        elif opts['explore']:
1088            explore(opts)
1089        else:
1090            compare(opts)
1091
1092if __name__ == "__main__":
1093    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.