source: sasmodels/sasmodels/compare.py @ 1e7b202a

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 1e7b202a was 1e7b202a, checked in by Paul Kienzle <pkienzle@…>, 4 years ago

add profile plot to sascomp

  • Property mode set to 100755
File size: 55.7 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import kernelcl
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel, get_mesh
45from .generate import FLOAT_RE, set_integration_size
46from .weights import plot_weights
47
48# pylint: disable=unused-import
49try:
50    from typing import Optional, Dict, Any, Callable, Tuple
51except ImportError:
52    pass
53else:
54    from .modelinfo import ModelInfo, Parameter, ParameterSet
55    from .data import Data
56    Calculator = Callable[[float], np.ndarray]
57# pylint: enable=unused-import
58
59USAGE = """
60usage: sascomp model [options...] [key=val]
61
62Generate and compare SAS models.  If a single model is specified it shows
63a plot of that model.  Different models can be compared, or the same model
64with different parameters.  The same model with the same parameters can
65be compared with different calculation engines to see the effects of precision
66on the resultant values.
67
68model or model1,model2 are the names of the models to compare (see below).
69
70Options (* for default):
71
72    === data generation ===
73    -data="path" uses q, dq from the data file
74    -noise=0 sets the measurement error dI/I
75    -res=0 sets the resolution width dQ/Q if calculating with resolution
76    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
77    -q=min:max alternative specification of qrange
78    -nq=128 sets the number of Q points in the data set
79    -1d*/-2d computes 1d or 2d data
80    -zero indicates that q=0 should be included
81
82    === model parameters ===
83    -preset*/-random[=seed] preset or random parameters
84    -sets=n generates n random datasets with the seed given by -random=seed
85    -pars/-nopars* prints the parameter set or not
86    -default/-demo* use demo vs default parameters
87    -sphere[=150] set up spherical integration over theta/phi using n points
88
89    === calculation options ===
90    -mono*/-poly force monodisperse or allow polydisperse random parameters
91    -cutoff=1e-5* cutoff value for including a point in polydispersity
92    -magnetic/-nonmagnetic* suppress magnetism
93    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
94    -neval=1 sets the number of evals for more accurate timing
95    -ngauss=0 overrides the number of points in the 1-D gaussian quadrature
96
97    === precision options ===
98    -engine=default uses the default calcution precision
99    -single/-double/-half/-fast sets an OpenCL calculation engine
100    -single!/-double!/-quad! sets an OpenMP calculation engine
101
102    === plotting ===
103    -plot*/-noplot plots or suppress the plot of the model
104    -linear/-log*/-q4 intensity scaling on plots
105    -hist/-nohist* plot histogram of relative error
106    -abs/-rel* plot relative or absolute error
107    -title="note" adds note to the plot title, after the model name
108    -weights shows weights plots for the polydisperse parameters
109    -profile shows the sld profile if the model has a plottable sld profile
110
111    === output options ===
112    -edit starts the parameter explorer
113    -help/-html shows the model docs instead of running the model
114
115The interpretation of quad precision depends on architecture, and may
116vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
117On unix and mac you may need single quotes around the DLL computation
118engines, such as -engine='single!,double!' since !, is treated as a history
119expansion request in the shell.
120
121Key=value pairs allow you to set specific values for the model parameters.
122Key=value1,value2 to compare different values of the same parameter. The
123value can be an expression including other parameters.
124
125Items later on the command line override those that appear earlier.
126
127Examples:
128
129    # compare single and double precision calculation for a barbell
130    sascomp barbell -engine=single,double
131
132    # generate 10 random lorentz models, with seed=27
133    sascomp lorentz -sets=10 -seed=27
134
135    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
136    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
137
138    # model timing test requires multiple evals to perform the estimate
139    sascomp pringle -engine=single,double -timing=100,100 -noplot
140"""
141
142# Update docs with command line usage string.   This is separate from the usual
143# doc string so that we can display it at run time if there is an error.
144# lin
145__doc__ = (__doc__  # pylint: disable=redefined-builtin
146           + """
147Program description
148-------------------
149
150""" + USAGE)
151
152kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
153
154def build_math_context():
155    # type: () -> Dict[str, Callable]
156    """build dictionary of functions from math module"""
157    return dict((k, getattr(math, k))
158                for k in dir(math) if not k.startswith('_'))
159
160#: list of math functions for use in evaluating parameters
161MATH = build_math_context()
162
163# CRUFT python 2.6
164if not hasattr(datetime.timedelta, 'total_seconds'):
165    def delay(dt):
166        """Return number date-time delta as number seconds"""
167        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
168else:
169    def delay(dt):
170        """Return number date-time delta as number seconds"""
171        return dt.total_seconds()
172
173
174class push_seed(object):
175    """
176    Set the seed value for the random number generator.
177
178    When used in a with statement, the random number generator state is
179    restored after the with statement is complete.
180
181    :Parameters:
182
183    *seed* : int or array_like, optional
184        Seed for RandomState
185
186    :Example:
187
188    Seed can be used directly to set the seed::
189
190        >>> from numpy.random import randint
191        >>> push_seed(24)
192        <...push_seed object at...>
193        >>> print(randint(0,1000000,3))
194        [242082    899 211136]
195
196    Seed can also be used in a with statement, which sets the random
197    number generator state for the enclosed computations and restores
198    it to the previous state on completion::
199
200        >>> with push_seed(24):
201        ...    print(randint(0,1000000,3))
202        [242082    899 211136]
203
204    Using nested contexts, we can demonstrate that state is indeed
205    restored after the block completes::
206
207        >>> with push_seed(24):
208        ...    print(randint(0,1000000))
209        ...    with push_seed(24):
210        ...        print(randint(0,1000000,3))
211        ...    print(randint(0,1000000))
212        242082
213        [242082    899 211136]
214        899
215
216    The restore step is protected against exceptions in the block::
217
218        >>> with push_seed(24):
219        ...    print(randint(0,1000000))
220        ...    try:
221        ...        with push_seed(24):
222        ...            print(randint(0,1000000,3))
223        ...            raise Exception()
224        ...    except Exception:
225        ...        print("Exception raised")
226        ...    print(randint(0,1000000))
227        242082
228        [242082    899 211136]
229        Exception raised
230        899
231    """
232    def __init__(self, seed=None):
233        # type: (Optional[int]) -> None
234        self._state = np.random.get_state()
235        np.random.seed(seed)
236
237    def __enter__(self):
238        # type: () -> None
239        pass
240
241    def __exit__(self, exc_type, exc_value, trace):
242        # type: (Any, BaseException, Any) -> None
243        np.random.set_state(self._state)
244
245def tic():
246    # type: () -> Callable[[], float]
247    """
248    Timer function.
249
250    Use "toc=tic()" to start the clock and "toc()" to measure
251    a time interval.
252    """
253    then = datetime.datetime.now()
254    return lambda: delay(datetime.datetime.now() - then)
255
256
257def set_beam_stop(data, radius, outer=None):
258    # type: (Data, float, float) -> None
259    """
260    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
261    """
262    if hasattr(data, 'qx_data'):
263        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
264        data.mask = (q < radius)
265        if outer is not None:
266            data.mask |= (q >= outer)
267    else:
268        data.mask = (data.x < radius)
269        if outer is not None:
270            data.mask |= (data.x >= outer)
271
272
273def parameter_range(p, v):
274    # type: (str, float) -> Tuple[float, float]
275    """
276    Choose a parameter range based on parameter name and initial value.
277    """
278    # process the polydispersity options
279    if p.endswith('_pd_n'):
280        return 0., 100.
281    elif p.endswith('_pd_nsigma'):
282        return 0., 5.
283    elif p.endswith('_pd_type'):
284        raise ValueError("Cannot return a range for a string value")
285    elif any(s in p for s in ('theta', 'phi', 'psi')):
286        # orientation in [-180,180], orientation pd in [0,45]
287        if p.endswith('_pd'):
288            return 0., 180.
289        else:
290            return -180., 180.
291    elif p.endswith('_pd'):
292        return 0., 1.
293    elif 'sld' in p:
294        return -0.5, 10.
295    elif p == 'background':
296        return 0., 10.
297    elif p == 'scale':
298        return 0., 1.e3
299    elif v < 0.:
300        return 2.*v, -2.*v
301    else:
302        return 0., (2.*v if v > 0. else 1.)
303
304
305def _randomize_one(model_info, name, value):
306    # type: (ModelInfo, str, float) -> float
307    # type: (ModelInfo, str, str) -> str
308    """
309    Randomize a single parameter.
310    """
311    # Set the amount of polydispersity/angular dispersion, but by default pd_n
312    # is zero so there is no polydispersity.  This allows us to turn on/off
313    # pd by setting pd_n, and still have randomly generated values
314    if name.endswith('_pd'):
315        par = model_info.parameters[name[:-3]]
316        if par.type == 'orientation':
317            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
318            return 180*np.random.beta(2.5, 20)
319        else:
320            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
321            return np.random.beta(1.5, 7)
322
323    # pd is selected globally rather than per parameter, so set to 0 for no pd
324    # In particular, when multiple pd dimensions, want to decrease the number
325    # of points per dimension for faster computation
326    if name.endswith('_pd_n'):
327        return 0
328
329    # Don't mess with distribution type for now
330    if name.endswith('_pd_type'):
331        return 'gaussian'
332
333    # type-dependent value of number of sigmas; for gaussian use 3.
334    if name.endswith('_pd_nsigma'):
335        return 3.
336
337    # background in the range [0.01, 1]
338    if name == 'background':
339        return 10**np.random.uniform(-2, 0)
340
341    # scale defaults to 0.1% to 30% volume fraction
342    if name == 'scale':
343        return 10**np.random.uniform(-3, -0.5)
344
345    # If it is a list of choices, pick one at random with equal probability
346    # In practice, the model specific random generator will override.
347    par = model_info.parameters[name]
348    if len(par.limits) > 2:  # choice list
349        return np.random.randint(len(par.limits))
350
351    # If it is a fixed range, pick from it with equal probability.
352    # For logarithmic ranges, the model will have to override.
353    if np.isfinite(par.limits).all():
354        return np.random.uniform(*par.limits)
355
356    # If the paramter is marked as an sld use the range of neutron slds
357    # TODO: ought to randomly contrast match a pair of SLDs
358    if par.type == 'sld':
359        return np.random.uniform(-0.5, 12)
360
361    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
362    if par.name.startswith('M0:'):
363        return np.random.uniform(0, 5)
364
365    # Guess at the random length/radius/thickness.  In practice, all models
366    # are going to set their own reasonable ranges.
367    if par.type == 'volume':
368        if ('length' in par.name or
369                'radius' in par.name or
370                'thick' in par.name):
371            return 10**np.random.uniform(2, 4)
372
373    # In the absence of any other info, select a value in [0, 2v], or
374    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
375    # model random parameter generators will override this default.
376    low, high = parameter_range(par.name, value)
377    limits = (max(par.limits[0], low), min(par.limits[1], high))
378    return np.random.uniform(*limits)
379
380def _random_pd(model_info, pars):
381    # type: (ModelInfo, Dict[str, float]) -> None
382    """
383    Generate a random dispersity distribution for the model.
384
385    1% no shape dispersity
386    85% single shape parameter
387    13% two shape parameters
388    1% three shape parameters
389
390    If oriented, then put dispersity in theta, add phi and psi dispersity
391    with 10% probability for each.
392    """
393    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
394    pd_volume = []
395    pd_oriented = []
396    for p in pd:
397        if p.type == 'orientation':
398            pd_oriented.append(p.name)
399        elif p.length_control is not None:
400            n = int(pars.get(p.length_control, 1) + 0.5)
401            pd_volume.extend(p.name+str(k+1) for k in range(n))
402        elif p.length > 1:
403            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
404        else:
405            pd_volume.append(p.name)
406    u = np.random.rand()
407    n = len(pd_volume)
408    if u < 0.01 or n < 1:
409        pass  # 1% chance of no polydispersity
410    elif u < 0.86 or n < 2:
411        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
412    elif u < 0.99 or n < 3:
413        choices = np.random.choice(len(pd_volume), size=2)
414        pars[pd_volume[choices[0]]+"_pd_n"] = 25
415        pars[pd_volume[choices[1]]+"_pd_n"] = 10
416    else:
417        choices = np.random.choice(len(pd_volume), size=3)
418        pars[pd_volume[choices[0]]+"_pd_n"] = 25
419        pars[pd_volume[choices[1]]+"_pd_n"] = 10
420        pars[pd_volume[choices[2]]+"_pd_n"] = 5
421    if pd_oriented:
422        pars['theta_pd_n'] = 20
423        if np.random.rand() < 0.1:
424            pars['phi_pd_n'] = 5
425        if np.random.rand() < 0.1:
426            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
427                #print("generating psi_pd_n")
428                pars['psi_pd_n'] = 5
429
430    ## Show selected polydispersity
431    #for name, value in pars.items():
432    #    if name.endswith('_pd_n') and value > 0:
433    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
434
435
436def randomize_pars(model_info, pars):
437    # type: (ModelInfo, ParameterSet) -> ParameterSet
438    """
439    Generate random values for all of the parameters.
440
441    Valid ranges for the random number generator are guessed from the name of
442    the parameter; this will not account for constraints such as cap radius
443    greater than cylinder radius in the capped_cylinder model, so
444    :func:`constrain_pars` needs to be called afterward..
445    """
446    # Note: the sort guarantees order of calls to random number generator
447    random_pars = dict((p, _randomize_one(model_info, p, v))
448                       for p, v in sorted(pars.items()))
449    if model_info.random is not None:
450        random_pars.update(model_info.random())
451    _random_pd(model_info, random_pars)
452    return random_pars
453
454
455def limit_dimensions(model_info, pars, maxdim):
456    # type: (ModelInfo, ParameterSet, float) -> None
457    """
458    Limit parameters of units of Ang to maxdim.
459    """
460    for p in model_info.parameters.call_parameters:
461        value = pars[p.name]
462        if p.units == 'Ang' and value > maxdim:
463            pars[p.name] = maxdim*10**np.random.uniform(-3, 0)
464
465def constrain_pars(model_info, pars):
466    # type: (ModelInfo, ParameterSet) -> None
467    """
468    Restrict parameters to valid values.
469
470    This includes model specific code for models such as capped_cylinder
471    which need to support within model constraints (cap radius more than
472    cylinder radius in this case).
473
474    Warning: this updates the *pars* dictionary in place.
475    """
476    # TODO: move the model specific code to the individual models
477    name = model_info.id
478    # if it is a product model, then just look at the form factor since
479    # none of the structure factors need any constraints.
480    if '*' in name:
481        name = name.split('*')[0]
482
483    # Suppress magnetism for python models (not yet implemented)
484    if callable(model_info.Iq):
485        pars.update(suppress_magnetism(pars))
486
487    if name == 'barbell':
488        if pars['radius_bell'] < pars['radius']:
489            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
490
491    elif name == 'capped_cylinder':
492        if pars['radius_cap'] < pars['radius']:
493            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
494
495    elif name == 'guinier':
496        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
497        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
498        # => ln A - (Rg^2 q^2/3) > -30 ln 10
499        # => Rg^2 q^2/3 < 30 ln 10 + ln A
500        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
501        #q_max = 0.2  # mid q maximum
502        q_max = 1.0  # high q maximum
503        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
504        pars['rg'] = min(pars['rg'], rg_max)
505
506    elif name == 'pearl_necklace':
507        if pars['radius'] < pars['thick_string']:
508            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
509
510    elif name == 'rpa':
511        # Make sure phi sums to 1.0
512        if pars['case_num'] < 2:
513            pars['Phi1'] = 0.
514            pars['Phi2'] = 0.
515        elif pars['case_num'] < 5:
516            pars['Phi1'] = 0.
517        total = sum(pars['Phi'+c] for c in '1234')
518        for c in '1234':
519            pars['Phi'+c] /= total
520
521def parlist(model_info, pars, is2d):
522    # type: (ModelInfo, ParameterSet, bool) -> str
523    """
524    Format the parameter list for printing.
525    """
526    is2d = True
527    lines = []
528    parameters = model_info.parameters
529    magnetic = False
530    magnetic_pars = []
531    for p in parameters.user_parameters(pars, is2d):
532        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
533            continue
534        if p.id.startswith('up:'):
535            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
536            continue
537        fields = dict(
538            value=pars.get(p.id, p.default),
539            pd=pars.get(p.id+"_pd", 0.),
540            n=int(pars.get(p.id+"_pd_n", 0)),
541            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
542            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
543            relative_pd=p.relative_pd,
544            M0=pars.get('M0:'+p.id, 0.),
545            mphi=pars.get('mphi:'+p.id, 0.),
546            mtheta=pars.get('mtheta:'+p.id, 0.),
547        )
548        lines.append(_format_par(p.name, **fields))
549        magnetic = magnetic or fields['M0'] != 0.
550    if magnetic and magnetic_pars:
551        lines.append(" ".join(magnetic_pars))
552    return "\n".join(lines)
553
554    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
555
556def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
557                relative_pd=False, M0=0., mphi=0., mtheta=0.):
558    # type: (str, float, float, int, float, str) -> str
559    line = "%s: %g"%(name, value)
560    if pd != 0.  and n != 0:
561        if relative_pd:
562            pd *= value
563        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
564                % (pd, n, nsigma, nsigma, pdtype)
565    if M0 != 0.:
566        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
567    return line
568
569def suppress_pd(pars, suppress=True):
570    # type: (ParameterSet) -> ParameterSet
571    """
572    If suppress is True complete eliminate polydispersity of the model to test
573    models more quickly.  If suppress is False, make sure at least one
574    parameter is polydisperse, setting the first polydispersity parameter to
575    15% if no polydispersity is given (with no explicit demo parameters given
576    in the model, there will be no default polydispersity).
577    """
578    pars = pars.copy()
579    #print("pars=", pars)
580    if suppress:
581        for p in pars:
582            if p.endswith("_pd_n"):
583                pars[p] = 0
584    else:
585        any_pd = False
586        first_pd = None
587        for p in pars:
588            if p.endswith("_pd_n"):
589                pd = pars.get(p[:-2], 0.)
590                any_pd |= (pars[p] != 0 and pd != 0.)
591                if first_pd is None:
592                    first_pd = p
593        if not any_pd and first_pd is not None:
594            if pars[first_pd] == 0:
595                pars[first_pd] = 35
596            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
597                pars[first_pd[:-2]] = 0.15
598    return pars
599
600def suppress_magnetism(pars, suppress=True):
601    # type: (ParameterSet) -> ParameterSet
602    """
603    If suppress is True complete eliminate magnetism of the model to test
604    models more quickly.  If suppress is False, make sure at least one sld
605    parameter is magnetic, setting the first parameter to have a strong
606    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
607    in the model, there will be no default magnetism).
608    """
609    pars = pars.copy()
610    if suppress:
611        for p in pars:
612            if p.startswith("M0:"):
613                pars[p] = 0
614    else:
615        any_mag = False
616        first_mag = None
617        for p in pars:
618            if p.startswith("M0:"):
619                any_mag |= (pars[p] != 0)
620                if first_mag is None:
621                    first_mag = p
622        if not any_mag and first_mag is not None:
623            pars[first_mag] = 8.
624    return pars
625
626
627def time_calculation(calculator, pars, evals=1):
628    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
629    """
630    Compute the average calculation time over N evaluations.
631
632    An additional call is generated without polydispersity in order to
633    initialize the calculation engine, and make the average more stable.
634    """
635    # initialize the code so time is more accurate
636    if evals > 1:
637        calculator(**suppress_pd(pars))
638    toc = tic()
639    # make sure there is at least one eval
640    value = calculator(**pars)
641    for _ in range(evals-1):
642        value = calculator(**pars)
643    average_time = toc()*1000. / evals
644    #print("I(q)",value)
645    return value, average_time
646
647def make_data(opts):
648    # type: (Dict[str, Any], float) -> Tuple[Data, np.ndarray]
649    """
650    Generate an empty dataset, used with the model to set Q points
651    and resolution.
652
653    *opts* contains the options, with 'qmax', 'nq', 'res',
654    'accuracy', 'is2d' and 'view' parsed from the command line.
655    """
656    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res']
657    if opts['is2d']:
658        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
659        data = empty_data2D(q, resolution=res)
660        data.accuracy = opts['accuracy']
661        set_beam_stop(data, qmin)
662        index = ~data.mask
663    else:
664        if opts['view'] == 'log' and not opts['zero']:
665            q = np.logspace(math.log10(qmin), math.log10(qmax), nq)
666        else:
667            q = np.linspace(qmin, qmax, nq)
668        if opts['zero']:
669            q = np.hstack((0, q))
670        # TODO: provide command line control of lambda and Delta lambda/lambda
671        #L, dLoL = 5, 0.14/np.sqrt(6)  # wavelength and 14% triangular FWHM
672        L, dLoL = 0, 0
673        data = empty_data1D(q, resolution=res, L=L, dL=L*dLoL)
674        index = slice(None, None)
675    return data, index
676
677DTYPE_MAP = {
678    'half': '16',
679    'fast': 'fast',
680    'single': '32',
681    'double': '64',
682    'quad': '128',
683    'f16': '16',
684    'f32': '32',
685    'f64': '64',
686    'float16': '16',
687    'float32': '32',
688    'float64': '64',
689    'float128': '128',
690    'longdouble': '128',
691}
692def eval_opencl(model_info, data, dtype='single', cutoff=0.):
693    # type: (ModelInfo, Data, str, float) -> Calculator
694    """
695    Return a model calculator using the OpenCL calculation engine.
696    """
697
698def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
699    # type: (ModelInfo, Data, str, float) -> Calculator
700    """
701    Return a model calculator using the DLL calculation engine.
702    """
703    model = core.build_model(model_info, dtype=dtype, platform="dll")
704    calculator = DirectModel(data, model, cutoff=cutoff)
705    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
706    return calculator
707
708def make_engine(model_info, data, dtype, cutoff, ngauss=0):
709    # type: (ModelInfo, Data, str, float) -> Calculator
710    """
711    Generate the appropriate calculation engine for the given datatype.
712
713    Datatypes with '!' appended are evaluated using external C DLLs rather
714    than OpenCL.
715    """
716    if ngauss:
717        set_integration_size(model_info, ngauss)
718
719    if dtype != "default" and not dtype.endswith('!') and not kernelcl.use_opencl():
720        raise RuntimeError("OpenCL not available " + kernelcl.OPENCL_ERROR)
721
722    model = core.build_model(model_info, dtype=dtype, platform="ocl")
723    calculator = DirectModel(data, model, cutoff=cutoff)
724    engine_type = calculator._model.__class__.__name__.replace('Model', '').upper()
725    bits = calculator._model.dtype.itemsize*8
726    precision = "fast" if getattr(calculator._model, 'fast', False) else str(bits)
727    calculator.engine = "%s[%s]" % (engine_type, precision)
728    return calculator
729
730def _show_invalid(data, theory):
731    # type: (Data, np.ma.ndarray) -> None
732    """
733    Display a list of the non-finite values in theory.
734    """
735    if not theory.mask.any():
736        return
737
738    if hasattr(data, 'x'):
739        bad = zip(data.x[theory.mask], theory[theory.mask])
740        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
741
742
743def compare(opts, limits=None, maxdim=np.inf):
744    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
745    """
746    Preform a comparison using options from the command line.
747
748    *limits* are the limits on the values to use, either to set the y-axis
749    for 1D or to set the colormap scale for 2D.  If None, then they are
750    inferred from the data and returned. When exploring using Bumps,
751    the limits are set when the model is initially called, and maintained
752    as the values are adjusted, making it easier to see the effects of the
753    parameters.
754
755    *maxdim* is the maximum value for any parameter with units of Angstrom.
756    """
757    for k in range(opts['sets']):
758        if k > 0:
759            # print a separate seed for each dataset for better reproducibility
760            new_seed = np.random.randint(1000000)
761            print("=== Set %d uses -random=%i ==="%(k+1, new_seed))
762            np.random.seed(new_seed)
763        opts['pars'] = parse_pars(opts, maxdim=maxdim)
764        if opts['pars'] is None:
765            return
766        result = run_models(opts, verbose=True)
767        if opts['plot']:
768            if opts['is2d'] and k > 0:
769                import matplotlib.pyplot as plt
770                plt.figure()
771            limits = plot_models(opts, result, limits=limits, setnum=k)
772        if opts['show_weights']:
773            base, _ = opts['engines']
774            base_pars, _ = opts['pars']
775            model_info = base._kernel.info
776            dim = base._kernel.dim
777            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
778        if opts['show_profile']:
779            import pylab
780            base, comp = opts['engines']
781            base_pars, comp_pars = opts['pars']
782            have_base = base._kernel.info.profile is not None
783            have_comp = (
784                comp is not None
785                and comp._kernel.info.profile is not None
786                and base_pars != comp_pars
787            )
788            if have_base or have_comp:
789                pylab.figure()
790            if have_base:
791                plot_profile(base._kernel.info, **base_pars)
792            if have_comp:
793                plot_profile(comp._kernel.info, label='comp', **comp_pars)
794                pylab.legend()
795    if opts['plot']:
796        import matplotlib.pyplot as plt
797        plt.show()
798    return limits
799
800def plot_profile(model_info, label='base', **args):
801    # type: (ModelInfo, List[Tuple[float, np.ndarray, np.ndarray]]) -> None
802    """
803    Plot the profile returned by the model profile method.
804
805    *model_info* defines model parameters, etc.
806
807    *mesh* is a list of tuples containing (*value*, *dispersity*, *weights*)
808    for each parameter, where (*dispersity*, *weights*) pairs are the
809    distributions to be plotted.
810    """
811    import pylab
812
813    args = dict((k, v) for k, v in args.items()
814                if "_pd" not in k
815                   and ":" not in k
816                   and k not in ("background", "scale", "theta", "phi", "psi"))
817    args = args.copy()
818
819    args.pop('scale', 1.)
820    args.pop('background', 0.)
821    z, rho = model_info.profile(**args)
822    #pylab.interactive(True)
823    pylab.plot(z, rho, '-', label=label)
824    pylab.grid(True)
825    #pylab.show()
826
827
828
829def run_models(opts, verbose=False):
830    # type: (Dict[str, Any]) -> Dict[str, Any]
831    """
832    Process a parameter set, return calculation results and times.
833    """
834
835    base, comp = opts['engines']
836    base_n, comp_n = opts['count']
837    base_pars, comp_pars = opts['pars']
838    base_data, comp_data = opts['data']
839
840    comparison = comp is not None
841
842    base_time = comp_time = None
843    base_value = comp_value = resid = relerr = None
844
845    # Base calculation
846    try:
847        base_raw, base_time = time_calculation(base, base_pars, base_n)
848        base_value = np.ma.masked_invalid(base_raw)
849        if verbose:
850            print("%s t=%.2f ms, intensity=%.0f"
851                  % (base.engine, base_time, base_value.sum()))
852        _show_invalid(base_data, base_value)
853    except ImportError:
854        traceback.print_exc()
855
856    # Comparison calculation
857    if comparison:
858        try:
859            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
860            comp_value = np.ma.masked_invalid(comp_raw)
861            if verbose:
862                print("%s t=%.2f ms, intensity=%.0f"
863                      % (comp.engine, comp_time, comp_value.sum()))
864            _show_invalid(base_data, comp_value)
865        except ImportError:
866            traceback.print_exc()
867
868    # Compare, but only if computing both forms
869    if comparison:
870        resid = (base_value - comp_value)
871        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
872        if verbose:
873            _print_stats("|%s-%s|"
874                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
875                         resid)
876            _print_stats("|(%s-%s)/%s|"
877                         % (base.engine, comp.engine, comp.engine),
878                         relerr)
879
880    return dict(base_value=base_value, comp_value=comp_value,
881                base_time=base_time, comp_time=comp_time,
882                resid=resid, relerr=relerr)
883
884
885def _print_stats(label, err):
886    # type: (str, np.ma.ndarray) -> None
887    # work with trimmed data, not the full set
888    sorted_err = np.sort(abs(err.compressed()))
889    if len(sorted_err) == 0:
890        print(label + "  no valid values")
891        return
892
893    p50 = int((len(sorted_err)-1)*0.50)
894    p98 = int((len(sorted_err)-1)*0.98)
895    data = [
896        "max:%.3e"%sorted_err[-1],
897        "median:%.3e"%sorted_err[p50],
898        "98%%:%.3e"%sorted_err[p98],
899        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
900        "zero-offset:%+.3e"%np.mean(sorted_err),
901        ]
902    print(label+"  "+"  ".join(data))
903
904
905def plot_models(opts, result, limits=None, setnum=0):
906    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
907    """
908    Plot the results from :func:`run_model`.
909    """
910    import matplotlib.pyplot as plt
911
912    base_value, comp_value = result['base_value'], result['comp_value']
913    base_time, comp_time = result['base_time'], result['comp_time']
914    resid, relerr = result['resid'], result['relerr']
915
916    have_base, have_comp = (base_value is not None), (comp_value is not None)
917    base, comp = opts['engines']
918    base_data, comp_data = opts['data']
919    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
920
921    # Plot if requested
922    view = opts['view']
923    #view = 'log'
924    if limits is None:
925        vmin, vmax = np.inf, -np.inf
926        if have_base:
927            vmin = min(vmin, base_value.min())
928            vmax = max(vmax, base_value.max())
929        if have_comp:
930            vmin = min(vmin, comp_value.min())
931            vmax = max(vmax, comp_value.max())
932        limits = vmin, vmax
933
934    if have_base:
935        if have_comp:
936            plt.subplot(131)
937        plot_theory(base_data, base_value, view=view, use_data=use_data, limits=limits)
938        plt.title("%s t=%.2f ms"%(base.engine, base_time))
939        #cbar_title = "log I"
940    if have_comp:
941        if have_base:
942            plt.subplot(132)
943        if not opts['is2d'] and have_base:
944            plot_theory(comp_data, base_value, view=view, use_data=use_data, limits=limits)
945        plot_theory(comp_data, comp_value, view=view, use_data=use_data, limits=limits)
946        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
947        #cbar_title = "log I"
948    if have_base and have_comp:
949        plt.subplot(133)
950        if not opts['rel_err']:
951            err, errstr, errview = resid, "abs err", "linear"
952        else:
953            err, errstr, errview = abs(relerr), "rel err", "log"
954            if (err == 0.).all():
955                errview = 'linear'
956        if 0:  # 95% cutoff
957            sorted_err = np.sort(err.flatten())
958            cutoff = sorted_err[int(sorted_err.size*0.95)]
959            err[err > cutoff] = cutoff
960        #err,errstr = base/comp,"ratio"
961        # Note: base_data only since base and comp have same q values (though
962        # perhaps different resolution), and we are plotting the difference
963        # at each q
964        plot_theory(base_data, None, resid=err, view=errview, use_data=use_data)
965        plt.xscale('log' if view == 'log' and not opts['is2d'] else 'linear')
966        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
967        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
968        #cbar_title = errstr if errview=="linear" else "log "+errstr
969    #if is2D:
970    #    h = plt.colorbar()
971    #    h.ax.set_title(cbar_title)
972    fig = plt.gcf()
973    extra_title = ' '+opts['title'] if opts['title'] else ''
974    fig.suptitle(":".join(opts['name']) + extra_title)
975
976    if have_base and have_comp and opts['show_hist']:
977        plt.figure()
978        v = relerr
979        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
980        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
981        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
982                   % (base.engine, comp.engine, comp.engine))
983        plt.ylabel('P(err)')
984        plt.title('Distribution of relative error between calculation engines')
985
986    return limits
987
988
989# ===========================================================================
990#
991
992# Set of command line options.
993# Normal options such as -plot/-noplot are specified as 'name'.
994# For options such as -nq=500 which require a value use 'name='.
995#
996OPTIONS = [
997    # Plotting
998    'plot', 'noplot',
999    'weights', 'profile',
1000    'linear', 'log', 'q4',
1001    'rel', 'abs',
1002    'hist', 'nohist',
1003    'title=',
1004
1005    # Data generation
1006    'data=', 'noise=', 'res=', 'nq=', 'q=',
1007    'lowq', 'midq', 'highq', 'exq', 'zero',
1008    '2d', '1d',
1009
1010    # Parameter set
1011    'preset', 'random', 'random=', 'sets=',
1012    'demo', 'default',  # TODO: remove demo/default
1013    'nopars', 'pars',
1014    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
1015
1016    # Calculation options
1017    'poly', 'mono', 'cutoff=',
1018    'magnetic', 'nonmagnetic',
1019    'accuracy=', 'ngauss=',
1020    'neval=',  # for timing...
1021
1022    # Precision options
1023    'engine=',
1024    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1025
1026    # Output options
1027    'help', 'html', 'edit',
1028    ]
1029
1030NAME_OPTIONS = (lambda: set(k for k in OPTIONS if not k.endswith('=')))()
1031VALUE_OPTIONS = (lambda: [k[:-1] for k in OPTIONS if k.endswith('=')])()
1032
1033
1034def columnize(items, indent="", width=79):
1035    # type: (List[str], str, int) -> str
1036    """
1037    Format a list of strings into columns.
1038
1039    Returns a string with carriage returns ready for printing.
1040    """
1041    column_width = max(len(w) for w in items) + 1
1042    num_columns = (width - len(indent)) // column_width
1043    num_rows = len(items) // num_columns
1044    items = items + [""] * (num_rows * num_columns - len(items))
1045    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1046    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1047             for row in zip(*columns)]
1048    output = indent + ("\n"+indent).join(lines)
1049    return output
1050
1051
1052def get_pars(model_info, use_demo=False):
1053    # type: (ModelInfo, bool) -> ParameterSet
1054    """
1055    Extract demo parameters from the model definition.
1056    """
1057    # Get the default values for the parameters
1058    pars = {}
1059    for p in model_info.parameters.call_parameters:
1060        parts = [('', p.default)]
1061        if p.polydisperse:
1062            parts.append(('_pd', 0.0))
1063            parts.append(('_pd_n', 0))
1064            parts.append(('_pd_nsigma', 3.0))
1065            parts.append(('_pd_type', "gaussian"))
1066        for ext, val in parts:
1067            if p.length > 1:
1068                dict(("%s%d%s" % (p.id, k, ext), val)
1069                     for k in range(1, p.length+1))
1070            else:
1071                pars[p.id + ext] = val
1072
1073    # Plug in values given in demo
1074    if use_demo and model_info.demo:
1075        pars.update(model_info.demo)
1076    return pars
1077
1078INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1079def isnumber(s):
1080    # type: (str) -> bool
1081    """Return True if string contains an int or float"""
1082    match = FLOAT_RE.match(s)
1083    isfloat = (match and not s[match.end():])
1084    return isfloat or INTEGER_RE.match(s)
1085
1086# For distinguishing pairs of models for comparison
1087# key-value pair separator =
1088# shell characters  | & ; <> $ % ' " \ # `
1089# model and parameter names _
1090# parameter expressions - + * / . ( )
1091# path characters including tilde expansion and windows drive ~ / :
1092# not sure about brackets [] {}
1093# maybe one of the following @ ? ^ ! ,
1094PAR_SPLIT = ','
1095def parse_opts(argv):
1096    # type: (List[str]) -> Dict[str, Any]
1097    """
1098    Parse command line options.
1099    """
1100    MODELS = core.list_models()
1101    flags = [arg for arg in argv
1102             if arg.startswith('-')]
1103    values = [arg for arg in argv
1104              if not arg.startswith('-') and '=' in arg]
1105    positional_args = [arg for arg in argv
1106                       if not arg.startswith('-') and '=' not in arg]
1107    models = "\n    ".join("%-15s"%v for v in MODELS)
1108    if len(positional_args) == 0:
1109        print(USAGE)
1110        print("\nAvailable models:")
1111        print(columnize(MODELS, indent="  "))
1112        return None
1113
1114    invalid = [o[1:] for o in flags
1115               if o[1:] not in NAME_OPTIONS
1116               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1117    if invalid:
1118        print("Invalid options: %s"%(", ".join(invalid)))
1119        return None
1120
1121    name = positional_args[-1]
1122
1123    # pylint: disable=bad-whitespace,C0321
1124    # Interpret the flags
1125    opts = {
1126        'plot'      : True,
1127        'view'      : 'log',
1128        'is2d'      : False,
1129        'qmin'      : None,
1130        'qmax'      : 0.05,
1131        'nq'        : 128,
1132        'res'       : '0.0',
1133        'noise'     : 0.0,
1134        'accuracy'  : 'Low',
1135        'cutoff'    : '0.0',
1136        'seed'      : -1,  # default to preset
1137        'mono'      : True,
1138        # Default to magnetic a magnetic moment is set on the command line
1139        'magnetic'  : False,
1140        'show_pars' : False,
1141        'show_hist' : False,
1142        'rel_err'   : True,
1143        'explore'   : False,
1144        'use_demo'  : True,
1145        'zero'      : False,
1146        'html'      : False,
1147        'title'     : None,
1148        'datafile'  : None,
1149        'sets'      : 0,
1150        'engine'    : 'default',
1151        'count'     : '1',
1152        'show_weights' : False,
1153        'show_profile' : False,
1154        'sphere'    : 0,
1155        'ngauss'    : '0',
1156    }
1157    for arg in flags:
1158        if arg == '-noplot':    opts['plot'] = False
1159        elif arg == '-plot':    opts['plot'] = True
1160        elif arg == '-linear':  opts['view'] = 'linear'
1161        elif arg == '-log':     opts['view'] = 'log'
1162        elif arg == '-q4':      opts['view'] = 'q4'
1163        elif arg == '-1d':      opts['is2d'] = False
1164        elif arg == '-2d':      opts['is2d'] = True
1165        elif arg == '-exq':     opts['qmax'] = 10.0
1166        elif arg == '-highq':   opts['qmax'] = 1.0
1167        elif arg == '-midq':    opts['qmax'] = 0.2
1168        elif arg == '-lowq':    opts['qmax'] = 0.05
1169        elif arg == '-zero':    opts['zero'] = True
1170        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1171        elif arg.startswith('-q='):
1172            opts['qmin'], opts['qmax'] = [float(v) for v in arg[3:].split(':')]
1173        elif arg.startswith('-res='):      opts['res'] = arg[5:]
1174        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1175        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1176        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1177        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1178        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1179        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1180        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1181        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
1182        elif arg.startswith('-ngauss='):   opts['ngauss'] = arg[8:]
1183        elif arg.startswith('-random='):
1184            opts['seed'] = int(arg[8:])
1185            opts['sets'] = 0
1186        elif arg == '-random':
1187            opts['seed'] = np.random.randint(1000000)
1188            opts['sets'] = 0
1189        elif arg.startswith('-sphere'):
1190            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1191            opts['is2d'] = True
1192        elif arg == '-preset':  opts['seed'] = -1
1193        elif arg == '-mono':    opts['mono'] = True
1194        elif arg == '-poly':    opts['mono'] = False
1195        elif arg == '-magnetic':       opts['magnetic'] = True
1196        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1197        elif arg == '-pars':    opts['show_pars'] = True
1198        elif arg == '-nopars':  opts['show_pars'] = False
1199        elif arg == '-hist':    opts['show_hist'] = True
1200        elif arg == '-nohist':  opts['show_hist'] = False
1201        elif arg == '-rel':     opts['rel_err'] = True
1202        elif arg == '-abs':     opts['rel_err'] = False
1203        elif arg == '-half':    opts['engine'] = 'half'
1204        elif arg == '-fast':    opts['engine'] = 'fast'
1205        elif arg == '-single':  opts['engine'] = 'single'
1206        elif arg == '-double':  opts['engine'] = 'double'
1207        elif arg == '-single!': opts['engine'] = 'single!'
1208        elif arg == '-double!': opts['engine'] = 'double!'
1209        elif arg == '-quad!':   opts['engine'] = 'quad!'
1210        elif arg == '-edit':    opts['explore'] = True
1211        elif arg == '-demo':    opts['use_demo'] = True
1212        elif arg == '-default': opts['use_demo'] = False
1213        elif arg == '-weights': opts['show_weights'] = True
1214        elif arg == '-profile': opts['show_profile'] = True
1215        elif arg == '-html':    opts['html'] = True
1216        elif arg == '-help':    opts['html'] = True
1217    # pylint: enable=bad-whitespace,C0321
1218
1219    # Magnetism forces 2D for now
1220    if opts['magnetic']:
1221        opts['is2d'] = True
1222
1223    # Force random if sets is used
1224    if opts['sets'] >= 1 and opts['seed'] < 0:
1225        opts['seed'] = np.random.randint(1000000)
1226    if opts['sets'] == 0:
1227        opts['sets'] = 1
1228
1229    # Create the computational engines
1230    if opts['qmin'] is None:
1231        opts['qmin'] = 0.001*opts['qmax']
1232
1233    comparison = any(PAR_SPLIT in v for v in values)
1234
1235    if PAR_SPLIT in name:
1236        names = name.split(PAR_SPLIT, 2)
1237        comparison = True
1238    else:
1239        names = [name]*2
1240    try:
1241        model_info = [core.load_model_info(k) for k in names]
1242    except ImportError as exc:
1243        print(str(exc))
1244        print("Could not find model; use one of:\n    " + models)
1245        return None
1246
1247    if PAR_SPLIT in opts['ngauss']:
1248        opts['ngauss'] = [int(k) for k in opts['ngauss'].split(PAR_SPLIT, 2)]
1249        comparison = True
1250    else:
1251        opts['ngauss'] = [int(opts['ngauss'])]*2
1252
1253    if PAR_SPLIT in opts['engine']:
1254        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
1255        comparison = True
1256    else:
1257        opts['engine'] = [opts['engine']]*2
1258
1259    if PAR_SPLIT in opts['count']:
1260        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
1261        comparison = True
1262    else:
1263        opts['count'] = [int(opts['count'])]*2
1264
1265    if PAR_SPLIT in opts['cutoff']:
1266        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1267        comparison = True
1268    else:
1269        opts['cutoff'] = [float(opts['cutoff'])]*2
1270
1271    if PAR_SPLIT in opts['res']:
1272        opts['res'] = [float(k) for k in opts['res'].split(PAR_SPLIT, 2)]
1273        comparison = True
1274    else:
1275        opts['res'] = [float(opts['res'])]*2
1276
1277    if opts['datafile'] is not None:
1278        data = load_data(os.path.expanduser(opts['datafile']))
1279    else:
1280        # Hack around the fact that make_data doesn't take a pair of resolutions
1281        res = opts['res']
1282        opts['res'] = res[0]
1283        data0, _ = make_data(opts)
1284        if res[0] != res[1]:
1285            opts['res'] = res[1]
1286            data1, _ = make_data(opts)
1287        else:
1288            data1 = data0
1289        opts['res'] = res
1290        data = data0, data1
1291
1292    base = make_engine(model_info[0], data[0], opts['engine'][0],
1293                       opts['cutoff'][0], opts['ngauss'][0])
1294    if comparison:
1295        comp = make_engine(model_info[1], data[1], opts['engine'][1],
1296                           opts['cutoff'][1], opts['ngauss'][1])
1297    else:
1298        comp = None
1299
1300    # pylint: disable=bad-whitespace
1301    # Remember it all
1302    opts.update({
1303        'data'      : data,
1304        'name'      : names,
1305        'info'      : model_info,
1306        'engines'   : [base, comp],
1307        'values'    : values,
1308    })
1309    # pylint: enable=bad-whitespace
1310
1311    # Set the integration parameters to the half sphere
1312    if opts['sphere'] > 0:
1313        set_spherical_integration_parameters(opts, opts['sphere'])
1314
1315    return opts
1316
1317def set_spherical_integration_parameters(opts, steps):
1318    # type: (Dict[str, Any], int) -> None
1319    """
1320    Set integration parameters for spherical integration over the entire
1321    surface in theta-phi coordinates.
1322    """
1323    # Set the integration parameters to the half sphere
1324    opts['values'].extend([
1325        #'theta=90',
1326        'theta_pd=%g'%(90/np.sqrt(3)),
1327        'theta_pd_n=%d'%steps,
1328        'theta_pd_type=rectangle',
1329        #'phi=0',
1330        'phi_pd=%g'%(180/np.sqrt(3)),
1331        'phi_pd_n=%d'%(2*steps),
1332        'phi_pd_type=rectangle',
1333        #'background=0',
1334    ])
1335    if 'psi' in opts['info'][0].parameters:
1336        opts['values'].extend([
1337            #'psi=0',
1338            'psi_pd=%g'%(180/np.sqrt(3)),
1339            'psi_pd_n=%d'%(2*steps),
1340            'psi_pd_type=rectangle',
1341        ])
1342
1343def parse_pars(opts, maxdim=np.inf):
1344    # type: (Dict[str, Any], float) -> Tuple[Dict[str, float], Dict[str, float]]
1345    """
1346    Generate a parameter set.
1347
1348    The default values come from the model, or a randomized model if a seed
1349    value is given.  Next, evaluate any parameter expressions, constraining
1350    the value of the parameter within and between models.  If *maxdim* is
1351    given, limit parameters with units of Angstrom to this value.
1352
1353    Returns a pair of parameter dictionaries for base and comparison models.
1354    """
1355    model_info, model_info2 = opts['info']
1356
1357    # Get demo parameters from model definition, or use default parameters
1358    # if model does not define demo parameters
1359    pars = get_pars(model_info, opts['use_demo'])
1360    pars2 = get_pars(model_info2, opts['use_demo'])
1361    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1362    # randomize parameters
1363    #pars.update(set_pars)  # set value before random to control range
1364    if opts['seed'] > -1:
1365        pars = randomize_pars(model_info, pars)
1366        limit_dimensions(model_info, pars, maxdim)
1367        if model_info != model_info2:
1368            pars2 = randomize_pars(model_info2, pars2)
1369            limit_dimensions(model_info2, pars2, maxdim)
1370            # Share values for parameters with the same name
1371            for k, v in pars.items():
1372                if k in pars2:
1373                    pars2[k] = v
1374        else:
1375            pars2 = pars.copy()
1376        constrain_pars(model_info, pars)
1377        constrain_pars(model_info2, pars2)
1378    pars = suppress_pd(pars, opts['mono'])
1379    pars2 = suppress_pd(pars2, opts['mono'])
1380    pars = suppress_magnetism(pars, not opts['magnetic'])
1381    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1382
1383    # Fill in parameters given on the command line
1384    presets = {}
1385    presets2 = {}
1386    for arg in opts['values']:
1387        k, v = arg.split('=', 1)
1388        if k not in pars and k not in pars2:
1389            # extract base name without polydispersity info
1390            s = set(p.split('_pd')[0] for p in pars)
1391            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1392            return None
1393        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v, v)
1394        if v1 and k in pars:
1395            presets[k] = float(v1) if isnumber(v1) else v1
1396        if v2 and k in pars2:
1397            presets2[k] = float(v2) if isnumber(v2) else v2
1398
1399    # If pd given on the command line, default pd_n to 35
1400    for k, v in list(presets.items()):
1401        if k.endswith('_pd'):
1402            presets.setdefault(k+'_n', 35.)
1403    for k, v in list(presets2.items()):
1404        if k.endswith('_pd'):
1405            presets2.setdefault(k+'_n', 35.)
1406
1407    # Evaluate preset parameter expressions
1408    # Note: need to replace ':' with '_' in parameter names and expressions
1409    # in order to support math on magnetic parameters.
1410    context = MATH.copy()
1411    context['np'] = np
1412    context.update((k.replace(':', '_'), v) for k, v in pars.items())
1413    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1414    #for k,v in sorted(context.items()): print(k, v)
1415    for k, v in presets.items():
1416        if not isinstance(v, float) and not k.endswith('_type'):
1417            presets[k] = eval(v.replace(':', '_'), context)
1418    context.update(presets)
1419    context.update((k.replace(':', '_'), v) for k, v in presets2.items() if isinstance(v, float))
1420    for k, v in presets2.items():
1421        if not isinstance(v, float) and not k.endswith('_type'):
1422            presets2[k] = eval(v.replace(':', '_'), context)
1423
1424    # update parameters with presets
1425    pars.update(presets)  # set value after random to control value
1426    pars2.update(presets2)  # set value after random to control value
1427    #import pprint; pprint.pprint(model_info)
1428
1429    if opts['show_pars']:
1430        if model_info.name != model_info2.name or pars != pars2:
1431            print("==== %s ====="%model_info.name)
1432            print(str(parlist(model_info, pars, opts['is2d'])))
1433            print("==== %s ====="%model_info2.name)
1434            print(str(parlist(model_info2, pars2, opts['is2d'])))
1435        else:
1436            print(str(parlist(model_info, pars, opts['is2d'])))
1437
1438    return pars, pars2
1439
1440def show_docs(opts):
1441    # type: (Dict[str, Any]) -> None
1442    """
1443    show html docs for the model
1444    """
1445    from .generate import make_html
1446    from . import rst2html
1447
1448    info = opts['info'][0]
1449    html = make_html(info)
1450    path = os.path.dirname(info.filename)
1451    url = "file://" + path.replace("\\", "/")[2:] + "/"
1452    rst2html.view_html_wxapp(html, url)
1453
1454def explore(opts):
1455    # type: (Dict[str, Any]) -> None
1456    """
1457    explore the model using the bumps gui.
1458    """
1459    import wx  # type: ignore
1460    from bumps.names import FitProblem  # type: ignore
1461    from bumps.gui.app_frame import AppFrame  # type: ignore
1462    from bumps.gui import signal
1463
1464    is_mac = "cocoa" in wx.version()
1465    # Create an app if not running embedded
1466    app = wx.App() if wx.GetApp() is None else None
1467    model = Explore(opts)
1468    problem = FitProblem(model)
1469    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1470    if not is_mac:
1471        frame.Show()
1472    frame.panel.set_model(model=problem)
1473    frame.panel.Layout()
1474    frame.panel.aui.Split(0, wx.TOP)
1475    def _reset_parameters(event):
1476        model.revert_values()
1477        signal.update_parameters(problem)
1478    frame.Bind(wx.EVT_TOOL, _reset_parameters, frame.ToolBar.GetToolByPos(1))
1479    if is_mac:
1480        frame.Show()
1481    # If running withing an app, start the main loop
1482    if app:
1483        app.MainLoop()
1484
1485class Explore(object):
1486    """
1487    Bumps wrapper for a SAS model comparison.
1488
1489    The resulting object can be used as a Bumps fit problem so that
1490    parameters can be adjusted in the GUI, with plots updated on the fly.
1491    """
1492    def __init__(self, opts):
1493        # type: (Dict[str, Any]) -> None
1494        from bumps.cli import config_matplotlib  # type: ignore
1495        from . import bumps_model
1496        config_matplotlib()
1497        self.opts = opts
1498        opts['pars'] = list(opts['pars'])
1499        p1, p2 = opts['pars']
1500        m1, m2 = opts['info']
1501        self.fix_p2 = m1 != m2 or p1 != p2
1502        model_info = m1
1503        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1504        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1505        if not opts['is2d']:
1506            for p in model_info.parameters.user_parameters({}, is2d=False):
1507                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1508                    k = p.name+ext
1509                    v = pars.get(k, None)
1510                    if v is not None:
1511                        v.range(*parameter_range(k, v.value))
1512        else:
1513            for k, v in pars.items():
1514                v.range(*parameter_range(k, v.value))
1515
1516        self.pars = pars
1517        self.starting_values = dict((k, v.value) for k, v in pars.items())
1518        self.pd_types = pd_types
1519        self.limits = None
1520
1521    def revert_values(self):
1522        # type: () -> None
1523        """
1524        Restore starting values of the parameters.
1525        """
1526        for k, v in self.starting_values.items():
1527            self.pars[k].value = v
1528
1529    def model_update(self):
1530        # type: () -> None
1531        """
1532        Respond to signal that model parameters have been changed.
1533        """
1534        pass
1535
1536    def numpoints(self):
1537        # type: () -> int
1538        """
1539        Return the number of points.
1540        """
1541        return len(self.pars) + 1  # so dof is 1
1542
1543    def parameters(self):
1544        # type: () -> Any   # Dict/List hierarchy of parameters
1545        """
1546        Return a dictionary of parameters.
1547        """
1548        return self.pars
1549
1550    def nllf(self):
1551        # type: () -> float
1552        """
1553        Return cost.
1554        """
1555        # pylint: disable=no-self-use
1556        return 0.  # No nllf
1557
1558    def plot(self, view='log'):
1559        # type: (str) -> None
1560        """
1561        Plot the data and residuals.
1562        """
1563        pars = dict((k, v.value) for k, v in self.pars.items())
1564        pars.update(self.pd_types)
1565        self.opts['pars'][0] = pars
1566        if not self.fix_p2:
1567            self.opts['pars'][1] = pars
1568        result = run_models(self.opts)
1569        limits = plot_models(self.opts, result, limits=self.limits)
1570        if self.limits is None:
1571            vmin, vmax = limits
1572            self.limits = vmax*1e-7, 1.3*vmax
1573            import pylab
1574            pylab.clf()
1575            plot_models(self.opts, result, limits=self.limits)
1576
1577
1578def main(*argv):
1579    # type: (*str) -> None
1580    """
1581    Main program.
1582    """
1583    opts = parse_opts(argv)
1584    if opts is not None:
1585        if opts['seed'] > -1:
1586            print("Randomize using -random=%i"%opts['seed'])
1587            np.random.seed(opts['seed'])
1588        if opts['html']:
1589            show_docs(opts)
1590        elif opts['explore']:
1591            opts['pars'] = parse_pars(opts)
1592            if opts['pars'] is None:
1593                return
1594            explore(opts)
1595        else:
1596            compare(opts)
1597
1598if __name__ == "__main__":
1599    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.