source: sasmodels/sasmodels/compare.py @ fb7c176

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since fb7c176 was fb7c176, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

use integer for choice list with random models in sascomp

  • Property mode set to 100755
File size: 56.2 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import kernelcl
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel, get_mesh
45from .generate import FLOAT_RE, set_integration_size
46from .weights import plot_weights
47
48# pylint: disable=unused-import
49try:
50    from typing import Optional, Dict, Any, Callable, Tuple
51except ImportError:
52    pass
53else:
54    from .modelinfo import ModelInfo, Parameter, ParameterSet
55    from .data import Data
56    Calculator = Callable[[float], np.ndarray]
57# pylint: enable=unused-import
58
59USAGE = """
60usage: sascomp model [options...] [key=val]
61
62Generate and compare SAS models.  If a single model is specified it shows
63a plot of that model.  Different models can be compared, or the same model
64with different parameters.  The same model with the same parameters can
65be compared with different calculation engines to see the effects of precision
66on the resultant values.
67
68model or model1,model2 are the names of the models to compare (see below).
69
70Options (* for default):
71
72    === data generation ===
73    -data="path" uses q, dq from the data file
74    -noise=0 sets the measurement error dI/I
75    -res=0 sets the resolution width dQ/Q if calculating with resolution
76    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
77    -q=min:max alternative specification of qrange
78    -nq=128 sets the number of Q points in the data set
79    -1d*/-2d computes 1d or 2d data
80    -zero indicates that q=0 should be included
81
82    === model parameters ===
83    -preset*/-random[=seed] preset or random parameters
84    -sets=n generates n random datasets with the seed given by -random=seed
85    -pars/-nopars* prints the parameter set or not
86    -default/-demo* use demo vs default parameters
87    -sphere[=150] set up spherical integration over theta/phi using n points
88
89    === calculation options ===
90    -mono*/-poly force monodisperse or allow polydisperse random parameters
91    -cutoff=1e-5* cutoff value for including a point in polydispersity
92    -magnetic/-nonmagnetic* suppress magnetism
93    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
94    -neval=1 sets the number of evals for more accurate timing
95    -ngauss=0 overrides the number of points in the 1-D gaussian quadrature
96
97    === precision options ===
98    -engine=default uses the default calcution precision
99    -single/-double/-half/-fast sets an OpenCL calculation engine
100    -single!/-double!/-quad! sets an OpenMP calculation engine
101
102    === plotting ===
103    -plot*/-noplot plots or suppress the plot of the model
104    -linear/-log*/-q4 intensity scaling on plots
105    -hist/-nohist* plot histogram of relative error
106    -abs/-rel* plot relative or absolute error
107    -title="note" adds note to the plot title, after the model name
108    -weights shows weights plots for the polydisperse parameters
109    -profile shows the sld profile if the model has a plottable sld profile
110
111    === output options ===
112    -edit starts the parameter explorer
113    -help/-html shows the model docs instead of running the model
114
115    === environment variables ===
116    -DSAS_MODELPATH=path sets directory containing custom models
117    -DSAS_OPENCL=vendor:device|none sets the target OpenCL device
118    -DXDG_CACHE_HOME=~/.cache sets the pyopencl cache root (linux only)
119    -DSAS_COMPILER=tinycc|msvc|mingw|unix sets the DLL compiler
120    -DSAS_OPENMP=1 turns on OpenMP for the DLLs
121    -DSAS_DLL_PATH=path sets the path to the compiled modules
122
123The interpretation of quad precision depends on architecture, and may
124vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
125On unix and mac you may need single quotes around the DLL computation
126engines, such as -engine='single!,double!' since !, is treated as a history
127expansion request in the shell.
128
129Key=value pairs allow you to set specific values for the model parameters.
130Key=value1,value2 to compare different values of the same parameter. The
131value can be an expression including other parameters.
132
133Items later on the command line override those that appear earlier.
134
135Examples:
136
137    # compare single and double precision calculation for a barbell
138    sascomp barbell -engine=single,double
139
140    # generate 10 random lorentz models, with seed=27
141    sascomp lorentz -sets=10 -seed=27
142
143    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
144    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
145
146    # model timing test requires multiple evals to perform the estimate
147    sascomp pringle -engine=single,double -timing=100,100 -noplot
148"""
149
150# Update docs with command line usage string.   This is separate from the usual
151# doc string so that we can display it at run time if there is an error.
152# lin
153__doc__ = (__doc__  # pylint: disable=redefined-builtin
154           + """
155Program description
156-------------------
157
158""" + USAGE)
159
160kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
161
162def build_math_context():
163    # type: () -> Dict[str, Callable]
164    """build dictionary of functions from math module"""
165    return dict((k, getattr(math, k))
166                for k in dir(math) if not k.startswith('_'))
167
168#: list of math functions for use in evaluating parameters
169MATH = build_math_context()
170
171# CRUFT python 2.6
172if not hasattr(datetime.timedelta, 'total_seconds'):
173    def delay(dt):
174        """Return number date-time delta as number seconds"""
175        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
176else:
177    def delay(dt):
178        """Return number date-time delta as number seconds"""
179        return dt.total_seconds()
180
181
182class push_seed(object):
183    """
184    Set the seed value for the random number generator.
185
186    When used in a with statement, the random number generator state is
187    restored after the with statement is complete.
188
189    :Parameters:
190
191    *seed* : int or array_like, optional
192        Seed for RandomState
193
194    :Example:
195
196    Seed can be used directly to set the seed::
197
198        >>> from numpy.random import randint
199        >>> push_seed(24)
200        <...push_seed object at...>
201        >>> print(randint(0,1000000,3))
202        [242082    899 211136]
203
204    Seed can also be used in a with statement, which sets the random
205    number generator state for the enclosed computations and restores
206    it to the previous state on completion::
207
208        >>> with push_seed(24):
209        ...    print(randint(0,1000000,3))
210        [242082    899 211136]
211
212    Using nested contexts, we can demonstrate that state is indeed
213    restored after the block completes::
214
215        >>> with push_seed(24):
216        ...    print(randint(0,1000000))
217        ...    with push_seed(24):
218        ...        print(randint(0,1000000,3))
219        ...    print(randint(0,1000000))
220        242082
221        [242082    899 211136]
222        899
223
224    The restore step is protected against exceptions in the block::
225
226        >>> with push_seed(24):
227        ...    print(randint(0,1000000))
228        ...    try:
229        ...        with push_seed(24):
230        ...            print(randint(0,1000000,3))
231        ...            raise Exception()
232        ...    except Exception:
233        ...        print("Exception raised")
234        ...    print(randint(0,1000000))
235        242082
236        [242082    899 211136]
237        Exception raised
238        899
239    """
240    def __init__(self, seed=None):
241        # type: (Optional[int]) -> None
242        self._state = np.random.get_state()
243        np.random.seed(seed)
244
245    def __enter__(self):
246        # type: () -> None
247        pass
248
249    def __exit__(self, exc_type, exc_value, trace):
250        # type: (Any, BaseException, Any) -> None
251        np.random.set_state(self._state)
252
253def tic():
254    # type: () -> Callable[[], float]
255    """
256    Timer function.
257
258    Use "toc=tic()" to start the clock and "toc()" to measure
259    a time interval.
260    """
261    then = datetime.datetime.now()
262    return lambda: delay(datetime.datetime.now() - then)
263
264
265def set_beam_stop(data, radius, outer=None):
266    # type: (Data, float, float) -> None
267    """
268    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
269    """
270    if hasattr(data, 'qx_data'):
271        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
272        data.mask = (q < radius)
273        if outer is not None:
274            data.mask |= (q >= outer)
275    else:
276        data.mask = (data.x < radius)
277        if outer is not None:
278            data.mask |= (data.x >= outer)
279
280
281def parameter_range(p, v):
282    # type: (str, float) -> Tuple[float, float]
283    """
284    Choose a parameter range based on parameter name and initial value.
285    """
286    # process the polydispersity options
287    if p.endswith('_pd_n'):
288        return 0., 100.
289    elif p.endswith('_pd_nsigma'):
290        return 0., 5.
291    elif p.endswith('_pd_type'):
292        raise ValueError("Cannot return a range for a string value")
293    elif any(s in p for s in ('theta', 'phi', 'psi')):
294        # orientation in [-180,180], orientation pd in [0,45]
295        if p.endswith('_pd'):
296            return 0., 180.
297        else:
298            return -180., 180.
299    elif p.endswith('_pd'):
300        return 0., 1.
301    elif 'sld' in p:
302        return -0.5, 10.
303    elif p == 'background':
304        return 0., 10.
305    elif p == 'scale':
306        return 0., 1.e3
307    elif v < 0.:
308        return 2.*v, -2.*v
309    else:
310        return 0., (2.*v if v > 0. else 1.)
311
312
313def _randomize_one(model_info, name, value):
314    # type: (ModelInfo, str, float) -> float
315    # type: (ModelInfo, str, str) -> str
316    """
317    Randomize a single parameter.
318    """
319    # Set the amount of polydispersity/angular dispersion, but by default pd_n
320    # is zero so there is no polydispersity.  This allows us to turn on/off
321    # pd by setting pd_n, and still have randomly generated values
322    if name.endswith('_pd'):
323        par = model_info.parameters[name[:-3]]
324        if par.type == 'orientation':
325            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
326            return 180*np.random.beta(2.5, 20)
327        else:
328            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
329            return np.random.beta(1.5, 7)
330
331    # pd is selected globally rather than per parameter, so set to 0 for no pd
332    # In particular, when multiple pd dimensions, want to decrease the number
333    # of points per dimension for faster computation
334    if name.endswith('_pd_n'):
335        return 0
336
337    # Don't mess with distribution type for now
338    if name.endswith('_pd_type'):
339        return 'gaussian'
340
341    # type-dependent value of number of sigmas; for gaussian use 3.
342    if name.endswith('_pd_nsigma'):
343        return 3.
344
345    # background in the range [0.01, 1]
346    if name == 'background':
347        return 10**np.random.uniform(-2, 0)
348
349    # scale defaults to 0.1% to 30% volume fraction
350    if name == 'scale':
351        return 10**np.random.uniform(-3, -0.5)
352
353    # If it is a list of choices, pick one at random with equal probability
354    par = model_info.parameters[name]
355    if par.choices:  # choice list
356        return np.random.randint(len(par.choices))
357
358    # If it is a fixed range, pick from it with equal probability.
359    # For logarithmic ranges, the model will have to override.
360    if np.isfinite(par.limits).all():
361        return np.random.uniform(*par.limits)
362
363    # If the paramter is marked as an sld use the range of neutron slds
364    # TODO: ought to randomly contrast match a pair of SLDs
365    if par.type == 'sld':
366        return np.random.uniform(-0.5, 12)
367
368    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
369    if par.name.startswith('M0:'):
370        return np.random.uniform(0, 5)
371
372    # Guess at the random length/radius/thickness.  In practice, all models
373    # are going to set their own reasonable ranges.
374    if par.type == 'volume':
375        if ('length' in par.name or
376                'radius' in par.name or
377                'thick' in par.name):
378            return 10**np.random.uniform(2, 4)
379
380    # In the absence of any other info, select a value in [0, 2v], or
381    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
382    # model random parameter generators will override this default.
383    low, high = parameter_range(par.name, value)
384    limits = (max(par.limits[0], low), min(par.limits[1], high))
385    return np.random.uniform(*limits)
386
387def _random_pd(model_info, pars):
388    # type: (ModelInfo, Dict[str, float]) -> None
389    """
390    Generate a random dispersity distribution for the model.
391
392    1% no shape dispersity
393    85% single shape parameter
394    13% two shape parameters
395    1% three shape parameters
396
397    If oriented, then put dispersity in theta, add phi and psi dispersity
398    with 10% probability for each.
399    """
400    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
401    pd_volume = []
402    pd_oriented = []
403    for p in pd:
404        if p.type == 'orientation':
405            pd_oriented.append(p.name)
406        elif p.length_control is not None:
407            n = int(pars.get(p.length_control, 1) + 0.5)
408            pd_volume.extend(p.name+str(k+1) for k in range(n))
409        elif p.length > 1:
410            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
411        else:
412            pd_volume.append(p.name)
413    u = np.random.rand()
414    n = len(pd_volume)
415    if u < 0.01 or n < 1:
416        pass  # 1% chance of no polydispersity
417    elif u < 0.86 or n < 2:
418        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
419    elif u < 0.99 or n < 3:
420        choices = np.random.choice(len(pd_volume), size=2)
421        pars[pd_volume[choices[0]]+"_pd_n"] = 25
422        pars[pd_volume[choices[1]]+"_pd_n"] = 10
423    else:
424        choices = np.random.choice(len(pd_volume), size=3)
425        pars[pd_volume[choices[0]]+"_pd_n"] = 25
426        pars[pd_volume[choices[1]]+"_pd_n"] = 10
427        pars[pd_volume[choices[2]]+"_pd_n"] = 5
428    if pd_oriented:
429        pars['theta_pd_n'] = 20
430        if np.random.rand() < 0.1:
431            pars['phi_pd_n'] = 5
432        if np.random.rand() < 0.1:
433            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
434                #print("generating psi_pd_n")
435                pars['psi_pd_n'] = 5
436
437    ## Show selected polydispersity
438    #for name, value in pars.items():
439    #    if name.endswith('_pd_n') and value > 0:
440    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
441
442
443def randomize_pars(model_info, pars):
444    # type: (ModelInfo, ParameterSet) -> ParameterSet
445    """
446    Generate random values for all of the parameters.
447
448    Valid ranges for the random number generator are guessed from the name of
449    the parameter; this will not account for constraints such as cap radius
450    greater than cylinder radius in the capped_cylinder model, so
451    :func:`constrain_pars` needs to be called afterward..
452    """
453    # Note: the sort guarantees order of calls to random number generator
454    random_pars = dict((p, _randomize_one(model_info, p, v))
455                       for p, v in sorted(pars.items()))
456    if model_info.random is not None:
457        random_pars.update(model_info.random())
458    _random_pd(model_info, random_pars)
459    return random_pars
460
461
462def limit_dimensions(model_info, pars, maxdim):
463    # type: (ModelInfo, ParameterSet, float) -> None
464    """
465    Limit parameters of units of Ang to maxdim.
466    """
467    for p in model_info.parameters.call_parameters:
468        value = pars[p.name]
469        if p.units == 'Ang' and value > maxdim:
470            pars[p.name] = maxdim*10**np.random.uniform(-3, 0)
471
472def constrain_pars(model_info, pars):
473    # type: (ModelInfo, ParameterSet) -> None
474    """
475    Restrict parameters to valid values.
476
477    This includes model specific code for models such as capped_cylinder
478    which need to support within model constraints (cap radius more than
479    cylinder radius in this case).
480
481    Warning: this updates the *pars* dictionary in place.
482    """
483    # TODO: move the model specific code to the individual models
484    name = model_info.id
485    # if it is a product model, then just look at the form factor since
486    # none of the structure factors need any constraints.
487    if '*' in name:
488        name = name.split('*')[0]
489
490    # Suppress magnetism for python models (not yet implemented)
491    if callable(model_info.Iq):
492        pars.update(suppress_magnetism(pars))
493
494    if name == 'barbell':
495        if pars['radius_bell'] < pars['radius']:
496            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
497
498    elif name == 'capped_cylinder':
499        if pars['radius_cap'] < pars['radius']:
500            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
501
502    elif name == 'guinier':
503        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
504        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
505        # => ln A - (Rg^2 q^2/3) > -30 ln 10
506        # => Rg^2 q^2/3 < 30 ln 10 + ln A
507        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
508        #q_max = 0.2  # mid q maximum
509        q_max = 1.0  # high q maximum
510        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
511        pars['rg'] = min(pars['rg'], rg_max)
512
513    elif name == 'pearl_necklace':
514        if pars['radius'] < pars['thick_string']:
515            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
516
517    elif name == 'rpa':
518        # Make sure phi sums to 1.0
519        if pars['case_num'] < 2:
520            pars['Phi1'] = 0.
521            pars['Phi2'] = 0.
522        elif pars['case_num'] < 5:
523            pars['Phi1'] = 0.
524        total = sum(pars['Phi'+c] for c in '1234')
525        for c in '1234':
526            pars['Phi'+c] /= total
527
528def parlist(model_info, pars, is2d):
529    # type: (ModelInfo, ParameterSet, bool) -> str
530    """
531    Format the parameter list for printing.
532    """
533    is2d = True
534    lines = []
535    parameters = model_info.parameters
536    magnetic = False
537    magnetic_pars = []
538    for p in parameters.user_parameters(pars, is2d):
539        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
540            continue
541        if p.id.startswith('up:'):
542            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
543            continue
544        fields = dict(
545            value=pars.get(p.id, p.default),
546            pd=pars.get(p.id+"_pd", 0.),
547            n=int(pars.get(p.id+"_pd_n", 0)),
548            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
549            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
550            relative_pd=p.relative_pd,
551            M0=pars.get('M0:'+p.id, 0.),
552            mphi=pars.get('mphi:'+p.id, 0.),
553            mtheta=pars.get('mtheta:'+p.id, 0.),
554        )
555        lines.append(_format_par(p.name, **fields))
556        magnetic = magnetic or fields['M0'] != 0.
557    if magnetic and magnetic_pars:
558        lines.append(" ".join(magnetic_pars))
559    return "\n".join(lines)
560
561    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
562
563def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
564                relative_pd=False, M0=0., mphi=0., mtheta=0.):
565    # type: (str, float, float, int, float, str) -> str
566    line = "%s: %g"%(name, value)
567    if pd != 0.  and n != 0:
568        if relative_pd:
569            pd *= value
570        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
571                % (pd, n, nsigma, nsigma, pdtype)
572    if M0 != 0.:
573        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
574    return line
575
576def suppress_pd(pars, suppress=True):
577    # type: (ParameterSet) -> ParameterSet
578    """
579    If suppress is True complete eliminate polydispersity of the model to test
580    models more quickly.  If suppress is False, make sure at least one
581    parameter is polydisperse, setting the first polydispersity parameter to
582    15% if no polydispersity is given (with no explicit demo parameters given
583    in the model, there will be no default polydispersity).
584    """
585    pars = pars.copy()
586    #print("pars=", pars)
587    if suppress:
588        for p in pars:
589            if p.endswith("_pd_n"):
590                pars[p] = 0
591    else:
592        any_pd = False
593        first_pd = None
594        for p in pars:
595            if p.endswith("_pd_n"):
596                pd = pars.get(p[:-2], 0.)
597                any_pd |= (pars[p] != 0 and pd != 0.)
598                if first_pd is None:
599                    first_pd = p
600        if not any_pd and first_pd is not None:
601            if pars[first_pd] == 0:
602                pars[first_pd] = 35
603            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
604                pars[first_pd[:-2]] = 0.15
605    return pars
606
607def suppress_magnetism(pars, suppress=True):
608    # type: (ParameterSet) -> ParameterSet
609    """
610    If suppress is True complete eliminate magnetism of the model to test
611    models more quickly.  If suppress is False, make sure at least one sld
612    parameter is magnetic, setting the first parameter to have a strong
613    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
614    in the model, there will be no default magnetism).
615    """
616    pars = pars.copy()
617    if suppress:
618        for p in pars:
619            if p.startswith("M0:"):
620                pars[p] = 0
621    else:
622        any_mag = False
623        first_mag = None
624        for p in pars:
625            if p.startswith("M0:"):
626                any_mag |= (pars[p] != 0)
627                if first_mag is None:
628                    first_mag = p
629        if not any_mag and first_mag is not None:
630            pars[first_mag] = 8.
631    return pars
632
633
634def time_calculation(calculator, pars, evals=1):
635    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
636    """
637    Compute the average calculation time over N evaluations.
638
639    An additional call is generated without polydispersity in order to
640    initialize the calculation engine, and make the average more stable.
641    """
642    # initialize the code so time is more accurate
643    if evals > 1:
644        calculator(**suppress_pd(pars))
645    toc = tic()
646    # make sure there is at least one eval
647    value = calculator(**pars)
648    for _ in range(evals-1):
649        value = calculator(**pars)
650    average_time = toc()*1000. / evals
651    #print("I(q)",value)
652    return value, average_time
653
654def make_data(opts):
655    # type: (Dict[str, Any], float) -> Tuple[Data, np.ndarray]
656    """
657    Generate an empty dataset, used with the model to set Q points
658    and resolution.
659
660    *opts* contains the options, with 'qmax', 'nq', 'res',
661    'accuracy', 'is2d' and 'view' parsed from the command line.
662    """
663    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res']
664    if opts['is2d']:
665        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
666        data = empty_data2D(q, resolution=res)
667        data.accuracy = opts['accuracy']
668        set_beam_stop(data, qmin)
669        index = ~data.mask
670    else:
671        if opts['view'] == 'log' and not opts['zero']:
672            q = np.logspace(math.log10(qmin), math.log10(qmax), nq)
673        else:
674            q = np.linspace(qmin, qmax, nq)
675        if opts['zero']:
676            q = np.hstack((0, q))
677        # TODO: provide command line control of lambda and Delta lambda/lambda
678        #L, dLoL = 5, 0.14/np.sqrt(6)  # wavelength and 14% triangular FWHM
679        L, dLoL = 0, 0
680        data = empty_data1D(q, resolution=res, L=L, dL=L*dLoL)
681        index = slice(None, None)
682    return data, index
683
684DTYPE_MAP = {
685    'half': '16',
686    'fast': 'fast',
687    'single': '32',
688    'double': '64',
689    'quad': '128',
690    'f16': '16',
691    'f32': '32',
692    'f64': '64',
693    'float16': '16',
694    'float32': '32',
695    'float64': '64',
696    'float128': '128',
697    'longdouble': '128',
698}
699def eval_opencl(model_info, data, dtype='single', cutoff=0.):
700    # type: (ModelInfo, Data, str, float) -> Calculator
701    """
702    Return a model calculator using the OpenCL calculation engine.
703    """
704
705def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
706    # type: (ModelInfo, Data, str, float) -> Calculator
707    """
708    Return a model calculator using the DLL calculation engine.
709    """
710    model = core.build_model(model_info, dtype=dtype, platform="dll")
711    calculator = DirectModel(data, model, cutoff=cutoff)
712    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
713    return calculator
714
715def make_engine(model_info, data, dtype, cutoff, ngauss=0):
716    # type: (ModelInfo, Data, str, float) -> Calculator
717    """
718    Generate the appropriate calculation engine for the given datatype.
719
720    Datatypes with '!' appended are evaluated using external C DLLs rather
721    than OpenCL.
722    """
723    if ngauss:
724        set_integration_size(model_info, ngauss)
725
726    if dtype != "default" and not dtype.endswith('!') and not kernelcl.use_opencl():
727        raise RuntimeError("OpenCL not available " + kernelcl.OPENCL_ERROR)
728
729    model = core.build_model(model_info, dtype=dtype, platform="ocl")
730    calculator = DirectModel(data, model, cutoff=cutoff)
731    engine_type = calculator._model.__class__.__name__.replace('Model', '').upper()
732    bits = calculator._model.dtype.itemsize*8
733    precision = "fast" if getattr(calculator._model, 'fast', False) else str(bits)
734    calculator.engine = "%s[%s]" % (engine_type, precision)
735    return calculator
736
737def _show_invalid(data, theory):
738    # type: (Data, np.ma.ndarray) -> None
739    """
740    Display a list of the non-finite values in theory.
741    """
742    if not theory.mask.any():
743        return
744
745    if hasattr(data, 'x'):
746        bad = zip(data.x[theory.mask], theory[theory.mask])
747        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
748
749
750def compare(opts, limits=None, maxdim=np.inf):
751    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
752    """
753    Preform a comparison using options from the command line.
754
755    *limits* are the limits on the values to use, either to set the y-axis
756    for 1D or to set the colormap scale for 2D.  If None, then they are
757    inferred from the data and returned. When exploring using Bumps,
758    the limits are set when the model is initially called, and maintained
759    as the values are adjusted, making it easier to see the effects of the
760    parameters.
761
762    *maxdim* is the maximum value for any parameter with units of Angstrom.
763    """
764    for k in range(opts['sets']):
765        if k > 0:
766            # print a separate seed for each dataset for better reproducibility
767            new_seed = np.random.randint(1000000)
768            print("=== Set %d uses -random=%i ==="%(k+1, new_seed))
769            np.random.seed(new_seed)
770        opts['pars'] = parse_pars(opts, maxdim=maxdim)
771        if opts['pars'] is None:
772            return
773        result = run_models(opts, verbose=True)
774        if opts['plot']:
775            if opts['is2d'] and k > 0:
776                import matplotlib.pyplot as plt
777                plt.figure()
778            limits = plot_models(opts, result, limits=limits, setnum=k)
779        if opts['show_weights']:
780            base, _ = opts['engines']
781            base_pars, _ = opts['pars']
782            model_info = base._kernel.info
783            dim = base._kernel.dim
784            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
785        if opts['show_profile']:
786            import pylab
787            base, comp = opts['engines']
788            base_pars, comp_pars = opts['pars']
789            have_base = base._kernel.info.profile is not None
790            have_comp = (
791                comp is not None
792                and comp._kernel.info.profile is not None
793                and base_pars != comp_pars
794            )
795            if have_base or have_comp:
796                pylab.figure()
797            if have_base:
798                plot_profile(base._kernel.info, **base_pars)
799            if have_comp:
800                plot_profile(comp._kernel.info, label='comp', **comp_pars)
801                pylab.legend()
802    if opts['plot']:
803        import matplotlib.pyplot as plt
804        plt.show()
805    return limits
806
807def plot_profile(model_info, label='base', **args):
808    # type: (ModelInfo, List[Tuple[float, np.ndarray, np.ndarray]]) -> None
809    """
810    Plot the profile returned by the model profile method.
811
812    *model_info* defines model parameters, etc.
813
814    *mesh* is a list of tuples containing (*value*, *dispersity*, *weights*)
815    for each parameter, where (*dispersity*, *weights*) pairs are the
816    distributions to be plotted.
817    """
818    import pylab
819
820    args = dict((k, v) for k, v in args.items()
821                if "_pd" not in k
822                   and ":" not in k
823                   and k not in ("background", "scale", "theta", "phi", "psi"))
824    args = args.copy()
825
826    args.pop('scale', 1.)
827    args.pop('background', 0.)
828    z, rho = model_info.profile(**args)
829    #pylab.interactive(True)
830    pylab.plot(z, rho, '-', label=label)
831    pylab.grid(True)
832    #pylab.show()
833
834
835
836def run_models(opts, verbose=False):
837    # type: (Dict[str, Any]) -> Dict[str, Any]
838    """
839    Process a parameter set, return calculation results and times.
840    """
841
842    base, comp = opts['engines']
843    base_n, comp_n = opts['count']
844    base_pars, comp_pars = opts['pars']
845    base_data, comp_data = opts['data']
846
847    comparison = comp is not None
848
849    base_time = comp_time = None
850    base_value = comp_value = resid = relerr = None
851
852    # Base calculation
853    try:
854        base_raw, base_time = time_calculation(base, base_pars, base_n)
855        base_value = np.ma.masked_invalid(base_raw)
856        if verbose:
857            print("%s t=%.2f ms, intensity=%.0f"
858                  % (base.engine, base_time, base_value.sum()))
859        _show_invalid(base_data, base_value)
860    except ImportError:
861        traceback.print_exc()
862
863    # Comparison calculation
864    if comparison:
865        try:
866            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
867            comp_value = np.ma.masked_invalid(comp_raw)
868            if verbose:
869                print("%s t=%.2f ms, intensity=%.0f"
870                      % (comp.engine, comp_time, comp_value.sum()))
871            _show_invalid(base_data, comp_value)
872        except ImportError:
873            traceback.print_exc()
874
875    # Compare, but only if computing both forms
876    if comparison:
877        resid = (base_value - comp_value)
878        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
879        if verbose:
880            _print_stats("|%s-%s|"
881                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
882                         resid)
883            _print_stats("|(%s-%s)/%s|"
884                         % (base.engine, comp.engine, comp.engine),
885                         relerr)
886
887    return dict(base_value=base_value, comp_value=comp_value,
888                base_time=base_time, comp_time=comp_time,
889                resid=resid, relerr=relerr)
890
891
892def _print_stats(label, err):
893    # type: (str, np.ma.ndarray) -> None
894    # work with trimmed data, not the full set
895    sorted_err = np.sort(abs(err.compressed()))
896    if len(sorted_err) == 0:
897        print(label + "  no valid values")
898        return
899
900    p50 = int((len(sorted_err)-1)*0.50)
901    p98 = int((len(sorted_err)-1)*0.98)
902    data = [
903        "max:%.3e"%sorted_err[-1],
904        "median:%.3e"%sorted_err[p50],
905        "98%%:%.3e"%sorted_err[p98],
906        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
907        "zero-offset:%+.3e"%np.mean(sorted_err),
908        ]
909    print(label+"  "+"  ".join(data))
910
911
912def plot_models(opts, result, limits=None, setnum=0):
913    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
914    """
915    Plot the results from :func:`run_model`.
916    """
917    import matplotlib.pyplot as plt
918
919    base_value, comp_value = result['base_value'], result['comp_value']
920    base_time, comp_time = result['base_time'], result['comp_time']
921    resid, relerr = result['resid'], result['relerr']
922
923    have_base, have_comp = (base_value is not None), (comp_value is not None)
924    base, comp = opts['engines']
925    base_data, comp_data = opts['data']
926    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
927
928    # Plot if requested
929    view = opts['view']
930    #view = 'log'
931    if limits is None:
932        vmin, vmax = np.inf, -np.inf
933        if have_base:
934            vmin = min(vmin, base_value.min())
935            vmax = max(vmax, base_value.max())
936        if have_comp:
937            vmin = min(vmin, comp_value.min())
938            vmax = max(vmax, comp_value.max())
939        limits = vmin, vmax
940
941    if have_base:
942        if have_comp:
943            plt.subplot(131)
944        plot_theory(base_data, base_value, view=view, use_data=use_data, limits=limits)
945        plt.title("%s t=%.2f ms"%(base.engine, base_time))
946        #cbar_title = "log I"
947    if have_comp:
948        if have_base:
949            plt.subplot(132)
950        if not opts['is2d'] and have_base:
951            plot_theory(comp_data, base_value, view=view, use_data=use_data, limits=limits)
952        plot_theory(comp_data, comp_value, view=view, use_data=use_data, limits=limits)
953        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
954        #cbar_title = "log I"
955    if have_base and have_comp:
956        plt.subplot(133)
957        if not opts['rel_err']:
958            err, errstr, errview = resid, "abs err", "linear"
959        else:
960            err, errstr, errview = abs(relerr), "rel err", "log"
961            if (err == 0.).all():
962                errview = 'linear'
963        if 0:  # 95% cutoff
964            sorted_err = np.sort(err.flatten())
965            cutoff = sorted_err[int(sorted_err.size*0.95)]
966            err[err > cutoff] = cutoff
967        #err,errstr = base/comp,"ratio"
968        # Note: base_data only since base and comp have same q values (though
969        # perhaps different resolution), and we are plotting the difference
970        # at each q
971        plot_theory(base_data, None, resid=err, view=errview, use_data=use_data)
972        plt.xscale('log' if view == 'log' and not opts['is2d'] else 'linear')
973        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
974        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
975        #cbar_title = errstr if errview=="linear" else "log "+errstr
976    #if is2D:
977    #    h = plt.colorbar()
978    #    h.ax.set_title(cbar_title)
979    fig = plt.gcf()
980    extra_title = ' '+opts['title'] if opts['title'] else ''
981    fig.suptitle(":".join(opts['name']) + extra_title)
982
983    if have_base and have_comp and opts['show_hist']:
984        plt.figure()
985        v = relerr
986        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
987        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
988        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
989                   % (base.engine, comp.engine, comp.engine))
990        plt.ylabel('P(err)')
991        plt.title('Distribution of relative error between calculation engines')
992
993    return limits
994
995
996# ===========================================================================
997#
998
999# Set of command line options.
1000# Normal options such as -plot/-noplot are specified as 'name'.
1001# For options such as -nq=500 which require a value use 'name='.
1002#
1003OPTIONS = [
1004    # Plotting
1005    'plot', 'noplot',
1006    'weights', 'profile',
1007    'linear', 'log', 'q4',
1008    'rel', 'abs',
1009    'hist', 'nohist',
1010    'title=',
1011
1012    # Data generation
1013    'data=', 'noise=', 'res=', 'nq=', 'q=',
1014    'lowq', 'midq', 'highq', 'exq', 'zero',
1015    '2d', '1d',
1016
1017    # Parameter set
1018    'preset', 'random', 'random=', 'sets=',
1019    'demo', 'default',  # TODO: remove demo/default
1020    'nopars', 'pars',
1021    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
1022
1023    # Calculation options
1024    'poly', 'mono', 'cutoff=',
1025    'magnetic', 'nonmagnetic',
1026    'accuracy=', 'ngauss=',
1027    'neval=',  # for timing...
1028
1029    # Precision options
1030    'engine=',
1031    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1032
1033    # Output options
1034    'help', 'html', 'edit',
1035    ]
1036
1037NAME_OPTIONS = (lambda: set(k for k in OPTIONS if not k.endswith('=')))()
1038VALUE_OPTIONS = (lambda: [k[:-1] for k in OPTIONS if k.endswith('=')])()
1039
1040
1041def columnize(items, indent="", width=79):
1042    # type: (List[str], str, int) -> str
1043    """
1044    Format a list of strings into columns.
1045
1046    Returns a string with carriage returns ready for printing.
1047    """
1048    column_width = max(len(w) for w in items) + 1
1049    num_columns = (width - len(indent)) // column_width
1050    num_rows = len(items) // num_columns
1051    items = items + [""] * (num_rows * num_columns - len(items))
1052    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1053    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1054             for row in zip(*columns)]
1055    output = indent + ("\n"+indent).join(lines)
1056    return output
1057
1058
1059def get_pars(model_info, use_demo=False):
1060    # type: (ModelInfo, bool) -> ParameterSet
1061    """
1062    Extract demo parameters from the model definition.
1063    """
1064    # Get the default values for the parameters
1065    pars = {}
1066    for p in model_info.parameters.call_parameters:
1067        parts = [('', p.default)]
1068        if p.polydisperse:
1069            parts.append(('_pd', 0.0))
1070            parts.append(('_pd_n', 0))
1071            parts.append(('_pd_nsigma', 3.0))
1072            parts.append(('_pd_type', "gaussian"))
1073        for ext, val in parts:
1074            if p.length > 1:
1075                dict(("%s%d%s" % (p.id, k, ext), val)
1076                     for k in range(1, p.length+1))
1077            else:
1078                pars[p.id + ext] = val
1079
1080    # Plug in values given in demo
1081    if use_demo and model_info.demo:
1082        pars.update(model_info.demo)
1083    return pars
1084
1085INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1086def isnumber(s):
1087    # type: (str) -> bool
1088    """Return True if string contains an int or float"""
1089    match = FLOAT_RE.match(s)
1090    isfloat = (match and not s[match.end():])
1091    return isfloat or INTEGER_RE.match(s)
1092
1093# For distinguishing pairs of models for comparison
1094# key-value pair separator =
1095# shell characters  | & ; <> $ % ' " \ # `
1096# model and parameter names _
1097# parameter expressions - + * / . ( )
1098# path characters including tilde expansion and windows drive ~ / :
1099# not sure about brackets [] {}
1100# maybe one of the following @ ? ^ ! ,
1101PAR_SPLIT = ','
1102def parse_opts(argv):
1103    # type: (List[str]) -> Dict[str, Any]
1104    """
1105    Parse command line options.
1106    """
1107    MODELS = core.list_models()
1108    flags = [arg for arg in argv
1109             if arg.startswith('-')]
1110    values = [arg for arg in argv
1111              if not arg.startswith('-') and '=' in arg]
1112    positional_args = [arg for arg in argv
1113                       if not arg.startswith('-') and '=' not in arg]
1114    models = "\n    ".join("%-15s"%v for v in MODELS)
1115    if len(positional_args) == 0:
1116        print(USAGE)
1117        print("\nAvailable models:")
1118        print(columnize(MODELS, indent="  "))
1119        return None
1120
1121    invalid = [o[1:] for o in flags
1122               if not (o[1:] in NAME_OPTIONS
1123                       or any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)
1124                       or o.startswith('-D'))]
1125    if invalid:
1126        print("Invalid options: %s"%(", ".join(invalid)))
1127        return None
1128
1129    name = positional_args[-1]
1130
1131    # pylint: disable=bad-whitespace,C0321
1132    # Interpret the flags
1133    opts = {
1134        'plot'      : True,
1135        'view'      : 'log',
1136        'is2d'      : False,
1137        'qmin'      : None,
1138        'qmax'      : 0.05,
1139        'nq'        : 128,
1140        'res'       : '0.0',
1141        'noise'     : 0.0,
1142        'accuracy'  : 'Low',
1143        'cutoff'    : '0.0',
1144        'seed'      : -1,  # default to preset
1145        'mono'      : True,
1146        # Default to magnetic a magnetic moment is set on the command line
1147        'magnetic'  : False,
1148        'show_pars' : False,
1149        'show_hist' : False,
1150        'rel_err'   : True,
1151        'explore'   : False,
1152        'use_demo'  : True,
1153        'zero'      : False,
1154        'html'      : False,
1155        'title'     : None,
1156        'datafile'  : None,
1157        'sets'      : 0,
1158        'engine'    : 'default',
1159        'count'     : '1',
1160        'show_weights' : False,
1161        'show_profile' : False,
1162        'sphere'    : 0,
1163        'ngauss'    : '0',
1164    }
1165    for arg in flags:
1166        if arg == '-noplot':    opts['plot'] = False
1167        elif arg == '-plot':    opts['plot'] = True
1168        elif arg == '-linear':  opts['view'] = 'linear'
1169        elif arg == '-log':     opts['view'] = 'log'
1170        elif arg == '-q4':      opts['view'] = 'q4'
1171        elif arg == '-1d':      opts['is2d'] = False
1172        elif arg == '-2d':      opts['is2d'] = True
1173        elif arg == '-exq':     opts['qmax'] = 10.0
1174        elif arg == '-highq':   opts['qmax'] = 1.0
1175        elif arg == '-midq':    opts['qmax'] = 0.2
1176        elif arg == '-lowq':    opts['qmax'] = 0.05
1177        elif arg == '-zero':    opts['zero'] = True
1178        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1179        elif arg.startswith('-q='):
1180            opts['qmin'], opts['qmax'] = [float(v) for v in arg[3:].split(':')]
1181        elif arg.startswith('-res='):      opts['res'] = arg[5:]
1182        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1183        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1184        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1185        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1186        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1187        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1188        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1189        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
1190        elif arg.startswith('-ngauss='):   opts['ngauss'] = arg[8:]
1191        elif arg.startswith('-random='):
1192            opts['seed'] = int(arg[8:])
1193            opts['sets'] = 0
1194        elif arg == '-random':
1195            opts['seed'] = np.random.randint(1000000)
1196            opts['sets'] = 0
1197        elif arg.startswith('-sphere'):
1198            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1199            opts['is2d'] = True
1200        elif arg == '-preset':  opts['seed'] = -1
1201        elif arg == '-mono':    opts['mono'] = True
1202        elif arg == '-poly':    opts['mono'] = False
1203        elif arg == '-magnetic':       opts['magnetic'] = True
1204        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1205        elif arg == '-pars':    opts['show_pars'] = True
1206        elif arg == '-nopars':  opts['show_pars'] = False
1207        elif arg == '-hist':    opts['show_hist'] = True
1208        elif arg == '-nohist':  opts['show_hist'] = False
1209        elif arg == '-rel':     opts['rel_err'] = True
1210        elif arg == '-abs':     opts['rel_err'] = False
1211        elif arg == '-half':    opts['engine'] = 'half'
1212        elif arg == '-fast':    opts['engine'] = 'fast'
1213        elif arg == '-single':  opts['engine'] = 'single'
1214        elif arg == '-double':  opts['engine'] = 'double'
1215        elif arg == '-single!': opts['engine'] = 'single!'
1216        elif arg == '-double!': opts['engine'] = 'double!'
1217        elif arg == '-quad!':   opts['engine'] = 'quad!'
1218        elif arg == '-edit':    opts['explore'] = True
1219        elif arg == '-demo':    opts['use_demo'] = True
1220        elif arg == '-default': opts['use_demo'] = False
1221        elif arg == '-weights': opts['show_weights'] = True
1222        elif arg == '-profile': opts['show_profile'] = True
1223        elif arg == '-html':    opts['html'] = True
1224        elif arg == '-help':    opts['html'] = True
1225        elif arg.startswith('-D'):
1226            var, val = arg[2:].split('=')
1227            os.environ[var] = val
1228    # pylint: enable=bad-whitespace,C0321
1229
1230    # Magnetism forces 2D for now
1231    if opts['magnetic']:
1232        opts['is2d'] = True
1233
1234    # Force random if sets is used
1235    if opts['sets'] >= 1 and opts['seed'] < 0:
1236        opts['seed'] = np.random.randint(1000000)
1237    if opts['sets'] == 0:
1238        opts['sets'] = 1
1239
1240    # Create the computational engines
1241    if opts['qmin'] is None:
1242        opts['qmin'] = 0.001*opts['qmax']
1243
1244    comparison = any(PAR_SPLIT in v for v in values)
1245
1246    if PAR_SPLIT in name:
1247        names = name.split(PAR_SPLIT, 2)
1248        comparison = True
1249    else:
1250        names = [name]*2
1251    try:
1252        model_info = [core.load_model_info(k) for k in names]
1253    except ImportError as exc:
1254        print(str(exc))
1255        print("Could not find model; use one of:\n    " + models)
1256        return None
1257
1258    if PAR_SPLIT in opts['ngauss']:
1259        opts['ngauss'] = [int(k) for k in opts['ngauss'].split(PAR_SPLIT, 2)]
1260        comparison = True
1261    else:
1262        opts['ngauss'] = [int(opts['ngauss'])]*2
1263
1264    if PAR_SPLIT in opts['engine']:
1265        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
1266        comparison = True
1267    else:
1268        opts['engine'] = [opts['engine']]*2
1269
1270    if PAR_SPLIT in opts['count']:
1271        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
1272        comparison = True
1273    else:
1274        opts['count'] = [int(opts['count'])]*2
1275
1276    if PAR_SPLIT in opts['cutoff']:
1277        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1278        comparison = True
1279    else:
1280        opts['cutoff'] = [float(opts['cutoff'])]*2
1281
1282    if PAR_SPLIT in opts['res']:
1283        opts['res'] = [float(k) for k in opts['res'].split(PAR_SPLIT, 2)]
1284        comparison = True
1285    else:
1286        opts['res'] = [float(opts['res'])]*2
1287
1288    if opts['datafile'] is not None:
1289        data0 = load_data(os.path.expanduser(opts['datafile']))
1290        data = data0, data0
1291    else:
1292        # Hack around the fact that make_data doesn't take a pair of resolutions
1293        res = opts['res']
1294        opts['res'] = res[0]
1295        data0, _ = make_data(opts)
1296        if res[0] != res[1]:
1297            opts['res'] = res[1]
1298            data1, _ = make_data(opts)
1299        else:
1300            data1 = data0
1301        opts['res'] = res
1302        data = data0, data1
1303
1304    base = make_engine(model_info[0], data[0], opts['engine'][0],
1305                       opts['cutoff'][0], opts['ngauss'][0])
1306    if comparison:
1307        comp = make_engine(model_info[1], data[1], opts['engine'][1],
1308                           opts['cutoff'][1], opts['ngauss'][1])
1309    else:
1310        comp = None
1311
1312    # pylint: disable=bad-whitespace
1313    # Remember it all
1314    opts.update({
1315        'data'      : data,
1316        'name'      : names,
1317        'info'      : model_info,
1318        'engines'   : [base, comp],
1319        'values'    : values,
1320    })
1321    # pylint: enable=bad-whitespace
1322
1323    # Set the integration parameters to the half sphere
1324    if opts['sphere'] > 0:
1325        set_spherical_integration_parameters(opts, opts['sphere'])
1326
1327    return opts
1328
1329def set_spherical_integration_parameters(opts, steps):
1330    # type: (Dict[str, Any], int) -> None
1331    """
1332    Set integration parameters for spherical integration over the entire
1333    surface in theta-phi coordinates.
1334    """
1335    # Set the integration parameters to the half sphere
1336    opts['values'].extend([
1337        #'theta=90',
1338        'theta_pd=%g'%(90/np.sqrt(3)),
1339        'theta_pd_n=%d'%steps,
1340        'theta_pd_type=rectangle',
1341        #'phi=0',
1342        'phi_pd=%g'%(180/np.sqrt(3)),
1343        'phi_pd_n=%d'%(2*steps),
1344        'phi_pd_type=rectangle',
1345        #'background=0',
1346    ])
1347    if 'psi' in opts['info'][0].parameters:
1348        opts['values'].extend([
1349            #'psi=0',
1350            'psi_pd=%g'%(180/np.sqrt(3)),
1351            'psi_pd_n=%d'%(2*steps),
1352            'psi_pd_type=rectangle',
1353        ])
1354
1355def parse_pars(opts, maxdim=np.inf):
1356    # type: (Dict[str, Any], float) -> Tuple[Dict[str, float], Dict[str, float]]
1357    """
1358    Generate a parameter set.
1359
1360    The default values come from the model, or a randomized model if a seed
1361    value is given.  Next, evaluate any parameter expressions, constraining
1362    the value of the parameter within and between models.  If *maxdim* is
1363    given, limit parameters with units of Angstrom to this value.
1364
1365    Returns a pair of parameter dictionaries for base and comparison models.
1366    """
1367    model_info, model_info2 = opts['info']
1368
1369    # Get demo parameters from model definition, or use default parameters
1370    # if model does not define demo parameters
1371    pars = get_pars(model_info, opts['use_demo'])
1372    pars2 = get_pars(model_info2, opts['use_demo'])
1373    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1374    # randomize parameters
1375    #pars.update(set_pars)  # set value before random to control range
1376    if opts['seed'] > -1:
1377        pars = randomize_pars(model_info, pars)
1378        limit_dimensions(model_info, pars, maxdim)
1379        if model_info != model_info2:
1380            pars2 = randomize_pars(model_info2, pars2)
1381            limit_dimensions(model_info2, pars2, maxdim)
1382            # Share values for parameters with the same name
1383            for k, v in pars.items():
1384                if k in pars2:
1385                    pars2[k] = v
1386        else:
1387            pars2 = pars.copy()
1388        constrain_pars(model_info, pars)
1389        constrain_pars(model_info2, pars2)
1390    pars = suppress_pd(pars, opts['mono'])
1391    pars2 = suppress_pd(pars2, opts['mono'])
1392    pars = suppress_magnetism(pars, not opts['magnetic'])
1393    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1394
1395    # Fill in parameters given on the command line
1396    presets = {}
1397    presets2 = {}
1398    for arg in opts['values']:
1399        k, v = arg.split('=', 1)
1400        if k not in pars and k not in pars2:
1401            # extract base name without polydispersity info
1402            s = set(p.split('_pd')[0] for p in pars)
1403            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1404            return None
1405        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v, v)
1406        if v1 and k in pars:
1407            presets[k] = float(v1) if isnumber(v1) else v1
1408        if v2 and k in pars2:
1409            presets2[k] = float(v2) if isnumber(v2) else v2
1410
1411    # If pd given on the command line, default pd_n to 35
1412    for k, v in list(presets.items()):
1413        if k.endswith('_pd'):
1414            presets.setdefault(k+'_n', 35.)
1415    for k, v in list(presets2.items()):
1416        if k.endswith('_pd'):
1417            presets2.setdefault(k+'_n', 35.)
1418
1419    # Evaluate preset parameter expressions
1420    # Note: need to replace ':' with '_' in parameter names and expressions
1421    # in order to support math on magnetic parameters.
1422    context = MATH.copy()
1423    context['np'] = np
1424    context.update((k.replace(':', '_'), v) for k, v in pars.items())
1425    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1426    #for k,v in sorted(context.items()): print(k, v)
1427    for k, v in presets.items():
1428        if not isinstance(v, float) and not k.endswith('_type'):
1429            presets[k] = eval(v.replace(':', '_'), context)
1430    context.update(presets)
1431    context.update((k.replace(':', '_'), v) for k, v in presets2.items() if isinstance(v, float))
1432    for k, v in presets2.items():
1433        if not isinstance(v, float) and not k.endswith('_type'):
1434            presets2[k] = eval(v.replace(':', '_'), context)
1435
1436    # update parameters with presets
1437    pars.update(presets)  # set value after random to control value
1438    pars2.update(presets2)  # set value after random to control value
1439    #import pprint; pprint.pprint(model_info)
1440
1441    if opts['show_pars']:
1442        if model_info.name != model_info2.name or pars != pars2:
1443            print("==== %s ====="%model_info.name)
1444            print(str(parlist(model_info, pars, opts['is2d'])))
1445            print("==== %s ====="%model_info2.name)
1446            print(str(parlist(model_info2, pars2, opts['is2d'])))
1447        else:
1448            print(str(parlist(model_info, pars, opts['is2d'])))
1449
1450    return pars, pars2
1451
1452def show_docs(opts):
1453    # type: (Dict[str, Any]) -> None
1454    """
1455    show html docs for the model
1456    """
1457    from .generate import make_html
1458    from . import rst2html
1459
1460    info = opts['info'][0]
1461    html = make_html(info)
1462    path = os.path.dirname(info.filename)
1463    url = "file://" + path.replace("\\", "/")[2:] + "/"
1464    rst2html.view_html_wxapp(html, url)
1465
1466def explore(opts):
1467    # type: (Dict[str, Any]) -> None
1468    """
1469    explore the model using the bumps gui.
1470    """
1471    import wx  # type: ignore
1472    from bumps.names import FitProblem  # type: ignore
1473    from bumps.gui.app_frame import AppFrame  # type: ignore
1474    from bumps.gui import signal
1475
1476    is_mac = "cocoa" in wx.version()
1477    # Create an app if not running embedded
1478    app = wx.App() if wx.GetApp() is None else None
1479    model = Explore(opts)
1480    problem = FitProblem(model)
1481    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1482    if not is_mac:
1483        frame.Show()
1484    frame.panel.set_model(model=problem)
1485    frame.panel.Layout()
1486    frame.panel.aui.Split(0, wx.TOP)
1487    def _reset_parameters(event):
1488        model.revert_values()
1489        signal.update_parameters(problem)
1490    frame.Bind(wx.EVT_TOOL, _reset_parameters, frame.ToolBar.GetToolByPos(1))
1491    if is_mac:
1492        frame.Show()
1493    # If running withing an app, start the main loop
1494    if app:
1495        app.MainLoop()
1496
1497class Explore(object):
1498    """
1499    Bumps wrapper for a SAS model comparison.
1500
1501    The resulting object can be used as a Bumps fit problem so that
1502    parameters can be adjusted in the GUI, with plots updated on the fly.
1503    """
1504    def __init__(self, opts):
1505        # type: (Dict[str, Any]) -> None
1506        from bumps.cli import config_matplotlib  # type: ignore
1507        from . import bumps_model
1508        config_matplotlib()
1509        self.opts = opts
1510        opts['pars'] = list(opts['pars'])
1511        p1, p2 = opts['pars']
1512        m1, m2 = opts['info']
1513        self.fix_p2 = m1 != m2 or p1 != p2
1514        model_info = m1
1515        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1516        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1517        if not opts['is2d']:
1518            for p in model_info.parameters.user_parameters({}, is2d=False):
1519                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1520                    k = p.name+ext
1521                    v = pars.get(k, None)
1522                    if v is not None:
1523                        v.range(*parameter_range(k, v.value))
1524        else:
1525            for k, v in pars.items():
1526                v.range(*parameter_range(k, v.value))
1527
1528        self.pars = pars
1529        self.starting_values = dict((k, v.value) for k, v in pars.items())
1530        self.pd_types = pd_types
1531        self.limits = None
1532
1533    def revert_values(self):
1534        # type: () -> None
1535        """
1536        Restore starting values of the parameters.
1537        """
1538        for k, v in self.starting_values.items():
1539            self.pars[k].value = v
1540
1541    def model_update(self):
1542        # type: () -> None
1543        """
1544        Respond to signal that model parameters have been changed.
1545        """
1546        pass
1547
1548    def numpoints(self):
1549        # type: () -> int
1550        """
1551        Return the number of points.
1552        """
1553        return len(self.pars) + 1  # so dof is 1
1554
1555    def parameters(self):
1556        # type: () -> Any   # Dict/List hierarchy of parameters
1557        """
1558        Return a dictionary of parameters.
1559        """
1560        return self.pars
1561
1562    def nllf(self):
1563        # type: () -> float
1564        """
1565        Return cost.
1566        """
1567        # pylint: disable=no-self-use
1568        return 0.  # No nllf
1569
1570    def plot(self, view='log'):
1571        # type: (str) -> None
1572        """
1573        Plot the data and residuals.
1574        """
1575        pars = dict((k, v.value) for k, v in self.pars.items())
1576        pars.update(self.pd_types)
1577        self.opts['pars'][0] = pars
1578        if not self.fix_p2:
1579            self.opts['pars'][1] = pars
1580        result = run_models(self.opts)
1581        limits = plot_models(self.opts, result, limits=self.limits)
1582        if self.limits is None:
1583            vmin, vmax = limits
1584            self.limits = vmax*1e-7, 1.3*vmax
1585            import pylab
1586            pylab.clf()
1587            plot_models(self.opts, result, limits=self.limits)
1588
1589
1590def main(*argv):
1591    # type: (*str) -> None
1592    """
1593    Main program.
1594    """
1595    opts = parse_opts(argv)
1596    if opts is not None:
1597        if opts['seed'] > -1:
1598            print("Randomize using -random=%i"%opts['seed'])
1599            np.random.seed(opts['seed'])
1600        if opts['html']:
1601            show_docs(opts)
1602        elif opts['explore']:
1603            opts['pars'] = parse_pars(opts)
1604            if opts['pars'] is None:
1605                return
1606            explore(opts)
1607        else:
1608            compare(opts)
1609
1610if __name__ == "__main__":
1611    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.