source: sasmodels/sasmodels/compare.py @ 630156b

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 630156b was 630156b, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

sascomp: improve data file handling; add -help as alias to -html; default to -mono

  • Property mode set to 100755
File size: 44.5 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
[caeb06d]3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
[9cfcac8]11tag the model as double precision only.
[caeb06d]12
[9cfcac8]13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
[caeb06d]14sasmodels root to see the command line options.
15
[9cfcac8]16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
[caeb06d]18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
[190fc2b]31import sys
[a769b54]32import os
[190fc2b]33import math
34import datetime
35import traceback
[ff1fff5]36import re
[190fc2b]37
[7ae2b7f]38import numpy as np  # type: ignore
[190fc2b]39
40from . import core
41from . import kerneldll
[6831fa0]42from . import exception
[a769b54]43from .data import plot_theory, empty_data1D, empty_data2D, load_data
[190fc2b]44from .direct_model import DirectModel
[f247314]45from .convert import revert_name, revert_pars, constrain_new_to_old
[ff1fff5]46from .generate import FLOAT_RE
[190fc2b]47
[dd7fc12]48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
[6831fa0]50except Exception:
[dd7fc12]51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
[8d62008]55    Calculator = Callable[[float], np.ndarray]
[dd7fc12]56
[caeb06d]57USAGE = """
58usage: compare.py model N1 N2 [options...] [key=val]
59
60Compare the speed and value for a model between the SasView original and the
61sasmodels rewrite.
62
[8c65a33]63model or model1,model2 are the names of the models to compare (see below).
[caeb06d]64N1 is the number of times to run sasmodels (default=1).
65N2 is the number times to run sasview (default=1).
66
67Options (* for default):
68
69    -plot*/-noplot plots or suppress the plot of the model
70    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
71    -nq=128 sets the number of Q points in the data set
[e78edc4]72    -zero indicates that q=0 should be included
[caeb06d]73    -1d*/-2d computes 1d or 2d data
74    -preset*/-random[=seed] preset or random parameters
[630156b]75    -mono*/-poly force monodisperse or allow polydisperse demo parameters
[0b040de]76    -magnetic/-nonmagnetic* suppress magnetism
[caeb06d]77    -cutoff=1e-5* cutoff value for including a point in polydispersity
78    -pars/-nopars* prints the parameter set or not
79    -abs/-rel* plot relative or absolute error
80    -linear/-log*/-q4 intensity scaling
81    -hist/-nohist* plot histogram of relative error
82    -res=0 sets the resolution width dQ/Q if calculating with resolution
83    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
84    -edit starts the parameter explorer
[98d6cfc]85    -default/-demo* use demo vs default parameters
[630156b]86    -help/-html shows the model docs instead of running the model
[9068f4c]87    -title="note" adds note to the plot title, after the model name
[a769b54]88    -data="path" uses q, dq from the data file
[caeb06d]89
90Any two calculation engines can be selected for comparison:
91
92    -single/-double/-half/-fast sets an OpenCL calculation engine
93    -single!/-double!/-quad! sets an OpenMP calculation engine
94    -sasview sets the sasview calculation engine
95
[8c65a33]96The default is -single -double.  Note that the interpretation of quad
[caeb06d]97precision depends on architecture, and may vary from 64-bit to 128-bit,
98with 80-bit floats being common (1e-19 precision).
99
100Key=value pairs allow you to set specific values for the model parameters.
[8c65a33]101Key=value1,value2 to compare different values of the same parameter.
102value can be an expression including other parameters
[caeb06d]103"""
104
105# Update docs with command line usage string.   This is separate from the usual
106# doc string so that we can display it at run time if there is an error.
107# lin
[d15a908]108__doc__ = (__doc__  # pylint: disable=redefined-builtin
109           + """
[caeb06d]110Program description
111-------------------
112
[d15a908]113"""
114           + USAGE)
[caeb06d]115
[750ffa5]116kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]117
[248561a]118# list of math functions for use in evaluating parameters
119MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
120
[7cf2cfd]121# CRUFT python 2.6
122if not hasattr(datetime.timedelta, 'total_seconds'):
123    def delay(dt):
124        """Return number date-time delta as number seconds"""
125        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
126else:
127    def delay(dt):
128        """Return number date-time delta as number seconds"""
129        return dt.total_seconds()
130
131
[4f2478e]132class push_seed(object):
133    """
134    Set the seed value for the random number generator.
135
136    When used in a with statement, the random number generator state is
137    restored after the with statement is complete.
138
139    :Parameters:
140
141    *seed* : int or array_like, optional
142        Seed for RandomState
143
144    :Example:
145
146    Seed can be used directly to set the seed::
147
148        >>> from numpy.random import randint
149        >>> push_seed(24)
150        <...push_seed object at...>
151        >>> print(randint(0,1000000,3))
152        [242082    899 211136]
153
154    Seed can also be used in a with statement, which sets the random
155    number generator state for the enclosed computations and restores
156    it to the previous state on completion::
157
158        >>> with push_seed(24):
159        ...    print(randint(0,1000000,3))
160        [242082    899 211136]
161
162    Using nested contexts, we can demonstrate that state is indeed
163    restored after the block completes::
164
165        >>> with push_seed(24):
166        ...    print(randint(0,1000000))
167        ...    with push_seed(24):
168        ...        print(randint(0,1000000,3))
169        ...    print(randint(0,1000000))
170        242082
171        [242082    899 211136]
172        899
173
174    The restore step is protected against exceptions in the block::
175
176        >>> with push_seed(24):
177        ...    print(randint(0,1000000))
178        ...    try:
179        ...        with push_seed(24):
180        ...            print(randint(0,1000000,3))
181        ...            raise Exception()
[dd7fc12]182        ...    except Exception:
[4f2478e]183        ...        print("Exception raised")
184        ...    print(randint(0,1000000))
185        242082
186        [242082    899 211136]
187        Exception raised
188        899
189    """
190    def __init__(self, seed=None):
[dd7fc12]191        # type: (Optional[int]) -> None
[4f2478e]192        self._state = np.random.get_state()
193        np.random.seed(seed)
194
195    def __enter__(self):
[dd7fc12]196        # type: () -> None
197        pass
[4f2478e]198
[b32dafd]199    def __exit__(self, exc_type, exc_value, traceback):
[dd7fc12]200        # type: (Any, BaseException, Any) -> None
201        # TODO: better typing for __exit__ method
[4f2478e]202        np.random.set_state(self._state)
203
[7cf2cfd]204def tic():
[dd7fc12]205    # type: () -> Callable[[], float]
[7cf2cfd]206    """
207    Timer function.
208
209    Use "toc=tic()" to start the clock and "toc()" to measure
210    a time interval.
211    """
212    then = datetime.datetime.now()
213    return lambda: delay(datetime.datetime.now() - then)
214
215
216def set_beam_stop(data, radius, outer=None):
[dd7fc12]217    # type: (Data, float, float) -> None
[7cf2cfd]218    """
219    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
220
[dd7fc12]221    Note: this function does not require sasview
[7cf2cfd]222    """
223    if hasattr(data, 'qx_data'):
224        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
225        data.mask = (q < radius)
226        if outer is not None:
227            data.mask |= (q >= outer)
228    else:
229        data.mask = (data.x < radius)
230        if outer is not None:
231            data.mask |= (data.x >= outer)
232
[8a20be5]233
[ec7e360]234def parameter_range(p, v):
[dd7fc12]235    # type: (str, float) -> Tuple[float, float]
[87985ca]236    """
[ec7e360]237    Choose a parameter range based on parameter name and initial value.
[87985ca]238    """
[8bd7b77]239    # process the polydispersity options
[ec7e360]240    if p.endswith('_pd_n'):
[dd7fc12]241        return 0., 100.
[ec7e360]242    elif p.endswith('_pd_nsigma'):
[dd7fc12]243        return 0., 5.
[ec7e360]244    elif p.endswith('_pd_type'):
[dd7fc12]245        raise ValueError("Cannot return a range for a string value")
[caeb06d]246    elif any(s in p for s in ('theta', 'phi', 'psi')):
[87985ca]247        # orientation in [-180,180], orientation pd in [0,45]
248        if p.endswith('_pd'):
[dd7fc12]249            return 0., 45.
[87985ca]250        else:
[dd7fc12]251            return -180., 180.
[87985ca]252    elif p.endswith('_pd'):
[dd7fc12]253        return 0., 1.
[8bd7b77]254    elif 'sld' in p:
[dd7fc12]255        return -0.5, 10.
[eb46451]256    elif p == 'background':
[dd7fc12]257        return 0., 10.
[eb46451]258    elif p == 'scale':
[dd7fc12]259        return 0., 1.e3
260    elif v < 0.:
261        return 2.*v, -2.*v
[87985ca]262    else:
[dd7fc12]263        return 0., (2.*v if v > 0. else 1.)
[87985ca]264
[4f2478e]265
[8bd7b77]266def _randomize_one(model_info, p, v):
[dd7fc12]267    # type: (ModelInfo, str, float) -> float
268    # type: (ModelInfo, str, str) -> str
[ec7e360]269    """
[caeb06d]270    Randomize a single parameter.
[ec7e360]271    """
[f3bd37f]272    if any(p.endswith(s) for s in ('_pd', '_pd_n', '_pd_nsigma', '_pd_type')):
[ec7e360]273        return v
[8bd7b77]274
275    # Find the parameter definition
[6d6508e]276    for par in model_info.parameters.call_parameters:
[8bd7b77]277        if par.name == p:
278            break
[ec7e360]279    else:
[8bd7b77]280        raise ValueError("unknown parameter %r"%p)
281
282    if len(par.limits) > 2:  # choice list
283        return np.random.randint(len(par.limits))
284
285    limits = par.limits
286    if not np.isfinite(limits).all():
287        low, high = parameter_range(p, v)
288        limits = (max(limits[0], low), min(limits[1], high))
289
290    return np.random.uniform(*limits)
[cd3dba0]291
[4f2478e]292
[8bd7b77]293def randomize_pars(model_info, pars, seed=None):
[dd7fc12]294    # type: (ModelInfo, ParameterSet, int) -> ParameterSet
[caeb06d]295    """
296    Generate random values for all of the parameters.
297
298    Valid ranges for the random number generator are guessed from the name of
299    the parameter; this will not account for constraints such as cap radius
300    greater than cylinder radius in the capped_cylinder model, so
301    :func:`constrain_pars` needs to be called afterward..
302    """
[4f2478e]303    with push_seed(seed):
304        # Note: the sort guarantees order `of calls to random number generator
[dd7fc12]305        random_pars = dict((p, _randomize_one(model_info, p, v))
306                           for p, v in sorted(pars.items()))
307    return random_pars
[cd3dba0]308
[17bbadd]309def constrain_pars(model_info, pars):
[dd7fc12]310    # type: (ModelInfo, ParameterSet) -> None
[9a66e65]311    """
312    Restrict parameters to valid values.
[caeb06d]313
314    This includes model specific code for models such as capped_cylinder
315    which need to support within model constraints (cap radius more than
316    cylinder radius in this case).
[dd7fc12]317
318    Warning: this updates the *pars* dictionary in place.
[9a66e65]319    """
[6d6508e]320    name = model_info.id
[17bbadd]321    # if it is a product model, then just look at the form factor since
322    # none of the structure factors need any constraints.
323    if '*' in name:
324        name = name.split('*')[0]
325
[f72d70a]326    # Suppress magnetism for python models (not yet implemented)
327    if callable(model_info.Iq):
328        pars.update(suppress_magnetism(pars))
329
[158cee4]330    if name == 'barbell':
331        if pars['radius_bell'] < pars['radius']:
332            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
[b514adf]333
[158cee4]334    elif name == 'capped_cylinder':
335        if pars['radius_cap'] < pars['radius']:
336            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
337
338    elif name == 'guinier':
339        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
[b514adf]340        #q_max = 0.2  # mid q maximum
341        q_max = 1.0  # high q maximum
342        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
[caeb06d]343        pars['rg'] = min(pars['rg'], rg_max)
[cd3dba0]344
[3e8ea5d]345    elif name == 'pearl_necklace':
346        if pars['radius'] < pars['thick_string']:
347            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
348        pass
349
[158cee4]350    elif name == 'rpa':
[82c299f]351        # Make sure phi sums to 1.0
352        if pars['case_num'] < 2:
[8bd7b77]353            pars['Phi1'] = 0.
354            pars['Phi2'] = 0.
[82c299f]355        elif pars['case_num'] < 5:
[8bd7b77]356            pars['Phi1'] = 0.
357        total = sum(pars['Phi'+c] for c in '1234')
358        for c in '1234':
[82c299f]359            pars['Phi'+c] /= total
360
[d6850fa]361def parlist(model_info, pars, is2d):
[dd7fc12]362    # type: (ModelInfo, ParameterSet, bool) -> str
[caeb06d]363    """
364    Format the parameter list for printing.
365    """
[a4a7308]366    lines = []
[6d6508e]367    parameters = model_info.parameters
[0b040de]368    magnetic = False
[d19962c]369    for p in parameters.user_parameters(pars, is2d):
[0b040de]370        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
371            continue
372        if p.id.startswith('up:') and not magnetic:
373            continue
[d19962c]374        fields = dict(
375            value=pars.get(p.id, p.default),
376            pd=pars.get(p.id+"_pd", 0.),
377            n=int(pars.get(p.id+"_pd_n", 0)),
378            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
[dd7fc12]379            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
[bd49c79]380            relative_pd=p.relative_pd,
[0b040de]381            M0=pars.get('M0:'+p.id, 0.),
382            mphi=pars.get('mphi:'+p.id, 0.),
383            mtheta=pars.get('mtheta:'+p.id, 0.),
[dd7fc12]384        )
[d19962c]385        lines.append(_format_par(p.name, **fields))
[0b040de]386        magnetic = magnetic or fields['M0'] != 0.
[a4a7308]387    return "\n".join(lines)
388
389    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
390
[bd49c79]391def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
[0b040de]392                relative_pd=False, M0=0., mphi=0., mtheta=0.):
[dd7fc12]393    # type: (str, float, float, int, float, str) -> str
[a4a7308]394    line = "%s: %g"%(name, value)
395    if pd != 0.  and n != 0:
[bd49c79]396        if relative_pd:
397            pd *= value
[a4a7308]398        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
[dd7fc12]399                % (pd, n, nsigma, nsigma, pdtype)
[0b040de]400    if M0 != 0.:
401        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
[a4a7308]402    return line
[87985ca]403
404def suppress_pd(pars):
[dd7fc12]405    # type: (ParameterSet) -> ParameterSet
[87985ca]406    """
407    Suppress theta_pd for now until the normalization is resolved.
408
409    May also suppress complete polydispersity of the model to test
410    models more quickly.
411    """
[f4f3919]412    pars = pars.copy()
[87985ca]413    for p in pars:
[8b25ee1]414        if p.endswith("_pd_n"): pars[p] = 0
[f4f3919]415    return pars
[87985ca]416
[0b040de]417def suppress_magnetism(pars):
418    # type: (ParameterSet) -> ParameterSet
419    """
420    Suppress theta_pd for now until the normalization is resolved.
421
422    May also suppress complete polydispersity of the model to test
423    models more quickly.
424    """
425    pars = pars.copy()
426    for p in pars:
427        if p.startswith("M0:"): pars[p] = 0
428    return pars
429
[17bbadd]430def eval_sasview(model_info, data):
[dd7fc12]431    # type: (Modelinfo, Data) -> Calculator
[caeb06d]432    """
[f247314]433    Return a model calculator using the pre-4.0 SasView models.
[caeb06d]434    """
[dc056b9]435    # importing sas here so that the error message will be that sas failed to
436    # import rather than the more obscure smear_selection not imported error
[2bebe2b]437    import sas
[dd7fc12]438    import sas.models
[8d62008]439    from sas.models.qsmearing import smear_selection
440    from sas.models.MultiplicationModel import MultiplicationModel
[050c2c8]441    from sas.models.dispersion_models import models as dispersers
[ec7e360]442
[256dfe1]443    def get_model_class(name):
[dd7fc12]444        # type: (str) -> "sas.models.BaseComponent"
[17bbadd]445        #print("new",sorted(_pars.items()))
[dd7fc12]446        __import__('sas.models.' + name)
[17bbadd]447        ModelClass = getattr(getattr(sas.models, name, None), name, None)
448        if ModelClass is None:
449            raise ValueError("could not find model %r in sas.models"%name)
[256dfe1]450        return ModelClass
451
452    # WARNING: ugly hack when handling model!
453    # Sasview models with multiplicity need to be created with the target
454    # multiplicity, so we cannot create the target model ahead of time for
455    # for multiplicity models.  Instead we store the model in a list and
456    # update the first element of that list with the new multiplicity model
457    # every time we evaluate.
[17bbadd]458
459    # grab the sasview model, or create it if it is a product model
[6d6508e]460    if model_info.composition:
461        composition_type, parts = model_info.composition
[17bbadd]462        if composition_type == 'product':
[51ec7e8]463            P, S = [get_model_class(revert_name(p))() for p in parts]
[256dfe1]464            model = [MultiplicationModel(P, S)]
[17bbadd]465        else:
[72a081d]466            raise ValueError("sasview mixture models not supported by compare")
[17bbadd]467    else:
[f3bd37f]468        old_name = revert_name(model_info)
469        if old_name is None:
470            raise ValueError("model %r does not exist in old sasview"
471                            % model_info.id)
[256dfe1]472        ModelClass = get_model_class(old_name)
473        model = [ModelClass()]
[050c2c8]474    model[0].disperser_handles = {}
[216a9e1]475
[17bbadd]476    # build a smearer with which to call the model, if necessary
477    smearer = smear_selection(data, model=model)
[ec7e360]478    if hasattr(data, 'qx_data'):
479        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
480        index = ((~data.mask) & (~np.isnan(data.data))
481                 & (q >= data.qmin) & (q <= data.qmax))
482        if smearer is not None:
483            smearer.model = model  # because smear_selection has a bug
484            smearer.accuracy = data.accuracy
485            smearer.set_index(index)
[256dfe1]486            def _call_smearer():
487                smearer.model = model[0]
488                return smearer.get_value()
[b32dafd]489            theory = _call_smearer
[ec7e360]490        else:
[256dfe1]491            theory = lambda: model[0].evalDistribution([data.qx_data[index],
492                                                        data.qy_data[index]])
[ec7e360]493    elif smearer is not None:
[256dfe1]494        theory = lambda: smearer(model[0].evalDistribution(data.x))
[ec7e360]495    else:
[256dfe1]496        theory = lambda: model[0].evalDistribution(data.x)
[ec7e360]497
498    def calculator(**pars):
[dd7fc12]499        # type: (float, ...) -> np.ndarray
[caeb06d]500        """
501        Sasview calculator for model.
502        """
[256dfe1]503        oldpars = revert_pars(model_info, pars)
[bd49c79]504        # For multiplicity models, create a model with the correct multiplicity
505        control = oldpars.pop("CONTROL", None)
506        if control is not None:
507            # sphericalSLD has one fewer multiplicity.  This update should
508            # happen in revert_pars, but it hasn't been called yet.
509            model[0] = ModelClass(control)
510        # paying for parameter conversion each time to keep life simple, if not fast
[050c2c8]511        for k, v in oldpars.items():
512            if k.endswith('.type'):
513                par = k[:-5]
[6831fa0]514                if v == 'gaussian': continue
[050c2c8]515                cls = dispersers[v if v != 'rectangle' else 'rectangula']
516                handle = cls()
517                model[0].disperser_handles[par] = handle
[6831fa0]518                try:
519                    model[0].set_dispersion(par, handle)
520                except Exception:
521                    exception.annotate_exception("while setting %s to %r"
522                                                 %(par, v))
523                    raise
524
[050c2c8]525
[f67f26c]526        #print("sasview pars",oldpars)
[256dfe1]527        for k, v in oldpars.items():
[dd7fc12]528            name_attr = k.split('.')  # polydispersity components
529            if len(name_attr) == 2:
[050c2c8]530                par, disp_par = name_attr
531                model[0].dispersion[par][disp_par] = v
[ec7e360]532            else:
[256dfe1]533                model[0].setParam(k, v)
[ec7e360]534        return theory()
535
536    calculator.engine = "sasview"
537    return calculator
538
539DTYPE_MAP = {
540    'half': '16',
541    'fast': 'fast',
542    'single': '32',
543    'double': '64',
544    'quad': '128',
545    'f16': '16',
546    'f32': '32',
547    'f64': '64',
[650c6d2]548    'float16': '16',
549    'float32': '32',
550    'float64': '64',
551    'float128': '128',
[ec7e360]552    'longdouble': '128',
553}
[17bbadd]554def eval_opencl(model_info, data, dtype='single', cutoff=0.):
[dd7fc12]555    # type: (ModelInfo, Data, str, float) -> Calculator
[caeb06d]556    """
557    Return a model calculator using the OpenCL calculation engine.
558    """
[a738209]559    if not core.HAVE_OPENCL:
560        raise RuntimeError("OpenCL not available")
561    model = core.build_model(model_info, dtype=dtype, platform="ocl")
[7cf2cfd]562    calculator = DirectModel(data, model, cutoff=cutoff)
[ec7e360]563    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
564    return calculator
[216a9e1]565
[17bbadd]566def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
[dd7fc12]567    # type: (ModelInfo, Data, str, float) -> Calculator
[9cfcac8]568    """
569    Return a model calculator using the DLL calculation engine.
570    """
[72a081d]571    model = core.build_model(model_info, dtype=dtype, platform="dll")
[7cf2cfd]572    calculator = DirectModel(data, model, cutoff=cutoff)
[ec7e360]573    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
574    return calculator
575
[b32dafd]576def time_calculation(calculator, pars, evals=1):
[dd7fc12]577    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
[caeb06d]578    """
579    Compute the average calculation time over N evaluations.
580
581    An additional call is generated without polydispersity in order to
582    initialize the calculation engine, and make the average more stable.
583    """
[ec7e360]584    # initialize the code so time is more accurate
[b32dafd]585    if evals > 1:
[dd7fc12]586        calculator(**suppress_pd(pars))
[216a9e1]587    toc = tic()
[dd7fc12]588    # make sure there is at least one eval
589    value = calculator(**pars)
[b32dafd]590    for _ in range(evals-1):
[7cf2cfd]591        value = calculator(**pars)
[b32dafd]592    average_time = toc()*1000. / evals
[f2f67a6]593    #print("I(q)",value)
[216a9e1]594    return value, average_time
595
[ec7e360]596def make_data(opts):
[dd7fc12]597    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
[caeb06d]598    """
599    Generate an empty dataset, used with the model to set Q points
600    and resolution.
601
602    *opts* contains the options, with 'qmax', 'nq', 'res',
603    'accuracy', 'is2d' and 'view' parsed from the command line.
604    """
[ec7e360]605    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
606    if opts['is2d']:
[dd7fc12]607        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
608        data = empty_data2D(q, resolution=res)
[ec7e360]609        data.accuracy = opts['accuracy']
[ea75043]610        set_beam_stop(data, 0.0004)
[87985ca]611        index = ~data.mask
[216a9e1]612    else:
[e78edc4]613        if opts['view'] == 'log' and not opts['zero']:
[b89f519]614            qmax = math.log10(qmax)
[ec7e360]615            q = np.logspace(qmax-3, qmax, nq)
[b89f519]616        else:
[ec7e360]617            q = np.linspace(0.001*qmax, qmax, nq)
[e78edc4]618        if opts['zero']:
619            q = np.hstack((0, q))
[ec7e360]620        data = empty_data1D(q, resolution=res)
[216a9e1]621        index = slice(None, None)
622    return data, index
623
[17bbadd]624def make_engine(model_info, data, dtype, cutoff):
[dd7fc12]625    # type: (ModelInfo, Data, str, float) -> Calculator
[caeb06d]626    """
627    Generate the appropriate calculation engine for the given datatype.
628
629    Datatypes with '!' appended are evaluated using external C DLLs rather
630    than OpenCL.
631    """
[ec7e360]632    if dtype == 'sasview':
[17bbadd]633        return eval_sasview(model_info, data)
[ec7e360]634    elif dtype.endswith('!'):
[17bbadd]635        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
[ec7e360]636    else:
[17bbadd]637        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
[87985ca]638
[e78edc4]639def _show_invalid(data, theory):
[dd7fc12]640    # type: (Data, np.ma.ndarray) -> None
641    """
642    Display a list of the non-finite values in theory.
643    """
[e78edc4]644    if not theory.mask.any():
645        return
646
647    if hasattr(data, 'x'):
648        bad = zip(data.x[theory.mask], theory[theory.mask])
[dd7fc12]649        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
[e78edc4]650
651
[013adb7]652def compare(opts, limits=None):
[dd7fc12]653    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
[caeb06d]654    """
655    Preform a comparison using options from the command line.
656
657    *limits* are the limits on the values to use, either to set the y-axis
658    for 1D or to set the colormap scale for 2D.  If None, then they are
659    inferred from the data and returned. When exploring using Bumps,
660    the limits are set when the model is initially called, and maintained
661    as the values are adjusted, making it easier to see the effects of the
662    parameters.
663    """
[ca9e54e]664    result = run_models(opts, verbose=True)
665    if opts['plot']:  # Note: never called from explore
666        plot_models(opts, result, limits=limits)
667
668def run_models(opts, verbose=False):
669    # type: (Dict[str, Any]) -> Dict[str, Any]
670
[ff1fff5]671    n_base, n_comp = opts['count']
672    pars, pars2 = opts['pars']
[ec7e360]673    data = opts['data']
[87985ca]674
[dd7fc12]675    # silence the linter
[b32dafd]676    base = opts['engines'][0] if n_base else None
677    comp = opts['engines'][1] if n_comp else None
[ca9e54e]678
[dd7fc12]679    base_time = comp_time = None
680    base_value = comp_value = resid = relerr = None
681
[4b41184]682    # Base calculation
[b32dafd]683    if n_base > 0:
[319ab14]684        try:
[b32dafd]685            base_raw, base_time = time_calculation(base, pars, n_base)
[dd7fc12]686            base_value = np.ma.masked_invalid(base_raw)
[ca9e54e]687            if verbose:
688                print("%s t=%.2f ms, intensity=%.0f"
689                      % (base.engine, base_time, base_value.sum()))
[e78edc4]690            _show_invalid(data, base_value)
[319ab14]691        except ImportError:
692            traceback.print_exc()
[b32dafd]693            n_base = 0
[4b41184]694
695    # Comparison calculation
[b32dafd]696    if n_comp > 0:
[7cf2cfd]697        try:
[ff1fff5]698            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
[dd7fc12]699            comp_value = np.ma.masked_invalid(comp_raw)
[ca9e54e]700            if verbose:
701                print("%s t=%.2f ms, intensity=%.0f"
702                      % (comp.engine, comp_time, comp_value.sum()))
[e78edc4]703            _show_invalid(data, comp_value)
[7cf2cfd]704        except ImportError:
[5753e4e]705            traceback.print_exc()
[b32dafd]706            n_comp = 0
[87985ca]707
708    # Compare, but only if computing both forms
[b32dafd]709    if n_base > 0 and n_comp > 0:
[ec7e360]710        resid = (base_value - comp_value)
[b32dafd]711        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
[ca9e54e]712        if verbose:
713            _print_stats("|%s-%s|"
714                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
715                         resid)
716            _print_stats("|(%s-%s)/%s|"
717                         % (base.engine, comp.engine, comp.engine),
718                         relerr)
719
720    return dict(base_value=base_value, comp_value=comp_value,
721                base_time=base_time, comp_time=comp_time,
722                resid=resid, relerr=relerr)
723
724
725def _print_stats(label, err):
726    # type: (str, np.ma.ndarray) -> None
727    # work with trimmed data, not the full set
728    sorted_err = np.sort(abs(err.compressed()))
729    if len(sorted_err) == 0.:
730        print(label + "  no valid values")
731        return
732
733    p50 = int((len(sorted_err)-1)*0.50)
734    p98 = int((len(sorted_err)-1)*0.98)
735    data = [
736        "max:%.3e"%sorted_err[-1],
737        "median:%.3e"%sorted_err[p50],
738        "98%%:%.3e"%sorted_err[p98],
739        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
740        "zero-offset:%+.3e"%np.mean(sorted_err),
741        ]
742    print(label+"  "+"  ".join(data))
743
744
745def plot_models(opts, result, limits=None):
746    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
747    base_value, comp_value= result['base_value'], result['comp_value']
748    base_time, comp_time = result['base_time'], result['comp_time']
749    resid, relerr = result['resid'], result['relerr']
750
751    have_base, have_comp = (base_value is not None), (comp_value is not None)
752    base = opts['engines'][0] if have_base else None
753    comp = opts['engines'][1] if have_comp else None
754    data = opts['data']
[630156b]755    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
[87985ca]756
757    # Plot if requested
[ec7e360]758    view = opts['view']
[1726b21]759    import matplotlib.pyplot as plt
[630156b]760    if limits is None and not use_data:
[013adb7]761        vmin, vmax = np.Inf, -np.Inf
[ca9e54e]762        if have_base:
[e78edc4]763            vmin = min(vmin, base_value.min())
764            vmax = max(vmax, base_value.max())
[ca9e54e]765        if have_comp:
[e78edc4]766            vmin = min(vmin, comp_value.min())
767            vmax = max(vmax, comp_value.max())
[013adb7]768        limits = vmin, vmax
769
[ca9e54e]770    if have_base:
771        if have_comp: plt.subplot(131)
[a769b54]772        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
[af92b73]773        plt.title("%s t=%.2f ms"%(base.engine, base_time))
[ec7e360]774        #cbar_title = "log I"
[ca9e54e]775    if have_comp:
776        if have_base: plt.subplot(132)
777        if not opts['is2d'] and have_base:
[a769b54]778            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
779        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
[af92b73]780        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
[7cf2cfd]781        #cbar_title = "log I"
[ca9e54e]782    if have_base and have_comp:
[87985ca]783        plt.subplot(133)
[d5e650d]784        if not opts['rel_err']:
[caeb06d]785            err, errstr, errview = resid, "abs err", "linear"
[29f5536]786        else:
[caeb06d]787            err, errstr, errview = abs(relerr), "rel err", "log"
[158cee4]788        if 0:  # 95% cutoff
789            sorted = np.sort(err.flatten())
790            cutoff = sorted[int(sorted.size*0.95)]
791            err[err>cutoff] = cutoff
[4b41184]792        #err,errstr = base/comp,"ratio"
[a769b54]793        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
[d5e650d]794        if view == 'linear':
795            plt.xscale('linear')
[e78edc4]796        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
[7cf2cfd]797        #cbar_title = errstr if errview=="linear" else "log "+errstr
798    #if is2D:
799    #    h = plt.colorbar()
800    #    h.ax.set_title(cbar_title)
[0c24a82]801    fig = plt.gcf()
[a0d75ce]802    extra_title = ' '+opts['title'] if opts['title'] else ''
[ff1fff5]803    fig.suptitle(":".join(opts['name']) + extra_title)
[ba69383]804
[ca9e54e]805    if have_base and have_comp and opts['show_hist']:
[ba69383]806        plt.figure()
[346bc88]807        v = relerr
[caeb06d]808        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
809        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
810        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
811                   % (base.engine, comp.engine, comp.engine))
[ba69383]812        plt.ylabel('P(err)')
[ec7e360]813        plt.title('Distribution of relative error between calculation engines')
[ba69383]814
[ec7e360]815    if not opts['explore']:
816        plt.show()
[8a20be5]817
[013adb7]818    return limits
819
[0763009]820
821
822
[87985ca]823# ===========================================================================
824#
[216a9e1]825NAME_OPTIONS = set([
[5d316e9]826    'plot', 'noplot',
[ec7e360]827    'half', 'fast', 'single', 'double',
828    'single!', 'double!', 'quad!', 'sasview',
[e78edc4]829    'lowq', 'midq', 'highq', 'exq', 'zero',
[5d316e9]830    '2d', '1d',
831    'preset', 'random',
832    'poly', 'mono',
[0b040de]833    'magnetic', 'nonmagnetic',
[5d316e9]834    'nopars', 'pars',
835    'rel', 'abs',
[b89f519]836    'linear', 'log', 'q4',
[5d316e9]837    'hist', 'nohist',
[630156b]838    'edit', 'html', 'help',
[98d6cfc]839    'demo', 'default',
[216a9e1]840    ])
841VALUE_OPTIONS = [
842    # Note: random is both a name option and a value option
[a769b54]843    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title', 'data',
[87985ca]844    ]
845
[b32dafd]846def columnize(items, indent="", width=79):
[dd7fc12]847    # type: (List[str], str, int) -> str
[caeb06d]848    """
[1d4017a]849    Format a list of strings into columns.
850
851    Returns a string with carriage returns ready for printing.
[caeb06d]852    """
[b32dafd]853    column_width = max(len(w) for w in items) + 1
[7cf2cfd]854    num_columns = (width - len(indent)) // column_width
[b32dafd]855    num_rows = len(items) // num_columns
856    items = items + [""] * (num_rows * num_columns - len(items))
857    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
[7cf2cfd]858    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
859             for row in zip(*columns)]
860    output = indent + ("\n"+indent).join(lines)
861    return output
862
863
[98d6cfc]864def get_pars(model_info, use_demo=False):
[dd7fc12]865    # type: (ModelInfo, bool) -> ParameterSet
[caeb06d]866    """
867    Extract demo parameters from the model definition.
868    """
[ec7e360]869    # Get the default values for the parameters
[c499331]870    pars = {}
[6d6508e]871    for p in model_info.parameters.call_parameters:
[c499331]872        parts = [('', p.default)]
873        if p.polydisperse:
874            parts.append(('_pd', 0.0))
875            parts.append(('_pd_n', 0))
876            parts.append(('_pd_nsigma', 3.0))
877            parts.append(('_pd_type', "gaussian"))
878        for ext, val in parts:
879            if p.length > 1:
[b32dafd]880                dict(("%s%d%s" % (p.id, k, ext), val)
881                     for k in range(1, p.length+1))
[c499331]882            else:
[b32dafd]883                pars[p.id + ext] = val
[ec7e360]884
885    # Plug in values given in demo
[98d6cfc]886    if use_demo:
[6d6508e]887        pars.update(model_info.demo)
[373d1b6]888    return pars
889
[ff1fff5]890INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
891def isnumber(str):
892    match = FLOAT_RE.match(str)
893    isfloat = (match and not str[match.end():])
894    return isfloat or INTEGER_RE.match(str)
[17bbadd]895
[8c65a33]896# For distinguishing pairs of models for comparison
897# key-value pair separator =
898# shell characters  | & ; <> $ % ' " \ # `
899# model and parameter names _
900# parameter expressions - + * / . ( )
901# path characters including tilde expansion and windows drive ~ / :
902# not sure about brackets [] {}
903# maybe one of the following @ ? ^ ! ,
904MODEL_SPLIT = ','
[424fe00]905def parse_opts(argv):
906    # type: (List[str]) -> Dict[str, Any]
[caeb06d]907    """
908    Parse command line options.
909    """
[fc0fcd0]910    MODELS = core.list_models()
[424fe00]911    flags = [arg for arg in argv
[caeb06d]912             if arg.startswith('-')]
[424fe00]913    values = [arg for arg in argv
[caeb06d]914              if not arg.startswith('-') and '=' in arg]
[424fe00]915    positional_args = [arg for arg in argv
[caeb06d]916            if not arg.startswith('-') and '=' not in arg]
[d547f16]917    models = "\n    ".join("%-15s"%v for v in MODELS)
[424fe00]918    if len(positional_args) == 0:
[7cf2cfd]919        print(USAGE)
[caeb06d]920        print("\nAvailable models:")
[7cf2cfd]921        print(columnize(MODELS, indent="  "))
[424fe00]922        return None
923    if len(positional_args) > 3:
[9cfcac8]924        print("expected parameters: model N1 N2")
[87985ca]925
[ec7e360]926    invalid = [o[1:] for o in flags
[216a9e1]927               if o[1:] not in NAME_OPTIONS
[d15a908]928               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]929    if invalid:
[9404dd3]930        print("Invalid options: %s"%(", ".join(invalid)))
[424fe00]931        return None
[87985ca]932
[ff1fff5]933    name = positional_args[0]
934    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
935    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
[ec7e360]936
[d15a908]937    # pylint: disable=bad-whitespace
[ec7e360]938    # Interpret the flags
939    opts = {
940        'plot'      : True,
941        'view'      : 'log',
942        'is2d'      : False,
943        'qmax'      : 0.05,
944        'nq'        : 128,
945        'res'       : 0.0,
946        'accuracy'  : 'Low',
[72a081d]947        'cutoff'    : 0.0,
[ec7e360]948        'seed'      : -1,  # default to preset
[630156b]949        'mono'      : True,
[0b040de]950        # Default to magnetic a magnetic moment is set on the command line
[b6f10d8]951        'magnetic'  : False,
[ec7e360]952        'show_pars' : False,
953        'show_hist' : False,
954        'rel_err'   : True,
955        'explore'   : False,
[98d6cfc]956        'use_demo'  : True,
[dd7fc12]957        'zero'      : False,
[234c532]958        'html'      : False,
[a0d75ce]959        'title'     : None,
[630156b]960        'datafile'  : None,
[ec7e360]961    }
962    engines = []
963    for arg in flags:
964        if arg == '-noplot':    opts['plot'] = False
965        elif arg == '-plot':    opts['plot'] = True
966        elif arg == '-linear':  opts['view'] = 'linear'
967        elif arg == '-log':     opts['view'] = 'log'
968        elif arg == '-q4':      opts['view'] = 'q4'
969        elif arg == '-1d':      opts['is2d'] = False
970        elif arg == '-2d':      opts['is2d'] = True
971        elif arg == '-exq':     opts['qmax'] = 10.0
972        elif arg == '-highq':   opts['qmax'] = 1.0
973        elif arg == '-midq':    opts['qmax'] = 0.2
[ce0b154]974        elif arg == '-lowq':    opts['qmax'] = 0.05
[e78edc4]975        elif arg == '-zero':    opts['zero'] = True
[ec7e360]976        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
977        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
978        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
979        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
980        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
[a769b54]981        elif arg.startswith('-title='):    opts['title'] = arg[7:]
[630156b]982        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
[dd7fc12]983        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
[ec7e360]984        elif arg == '-preset':  opts['seed'] = -1
985        elif arg == '-mono':    opts['mono'] = True
986        elif arg == '-poly':    opts['mono'] = False
[0b040de]987        elif arg == '-magnetic':       opts['magnetic'] = True
988        elif arg == '-nonmagnetic':    opts['magnetic'] = False
[ec7e360]989        elif arg == '-pars':    opts['show_pars'] = True
990        elif arg == '-nopars':  opts['show_pars'] = False
991        elif arg == '-hist':    opts['show_hist'] = True
992        elif arg == '-nohist':  opts['show_hist'] = False
993        elif arg == '-rel':     opts['rel_err'] = True
994        elif arg == '-abs':     opts['rel_err'] = False
995        elif arg == '-half':    engines.append(arg[1:])
996        elif arg == '-fast':    engines.append(arg[1:])
997        elif arg == '-single':  engines.append(arg[1:])
998        elif arg == '-double':  engines.append(arg[1:])
999        elif arg == '-single!': engines.append(arg[1:])
1000        elif arg == '-double!': engines.append(arg[1:])
1001        elif arg == '-quad!':   engines.append(arg[1:])
1002        elif arg == '-sasview': engines.append(arg[1:])
1003        elif arg == '-edit':    opts['explore'] = True
[98d6cfc]1004        elif arg == '-demo':    opts['use_demo'] = True
1005        elif arg == '-default':    opts['use_demo'] = False
[234c532]1006        elif arg == '-html':    opts['html'] = True
[630156b]1007        elif arg == '-help':    opts['html'] = True
[d15a908]1008    # pylint: enable=bad-whitespace
[ec7e360]1009
[8c65a33]1010    if MODEL_SPLIT in name:
1011        name, name2 = name.split(MODEL_SPLIT, 2)
[ff1fff5]1012    else:
1013        name2 = name
1014    try:
1015        model_info = core.load_model_info(name)
1016        model_info2 = core.load_model_info(name2) if name2 != name else model_info
1017    except ImportError as exc:
1018        print(str(exc))
1019        print("Could not find model; use one of:\n    " + models)
1020        return None
[87985ca]1021
[ec7e360]1022    # Get demo parameters from model definition, or use default parameters
1023    # if model does not define demo parameters
[98d6cfc]1024    pars = get_pars(model_info, opts['use_demo'])
[ff1fff5]1025    pars2 = get_pars(model_info2, opts['use_demo'])
[248561a]1026    pars2.update((k, v) for k, v in pars.items() if k in pars2)
[ff1fff5]1027    # randomize parameters
1028    #pars.update(set_pars)  # set value before random to control range
1029    if opts['seed'] > -1:
1030        pars = randomize_pars(model_info, pars, seed=opts['seed'])
1031        if model_info != model_info2:
1032            pars2 = randomize_pars(model_info2, pars2, seed=opts['seed'])
[158cee4]1033            # Share values for parameters with the same name
1034            for k, v in pars.items():
1035                if k in pars2:
1036                    pars2[k] = v
[ff1fff5]1037        else:
1038            pars2 = pars.copy()
[158cee4]1039        constrain_pars(model_info, pars)
1040        constrain_pars(model_info2, pars2)
[ff1fff5]1041        print("Randomize using -random=%i"%opts['seed'])
1042    if opts['mono']:
1043        pars = suppress_pd(pars)
1044        pars2 = suppress_pd(pars2)
[b6f10d8]1045    if not opts['magnetic']:
1046        pars = suppress_magnetism(pars)
1047        pars2 = suppress_magnetism(pars2)
[87985ca]1048
1049    # Fill in parameters given on the command line
[ec7e360]1050    presets = {}
[ff1fff5]1051    presets2 = {}
[ec7e360]1052    for arg in values:
[d15a908]1053        k, v = arg.split('=', 1)
[ff1fff5]1054        if k not in pars and k not in pars2:
[ec7e360]1055            # extract base name without polydispersity info
[87985ca]1056            s = set(p.split('_pd')[0] for p in pars)
[d15a908]1057            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
[424fe00]1058            return None
[8c65a33]1059        v1, v2 = v.split(MODEL_SPLIT, 2) if MODEL_SPLIT in v else (v,v)
[ff1fff5]1060        if v1 and k in pars:
1061            presets[k] = float(v1) if isnumber(v1) else v1
1062        if v2 and k in pars2:
1063            presets2[k] = float(v2) if isnumber(v2) else v2
1064
[b6f10d8]1065    # If pd given on the command line, default pd_n to 35
1066    for k, v in list(presets.items()):
1067        if k.endswith('_pd'):
1068            presets.setdefault(k+'_n', 35.)
1069    for k, v in list(presets2.items()):
1070        if k.endswith('_pd'):
1071            presets2.setdefault(k+'_n', 35.)
1072
[ff1fff5]1073    # Evaluate preset parameter expressions
[248561a]1074    context = MATH.copy()
[fe25eda]1075    context['np'] = np
[248561a]1076    context.update(pars)
[ff1fff5]1077    context.update((k,v) for k,v in presets.items() if isinstance(v, float))
1078    for k, v in presets.items():
1079        if not isinstance(v, float) and not k.endswith('_type'):
1080            presets[k] = eval(v, context)
1081    context.update(presets)
1082    context.update((k,v) for k,v in presets2.items() if isinstance(v, float))
1083    for k, v in presets2.items():
1084        if not isinstance(v, float) and not k.endswith('_type'):
1085            presets2[k] = eval(v, context)
1086
1087    # update parameters with presets
[ec7e360]1088    pars.update(presets)  # set value after random to control value
[ff1fff5]1089    pars2.update(presets2)  # set value after random to control value
[fcd7bbd]1090    #import pprint; pprint.pprint(model_info)
[ff1fff5]1091
1092    same_model = name == name2 and pars == pars
1093    if len(engines) == 0:
1094        if same_model:
1095            engines.extend(['single', 'double'])
1096        else:
1097            engines.extend(['single', 'single'])
1098    elif len(engines) == 1:
1099        if not same_model:
1100            engines.append(engines[0])
1101        elif engines[0] == 'double':
1102            engines.append('single')
1103        else:
1104            engines.append('double')
1105    elif len(engines) > 2:
1106        del engines[2:]
1107
1108    use_sasview = any(engine == 'sasview' and count > 0
1109                      for engine, count in zip(engines, [n1, n2]))
[fa1582e]1110    if use_sasview:
1111        constrain_new_to_old(model_info, pars)
[ff1fff5]1112        constrain_new_to_old(model_info2, pars2)
1113
[ec7e360]1114    if opts['show_pars']:
[248561a]1115        if not same_model:
1116            print("==== %s ====="%model_info.name)
1117            print(str(parlist(model_info, pars, opts['is2d'])))
1118            print("==== %s ====="%model_info2.name)
1119            print(str(parlist(model_info2, pars2, opts['is2d'])))
1120        else:
1121            print(str(parlist(model_info, pars, opts['is2d'])))
[ec7e360]1122
1123    # Create the computational engines
[630156b]1124    if opts['datafile'] is not None:
1125        data = load_data(os.path.expanduser(opts['datafile']))
[a769b54]1126    else:
1127        data, _ = make_data(opts)
[9cfcac8]1128    if n1:
[17bbadd]1129        base = make_engine(model_info, data, engines[0], opts['cutoff'])
[ec7e360]1130    else:
1131        base = None
[9cfcac8]1132    if n2:
[ff1fff5]1133        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
[ec7e360]1134    else:
1135        comp = None
1136
[d15a908]1137    # pylint: disable=bad-whitespace
[ec7e360]1138    # Remember it all
1139    opts.update({
1140        'data'      : data,
[ff1fff5]1141        'name'      : [name, name2],
1142        'def'       : [model_info, model_info2],
1143        'count'     : [n1, n2],
1144        'presets'   : [presets, presets2],
1145        'pars'      : [pars, pars2],
[ec7e360]1146        'engines'   : [base, comp],
1147    })
[d15a908]1148    # pylint: enable=bad-whitespace
[ec7e360]1149
1150    return opts
1151
[234c532]1152def show_docs(opts):
1153    # type: (Dict[str, Any]) -> None
1154    """
1155    show html docs for the model
1156    """
[c4e3215]1157    import os
1158    from .generate import make_html
1159    from . import rst2html
1160
1161    info = opts['def'][0]
1162    html = make_html(info)
1163    path = os.path.dirname(info.filename)
1164    url = "file://"+path.replace("\\","/")[2:]+"/"
1165    rst2html.view_html_qtapp(html, url)
[234c532]1166
[ec7e360]1167def explore(opts):
[dd7fc12]1168    # type: (Dict[str, Any]) -> None
[d15a908]1169    """
[234c532]1170    explore the model using the bumps gui.
[d15a908]1171    """
[7ae2b7f]1172    import wx  # type: ignore
1173    from bumps.names import FitProblem  # type: ignore
1174    from bumps.gui.app_frame import AppFrame  # type: ignore
[ca9e54e]1175    from bumps.gui import signal
[ec7e360]1176
[d15a908]1177    is_mac = "cocoa" in wx.version()
[80013a6]1178    # Create an app if not running embedded
1179    app = wx.App() if wx.GetApp() is None else None
[ca9e54e]1180    model = Explore(opts)
1181    problem = FitProblem(model)
[80013a6]1182    frame = AppFrame(parent=None, title="explore", size=(1000,700))
[d15a908]1183    if not is_mac: frame.Show()
[ec7e360]1184    frame.panel.set_model(model=problem)
1185    frame.panel.Layout()
1186    frame.panel.aui.Split(0, wx.TOP)
[ca9e54e]1187    def reset_parameters(event):
1188        model.revert_values()
1189        signal.update_parameters(problem)
1190    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
[d15a908]1191    if is_mac: frame.Show()
[80013a6]1192    # If running withing an app, start the main loop
1193    if app: app.MainLoop()
[ec7e360]1194
1195class Explore(object):
1196    """
[d15a908]1197    Bumps wrapper for a SAS model comparison.
1198
1199    The resulting object can be used as a Bumps fit problem so that
1200    parameters can be adjusted in the GUI, with plots updated on the fly.
[ec7e360]1201    """
1202    def __init__(self, opts):
[dd7fc12]1203        # type: (Dict[str, Any]) -> None
[7ae2b7f]1204        from bumps.cli import config_matplotlib  # type: ignore
[608e31e]1205        from . import bumps_model
[ec7e360]1206        config_matplotlib()
1207        self.opts = opts
[ca9e54e]1208        p1, p2 = opts['pars']
1209        m1, m2 = opts['def']
1210        self.fix_p2 = m1 != m2 or p1 != p2
1211        model_info = m1
1212        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
[21b116f]1213        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
[ec7e360]1214        if not opts['is2d']:
[85fe7f8]1215            for p in model_info.parameters.user_parameters({}, is2d=False):
[303d8d6]1216                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
[69aa451]1217                    k = p.name+ext
[303d8d6]1218                    v = pars.get(k, None)
1219                    if v is not None:
1220                        v.range(*parameter_range(k, v.value))
[ec7e360]1221        else:
[013adb7]1222            for k, v in pars.items():
[ec7e360]1223                v.range(*parameter_range(k, v.value))
1224
1225        self.pars = pars
[ca9e54e]1226        self.starting_values = dict((k, v.value) for k, v in pars.items())
[ec7e360]1227        self.pd_types = pd_types
[013adb7]1228        self.limits = None
[ec7e360]1229
[ca9e54e]1230    def revert_values(self):
1231        for k, v in self.starting_values.items():
1232            self.pars[k].value = v
1233
1234    def model_update(self):
1235        pass
1236
[ec7e360]1237    def numpoints(self):
[dd7fc12]1238        # type: () -> int
[ec7e360]1239        """
[608e31e]1240        Return the number of points.
[ec7e360]1241        """
1242        return len(self.pars) + 1  # so dof is 1
1243
1244    def parameters(self):
[dd7fc12]1245        # type: () -> Any   # Dict/List hierarchy of parameters
[ec7e360]1246        """
[608e31e]1247        Return a dictionary of parameters.
[ec7e360]1248        """
1249        return self.pars
1250
1251    def nllf(self):
[dd7fc12]1252        # type: () -> float
[608e31e]1253        """
1254        Return cost.
1255        """
[d15a908]1256        # pylint: disable=no-self-use
[ec7e360]1257        return 0.  # No nllf
1258
1259    def plot(self, view='log'):
[dd7fc12]1260        # type: (str) -> None
[ec7e360]1261        """
1262        Plot the data and residuals.
1263        """
[608e31e]1264        pars = dict((k, v.value) for k, v in self.pars.items())
[ec7e360]1265        pars.update(self.pd_types)
[ff1fff5]1266        self.opts['pars'][0] = pars
[ca9e54e]1267        if not self.fix_p2:
1268            self.opts['pars'][1] = pars
1269        result = run_models(self.opts)
1270        limits = plot_models(self.opts, result, limits=self.limits)
[013adb7]1271        if self.limits is None:
1272            vmin, vmax = limits
[dd7fc12]1273            self.limits = vmax*1e-7, 1.3*vmax
[ca9e54e]1274            import pylab; pylab.clf()
1275            plot_models(self.opts, result, limits=self.limits)
[87985ca]1276
1277
[424fe00]1278def main(*argv):
1279    # type: (*str) -> None
[d15a908]1280    """
1281    Main program.
1282    """
[424fe00]1283    opts = parse_opts(argv)
1284    if opts is not None:
[234c532]1285        if opts['html']:
1286            show_docs(opts)
1287        elif opts['explore']:
[424fe00]1288            explore(opts)
1289        else:
1290            compare(opts)
[d15a908]1291
[8a20be5]1292if __name__ == "__main__":
[424fe00]1293    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.