source: sasmodels/sasmodels/compare.py @ 2a7e20e

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 2a7e20e was 2a7e20e, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

update developer docs with current interpretation of orientation; describe the scripts in the explore directory

  • Property mode set to 100755
File size: 52.6 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from .data import plot_theory, empty_data1D, empty_data2D, load_data
43from .direct_model import DirectModel, get_mesh
44from .generate import FLOAT_RE, set_integration_size
45from .weights import plot_weights
46
47# pylint: disable=unused-import
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except ImportError:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56# pylint: enable=unused-import
57
58USAGE = """
59usage: sascomp model [options...] [key=val]
60
61Generate and compare SAS models.  If a single model is specified it shows
62a plot of that model.  Different models can be compared, or the same model
63with different parameters.  The same model with the same parameters can
64be compared with different calculation engines to see the effects of precision
65on the resultant values.
66
67model or model1,model2 are the names of the models to compare (see below).
68
69Options (* for default):
70
71    === data generation ===
72    -data="path" uses q, dq from the data file
73    -noise=0 sets the measurement error dI/I
74    -res=0 sets the resolution width dQ/Q if calculating with resolution
75    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
76    -q=min:max alternative specification of qrange
77    -nq=128 sets the number of Q points in the data set
78    -1d*/-2d computes 1d or 2d data
79    -zero indicates that q=0 should be included
80
81    === model parameters ===
82    -preset*/-random[=seed] preset or random parameters
83    -sets=n generates n random datasets with the seed given by -random=seed
84    -pars/-nopars* prints the parameter set or not
85    -default/-demo* use demo vs default parameters
86    -sphere[=150] set up spherical integration over theta/phi using n points
87
88    === calculation options ===
89    -mono*/-poly force monodisperse or allow polydisperse random parameters
90    -cutoff=1e-5* cutoff value for including a point in polydispersity
91    -magnetic/-nonmagnetic* suppress magnetism
92    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
93    -neval=1 sets the number of evals for more accurate timing
94    -ngauss=0 overrides the number of points in the 1-D gaussian quadrature
95
96    === precision options ===
97    -engine=default uses the default calcution precision
98    -single/-double/-half/-fast sets an OpenCL calculation engine
99    -single!/-double!/-quad! sets an OpenMP calculation engine
100
101    === plotting ===
102    -plot*/-noplot plots or suppress the plot of the model
103    -linear/-log*/-q4 intensity scaling on plots
104    -hist/-nohist* plot histogram of relative error
105    -abs/-rel* plot relative or absolute error
106    -title="note" adds note to the plot title, after the model name
107    -weights shows weights plots for the polydisperse parameters
108
109    === output options ===
110    -edit starts the parameter explorer
111    -help/-html shows the model docs instead of running the model
112
113The interpretation of quad precision depends on architecture, and may
114vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
115On unix and mac you may need single quotes around the DLL computation
116engines, such as -engine='single!,double!' since !, is treated as a history
117expansion request in the shell.
118
119Key=value pairs allow you to set specific values for the model parameters.
120Key=value1,value2 to compare different values of the same parameter. The
121value can be an expression including other parameters.
122
123Items later on the command line override those that appear earlier.
124
125Examples:
126
127    # compare single and double precision calculation for a barbell
128    sascomp barbell -engine=single,double
129
130    # generate 10 random lorentz models, with seed=27
131    sascomp lorentz -sets=10 -seed=27
132
133    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
134    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
135
136    # model timing test requires multiple evals to perform the estimate
137    sascomp pringle -engine=single,double -timing=100,100 -noplot
138"""
139
140# Update docs with command line usage string.   This is separate from the usual
141# doc string so that we can display it at run time if there is an error.
142# lin
143__doc__ = (__doc__  # pylint: disable=redefined-builtin
144           + """
145Program description
146-------------------
147
148""" + USAGE)
149
150kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
151
152def build_math_context():
153    # type: () -> Dict[str, Callable]
154    """build dictionary of functions from math module"""
155    return dict((k, getattr(math, k))
156                for k in dir(math) if not k.startswith('_'))
157
158#: list of math functions for use in evaluating parameters
159MATH = build_math_context()
160
161# CRUFT python 2.6
162if not hasattr(datetime.timedelta, 'total_seconds'):
163    def delay(dt):
164        """Return number date-time delta as number seconds"""
165        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
166else:
167    def delay(dt):
168        """Return number date-time delta as number seconds"""
169        return dt.total_seconds()
170
171
172class push_seed(object):
173    """
174    Set the seed value for the random number generator.
175
176    When used in a with statement, the random number generator state is
177    restored after the with statement is complete.
178
179    :Parameters:
180
181    *seed* : int or array_like, optional
182        Seed for RandomState
183
184    :Example:
185
186    Seed can be used directly to set the seed::
187
188        >>> from numpy.random import randint
189        >>> push_seed(24)
190        <...push_seed object at...>
191        >>> print(randint(0,1000000,3))
192        [242082    899 211136]
193
194    Seed can also be used in a with statement, which sets the random
195    number generator state for the enclosed computations and restores
196    it to the previous state on completion::
197
198        >>> with push_seed(24):
199        ...    print(randint(0,1000000,3))
200        [242082    899 211136]
201
202    Using nested contexts, we can demonstrate that state is indeed
203    restored after the block completes::
204
205        >>> with push_seed(24):
206        ...    print(randint(0,1000000))
207        ...    with push_seed(24):
208        ...        print(randint(0,1000000,3))
209        ...    print(randint(0,1000000))
210        242082
211        [242082    899 211136]
212        899
213
214    The restore step is protected against exceptions in the block::
215
216        >>> with push_seed(24):
217        ...    print(randint(0,1000000))
218        ...    try:
219        ...        with push_seed(24):
220        ...            print(randint(0,1000000,3))
221        ...            raise Exception()
222        ...    except Exception:
223        ...        print("Exception raised")
224        ...    print(randint(0,1000000))
225        242082
226        [242082    899 211136]
227        Exception raised
228        899
229    """
230    def __init__(self, seed=None):
231        # type: (Optional[int]) -> None
232        self._state = np.random.get_state()
233        np.random.seed(seed)
234
235    def __enter__(self):
236        # type: () -> None
237        pass
238
239    def __exit__(self, exc_type, exc_value, trace):
240        # type: (Any, BaseException, Any) -> None
241        np.random.set_state(self._state)
242
243def tic():
244    # type: () -> Callable[[], float]
245    """
246    Timer function.
247
248    Use "toc=tic()" to start the clock and "toc()" to measure
249    a time interval.
250    """
251    then = datetime.datetime.now()
252    return lambda: delay(datetime.datetime.now() - then)
253
254
255def set_beam_stop(data, radius, outer=None):
256    # type: (Data, float, float) -> None
257    """
258    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
259    """
260    if hasattr(data, 'qx_data'):
261        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
262        data.mask = (q < radius)
263        if outer is not None:
264            data.mask |= (q >= outer)
265    else:
266        data.mask = (data.x < radius)
267        if outer is not None:
268            data.mask |= (data.x >= outer)
269
270
271def parameter_range(p, v):
272    # type: (str, float) -> Tuple[float, float]
273    """
274    Choose a parameter range based on parameter name and initial value.
275    """
276    # process the polydispersity options
277    if p.endswith('_pd_n'):
278        return 0., 100.
279    elif p.endswith('_pd_nsigma'):
280        return 0., 5.
281    elif p.endswith('_pd_type'):
282        raise ValueError("Cannot return a range for a string value")
283    elif any(s in p for s in ('theta', 'phi', 'psi')):
284        # orientation in [-180,180], orientation pd in [0,45]
285        if p.endswith('_pd'):
286            return 0., 180.
287        else:
288            return -180., 180.
289    elif p.endswith('_pd'):
290        return 0., 1.
291    elif 'sld' in p:
292        return -0.5, 10.
293    elif p == 'background':
294        return 0., 10.
295    elif p == 'scale':
296        return 0., 1.e3
297    elif v < 0.:
298        return 2.*v, -2.*v
299    else:
300        return 0., (2.*v if v > 0. else 1.)
301
302
303def _randomize_one(model_info, name, value):
304    # type: (ModelInfo, str, float) -> float
305    # type: (ModelInfo, str, str) -> str
306    """
307    Randomize a single parameter.
308    """
309    # Set the amount of polydispersity/angular dispersion, but by default pd_n
310    # is zero so there is no polydispersity.  This allows us to turn on/off
311    # pd by setting pd_n, and still have randomly generated values
312    if name.endswith('_pd'):
313        par = model_info.parameters[name[:-3]]
314        if par.type == 'orientation':
315            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
316            return 180*np.random.beta(2.5, 20)
317        else:
318            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
319            return np.random.beta(1.5, 7)
320
321    # pd is selected globally rather than per parameter, so set to 0 for no pd
322    # In particular, when multiple pd dimensions, want to decrease the number
323    # of points per dimension for faster computation
324    if name.endswith('_pd_n'):
325        return 0
326
327    # Don't mess with distribution type for now
328    if name.endswith('_pd_type'):
329        return 'gaussian'
330
331    # type-dependent value of number of sigmas; for gaussian use 3.
332    if name.endswith('_pd_nsigma'):
333        return 3.
334
335    # background in the range [0.01, 1]
336    if name == 'background':
337        return 10**np.random.uniform(-2, 0)
338
339    # scale defaults to 0.1% to 30% volume fraction
340    if name == 'scale':
341        return 10**np.random.uniform(-3, -0.5)
342
343    # If it is a list of choices, pick one at random with equal probability
344    # In practice, the model specific random generator will override.
345    par = model_info.parameters[name]
346    if len(par.limits) > 2:  # choice list
347        return np.random.randint(len(par.limits))
348
349    # If it is a fixed range, pick from it with equal probability.
350    # For logarithmic ranges, the model will have to override.
351    if np.isfinite(par.limits).all():
352        return np.random.uniform(*par.limits)
353
354    # If the paramter is marked as an sld use the range of neutron slds
355    # TODO: ought to randomly contrast match a pair of SLDs
356    if par.type == 'sld':
357        return np.random.uniform(-0.5, 12)
358
359    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
360    if par.name.startswith('M0:'):
361        return np.random.uniform(0, 5)
362
363    # Guess at the random length/radius/thickness.  In practice, all models
364    # are going to set their own reasonable ranges.
365    if par.type == 'volume':
366        if ('length' in par.name or
367                'radius' in par.name or
368                'thick' in par.name):
369            return 10**np.random.uniform(2, 4)
370
371    # In the absence of any other info, select a value in [0, 2v], or
372    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
373    # model random parameter generators will override this default.
374    low, high = parameter_range(par.name, value)
375    limits = (max(par.limits[0], low), min(par.limits[1], high))
376    return np.random.uniform(*limits)
377
378def _random_pd(model_info, pars):
379    # type: (ModelInfo, Dict[str, float]) -> None
380    """
381    Generate a random dispersity distribution for the model.
382
383    1% no shape dispersity
384    85% single shape parameter
385    13% two shape parameters
386    1% three shape parameters
387
388    If oriented, then put dispersity in theta, add phi and psi dispersity
389    with 10% probability for each.
390    """
391    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
392    pd_volume = []
393    pd_oriented = []
394    for p in pd:
395        if p.type == 'orientation':
396            pd_oriented.append(p.name)
397        elif p.length_control is not None:
398            n = int(pars.get(p.length_control, 1) + 0.5)
399            pd_volume.extend(p.name+str(k+1) for k in range(n))
400        elif p.length > 1:
401            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
402        else:
403            pd_volume.append(p.name)
404    u = np.random.rand()
405    n = len(pd_volume)
406    if u < 0.01 or n < 1:
407        pass  # 1% chance of no polydispersity
408    elif u < 0.86 or n < 2:
409        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
410    elif u < 0.99 or n < 3:
411        choices = np.random.choice(len(pd_volume), size=2)
412        pars[pd_volume[choices[0]]+"_pd_n"] = 25
413        pars[pd_volume[choices[1]]+"_pd_n"] = 10
414    else:
415        choices = np.random.choice(len(pd_volume), size=3)
416        pars[pd_volume[choices[0]]+"_pd_n"] = 25
417        pars[pd_volume[choices[1]]+"_pd_n"] = 10
418        pars[pd_volume[choices[2]]+"_pd_n"] = 5
419    if pd_oriented:
420        pars['theta_pd_n'] = 20
421        if np.random.rand() < 0.1:
422            pars['phi_pd_n'] = 5
423        if np.random.rand() < 0.1:
424            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
425                #print("generating psi_pd_n")
426                pars['psi_pd_n'] = 5
427
428    ## Show selected polydispersity
429    #for name, value in pars.items():
430    #    if name.endswith('_pd_n') and value > 0:
431    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
432
433
434def randomize_pars(model_info, pars):
435    # type: (ModelInfo, ParameterSet) -> ParameterSet
436    """
437    Generate random values for all of the parameters.
438
439    Valid ranges for the random number generator are guessed from the name of
440    the parameter; this will not account for constraints such as cap radius
441    greater than cylinder radius in the capped_cylinder model, so
442    :func:`constrain_pars` needs to be called afterward..
443    """
444    # Note: the sort guarantees order of calls to random number generator
445    random_pars = dict((p, _randomize_one(model_info, p, v))
446                       for p, v in sorted(pars.items()))
447    if model_info.random is not None:
448        random_pars.update(model_info.random())
449    _random_pd(model_info, random_pars)
450    return random_pars
451
452
453def limit_dimensions(model_info, pars, maxdim):
454    # type: (ModelInfo, ParameterSet, float) -> None
455    """
456    Limit parameters of units of Ang to maxdim.
457    """
458    for p in model_info.parameters.call_parameters:
459        value = pars[p.name]
460        if p.units == 'Ang' and value > maxdim:
461            pars[p.name] = maxdim*10**np.random.uniform(-3, 0)
462
463def constrain_pars(model_info, pars):
464    # type: (ModelInfo, ParameterSet) -> None
465    """
466    Restrict parameters to valid values.
467
468    This includes model specific code for models such as capped_cylinder
469    which need to support within model constraints (cap radius more than
470    cylinder radius in this case).
471
472    Warning: this updates the *pars* dictionary in place.
473    """
474    # TODO: move the model specific code to the individual models
475    name = model_info.id
476    # if it is a product model, then just look at the form factor since
477    # none of the structure factors need any constraints.
478    if '*' in name:
479        name = name.split('*')[0]
480
481    # Suppress magnetism for python models (not yet implemented)
482    if callable(model_info.Iq):
483        pars.update(suppress_magnetism(pars))
484
485    if name == 'barbell':
486        if pars['radius_bell'] < pars['radius']:
487            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
488
489    elif name == 'capped_cylinder':
490        if pars['radius_cap'] < pars['radius']:
491            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
492
493    elif name == 'guinier':
494        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
495        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
496        # => ln A - (Rg^2 q^2/3) > -30 ln 10
497        # => Rg^2 q^2/3 < 30 ln 10 + ln A
498        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
499        #q_max = 0.2  # mid q maximum
500        q_max = 1.0  # high q maximum
501        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
502        pars['rg'] = min(pars['rg'], rg_max)
503
504    elif name == 'pearl_necklace':
505        if pars['radius'] < pars['thick_string']:
506            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
507
508    elif name == 'rpa':
509        # Make sure phi sums to 1.0
510        if pars['case_num'] < 2:
511            pars['Phi1'] = 0.
512            pars['Phi2'] = 0.
513        elif pars['case_num'] < 5:
514            pars['Phi1'] = 0.
515        total = sum(pars['Phi'+c] for c in '1234')
516        for c in '1234':
517            pars['Phi'+c] /= total
518
519def parlist(model_info, pars, is2d):
520    # type: (ModelInfo, ParameterSet, bool) -> str
521    """
522    Format the parameter list for printing.
523    """
524    is2d = True
525    lines = []
526    parameters = model_info.parameters
527    magnetic = False
528    magnetic_pars = []
529    for p in parameters.user_parameters(pars, is2d):
530        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
531            continue
532        if p.id.startswith('up:'):
533            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
534            continue
535        fields = dict(
536            value=pars.get(p.id, p.default),
537            pd=pars.get(p.id+"_pd", 0.),
538            n=int(pars.get(p.id+"_pd_n", 0)),
539            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
540            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
541            relative_pd=p.relative_pd,
542            M0=pars.get('M0:'+p.id, 0.),
543            mphi=pars.get('mphi:'+p.id, 0.),
544            mtheta=pars.get('mtheta:'+p.id, 0.),
545        )
546        lines.append(_format_par(p.name, **fields))
547        magnetic = magnetic or fields['M0'] != 0.
548    if magnetic and magnetic_pars:
549        lines.append(" ".join(magnetic_pars))
550    return "\n".join(lines)
551
552    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
553
554def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
555                relative_pd=False, M0=0., mphi=0., mtheta=0.):
556    # type: (str, float, float, int, float, str) -> str
557    line = "%s: %g"%(name, value)
558    if pd != 0.  and n != 0:
559        if relative_pd:
560            pd *= value
561        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
562                % (pd, n, nsigma, nsigma, pdtype)
563    if M0 != 0.:
564        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
565    return line
566
567def suppress_pd(pars, suppress=True):
568    # type: (ParameterSet) -> ParameterSet
569    """
570    If suppress is True complete eliminate polydispersity of the model to test
571    models more quickly.  If suppress is False, make sure at least one
572    parameter is polydisperse, setting the first polydispersity parameter to
573    15% if no polydispersity is given (with no explicit demo parameters given
574    in the model, there will be no default polydispersity).
575    """
576    pars = pars.copy()
577    #print("pars=", pars)
578    if suppress:
579        for p in pars:
580            if p.endswith("_pd_n"):
581                pars[p] = 0
582    else:
583        any_pd = False
584        first_pd = None
585        for p in pars:
586            if p.endswith("_pd_n"):
587                pd = pars.get(p[:-2], 0.)
588                any_pd |= (pars[p] != 0 and pd != 0.)
589                if first_pd is None:
590                    first_pd = p
591        if not any_pd and first_pd is not None:
592            if pars[first_pd] == 0:
593                pars[first_pd] = 35
594            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
595                pars[first_pd[:-2]] = 0.15
596    return pars
597
598def suppress_magnetism(pars, suppress=True):
599    # type: (ParameterSet) -> ParameterSet
600    """
601    If suppress is True complete eliminate magnetism of the model to test
602    models more quickly.  If suppress is False, make sure at least one sld
603    parameter is magnetic, setting the first parameter to have a strong
604    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
605    in the model, there will be no default magnetism).
606    """
607    pars = pars.copy()
608    if suppress:
609        for p in pars:
610            if p.startswith("M0:"):
611                pars[p] = 0
612    else:
613        any_mag = False
614        first_mag = None
615        for p in pars:
616            if p.startswith("M0:"):
617                any_mag |= (pars[p] != 0)
618                if first_mag is None:
619                    first_mag = p
620        if not any_mag and first_mag is not None:
621            pars[first_mag] = 8.
622    return pars
623
624
625DTYPE_MAP = {
626    'half': '16',
627    'fast': 'fast',
628    'single': '32',
629    'double': '64',
630    'quad': '128',
631    'f16': '16',
632    'f32': '32',
633    'f64': '64',
634    'float16': '16',
635    'float32': '32',
636    'float64': '64',
637    'float128': '128',
638    'longdouble': '128',
639}
640def eval_opencl(model_info, data, dtype='single', cutoff=0.):
641    # type: (ModelInfo, Data, str, float) -> Calculator
642    """
643    Return a model calculator using the OpenCL calculation engine.
644    """
645    if not core.HAVE_OPENCL:
646        raise RuntimeError("OpenCL not available")
647    model = core.build_model(model_info, dtype=dtype, platform="ocl")
648    calculator = DirectModel(data, model, cutoff=cutoff)
649    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
650    return calculator
651
652def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
653    # type: (ModelInfo, Data, str, float) -> Calculator
654    """
655    Return a model calculator using the DLL calculation engine.
656    """
657    model = core.build_model(model_info, dtype=dtype, platform="dll")
658    calculator = DirectModel(data, model, cutoff=cutoff)
659    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
660    return calculator
661
662def time_calculation(calculator, pars, evals=1):
663    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
664    """
665    Compute the average calculation time over N evaluations.
666
667    An additional call is generated without polydispersity in order to
668    initialize the calculation engine, and make the average more stable.
669    """
670    # initialize the code so time is more accurate
671    if evals > 1:
672        calculator(**suppress_pd(pars))
673    toc = tic()
674    # make sure there is at least one eval
675    value = calculator(**pars)
676    for _ in range(evals-1):
677        value = calculator(**pars)
678    average_time = toc()*1000. / evals
679    #print("I(q)",value)
680    return value, average_time
681
682def make_data(opts):
683    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
684    """
685    Generate an empty dataset, used with the model to set Q points
686    and resolution.
687
688    *opts* contains the options, with 'qmax', 'nq', 'res',
689    'accuracy', 'is2d' and 'view' parsed from the command line.
690    """
691    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res']
692    if opts['is2d']:
693        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
694        data = empty_data2D(q, resolution=res)
695        data.accuracy = opts['accuracy']
696        set_beam_stop(data, qmin)
697        index = ~data.mask
698    else:
699        if opts['view'] == 'log' and not opts['zero']:
700            q = np.logspace(math.log10(qmin), math.log10(qmax), nq)
701        else:
702            q = np.linspace(qmin, qmax, nq)
703        if opts['zero']:
704            q = np.hstack((0, q))
705        data = empty_data1D(q, resolution=res)
706        index = slice(None, None)
707    return data, index
708
709def make_engine(model_info, data, dtype, cutoff, ngauss=0):
710    # type: (ModelInfo, Data, str, float) -> Calculator
711    """
712    Generate the appropriate calculation engine for the given datatype.
713
714    Datatypes with '!' appended are evaluated using external C DLLs rather
715    than OpenCL.
716    """
717    if ngauss:
718        set_integration_size(model_info, ngauss)
719
720    if dtype is None or not dtype.endswith('!'):
721        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
722    else:
723        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
724
725def _show_invalid(data, theory):
726    # type: (Data, np.ma.ndarray) -> None
727    """
728    Display a list of the non-finite values in theory.
729    """
730    if not theory.mask.any():
731        return
732
733    if hasattr(data, 'x'):
734        bad = zip(data.x[theory.mask], theory[theory.mask])
735        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
736
737
738def compare(opts, limits=None, maxdim=np.inf):
739    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
740    """
741    Preform a comparison using options from the command line.
742
743    *limits* are the limits on the values to use, either to set the y-axis
744    for 1D or to set the colormap scale for 2D.  If None, then they are
745    inferred from the data and returned. When exploring using Bumps,
746    the limits are set when the model is initially called, and maintained
747    as the values are adjusted, making it easier to see the effects of the
748    parameters.
749
750    *maxdim* is the maximum value for any parameter with units of Angstrom.
751    """
752    for k in range(opts['sets']):
753        if k > 1:
754            # print a separate seed for each dataset for better reproducibility
755            new_seed = np.random.randint(1000000)
756            print("Set %d uses -random=%i"%(k+1, new_seed))
757            np.random.seed(new_seed)
758        opts['pars'] = parse_pars(opts, maxdim=maxdim)
759        if opts['pars'] is None:
760            return
761        result = run_models(opts, verbose=True)
762        if opts['plot']:
763            limits = plot_models(opts, result, limits=limits, setnum=k)
764        if opts['show_weights']:
765            base, _ = opts['engines']
766            base_pars, _ = opts['pars']
767            model_info = base._kernel.info
768            dim = base._kernel.dim
769            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
770    if opts['plot']:
771        import matplotlib.pyplot as plt
772        plt.show()
773    return limits
774
775def run_models(opts, verbose=False):
776    # type: (Dict[str, Any]) -> Dict[str, Any]
777    """
778    Process a parameter set, return calculation results and times.
779    """
780
781    base, comp = opts['engines']
782    base_n, comp_n = opts['count']
783    base_pars, comp_pars = opts['pars']
784    data = opts['data']
785
786    comparison = comp is not None
787
788    base_time = comp_time = None
789    base_value = comp_value = resid = relerr = None
790
791    # Base calculation
792    try:
793        base_raw, base_time = time_calculation(base, base_pars, base_n)
794        base_value = np.ma.masked_invalid(base_raw)
795        if verbose:
796            print("%s t=%.2f ms, intensity=%.0f"
797                  % (base.engine, base_time, base_value.sum()))
798        _show_invalid(data, base_value)
799    except ImportError:
800        traceback.print_exc()
801
802    # Comparison calculation
803    if comparison:
804        try:
805            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
806            comp_value = np.ma.masked_invalid(comp_raw)
807            if verbose:
808                print("%s t=%.2f ms, intensity=%.0f"
809                      % (comp.engine, comp_time, comp_value.sum()))
810            _show_invalid(data, comp_value)
811        except ImportError:
812            traceback.print_exc()
813
814    # Compare, but only if computing both forms
815    if comparison:
816        resid = (base_value - comp_value)
817        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
818        if verbose:
819            _print_stats("|%s-%s|"
820                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
821                         resid)
822            _print_stats("|(%s-%s)/%s|"
823                         % (base.engine, comp.engine, comp.engine),
824                         relerr)
825
826    return dict(base_value=base_value, comp_value=comp_value,
827                base_time=base_time, comp_time=comp_time,
828                resid=resid, relerr=relerr)
829
830
831def _print_stats(label, err):
832    # type: (str, np.ma.ndarray) -> None
833    # work with trimmed data, not the full set
834    sorted_err = np.sort(abs(err.compressed()))
835    if len(sorted_err) == 0:
836        print(label + "  no valid values")
837        return
838
839    p50 = int((len(sorted_err)-1)*0.50)
840    p98 = int((len(sorted_err)-1)*0.98)
841    data = [
842        "max:%.3e"%sorted_err[-1],
843        "median:%.3e"%sorted_err[p50],
844        "98%%:%.3e"%sorted_err[p98],
845        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
846        "zero-offset:%+.3e"%np.mean(sorted_err),
847        ]
848    print(label+"  "+"  ".join(data))
849
850
851def plot_models(opts, result, limits=None, setnum=0):
852    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
853    """
854    Plot the results from :func:`run_model`.
855    """
856    import matplotlib.pyplot as plt
857
858    base_value, comp_value = result['base_value'], result['comp_value']
859    base_time, comp_time = result['base_time'], result['comp_time']
860    resid, relerr = result['resid'], result['relerr']
861
862    have_base, have_comp = (base_value is not None), (comp_value is not None)
863    base, comp = opts['engines']
864    data = opts['data']
865    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
866
867    # Plot if requested
868    view = opts['view']
869    if limits is None:
870        vmin, vmax = np.inf, -np.inf
871        if have_base:
872            vmin = min(vmin, base_value.min())
873            vmax = max(vmax, base_value.max())
874        if have_comp:
875            vmin = min(vmin, comp_value.min())
876            vmax = max(vmax, comp_value.max())
877        limits = vmin, vmax
878
879    if have_base:
880        if have_comp:
881            plt.subplot(131)
882        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
883        plt.title("%s t=%.2f ms"%(base.engine, base_time))
884        #cbar_title = "log I"
885    if have_comp:
886        if have_base:
887            plt.subplot(132)
888        if not opts['is2d'] and have_base:
889            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
890        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
891        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
892        #cbar_title = "log I"
893    if have_base and have_comp:
894        plt.subplot(133)
895        if not opts['rel_err']:
896            err, errstr, errview = resid, "abs err", "linear"
897        else:
898            err, errstr, errview = abs(relerr), "rel err", "log"
899            if (err == 0.).all():
900                errview = 'linear'
901        if 0:  # 95% cutoff
902            sorted_err = np.sort(err.flatten())
903            cutoff = sorted_err[int(sorted_err.size*0.95)]
904            err[err > cutoff] = cutoff
905        #err,errstr = base/comp,"ratio"
906        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
907        plt.xscale('log' if view == 'log' and not opts['is2d'] else 'linear')
908        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
909        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
910        #cbar_title = errstr if errview=="linear" else "log "+errstr
911    #if is2D:
912    #    h = plt.colorbar()
913    #    h.ax.set_title(cbar_title)
914    fig = plt.gcf()
915    extra_title = ' '+opts['title'] if opts['title'] else ''
916    fig.suptitle(":".join(opts['name']) + extra_title)
917
918    if have_base and have_comp and opts['show_hist']:
919        plt.figure()
920        v = relerr
921        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
922        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
923        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
924                   % (base.engine, comp.engine, comp.engine))
925        plt.ylabel('P(err)')
926        plt.title('Distribution of relative error between calculation engines')
927
928    return limits
929
930
931# ===========================================================================
932#
933
934# Set of command line options.
935# Normal options such as -plot/-noplot are specified as 'name'.
936# For options such as -nq=500 which require a value use 'name='.
937#
938OPTIONS = [
939    # Plotting
940    'plot', 'noplot', 'weights',
941    'linear', 'log', 'q4',
942    'rel', 'abs',
943    'hist', 'nohist',
944    'title=',
945
946    # Data generation
947    'data=', 'noise=', 'res=', 'nq=', 'q=',
948    'lowq', 'midq', 'highq', 'exq', 'zero',
949    '2d', '1d',
950
951    # Parameter set
952    'preset', 'random', 'random=', 'sets=',
953    'demo', 'default',  # TODO: remove demo/default
954    'nopars', 'pars',
955    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
956
957    # Calculation options
958    'poly', 'mono', 'cutoff=',
959    'magnetic', 'nonmagnetic',
960    'accuracy=', 'ngauss=',
961    'neval=',  # for timing...
962
963    # Precision options
964    'engine=',
965    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
966
967    # Output options
968    'help', 'html', 'edit',
969    ]
970
971NAME_OPTIONS = (lambda: set(k for k in OPTIONS if not k.endswith('=')))()
972VALUE_OPTIONS = (lambda: [k[:-1] for k in OPTIONS if k.endswith('=')])()
973
974
975def columnize(items, indent="", width=79):
976    # type: (List[str], str, int) -> str
977    """
978    Format a list of strings into columns.
979
980    Returns a string with carriage returns ready for printing.
981    """
982    column_width = max(len(w) for w in items) + 1
983    num_columns = (width - len(indent)) // column_width
984    num_rows = len(items) // num_columns
985    items = items + [""] * (num_rows * num_columns - len(items))
986    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
987    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
988             for row in zip(*columns)]
989    output = indent + ("\n"+indent).join(lines)
990    return output
991
992
993def get_pars(model_info, use_demo=False):
994    # type: (ModelInfo, bool) -> ParameterSet
995    """
996    Extract demo parameters from the model definition.
997    """
998    # Get the default values for the parameters
999    pars = {}
1000    for p in model_info.parameters.call_parameters:
1001        parts = [('', p.default)]
1002        if p.polydisperse:
1003            parts.append(('_pd', 0.0))
1004            parts.append(('_pd_n', 0))
1005            parts.append(('_pd_nsigma', 3.0))
1006            parts.append(('_pd_type', "gaussian"))
1007        for ext, val in parts:
1008            if p.length > 1:
1009                dict(("%s%d%s" % (p.id, k, ext), val)
1010                     for k in range(1, p.length+1))
1011            else:
1012                pars[p.id + ext] = val
1013
1014    # Plug in values given in demo
1015    if use_demo and model_info.demo:
1016        pars.update(model_info.demo)
1017    return pars
1018
1019INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1020def isnumber(s):
1021    # type: (str) -> bool
1022    """Return True if string contains an int or float"""
1023    match = FLOAT_RE.match(s)
1024    isfloat = (match and not s[match.end():])
1025    return isfloat or INTEGER_RE.match(s)
1026
1027# For distinguishing pairs of models for comparison
1028# key-value pair separator =
1029# shell characters  | & ; <> $ % ' " \ # `
1030# model and parameter names _
1031# parameter expressions - + * / . ( )
1032# path characters including tilde expansion and windows drive ~ / :
1033# not sure about brackets [] {}
1034# maybe one of the following @ ? ^ ! ,
1035PAR_SPLIT = ','
1036def parse_opts(argv):
1037    # type: (List[str]) -> Dict[str, Any]
1038    """
1039    Parse command line options.
1040    """
1041    MODELS = core.list_models()
1042    flags = [arg for arg in argv
1043             if arg.startswith('-')]
1044    values = [arg for arg in argv
1045              if not arg.startswith('-') and '=' in arg]
1046    positional_args = [arg for arg in argv
1047                       if not arg.startswith('-') and '=' not in arg]
1048    models = "\n    ".join("%-15s"%v for v in MODELS)
1049    if len(positional_args) == 0:
1050        print(USAGE)
1051        print("\nAvailable models:")
1052        print(columnize(MODELS, indent="  "))
1053        return None
1054
1055    invalid = [o[1:] for o in flags
1056               if o[1:] not in NAME_OPTIONS
1057               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1058    if invalid:
1059        print("Invalid options: %s"%(", ".join(invalid)))
1060        return None
1061
1062    name = positional_args[-1]
1063
1064    # pylint: disable=bad-whitespace,C0321
1065    # Interpret the flags
1066    opts = {
1067        'plot'      : True,
1068        'view'      : 'log',
1069        'is2d'      : False,
1070        'qmin'      : None,
1071        'qmax'      : 0.05,
1072        'nq'        : 128,
1073        'res'       : 0.0,
1074        'noise'     : 0.0,
1075        'accuracy'  : 'Low',
1076        'cutoff'    : '0.0',
1077        'seed'      : -1,  # default to preset
1078        'mono'      : True,
1079        # Default to magnetic a magnetic moment is set on the command line
1080        'magnetic'  : False,
1081        'show_pars' : False,
1082        'show_hist' : False,
1083        'rel_err'   : True,
1084        'explore'   : False,
1085        'use_demo'  : True,
1086        'zero'      : False,
1087        'html'      : False,
1088        'title'     : None,
1089        'datafile'  : None,
1090        'sets'      : 0,
1091        'engine'    : 'default',
1092        'count'     : '1',
1093        'show_weights' : False,
1094        'sphere'    : 0,
1095        'ngauss'    : '0',
1096    }
1097    for arg in flags:
1098        if arg == '-noplot':    opts['plot'] = False
1099        elif arg == '-plot':    opts['plot'] = True
1100        elif arg == '-linear':  opts['view'] = 'linear'
1101        elif arg == '-log':     opts['view'] = 'log'
1102        elif arg == '-q4':      opts['view'] = 'q4'
1103        elif arg == '-1d':      opts['is2d'] = False
1104        elif arg == '-2d':      opts['is2d'] = True
1105        elif arg == '-exq':     opts['qmax'] = 10.0
1106        elif arg == '-highq':   opts['qmax'] = 1.0
1107        elif arg == '-midq':    opts['qmax'] = 0.2
1108        elif arg == '-lowq':    opts['qmax'] = 0.05
1109        elif arg == '-zero':    opts['zero'] = True
1110        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1111        elif arg.startswith('-q='):
1112            opts['qmin'], opts['qmax'] = [float(v) for v in arg[3:].split(':')]
1113        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1114        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1115        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1116        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1117        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1118        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1119        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1120        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1121        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
1122        elif arg.startswith('-ngauss='):   opts['ngauss'] = arg[8:]
1123        elif arg.startswith('-random='):
1124            opts['seed'] = int(arg[8:])
1125            opts['sets'] = 0
1126        elif arg == '-random':
1127            opts['seed'] = np.random.randint(1000000)
1128            opts['sets'] = 0
1129        elif arg.startswith('-sphere'):
1130            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1131            opts['is2d'] = True
1132        elif arg == '-preset':  opts['seed'] = -1
1133        elif arg == '-mono':    opts['mono'] = True
1134        elif arg == '-poly':    opts['mono'] = False
1135        elif arg == '-magnetic':       opts['magnetic'] = True
1136        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1137        elif arg == '-pars':    opts['show_pars'] = True
1138        elif arg == '-nopars':  opts['show_pars'] = False
1139        elif arg == '-hist':    opts['show_hist'] = True
1140        elif arg == '-nohist':  opts['show_hist'] = False
1141        elif arg == '-rel':     opts['rel_err'] = True
1142        elif arg == '-abs':     opts['rel_err'] = False
1143        elif arg == '-half':    opts['engine'] = 'half'
1144        elif arg == '-fast':    opts['engine'] = 'fast'
1145        elif arg == '-single':  opts['engine'] = 'single'
1146        elif arg == '-double':  opts['engine'] = 'double'
1147        elif arg == '-single!': opts['engine'] = 'single!'
1148        elif arg == '-double!': opts['engine'] = 'double!'
1149        elif arg == '-quad!':   opts['engine'] = 'quad!'
1150        elif arg == '-edit':    opts['explore'] = True
1151        elif arg == '-demo':    opts['use_demo'] = True
1152        elif arg == '-default': opts['use_demo'] = False
1153        elif arg == '-weights': opts['show_weights'] = True
1154        elif arg == '-html':    opts['html'] = True
1155        elif arg == '-help':    opts['html'] = True
1156    # pylint: enable=bad-whitespace,C0321
1157
1158    # Magnetism forces 2D for now
1159    if opts['magnetic']:
1160        opts['is2d'] = True
1161
1162    # Force random if sets is used
1163    if opts['sets'] >= 1 and opts['seed'] < 0:
1164        opts['seed'] = np.random.randint(1000000)
1165    if opts['sets'] == 0:
1166        opts['sets'] = 1
1167
1168    # Create the computational engines
1169    if opts['qmin'] is None:
1170        opts['qmin'] = 0.001*opts['qmax']
1171    if opts['datafile'] is not None:
1172        data = load_data(os.path.expanduser(opts['datafile']))
1173    else:
1174        data, _ = make_data(opts)
1175
1176    comparison = any(PAR_SPLIT in v for v in values)
1177
1178    if PAR_SPLIT in name:
1179        names = name.split(PAR_SPLIT, 2)
1180        comparison = True
1181    else:
1182        names = [name]*2
1183    try:
1184        model_info = [core.load_model_info(k) for k in names]
1185    except ImportError as exc:
1186        print(str(exc))
1187        print("Could not find model; use one of:\n    " + models)
1188        return None
1189
1190    if PAR_SPLIT in opts['ngauss']:
1191        opts['ngauss'] = [int(k) for k in opts['ngauss'].split(PAR_SPLIT, 2)]
1192        comparison = True
1193    else:
1194        opts['ngauss'] = [int(opts['ngauss'])]*2
1195
1196    if PAR_SPLIT in opts['engine']:
1197        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
1198        comparison = True
1199    else:
1200        opts['engine'] = [opts['engine']]*2
1201
1202    if PAR_SPLIT in opts['count']:
1203        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
1204        comparison = True
1205    else:
1206        opts['count'] = [int(opts['count'])]*2
1207
1208    if PAR_SPLIT in opts['cutoff']:
1209        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1210        comparison = True
1211    else:
1212        opts['cutoff'] = [float(opts['cutoff'])]*2
1213
1214    base = make_engine(model_info[0], data, opts['engine'][0],
1215                       opts['cutoff'][0], opts['ngauss'][0])
1216    if comparison:
1217        comp = make_engine(model_info[1], data, opts['engine'][1],
1218                           opts['cutoff'][1], opts['ngauss'][1])
1219    else:
1220        comp = None
1221
1222    # pylint: disable=bad-whitespace
1223    # Remember it all
1224    opts.update({
1225        'data'      : data,
1226        'name'      : names,
1227        'info'      : model_info,
1228        'engines'   : [base, comp],
1229        'values'    : values,
1230    })
1231    # pylint: enable=bad-whitespace
1232
1233    # Set the integration parameters to the half sphere
1234    if opts['sphere'] > 0:
1235        set_spherical_integration_parameters(opts, opts['sphere'])
1236
1237    return opts
1238
1239def set_spherical_integration_parameters(opts, steps):
1240    # type: (Dict[str, Any], int) -> None
1241    """
1242    Set integration parameters for spherical integration over the entire
1243    surface in theta-phi coordinates.
1244    """
1245    # Set the integration parameters to the half sphere
1246    opts['values'].extend([
1247        #'theta=90',
1248        'theta_pd=%g'%(90/np.sqrt(3)),
1249        'theta_pd_n=%d'%steps,
1250        'theta_pd_type=rectangle',
1251        #'phi=0',
1252        'phi_pd=%g'%(180/np.sqrt(3)),
1253        'phi_pd_n=%d'%(2*steps),
1254        'phi_pd_type=rectangle',
1255        #'background=0',
1256    ])
1257    if 'psi' in opts['info'][0].parameters:
1258        opts['values'].extend([
1259            #'psi=0',
1260            'psi_pd=%g'%(180/np.sqrt(3)),
1261            'psi_pd_n=%d'%(2*steps),
1262            'psi_pd_type=rectangle',
1263        ])
1264
1265def parse_pars(opts, maxdim=np.inf):
1266    # type: (Dict[str, Any], float) -> Tuple[Dict[str, float], Dict[str, float]]
1267    """
1268    Generate a parameter set.
1269
1270    The default values come from the model, or a randomized model if a seed
1271    value is given.  Next, evaluate any parameter expressions, constraining
1272    the value of the parameter within and between models.  If *maxdim* is
1273    given, limit parameters with units of Angstrom to this value.
1274
1275    Returns a pair of parameter dictionaries for base and comparison models.
1276    """
1277    model_info, model_info2 = opts['info']
1278
1279    # Get demo parameters from model definition, or use default parameters
1280    # if model does not define demo parameters
1281    pars = get_pars(model_info, opts['use_demo'])
1282    pars2 = get_pars(model_info2, opts['use_demo'])
1283    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1284    # randomize parameters
1285    #pars.update(set_pars)  # set value before random to control range
1286    if opts['seed'] > -1:
1287        pars = randomize_pars(model_info, pars)
1288        limit_dimensions(model_info, pars, maxdim)
1289        if model_info != model_info2:
1290            pars2 = randomize_pars(model_info2, pars2)
1291            limit_dimensions(model_info2, pars2, maxdim)
1292            # Share values for parameters with the same name
1293            for k, v in pars.items():
1294                if k in pars2:
1295                    pars2[k] = v
1296        else:
1297            pars2 = pars.copy()
1298        constrain_pars(model_info, pars)
1299        constrain_pars(model_info2, pars2)
1300    pars = suppress_pd(pars, opts['mono'])
1301    pars2 = suppress_pd(pars2, opts['mono'])
1302    pars = suppress_magnetism(pars, not opts['magnetic'])
1303    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1304
1305    # Fill in parameters given on the command line
1306    presets = {}
1307    presets2 = {}
1308    for arg in opts['values']:
1309        k, v = arg.split('=', 1)
1310        if k not in pars and k not in pars2:
1311            # extract base name without polydispersity info
1312            s = set(p.split('_pd')[0] for p in pars)
1313            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1314            return None
1315        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v, v)
1316        if v1 and k in pars:
1317            presets[k] = float(v1) if isnumber(v1) else v1
1318        if v2 and k in pars2:
1319            presets2[k] = float(v2) if isnumber(v2) else v2
1320
1321    # If pd given on the command line, default pd_n to 35
1322    for k, v in list(presets.items()):
1323        if k.endswith('_pd'):
1324            presets.setdefault(k+'_n', 35.)
1325    for k, v in list(presets2.items()):
1326        if k.endswith('_pd'):
1327            presets2.setdefault(k+'_n', 35.)
1328
1329    # Evaluate preset parameter expressions
1330    context = MATH.copy()
1331    context['np'] = np
1332    context.update(pars)
1333    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1334    for k, v in presets.items():
1335        if not isinstance(v, float) and not k.endswith('_type'):
1336            presets[k] = eval(v, context)
1337    context.update(presets)
1338    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1339    for k, v in presets2.items():
1340        if not isinstance(v, float) and not k.endswith('_type'):
1341            presets2[k] = eval(v, context)
1342
1343    # update parameters with presets
1344    pars.update(presets)  # set value after random to control value
1345    pars2.update(presets2)  # set value after random to control value
1346    #import pprint; pprint.pprint(model_info)
1347
1348    if opts['show_pars']:
1349        if model_info.name != model_info2.name or pars != pars2:
1350            print("==== %s ====="%model_info.name)
1351            print(str(parlist(model_info, pars, opts['is2d'])))
1352            print("==== %s ====="%model_info2.name)
1353            print(str(parlist(model_info2, pars2, opts['is2d'])))
1354        else:
1355            print(str(parlist(model_info, pars, opts['is2d'])))
1356
1357    return pars, pars2
1358
1359def show_docs(opts):
1360    # type: (Dict[str, Any]) -> None
1361    """
1362    show html docs for the model
1363    """
1364    from .generate import make_html
1365    from . import rst2html
1366
1367    info = opts['info'][0]
1368    html = make_html(info)
1369    path = os.path.dirname(info.filename)
1370    url = "file://" + path.replace("\\", "/")[2:] + "/"
1371    rst2html.view_html_qtapp(html, url)
1372
1373def explore(opts):
1374    # type: (Dict[str, Any]) -> None
1375    """
1376    explore the model using the bumps gui.
1377    """
1378    import wx  # type: ignore
1379    from bumps.names import FitProblem  # type: ignore
1380    from bumps.gui.app_frame import AppFrame  # type: ignore
1381    from bumps.gui import signal
1382
1383    is_mac = "cocoa" in wx.version()
1384    # Create an app if not running embedded
1385    app = wx.App() if wx.GetApp() is None else None
1386    model = Explore(opts)
1387    problem = FitProblem(model)
1388    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1389    if not is_mac:
1390        frame.Show()
1391    frame.panel.set_model(model=problem)
1392    frame.panel.Layout()
1393    frame.panel.aui.Split(0, wx.TOP)
1394    def _reset_parameters(event):
1395        model.revert_values()
1396        signal.update_parameters(problem)
1397    frame.Bind(wx.EVT_TOOL, _reset_parameters, frame.ToolBar.GetToolByPos(1))
1398    if is_mac:
1399        frame.Show()
1400    # If running withing an app, start the main loop
1401    if app:
1402        app.MainLoop()
1403
1404class Explore(object):
1405    """
1406    Bumps wrapper for a SAS model comparison.
1407
1408    The resulting object can be used as a Bumps fit problem so that
1409    parameters can be adjusted in the GUI, with plots updated on the fly.
1410    """
1411    def __init__(self, opts):
1412        # type: (Dict[str, Any]) -> None
1413        from bumps.cli import config_matplotlib  # type: ignore
1414        from . import bumps_model
1415        config_matplotlib()
1416        self.opts = opts
1417        opts['pars'] = list(opts['pars'])
1418        p1, p2 = opts['pars']
1419        m1, m2 = opts['info']
1420        self.fix_p2 = m1 != m2 or p1 != p2
1421        model_info = m1
1422        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1423        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1424        if not opts['is2d']:
1425            for p in model_info.parameters.user_parameters({}, is2d=False):
1426                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1427                    k = p.name+ext
1428                    v = pars.get(k, None)
1429                    if v is not None:
1430                        v.range(*parameter_range(k, v.value))
1431        else:
1432            for k, v in pars.items():
1433                v.range(*parameter_range(k, v.value))
1434
1435        self.pars = pars
1436        self.starting_values = dict((k, v.value) for k, v in pars.items())
1437        self.pd_types = pd_types
1438        self.limits = None
1439
1440    def revert_values(self):
1441        # type: () -> None
1442        """
1443        Restore starting values of the parameters.
1444        """
1445        for k, v in self.starting_values.items():
1446            self.pars[k].value = v
1447
1448    def model_update(self):
1449        # type: () -> None
1450        """
1451        Respond to signal that model parameters have been changed.
1452        """
1453        pass
1454
1455    def numpoints(self):
1456        # type: () -> int
1457        """
1458        Return the number of points.
1459        """
1460        return len(self.pars) + 1  # so dof is 1
1461
1462    def parameters(self):
1463        # type: () -> Any   # Dict/List hierarchy of parameters
1464        """
1465        Return a dictionary of parameters.
1466        """
1467        return self.pars
1468
1469    def nllf(self):
1470        # type: () -> float
1471        """
1472        Return cost.
1473        """
1474        # pylint: disable=no-self-use
1475        return 0.  # No nllf
1476
1477    def plot(self, view='log'):
1478        # type: (str) -> None
1479        """
1480        Plot the data and residuals.
1481        """
1482        pars = dict((k, v.value) for k, v in self.pars.items())
1483        pars.update(self.pd_types)
1484        self.opts['pars'][0] = pars
1485        if not self.fix_p2:
1486            self.opts['pars'][1] = pars
1487        result = run_models(self.opts)
1488        limits = plot_models(self.opts, result, limits=self.limits)
1489        if self.limits is None:
1490            vmin, vmax = limits
1491            self.limits = vmax*1e-7, 1.3*vmax
1492            import pylab; pylab.clf()
1493            plot_models(self.opts, result, limits=self.limits)
1494
1495
1496def main(*argv):
1497    # type: (*str) -> None
1498    """
1499    Main program.
1500    """
1501    opts = parse_opts(argv)
1502    if opts is not None:
1503        if opts['seed'] > -1:
1504            print("Randomize using -random=%i"%opts['seed'])
1505            np.random.seed(opts['seed'])
1506        if opts['html']:
1507            show_docs(opts)
1508        elif opts['explore']:
1509            opts['pars'] = parse_pars(opts)
1510            if opts['pars'] is None:
1511                return
1512            explore(opts)
1513        else:
1514            compare(opts)
1515
1516if __name__ == "__main__":
1517    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.