source: sasmodels/sasmodels/compare.py @ 32398dc

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 32398dc was 32398dc, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

remove sascomp support for sasview 3.x models

  • Property mode set to 100755
File size: 52.1 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel, get_mesh
45from .convert import revert_name, revert_pars
46from .generate import FLOAT_RE
47from .weights import plot_weights
48
49# pylint: disable=unused-import
50try:
51    from typing import Optional, Dict, Any, Callable, Tuple
52except ImportError:
53    pass
54else:
55    from .modelinfo import ModelInfo, Parameter, ParameterSet
56    from .data import Data
57    Calculator = Callable[[float], np.ndarray]
58# pylint: enable=unused-import
59
60USAGE = """
61usage: sascomp model [options...] [key=val]
62
63Generate and compare SAS models.  If a single model is specified it shows
64a plot of that model.  Different models can be compared, or the same model
65with different parameters.  The same model with the same parameters can
66be compared with different calculation engines to see the effects of precision
67on the resultant values.
68
69model or model1,model2 are the names of the models to compare (see below).
70
71Options (* for default):
72
73    === data generation ===
74    -data="path" uses q, dq from the data file
75    -noise=0 sets the measurement error dI/I
76    -res=0 sets the resolution width dQ/Q if calculating with resolution
77    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
78    -q=min:max alternative specification of qrange
79    -nq=128 sets the number of Q points in the data set
80    -1d*/-2d computes 1d or 2d data
81    -zero indicates that q=0 should be included
82
83    === model parameters ===
84    -preset*/-random[=seed] preset or random parameters
85    -sets=n generates n random datasets with the seed given by -random=seed
86    -pars/-nopars* prints the parameter set or not
87    -default/-demo* use demo vs default parameters
88    -sphere[=150] set up spherical integration over theta/phi using n points
89
90    === calculation options ===
91    -mono*/-poly force monodisperse or allow polydisperse random parameters
92    -cutoff=1e-5* cutoff value for including a point in polydispersity
93    -magnetic/-nonmagnetic* suppress magnetism
94    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
95    -neval=1 sets the number of evals for more accurate timing
96
97    === precision options ===
98    -engine=default uses the default calcution precision
99    -single/-double/-half/-fast sets an OpenCL calculation engine
100    -single!/-double!/-quad! sets an OpenMP calculation engine
101
102    === plotting ===
103    -plot*/-noplot plots or suppress the plot of the model
104    -linear/-log*/-q4 intensity scaling on plots
105    -hist/-nohist* plot histogram of relative error
106    -abs/-rel* plot relative or absolute error
107    -title="note" adds note to the plot title, after the model name
108    -weights shows weights plots for the polydisperse parameters
109
110    === output options ===
111    -edit starts the parameter explorer
112    -help/-html shows the model docs instead of running the model
113
114The interpretation of quad precision depends on architecture, and may
115vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
116On unix and mac you may need single quotes around the DLL computation
117engines, such as -engine='single!,double!' since !, is treated as a history
118expansion request in the shell.
119
120Key=value pairs allow you to set specific values for the model parameters.
121Key=value1,value2 to compare different values of the same parameter. The
122value can be an expression including other parameters.
123
124Items later on the command line override those that appear earlier.
125
126Examples:
127
128    # compare single and double precision calculation for a barbell
129    sascomp barbell -engine=single,double
130
131    # generate 10 random lorentz models, with seed=27
132    sascomp lorentz -sets=10 -seed=27
133
134    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
135    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
136
137    # model timing test requires multiple evals to perform the estimate
138    sascomp pringle -engine=single,double -timing=100,100 -noplot
139"""
140
141# Update docs with command line usage string.   This is separate from the usual
142# doc string so that we can display it at run time if there is an error.
143# lin
144__doc__ = (__doc__  # pylint: disable=redefined-builtin
145           + """
146Program description
147-------------------
148
149""" + USAGE)
150
151kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
152
153def build_math_context():
154    # type: () -> Dict[str, Callable]
155    """build dictionary of functions from math module"""
156    return dict((k, getattr(math, k))
157                for k in dir(math) if not k.startswith('_'))
158
159#: list of math functions for use in evaluating parameters
160MATH = build_math_context()
161
162# CRUFT python 2.6
163if not hasattr(datetime.timedelta, 'total_seconds'):
164    def delay(dt):
165        """Return number date-time delta as number seconds"""
166        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
167else:
168    def delay(dt):
169        """Return number date-time delta as number seconds"""
170        return dt.total_seconds()
171
172
173class push_seed(object):
174    """
175    Set the seed value for the random number generator.
176
177    When used in a with statement, the random number generator state is
178    restored after the with statement is complete.
179
180    :Parameters:
181
182    *seed* : int or array_like, optional
183        Seed for RandomState
184
185    :Example:
186
187    Seed can be used directly to set the seed::
188
189        >>> from numpy.random import randint
190        >>> push_seed(24)
191        <...push_seed object at...>
192        >>> print(randint(0,1000000,3))
193        [242082    899 211136]
194
195    Seed can also be used in a with statement, which sets the random
196    number generator state for the enclosed computations and restores
197    it to the previous state on completion::
198
199        >>> with push_seed(24):
200        ...    print(randint(0,1000000,3))
201        [242082    899 211136]
202
203    Using nested contexts, we can demonstrate that state is indeed
204    restored after the block completes::
205
206        >>> with push_seed(24):
207        ...    print(randint(0,1000000))
208        ...    with push_seed(24):
209        ...        print(randint(0,1000000,3))
210        ...    print(randint(0,1000000))
211        242082
212        [242082    899 211136]
213        899
214
215    The restore step is protected against exceptions in the block::
216
217        >>> with push_seed(24):
218        ...    print(randint(0,1000000))
219        ...    try:
220        ...        with push_seed(24):
221        ...            print(randint(0,1000000,3))
222        ...            raise Exception()
223        ...    except Exception:
224        ...        print("Exception raised")
225        ...    print(randint(0,1000000))
226        242082
227        [242082    899 211136]
228        Exception raised
229        899
230    """
231    def __init__(self, seed=None):
232        # type: (Optional[int]) -> None
233        self._state = np.random.get_state()
234        np.random.seed(seed)
235
236    def __enter__(self):
237        # type: () -> None
238        pass
239
240    def __exit__(self, exc_type, exc_value, trace):
241        # type: (Any, BaseException, Any) -> None
242        np.random.set_state(self._state)
243
244def tic():
245    # type: () -> Callable[[], float]
246    """
247    Timer function.
248
249    Use "toc=tic()" to start the clock and "toc()" to measure
250    a time interval.
251    """
252    then = datetime.datetime.now()
253    return lambda: delay(datetime.datetime.now() - then)
254
255
256def set_beam_stop(data, radius, outer=None):
257    # type: (Data, float, float) -> None
258    """
259    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
260    """
261    if hasattr(data, 'qx_data'):
262        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
263        data.mask = (q < radius)
264        if outer is not None:
265            data.mask |= (q >= outer)
266    else:
267        data.mask = (data.x < radius)
268        if outer is not None:
269            data.mask |= (data.x >= outer)
270
271
272def parameter_range(p, v):
273    # type: (str, float) -> Tuple[float, float]
274    """
275    Choose a parameter range based on parameter name and initial value.
276    """
277    # process the polydispersity options
278    if p.endswith('_pd_n'):
279        return 0., 100.
280    elif p.endswith('_pd_nsigma'):
281        return 0., 5.
282    elif p.endswith('_pd_type'):
283        raise ValueError("Cannot return a range for a string value")
284    elif any(s in p for s in ('theta', 'phi', 'psi')):
285        # orientation in [-180,180], orientation pd in [0,45]
286        if p.endswith('_pd'):
287            return 0., 180.
288        else:
289            return -180., 180.
290    elif p.endswith('_pd'):
291        return 0., 1.
292    elif 'sld' in p:
293        return -0.5, 10.
294    elif p == 'background':
295        return 0., 10.
296    elif p == 'scale':
297        return 0., 1.e3
298    elif v < 0.:
299        return 2.*v, -2.*v
300    else:
301        return 0., (2.*v if v > 0. else 1.)
302
303
304def _randomize_one(model_info, name, value):
305    # type: (ModelInfo, str, float) -> float
306    # type: (ModelInfo, str, str) -> str
307    """
308    Randomize a single parameter.
309    """
310    # Set the amount of polydispersity/angular dispersion, but by default pd_n
311    # is zero so there is no polydispersity.  This allows us to turn on/off
312    # pd by setting pd_n, and still have randomly generated values
313    if name.endswith('_pd'):
314        par = model_info.parameters[name[:-3]]
315        if par.type == 'orientation':
316            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
317            return 180*np.random.beta(2.5, 20)
318        else:
319            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
320            return np.random.beta(1.5, 7)
321
322    # pd is selected globally rather than per parameter, so set to 0 for no pd
323    # In particular, when multiple pd dimensions, want to decrease the number
324    # of points per dimension for faster computation
325    if name.endswith('_pd_n'):
326        return 0
327
328    # Don't mess with distribution type for now
329    if name.endswith('_pd_type'):
330        return 'gaussian'
331
332    # type-dependent value of number of sigmas; for gaussian use 3.
333    if name.endswith('_pd_nsigma'):
334        return 3.
335
336    # background in the range [0.01, 1]
337    if name == 'background':
338        return 10**np.random.uniform(-2, 0)
339
340    # scale defaults to 0.1% to 30% volume fraction
341    if name == 'scale':
342        return 10**np.random.uniform(-3, -0.5)
343
344    # If it is a list of choices, pick one at random with equal probability
345    # In practice, the model specific random generator will override.
346    par = model_info.parameters[name]
347    if len(par.limits) > 2:  # choice list
348        return np.random.randint(len(par.limits))
349
350    # If it is a fixed range, pick from it with equal probability.
351    # For logarithmic ranges, the model will have to override.
352    if np.isfinite(par.limits).all():
353        return np.random.uniform(*par.limits)
354
355    # If the paramter is marked as an sld use the range of neutron slds
356    # TODO: ought to randomly contrast match a pair of SLDs
357    if par.type == 'sld':
358        return np.random.uniform(-0.5, 12)
359
360    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
361    if par.name.startswith('M0:'):
362        return np.random.uniform(0, 5)
363
364    # Guess at the random length/radius/thickness.  In practice, all models
365    # are going to set their own reasonable ranges.
366    if par.type == 'volume':
367        if ('length' in par.name or
368                'radius' in par.name or
369                'thick' in par.name):
370            return 10**np.random.uniform(2, 4)
371
372    # In the absence of any other info, select a value in [0, 2v], or
373    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
374    # model random parameter generators will override this default.
375    low, high = parameter_range(par.name, value)
376    limits = (max(par.limits[0], low), min(par.limits[1], high))
377    return np.random.uniform(*limits)
378
379def _random_pd(model_info, pars):
380    # type: (ModelInfo, Dict[str, float]) -> None
381    """
382    Generate a random dispersity distribution for the model.
383
384    1% no shape dispersity
385    85% single shape parameter
386    13% two shape parameters
387    1% three shape parameters
388
389    If oriented, then put dispersity in theta, add phi and psi dispersity
390    with 10% probability for each.
391    """
392    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
393    pd_volume = []
394    pd_oriented = []
395    for p in pd:
396        if p.type == 'orientation':
397            pd_oriented.append(p.name)
398        elif p.length_control is not None:
399            n = int(pars.get(p.length_control, 1) + 0.5)
400            pd_volume.extend(p.name+str(k+1) for k in range(n))
401        elif p.length > 1:
402            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
403        else:
404            pd_volume.append(p.name)
405    u = np.random.rand()
406    n = len(pd_volume)
407    if u < 0.01 or n < 1:
408        pass  # 1% chance of no polydispersity
409    elif u < 0.86 or n < 2:
410        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
411    elif u < 0.99 or n < 3:
412        choices = np.random.choice(len(pd_volume), size=2)
413        pars[pd_volume[choices[0]]+"_pd_n"] = 25
414        pars[pd_volume[choices[1]]+"_pd_n"] = 10
415    else:
416        choices = np.random.choice(len(pd_volume), size=3)
417        pars[pd_volume[choices[0]]+"_pd_n"] = 25
418        pars[pd_volume[choices[1]]+"_pd_n"] = 10
419        pars[pd_volume[choices[2]]+"_pd_n"] = 5
420    if pd_oriented:
421        pars['theta_pd_n'] = 20
422        if np.random.rand() < 0.1:
423            pars['phi_pd_n'] = 5
424        if np.random.rand() < 0.1:
425            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
426                #print("generating psi_pd_n")
427                pars['psi_pd_n'] = 5
428
429    ## Show selected polydispersity
430    #for name, value in pars.items():
431    #    if name.endswith('_pd_n') and value > 0:
432    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
433
434
435def randomize_pars(model_info, pars):
436    # type: (ModelInfo, ParameterSet) -> ParameterSet
437    """
438    Generate random values for all of the parameters.
439
440    Valid ranges for the random number generator are guessed from the name of
441    the parameter; this will not account for constraints such as cap radius
442    greater than cylinder radius in the capped_cylinder model, so
443    :func:`constrain_pars` needs to be called afterward..
444    """
445    # Note: the sort guarantees order of calls to random number generator
446    random_pars = dict((p, _randomize_one(model_info, p, v))
447                       for p, v in sorted(pars.items()))
448    if model_info.random is not None:
449        random_pars.update(model_info.random())
450    _random_pd(model_info, random_pars)
451    return random_pars
452
453
454def limit_dimensions(model_info, pars, maxdim):
455    # type: (ModelInfo, ParameterSet, float) -> None
456    """
457    Limit parameters of units of Ang to maxdim.
458    """
459    for p in model_info.parameters.call_parameters:
460        value = pars[p.name]
461        if p.units == 'Ang' and value > maxdim:
462            pars[p.name] = maxdim*10**np.random.uniform(-3, 0)
463
464def constrain_pars(model_info, pars):
465    # type: (ModelInfo, ParameterSet) -> None
466    """
467    Restrict parameters to valid values.
468
469    This includes model specific code for models such as capped_cylinder
470    which need to support within model constraints (cap radius more than
471    cylinder radius in this case).
472
473    Warning: this updates the *pars* dictionary in place.
474    """
475    # TODO: move the model specific code to the individual models
476    name = model_info.id
477    # if it is a product model, then just look at the form factor since
478    # none of the structure factors need any constraints.
479    if '*' in name:
480        name = name.split('*')[0]
481
482    # Suppress magnetism for python models (not yet implemented)
483    if callable(model_info.Iq):
484        pars.update(suppress_magnetism(pars))
485
486    if name == 'barbell':
487        if pars['radius_bell'] < pars['radius']:
488            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
489
490    elif name == 'capped_cylinder':
491        if pars['radius_cap'] < pars['radius']:
492            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
493
494    elif name == 'guinier':
495        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
496        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
497        # => ln A - (Rg^2 q^2/3) > -30 ln 10
498        # => Rg^2 q^2/3 < 30 ln 10 + ln A
499        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
500        #q_max = 0.2  # mid q maximum
501        q_max = 1.0  # high q maximum
502        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
503        pars['rg'] = min(pars['rg'], rg_max)
504
505    elif name == 'pearl_necklace':
506        if pars['radius'] < pars['thick_string']:
507            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
508
509    elif name == 'rpa':
510        # Make sure phi sums to 1.0
511        if pars['case_num'] < 2:
512            pars['Phi1'] = 0.
513            pars['Phi2'] = 0.
514        elif pars['case_num'] < 5:
515            pars['Phi1'] = 0.
516        total = sum(pars['Phi'+c] for c in '1234')
517        for c in '1234':
518            pars['Phi'+c] /= total
519
520def parlist(model_info, pars, is2d):
521    # type: (ModelInfo, ParameterSet, bool) -> str
522    """
523    Format the parameter list for printing.
524    """
525    is2d = True
526    lines = []
527    parameters = model_info.parameters
528    magnetic = False
529    magnetic_pars = []
530    for p in parameters.user_parameters(pars, is2d):
531        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
532            continue
533        if p.id.startswith('up:'):
534            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
535            continue
536        fields = dict(
537            value=pars.get(p.id, p.default),
538            pd=pars.get(p.id+"_pd", 0.),
539            n=int(pars.get(p.id+"_pd_n", 0)),
540            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
541            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
542            relative_pd=p.relative_pd,
543            M0=pars.get('M0:'+p.id, 0.),
544            mphi=pars.get('mphi:'+p.id, 0.),
545            mtheta=pars.get('mtheta:'+p.id, 0.),
546        )
547        lines.append(_format_par(p.name, **fields))
548        magnetic = magnetic or fields['M0'] != 0.
549    if magnetic and magnetic_pars:
550        lines.append(" ".join(magnetic_pars))
551    return "\n".join(lines)
552
553    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
554
555def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
556                relative_pd=False, M0=0., mphi=0., mtheta=0.):
557    # type: (str, float, float, int, float, str) -> str
558    line = "%s: %g"%(name, value)
559    if pd != 0.  and n != 0:
560        if relative_pd:
561            pd *= value
562        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
563                % (pd, n, nsigma, nsigma, pdtype)
564    if M0 != 0.:
565        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
566    return line
567
568def suppress_pd(pars, suppress=True):
569    # type: (ParameterSet) -> ParameterSet
570    """
571    If suppress is True complete eliminate polydispersity of the model to test
572    models more quickly.  If suppress is False, make sure at least one
573    parameter is polydisperse, setting the first polydispersity parameter to
574    15% if no polydispersity is given (with no explicit demo parameters given
575    in the model, there will be no default polydispersity).
576    """
577    pars = pars.copy()
578    #print("pars=", pars)
579    if suppress:
580        for p in pars:
581            if p.endswith("_pd_n"):
582                pars[p] = 0
583    else:
584        any_pd = False
585        first_pd = None
586        for p in pars:
587            if p.endswith("_pd_n"):
588                pd = pars.get(p[:-2], 0.)
589                any_pd |= (pars[p] != 0 and pd != 0.)
590                if first_pd is None:
591                    first_pd = p
592        if not any_pd and first_pd is not None:
593            if pars[first_pd] == 0:
594                pars[first_pd] = 35
595            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
596                pars[first_pd[:-2]] = 0.15
597    return pars
598
599def suppress_magnetism(pars, suppress=True):
600    # type: (ParameterSet) -> ParameterSet
601    """
602    If suppress is True complete eliminate magnetism of the model to test
603    models more quickly.  If suppress is False, make sure at least one sld
604    parameter is magnetic, setting the first parameter to have a strong
605    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
606    in the model, there will be no default magnetism).
607    """
608    pars = pars.copy()
609    if suppress:
610        for p in pars:
611            if p.startswith("M0:"):
612                pars[p] = 0
613    else:
614        any_mag = False
615        first_mag = None
616        for p in pars:
617            if p.startswith("M0:"):
618                any_mag |= (pars[p] != 0)
619                if first_mag is None:
620                    first_mag = p
621        if not any_mag and first_mag is not None:
622            pars[first_mag] = 8.
623    return pars
624
625
626DTYPE_MAP = {
627    'half': '16',
628    'fast': 'fast',
629    'single': '32',
630    'double': '64',
631    'quad': '128',
632    'f16': '16',
633    'f32': '32',
634    'f64': '64',
635    'float16': '16',
636    'float32': '32',
637    'float64': '64',
638    'float128': '128',
639    'longdouble': '128',
640}
641def eval_opencl(model_info, data, dtype='single', cutoff=0.):
642    # type: (ModelInfo, Data, str, float) -> Calculator
643    """
644    Return a model calculator using the OpenCL calculation engine.
645    """
646    if not core.HAVE_OPENCL:
647        raise RuntimeError("OpenCL not available")
648    model = core.build_model(model_info, dtype=dtype, platform="ocl")
649    calculator = DirectModel(data, model, cutoff=cutoff)
650    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
651    return calculator
652
653def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
654    # type: (ModelInfo, Data, str, float) -> Calculator
655    """
656    Return a model calculator using the DLL calculation engine.
657    """
658    model = core.build_model(model_info, dtype=dtype, platform="dll")
659    calculator = DirectModel(data, model, cutoff=cutoff)
660    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
661    return calculator
662
663def time_calculation(calculator, pars, evals=1):
664    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
665    """
666    Compute the average calculation time over N evaluations.
667
668    An additional call is generated without polydispersity in order to
669    initialize the calculation engine, and make the average more stable.
670    """
671    # initialize the code so time is more accurate
672    if evals > 1:
673        calculator(**suppress_pd(pars))
674    toc = tic()
675    # make sure there is at least one eval
676    value = calculator(**pars)
677    for _ in range(evals-1):
678        value = calculator(**pars)
679    average_time = toc()*1000. / evals
680    #print("I(q)",value)
681    return value, average_time
682
683def make_data(opts):
684    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
685    """
686    Generate an empty dataset, used with the model to set Q points
687    and resolution.
688
689    *opts* contains the options, with 'qmax', 'nq', 'res',
690    'accuracy', 'is2d' and 'view' parsed from the command line.
691    """
692    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res']
693    if opts['is2d']:
694        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
695        data = empty_data2D(q, resolution=res)
696        data.accuracy = opts['accuracy']
697        set_beam_stop(data, 0.0004)
698        index = ~data.mask
699    else:
700        if opts['view'] == 'log' and not opts['zero']:
701            q = np.logspace(math.log10(qmin), math.log10(qmax), nq)
702        else:
703            q = np.linspace(qmin, qmax, nq)
704        if opts['zero']:
705            q = np.hstack((0, q))
706        data = empty_data1D(q, resolution=res)
707        index = slice(None, None)
708    return data, index
709
710def make_engine(model_info, data, dtype, cutoff):
711    # type: (ModelInfo, Data, str, float) -> Calculator
712    """
713    Generate the appropriate calculation engine for the given datatype.
714
715    Datatypes with '!' appended are evaluated using external C DLLs rather
716    than OpenCL.
717    """
718    if dtype is None or not dtype.endswith('!'):
719        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
720    else:
721        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
722
723def _show_invalid(data, theory):
724    # type: (Data, np.ma.ndarray) -> None
725    """
726    Display a list of the non-finite values in theory.
727    """
728    if not theory.mask.any():
729        return
730
731    if hasattr(data, 'x'):
732        bad = zip(data.x[theory.mask], theory[theory.mask])
733        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
734
735
736def compare(opts, limits=None, maxdim=np.inf):
737    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
738    """
739    Preform a comparison using options from the command line.
740
741    *limits* are the limits on the values to use, either to set the y-axis
742    for 1D or to set the colormap scale for 2D.  If None, then they are
743    inferred from the data and returned. When exploring using Bumps,
744    the limits are set when the model is initially called, and maintained
745    as the values are adjusted, making it easier to see the effects of the
746    parameters.
747
748    *maxdim* is the maximum value for any parameter with units of Angstrom.
749    """
750    for k in range(opts['sets']):
751        if k > 1:
752            # print a separate seed for each dataset for better reproducibility
753            new_seed = np.random.randint(1000000)
754            print("Set %d uses -random=%i"%(k+1, new_seed))
755            np.random.seed(new_seed)
756        opts['pars'] = parse_pars(opts, maxdim=maxdim)
757        if opts['pars'] is None:
758            return
759        result = run_models(opts, verbose=True)
760        if opts['plot']:
761            limits = plot_models(opts, result, limits=limits, setnum=k)
762        if opts['show_weights']:
763            base, _ = opts['engines']
764            base_pars, _ = opts['pars']
765            model_info = base._kernel.info
766            dim = base._kernel.dim
767            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
768    if opts['plot']:
769        import matplotlib.pyplot as plt
770        plt.show()
771    return limits
772
773def run_models(opts, verbose=False):
774    # type: (Dict[str, Any]) -> Dict[str, Any]
775    """
776    Process a parameter set, return calculation results and times.
777    """
778
779    base, comp = opts['engines']
780    base_n, comp_n = opts['count']
781    base_pars, comp_pars = opts['pars']
782    data = opts['data']
783
784    comparison = comp is not None
785
786    base_time = comp_time = None
787    base_value = comp_value = resid = relerr = None
788
789    # Base calculation
790    try:
791        base_raw, base_time = time_calculation(base, base_pars, base_n)
792        base_value = np.ma.masked_invalid(base_raw)
793        if verbose:
794            print("%s t=%.2f ms, intensity=%.0f"
795                  % (base.engine, base_time, base_value.sum()))
796        _show_invalid(data, base_value)
797    except ImportError:
798        traceback.print_exc()
799
800    # Comparison calculation
801    if comparison:
802        try:
803            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
804            comp_value = np.ma.masked_invalid(comp_raw)
805            if verbose:
806                print("%s t=%.2f ms, intensity=%.0f"
807                      % (comp.engine, comp_time, comp_value.sum()))
808            _show_invalid(data, comp_value)
809        except ImportError:
810            traceback.print_exc()
811
812    # Compare, but only if computing both forms
813    if comparison:
814        resid = (base_value - comp_value)
815        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
816        if verbose:
817            _print_stats("|%s-%s|"
818                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
819                         resid)
820            _print_stats("|(%s-%s)/%s|"
821                         % (base.engine, comp.engine, comp.engine),
822                         relerr)
823
824    return dict(base_value=base_value, comp_value=comp_value,
825                base_time=base_time, comp_time=comp_time,
826                resid=resid, relerr=relerr)
827
828
829def _print_stats(label, err):
830    # type: (str, np.ma.ndarray) -> None
831    # work with trimmed data, not the full set
832    sorted_err = np.sort(abs(err.compressed()))
833    if len(sorted_err) == 0:
834        print(label + "  no valid values")
835        return
836
837    p50 = int((len(sorted_err)-1)*0.50)
838    p98 = int((len(sorted_err)-1)*0.98)
839    data = [
840        "max:%.3e"%sorted_err[-1],
841        "median:%.3e"%sorted_err[p50],
842        "98%%:%.3e"%sorted_err[p98],
843        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
844        "zero-offset:%+.3e"%np.mean(sorted_err),
845        ]
846    print(label+"  "+"  ".join(data))
847
848
849def plot_models(opts, result, limits=None, setnum=0):
850    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
851    """
852    Plot the results from :func:`run_model`.
853    """
854    import matplotlib.pyplot as plt
855
856    base_value, comp_value = result['base_value'], result['comp_value']
857    base_time, comp_time = result['base_time'], result['comp_time']
858    resid, relerr = result['resid'], result['relerr']
859
860    have_base, have_comp = (base_value is not None), (comp_value is not None)
861    base, comp = opts['engines']
862    data = opts['data']
863    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
864
865    # Plot if requested
866    view = opts['view']
867    if limits is None:
868        vmin, vmax = np.inf, -np.inf
869        if have_base:
870            vmin = min(vmin, base_value.min())
871            vmax = max(vmax, base_value.max())
872        if have_comp:
873            vmin = min(vmin, comp_value.min())
874            vmax = max(vmax, comp_value.max())
875        limits = vmin, vmax
876
877    if have_base:
878        if have_comp:
879            plt.subplot(131)
880        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
881        plt.title("%s t=%.2f ms"%(base.engine, base_time))
882        #cbar_title = "log I"
883    if have_comp:
884        if have_base:
885            plt.subplot(132)
886        if not opts['is2d'] and have_base:
887            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
888        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
889        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
890        #cbar_title = "log I"
891    if have_base and have_comp:
892        plt.subplot(133)
893        if not opts['rel_err']:
894            err, errstr, errview = resid, "abs err", "linear"
895        else:
896            err, errstr, errview = abs(relerr), "rel err", "log"
897            if (err == 0.).all():
898                errview = 'linear'
899        if 0:  # 95% cutoff
900            sorted_err = np.sort(err.flatten())
901            cutoff = sorted_err[int(sorted_err.size*0.95)]
902            err[err > cutoff] = cutoff
903        #err,errstr = base/comp,"ratio"
904        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
905        plt.xscale('log' if view == 'log' and not opts['is2d'] else 'linear')
906        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
907        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
908        #cbar_title = errstr if errview=="linear" else "log "+errstr
909    #if is2D:
910    #    h = plt.colorbar()
911    #    h.ax.set_title(cbar_title)
912    fig = plt.gcf()
913    extra_title = ' '+opts['title'] if opts['title'] else ''
914    fig.suptitle(":".join(opts['name']) + extra_title)
915
916    if have_base and have_comp and opts['show_hist']:
917        plt.figure()
918        v = relerr
919        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
920        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
921        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
922                   % (base.engine, comp.engine, comp.engine))
923        plt.ylabel('P(err)')
924        plt.title('Distribution of relative error between calculation engines')
925
926    return limits
927
928
929# ===========================================================================
930#
931
932# Set of command line options.
933# Normal options such as -plot/-noplot are specified as 'name'.
934# For options such as -nq=500 which require a value use 'name='.
935#
936OPTIONS = [
937    # Plotting
938    'plot', 'noplot', 'weights',
939    'linear', 'log', 'q4',
940    'rel', 'abs',
941    'hist', 'nohist',
942    'title=',
943
944    # Data generation
945    'data=', 'noise=', 'res=', 'nq=', 'q=',
946    'lowq', 'midq', 'highq', 'exq', 'zero',
947    '2d', '1d',
948
949    # Parameter set
950    'preset', 'random', 'random=', 'sets=',
951    'demo', 'default',  # TODO: remove demo/default
952    'nopars', 'pars',
953    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
954
955    # Calculation options
956    'poly', 'mono', 'cutoff=',
957    'magnetic', 'nonmagnetic',
958    'accuracy=',
959    'neval=',  # for timing...
960
961    # Precision options
962    'engine=',
963    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
964
965    # Output options
966    'help', 'html', 'edit',
967    ]
968
969NAME_OPTIONS = (lambda: set(k for k in OPTIONS if not k.endswith('=')))()
970VALUE_OPTIONS = (lambda: [k[:-1] for k in OPTIONS if k.endswith('=')])()
971
972
973def columnize(items, indent="", width=79):
974    # type: (List[str], str, int) -> str
975    """
976    Format a list of strings into columns.
977
978    Returns a string with carriage returns ready for printing.
979    """
980    column_width = max(len(w) for w in items) + 1
981    num_columns = (width - len(indent)) // column_width
982    num_rows = len(items) // num_columns
983    items = items + [""] * (num_rows * num_columns - len(items))
984    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
985    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
986             for row in zip(*columns)]
987    output = indent + ("\n"+indent).join(lines)
988    return output
989
990
991def get_pars(model_info, use_demo=False):
992    # type: (ModelInfo, bool) -> ParameterSet
993    """
994    Extract demo parameters from the model definition.
995    """
996    # Get the default values for the parameters
997    pars = {}
998    for p in model_info.parameters.call_parameters:
999        parts = [('', p.default)]
1000        if p.polydisperse:
1001            parts.append(('_pd', 0.0))
1002            parts.append(('_pd_n', 0))
1003            parts.append(('_pd_nsigma', 3.0))
1004            parts.append(('_pd_type', "gaussian"))
1005        for ext, val in parts:
1006            if p.length > 1:
1007                dict(("%s%d%s" % (p.id, k, ext), val)
1008                     for k in range(1, p.length+1))
1009            else:
1010                pars[p.id + ext] = val
1011
1012    # Plug in values given in demo
1013    if use_demo and model_info.demo:
1014        pars.update(model_info.demo)
1015    return pars
1016
1017INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1018def isnumber(s):
1019    # type: (str) -> bool
1020    """Return True if string contains an int or float"""
1021    match = FLOAT_RE.match(s)
1022    isfloat = (match and not s[match.end():])
1023    return isfloat or INTEGER_RE.match(s)
1024
1025# For distinguishing pairs of models for comparison
1026# key-value pair separator =
1027# shell characters  | & ; <> $ % ' " \ # `
1028# model and parameter names _
1029# parameter expressions - + * / . ( )
1030# path characters including tilde expansion and windows drive ~ / :
1031# not sure about brackets [] {}
1032# maybe one of the following @ ? ^ ! ,
1033PAR_SPLIT = ','
1034def parse_opts(argv):
1035    # type: (List[str]) -> Dict[str, Any]
1036    """
1037    Parse command line options.
1038    """
1039    MODELS = core.list_models()
1040    flags = [arg for arg in argv
1041             if arg.startswith('-')]
1042    values = [arg for arg in argv
1043              if not arg.startswith('-') and '=' in arg]
1044    positional_args = [arg for arg in argv
1045                       if not arg.startswith('-') and '=' not in arg]
1046    models = "\n    ".join("%-15s"%v for v in MODELS)
1047    if len(positional_args) == 0:
1048        print(USAGE)
1049        print("\nAvailable models:")
1050        print(columnize(MODELS, indent="  "))
1051        return None
1052
1053    invalid = [o[1:] for o in flags
1054               if o[1:] not in NAME_OPTIONS
1055               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1056    if invalid:
1057        print("Invalid options: %s"%(", ".join(invalid)))
1058        return None
1059
1060    name = positional_args[-1]
1061
1062    # pylint: disable=bad-whitespace,C0321
1063    # Interpret the flags
1064    opts = {
1065        'plot'      : True,
1066        'view'      : 'log',
1067        'is2d'      : False,
1068        'qmin'      : None,
1069        'qmax'      : 0.05,
1070        'nq'        : 128,
1071        'res'       : 0.0,
1072        'noise'     : 0.0,
1073        'accuracy'  : 'Low',
1074        'cutoff'    : '0.0',
1075        'seed'      : -1,  # default to preset
1076        'mono'      : True,
1077        # Default to magnetic a magnetic moment is set on the command line
1078        'magnetic'  : False,
1079        'show_pars' : False,
1080        'show_hist' : False,
1081        'rel_err'   : True,
1082        'explore'   : False,
1083        'use_demo'  : True,
1084        'zero'      : False,
1085        'html'      : False,
1086        'title'     : None,
1087        'datafile'  : None,
1088        'sets'      : 0,
1089        'engine'    : 'default',
1090        'count'     : '1',
1091        'show_weights' : False,
1092        'sphere'    : 0,
1093    }
1094    for arg in flags:
1095        if arg == '-noplot':    opts['plot'] = False
1096        elif arg == '-plot':    opts['plot'] = True
1097        elif arg == '-linear':  opts['view'] = 'linear'
1098        elif arg == '-log':     opts['view'] = 'log'
1099        elif arg == '-q4':      opts['view'] = 'q4'
1100        elif arg == '-1d':      opts['is2d'] = False
1101        elif arg == '-2d':      opts['is2d'] = True
1102        elif arg == '-exq':     opts['qmax'] = 10.0
1103        elif arg == '-highq':   opts['qmax'] = 1.0
1104        elif arg == '-midq':    opts['qmax'] = 0.2
1105        elif arg == '-lowq':    opts['qmax'] = 0.05
1106        elif arg == '-zero':    opts['zero'] = True
1107        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1108        elif arg.startswith('-q='):
1109            opts['qmin'], opts['qmax'] = [float(v) for v in arg[3:].split(':')]
1110        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1111        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1112        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1113        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1114        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1115        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1116        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1117        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1118        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
1119        elif arg.startswith('-random='):
1120            opts['seed'] = int(arg[8:])
1121            opts['sets'] = 0
1122        elif arg == '-random':
1123            opts['seed'] = np.random.randint(1000000)
1124            opts['sets'] = 0
1125        elif arg.startswith('-sphere'):
1126            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1127            opts['is2d'] = True
1128        elif arg == '-preset':  opts['seed'] = -1
1129        elif arg == '-mono':    opts['mono'] = True
1130        elif arg == '-poly':    opts['mono'] = False
1131        elif arg == '-magnetic':       opts['magnetic'] = True
1132        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1133        elif arg == '-pars':    opts['show_pars'] = True
1134        elif arg == '-nopars':  opts['show_pars'] = False
1135        elif arg == '-hist':    opts['show_hist'] = True
1136        elif arg == '-nohist':  opts['show_hist'] = False
1137        elif arg == '-rel':     opts['rel_err'] = True
1138        elif arg == '-abs':     opts['rel_err'] = False
1139        elif arg == '-half':    opts['engine'] = 'half'
1140        elif arg == '-fast':    opts['engine'] = 'fast'
1141        elif arg == '-single':  opts['engine'] = 'single'
1142        elif arg == '-double':  opts['engine'] = 'double'
1143        elif arg == '-single!': opts['engine'] = 'single!'
1144        elif arg == '-double!': opts['engine'] = 'double!'
1145        elif arg == '-quad!':   opts['engine'] = 'quad!'
1146        elif arg == '-edit':    opts['explore'] = True
1147        elif arg == '-demo':    opts['use_demo'] = True
1148        elif arg == '-default': opts['use_demo'] = False
1149        elif arg == '-weights': opts['show_weights'] = True
1150        elif arg == '-html':    opts['html'] = True
1151        elif arg == '-help':    opts['html'] = True
1152    # pylint: enable=bad-whitespace,C0321
1153
1154    # Magnetism forces 2D for now
1155    if opts['magnetic']:
1156        opts['is2d'] = True
1157
1158    # Force random if sets is used
1159    if opts['sets'] >= 1 and opts['seed'] < 0:
1160        opts['seed'] = np.random.randint(1000000)
1161    if opts['sets'] == 0:
1162        opts['sets'] = 1
1163
1164    # Create the computational engines
1165    if opts['qmin'] is None:
1166        opts['qmin'] = 0.001*opts['qmax']
1167    if opts['datafile'] is not None:
1168        data = load_data(os.path.expanduser(opts['datafile']))
1169    else:
1170        data, _ = make_data(opts)
1171
1172    comparison = any(PAR_SPLIT in v for v in values)
1173    if PAR_SPLIT in name:
1174        names = name.split(PAR_SPLIT, 2)
1175        comparison = True
1176    else:
1177        names = [name]*2
1178    try:
1179        model_info = [core.load_model_info(k) for k in names]
1180    except ImportError as exc:
1181        print(str(exc))
1182        print("Could not find model; use one of:\n    " + models)
1183        return None
1184
1185    if PAR_SPLIT in opts['engine']:
1186        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
1187        comparison = True
1188    else:
1189        opts['engine'] = [opts['engine']]*2
1190
1191    if PAR_SPLIT in opts['count']:
1192        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
1193        comparison = True
1194    else:
1195        opts['count'] = [int(opts['count'])]*2
1196
1197    if PAR_SPLIT in opts['cutoff']:
1198        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1199        comparison = True
1200    else:
1201        opts['cutoff'] = [float(opts['cutoff'])]*2
1202
1203    base = make_engine(model_info[0], data, opts['engine'][0], opts['cutoff'][0])
1204    if comparison:
1205        comp = make_engine(model_info[1], data, opts['engine'][1], opts['cutoff'][1])
1206    else:
1207        comp = None
1208
1209    # pylint: disable=bad-whitespace
1210    # Remember it all
1211    opts.update({
1212        'data'      : data,
1213        'name'      : names,
1214        'info'      : model_info,
1215        'engines'   : [base, comp],
1216        'values'    : values,
1217    })
1218    # pylint: enable=bad-whitespace
1219
1220    # Set the integration parameters to the half sphere
1221    if opts['sphere'] > 0:
1222        set_spherical_integration_parameters(opts, opts['sphere'])
1223
1224    return opts
1225
1226def set_spherical_integration_parameters(opts, steps):
1227    # type: (Dict[str, Any], int) -> None
1228    """
1229    Set integration parameters for spherical integration over the entire
1230    surface in theta-phi coordinates.
1231    """
1232    # Set the integration parameters to the half sphere
1233    opts['values'].extend([
1234        #'theta=90',
1235        'theta_pd=%g'%(90/np.sqrt(3)),
1236        'theta_pd_n=%d'%steps,
1237        'theta_pd_type=rectangle',
1238        #'phi=0',
1239        'phi_pd=%g'%(180/np.sqrt(3)),
1240        'phi_pd_n=%d'%(2*steps),
1241        'phi_pd_type=rectangle',
1242        #'background=0',
1243    ])
1244    if 'psi' in opts['info'][0].parameters:
1245        opts['values'].extend([
1246            #'psi=0',
1247            'psi_pd=%g'%(180/np.sqrt(3)),
1248            'psi_pd_n=%d'%(2*steps),
1249            'psi_pd_type=rectangle',
1250        ])
1251
1252def parse_pars(opts, maxdim=np.inf):
1253    # type: (Dict[str, Any], float) -> Tuple[Dict[str, float], Dict[str, float]]
1254    """
1255    Generate a parameter set.
1256
1257    The default values come from the model, or a randomized model if a seed
1258    value is given.  Next, evaluate any parameter expressions, constraining
1259    the value of the parameter within and between models.  If *maxdim* is
1260    given, limit parameters with units of Angstrom to this value.
1261
1262    Returns a pair of parameter dictionaries for base and comparison models.
1263    """
1264    model_info, model_info2 = opts['info']
1265
1266    # Get demo parameters from model definition, or use default parameters
1267    # if model does not define demo parameters
1268    pars = get_pars(model_info, opts['use_demo'])
1269    pars2 = get_pars(model_info2, opts['use_demo'])
1270    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1271    # randomize parameters
1272    #pars.update(set_pars)  # set value before random to control range
1273    if opts['seed'] > -1:
1274        pars = randomize_pars(model_info, pars)
1275        limit_dimensions(model_info, pars, maxdim)
1276        if model_info != model_info2:
1277            pars2 = randomize_pars(model_info2, pars2)
1278            limit_dimensions(model_info, pars2, maxdim)
1279            # Share values for parameters with the same name
1280            for k, v in pars.items():
1281                if k in pars2:
1282                    pars2[k] = v
1283        else:
1284            pars2 = pars.copy()
1285        constrain_pars(model_info, pars)
1286        constrain_pars(model_info2, pars2)
1287    pars = suppress_pd(pars, opts['mono'])
1288    pars2 = suppress_pd(pars2, opts['mono'])
1289    pars = suppress_magnetism(pars, not opts['magnetic'])
1290    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1291
1292    # Fill in parameters given on the command line
1293    presets = {}
1294    presets2 = {}
1295    for arg in opts['values']:
1296        k, v = arg.split('=', 1)
1297        if k not in pars and k not in pars2:
1298            # extract base name without polydispersity info
1299            s = set(p.split('_pd')[0] for p in pars)
1300            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1301            return None
1302        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v, v)
1303        if v1 and k in pars:
1304            presets[k] = float(v1) if isnumber(v1) else v1
1305        if v2 and k in pars2:
1306            presets2[k] = float(v2) if isnumber(v2) else v2
1307
1308    # If pd given on the command line, default pd_n to 35
1309    for k, v in list(presets.items()):
1310        if k.endswith('_pd'):
1311            presets.setdefault(k+'_n', 35.)
1312    for k, v in list(presets2.items()):
1313        if k.endswith('_pd'):
1314            presets2.setdefault(k+'_n', 35.)
1315
1316    # Evaluate preset parameter expressions
1317    context = MATH.copy()
1318    context['np'] = np
1319    context.update(pars)
1320    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1321    for k, v in presets.items():
1322        if not isinstance(v, float) and not k.endswith('_type'):
1323            presets[k] = eval(v, context)
1324    context.update(presets)
1325    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1326    for k, v in presets2.items():
1327        if not isinstance(v, float) and not k.endswith('_type'):
1328            presets2[k] = eval(v, context)
1329
1330    # update parameters with presets
1331    pars.update(presets)  # set value after random to control value
1332    pars2.update(presets2)  # set value after random to control value
1333    #import pprint; pprint.pprint(model_info)
1334
1335    if opts['show_pars']:
1336        if model_info.name != model_info2.name or pars != pars2:
1337            print("==== %s ====="%model_info.name)
1338            print(str(parlist(model_info, pars, opts['is2d'])))
1339            print("==== %s ====="%model_info2.name)
1340            print(str(parlist(model_info2, pars2, opts['is2d'])))
1341        else:
1342            print(str(parlist(model_info, pars, opts['is2d'])))
1343
1344    return pars, pars2
1345
1346def show_docs(opts):
1347    # type: (Dict[str, Any]) -> None
1348    """
1349    show html docs for the model
1350    """
1351    from .generate import make_html
1352    from . import rst2html
1353
1354    info = opts['info'][0]
1355    html = make_html(info)
1356    path = os.path.dirname(info.filename)
1357    url = "file://" + path.replace("\\", "/")[2:] + "/"
1358    rst2html.view_html_qtapp(html, url)
1359
1360def explore(opts):
1361    # type: (Dict[str, Any]) -> None
1362    """
1363    explore the model using the bumps gui.
1364    """
1365    import wx  # type: ignore
1366    from bumps.names import FitProblem  # type: ignore
1367    from bumps.gui.app_frame import AppFrame  # type: ignore
1368    from bumps.gui import signal
1369
1370    is_mac = "cocoa" in wx.version()
1371    # Create an app if not running embedded
1372    app = wx.App() if wx.GetApp() is None else None
1373    model = Explore(opts)
1374    problem = FitProblem(model)
1375    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1376    if not is_mac:
1377        frame.Show()
1378    frame.panel.set_model(model=problem)
1379    frame.panel.Layout()
1380    frame.panel.aui.Split(0, wx.TOP)
1381    def _reset_parameters(event):
1382        model.revert_values()
1383        signal.update_parameters(problem)
1384    frame.Bind(wx.EVT_TOOL, _reset_parameters, frame.ToolBar.GetToolByPos(1))
1385    if is_mac:
1386        frame.Show()
1387    # If running withing an app, start the main loop
1388    if app:
1389        app.MainLoop()
1390
1391class Explore(object):
1392    """
1393    Bumps wrapper for a SAS model comparison.
1394
1395    The resulting object can be used as a Bumps fit problem so that
1396    parameters can be adjusted in the GUI, with plots updated on the fly.
1397    """
1398    def __init__(self, opts):
1399        # type: (Dict[str, Any]) -> None
1400        from bumps.cli import config_matplotlib  # type: ignore
1401        from . import bumps_model
1402        config_matplotlib()
1403        self.opts = opts
1404        opts['pars'] = list(opts['pars'])
1405        p1, p2 = opts['pars']
1406        m1, m2 = opts['info']
1407        self.fix_p2 = m1 != m2 or p1 != p2
1408        model_info = m1
1409        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1410        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1411        if not opts['is2d']:
1412            for p in model_info.parameters.user_parameters({}, is2d=False):
1413                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1414                    k = p.name+ext
1415                    v = pars.get(k, None)
1416                    if v is not None:
1417                        v.range(*parameter_range(k, v.value))
1418        else:
1419            for k, v in pars.items():
1420                v.range(*parameter_range(k, v.value))
1421
1422        self.pars = pars
1423        self.starting_values = dict((k, v.value) for k, v in pars.items())
1424        self.pd_types = pd_types
1425        self.limits = None
1426
1427    def revert_values(self):
1428        # type: () -> None
1429        """
1430        Restore starting values of the parameters.
1431        """
1432        for k, v in self.starting_values.items():
1433            self.pars[k].value = v
1434
1435    def model_update(self):
1436        # type: () -> None
1437        """
1438        Respond to signal that model parameters have been changed.
1439        """
1440        pass
1441
1442    def numpoints(self):
1443        # type: () -> int
1444        """
1445        Return the number of points.
1446        """
1447        return len(self.pars) + 1  # so dof is 1
1448
1449    def parameters(self):
1450        # type: () -> Any   # Dict/List hierarchy of parameters
1451        """
1452        Return a dictionary of parameters.
1453        """
1454        return self.pars
1455
1456    def nllf(self):
1457        # type: () -> float
1458        """
1459        Return cost.
1460        """
1461        # pylint: disable=no-self-use
1462        return 0.  # No nllf
1463
1464    def plot(self, view='log'):
1465        # type: (str) -> None
1466        """
1467        Plot the data and residuals.
1468        """
1469        pars = dict((k, v.value) for k, v in self.pars.items())
1470        pars.update(self.pd_types)
1471        self.opts['pars'][0] = pars
1472        if not self.fix_p2:
1473            self.opts['pars'][1] = pars
1474        result = run_models(self.opts)
1475        limits = plot_models(self.opts, result, limits=self.limits)
1476        if self.limits is None:
1477            vmin, vmax = limits
1478            self.limits = vmax*1e-7, 1.3*vmax
1479            import pylab; pylab.clf()
1480            plot_models(self.opts, result, limits=self.limits)
1481
1482
1483def main(*argv):
1484    # type: (*str) -> None
1485    """
1486    Main program.
1487    """
1488    opts = parse_opts(argv)
1489    if opts is not None:
1490        if opts['seed'] > -1:
1491            print("Randomize using -random=%i"%opts['seed'])
1492            np.random.seed(opts['seed'])
1493        if opts['html']:
1494            show_docs(opts)
1495        elif opts['explore']:
1496            opts['pars'] = parse_pars(opts)
1497            if opts['pars'] is None:
1498                return
1499            explore(opts)
1500        else:
1501            compare(opts)
1502
1503if __name__ == "__main__":
1504    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.