source: sasmodels/sasmodels/compare.py @ c52f9da

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since c52f9da was 8407d8c, checked in by wojciech, 8 years ago

Changed gpu flag to opencl as it seems more appropritae to what code is doing

  • Property mode set to 100755
File size: 35.3 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
[caeb06d]3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
[9cfcac8]11tag the model as double precision only.
[caeb06d]12
[9cfcac8]13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
[caeb06d]14sasmodels root to see the command line options.
15
[9cfcac8]16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
[caeb06d]18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
[190fc2b]31import sys
32import math
33import datetime
34import traceback
35
[7ae2b7f]36import numpy as np  # type: ignore
[190fc2b]37
38from . import core
39from . import kerneldll
40from .data import plot_theory, empty_data1D, empty_data2D
41from .direct_model import DirectModel
[f247314]42from .convert import revert_name, revert_pars, constrain_new_to_old
[190fc2b]43
[dd7fc12]44try:
45    from typing import Optional, Dict, Any, Callable, Tuple
46except:
47    pass
48else:
49    from .modelinfo import ModelInfo, Parameter, ParameterSet
50    from .data import Data
[8d62008]51    Calculator = Callable[[float], np.ndarray]
[dd7fc12]52
[caeb06d]53USAGE = """
54usage: compare.py model N1 N2 [options...] [key=val]
55
56Compare the speed and value for a model between the SasView original and the
57sasmodels rewrite.
58
59model is the name of the model to compare (see below).
60N1 is the number of times to run sasmodels (default=1).
61N2 is the number times to run sasview (default=1).
62
63Options (* for default):
64
65    -plot*/-noplot plots or suppress the plot of the model
66    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
67    -nq=128 sets the number of Q points in the data set
[e78edc4]68    -zero indicates that q=0 should be included
[caeb06d]69    -1d*/-2d computes 1d or 2d data
70    -preset*/-random[=seed] preset or random parameters
71    -mono/-poly* force monodisperse/polydisperse
72    -cutoff=1e-5* cutoff value for including a point in polydispersity
73    -pars/-nopars* prints the parameter set or not
74    -abs/-rel* plot relative or absolute error
75    -linear/-log*/-q4 intensity scaling
76    -hist/-nohist* plot histogram of relative error
77    -res=0 sets the resolution width dQ/Q if calculating with resolution
78    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
79    -edit starts the parameter explorer
[98d6cfc]80    -default/-demo* use demo vs default parameters
[caeb06d]81
82Any two calculation engines can be selected for comparison:
83
84    -single/-double/-half/-fast sets an OpenCL calculation engine
85    -single!/-double!/-quad! sets an OpenMP calculation engine
86    -sasview sets the sasview calculation engine
87
88The default is -single -sasview.  Note that the interpretation of quad
89precision depends on architecture, and may vary from 64-bit to 128-bit,
90with 80-bit floats being common (1e-19 precision).
91
92Key=value pairs allow you to set specific values for the model parameters.
93"""
94
95# Update docs with command line usage string.   This is separate from the usual
96# doc string so that we can display it at run time if there is an error.
97# lin
[d15a908]98__doc__ = (__doc__  # pylint: disable=redefined-builtin
99           + """
[caeb06d]100Program description
101-------------------
102
[d15a908]103"""
104           + USAGE)
[caeb06d]105
[750ffa5]106kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]107
[7cf2cfd]108# CRUFT python 2.6
109if not hasattr(datetime.timedelta, 'total_seconds'):
110    def delay(dt):
111        """Return number date-time delta as number seconds"""
112        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
113else:
114    def delay(dt):
115        """Return number date-time delta as number seconds"""
116        return dt.total_seconds()
117
118
[4f2478e]119class push_seed(object):
120    """
121    Set the seed value for the random number generator.
122
123    When used in a with statement, the random number generator state is
124    restored after the with statement is complete.
125
126    :Parameters:
127
128    *seed* : int or array_like, optional
129        Seed for RandomState
130
131    :Example:
132
133    Seed can be used directly to set the seed::
134
135        >>> from numpy.random import randint
136        >>> push_seed(24)
137        <...push_seed object at...>
138        >>> print(randint(0,1000000,3))
139        [242082    899 211136]
140
141    Seed can also be used in a with statement, which sets the random
142    number generator state for the enclosed computations and restores
143    it to the previous state on completion::
144
145        >>> with push_seed(24):
146        ...    print(randint(0,1000000,3))
147        [242082    899 211136]
148
149    Using nested contexts, we can demonstrate that state is indeed
150    restored after the block completes::
151
152        >>> with push_seed(24):
153        ...    print(randint(0,1000000))
154        ...    with push_seed(24):
155        ...        print(randint(0,1000000,3))
156        ...    print(randint(0,1000000))
157        242082
158        [242082    899 211136]
159        899
160
161    The restore step is protected against exceptions in the block::
162
163        >>> with push_seed(24):
164        ...    print(randint(0,1000000))
165        ...    try:
166        ...        with push_seed(24):
167        ...            print(randint(0,1000000,3))
168        ...            raise Exception()
[dd7fc12]169        ...    except Exception:
[4f2478e]170        ...        print("Exception raised")
171        ...    print(randint(0,1000000))
172        242082
173        [242082    899 211136]
174        Exception raised
175        899
176    """
177    def __init__(self, seed=None):
[dd7fc12]178        # type: (Optional[int]) -> None
[4f2478e]179        self._state = np.random.get_state()
180        np.random.seed(seed)
181
182    def __enter__(self):
[dd7fc12]183        # type: () -> None
184        pass
[4f2478e]185
[b32dafd]186    def __exit__(self, exc_type, exc_value, traceback):
[dd7fc12]187        # type: (Any, BaseException, Any) -> None
188        # TODO: better typing for __exit__ method
[4f2478e]189        np.random.set_state(self._state)
190
[7cf2cfd]191def tic():
[dd7fc12]192    # type: () -> Callable[[], float]
[7cf2cfd]193    """
194    Timer function.
195
196    Use "toc=tic()" to start the clock and "toc()" to measure
197    a time interval.
198    """
199    then = datetime.datetime.now()
200    return lambda: delay(datetime.datetime.now() - then)
201
202
203def set_beam_stop(data, radius, outer=None):
[dd7fc12]204    # type: (Data, float, float) -> None
[7cf2cfd]205    """
206    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
207
[dd7fc12]208    Note: this function does not require sasview
[7cf2cfd]209    """
210    if hasattr(data, 'qx_data'):
211        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
212        data.mask = (q < radius)
213        if outer is not None:
214            data.mask |= (q >= outer)
215    else:
216        data.mask = (data.x < radius)
217        if outer is not None:
218            data.mask |= (data.x >= outer)
219
[8a20be5]220
[ec7e360]221def parameter_range(p, v):
[dd7fc12]222    # type: (str, float) -> Tuple[float, float]
[87985ca]223    """
[ec7e360]224    Choose a parameter range based on parameter name and initial value.
[87985ca]225    """
[8bd7b77]226    # process the polydispersity options
[ec7e360]227    if p.endswith('_pd_n'):
[dd7fc12]228        return 0., 100.
[ec7e360]229    elif p.endswith('_pd_nsigma'):
[dd7fc12]230        return 0., 5.
[ec7e360]231    elif p.endswith('_pd_type'):
[dd7fc12]232        raise ValueError("Cannot return a range for a string value")
[caeb06d]233    elif any(s in p for s in ('theta', 'phi', 'psi')):
[87985ca]234        # orientation in [-180,180], orientation pd in [0,45]
235        if p.endswith('_pd'):
[dd7fc12]236            return 0., 45.
[87985ca]237        else:
[dd7fc12]238            return -180., 180.
[87985ca]239    elif p.endswith('_pd'):
[dd7fc12]240        return 0., 1.
[8bd7b77]241    elif 'sld' in p:
[dd7fc12]242        return -0.5, 10.
[eb46451]243    elif p == 'background':
[dd7fc12]244        return 0., 10.
[eb46451]245    elif p == 'scale':
[dd7fc12]246        return 0., 1.e3
247    elif v < 0.:
248        return 2.*v, -2.*v
[87985ca]249    else:
[dd7fc12]250        return 0., (2.*v if v > 0. else 1.)
[87985ca]251
[4f2478e]252
[8bd7b77]253def _randomize_one(model_info, p, v):
[dd7fc12]254    # type: (ModelInfo, str, float) -> float
255    # type: (ModelInfo, str, str) -> str
[ec7e360]256    """
[caeb06d]257    Randomize a single parameter.
[ec7e360]258    """
[f3bd37f]259    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
[ec7e360]260        return v
[8bd7b77]261
262    # Find the parameter definition
[6d6508e]263    for par in model_info.parameters.call_parameters:
[8bd7b77]264        if par.name == p:
265            break
[ec7e360]266    else:
[8bd7b77]267        raise ValueError("unknown parameter %r"%p)
268
269    if len(par.limits) > 2:  # choice list
270        return np.random.randint(len(par.limits))
271
272    limits = par.limits
273    if not np.isfinite(limits).all():
274        low, high = parameter_range(p, v)
275        limits = (max(limits[0], low), min(limits[1], high))
276
277    return np.random.uniform(*limits)
[cd3dba0]278
[4f2478e]279
[8bd7b77]280def randomize_pars(model_info, pars, seed=None):
[dd7fc12]281    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
[caeb06d]282    """
283    Generate random values for all of the parameters.
284
285    Valid ranges for the random number generator are guessed from the name of
286    the parameter; this will not account for constraints such as cap radius
287    greater than cylinder radius in the capped_cylinder model, so
288    :func:`constrain_pars` needs to be called afterward..
289    """
[4f2478e]290    with push_seed(seed):
291        # Note: the sort guarantees order `of calls to random number generator
[dd7fc12]292        random_pars = dict((p, _randomize_one(model_info, p, v))
293                           for p, v in sorted(pars.items()))
294    return random_pars
[cd3dba0]295
[17bbadd]296def constrain_pars(model_info, pars):
[dd7fc12]297    # type: (ModelInfo, ParameterSet) -> None
[9a66e65]298    """
299    Restrict parameters to valid values.
[caeb06d]300
301    This includes model specific code for models such as capped_cylinder
302    which need to support within model constraints (cap radius more than
303    cylinder radius in this case).
[dd7fc12]304
305    Warning: this updates the *pars* dictionary in place.
[9a66e65]306    """
[6d6508e]307    name = model_info.id
[17bbadd]308    # if it is a product model, then just look at the form factor since
309    # none of the structure factors need any constraints.
310    if '*' in name:
311        name = name.split('*')[0]
312
[216a9e1]313    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
[caeb06d]314        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
[b514adf]315    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
[caeb06d]316        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
[b514adf]317
318    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
319    if name == 'guinier':
320        #q_max = 0.2  # mid q maximum
321        q_max = 1.0  # high q maximum
322        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
[caeb06d]323        pars['rg'] = min(pars['rg'], rg_max)
[cd3dba0]324
[82c299f]325    if name == 'rpa':
326        # Make sure phi sums to 1.0
327        if pars['case_num'] < 2:
[8bd7b77]328            pars['Phi1'] = 0.
329            pars['Phi2'] = 0.
[82c299f]330        elif pars['case_num'] < 5:
[8bd7b77]331            pars['Phi1'] = 0.
332        total = sum(pars['Phi'+c] for c in '1234')
333        for c in '1234':
[82c299f]334            pars['Phi'+c] /= total
335
[d6850fa]336def parlist(model_info, pars, is2d):
[dd7fc12]337    # type: (ModelInfo, ParameterSet, bool) -> str
[caeb06d]338    """
339    Format the parameter list for printing.
340    """
[a4a7308]341    lines = []
[6d6508e]342    parameters = model_info.parameters
[d19962c]343    for p in parameters.user_parameters(pars, is2d):
344        fields = dict(
345            value=pars.get(p.id, p.default),
346            pd=pars.get(p.id+"_pd", 0.),
347            n=int(pars.get(p.id+"_pd_n", 0)),
348            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
[dd7fc12]349            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
[bd49c79]350            relative_pd=p.relative_pd,
[dd7fc12]351        )
[d19962c]352        lines.append(_format_par(p.name, **fields))
[a4a7308]353    return "\n".join(lines)
354
355    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
356
[bd49c79]357def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
358                relative_pd=False):
[dd7fc12]359    # type: (str, float, float, int, float, str) -> str
[a4a7308]360    line = "%s: %g"%(name, value)
361    if pd != 0.  and n != 0:
[bd49c79]362        if relative_pd:
363            pd *= value
[a4a7308]364        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
[dd7fc12]365                % (pd, n, nsigma, nsigma, pdtype)
[a4a7308]366    return line
[87985ca]367
368def suppress_pd(pars):
[dd7fc12]369    # type: (ParameterSet) -> ParameterSet
[87985ca]370    """
371    Suppress theta_pd for now until the normalization is resolved.
372
373    May also suppress complete polydispersity of the model to test
374    models more quickly.
375    """
[f4f3919]376    pars = pars.copy()
[87985ca]377    for p in pars:
[8b25ee1]378        if p.endswith("_pd_n"): pars[p] = 0
[f4f3919]379    return pars
[87985ca]380
[17bbadd]381def eval_sasview(model_info, data):
[dd7fc12]382    # type: (Modelinfo, Data) -> Calculator
[caeb06d]383    """
[f247314]384    Return a model calculator using the pre-4.0 SasView models.
[caeb06d]385    """
[dc056b9]386    # importing sas here so that the error message will be that sas failed to
387    # import rather than the more obscure smear_selection not imported error
[2bebe2b]388    import sas
[dd7fc12]389    import sas.models
[8d62008]390    from sas.models.qsmearing import smear_selection
391    from sas.models.MultiplicationModel import MultiplicationModel
[ec7e360]392
[256dfe1]393    def get_model_class(name):
[dd7fc12]394        # type: (str) -> "sas.models.BaseComponent"
[17bbadd]395        #print("new",sorted(_pars.items()))
[dd7fc12]396        __import__('sas.models.' + name)
[17bbadd]397        ModelClass = getattr(getattr(sas.models, name, None), name, None)
398        if ModelClass is None:
399            raise ValueError("could not find model %r in sas.models"%name)
[256dfe1]400        return ModelClass
401
402    # WARNING: ugly hack when handling model!
403    # Sasview models with multiplicity need to be created with the target
404    # multiplicity, so we cannot create the target model ahead of time for
405    # for multiplicity models.  Instead we store the model in a list and
406    # update the first element of that list with the new multiplicity model
407    # every time we evaluate.
[17bbadd]408
409    # grab the sasview model, or create it if it is a product model
[6d6508e]410    if model_info.composition:
411        composition_type, parts = model_info.composition
[17bbadd]412        if composition_type == 'product':
[51ec7e8]413            P, S = [get_model_class(revert_name(p))() for p in parts]
[256dfe1]414            model = [MultiplicationModel(P, S)]
[17bbadd]415        else:
[72a081d]416            raise ValueError("sasview mixture models not supported by compare")
[17bbadd]417    else:
[f3bd37f]418        old_name = revert_name(model_info)
419        if old_name is None:
420            raise ValueError("model %r does not exist in old sasview"
421                            % model_info.id)
[256dfe1]422        ModelClass = get_model_class(old_name)
423        model = [ModelClass()]
[216a9e1]424
[17bbadd]425    # build a smearer with which to call the model, if necessary
426    smearer = smear_selection(data, model=model)
[ec7e360]427    if hasattr(data, 'qx_data'):
428        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
429        index = ((~data.mask) & (~np.isnan(data.data))
430                 & (q >= data.qmin) & (q <= data.qmax))
431        if smearer is not None:
432            smearer.model = model  # because smear_selection has a bug
433            smearer.accuracy = data.accuracy
434            smearer.set_index(index)
[256dfe1]435            def _call_smearer():
436                smearer.model = model[0]
437                return smearer.get_value()
[b32dafd]438            theory = _call_smearer
[ec7e360]439        else:
[256dfe1]440            theory = lambda: model[0].evalDistribution([data.qx_data[index],
441                                                        data.qy_data[index]])
[ec7e360]442    elif smearer is not None:
[256dfe1]443        theory = lambda: smearer(model[0].evalDistribution(data.x))
[ec7e360]444    else:
[256dfe1]445        theory = lambda: model[0].evalDistribution(data.x)
[ec7e360]446
447    def calculator(**pars):
[dd7fc12]448        # type: (float, ...) -> np.ndarray
[caeb06d]449        """
450        Sasview calculator for model.
451        """
[256dfe1]452        oldpars = revert_pars(model_info, pars)
[bd49c79]453        # For multiplicity models, create a model with the correct multiplicity
454        control = oldpars.pop("CONTROL", None)
455        if control is not None:
456            # sphericalSLD has one fewer multiplicity.  This update should
457            # happen in revert_pars, but it hasn't been called yet.
458            model[0] = ModelClass(control)
459        # paying for parameter conversion each time to keep life simple, if not fast
[f67f26c]460        #print("sasview pars",oldpars)
[256dfe1]461        for k, v in oldpars.items():
[dd7fc12]462            name_attr = k.split('.')  # polydispersity components
463            if len(name_attr) == 2:
[256dfe1]464                model[0].dispersion[name_attr[0]][name_attr[1]] = v
[ec7e360]465            else:
[256dfe1]466                model[0].setParam(k, v)
[ec7e360]467        return theory()
468
469    calculator.engine = "sasview"
470    return calculator
471
472DTYPE_MAP = {
473    'half': '16',
474    'fast': 'fast',
475    'single': '32',
476    'double': '64',
477    'quad': '128',
478    'f16': '16',
479    'f32': '32',
480    'f64': '64',
481    'longdouble': '128',
482}
[17bbadd]483def eval_opencl(model_info, data, dtype='single', cutoff=0.):
[dd7fc12]484    # type: (ModelInfo, Data, str, float) -> Calculator
[caeb06d]485    """
486    Return a model calculator using the OpenCL calculation engine.
487    """
[a738209]488    if not core.HAVE_OPENCL:
489        raise RuntimeError("OpenCL not available")
490    model = core.build_model(model_info, dtype=dtype, platform="ocl")
[7cf2cfd]491    calculator = DirectModel(data, model, cutoff=cutoff)
[ec7e360]492    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
493    return calculator
[216a9e1]494
[17bbadd]495def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
[dd7fc12]496    # type: (ModelInfo, Data, str, float) -> Calculator
[9cfcac8]497    """
498    Return a model calculator using the DLL calculation engine.
499    """
[72a081d]500    model = core.build_model(model_info, dtype=dtype, platform="dll")
[7cf2cfd]501    calculator = DirectModel(data, model, cutoff=cutoff)
[ec7e360]502    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
503    return calculator
504
[b32dafd]505def time_calculation(calculator, pars, evals=1):
[dd7fc12]506    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
[caeb06d]507    """
508    Compute the average calculation time over N evaluations.
509
510    An additional call is generated without polydispersity in order to
511    initialize the calculation engine, and make the average more stable.
512    """
[ec7e360]513    # initialize the code so time is more accurate
[b32dafd]514    if evals > 1:
[dd7fc12]515        calculator(**suppress_pd(pars))
[216a9e1]516    toc = tic()
[dd7fc12]517    # make sure there is at least one eval
518    value = calculator(**pars)
[b32dafd]519    for _ in range(evals-1):
[7cf2cfd]520        value = calculator(**pars)
[b32dafd]521    average_time = toc()*1000. / evals
[f2f67a6]522    #print("I(q)",value)
[216a9e1]523    return value, average_time
524
[ec7e360]525def make_data(opts):
[dd7fc12]526    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
[caeb06d]527    """
528    Generate an empty dataset, used with the model to set Q points
529    and resolution.
530
531    *opts* contains the options, with 'qmax', 'nq', 'res',
532    'accuracy', 'is2d' and 'view' parsed from the command line.
533    """
[ec7e360]534    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
535    if opts['is2d']:
[dd7fc12]536        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
537        data = empty_data2D(q, resolution=res)
[ec7e360]538        data.accuracy = opts['accuracy']
[ea75043]539        set_beam_stop(data, 0.0004)
[87985ca]540        index = ~data.mask
[216a9e1]541    else:
[e78edc4]542        if opts['view'] == 'log' and not opts['zero']:
[b89f519]543            qmax = math.log10(qmax)
[ec7e360]544            q = np.logspace(qmax-3, qmax, nq)
[b89f519]545        else:
[ec7e360]546            q = np.linspace(0.001*qmax, qmax, nq)
[e78edc4]547        if opts['zero']:
548            q = np.hstack((0, q))
[ec7e360]549        data = empty_data1D(q, resolution=res)
[216a9e1]550        index = slice(None, None)
551    return data, index
552
[17bbadd]553def make_engine(model_info, data, dtype, cutoff):
[dd7fc12]554    # type: (ModelInfo, Data, str, float) -> Calculator
[caeb06d]555    """
556    Generate the appropriate calculation engine for the given datatype.
557
558    Datatypes with '!' appended are evaluated using external C DLLs rather
559    than OpenCL.
560    """
[ec7e360]561    if dtype == 'sasview':
[17bbadd]562        return eval_sasview(model_info, data)
[ec7e360]563    elif dtype.endswith('!'):
[17bbadd]564        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
[8407d8c]565    elif not model_info.opencl:
[407bf48]566        return eval_ctypes(model_info, data, dtype=dtype, cutoff=cutoff)
[ec7e360]567    else:
[17bbadd]568        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
[87985ca]569
[e78edc4]570def _show_invalid(data, theory):
[dd7fc12]571    # type: (Data, np.ma.ndarray) -> None
572    """
573    Display a list of the non-finite values in theory.
574    """
[e78edc4]575    if not theory.mask.any():
576        return
577
578    if hasattr(data, 'x'):
579        bad = zip(data.x[theory.mask], theory[theory.mask])
[dd7fc12]580        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
[e78edc4]581
582
[013adb7]583def compare(opts, limits=None):
[dd7fc12]584    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
[caeb06d]585    """
586    Preform a comparison using options from the command line.
587
588    *limits* are the limits on the values to use, either to set the y-axis
589    for 1D or to set the colormap scale for 2D.  If None, then they are
590    inferred from the data and returned. When exploring using Bumps,
591    the limits are set when the model is initially called, and maintained
592    as the values are adjusted, making it easier to see the effects of the
593    parameters.
594    """
[b32dafd]595    n_base, n_comp = opts['n1'], opts['n2']
[ec7e360]596    pars = opts['pars']
597    data = opts['data']
[87985ca]598
[dd7fc12]599    # silence the linter
[b32dafd]600    base = opts['engines'][0] if n_base else None
601    comp = opts['engines'][1] if n_comp else None
[dd7fc12]602    base_time = comp_time = None
603    base_value = comp_value = resid = relerr = None
604
[4b41184]605    # Base calculation
[b32dafd]606    if n_base > 0:
[319ab14]607        try:
[b32dafd]608            base_raw, base_time = time_calculation(base, pars, n_base)
[dd7fc12]609            base_value = np.ma.masked_invalid(base_raw)
[af92b73]610            print("%s t=%.2f ms, intensity=%.0f"
[e78edc4]611                  % (base.engine, base_time, base_value.sum()))
612            _show_invalid(data, base_value)
[319ab14]613        except ImportError:
614            traceback.print_exc()
[b32dafd]615            n_base = 0
[4b41184]616
617    # Comparison calculation
[b32dafd]618    if n_comp > 0:
[7cf2cfd]619        try:
[b32dafd]620            comp_raw, comp_time = time_calculation(comp, pars, n_comp)
[dd7fc12]621            comp_value = np.ma.masked_invalid(comp_raw)
[af92b73]622            print("%s t=%.2f ms, intensity=%.0f"
[e78edc4]623                  % (comp.engine, comp_time, comp_value.sum()))
624            _show_invalid(data, comp_value)
[7cf2cfd]625        except ImportError:
[5753e4e]626            traceback.print_exc()
[b32dafd]627            n_comp = 0
[87985ca]628
629    # Compare, but only if computing both forms
[b32dafd]630    if n_base > 0 and n_comp > 0:
[ec7e360]631        resid = (base_value - comp_value)
[b32dafd]632        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
[d15a908]633        _print_stats("|%s-%s|"
634                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
[caeb06d]635                     resid)
[d15a908]636        _print_stats("|(%s-%s)/%s|"
637                     % (base.engine, comp.engine, comp.engine),
[caeb06d]638                     relerr)
[87985ca]639
640    # Plot if requested
[ec7e360]641    if not opts['plot'] and not opts['explore']: return
642    view = opts['view']
[1726b21]643    import matplotlib.pyplot as plt
[013adb7]644    if limits is None:
645        vmin, vmax = np.Inf, -np.Inf
[b32dafd]646        if n_base > 0:
[e78edc4]647            vmin = min(vmin, base_value.min())
648            vmax = max(vmax, base_value.max())
[b32dafd]649        if n_comp > 0:
[e78edc4]650            vmin = min(vmin, comp_value.min())
651            vmax = max(vmax, comp_value.max())
[013adb7]652        limits = vmin, vmax
653
[b32dafd]654    if n_base > 0:
655        if n_comp > 0: plt.subplot(131)
[841753c]656        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
[af92b73]657        plt.title("%s t=%.2f ms"%(base.engine, base_time))
[ec7e360]658        #cbar_title = "log I"
[b32dafd]659    if n_comp > 0:
660        if n_base > 0: plt.subplot(132)
[841753c]661        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
[af92b73]662        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
[7cf2cfd]663        #cbar_title = "log I"
[b32dafd]664    if n_comp > 0 and n_base > 0:
[87985ca]665        plt.subplot(133)
[d5e650d]666        if not opts['rel_err']:
[caeb06d]667            err, errstr, errview = resid, "abs err", "linear"
[29f5536]668        else:
[caeb06d]669            err, errstr, errview = abs(relerr), "rel err", "log"
[4b41184]670        #err,errstr = base/comp,"ratio"
[841753c]671        plot_theory(data, None, resid=err, view=errview, use_data=False)
[d5e650d]672        if view == 'linear':
673            plt.xscale('linear')
[e78edc4]674        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
[7cf2cfd]675        #cbar_title = errstr if errview=="linear" else "log "+errstr
676    #if is2D:
677    #    h = plt.colorbar()
678    #    h.ax.set_title(cbar_title)
[0c24a82]679    fig = plt.gcf()
680    fig.suptitle(opts['name'])
[ba69383]681
[b32dafd]682    if n_comp > 0 and n_base > 0 and '-hist' in opts:
[ba69383]683        plt.figure()
[346bc88]684        v = relerr
[caeb06d]685        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
686        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
687        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
688                   % (base.engine, comp.engine, comp.engine))
[ba69383]689        plt.ylabel('P(err)')
[ec7e360]690        plt.title('Distribution of relative error between calculation engines')
[ba69383]691
[ec7e360]692    if not opts['explore']:
693        plt.show()
[8a20be5]694
[013adb7]695    return limits
696
[0763009]697def _print_stats(label, err):
[dd7fc12]698    # type: (str, np.ma.ndarray) -> None
699    # work with trimmed data, not the full set
[e78edc4]700    sorted_err = np.sort(abs(err.compressed()))
[dd7fc12]701    p50 = int((len(sorted_err)-1)*0.50)
702    p98 = int((len(sorted_err)-1)*0.98)
[0763009]703    data = [
704        "max:%.3e"%sorted_err[-1],
705        "median:%.3e"%sorted_err[p50],
706        "98%%:%.3e"%sorted_err[p98],
[dd7fc12]707        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
708        "zero-offset:%+.3e"%np.mean(sorted_err),
[0763009]709        ]
[caeb06d]710    print(label+"  "+"  ".join(data))
[0763009]711
712
713
[87985ca]714# ===========================================================================
715#
[216a9e1]716NAME_OPTIONS = set([
[5d316e9]717    'plot', 'noplot',
[ec7e360]718    'half', 'fast', 'single', 'double',
719    'single!', 'double!', 'quad!', 'sasview',
[e78edc4]720    'lowq', 'midq', 'highq', 'exq', 'zero',
[5d316e9]721    '2d', '1d',
722    'preset', 'random',
723    'poly', 'mono',
724    'nopars', 'pars',
725    'rel', 'abs',
[b89f519]726    'linear', 'log', 'q4',
[5d316e9]727    'hist', 'nohist',
[ec7e360]728    'edit',
[98d6cfc]729    'demo', 'default',
[216a9e1]730    ])
731VALUE_OPTIONS = [
732    # Note: random is both a name option and a value option
[ec7e360]733    'cutoff', 'random', 'nq', 'res', 'accuracy',
[87985ca]734    ]
735
[b32dafd]736def columnize(items, indent="", width=79):
[dd7fc12]737    # type: (List[str], str, int) -> str
[caeb06d]738    """
[1d4017a]739    Format a list of strings into columns.
740
741    Returns a string with carriage returns ready for printing.
[caeb06d]742    """
[b32dafd]743    column_width = max(len(w) for w in items) + 1
[7cf2cfd]744    num_columns = (width - len(indent)) // column_width
[b32dafd]745    num_rows = len(items) // num_columns
746    items = items + [""] * (num_rows * num_columns - len(items))
747    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
[7cf2cfd]748    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
749             for row in zip(*columns)]
750    output = indent + ("\n"+indent).join(lines)
751    return output
752
753
[98d6cfc]754def get_pars(model_info, use_demo=False):
[dd7fc12]755    # type: (ModelInfo, bool) -> ParameterSet
[caeb06d]756    """
757    Extract demo parameters from the model definition.
758    """
[ec7e360]759    # Get the default values for the parameters
[c499331]760    pars = {}
[6d6508e]761    for p in model_info.parameters.call_parameters:
[c499331]762        parts = [('', p.default)]
763        if p.polydisperse:
764            parts.append(('_pd', 0.0))
765            parts.append(('_pd_n', 0))
766            parts.append(('_pd_nsigma', 3.0))
767            parts.append(('_pd_type', "gaussian"))
768        for ext, val in parts:
769            if p.length > 1:
[b32dafd]770                dict(("%s%d%s" % (p.id, k, ext), val)
771                     for k in range(1, p.length+1))
[c499331]772            else:
[b32dafd]773                pars[p.id + ext] = val
[ec7e360]774
775    # Plug in values given in demo
[98d6cfc]776    if use_demo:
[6d6508e]777        pars.update(model_info.demo)
[373d1b6]778    return pars
779
[17bbadd]780
[ec7e360]781def parse_opts():
[dd7fc12]782    # type: () -> Dict[str, Any]
[caeb06d]783    """
784    Parse command line options.
785    """
[fc0fcd0]786    MODELS = core.list_models()
[caeb06d]787    flags = [arg for arg in sys.argv[1:]
788             if arg.startswith('-')]
789    values = [arg for arg in sys.argv[1:]
790              if not arg.startswith('-') and '=' in arg]
791    args = [arg for arg in sys.argv[1:]
792            if not arg.startswith('-') and '=' not in arg]
[d547f16]793    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]794    if len(args) == 0:
[7cf2cfd]795        print(USAGE)
[caeb06d]796        print("\nAvailable models:")
[7cf2cfd]797        print(columnize(MODELS, indent="  "))
[87985ca]798        sys.exit(1)
[319ab14]799    if len(args) > 3:
[9cfcac8]800        print("expected parameters: model N1 N2")
[87985ca]801
[17bbadd]802    name = args[0]
[72a081d]803    try:
[d19962c]804        model_info = core.load_model_info(name)
[7ae2b7f]805    except ImportError as exc:
[72a081d]806        print(str(exc))
807        print("Could not find model; use one of:\n    " + models)
808        sys.exit(1)
[17bbadd]809
[ec7e360]810    invalid = [o[1:] for o in flags
[216a9e1]811               if o[1:] not in NAME_OPTIONS
[d15a908]812               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]813    if invalid:
[9404dd3]814        print("Invalid options: %s"%(", ".join(invalid)))
[87985ca]815        sys.exit(1)
816
[ec7e360]817
[d15a908]818    # pylint: disable=bad-whitespace
[ec7e360]819    # Interpret the flags
820    opts = {
821        'plot'      : True,
822        'view'      : 'log',
823        'is2d'      : False,
824        'qmax'      : 0.05,
825        'nq'        : 128,
826        'res'       : 0.0,
827        'accuracy'  : 'Low',
[72a081d]828        'cutoff'    : 0.0,
[ec7e360]829        'seed'      : -1,  # default to preset
830        'mono'      : False,
831        'show_pars' : False,
832        'show_hist' : False,
833        'rel_err'   : True,
834        'explore'   : False,
[98d6cfc]835        'use_demo'  : True,
[dd7fc12]836        'zero'      : False,
[ec7e360]837    }
838    engines = []
839    for arg in flags:
840        if arg == '-noplot':    opts['plot'] = False
841        elif arg == '-plot':    opts['plot'] = True
842        elif arg == '-linear':  opts['view'] = 'linear'
843        elif arg == '-log':     opts['view'] = 'log'
844        elif arg == '-q4':      opts['view'] = 'q4'
845        elif arg == '-1d':      opts['is2d'] = False
846        elif arg == '-2d':      opts['is2d'] = True
847        elif arg == '-exq':     opts['qmax'] = 10.0
848        elif arg == '-highq':   opts['qmax'] = 1.0
849        elif arg == '-midq':    opts['qmax'] = 0.2
[ce0b154]850        elif arg == '-lowq':    opts['qmax'] = 0.05
[e78edc4]851        elif arg == '-zero':    opts['zero'] = True
[ec7e360]852        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
853        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
854        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
855        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
856        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
[dd7fc12]857        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
[ec7e360]858        elif arg == '-preset':  opts['seed'] = -1
859        elif arg == '-mono':    opts['mono'] = True
860        elif arg == '-poly':    opts['mono'] = False
861        elif arg == '-pars':    opts['show_pars'] = True
862        elif arg == '-nopars':  opts['show_pars'] = False
863        elif arg == '-hist':    opts['show_hist'] = True
864        elif arg == '-nohist':  opts['show_hist'] = False
865        elif arg == '-rel':     opts['rel_err'] = True
866        elif arg == '-abs':     opts['rel_err'] = False
867        elif arg == '-half':    engines.append(arg[1:])
868        elif arg == '-fast':    engines.append(arg[1:])
869        elif arg == '-single':  engines.append(arg[1:])
870        elif arg == '-double':  engines.append(arg[1:])
871        elif arg == '-single!': engines.append(arg[1:])
872        elif arg == '-double!': engines.append(arg[1:])
873        elif arg == '-quad!':   engines.append(arg[1:])
874        elif arg == '-sasview': engines.append(arg[1:])
875        elif arg == '-edit':    opts['explore'] = True
[98d6cfc]876        elif arg == '-demo':    opts['use_demo'] = True
877        elif arg == '-default':    opts['use_demo'] = False
[d15a908]878    # pylint: enable=bad-whitespace
[ec7e360]879
880    if len(engines) == 0:
[c499331]881        engines.extend(['single', 'double'])
[ec7e360]882    elif len(engines) == 1:
[def2c1b]883        if engines[0][0] == 'double':
[ec7e360]884            engines.append('single')
[def2c1b]885        else:
886            engines.append('double')
[ec7e360]887    elif len(engines) > 2:
888        del engines[2:]
889
[9cfcac8]890    n1 = int(args[1]) if len(args) > 1 else 1
891    n2 = int(args[2]) if len(args) > 2 else 1
[b32dafd]892    use_sasview = any(engine == 'sasview' and count > 0
[fa1582e]893                      for engine, count in zip(engines, [n1, n2]))
[87985ca]894
[ec7e360]895    # Get demo parameters from model definition, or use default parameters
896    # if model does not define demo parameters
[98d6cfc]897    pars = get_pars(model_info, opts['use_demo'])
898
[87985ca]899
900    # Fill in parameters given on the command line
[ec7e360]901    presets = {}
902    for arg in values:
[d15a908]903        k, v = arg.split('=', 1)
[87985ca]904        if k not in pars:
[ec7e360]905            # extract base name without polydispersity info
[87985ca]906            s = set(p.split('_pd')[0] for p in pars)
[d15a908]907            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
[87985ca]908            sys.exit(1)
[ec7e360]909        presets[k] = float(v) if not k.endswith('type') else v
910
911    # randomize parameters
912    #pars.update(set_pars)  # set value before random to control range
913    if opts['seed'] > -1:
[8bd7b77]914        pars = randomize_pars(model_info, pars, seed=opts['seed'])
[ec7e360]915        print("Randomize using -random=%i"%opts['seed'])
[8b25ee1]916    if opts['mono']:
917        pars = suppress_pd(pars)
[ec7e360]918    pars.update(presets)  # set value after random to control value
[fcd7bbd]919    #import pprint; pprint.pprint(model_info)
[17bbadd]920    constrain_pars(model_info, pars)
[fa1582e]921    if use_sasview:
922        constrain_new_to_old(model_info, pars)
[ec7e360]923    if opts['show_pars']:
[d6850fa]924        print(str(parlist(model_info, pars, opts['is2d'])))
[ec7e360]925
926    # Create the computational engines
[d15a908]927    data, _ = make_data(opts)
[9cfcac8]928    if n1:
[17bbadd]929        base = make_engine(model_info, data, engines[0], opts['cutoff'])
[ec7e360]930    else:
931        base = None
[9cfcac8]932    if n2:
[17bbadd]933        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
[ec7e360]934    else:
935        comp = None
936
[d15a908]937    # pylint: disable=bad-whitespace
[ec7e360]938    # Remember it all
939    opts.update({
940        'name'      : name,
[17bbadd]941        'def'       : model_info,
[9cfcac8]942        'n1'        : n1,
943        'n2'        : n2,
[ec7e360]944        'presets'   : presets,
945        'pars'      : pars,
946        'data'      : data,
947        'engines'   : [base, comp],
948    })
[d15a908]949    # pylint: enable=bad-whitespace
[ec7e360]950
951    return opts
952
953def explore(opts):
[dd7fc12]954    # type: (Dict[str, Any]) -> None
[d15a908]955    """
956    Explore the model using the Bumps GUI.
957    """
[7ae2b7f]958    import wx  # type: ignore
959    from bumps.names import FitProblem  # type: ignore
960    from bumps.gui.app_frame import AppFrame  # type: ignore
[ec7e360]961
962    problem = FitProblem(Explore(opts))
[d15a908]963    is_mac = "cocoa" in wx.version()
[ec7e360]964    app = wx.App()
965    frame = AppFrame(parent=None, title="explore")
[d15a908]966    if not is_mac: frame.Show()
[ec7e360]967    frame.panel.set_model(model=problem)
968    frame.panel.Layout()
969    frame.panel.aui.Split(0, wx.TOP)
[d15a908]970    if is_mac: frame.Show()
[ec7e360]971    app.MainLoop()
972
973class Explore(object):
974    """
[d15a908]975    Bumps wrapper for a SAS model comparison.
976
977    The resulting object can be used as a Bumps fit problem so that
978    parameters can be adjusted in the GUI, with plots updated on the fly.
[ec7e360]979    """
980    def __init__(self, opts):
[dd7fc12]981        # type: (Dict[str, Any]) -> None
[7ae2b7f]982        from bumps.cli import config_matplotlib  # type: ignore
[608e31e]983        from . import bumps_model
[ec7e360]984        config_matplotlib()
985        self.opts = opts
[17bbadd]986        model_info = opts['def']
987        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
[21b116f]988        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
[ec7e360]989        if not opts['is2d']:
[6d6508e]990            for p in model_info.parameters.user_parameters(is2d=False):
[303d8d6]991                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
[69aa451]992                    k = p.name+ext
[303d8d6]993                    v = pars.get(k, None)
994                    if v is not None:
995                        v.range(*parameter_range(k, v.value))
[ec7e360]996        else:
[013adb7]997            for k, v in pars.items():
[ec7e360]998                v.range(*parameter_range(k, v.value))
999
1000        self.pars = pars
1001        self.pd_types = pd_types
[013adb7]1002        self.limits = None
[ec7e360]1003
1004    def numpoints(self):
[dd7fc12]1005        # type: () -> int
[ec7e360]1006        """
[608e31e]1007        Return the number of points.
[ec7e360]1008        """
1009        return len(self.pars) + 1  # so dof is 1
1010
1011    def parameters(self):
[dd7fc12]1012        # type: () -> Any   # Dict/List hierarchy of parameters
[ec7e360]1013        """
[608e31e]1014        Return a dictionary of parameters.
[ec7e360]1015        """
1016        return self.pars
1017
1018    def nllf(self):
[dd7fc12]1019        # type: () -> float
[608e31e]1020        """
1021        Return cost.
1022        """
[d15a908]1023        # pylint: disable=no-self-use
[ec7e360]1024        return 0.  # No nllf
1025
1026    def plot(self, view='log'):
[dd7fc12]1027        # type: (str) -> None
[ec7e360]1028        """
1029        Plot the data and residuals.
1030        """
[608e31e]1031        pars = dict((k, v.value) for k, v in self.pars.items())
[ec7e360]1032        pars.update(self.pd_types)
1033        self.opts['pars'] = pars
[013adb7]1034        limits = compare(self.opts, limits=self.limits)
1035        if self.limits is None:
1036            vmin, vmax = limits
[dd7fc12]1037            self.limits = vmax*1e-7, 1.3*vmax
[87985ca]1038
1039
[d15a908]1040def main():
[dd7fc12]1041    # type: () -> None
[d15a908]1042    """
1043    Main program.
1044    """
1045    opts = parse_opts()
1046    if opts['explore']:
1047        explore(opts)
1048    else:
1049        compare(opts)
1050
[8a20be5]1051if __name__ == "__main__":
[87985ca]1052    main()
Note: See TracBrowser for help on using the repository browser.