source: sasmodels/sasmodels/compare.py @ 108e70e

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 108e70e was a261a83, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

Merge branch 'ticket-786' into generic_integration_loop

  • Property mode set to 100755
File size: 52.5 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from .data import plot_theory, empty_data1D, empty_data2D, load_data
43from .direct_model import DirectModel, get_mesh
44from .generate import FLOAT_RE, set_integration_size
45from .weights import plot_weights
46
47# pylint: disable=unused-import
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except ImportError:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56# pylint: enable=unused-import
57
58USAGE = """
59usage: sascomp model [options...] [key=val]
60
61Generate and compare SAS models.  If a single model is specified it shows
62a plot of that model.  Different models can be compared, or the same model
63with different parameters.  The same model with the same parameters can
64be compared with different calculation engines to see the effects of precision
65on the resultant values.
66
67model or model1,model2 are the names of the models to compare (see below).
68
69Options (* for default):
70
71    === data generation ===
72    -data="path" uses q, dq from the data file
73    -noise=0 sets the measurement error dI/I
74    -res=0 sets the resolution width dQ/Q if calculating with resolution
75    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
76    -q=min:max alternative specification of qrange
77    -nq=128 sets the number of Q points in the data set
78    -1d*/-2d computes 1d or 2d data
79    -zero indicates that q=0 should be included
80
81    === model parameters ===
82    -preset*/-random[=seed] preset or random parameters
83    -sets=n generates n random datasets with the seed given by -random=seed
84    -pars/-nopars* prints the parameter set or not
85    -default/-demo* use demo vs default parameters
86    -sphere[=150] set up spherical integration over theta/phi using n points
87
88    === calculation options ===
89    -mono*/-poly force monodisperse or allow polydisperse random parameters
90    -cutoff=1e-5* cutoff value for including a point in polydispersity
91    -magnetic/-nonmagnetic* suppress magnetism
92    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
93    -neval=1 sets the number of evals for more accurate timing
94
95    === precision options ===
96    -engine=default uses the default calcution precision
97    -single/-double/-half/-fast sets an OpenCL calculation engine
98    -single!/-double!/-quad! sets an OpenMP calculation engine
99
100    === plotting ===
101    -plot*/-noplot plots or suppress the plot of the model
102    -linear/-log*/-q4 intensity scaling on plots
103    -hist/-nohist* plot histogram of relative error
104    -abs/-rel* plot relative or absolute error
105    -title="note" adds note to the plot title, after the model name
106    -weights shows weights plots for the polydisperse parameters
107
108    === output options ===
109    -edit starts the parameter explorer
110    -help/-html shows the model docs instead of running the model
111
112The interpretation of quad precision depends on architecture, and may
113vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
114On unix and mac you may need single quotes around the DLL computation
115engines, such as -engine='single!,double!' since !, is treated as a history
116expansion request in the shell.
117
118Key=value pairs allow you to set specific values for the model parameters.
119Key=value1,value2 to compare different values of the same parameter. The
120value can be an expression including other parameters.
121
122Items later on the command line override those that appear earlier.
123
124Examples:
125
126    # compare single and double precision calculation for a barbell
127    sascomp barbell -engine=single,double
128
129    # generate 10 random lorentz models, with seed=27
130    sascomp lorentz -sets=10 -seed=27
131
132    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
133    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
134
135    # model timing test requires multiple evals to perform the estimate
136    sascomp pringle -engine=single,double -timing=100,100 -noplot
137"""
138
139# Update docs with command line usage string.   This is separate from the usual
140# doc string so that we can display it at run time if there is an error.
141# lin
142__doc__ = (__doc__  # pylint: disable=redefined-builtin
143           + """
144Program description
145-------------------
146
147""" + USAGE)
148
149kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
150
151def build_math_context():
152    # type: () -> Dict[str, Callable]
153    """build dictionary of functions from math module"""
154    return dict((k, getattr(math, k))
155                for k in dir(math) if not k.startswith('_'))
156
157#: list of math functions for use in evaluating parameters
158MATH = build_math_context()
159
160# CRUFT python 2.6
161if not hasattr(datetime.timedelta, 'total_seconds'):
162    def delay(dt):
163        """Return number date-time delta as number seconds"""
164        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
165else:
166    def delay(dt):
167        """Return number date-time delta as number seconds"""
168        return dt.total_seconds()
169
170
171class push_seed(object):
172    """
173    Set the seed value for the random number generator.
174
175    When used in a with statement, the random number generator state is
176    restored after the with statement is complete.
177
178    :Parameters:
179
180    *seed* : int or array_like, optional
181        Seed for RandomState
182
183    :Example:
184
185    Seed can be used directly to set the seed::
186
187        >>> from numpy.random import randint
188        >>> push_seed(24)
189        <...push_seed object at...>
190        >>> print(randint(0,1000000,3))
191        [242082    899 211136]
192
193    Seed can also be used in a with statement, which sets the random
194    number generator state for the enclosed computations and restores
195    it to the previous state on completion::
196
197        >>> with push_seed(24):
198        ...    print(randint(0,1000000,3))
199        [242082    899 211136]
200
201    Using nested contexts, we can demonstrate that state is indeed
202    restored after the block completes::
203
204        >>> with push_seed(24):
205        ...    print(randint(0,1000000))
206        ...    with push_seed(24):
207        ...        print(randint(0,1000000,3))
208        ...    print(randint(0,1000000))
209        242082
210        [242082    899 211136]
211        899
212
213    The restore step is protected against exceptions in the block::
214
215        >>> with push_seed(24):
216        ...    print(randint(0,1000000))
217        ...    try:
218        ...        with push_seed(24):
219        ...            print(randint(0,1000000,3))
220        ...            raise Exception()
221        ...    except Exception:
222        ...        print("Exception raised")
223        ...    print(randint(0,1000000))
224        242082
225        [242082    899 211136]
226        Exception raised
227        899
228    """
229    def __init__(self, seed=None):
230        # type: (Optional[int]) -> None
231        self._state = np.random.get_state()
232        np.random.seed(seed)
233
234    def __enter__(self):
235        # type: () -> None
236        pass
237
238    def __exit__(self, exc_type, exc_value, trace):
239        # type: (Any, BaseException, Any) -> None
240        np.random.set_state(self._state)
241
242def tic():
243    # type: () -> Callable[[], float]
244    """
245    Timer function.
246
247    Use "toc=tic()" to start the clock and "toc()" to measure
248    a time interval.
249    """
250    then = datetime.datetime.now()
251    return lambda: delay(datetime.datetime.now() - then)
252
253
254def set_beam_stop(data, radius, outer=None):
255    # type: (Data, float, float) -> None
256    """
257    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
258    """
259    if hasattr(data, 'qx_data'):
260        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
261        data.mask = (q < radius)
262        if outer is not None:
263            data.mask |= (q >= outer)
264    else:
265        data.mask = (data.x < radius)
266        if outer is not None:
267            data.mask |= (data.x >= outer)
268
269
270def parameter_range(p, v):
271    # type: (str, float) -> Tuple[float, float]
272    """
273    Choose a parameter range based on parameter name and initial value.
274    """
275    # process the polydispersity options
276    if p.endswith('_pd_n'):
277        return 0., 100.
278    elif p.endswith('_pd_nsigma'):
279        return 0., 5.
280    elif p.endswith('_pd_type'):
281        raise ValueError("Cannot return a range for a string value")
282    elif any(s in p for s in ('theta', 'phi', 'psi')):
283        # orientation in [-180,180], orientation pd in [0,45]
284        if p.endswith('_pd'):
285            return 0., 180.
286        else:
287            return -180., 180.
288    elif p.endswith('_pd'):
289        return 0., 1.
290    elif 'sld' in p:
291        return -0.5, 10.
292    elif p == 'background':
293        return 0., 10.
294    elif p == 'scale':
295        return 0., 1.e3
296    elif v < 0.:
297        return 2.*v, -2.*v
298    else:
299        return 0., (2.*v if v > 0. else 1.)
300
301
302def _randomize_one(model_info, name, value):
303    # type: (ModelInfo, str, float) -> float
304    # type: (ModelInfo, str, str) -> str
305    """
306    Randomize a single parameter.
307    """
308    # Set the amount of polydispersity/angular dispersion, but by default pd_n
309    # is zero so there is no polydispersity.  This allows us to turn on/off
310    # pd by setting pd_n, and still have randomly generated values
311    if name.endswith('_pd'):
312        par = model_info.parameters[name[:-3]]
313        if par.type == 'orientation':
314            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
315            return 180*np.random.beta(2.5, 20)
316        else:
317            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
318            return np.random.beta(1.5, 7)
319
320    # pd is selected globally rather than per parameter, so set to 0 for no pd
321    # In particular, when multiple pd dimensions, want to decrease the number
322    # of points per dimension for faster computation
323    if name.endswith('_pd_n'):
324        return 0
325
326    # Don't mess with distribution type for now
327    if name.endswith('_pd_type'):
328        return 'gaussian'
329
330    # type-dependent value of number of sigmas; for gaussian use 3.
331    if name.endswith('_pd_nsigma'):
332        return 3.
333
334    # background in the range [0.01, 1]
335    if name == 'background':
336        return 10**np.random.uniform(-2, 0)
337
338    # scale defaults to 0.1% to 30% volume fraction
339    if name == 'scale':
340        return 10**np.random.uniform(-3, -0.5)
341
342    # If it is a list of choices, pick one at random with equal probability
343    # In practice, the model specific random generator will override.
344    par = model_info.parameters[name]
345    if len(par.limits) > 2:  # choice list
346        return np.random.randint(len(par.limits))
347
348    # If it is a fixed range, pick from it with equal probability.
349    # For logarithmic ranges, the model will have to override.
350    if np.isfinite(par.limits).all():
351        return np.random.uniform(*par.limits)
352
353    # If the paramter is marked as an sld use the range of neutron slds
354    # TODO: ought to randomly contrast match a pair of SLDs
355    if par.type == 'sld':
356        return np.random.uniform(-0.5, 12)
357
358    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
359    if par.name.startswith('M0:'):
360        return np.random.uniform(0, 5)
361
362    # Guess at the random length/radius/thickness.  In practice, all models
363    # are going to set their own reasonable ranges.
364    if par.type == 'volume':
365        if ('length' in par.name or
366                'radius' in par.name or
367                'thick' in par.name):
368            return 10**np.random.uniform(2, 4)
369
370    # In the absence of any other info, select a value in [0, 2v], or
371    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
372    # model random parameter generators will override this default.
373    low, high = parameter_range(par.name, value)
374    limits = (max(par.limits[0], low), min(par.limits[1], high))
375    return np.random.uniform(*limits)
376
377def _random_pd(model_info, pars):
378    # type: (ModelInfo, Dict[str, float]) -> None
379    """
380    Generate a random dispersity distribution for the model.
381
382    1% no shape dispersity
383    85% single shape parameter
384    13% two shape parameters
385    1% three shape parameters
386
387    If oriented, then put dispersity in theta, add phi and psi dispersity
388    with 10% probability for each.
389    """
390    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
391    pd_volume = []
392    pd_oriented = []
393    for p in pd:
394        if p.type == 'orientation':
395            pd_oriented.append(p.name)
396        elif p.length_control is not None:
397            n = int(pars.get(p.length_control, 1) + 0.5)
398            pd_volume.extend(p.name+str(k+1) for k in range(n))
399        elif p.length > 1:
400            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
401        else:
402            pd_volume.append(p.name)
403    u = np.random.rand()
404    n = len(pd_volume)
405    if u < 0.01 or n < 1:
406        pass  # 1% chance of no polydispersity
407    elif u < 0.86 or n < 2:
408        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
409    elif u < 0.99 or n < 3:
410        choices = np.random.choice(len(pd_volume), size=2)
411        pars[pd_volume[choices[0]]+"_pd_n"] = 25
412        pars[pd_volume[choices[1]]+"_pd_n"] = 10
413    else:
414        choices = np.random.choice(len(pd_volume), size=3)
415        pars[pd_volume[choices[0]]+"_pd_n"] = 25
416        pars[pd_volume[choices[1]]+"_pd_n"] = 10
417        pars[pd_volume[choices[2]]+"_pd_n"] = 5
418    if pd_oriented:
419        pars['theta_pd_n'] = 20
420        if np.random.rand() < 0.1:
421            pars['phi_pd_n'] = 5
422        if np.random.rand() < 0.1:
423            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
424                #print("generating psi_pd_n")
425                pars['psi_pd_n'] = 5
426
427    ## Show selected polydispersity
428    #for name, value in pars.items():
429    #    if name.endswith('_pd_n') and value > 0:
430    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
431
432
433def randomize_pars(model_info, pars):
434    # type: (ModelInfo, ParameterSet) -> ParameterSet
435    """
436    Generate random values for all of the parameters.
437
438    Valid ranges for the random number generator are guessed from the name of
439    the parameter; this will not account for constraints such as cap radius
440    greater than cylinder radius in the capped_cylinder model, so
441    :func:`constrain_pars` needs to be called afterward..
442    """
443    # Note: the sort guarantees order of calls to random number generator
444    random_pars = dict((p, _randomize_one(model_info, p, v))
445                       for p, v in sorted(pars.items()))
446    if model_info.random is not None:
447        random_pars.update(model_info.random())
448    _random_pd(model_info, random_pars)
449    return random_pars
450
451
452def limit_dimensions(model_info, pars, maxdim):
453    # type: (ModelInfo, ParameterSet, float) -> None
454    """
455    Limit parameters of units of Ang to maxdim.
456    """
457    for p in model_info.parameters.call_parameters:
458        value = pars[p.name]
459        if p.units == 'Ang' and value > maxdim:
460            pars[p.name] = maxdim*10**np.random.uniform(-3, 0)
461
462def constrain_pars(model_info, pars):
463    # type: (ModelInfo, ParameterSet) -> None
464    """
465    Restrict parameters to valid values.
466
467    This includes model specific code for models such as capped_cylinder
468    which need to support within model constraints (cap radius more than
469    cylinder radius in this case).
470
471    Warning: this updates the *pars* dictionary in place.
472    """
473    # TODO: move the model specific code to the individual models
474    name = model_info.id
475    # if it is a product model, then just look at the form factor since
476    # none of the structure factors need any constraints.
477    if '*' in name:
478        name = name.split('*')[0]
479
480    # Suppress magnetism for python models (not yet implemented)
481    if callable(model_info.Iq):
482        pars.update(suppress_magnetism(pars))
483
484    if name == 'barbell':
485        if pars['radius_bell'] < pars['radius']:
486            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
487
488    elif name == 'capped_cylinder':
489        if pars['radius_cap'] < pars['radius']:
490            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
491
492    elif name == 'guinier':
493        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
494        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
495        # => ln A - (Rg^2 q^2/3) > -30 ln 10
496        # => Rg^2 q^2/3 < 30 ln 10 + ln A
497        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
498        #q_max = 0.2  # mid q maximum
499        q_max = 1.0  # high q maximum
500        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
501        pars['rg'] = min(pars['rg'], rg_max)
502
503    elif name == 'pearl_necklace':
504        if pars['radius'] < pars['thick_string']:
505            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
506
507    elif name == 'rpa':
508        # Make sure phi sums to 1.0
509        if pars['case_num'] < 2:
510            pars['Phi1'] = 0.
511            pars['Phi2'] = 0.
512        elif pars['case_num'] < 5:
513            pars['Phi1'] = 0.
514        total = sum(pars['Phi'+c] for c in '1234')
515        for c in '1234':
516            pars['Phi'+c] /= total
517
518def parlist(model_info, pars, is2d):
519    # type: (ModelInfo, ParameterSet, bool) -> str
520    """
521    Format the parameter list for printing.
522    """
523    is2d = True
524    lines = []
525    parameters = model_info.parameters
526    magnetic = False
527    magnetic_pars = []
528    for p in parameters.user_parameters(pars, is2d):
529        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
530            continue
531        if p.id.startswith('up:'):
532            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
533            continue
534        fields = dict(
535            value=pars.get(p.id, p.default),
536            pd=pars.get(p.id+"_pd", 0.),
537            n=int(pars.get(p.id+"_pd_n", 0)),
538            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
539            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
540            relative_pd=p.relative_pd,
541            M0=pars.get('M0:'+p.id, 0.),
542            mphi=pars.get('mphi:'+p.id, 0.),
543            mtheta=pars.get('mtheta:'+p.id, 0.),
544        )
545        lines.append(_format_par(p.name, **fields))
546        magnetic = magnetic or fields['M0'] != 0.
547    if magnetic and magnetic_pars:
548        lines.append(" ".join(magnetic_pars))
549    return "\n".join(lines)
550
551    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
552
553def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
554                relative_pd=False, M0=0., mphi=0., mtheta=0.):
555    # type: (str, float, float, int, float, str) -> str
556    line = "%s: %g"%(name, value)
557    if pd != 0.  and n != 0:
558        if relative_pd:
559            pd *= value
560        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
561                % (pd, n, nsigma, nsigma, pdtype)
562    if M0 != 0.:
563        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
564    return line
565
566def suppress_pd(pars, suppress=True):
567    # type: (ParameterSet) -> ParameterSet
568    """
569    If suppress is True complete eliminate polydispersity of the model to test
570    models more quickly.  If suppress is False, make sure at least one
571    parameter is polydisperse, setting the first polydispersity parameter to
572    15% if no polydispersity is given (with no explicit demo parameters given
573    in the model, there will be no default polydispersity).
574    """
575    pars = pars.copy()
576    #print("pars=", pars)
577    if suppress:
578        for p in pars:
579            if p.endswith("_pd_n"):
580                pars[p] = 0
581    else:
582        any_pd = False
583        first_pd = None
584        for p in pars:
585            if p.endswith("_pd_n"):
586                pd = pars.get(p[:-2], 0.)
587                any_pd |= (pars[p] != 0 and pd != 0.)
588                if first_pd is None:
589                    first_pd = p
590        if not any_pd and first_pd is not None:
591            if pars[first_pd] == 0:
592                pars[first_pd] = 35
593            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
594                pars[first_pd[:-2]] = 0.15
595    return pars
596
597def suppress_magnetism(pars, suppress=True):
598    # type: (ParameterSet) -> ParameterSet
599    """
600    If suppress is True complete eliminate magnetism of the model to test
601    models more quickly.  If suppress is False, make sure at least one sld
602    parameter is magnetic, setting the first parameter to have a strong
603    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
604    in the model, there will be no default magnetism).
605    """
606    pars = pars.copy()
607    if suppress:
608        for p in pars:
609            if p.startswith("M0:"):
610                pars[p] = 0
611    else:
612        any_mag = False
613        first_mag = None
614        for p in pars:
615            if p.startswith("M0:"):
616                any_mag |= (pars[p] != 0)
617                if first_mag is None:
618                    first_mag = p
619        if not any_mag and first_mag is not None:
620            pars[first_mag] = 8.
621    return pars
622
623
624DTYPE_MAP = {
625    'half': '16',
626    'fast': 'fast',
627    'single': '32',
628    'double': '64',
629    'quad': '128',
630    'f16': '16',
631    'f32': '32',
632    'f64': '64',
633    'float16': '16',
634    'float32': '32',
635    'float64': '64',
636    'float128': '128',
637    'longdouble': '128',
638}
639def eval_opencl(model_info, data, dtype='single', cutoff=0.):
640    # type: (ModelInfo, Data, str, float) -> Calculator
641    """
642    Return a model calculator using the OpenCL calculation engine.
643    """
644    if not core.HAVE_OPENCL:
645        raise RuntimeError("OpenCL not available")
646    model = core.build_model(model_info, dtype=dtype, platform="ocl")
647    calculator = DirectModel(data, model, cutoff=cutoff)
648    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
649    return calculator
650
651def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
652    # type: (ModelInfo, Data, str, float) -> Calculator
653    """
654    Return a model calculator using the DLL calculation engine.
655    """
656    model = core.build_model(model_info, dtype=dtype, platform="dll")
657    calculator = DirectModel(data, model, cutoff=cutoff)
658    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
659    return calculator
660
661def time_calculation(calculator, pars, evals=1):
662    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
663    """
664    Compute the average calculation time over N evaluations.
665
666    An additional call is generated without polydispersity in order to
667    initialize the calculation engine, and make the average more stable.
668    """
669    # initialize the code so time is more accurate
670    if evals > 1:
671        calculator(**suppress_pd(pars))
672    toc = tic()
673    # make sure there is at least one eval
674    value = calculator(**pars)
675    for _ in range(evals-1):
676        value = calculator(**pars)
677    average_time = toc()*1000. / evals
678    #print("I(q)",value)
679    return value, average_time
680
681def make_data(opts):
682    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
683    """
684    Generate an empty dataset, used with the model to set Q points
685    and resolution.
686
687    *opts* contains the options, with 'qmax', 'nq', 'res',
688    'accuracy', 'is2d' and 'view' parsed from the command line.
689    """
690    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res']
691    if opts['is2d']:
692        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
693        data = empty_data2D(q, resolution=res)
694        data.accuracy = opts['accuracy']
695        set_beam_stop(data, qmin)
696        index = ~data.mask
697    else:
698        if opts['view'] == 'log' and not opts['zero']:
699            q = np.logspace(math.log10(qmin), math.log10(qmax), nq)
700        else:
701            q = np.linspace(qmin, qmax, nq)
702        if opts['zero']:
703            q = np.hstack((0, q))
704        data = empty_data1D(q, resolution=res)
705        index = slice(None, None)
706    return data, index
707
708def make_engine(model_info, data, dtype, cutoff, ngauss=0):
709    # type: (ModelInfo, Data, str, float) -> Calculator
710    """
711    Generate the appropriate calculation engine for the given datatype.
712
713    Datatypes with '!' appended are evaluated using external C DLLs rather
714    than OpenCL.
715    """
716    if ngauss:
717        set_integration_size(model_info, ngauss)
718
719    if dtype is None or not dtype.endswith('!'):
720        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
721    else:
722        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
723
724def _show_invalid(data, theory):
725    # type: (Data, np.ma.ndarray) -> None
726    """
727    Display a list of the non-finite values in theory.
728    """
729    if not theory.mask.any():
730        return
731
732    if hasattr(data, 'x'):
733        bad = zip(data.x[theory.mask], theory[theory.mask])
734        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
735
736
737def compare(opts, limits=None, maxdim=np.inf):
738    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
739    """
740    Preform a comparison using options from the command line.
741
742    *limits* are the limits on the values to use, either to set the y-axis
743    for 1D or to set the colormap scale for 2D.  If None, then they are
744    inferred from the data and returned. When exploring using Bumps,
745    the limits are set when the model is initially called, and maintained
746    as the values are adjusted, making it easier to see the effects of the
747    parameters.
748
749    *maxdim* is the maximum value for any parameter with units of Angstrom.
750    """
751    for k in range(opts['sets']):
752        if k > 1:
753            # print a separate seed for each dataset for better reproducibility
754            new_seed = np.random.randint(1000000)
755            print("Set %d uses -random=%i"%(k+1, new_seed))
756            np.random.seed(new_seed)
757        opts['pars'] = parse_pars(opts, maxdim=maxdim)
758        if opts['pars'] is None:
759            return
760        result = run_models(opts, verbose=True)
761        if opts['plot']:
762            limits = plot_models(opts, result, limits=limits, setnum=k)
763        if opts['show_weights']:
764            base, _ = opts['engines']
765            base_pars, _ = opts['pars']
766            model_info = base._kernel.info
767            dim = base._kernel.dim
768            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
769    if opts['plot']:
770        import matplotlib.pyplot as plt
771        plt.show()
772    return limits
773
774def run_models(opts, verbose=False):
775    # type: (Dict[str, Any]) -> Dict[str, Any]
776    """
777    Process a parameter set, return calculation results and times.
778    """
779
780    base, comp = opts['engines']
781    base_n, comp_n = opts['count']
782    base_pars, comp_pars = opts['pars']
783    data = opts['data']
784
785    comparison = comp is not None
786
787    base_time = comp_time = None
788    base_value = comp_value = resid = relerr = None
789
790    # Base calculation
791    try:
792        base_raw, base_time = time_calculation(base, base_pars, base_n)
793        base_value = np.ma.masked_invalid(base_raw)
794        if verbose:
795            print("%s t=%.2f ms, intensity=%.0f"
796                  % (base.engine, base_time, base_value.sum()))
797        _show_invalid(data, base_value)
798    except ImportError:
799        traceback.print_exc()
800
801    # Comparison calculation
802    if comparison:
803        try:
804            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
805            comp_value = np.ma.masked_invalid(comp_raw)
806            if verbose:
807                print("%s t=%.2f ms, intensity=%.0f"
808                      % (comp.engine, comp_time, comp_value.sum()))
809            _show_invalid(data, comp_value)
810        except ImportError:
811            traceback.print_exc()
812
813    # Compare, but only if computing both forms
814    if comparison:
815        resid = (base_value - comp_value)
816        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
817        if verbose:
818            _print_stats("|%s-%s|"
819                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
820                         resid)
821            _print_stats("|(%s-%s)/%s|"
822                         % (base.engine, comp.engine, comp.engine),
823                         relerr)
824
825    return dict(base_value=base_value, comp_value=comp_value,
826                base_time=base_time, comp_time=comp_time,
827                resid=resid, relerr=relerr)
828
829
830def _print_stats(label, err):
831    # type: (str, np.ma.ndarray) -> None
832    # work with trimmed data, not the full set
833    sorted_err = np.sort(abs(err.compressed()))
834    if len(sorted_err) == 0:
835        print(label + "  no valid values")
836        return
837
838    p50 = int((len(sorted_err)-1)*0.50)
839    p98 = int((len(sorted_err)-1)*0.98)
840    data = [
841        "max:%.3e"%sorted_err[-1],
842        "median:%.3e"%sorted_err[p50],
843        "98%%:%.3e"%sorted_err[p98],
844        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
845        "zero-offset:%+.3e"%np.mean(sorted_err),
846        ]
847    print(label+"  "+"  ".join(data))
848
849
850def plot_models(opts, result, limits=None, setnum=0):
851    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
852    """
853    Plot the results from :func:`run_model`.
854    """
855    import matplotlib.pyplot as plt
856
857    base_value, comp_value = result['base_value'], result['comp_value']
858    base_time, comp_time = result['base_time'], result['comp_time']
859    resid, relerr = result['resid'], result['relerr']
860
861    have_base, have_comp = (base_value is not None), (comp_value is not None)
862    base, comp = opts['engines']
863    data = opts['data']
864    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
865
866    # Plot if requested
867    view = opts['view']
868    if limits is None:
869        vmin, vmax = np.inf, -np.inf
870        if have_base:
871            vmin = min(vmin, base_value.min())
872            vmax = max(vmax, base_value.max())
873        if have_comp:
874            vmin = min(vmin, comp_value.min())
875            vmax = max(vmax, comp_value.max())
876        limits = vmin, vmax
877
878    if have_base:
879        if have_comp:
880            plt.subplot(131)
881        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
882        plt.title("%s t=%.2f ms"%(base.engine, base_time))
883        #cbar_title = "log I"
884    if have_comp:
885        if have_base:
886            plt.subplot(132)
887        if not opts['is2d'] and have_base:
888            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
889        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
890        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
891        #cbar_title = "log I"
892    if have_base and have_comp:
893        plt.subplot(133)
894        if not opts['rel_err']:
895            err, errstr, errview = resid, "abs err", "linear"
896        else:
897            err, errstr, errview = abs(relerr), "rel err", "log"
898            if (err == 0.).all():
899                errview = 'linear'
900        if 0:  # 95% cutoff
901            sorted_err = np.sort(err.flatten())
902            cutoff = sorted_err[int(sorted_err.size*0.95)]
903            err[err > cutoff] = cutoff
904        #err,errstr = base/comp,"ratio"
905        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
906        plt.xscale('log' if view == 'log' and not opts['is2d'] else 'linear')
907        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
908        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
909        #cbar_title = errstr if errview=="linear" else "log "+errstr
910    #if is2D:
911    #    h = plt.colorbar()
912    #    h.ax.set_title(cbar_title)
913    fig = plt.gcf()
914    extra_title = ' '+opts['title'] if opts['title'] else ''
915    fig.suptitle(":".join(opts['name']) + extra_title)
916
917    if have_base and have_comp and opts['show_hist']:
918        plt.figure()
919        v = relerr
920        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
921        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
922        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
923                   % (base.engine, comp.engine, comp.engine))
924        plt.ylabel('P(err)')
925        plt.title('Distribution of relative error between calculation engines')
926
927    return limits
928
929
930# ===========================================================================
931#
932
933# Set of command line options.
934# Normal options such as -plot/-noplot are specified as 'name'.
935# For options such as -nq=500 which require a value use 'name='.
936#
937OPTIONS = [
938    # Plotting
939    'plot', 'noplot', 'weights',
940    'linear', 'log', 'q4',
941    'rel', 'abs',
942    'hist', 'nohist',
943    'title=',
944
945    # Data generation
946    'data=', 'noise=', 'res=', 'nq=', 'q=',
947    'lowq', 'midq', 'highq', 'exq', 'zero',
948    '2d', '1d',
949
950    # Parameter set
951    'preset', 'random', 'random=', 'sets=',
952    'demo', 'default',  # TODO: remove demo/default
953    'nopars', 'pars',
954    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
955
956    # Calculation options
957    'poly', 'mono', 'cutoff=',
958    'magnetic', 'nonmagnetic',
959    'accuracy=', 'ngauss=',
960    'neval=',  # for timing...
961
962    # Precision options
963    'engine=',
964    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
965
966    # Output options
967    'help', 'html', 'edit',
968    ]
969
970NAME_OPTIONS = (lambda: set(k for k in OPTIONS if not k.endswith('=')))()
971VALUE_OPTIONS = (lambda: [k[:-1] for k in OPTIONS if k.endswith('=')])()
972
973
974def columnize(items, indent="", width=79):
975    # type: (List[str], str, int) -> str
976    """
977    Format a list of strings into columns.
978
979    Returns a string with carriage returns ready for printing.
980    """
981    column_width = max(len(w) for w in items) + 1
982    num_columns = (width - len(indent)) // column_width
983    num_rows = len(items) // num_columns
984    items = items + [""] * (num_rows * num_columns - len(items))
985    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
986    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
987             for row in zip(*columns)]
988    output = indent + ("\n"+indent).join(lines)
989    return output
990
991
992def get_pars(model_info, use_demo=False):
993    # type: (ModelInfo, bool) -> ParameterSet
994    """
995    Extract demo parameters from the model definition.
996    """
997    # Get the default values for the parameters
998    pars = {}
999    for p in model_info.parameters.call_parameters:
1000        parts = [('', p.default)]
1001        if p.polydisperse:
1002            parts.append(('_pd', 0.0))
1003            parts.append(('_pd_n', 0))
1004            parts.append(('_pd_nsigma', 3.0))
1005            parts.append(('_pd_type', "gaussian"))
1006        for ext, val in parts:
1007            if p.length > 1:
1008                dict(("%s%d%s" % (p.id, k, ext), val)
1009                     for k in range(1, p.length+1))
1010            else:
1011                pars[p.id + ext] = val
1012
1013    # Plug in values given in demo
1014    if use_demo and model_info.demo:
1015        pars.update(model_info.demo)
1016    return pars
1017
1018INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1019def isnumber(s):
1020    # type: (str) -> bool
1021    """Return True if string contains an int or float"""
1022    match = FLOAT_RE.match(s)
1023    isfloat = (match and not s[match.end():])
1024    return isfloat or INTEGER_RE.match(s)
1025
1026# For distinguishing pairs of models for comparison
1027# key-value pair separator =
1028# shell characters  | & ; <> $ % ' " \ # `
1029# model and parameter names _
1030# parameter expressions - + * / . ( )
1031# path characters including tilde expansion and windows drive ~ / :
1032# not sure about brackets [] {}
1033# maybe one of the following @ ? ^ ! ,
1034PAR_SPLIT = ','
1035def parse_opts(argv):
1036    # type: (List[str]) -> Dict[str, Any]
1037    """
1038    Parse command line options.
1039    """
1040    MODELS = core.list_models()
1041    flags = [arg for arg in argv
1042             if arg.startswith('-')]
1043    values = [arg for arg in argv
1044              if not arg.startswith('-') and '=' in arg]
1045    positional_args = [arg for arg in argv
1046                       if not arg.startswith('-') and '=' not in arg]
1047    models = "\n    ".join("%-15s"%v for v in MODELS)
1048    if len(positional_args) == 0:
1049        print(USAGE)
1050        print("\nAvailable models:")
1051        print(columnize(MODELS, indent="  "))
1052        return None
1053
1054    invalid = [o[1:] for o in flags
1055               if o[1:] not in NAME_OPTIONS
1056               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1057    if invalid:
1058        print("Invalid options: %s"%(", ".join(invalid)))
1059        return None
1060
1061    name = positional_args[-1]
1062
1063    # pylint: disable=bad-whitespace,C0321
1064    # Interpret the flags
1065    opts = {
1066        'plot'      : True,
1067        'view'      : 'log',
1068        'is2d'      : False,
1069        'qmin'      : None,
1070        'qmax'      : 0.05,
1071        'nq'        : 128,
1072        'res'       : 0.0,
1073        'noise'     : 0.0,
1074        'accuracy'  : 'Low',
1075        'cutoff'    : '0.0',
1076        'seed'      : -1,  # default to preset
1077        'mono'      : True,
1078        # Default to magnetic a magnetic moment is set on the command line
1079        'magnetic'  : False,
1080        'show_pars' : False,
1081        'show_hist' : False,
1082        'rel_err'   : True,
1083        'explore'   : False,
1084        'use_demo'  : True,
1085        'zero'      : False,
1086        'html'      : False,
1087        'title'     : None,
1088        'datafile'  : None,
1089        'sets'      : 0,
1090        'engine'    : 'default',
1091        'count'     : '1',
1092        'show_weights' : False,
1093        'sphere'    : 0,
1094        'ngauss'    : '0',
1095    }
1096    for arg in flags:
1097        if arg == '-noplot':    opts['plot'] = False
1098        elif arg == '-plot':    opts['plot'] = True
1099        elif arg == '-linear':  opts['view'] = 'linear'
1100        elif arg == '-log':     opts['view'] = 'log'
1101        elif arg == '-q4':      opts['view'] = 'q4'
1102        elif arg == '-1d':      opts['is2d'] = False
1103        elif arg == '-2d':      opts['is2d'] = True
1104        elif arg == '-exq':     opts['qmax'] = 10.0
1105        elif arg == '-highq':   opts['qmax'] = 1.0
1106        elif arg == '-midq':    opts['qmax'] = 0.2
1107        elif arg == '-lowq':    opts['qmax'] = 0.05
1108        elif arg == '-zero':    opts['zero'] = True
1109        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1110        elif arg.startswith('-q='):
1111            opts['qmin'], opts['qmax'] = [float(v) for v in arg[3:].split(':')]
1112        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1113        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1114        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1115        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1116        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1117        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1118        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1119        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1120        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
1121        elif arg.startswith('-ngauss='):   opts['ngauss'] = arg[8:]
1122        elif arg.startswith('-random='):
1123            opts['seed'] = int(arg[8:])
1124            opts['sets'] = 0
1125        elif arg == '-random':
1126            opts['seed'] = np.random.randint(1000000)
1127            opts['sets'] = 0
1128        elif arg.startswith('-sphere'):
1129            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1130            opts['is2d'] = True
1131        elif arg == '-preset':  opts['seed'] = -1
1132        elif arg == '-mono':    opts['mono'] = True
1133        elif arg == '-poly':    opts['mono'] = False
1134        elif arg == '-magnetic':       opts['magnetic'] = True
1135        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1136        elif arg == '-pars':    opts['show_pars'] = True
1137        elif arg == '-nopars':  opts['show_pars'] = False
1138        elif arg == '-hist':    opts['show_hist'] = True
1139        elif arg == '-nohist':  opts['show_hist'] = False
1140        elif arg == '-rel':     opts['rel_err'] = True
1141        elif arg == '-abs':     opts['rel_err'] = False
1142        elif arg == '-half':    opts['engine'] = 'half'
1143        elif arg == '-fast':    opts['engine'] = 'fast'
1144        elif arg == '-single':  opts['engine'] = 'single'
1145        elif arg == '-double':  opts['engine'] = 'double'
1146        elif arg == '-single!': opts['engine'] = 'single!'
1147        elif arg == '-double!': opts['engine'] = 'double!'
1148        elif arg == '-quad!':   opts['engine'] = 'quad!'
1149        elif arg == '-edit':    opts['explore'] = True
1150        elif arg == '-demo':    opts['use_demo'] = True
1151        elif arg == '-default': opts['use_demo'] = False
1152        elif arg == '-weights': opts['show_weights'] = True
1153        elif arg == '-html':    opts['html'] = True
1154        elif arg == '-help':    opts['html'] = True
1155    # pylint: enable=bad-whitespace,C0321
1156
1157    # Magnetism forces 2D for now
1158    if opts['magnetic']:
1159        opts['is2d'] = True
1160
1161    # Force random if sets is used
1162    if opts['sets'] >= 1 and opts['seed'] < 0:
1163        opts['seed'] = np.random.randint(1000000)
1164    if opts['sets'] == 0:
1165        opts['sets'] = 1
1166
1167    # Create the computational engines
1168    if opts['qmin'] is None:
1169        opts['qmin'] = 0.001*opts['qmax']
1170    if opts['datafile'] is not None:
1171        data = load_data(os.path.expanduser(opts['datafile']))
1172    else:
1173        data, _ = make_data(opts)
1174
1175    comparison = any(PAR_SPLIT in v for v in values)
1176
1177    if PAR_SPLIT in name:
1178        names = name.split(PAR_SPLIT, 2)
1179        comparison = True
1180    else:
1181        names = [name]*2
1182    try:
1183        model_info = [core.load_model_info(k) for k in names]
1184    except ImportError as exc:
1185        print(str(exc))
1186        print("Could not find model; use one of:\n    " + models)
1187        return None
1188
1189    if PAR_SPLIT in opts['ngauss']:
1190        opts['ngauss'] = [int(k) for k in opts['ngauss'].split(PAR_SPLIT, 2)]
1191        comparison = True
1192    else:
1193        opts['ngauss'] = [int(opts['ngauss'])]*2
1194
1195    if PAR_SPLIT in opts['engine']:
1196        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
1197        comparison = True
1198    else:
1199        opts['engine'] = [opts['engine']]*2
1200
1201    if PAR_SPLIT in opts['count']:
1202        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
1203        comparison = True
1204    else:
1205        opts['count'] = [int(opts['count'])]*2
1206
1207    if PAR_SPLIT in opts['cutoff']:
1208        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1209        comparison = True
1210    else:
1211        opts['cutoff'] = [float(opts['cutoff'])]*2
1212
1213    base = make_engine(model_info[0], data, opts['engine'][0],
1214                       opts['cutoff'][0], opts['ngauss'][0])
1215    if comparison:
1216        comp = make_engine(model_info[1], data, opts['engine'][1],
1217                           opts['cutoff'][1], opts['ngauss'][1])
1218    else:
1219        comp = None
1220
1221    # pylint: disable=bad-whitespace
1222    # Remember it all
1223    opts.update({
1224        'data'      : data,
1225        'name'      : names,
1226        'info'      : model_info,
1227        'engines'   : [base, comp],
1228        'values'    : values,
1229    })
1230    # pylint: enable=bad-whitespace
1231
1232    # Set the integration parameters to the half sphere
1233    if opts['sphere'] > 0:
1234        set_spherical_integration_parameters(opts, opts['sphere'])
1235
1236    return opts
1237
1238def set_spherical_integration_parameters(opts, steps):
1239    # type: (Dict[str, Any], int) -> None
1240    """
1241    Set integration parameters for spherical integration over the entire
1242    surface in theta-phi coordinates.
1243    """
1244    # Set the integration parameters to the half sphere
1245    opts['values'].extend([
1246        #'theta=90',
1247        'theta_pd=%g'%(90/np.sqrt(3)),
1248        'theta_pd_n=%d'%steps,
1249        'theta_pd_type=rectangle',
1250        #'phi=0',
1251        'phi_pd=%g'%(180/np.sqrt(3)),
1252        'phi_pd_n=%d'%(2*steps),
1253        'phi_pd_type=rectangle',
1254        #'background=0',
1255    ])
1256    if 'psi' in opts['info'][0].parameters:
1257        opts['values'].extend([
1258            #'psi=0',
1259            'psi_pd=%g'%(180/np.sqrt(3)),
1260            'psi_pd_n=%d'%(2*steps),
1261            'psi_pd_type=rectangle',
1262        ])
1263
1264def parse_pars(opts, maxdim=np.inf):
1265    # type: (Dict[str, Any], float) -> Tuple[Dict[str, float], Dict[str, float]]
1266    """
1267    Generate a parameter set.
1268
1269    The default values come from the model, or a randomized model if a seed
1270    value is given.  Next, evaluate any parameter expressions, constraining
1271    the value of the parameter within and between models.  If *maxdim* is
1272    given, limit parameters with units of Angstrom to this value.
1273
1274    Returns a pair of parameter dictionaries for base and comparison models.
1275    """
1276    model_info, model_info2 = opts['info']
1277
1278    # Get demo parameters from model definition, or use default parameters
1279    # if model does not define demo parameters
1280    pars = get_pars(model_info, opts['use_demo'])
1281    pars2 = get_pars(model_info2, opts['use_demo'])
1282    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1283    # randomize parameters
1284    #pars.update(set_pars)  # set value before random to control range
1285    if opts['seed'] > -1:
1286        pars = randomize_pars(model_info, pars)
1287        limit_dimensions(model_info, pars, maxdim)
1288        if model_info != model_info2:
1289            pars2 = randomize_pars(model_info2, pars2)
1290            limit_dimensions(model_info2, pars2, maxdim)
1291            # Share values for parameters with the same name
1292            for k, v in pars.items():
1293                if k in pars2:
1294                    pars2[k] = v
1295        else:
1296            pars2 = pars.copy()
1297        constrain_pars(model_info, pars)
1298        constrain_pars(model_info2, pars2)
1299    pars = suppress_pd(pars, opts['mono'])
1300    pars2 = suppress_pd(pars2, opts['mono'])
1301    pars = suppress_magnetism(pars, not opts['magnetic'])
1302    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1303
1304    # Fill in parameters given on the command line
1305    presets = {}
1306    presets2 = {}
1307    for arg in opts['values']:
1308        k, v = arg.split('=', 1)
1309        if k not in pars and k not in pars2:
1310            # extract base name without polydispersity info
1311            s = set(p.split('_pd')[0] for p in pars)
1312            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1313            return None
1314        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v, v)
1315        if v1 and k in pars:
1316            presets[k] = float(v1) if isnumber(v1) else v1
1317        if v2 and k in pars2:
1318            presets2[k] = float(v2) if isnumber(v2) else v2
1319
1320    # If pd given on the command line, default pd_n to 35
1321    for k, v in list(presets.items()):
1322        if k.endswith('_pd'):
1323            presets.setdefault(k+'_n', 35.)
1324    for k, v in list(presets2.items()):
1325        if k.endswith('_pd'):
1326            presets2.setdefault(k+'_n', 35.)
1327
1328    # Evaluate preset parameter expressions
1329    context = MATH.copy()
1330    context['np'] = np
1331    context.update(pars)
1332    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1333    for k, v in presets.items():
1334        if not isinstance(v, float) and not k.endswith('_type'):
1335            presets[k] = eval(v, context)
1336    context.update(presets)
1337    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1338    for k, v in presets2.items():
1339        if not isinstance(v, float) and not k.endswith('_type'):
1340            presets2[k] = eval(v, context)
1341
1342    # update parameters with presets
1343    pars.update(presets)  # set value after random to control value
1344    pars2.update(presets2)  # set value after random to control value
1345    #import pprint; pprint.pprint(model_info)
1346
1347    if opts['show_pars']:
1348        if model_info.name != model_info2.name or pars != pars2:
1349            print("==== %s ====="%model_info.name)
1350            print(str(parlist(model_info, pars, opts['is2d'])))
1351            print("==== %s ====="%model_info2.name)
1352            print(str(parlist(model_info2, pars2, opts['is2d'])))
1353        else:
1354            print(str(parlist(model_info, pars, opts['is2d'])))
1355
1356    return pars, pars2
1357
1358def show_docs(opts):
1359    # type: (Dict[str, Any]) -> None
1360    """
1361    show html docs for the model
1362    """
1363    from .generate import make_html
1364    from . import rst2html
1365
1366    info = opts['info'][0]
1367    html = make_html(info)
1368    path = os.path.dirname(info.filename)
1369    url = "file://" + path.replace("\\", "/")[2:] + "/"
1370    rst2html.view_html_qtapp(html, url)
1371
1372def explore(opts):
1373    # type: (Dict[str, Any]) -> None
1374    """
1375    explore the model using the bumps gui.
1376    """
1377    import wx  # type: ignore
1378    from bumps.names import FitProblem  # type: ignore
1379    from bumps.gui.app_frame import AppFrame  # type: ignore
1380    from bumps.gui import signal
1381
1382    is_mac = "cocoa" in wx.version()
1383    # Create an app if not running embedded
1384    app = wx.App() if wx.GetApp() is None else None
1385    model = Explore(opts)
1386    problem = FitProblem(model)
1387    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1388    if not is_mac:
1389        frame.Show()
1390    frame.panel.set_model(model=problem)
1391    frame.panel.Layout()
1392    frame.panel.aui.Split(0, wx.TOP)
1393    def _reset_parameters(event):
1394        model.revert_values()
1395        signal.update_parameters(problem)
1396    frame.Bind(wx.EVT_TOOL, _reset_parameters, frame.ToolBar.GetToolByPos(1))
1397    if is_mac:
1398        frame.Show()
1399    # If running withing an app, start the main loop
1400    if app:
1401        app.MainLoop()
1402
1403class Explore(object):
1404    """
1405    Bumps wrapper for a SAS model comparison.
1406
1407    The resulting object can be used as a Bumps fit problem so that
1408    parameters can be adjusted in the GUI, with plots updated on the fly.
1409    """
1410    def __init__(self, opts):
1411        # type: (Dict[str, Any]) -> None
1412        from bumps.cli import config_matplotlib  # type: ignore
1413        from . import bumps_model
1414        config_matplotlib()
1415        self.opts = opts
1416        opts['pars'] = list(opts['pars'])
1417        p1, p2 = opts['pars']
1418        m1, m2 = opts['info']
1419        self.fix_p2 = m1 != m2 or p1 != p2
1420        model_info = m1
1421        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1422        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1423        if not opts['is2d']:
1424            for p in model_info.parameters.user_parameters({}, is2d=False):
1425                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1426                    k = p.name+ext
1427                    v = pars.get(k, None)
1428                    if v is not None:
1429                        v.range(*parameter_range(k, v.value))
1430        else:
1431            for k, v in pars.items():
1432                v.range(*parameter_range(k, v.value))
1433
1434        self.pars = pars
1435        self.starting_values = dict((k, v.value) for k, v in pars.items())
1436        self.pd_types = pd_types
1437        self.limits = None
1438
1439    def revert_values(self):
1440        # type: () -> None
1441        """
1442        Restore starting values of the parameters.
1443        """
1444        for k, v in self.starting_values.items():
1445            self.pars[k].value = v
1446
1447    def model_update(self):
1448        # type: () -> None
1449        """
1450        Respond to signal that model parameters have been changed.
1451        """
1452        pass
1453
1454    def numpoints(self):
1455        # type: () -> int
1456        """
1457        Return the number of points.
1458        """
1459        return len(self.pars) + 1  # so dof is 1
1460
1461    def parameters(self):
1462        # type: () -> Any   # Dict/List hierarchy of parameters
1463        """
1464        Return a dictionary of parameters.
1465        """
1466        return self.pars
1467
1468    def nllf(self):
1469        # type: () -> float
1470        """
1471        Return cost.
1472        """
1473        # pylint: disable=no-self-use
1474        return 0.  # No nllf
1475
1476    def plot(self, view='log'):
1477        # type: (str) -> None
1478        """
1479        Plot the data and residuals.
1480        """
1481        pars = dict((k, v.value) for k, v in self.pars.items())
1482        pars.update(self.pd_types)
1483        self.opts['pars'][0] = pars
1484        if not self.fix_p2:
1485            self.opts['pars'][1] = pars
1486        result = run_models(self.opts)
1487        limits = plot_models(self.opts, result, limits=self.limits)
1488        if self.limits is None:
1489            vmin, vmax = limits
1490            self.limits = vmax*1e-7, 1.3*vmax
1491            import pylab; pylab.clf()
1492            plot_models(self.opts, result, limits=self.limits)
1493
1494
1495def main(*argv):
1496    # type: (*str) -> None
1497    """
1498    Main program.
1499    """
1500    opts = parse_opts(argv)
1501    if opts is not None:
1502        if opts['seed'] > -1:
1503            print("Randomize using -random=%i"%opts['seed'])
1504            np.random.seed(opts['seed'])
1505        if opts['html']:
1506            show_docs(opts)
1507        elif opts['explore']:
1508            opts['pars'] = parse_pars(opts)
1509            if opts['pars'] is None:
1510                return
1511            explore(opts)
1512        else:
1513            compare(opts)
1514
1515if __name__ == "__main__":
1516    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.