source: sasmodels/sasmodels/compare.py @ df0d2ca

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since df0d2ca was 1198f90, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

limit pinhole resolution integral to ± 2.5 dq

  • Property mode set to 100755
File size: 53.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import kernelcl
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel, get_mesh
45from .generate import FLOAT_RE, set_integration_size
46from .weights import plot_weights
47
48# pylint: disable=unused-import
49try:
50    from typing import Optional, Dict, Any, Callable, Tuple
51except ImportError:
52    pass
53else:
54    from .modelinfo import ModelInfo, Parameter, ParameterSet
55    from .data import Data
56    Calculator = Callable[[float], np.ndarray]
57# pylint: enable=unused-import
58
59USAGE = """
60usage: sascomp model [options...] [key=val]
61
62Generate and compare SAS models.  If a single model is specified it shows
63a plot of that model.  Different models can be compared, or the same model
64with different parameters.  The same model with the same parameters can
65be compared with different calculation engines to see the effects of precision
66on the resultant values.
67
68model or model1,model2 are the names of the models to compare (see below).
69
70Options (* for default):
71
72    === data generation ===
73    -data="path" uses q, dq from the data file
74    -noise=0 sets the measurement error dI/I
75    -res=0 sets the resolution width dQ/Q if calculating with resolution
76    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
77    -q=min:max alternative specification of qrange
78    -nq=128 sets the number of Q points in the data set
79    -1d*/-2d computes 1d or 2d data
80    -zero indicates that q=0 should be included
81
82    === model parameters ===
83    -preset*/-random[=seed] preset or random parameters
84    -sets=n generates n random datasets with the seed given by -random=seed
85    -pars/-nopars* prints the parameter set or not
86    -default/-demo* use demo vs default parameters
87    -sphere[=150] set up spherical integration over theta/phi using n points
88
89    === calculation options ===
90    -mono*/-poly force monodisperse or allow polydisperse random parameters
91    -cutoff=1e-5* cutoff value for including a point in polydispersity
92    -magnetic/-nonmagnetic* suppress magnetism
93    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
94    -neval=1 sets the number of evals for more accurate timing
95    -ngauss=0 overrides the number of points in the 1-D gaussian quadrature
96
97    === precision options ===
98    -engine=default uses the default calcution precision
99    -single/-double/-half/-fast sets an OpenCL calculation engine
100    -single!/-double!/-quad! sets an OpenMP calculation engine
101
102    === plotting ===
103    -plot*/-noplot plots or suppress the plot of the model
104    -linear/-log*/-q4 intensity scaling on plots
105    -hist/-nohist* plot histogram of relative error
106    -abs/-rel* plot relative or absolute error
107    -title="note" adds note to the plot title, after the model name
108    -weights shows weights plots for the polydisperse parameters
109
110    === output options ===
111    -edit starts the parameter explorer
112    -help/-html shows the model docs instead of running the model
113
114The interpretation of quad precision depends on architecture, and may
115vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
116On unix and mac you may need single quotes around the DLL computation
117engines, such as -engine='single!,double!' since !, is treated as a history
118expansion request in the shell.
119
120Key=value pairs allow you to set specific values for the model parameters.
121Key=value1,value2 to compare different values of the same parameter. The
122value can be an expression including other parameters.
123
124Items later on the command line override those that appear earlier.
125
126Examples:
127
128    # compare single and double precision calculation for a barbell
129    sascomp barbell -engine=single,double
130
131    # generate 10 random lorentz models, with seed=27
132    sascomp lorentz -sets=10 -seed=27
133
134    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
135    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
136
137    # model timing test requires multiple evals to perform the estimate
138    sascomp pringle -engine=single,double -timing=100,100 -noplot
139"""
140
141# Update docs with command line usage string.   This is separate from the usual
142# doc string so that we can display it at run time if there is an error.
143# lin
144__doc__ = (__doc__  # pylint: disable=redefined-builtin
145           + """
146Program description
147-------------------
148
149""" + USAGE)
150
151kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
152
153def build_math_context():
154    # type: () -> Dict[str, Callable]
155    """build dictionary of functions from math module"""
156    return dict((k, getattr(math, k))
157                for k in dir(math) if not k.startswith('_'))
158
159#: list of math functions for use in evaluating parameters
160MATH = build_math_context()
161
162# CRUFT python 2.6
163if not hasattr(datetime.timedelta, 'total_seconds'):
164    def delay(dt):
165        """Return number date-time delta as number seconds"""
166        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
167else:
168    def delay(dt):
169        """Return number date-time delta as number seconds"""
170        return dt.total_seconds()
171
172
173class push_seed(object):
174    """
175    Set the seed value for the random number generator.
176
177    When used in a with statement, the random number generator state is
178    restored after the with statement is complete.
179
180    :Parameters:
181
182    *seed* : int or array_like, optional
183        Seed for RandomState
184
185    :Example:
186
187    Seed can be used directly to set the seed::
188
189        >>> from numpy.random import randint
190        >>> push_seed(24)
191        <...push_seed object at...>
192        >>> print(randint(0,1000000,3))
193        [242082    899 211136]
194
195    Seed can also be used in a with statement, which sets the random
196    number generator state for the enclosed computations and restores
197    it to the previous state on completion::
198
199        >>> with push_seed(24):
200        ...    print(randint(0,1000000,3))
201        [242082    899 211136]
202
203    Using nested contexts, we can demonstrate that state is indeed
204    restored after the block completes::
205
206        >>> with push_seed(24):
207        ...    print(randint(0,1000000))
208        ...    with push_seed(24):
209        ...        print(randint(0,1000000,3))
210        ...    print(randint(0,1000000))
211        242082
212        [242082    899 211136]
213        899
214
215    The restore step is protected against exceptions in the block::
216
217        >>> with push_seed(24):
218        ...    print(randint(0,1000000))
219        ...    try:
220        ...        with push_seed(24):
221        ...            print(randint(0,1000000,3))
222        ...            raise Exception()
223        ...    except Exception:
224        ...        print("Exception raised")
225        ...    print(randint(0,1000000))
226        242082
227        [242082    899 211136]
228        Exception raised
229        899
230    """
231    def __init__(self, seed=None):
232        # type: (Optional[int]) -> None
233        self._state = np.random.get_state()
234        np.random.seed(seed)
235
236    def __enter__(self):
237        # type: () -> None
238        pass
239
240    def __exit__(self, exc_type, exc_value, trace):
241        # type: (Any, BaseException, Any) -> None
242        np.random.set_state(self._state)
243
244def tic():
245    # type: () -> Callable[[], float]
246    """
247    Timer function.
248
249    Use "toc=tic()" to start the clock and "toc()" to measure
250    a time interval.
251    """
252    then = datetime.datetime.now()
253    return lambda: delay(datetime.datetime.now() - then)
254
255
256def set_beam_stop(data, radius, outer=None):
257    # type: (Data, float, float) -> None
258    """
259    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
260    """
261    if hasattr(data, 'qx_data'):
262        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
263        data.mask = (q < radius)
264        if outer is not None:
265            data.mask |= (q >= outer)
266    else:
267        data.mask = (data.x < radius)
268        if outer is not None:
269            data.mask |= (data.x >= outer)
270
271
272def parameter_range(p, v):
273    # type: (str, float) -> Tuple[float, float]
274    """
275    Choose a parameter range based on parameter name and initial value.
276    """
277    # process the polydispersity options
278    if p.endswith('_pd_n'):
279        return 0., 100.
280    elif p.endswith('_pd_nsigma'):
281        return 0., 5.
282    elif p.endswith('_pd_type'):
283        raise ValueError("Cannot return a range for a string value")
284    elif any(s in p for s in ('theta', 'phi', 'psi')):
285        # orientation in [-180,180], orientation pd in [0,45]
286        if p.endswith('_pd'):
287            return 0., 180.
288        else:
289            return -180., 180.
290    elif p.endswith('_pd'):
291        return 0., 1.
292    elif 'sld' in p:
293        return -0.5, 10.
294    elif p == 'background':
295        return 0., 10.
296    elif p == 'scale':
297        return 0., 1.e3
298    elif v < 0.:
299        return 2.*v, -2.*v
300    else:
301        return 0., (2.*v if v > 0. else 1.)
302
303
304def _randomize_one(model_info, name, value):
305    # type: (ModelInfo, str, float) -> float
306    # type: (ModelInfo, str, str) -> str
307    """
308    Randomize a single parameter.
309    """
310    # Set the amount of polydispersity/angular dispersion, but by default pd_n
311    # is zero so there is no polydispersity.  This allows us to turn on/off
312    # pd by setting pd_n, and still have randomly generated values
313    if name.endswith('_pd'):
314        par = model_info.parameters[name[:-3]]
315        if par.type == 'orientation':
316            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
317            return 180*np.random.beta(2.5, 20)
318        else:
319            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
320            return np.random.beta(1.5, 7)
321
322    # pd is selected globally rather than per parameter, so set to 0 for no pd
323    # In particular, when multiple pd dimensions, want to decrease the number
324    # of points per dimension for faster computation
325    if name.endswith('_pd_n'):
326        return 0
327
328    # Don't mess with distribution type for now
329    if name.endswith('_pd_type'):
330        return 'gaussian'
331
332    # type-dependent value of number of sigmas; for gaussian use 3.
333    if name.endswith('_pd_nsigma'):
334        return 3.
335
336    # background in the range [0.01, 1]
337    if name == 'background':
338        return 10**np.random.uniform(-2, 0)
339
340    # scale defaults to 0.1% to 30% volume fraction
341    if name == 'scale':
342        return 10**np.random.uniform(-3, -0.5)
343
344    # If it is a list of choices, pick one at random with equal probability
345    # In practice, the model specific random generator will override.
346    par = model_info.parameters[name]
347    if len(par.limits) > 2:  # choice list
348        return np.random.randint(len(par.limits))
349
350    # If it is a fixed range, pick from it with equal probability.
351    # For logarithmic ranges, the model will have to override.
352    if np.isfinite(par.limits).all():
353        return np.random.uniform(*par.limits)
354
355    # If the paramter is marked as an sld use the range of neutron slds
356    # TODO: ought to randomly contrast match a pair of SLDs
357    if par.type == 'sld':
358        return np.random.uniform(-0.5, 12)
359
360    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
361    if par.name.startswith('M0:'):
362        return np.random.uniform(0, 5)
363
364    # Guess at the random length/radius/thickness.  In practice, all models
365    # are going to set their own reasonable ranges.
366    if par.type == 'volume':
367        if ('length' in par.name or
368                'radius' in par.name or
369                'thick' in par.name):
370            return 10**np.random.uniform(2, 4)
371
372    # In the absence of any other info, select a value in [0, 2v], or
373    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
374    # model random parameter generators will override this default.
375    low, high = parameter_range(par.name, value)
376    limits = (max(par.limits[0], low), min(par.limits[1], high))
377    return np.random.uniform(*limits)
378
379def _random_pd(model_info, pars):
380    # type: (ModelInfo, Dict[str, float]) -> None
381    """
382    Generate a random dispersity distribution for the model.
383
384    1% no shape dispersity
385    85% single shape parameter
386    13% two shape parameters
387    1% three shape parameters
388
389    If oriented, then put dispersity in theta, add phi and psi dispersity
390    with 10% probability for each.
391    """
392    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
393    pd_volume = []
394    pd_oriented = []
395    for p in pd:
396        if p.type == 'orientation':
397            pd_oriented.append(p.name)
398        elif p.length_control is not None:
399            n = int(pars.get(p.length_control, 1) + 0.5)
400            pd_volume.extend(p.name+str(k+1) for k in range(n))
401        elif p.length > 1:
402            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
403        else:
404            pd_volume.append(p.name)
405    u = np.random.rand()
406    n = len(pd_volume)
407    if u < 0.01 or n < 1:
408        pass  # 1% chance of no polydispersity
409    elif u < 0.86 or n < 2:
410        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
411    elif u < 0.99 or n < 3:
412        choices = np.random.choice(len(pd_volume), size=2)
413        pars[pd_volume[choices[0]]+"_pd_n"] = 25
414        pars[pd_volume[choices[1]]+"_pd_n"] = 10
415    else:
416        choices = np.random.choice(len(pd_volume), size=3)
417        pars[pd_volume[choices[0]]+"_pd_n"] = 25
418        pars[pd_volume[choices[1]]+"_pd_n"] = 10
419        pars[pd_volume[choices[2]]+"_pd_n"] = 5
420    if pd_oriented:
421        pars['theta_pd_n'] = 20
422        if np.random.rand() < 0.1:
423            pars['phi_pd_n'] = 5
424        if np.random.rand() < 0.1:
425            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
426                #print("generating psi_pd_n")
427                pars['psi_pd_n'] = 5
428
429    ## Show selected polydispersity
430    #for name, value in pars.items():
431    #    if name.endswith('_pd_n') and value > 0:
432    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
433
434
435def randomize_pars(model_info, pars):
436    # type: (ModelInfo, ParameterSet) -> ParameterSet
437    """
438    Generate random values for all of the parameters.
439
440    Valid ranges for the random number generator are guessed from the name of
441    the parameter; this will not account for constraints such as cap radius
442    greater than cylinder radius in the capped_cylinder model, so
443    :func:`constrain_pars` needs to be called afterward..
444    """
445    # Note: the sort guarantees order of calls to random number generator
446    random_pars = dict((p, _randomize_one(model_info, p, v))
447                       for p, v in sorted(pars.items()))
448    if model_info.random is not None:
449        random_pars.update(model_info.random())
450    _random_pd(model_info, random_pars)
451    return random_pars
452
453
454def limit_dimensions(model_info, pars, maxdim):
455    # type: (ModelInfo, ParameterSet, float) -> None
456    """
457    Limit parameters of units of Ang to maxdim.
458    """
459    for p in model_info.parameters.call_parameters:
460        value = pars[p.name]
461        if p.units == 'Ang' and value > maxdim:
462            pars[p.name] = maxdim*10**np.random.uniform(-3, 0)
463
464def constrain_pars(model_info, pars):
465    # type: (ModelInfo, ParameterSet) -> None
466    """
467    Restrict parameters to valid values.
468
469    This includes model specific code for models such as capped_cylinder
470    which need to support within model constraints (cap radius more than
471    cylinder radius in this case).
472
473    Warning: this updates the *pars* dictionary in place.
474    """
475    # TODO: move the model specific code to the individual models
476    name = model_info.id
477    # if it is a product model, then just look at the form factor since
478    # none of the structure factors need any constraints.
479    if '*' in name:
480        name = name.split('*')[0]
481
482    # Suppress magnetism for python models (not yet implemented)
483    if callable(model_info.Iq):
484        pars.update(suppress_magnetism(pars))
485
486    if name == 'barbell':
487        if pars['radius_bell'] < pars['radius']:
488            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
489
490    elif name == 'capped_cylinder':
491        if pars['radius_cap'] < pars['radius']:
492            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
493
494    elif name == 'guinier':
495        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
496        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
497        # => ln A - (Rg^2 q^2/3) > -30 ln 10
498        # => Rg^2 q^2/3 < 30 ln 10 + ln A
499        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
500        #q_max = 0.2  # mid q maximum
501        q_max = 1.0  # high q maximum
502        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
503        pars['rg'] = min(pars['rg'], rg_max)
504
505    elif name == 'pearl_necklace':
506        if pars['radius'] < pars['thick_string']:
507            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
508
509    elif name == 'rpa':
510        # Make sure phi sums to 1.0
511        if pars['case_num'] < 2:
512            pars['Phi1'] = 0.
513            pars['Phi2'] = 0.
514        elif pars['case_num'] < 5:
515            pars['Phi1'] = 0.
516        total = sum(pars['Phi'+c] for c in '1234')
517        for c in '1234':
518            pars['Phi'+c] /= total
519
520def parlist(model_info, pars, is2d):
521    # type: (ModelInfo, ParameterSet, bool) -> str
522    """
523    Format the parameter list for printing.
524    """
525    is2d = True
526    lines = []
527    parameters = model_info.parameters
528    magnetic = False
529    magnetic_pars = []
530    for p in parameters.user_parameters(pars, is2d):
531        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
532            continue
533        if p.id.startswith('up:'):
534            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
535            continue
536        fields = dict(
537            value=pars.get(p.id, p.default),
538            pd=pars.get(p.id+"_pd", 0.),
539            n=int(pars.get(p.id+"_pd_n", 0)),
540            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
541            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
542            relative_pd=p.relative_pd,
543            M0=pars.get('M0:'+p.id, 0.),
544            mphi=pars.get('mphi:'+p.id, 0.),
545            mtheta=pars.get('mtheta:'+p.id, 0.),
546        )
547        lines.append(_format_par(p.name, **fields))
548        magnetic = magnetic or fields['M0'] != 0.
549    if magnetic and magnetic_pars:
550        lines.append(" ".join(magnetic_pars))
551    return "\n".join(lines)
552
553    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
554
555def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
556                relative_pd=False, M0=0., mphi=0., mtheta=0.):
557    # type: (str, float, float, int, float, str) -> str
558    line = "%s: %g"%(name, value)
559    if pd != 0.  and n != 0:
560        if relative_pd:
561            pd *= value
562        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
563                % (pd, n, nsigma, nsigma, pdtype)
564    if M0 != 0.:
565        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
566    return line
567
568def suppress_pd(pars, suppress=True):
569    # type: (ParameterSet) -> ParameterSet
570    """
571    If suppress is True complete eliminate polydispersity of the model to test
572    models more quickly.  If suppress is False, make sure at least one
573    parameter is polydisperse, setting the first polydispersity parameter to
574    15% if no polydispersity is given (with no explicit demo parameters given
575    in the model, there will be no default polydispersity).
576    """
577    pars = pars.copy()
578    #print("pars=", pars)
579    if suppress:
580        for p in pars:
581            if p.endswith("_pd_n"):
582                pars[p] = 0
583    else:
584        any_pd = False
585        first_pd = None
586        for p in pars:
587            if p.endswith("_pd_n"):
588                pd = pars.get(p[:-2], 0.)
589                any_pd |= (pars[p] != 0 and pd != 0.)
590                if first_pd is None:
591                    first_pd = p
592        if not any_pd and first_pd is not None:
593            if pars[first_pd] == 0:
594                pars[first_pd] = 35
595            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
596                pars[first_pd[:-2]] = 0.15
597    return pars
598
599def suppress_magnetism(pars, suppress=True):
600    # type: (ParameterSet) -> ParameterSet
601    """
602    If suppress is True complete eliminate magnetism of the model to test
603    models more quickly.  If suppress is False, make sure at least one sld
604    parameter is magnetic, setting the first parameter to have a strong
605    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
606    in the model, there will be no default magnetism).
607    """
608    pars = pars.copy()
609    if suppress:
610        for p in pars:
611            if p.startswith("M0:"):
612                pars[p] = 0
613    else:
614        any_mag = False
615        first_mag = None
616        for p in pars:
617            if p.startswith("M0:"):
618                any_mag |= (pars[p] != 0)
619                if first_mag is None:
620                    first_mag = p
621        if not any_mag and first_mag is not None:
622            pars[first_mag] = 8.
623    return pars
624
625
626def time_calculation(calculator, pars, evals=1):
627    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
628    """
629    Compute the average calculation time over N evaluations.
630
631    An additional call is generated without polydispersity in order to
632    initialize the calculation engine, and make the average more stable.
633    """
634    # initialize the code so time is more accurate
635    if evals > 1:
636        calculator(**suppress_pd(pars))
637    toc = tic()
638    # make sure there is at least one eval
639    value = calculator(**pars)
640    for _ in range(evals-1):
641        value = calculator(**pars)
642    average_time = toc()*1000. / evals
643    #print("I(q)",value)
644    return value, average_time
645
646def make_data(opts):
647    # type: (Dict[str, Any], float) -> Tuple[Data, np.ndarray]
648    """
649    Generate an empty dataset, used with the model to set Q points
650    and resolution.
651
652    *opts* contains the options, with 'qmax', 'nq', 'res',
653    'accuracy', 'is2d' and 'view' parsed from the command line.
654    """
655    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res']
656    if opts['is2d']:
657        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
658        data = empty_data2D(q, resolution=res)
659        data.accuracy = opts['accuracy']
660        set_beam_stop(data, qmin)
661        index = ~data.mask
662    else:
663        if opts['view'] == 'log' and not opts['zero']:
664            q = np.logspace(math.log10(qmin), math.log10(qmax), nq)
665        else:
666            q = np.linspace(qmin, qmax, nq)
667        if opts['zero']:
668            q = np.hstack((0, q))
669        data = empty_data1D(q, resolution=res)
670        index = slice(None, None)
671    return data, index
672
673DTYPE_MAP = {
674    'half': '16',
675    'fast': 'fast',
676    'single': '32',
677    'double': '64',
678    'quad': '128',
679    'f16': '16',
680    'f32': '32',
681    'f64': '64',
682    'float16': '16',
683    'float32': '32',
684    'float64': '64',
685    'float128': '128',
686    'longdouble': '128',
687}
688def eval_opencl(model_info, data, dtype='single', cutoff=0.):
689    # type: (ModelInfo, Data, str, float) -> Calculator
690    """
691    Return a model calculator using the OpenCL calculation engine.
692    """
693
694def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
695    # type: (ModelInfo, Data, str, float) -> Calculator
696    """
697    Return a model calculator using the DLL calculation engine.
698    """
699    model = core.build_model(model_info, dtype=dtype, platform="dll")
700    calculator = DirectModel(data, model, cutoff=cutoff)
701    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
702    return calculator
703
704def make_engine(model_info, data, dtype, cutoff, ngauss=0):
705    # type: (ModelInfo, Data, str, float) -> Calculator
706    """
707    Generate the appropriate calculation engine for the given datatype.
708
709    Datatypes with '!' appended are evaluated using external C DLLs rather
710    than OpenCL.
711    """
712    if ngauss:
713        set_integration_size(model_info, ngauss)
714
715    if dtype != "default" and not dtype.endswith('!') and not kernelcl.use_opencl():
716        raise RuntimeError("OpenCL not available " + kernelcl.OPENCL_ERROR)
717
718    model = core.build_model(model_info, dtype=dtype, platform="ocl")
719    calculator = DirectModel(data, model, cutoff=cutoff)
720    engine_type = calculator._model.__class__.__name__.replace('Model', '').upper()
721    bits = calculator._model.dtype.itemsize*8
722    precision = "fast" if getattr(calculator._model, 'fast', False) else str(bits)
723    calculator.engine = "%s[%s]" % (engine_type, precision)
724    return calculator
725
726def _show_invalid(data, theory):
727    # type: (Data, np.ma.ndarray) -> None
728    """
729    Display a list of the non-finite values in theory.
730    """
731    if not theory.mask.any():
732        return
733
734    if hasattr(data, 'x'):
735        bad = zip(data.x[theory.mask], theory[theory.mask])
736        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
737
738
739def compare(opts, limits=None, maxdim=np.inf):
740    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
741    """
742    Preform a comparison using options from the command line.
743
744    *limits* are the limits on the values to use, either to set the y-axis
745    for 1D or to set the colormap scale for 2D.  If None, then they are
746    inferred from the data and returned. When exploring using Bumps,
747    the limits are set when the model is initially called, and maintained
748    as the values are adjusted, making it easier to see the effects of the
749    parameters.
750
751    *maxdim* is the maximum value for any parameter with units of Angstrom.
752    """
753    for k in range(opts['sets']):
754        if k > 0:
755            # print a separate seed for each dataset for better reproducibility
756            new_seed = np.random.randint(1000000)
757            print("=== Set %d uses -random=%i ==="%(k+1, new_seed))
758            np.random.seed(new_seed)
759        opts['pars'] = parse_pars(opts, maxdim=maxdim)
760        if opts['pars'] is None:
761            return
762        result = run_models(opts, verbose=True)
763        if opts['plot']:
764            if opts['is2d'] and k > 0:
765                import matplotlib.pyplot as plt
766                plt.figure()
767            limits = plot_models(opts, result, limits=limits, setnum=k)
768        if opts['show_weights']:
769            base, _ = opts['engines']
770            base_pars, _ = opts['pars']
771            model_info = base._kernel.info
772            dim = base._kernel.dim
773            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
774    if opts['plot']:
775        import matplotlib.pyplot as plt
776        plt.show()
777    return limits
778
779def run_models(opts, verbose=False):
780    # type: (Dict[str, Any]) -> Dict[str, Any]
781    """
782    Process a parameter set, return calculation results and times.
783    """
784
785    base, comp = opts['engines']
786    base_n, comp_n = opts['count']
787    base_pars, comp_pars = opts['pars']
788    base_data, comp_data = opts['data']
789
790    comparison = comp is not None
791
792    base_time = comp_time = None
793    base_value = comp_value = resid = relerr = None
794
795    # Base calculation
796    try:
797        base_raw, base_time = time_calculation(base, base_pars, base_n)
798        base_value = np.ma.masked_invalid(base_raw)
799        if verbose:
800            print("%s t=%.2f ms, intensity=%.0f"
801                  % (base.engine, base_time, base_value.sum()))
802        _show_invalid(base_data, base_value)
803    except ImportError:
804        traceback.print_exc()
805
806    # Comparison calculation
807    if comparison:
808        try:
809            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
810            comp_value = np.ma.masked_invalid(comp_raw)
811            if verbose:
812                print("%s t=%.2f ms, intensity=%.0f"
813                      % (comp.engine, comp_time, comp_value.sum()))
814            _show_invalid(base_data, comp_value)
815        except ImportError:
816            traceback.print_exc()
817
818    # Compare, but only if computing both forms
819    if comparison:
820        resid = (base_value - comp_value)
821        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
822        if verbose:
823            _print_stats("|%s-%s|"
824                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
825                         resid)
826            _print_stats("|(%s-%s)/%s|"
827                         % (base.engine, comp.engine, comp.engine),
828                         relerr)
829
830    return dict(base_value=base_value, comp_value=comp_value,
831                base_time=base_time, comp_time=comp_time,
832                resid=resid, relerr=relerr)
833
834
835def _print_stats(label, err):
836    # type: (str, np.ma.ndarray) -> None
837    # work with trimmed data, not the full set
838    sorted_err = np.sort(abs(err.compressed()))
839    if len(sorted_err) == 0:
840        print(label + "  no valid values")
841        return
842
843    p50 = int((len(sorted_err)-1)*0.50)
844    p98 = int((len(sorted_err)-1)*0.98)
845    data = [
846        "max:%.3e"%sorted_err[-1],
847        "median:%.3e"%sorted_err[p50],
848        "98%%:%.3e"%sorted_err[p98],
849        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
850        "zero-offset:%+.3e"%np.mean(sorted_err),
851        ]
852    print(label+"  "+"  ".join(data))
853
854
855def plot_models(opts, result, limits=None, setnum=0):
856    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
857    """
858    Plot the results from :func:`run_model`.
859    """
860    import matplotlib.pyplot as plt
861
862    base_value, comp_value = result['base_value'], result['comp_value']
863    base_time, comp_time = result['base_time'], result['comp_time']
864    resid, relerr = result['resid'], result['relerr']
865
866    have_base, have_comp = (base_value is not None), (comp_value is not None)
867    base, comp = opts['engines']
868    base_data, comp_data = opts['data']
869    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
870
871    # Plot if requested
872    view = opts['view']
873    if limits is None:
874        vmin, vmax = np.inf, -np.inf
875        if have_base:
876            vmin = min(vmin, base_value.min())
877            vmax = max(vmax, base_value.max())
878        if have_comp:
879            vmin = min(vmin, comp_value.min())
880            vmax = max(vmax, comp_value.max())
881        limits = vmin, vmax
882
883    if have_base:
884        if have_comp:
885            plt.subplot(131)
886        plot_theory(base_data, base_value, view=view, use_data=use_data, limits=limits)
887        plt.title("%s t=%.2f ms"%(base.engine, base_time))
888        #cbar_title = "log I"
889    if have_comp:
890        if have_base:
891            plt.subplot(132)
892        if not opts['is2d'] and have_base:
893            plot_theory(comp_data, base_value, view=view, use_data=use_data, limits=limits)
894        plot_theory(comp_data, comp_value, view=view, use_data=use_data, limits=limits)
895        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
896        #cbar_title = "log I"
897    if have_base and have_comp:
898        plt.subplot(133)
899        if not opts['rel_err']:
900            err, errstr, errview = resid, "abs err", "linear"
901        else:
902            err, errstr, errview = abs(relerr), "rel err", "log"
903            if (err == 0.).all():
904                errview = 'linear'
905        if 0:  # 95% cutoff
906            sorted_err = np.sort(err.flatten())
907            cutoff = sorted_err[int(sorted_err.size*0.95)]
908            err[err > cutoff] = cutoff
909        #err,errstr = base/comp,"ratio"
910        # Note: base_data only since base and comp have same q values (though
911        # perhaps different resolution), and we are plotting the difference
912        # at each q
913        plot_theory(base_data, None, resid=err, view=errview, use_data=use_data)
914        plt.xscale('log' if view == 'log' and not opts['is2d'] else 'linear')
915        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
916        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
917        #cbar_title = errstr if errview=="linear" else "log "+errstr
918    #if is2D:
919    #    h = plt.colorbar()
920    #    h.ax.set_title(cbar_title)
921    fig = plt.gcf()
922    extra_title = ' '+opts['title'] if opts['title'] else ''
923    fig.suptitle(":".join(opts['name']) + extra_title)
924
925    if have_base and have_comp and opts['show_hist']:
926        plt.figure()
927        v = relerr
928        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
929        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
930        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
931                   % (base.engine, comp.engine, comp.engine))
932        plt.ylabel('P(err)')
933        plt.title('Distribution of relative error between calculation engines')
934
935    return limits
936
937
938# ===========================================================================
939#
940
941# Set of command line options.
942# Normal options such as -plot/-noplot are specified as 'name'.
943# For options such as -nq=500 which require a value use 'name='.
944#
945OPTIONS = [
946    # Plotting
947    'plot', 'noplot', 'weights',
948    'linear', 'log', 'q4',
949    'rel', 'abs',
950    'hist', 'nohist',
951    'title=',
952
953    # Data generation
954    'data=', 'noise=', 'res=', 'nq=', 'q=',
955    'lowq', 'midq', 'highq', 'exq', 'zero',
956    '2d', '1d',
957
958    # Parameter set
959    'preset', 'random', 'random=', 'sets=',
960    'demo', 'default',  # TODO: remove demo/default
961    'nopars', 'pars',
962    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
963
964    # Calculation options
965    'poly', 'mono', 'cutoff=',
966    'magnetic', 'nonmagnetic',
967    'accuracy=', 'ngauss=',
968    'neval=',  # for timing...
969
970    # Precision options
971    'engine=',
972    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
973
974    # Output options
975    'help', 'html', 'edit',
976    ]
977
978NAME_OPTIONS = (lambda: set(k for k in OPTIONS if not k.endswith('=')))()
979VALUE_OPTIONS = (lambda: [k[:-1] for k in OPTIONS if k.endswith('=')])()
980
981
982def columnize(items, indent="", width=79):
983    # type: (List[str], str, int) -> str
984    """
985    Format a list of strings into columns.
986
987    Returns a string with carriage returns ready for printing.
988    """
989    column_width = max(len(w) for w in items) + 1
990    num_columns = (width - len(indent)) // column_width
991    num_rows = len(items) // num_columns
992    items = items + [""] * (num_rows * num_columns - len(items))
993    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
994    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
995             for row in zip(*columns)]
996    output = indent + ("\n"+indent).join(lines)
997    return output
998
999
1000def get_pars(model_info, use_demo=False):
1001    # type: (ModelInfo, bool) -> ParameterSet
1002    """
1003    Extract demo parameters from the model definition.
1004    """
1005    # Get the default values for the parameters
1006    pars = {}
1007    for p in model_info.parameters.call_parameters:
1008        parts = [('', p.default)]
1009        if p.polydisperse:
1010            parts.append(('_pd', 0.0))
1011            parts.append(('_pd_n', 0))
1012            parts.append(('_pd_nsigma', 3.0))
1013            parts.append(('_pd_type', "gaussian"))
1014        for ext, val in parts:
1015            if p.length > 1:
1016                dict(("%s%d%s" % (p.id, k, ext), val)
1017                     for k in range(1, p.length+1))
1018            else:
1019                pars[p.id + ext] = val
1020
1021    # Plug in values given in demo
1022    if use_demo and model_info.demo:
1023        pars.update(model_info.demo)
1024    return pars
1025
1026INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1027def isnumber(s):
1028    # type: (str) -> bool
1029    """Return True if string contains an int or float"""
1030    match = FLOAT_RE.match(s)
1031    isfloat = (match and not s[match.end():])
1032    return isfloat or INTEGER_RE.match(s)
1033
1034# For distinguishing pairs of models for comparison
1035# key-value pair separator =
1036# shell characters  | & ; <> $ % ' " \ # `
1037# model and parameter names _
1038# parameter expressions - + * / . ( )
1039# path characters including tilde expansion and windows drive ~ / :
1040# not sure about brackets [] {}
1041# maybe one of the following @ ? ^ ! ,
1042PAR_SPLIT = ','
1043def parse_opts(argv):
1044    # type: (List[str]) -> Dict[str, Any]
1045    """
1046    Parse command line options.
1047    """
1048    MODELS = core.list_models()
1049    flags = [arg for arg in argv
1050             if arg.startswith('-')]
1051    values = [arg for arg in argv
1052              if not arg.startswith('-') and '=' in arg]
1053    positional_args = [arg for arg in argv
1054                       if not arg.startswith('-') and '=' not in arg]
1055    models = "\n    ".join("%-15s"%v for v in MODELS)
1056    if len(positional_args) == 0:
1057        print(USAGE)
1058        print("\nAvailable models:")
1059        print(columnize(MODELS, indent="  "))
1060        return None
1061
1062    invalid = [o[1:] for o in flags
1063               if o[1:] not in NAME_OPTIONS
1064               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1065    if invalid:
1066        print("Invalid options: %s"%(", ".join(invalid)))
1067        return None
1068
1069    name = positional_args[-1]
1070
1071    # pylint: disable=bad-whitespace,C0321
1072    # Interpret the flags
1073    opts = {
1074        'plot'      : True,
1075        'view'      : 'log',
1076        'is2d'      : False,
1077        'qmin'      : None,
1078        'qmax'      : 0.05,
1079        'nq'        : 128,
1080        'res'       : '0.0',
1081        'noise'     : 0.0,
1082        'accuracy'  : 'Low',
1083        'cutoff'    : '0.0',
1084        'seed'      : -1,  # default to preset
1085        'mono'      : True,
1086        # Default to magnetic a magnetic moment is set on the command line
1087        'magnetic'  : False,
1088        'show_pars' : False,
1089        'show_hist' : False,
1090        'rel_err'   : True,
1091        'explore'   : False,
1092        'use_demo'  : True,
1093        'zero'      : False,
1094        'html'      : False,
1095        'title'     : None,
1096        'datafile'  : None,
1097        'sets'      : 0,
1098        'engine'    : 'default',
1099        'count'     : '1',
1100        'show_weights' : False,
1101        'sphere'    : 0,
1102        'ngauss'    : '0',
1103    }
1104    for arg in flags:
1105        if arg == '-noplot':    opts['plot'] = False
1106        elif arg == '-plot':    opts['plot'] = True
1107        elif arg == '-linear':  opts['view'] = 'linear'
1108        elif arg == '-log':     opts['view'] = 'log'
1109        elif arg == '-q4':      opts['view'] = 'q4'
1110        elif arg == '-1d':      opts['is2d'] = False
1111        elif arg == '-2d':      opts['is2d'] = True
1112        elif arg == '-exq':     opts['qmax'] = 10.0
1113        elif arg == '-highq':   opts['qmax'] = 1.0
1114        elif arg == '-midq':    opts['qmax'] = 0.2
1115        elif arg == '-lowq':    opts['qmax'] = 0.05
1116        elif arg == '-zero':    opts['zero'] = True
1117        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1118        elif arg.startswith('-q='):
1119            opts['qmin'], opts['qmax'] = [float(v) for v in arg[3:].split(':')]
1120        elif arg.startswith('-res='):      opts['res'] = arg[5:]
1121        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1122        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1123        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1124        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1125        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1126        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1127        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1128        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
1129        elif arg.startswith('-ngauss='):   opts['ngauss'] = arg[8:]
1130        elif arg.startswith('-random='):
1131            opts['seed'] = int(arg[8:])
1132            opts['sets'] = 0
1133        elif arg == '-random':
1134            opts['seed'] = np.random.randint(1000000)
1135            opts['sets'] = 0
1136        elif arg.startswith('-sphere'):
1137            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1138            opts['is2d'] = True
1139        elif arg == '-preset':  opts['seed'] = -1
1140        elif arg == '-mono':    opts['mono'] = True
1141        elif arg == '-poly':    opts['mono'] = False
1142        elif arg == '-magnetic':       opts['magnetic'] = True
1143        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1144        elif arg == '-pars':    opts['show_pars'] = True
1145        elif arg == '-nopars':  opts['show_pars'] = False
1146        elif arg == '-hist':    opts['show_hist'] = True
1147        elif arg == '-nohist':  opts['show_hist'] = False
1148        elif arg == '-rel':     opts['rel_err'] = True
1149        elif arg == '-abs':     opts['rel_err'] = False
1150        elif arg == '-half':    opts['engine'] = 'half'
1151        elif arg == '-fast':    opts['engine'] = 'fast'
1152        elif arg == '-single':  opts['engine'] = 'single'
1153        elif arg == '-double':  opts['engine'] = 'double'
1154        elif arg == '-single!': opts['engine'] = 'single!'
1155        elif arg == '-double!': opts['engine'] = 'double!'
1156        elif arg == '-quad!':   opts['engine'] = 'quad!'
1157        elif arg == '-edit':    opts['explore'] = True
1158        elif arg == '-demo':    opts['use_demo'] = True
1159        elif arg == '-default': opts['use_demo'] = False
1160        elif arg == '-weights': opts['show_weights'] = True
1161        elif arg == '-html':    opts['html'] = True
1162        elif arg == '-help':    opts['html'] = True
1163    # pylint: enable=bad-whitespace,C0321
1164
1165    # Magnetism forces 2D for now
1166    if opts['magnetic']:
1167        opts['is2d'] = True
1168
1169    # Force random if sets is used
1170    if opts['sets'] >= 1 and opts['seed'] < 0:
1171        opts['seed'] = np.random.randint(1000000)
1172    if opts['sets'] == 0:
1173        opts['sets'] = 1
1174
1175    # Create the computational engines
1176    if opts['qmin'] is None:
1177        opts['qmin'] = 0.001*opts['qmax']
1178
1179    comparison = any(PAR_SPLIT in v for v in values)
1180
1181    if PAR_SPLIT in name:
1182        names = name.split(PAR_SPLIT, 2)
1183        comparison = True
1184    else:
1185        names = [name]*2
1186    try:
1187        model_info = [core.load_model_info(k) for k in names]
1188    except ImportError as exc:
1189        print(str(exc))
1190        print("Could not find model; use one of:\n    " + models)
1191        return None
1192
1193    if PAR_SPLIT in opts['ngauss']:
1194        opts['ngauss'] = [int(k) for k in opts['ngauss'].split(PAR_SPLIT, 2)]
1195        comparison = True
1196    else:
1197        opts['ngauss'] = [int(opts['ngauss'])]*2
1198
1199    if PAR_SPLIT in opts['engine']:
1200        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
1201        comparison = True
1202    else:
1203        opts['engine'] = [opts['engine']]*2
1204
1205    if PAR_SPLIT in opts['count']:
1206        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
1207        comparison = True
1208    else:
1209        opts['count'] = [int(opts['count'])]*2
1210
1211    if PAR_SPLIT in opts['cutoff']:
1212        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1213        comparison = True
1214    else:
1215        opts['cutoff'] = [float(opts['cutoff'])]*2
1216
1217    if PAR_SPLIT in opts['res']:
1218        opts['res'] = [float(k) for k in opts['res'].split(PAR_SPLIT, 2)]
1219        comparison = True
1220    else:
1221        opts['res'] = [float(opts['res'])]*2
1222
1223    if opts['datafile'] is not None:
1224        data = load_data(os.path.expanduser(opts['datafile']))
1225    else:
1226        # Hack around the fact that make_data doesn't take a pair of resolutions
1227        res = opts['res']
1228        opts['res'] = res[0]
1229        data0, _ = make_data(opts)
1230        if res[0] != res[1]:
1231            opts['res'] = res[1]
1232            data1, _ = make_data(opts)
1233        else:
1234            data1 = data0
1235        opts['res'] = res
1236        data = data0, data1
1237
1238    base = make_engine(model_info[0], data[0], opts['engine'][0],
1239                       opts['cutoff'][0], opts['ngauss'][0])
1240    if comparison:
1241        comp = make_engine(model_info[1], data[1], opts['engine'][1],
1242                           opts['cutoff'][1], opts['ngauss'][1])
1243    else:
1244        comp = None
1245
1246    # pylint: disable=bad-whitespace
1247    # Remember it all
1248    opts.update({
1249        'data'      : data,
1250        'name'      : names,
1251        'info'      : model_info,
1252        'engines'   : [base, comp],
1253        'values'    : values,
1254    })
1255    # pylint: enable=bad-whitespace
1256
1257    # Set the integration parameters to the half sphere
1258    if opts['sphere'] > 0:
1259        set_spherical_integration_parameters(opts, opts['sphere'])
1260
1261    return opts
1262
1263def set_spherical_integration_parameters(opts, steps):
1264    # type: (Dict[str, Any], int) -> None
1265    """
1266    Set integration parameters for spherical integration over the entire
1267    surface in theta-phi coordinates.
1268    """
1269    # Set the integration parameters to the half sphere
1270    opts['values'].extend([
1271        #'theta=90',
1272        'theta_pd=%g'%(90/np.sqrt(3)),
1273        'theta_pd_n=%d'%steps,
1274        'theta_pd_type=rectangle',
1275        #'phi=0',
1276        'phi_pd=%g'%(180/np.sqrt(3)),
1277        'phi_pd_n=%d'%(2*steps),
1278        'phi_pd_type=rectangle',
1279        #'background=0',
1280    ])
1281    if 'psi' in opts['info'][0].parameters:
1282        opts['values'].extend([
1283            #'psi=0',
1284            'psi_pd=%g'%(180/np.sqrt(3)),
1285            'psi_pd_n=%d'%(2*steps),
1286            'psi_pd_type=rectangle',
1287        ])
1288
1289def parse_pars(opts, maxdim=np.inf):
1290    # type: (Dict[str, Any], float) -> Tuple[Dict[str, float], Dict[str, float]]
1291    """
1292    Generate a parameter set.
1293
1294    The default values come from the model, or a randomized model if a seed
1295    value is given.  Next, evaluate any parameter expressions, constraining
1296    the value of the parameter within and between models.  If *maxdim* is
1297    given, limit parameters with units of Angstrom to this value.
1298
1299    Returns a pair of parameter dictionaries for base and comparison models.
1300    """
1301    model_info, model_info2 = opts['info']
1302
1303    # Get demo parameters from model definition, or use default parameters
1304    # if model does not define demo parameters
1305    pars = get_pars(model_info, opts['use_demo'])
1306    pars2 = get_pars(model_info2, opts['use_demo'])
1307    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1308    # randomize parameters
1309    #pars.update(set_pars)  # set value before random to control range
1310    if opts['seed'] > -1:
1311        pars = randomize_pars(model_info, pars)
1312        limit_dimensions(model_info, pars, maxdim)
1313        if model_info != model_info2:
1314            pars2 = randomize_pars(model_info2, pars2)
1315            limit_dimensions(model_info2, pars2, maxdim)
1316            # Share values for parameters with the same name
1317            for k, v in pars.items():
1318                if k in pars2:
1319                    pars2[k] = v
1320        else:
1321            pars2 = pars.copy()
1322        constrain_pars(model_info, pars)
1323        constrain_pars(model_info2, pars2)
1324    pars = suppress_pd(pars, opts['mono'])
1325    pars2 = suppress_pd(pars2, opts['mono'])
1326    pars = suppress_magnetism(pars, not opts['magnetic'])
1327    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1328
1329    # Fill in parameters given on the command line
1330    presets = {}
1331    presets2 = {}
1332    for arg in opts['values']:
1333        k, v = arg.split('=', 1)
1334        if k not in pars and k not in pars2:
1335            # extract base name without polydispersity info
1336            s = set(p.split('_pd')[0] for p in pars)
1337            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1338            return None
1339        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v, v)
1340        if v1 and k in pars:
1341            presets[k] = float(v1) if isnumber(v1) else v1
1342        if v2 and k in pars2:
1343            presets2[k] = float(v2) if isnumber(v2) else v2
1344
1345    # If pd given on the command line, default pd_n to 35
1346    for k, v in list(presets.items()):
1347        if k.endswith('_pd'):
1348            presets.setdefault(k+'_n', 35.)
1349    for k, v in list(presets2.items()):
1350        if k.endswith('_pd'):
1351            presets2.setdefault(k+'_n', 35.)
1352
1353    # Evaluate preset parameter expressions
1354    # Note: need to replace ':' with '_' in parameter names and expressions
1355    # in order to support math on magnetic parameters.
1356    context = MATH.copy()
1357    context['np'] = np
1358    context.update((k.replace(':', '_'), v) for k, v in pars.items())
1359    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1360    #for k,v in sorted(context.items()): print(k, v)
1361    for k, v in presets.items():
1362        if not isinstance(v, float) and not k.endswith('_type'):
1363            presets[k] = eval(v.replace(':', '_'), context)
1364    context.update(presets)
1365    context.update((k.replace(':', '_'), v) for k, v in presets2.items() if isinstance(v, float))
1366    for k, v in presets2.items():
1367        if not isinstance(v, float) and not k.endswith('_type'):
1368            presets2[k] = eval(v.replace(':', '_'), context)
1369
1370    # update parameters with presets
1371    pars.update(presets)  # set value after random to control value
1372    pars2.update(presets2)  # set value after random to control value
1373    #import pprint; pprint.pprint(model_info)
1374
1375    if opts['show_pars']:
1376        if model_info.name != model_info2.name or pars != pars2:
1377            print("==== %s ====="%model_info.name)
1378            print(str(parlist(model_info, pars, opts['is2d'])))
1379            print("==== %s ====="%model_info2.name)
1380            print(str(parlist(model_info2, pars2, opts['is2d'])))
1381        else:
1382            print(str(parlist(model_info, pars, opts['is2d'])))
1383
1384    return pars, pars2
1385
1386def show_docs(opts):
1387    # type: (Dict[str, Any]) -> None
1388    """
1389    show html docs for the model
1390    """
1391    from .generate import make_html
1392    from . import rst2html
1393
1394    info = opts['info'][0]
1395    html = make_html(info)
1396    path = os.path.dirname(info.filename)
1397    url = "file://" + path.replace("\\", "/")[2:] + "/"
1398    rst2html.view_html_wxapp(html, url)
1399
1400def explore(opts):
1401    # type: (Dict[str, Any]) -> None
1402    """
1403    explore the model using the bumps gui.
1404    """
1405    import wx  # type: ignore
1406    from bumps.names import FitProblem  # type: ignore
1407    from bumps.gui.app_frame import AppFrame  # type: ignore
1408    from bumps.gui import signal
1409
1410    is_mac = "cocoa" in wx.version()
1411    # Create an app if not running embedded
1412    app = wx.App() if wx.GetApp() is None else None
1413    model = Explore(opts)
1414    problem = FitProblem(model)
1415    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1416    if not is_mac:
1417        frame.Show()
1418    frame.panel.set_model(model=problem)
1419    frame.panel.Layout()
1420    frame.panel.aui.Split(0, wx.TOP)
1421    def _reset_parameters(event):
1422        model.revert_values()
1423        signal.update_parameters(problem)
1424    frame.Bind(wx.EVT_TOOL, _reset_parameters, frame.ToolBar.GetToolByPos(1))
1425    if is_mac:
1426        frame.Show()
1427    # If running withing an app, start the main loop
1428    if app:
1429        app.MainLoop()
1430
1431class Explore(object):
1432    """
1433    Bumps wrapper for a SAS model comparison.
1434
1435    The resulting object can be used as a Bumps fit problem so that
1436    parameters can be adjusted in the GUI, with plots updated on the fly.
1437    """
1438    def __init__(self, opts):
1439        # type: (Dict[str, Any]) -> None
1440        from bumps.cli import config_matplotlib  # type: ignore
1441        from . import bumps_model
1442        config_matplotlib()
1443        self.opts = opts
1444        opts['pars'] = list(opts['pars'])
1445        p1, p2 = opts['pars']
1446        m1, m2 = opts['info']
1447        self.fix_p2 = m1 != m2 or p1 != p2
1448        model_info = m1
1449        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1450        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1451        if not opts['is2d']:
1452            for p in model_info.parameters.user_parameters({}, is2d=False):
1453                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1454                    k = p.name+ext
1455                    v = pars.get(k, None)
1456                    if v is not None:
1457                        v.range(*parameter_range(k, v.value))
1458        else:
1459            for k, v in pars.items():
1460                v.range(*parameter_range(k, v.value))
1461
1462        self.pars = pars
1463        self.starting_values = dict((k, v.value) for k, v in pars.items())
1464        self.pd_types = pd_types
1465        self.limits = None
1466
1467    def revert_values(self):
1468        # type: () -> None
1469        """
1470        Restore starting values of the parameters.
1471        """
1472        for k, v in self.starting_values.items():
1473            self.pars[k].value = v
1474
1475    def model_update(self):
1476        # type: () -> None
1477        """
1478        Respond to signal that model parameters have been changed.
1479        """
1480        pass
1481
1482    def numpoints(self):
1483        # type: () -> int
1484        """
1485        Return the number of points.
1486        """
1487        return len(self.pars) + 1  # so dof is 1
1488
1489    def parameters(self):
1490        # type: () -> Any   # Dict/List hierarchy of parameters
1491        """
1492        Return a dictionary of parameters.
1493        """
1494        return self.pars
1495
1496    def nllf(self):
1497        # type: () -> float
1498        """
1499        Return cost.
1500        """
1501        # pylint: disable=no-self-use
1502        return 0.  # No nllf
1503
1504    def plot(self, view='log'):
1505        # type: (str) -> None
1506        """
1507        Plot the data and residuals.
1508        """
1509        pars = dict((k, v.value) for k, v in self.pars.items())
1510        pars.update(self.pd_types)
1511        self.opts['pars'][0] = pars
1512        if not self.fix_p2:
1513            self.opts['pars'][1] = pars
1514        result = run_models(self.opts)
1515        limits = plot_models(self.opts, result, limits=self.limits)
1516        if self.limits is None:
1517            vmin, vmax = limits
1518            self.limits = vmax*1e-7, 1.3*vmax
1519            import pylab
1520            pylab.clf()
1521            plot_models(self.opts, result, limits=self.limits)
1522
1523
1524def main(*argv):
1525    # type: (*str) -> None
1526    """
1527    Main program.
1528    """
1529    opts = parse_opts(argv)
1530    if opts is not None:
1531        if opts['seed'] > -1:
1532            print("Randomize using -random=%i"%opts['seed'])
1533            np.random.seed(opts['seed'])
1534        if opts['html']:
1535            show_docs(opts)
1536        elif opts['explore']:
1537            opts['pars'] = parse_pars(opts)
1538            if opts['pars'] is None:
1539                return
1540            explore(opts)
1541        else:
1542            compare(opts)
1543
1544if __name__ == "__main__":
1545    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.