source: sasmodels/sasmodels/compare.py @ 110f69c

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 110f69c was 110f69c, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

lint

  • Property mode set to 100755
File size: 56.9 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel, get_mesh
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47from .weights import plot_weights
48
49try:
50    from typing import Optional, Dict, Any, Callable, Tuple
51except Exception:
52    pass
53else:
54    from .modelinfo import ModelInfo, Parameter, ParameterSet
55    from .data import Data
56    Calculator = Callable[[float], np.ndarray]
57
58USAGE = """
59usage: sascomp model [options...] [key=val]
60
61Generate and compare SAS models.  If a single model is specified it shows
62a plot of that model.  Different models can be compared, or the same model
63with different parameters.  The same model with the same parameters can
64be compared with different calculation engines to see the effects of precision
65on the resultant values.
66
67model or model1,model2 are the names of the models to compare (see below).
68
69Options (* for default):
70
71    === data generation ===
72    -data="path" uses q, dq from the data file
73    -noise=0 sets the measurement error dI/I
74    -res=0 sets the resolution width dQ/Q if calculating with resolution
75    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
76    -q=min:max alternative specification of qrange
77    -nq=128 sets the number of Q points in the data set
78    -1d*/-2d computes 1d or 2d data
79    -zero indicates that q=0 should be included
80
81    === model parameters ===
82    -preset*/-random[=seed] preset or random parameters
83    -sets=n generates n random datasets with the seed given by -random=seed
84    -pars/-nopars* prints the parameter set or not
85    -default/-demo* use demo vs default parameters
86    -sphere[=150] set up spherical integration over theta/phi using n points
87
88    === calculation options ===
89    -mono*/-poly force monodisperse or allow polydisperse random parameters
90    -cutoff=1e-5* cutoff value for including a point in polydispersity
91    -magnetic/-nonmagnetic* suppress magnetism
92    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
93    -neval=1 sets the number of evals for more accurate timing
94
95    === precision options ===
96    -engine=default uses the default calcution precision
97    -single/-double/-half/-fast sets an OpenCL calculation engine
98    -single!/-double!/-quad! sets an OpenMP calculation engine
99    -sasview sets the sasview calculation engine
100
101    === plotting ===
102    -plot*/-noplot plots or suppress the plot of the model
103    -linear/-log*/-q4 intensity scaling on plots
104    -hist/-nohist* plot histogram of relative error
105    -abs/-rel* plot relative or absolute error
106    -title="note" adds note to the plot title, after the model name
107    -weights shows weights plots for the polydisperse parameters
108
109    === output options ===
110    -edit starts the parameter explorer
111    -help/-html shows the model docs instead of running the model
112
113The interpretation of quad precision depends on architecture, and may
114vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
115On unix and mac you may need single quotes around the DLL computation
116engines, such as -engine='single!,double!' since !, is treated as a history
117expansion request in the shell.
118
119Key=value pairs allow you to set specific values for the model parameters.
120Key=value1,value2 to compare different values of the same parameter. The
121value can be an expression including other parameters.
122
123Items later on the command line override those that appear earlier.
124
125Examples:
126
127    # compare single and double precision calculation for a barbell
128    sascomp barbell -engine=single,double
129
130    # generate 10 random lorentz models, with seed=27
131    sascomp lorentz -sets=10 -seed=27
132
133    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
134    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
135
136    # model timing test requires multiple evals to perform the estimate
137    sascomp pringle -engine=single,double -timing=100,100 -noplot
138"""
139
140# Update docs with command line usage string.   This is separate from the usual
141# doc string so that we can display it at run time if there is an error.
142# lin
143__doc__ = (__doc__  # pylint: disable=redefined-builtin
144           + """
145Program description
146-------------------
147
148""" + USAGE)
149
150kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
151
152def build_math_context():
153    # type: () -> Dict[str, Callable]
154    """build dictionary of functions from math module"""
155    return dict((k, getattr(math, k))
156                for k in dir(math) if not k.startswith('_'))
157
158#: list of math functions for use in evaluating parameters
159MATH = build_math_context()
160
161# CRUFT python 2.6
162if not hasattr(datetime.timedelta, 'total_seconds'):
163    def delay(dt):
164        """Return number date-time delta as number seconds"""
165        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
166else:
167    def delay(dt):
168        """Return number date-time delta as number seconds"""
169        return dt.total_seconds()
170
171
172class push_seed(object):
173    """
174    Set the seed value for the random number generator.
175
176    When used in a with statement, the random number generator state is
177    restored after the with statement is complete.
178
179    :Parameters:
180
181    *seed* : int or array_like, optional
182        Seed for RandomState
183
184    :Example:
185
186    Seed can be used directly to set the seed::
187
188        >>> from numpy.random import randint
189        >>> push_seed(24)
190        <...push_seed object at...>
191        >>> print(randint(0,1000000,3))
192        [242082    899 211136]
193
194    Seed can also be used in a with statement, which sets the random
195    number generator state for the enclosed computations and restores
196    it to the previous state on completion::
197
198        >>> with push_seed(24):
199        ...    print(randint(0,1000000,3))
200        [242082    899 211136]
201
202    Using nested contexts, we can demonstrate that state is indeed
203    restored after the block completes::
204
205        >>> with push_seed(24):
206        ...    print(randint(0,1000000))
207        ...    with push_seed(24):
208        ...        print(randint(0,1000000,3))
209        ...    print(randint(0,1000000))
210        242082
211        [242082    899 211136]
212        899
213
214    The restore step is protected against exceptions in the block::
215
216        >>> with push_seed(24):
217        ...    print(randint(0,1000000))
218        ...    try:
219        ...        with push_seed(24):
220        ...            print(randint(0,1000000,3))
221        ...            raise Exception()
222        ...    except Exception:
223        ...        print("Exception raised")
224        ...    print(randint(0,1000000))
225        242082
226        [242082    899 211136]
227        Exception raised
228        899
229    """
230    def __init__(self, seed=None):
231        # type: (Optional[int]) -> None
232        self._state = np.random.get_state()
233        np.random.seed(seed)
234
235    def __enter__(self):
236        # type: () -> None
237        pass
238
239    def __exit__(self, exc_type, exc_value, tb):
240        # type: (Any, BaseException, Any) -> None
241        # TODO: better typing for __exit__ method
242        np.random.set_state(self._state)
243
244def tic():
245    # type: () -> Callable[[], float]
246    """
247    Timer function.
248
249    Use "toc=tic()" to start the clock and "toc()" to measure
250    a time interval.
251    """
252    then = datetime.datetime.now()
253    return lambda: delay(datetime.datetime.now() - then)
254
255
256def set_beam_stop(data, radius, outer=None):
257    # type: (Data, float, float) -> None
258    """
259    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
260
261    Note: this function does not require sasview
262    """
263    if hasattr(data, 'qx_data'):
264        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
265        data.mask = (q < radius)
266        if outer is not None:
267            data.mask |= (q >= outer)
268    else:
269        data.mask = (data.x < radius)
270        if outer is not None:
271            data.mask |= (data.x >= outer)
272
273
274def parameter_range(p, v):
275    # type: (str, float) -> Tuple[float, float]
276    """
277    Choose a parameter range based on parameter name and initial value.
278    """
279    # process the polydispersity options
280    if p.endswith('_pd_n'):
281        return 0., 100.
282    elif p.endswith('_pd_nsigma'):
283        return 0., 5.
284    elif p.endswith('_pd_type'):
285        raise ValueError("Cannot return a range for a string value")
286    elif any(s in p for s in ('theta', 'phi', 'psi')):
287        # orientation in [-180,180], orientation pd in [0,45]
288        if p.endswith('_pd'):
289            return 0., 180.
290        else:
291            return -180., 180.
292    elif p.endswith('_pd'):
293        return 0., 1.
294    elif 'sld' in p:
295        return -0.5, 10.
296    elif p == 'background':
297        return 0., 10.
298    elif p == 'scale':
299        return 0., 1.e3
300    elif v < 0.:
301        return 2.*v, -2.*v
302    else:
303        return 0., (2.*v if v > 0. else 1.)
304
305
306def _randomize_one(model_info, name, value):
307    # type: (ModelInfo, str, float) -> float
308    # type: (ModelInfo, str, str) -> str
309    """
310    Randomize a single parameter.
311    """
312    # Set the amount of polydispersity/angular dispersion, but by default pd_n
313    # is zero so there is no polydispersity.  This allows us to turn on/off
314    # pd by setting pd_n, and still have randomly generated values
315    if name.endswith('_pd'):
316        par = model_info.parameters[name[:-3]]
317        if par.type == 'orientation':
318            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
319            return 180*np.random.beta(2.5, 20)
320        else:
321            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
322            return np.random.beta(1.5, 7)
323
324    # pd is selected globally rather than per parameter, so set to 0 for no pd
325    # In particular, when multiple pd dimensions, want to decrease the number
326    # of points per dimension for faster computation
327    if name.endswith('_pd_n'):
328        return 0
329
330    # Don't mess with distribution type for now
331    if name.endswith('_pd_type'):
332        return 'gaussian'
333
334    # type-dependent value of number of sigmas; for gaussian use 3.
335    if name.endswith('_pd_nsigma'):
336        return 3.
337
338    # background in the range [0.01, 1]
339    if name == 'background':
340        return 10**np.random.uniform(-2, 0)
341
342    # scale defaults to 0.1% to 30% volume fraction
343    if name == 'scale':
344        return 10**np.random.uniform(-3, -0.5)
345
346    # If it is a list of choices, pick one at random with equal probability
347    # In practice, the model specific random generator will override.
348    par = model_info.parameters[name]
349    if len(par.limits) > 2:  # choice list
350        return np.random.randint(len(par.limits))
351
352    # If it is a fixed range, pick from it with equal probability.
353    # For logarithmic ranges, the model will have to override.
354    if np.isfinite(par.limits).all():
355        return np.random.uniform(*par.limits)
356
357    # If the paramter is marked as an sld use the range of neutron slds
358    # TODO: ought to randomly contrast match a pair of SLDs
359    if par.type == 'sld':
360        return np.random.uniform(-0.5, 12)
361
362    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
363    if par.name.startswith('M0:'):
364        return np.random.uniform(0, 5)
365
366    # Guess at the random length/radius/thickness.  In practice, all models
367    # are going to set their own reasonable ranges.
368    if par.type == 'volume':
369        if ('length' in par.name or
370                'radius' in par.name or
371                'thick' in par.name):
372            return 10**np.random.uniform(2, 4)
373
374    # In the absence of any other info, select a value in [0, 2v], or
375    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
376    # model random parameter generators will override this default.
377    low, high = parameter_range(par.name, value)
378    limits = (max(par.limits[0], low), min(par.limits[1], high))
379    return np.random.uniform(*limits)
380
381def _random_pd(model_info, pars):
382    # type: (ModelInfo, Dict[str, float]) -> None
383    """
384    Generate a random dispersity distribution for the model.
385
386    1% no shape dispersity
387    85% single shape parameter
388    13% two shape parameters
389    1% three shape parameters
390
391    If oriented, then put dispersity in theta, add phi and psi dispersity
392    with 10% probability for each.
393    """
394    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
395    pd_volume = []
396    pd_oriented = []
397    for p in pd:
398        if p.type == 'orientation':
399            pd_oriented.append(p.name)
400        elif p.length_control is not None:
401            n = int(pars.get(p.length_control, 1) + 0.5)
402            pd_volume.extend(p.name+str(k+1) for k in range(n))
403        elif p.length > 1:
404            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
405        else:
406            pd_volume.append(p.name)
407    u = np.random.rand()
408    n = len(pd_volume)
409    if u < 0.01 or n < 1:
410        pass  # 1% chance of no polydispersity
411    elif u < 0.86 or n < 2:
412        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
413    elif u < 0.99 or n < 3:
414        choices = np.random.choice(len(pd_volume), size=2)
415        pars[pd_volume[choices[0]]+"_pd_n"] = 25
416        pars[pd_volume[choices[1]]+"_pd_n"] = 10
417    else:
418        choices = np.random.choice(len(pd_volume), size=3)
419        pars[pd_volume[choices[0]]+"_pd_n"] = 25
420        pars[pd_volume[choices[1]]+"_pd_n"] = 10
421        pars[pd_volume[choices[2]]+"_pd_n"] = 5
422    if pd_oriented:
423        pars['theta_pd_n'] = 20
424        if np.random.rand() < 0.1:
425            pars['phi_pd_n'] = 5
426        if np.random.rand() < 0.1:
427            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
428                #print("generating psi_pd_n")
429                pars['psi_pd_n'] = 5
430
431    ## Show selected polydispersity
432    #for name, value in pars.items():
433    #    if name.endswith('_pd_n') and value > 0:
434    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
435
436
437def randomize_pars(model_info, pars):
438    # type: (ModelInfo, ParameterSet) -> ParameterSet
439    """
440    Generate random values for all of the parameters.
441
442    Valid ranges for the random number generator are guessed from the name of
443    the parameter; this will not account for constraints such as cap radius
444    greater than cylinder radius in the capped_cylinder model, so
445    :func:`constrain_pars` needs to be called afterward..
446    """
447    # Note: the sort guarantees order of calls to random number generator
448    random_pars = dict((p, _randomize_one(model_info, p, v))
449                       for p, v in sorted(pars.items()))
450    if model_info.random is not None:
451        random_pars.update(model_info.random())
452    _random_pd(model_info, random_pars)
453    return random_pars
454
455
456def limit_dimensions(model_info, pars, maxdim):
457    # type: (ModelInfo, ParameterSet, float) -> None
458    """
459    Limit parameters of units of Ang to maxdim.
460    """
461    for p in model_info.parameters.call_parameters:
462        value = pars[p.name]
463        if p.units == 'Ang' and value > maxdim:
464            pars[p.name] = maxdim*10**np.random.uniform(-3, 0)
465
466def constrain_pars(model_info, pars):
467    # type: (ModelInfo, ParameterSet) -> None
468    """
469    Restrict parameters to valid values.
470
471    This includes model specific code for models such as capped_cylinder
472    which need to support within model constraints (cap radius more than
473    cylinder radius in this case).
474
475    Warning: this updates the *pars* dictionary in place.
476    """
477    # TODO: move the model specific code to the individual models
478    name = model_info.id
479    # if it is a product model, then just look at the form factor since
480    # none of the structure factors need any constraints.
481    if '*' in name:
482        name = name.split('*')[0]
483
484    # Suppress magnetism for python models (not yet implemented)
485    if callable(model_info.Iq):
486        pars.update(suppress_magnetism(pars))
487
488    if name == 'barbell':
489        if pars['radius_bell'] < pars['radius']:
490            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
491
492    elif name == 'capped_cylinder':
493        if pars['radius_cap'] < pars['radius']:
494            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
495
496    elif name == 'guinier':
497        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
498        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
499        # => ln A - (Rg^2 q^2/3) > -30 ln 10
500        # => Rg^2 q^2/3 < 30 ln 10 + ln A
501        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
502        #q_max = 0.2  # mid q maximum
503        q_max = 1.0  # high q maximum
504        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
505        pars['rg'] = min(pars['rg'], rg_max)
506
507    elif name == 'pearl_necklace':
508        if pars['radius'] < pars['thick_string']:
509            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
510
511    elif name == 'rpa':
512        # Make sure phi sums to 1.0
513        if pars['case_num'] < 2:
514            pars['Phi1'] = 0.
515            pars['Phi2'] = 0.
516        elif pars['case_num'] < 5:
517            pars['Phi1'] = 0.
518        total = sum(pars['Phi'+c] for c in '1234')
519        for c in '1234':
520            pars['Phi'+c] /= total
521
522def parlist(model_info, pars, is2d):
523    # type: (ModelInfo, ParameterSet, bool) -> str
524    """
525    Format the parameter list for printing.
526    """
527    is2d = True
528    lines = []
529    parameters = model_info.parameters
530    magnetic = False
531    magnetic_pars = []
532    for p in parameters.user_parameters(pars, is2d):
533        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
534            continue
535        if p.id.startswith('up:'):
536            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
537            continue
538        fields = dict(
539            value=pars.get(p.id, p.default),
540            pd=pars.get(p.id+"_pd", 0.),
541            n=int(pars.get(p.id+"_pd_n", 0)),
542            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
543            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
544            relative_pd=p.relative_pd,
545            M0=pars.get('M0:'+p.id, 0.),
546            mphi=pars.get('mphi:'+p.id, 0.),
547            mtheta=pars.get('mtheta:'+p.id, 0.),
548        )
549        lines.append(_format_par(p.name, **fields))
550        magnetic = magnetic or fields['M0'] != 0.
551    if magnetic and magnetic_pars:
552        lines.append(" ".join(magnetic_pars))
553    return "\n".join(lines)
554
555    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
556
557def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
558                relative_pd=False, M0=0., mphi=0., mtheta=0.):
559    # type: (str, float, float, int, float, str) -> str
560    line = "%s: %g"%(name, value)
561    if pd != 0.  and n != 0:
562        if relative_pd:
563            pd *= value
564        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
565                % (pd, n, nsigma, nsigma, pdtype)
566    if M0 != 0.:
567        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
568    return line
569
570def suppress_pd(pars, suppress=True):
571    # type: (ParameterSet) -> ParameterSet
572    """
573    If suppress is True complete eliminate polydispersity of the model to test
574    models more quickly.  If suppress is False, make sure at least one
575    parameter is polydisperse, setting the first polydispersity parameter to
576    15% if no polydispersity is given (with no explicit demo parameters given
577    in the model, there will be no default polydispersity).
578    """
579    pars = pars.copy()
580    #print("pars=", pars)
581    if suppress:
582        for p in pars:
583            if p.endswith("_pd_n"):
584                pars[p] = 0
585    else:
586        any_pd = False
587        first_pd = None
588        for p in pars:
589            if p.endswith("_pd_n"):
590                pd = pars.get(p[:-2], 0.)
591                any_pd |= (pars[p] != 0 and pd != 0.)
592                if first_pd is None:
593                    first_pd = p
594        if not any_pd and first_pd is not None:
595            if pars[first_pd] == 0:
596                pars[first_pd] = 35
597            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
598                pars[first_pd[:-2]] = 0.15
599    return pars
600
601def suppress_magnetism(pars, suppress=True):
602    # type: (ParameterSet) -> ParameterSet
603    """
604    If suppress is True complete eliminate magnetism of the model to test
605    models more quickly.  If suppress is False, make sure at least one sld
606    parameter is magnetic, setting the first parameter to have a strong
607    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
608    in the model, there will be no default magnetism).
609    """
610    pars = pars.copy()
611    if suppress:
612        for p in pars:
613            if p.startswith("M0:"):
614                pars[p] = 0
615    else:
616        any_mag = False
617        first_mag = None
618        for p in pars:
619            if p.startswith("M0:"):
620                any_mag |= (pars[p] != 0)
621                if first_mag is None:
622                    first_mag = p
623        if not any_mag and first_mag is not None:
624            pars[first_mag] = 8.
625    return pars
626
627# TODO: remove support for sasview 3.x models
628def eval_sasview(model_info, data):
629    # type: (Modelinfo, Data) -> Calculator
630    """
631    Return a model calculator using the pre-4.0 SasView models.
632    """
633    # importing sas here so that the error message will be that sas failed to
634    # import rather than the more obscure smear_selection not imported error
635    import sas
636    import sas.models
637    from sas.models.qsmearing import smear_selection
638    from sas.models.MultiplicationModel import MultiplicationModel
639    from sas.models.dispersion_models import models as dispersers
640
641    def _get_model_class(name):
642        # type: (str) -> "sas.models.BaseComponent"
643        #print("new",sorted(_pars.items()))
644        __import__('sas.models.' + name)
645        ModelClass = getattr(getattr(sas.models, name, None), name, None)
646        if ModelClass is None:
647            raise ValueError("could not find model %r in sas.models"%name)
648        return ModelClass
649
650    # WARNING: ugly hack when handling model!
651    # Sasview models with multiplicity need to be created with the target
652    # multiplicity, so we cannot create the target model ahead of time for
653    # for multiplicity models.  Instead we store the model in a list and
654    # update the first element of that list with the new multiplicity model
655    # every time we evaluate.
656
657    # grab the sasview model, or create it if it is a product model
658    if model_info.composition:
659        composition_type, parts = model_info.composition
660        if composition_type == 'product':
661            P, S = [_get_model_class(revert_name(p))() for p in parts]
662            model = [MultiplicationModel(P, S)]
663        else:
664            raise ValueError("sasview mixture models not supported by compare")
665    else:
666        old_name = revert_name(model_info)
667        if old_name is None:
668            raise ValueError("model %r does not exist in old sasview"
669                             % model_info.id)
670        ModelClass = _get_model_class(old_name)
671        model = [ModelClass()]
672    model[0].disperser_handles = {}
673
674    # build a smearer with which to call the model, if necessary
675    smearer = smear_selection(data, model=model)
676    if hasattr(data, 'qx_data'):
677        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
678        index = ((~data.mask) & (~np.isnan(data.data))
679                 & (q >= data.qmin) & (q <= data.qmax))
680        if smearer is not None:
681            smearer.model = model  # because smear_selection has a bug
682            smearer.accuracy = data.accuracy
683            smearer.set_index(index)
684            def _call_smearer():
685                smearer.model = model[0]
686                return smearer.get_value()
687            theory = _call_smearer
688        else:
689            theory = lambda: model[0].evalDistribution([data.qx_data[index],
690                                                        data.qy_data[index]])
691    elif smearer is not None:
692        theory = lambda: smearer(model[0].evalDistribution(data.x))
693    else:
694        theory = lambda: model[0].evalDistribution(data.x)
695
696    def calculator(**pars):
697        # type: (float, ...) -> np.ndarray
698        """
699        Sasview calculator for model.
700        """
701        oldpars = revert_pars(model_info, pars)
702        # For multiplicity models, create a model with the correct multiplicity
703        control = oldpars.pop("CONTROL", None)
704        if control is not None:
705            # sphericalSLD has one fewer multiplicity.  This update should
706            # happen in revert_pars, but it hasn't been called yet.
707            model[0] = ModelClass(control)
708        # paying for parameter conversion each time to keep life simple, if not fast
709        for k, v in oldpars.items():
710            if k.endswith('.type'):
711                par = k[:-5]
712                if v == 'gaussian': continue
713                cls = dispersers[v if v != 'rectangle' else 'rectangula']
714                handle = cls()
715                model[0].disperser_handles[par] = handle
716                try:
717                    model[0].set_dispersion(par, handle)
718                except Exception:
719                    exception.annotate_exception("while setting %s to %r"
720                                                 %(par, v))
721                    raise
722
723
724        #print("sasview pars",oldpars)
725        for k, v in oldpars.items():
726            name_attr = k.split('.')  # polydispersity components
727            if len(name_attr) == 2:
728                par, disp_par = name_attr
729                model[0].dispersion[par][disp_par] = v
730            else:
731                model[0].setParam(k, v)
732        return theory()
733
734    calculator.engine = "sasview"
735    return calculator
736
737DTYPE_MAP = {
738    'half': '16',
739    'fast': 'fast',
740    'single': '32',
741    'double': '64',
742    'quad': '128',
743    'f16': '16',
744    'f32': '32',
745    'f64': '64',
746    'float16': '16',
747    'float32': '32',
748    'float64': '64',
749    'float128': '128',
750    'longdouble': '128',
751}
752def eval_opencl(model_info, data, dtype='single', cutoff=0.):
753    # type: (ModelInfo, Data, str, float) -> Calculator
754    """
755    Return a model calculator using the OpenCL calculation engine.
756    """
757    if not core.HAVE_OPENCL:
758        raise RuntimeError("OpenCL not available")
759    model = core.build_model(model_info, dtype=dtype, platform="ocl")
760    calculator = DirectModel(data, model, cutoff=cutoff)
761    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
762    return calculator
763
764def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
765    # type: (ModelInfo, Data, str, float) -> Calculator
766    """
767    Return a model calculator using the DLL calculation engine.
768    """
769    model = core.build_model(model_info, dtype=dtype, platform="dll")
770    calculator = DirectModel(data, model, cutoff=cutoff)
771    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
772    return calculator
773
774def time_calculation(calculator, pars, evals=1):
775    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
776    """
777    Compute the average calculation time over N evaluations.
778
779    An additional call is generated without polydispersity in order to
780    initialize the calculation engine, and make the average more stable.
781    """
782    # initialize the code so time is more accurate
783    if evals > 1:
784        calculator(**suppress_pd(pars))
785    toc = tic()
786    # make sure there is at least one eval
787    value = calculator(**pars)
788    for _ in range(evals-1):
789        value = calculator(**pars)
790    average_time = toc()*1000. / evals
791    #print("I(q)",value)
792    return value, average_time
793
794def make_data(opts):
795    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
796    """
797    Generate an empty dataset, used with the model to set Q points
798    and resolution.
799
800    *opts* contains the options, with 'qmax', 'nq', 'res',
801    'accuracy', 'is2d' and 'view' parsed from the command line.
802    """
803    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res']
804    if opts['is2d']:
805        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
806        data = empty_data2D(q, resolution=res)
807        data.accuracy = opts['accuracy']
808        set_beam_stop(data, 0.0004)
809        index = ~data.mask
810    else:
811        if opts['view'] == 'log' and not opts['zero']:
812            q = np.logspace(math.log10(qmin), math.log10(qmax), nq)
813        else:
814            q = np.linspace(qmin, qmax, nq)
815        if opts['zero']:
816            q = np.hstack((0, q))
817        data = empty_data1D(q, resolution=res)
818        index = slice(None, None)
819    return data, index
820
821def make_engine(model_info, data, dtype, cutoff):
822    # type: (ModelInfo, Data, str, float) -> Calculator
823    """
824    Generate the appropriate calculation engine for the given datatype.
825
826    Datatypes with '!' appended are evaluated using external C DLLs rather
827    than OpenCL.
828    """
829    if dtype == 'sasview':
830        return eval_sasview(model_info, data)
831    elif dtype is None or not dtype.endswith('!'):
832        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
833    else:
834        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
835
836def _show_invalid(data, theory):
837    # type: (Data, np.ma.ndarray) -> None
838    """
839    Display a list of the non-finite values in theory.
840    """
841    if not theory.mask.any():
842        return
843
844    if hasattr(data, 'x'):
845        bad = zip(data.x[theory.mask], theory[theory.mask])
846        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
847
848
849def compare(opts, limits=None, maxdim=np.inf):
850    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
851    """
852    Preform a comparison using options from the command line.
853
854    *limits* are the limits on the values to use, either to set the y-axis
855    for 1D or to set the colormap scale for 2D.  If None, then they are
856    inferred from the data and returned. When exploring using Bumps,
857    the limits are set when the model is initially called, and maintained
858    as the values are adjusted, making it easier to see the effects of the
859    parameters.
860
861    *maxdim* is the maximum value for any parameter with units of Angstrom.
862    """
863    for k in range(opts['sets']):
864        if k > 1:
865            # print a separate seed for each dataset for better reproducibility
866            new_seed = np.random.randint(1000000)
867            print("Set %d uses -random=%i"%(k+1, new_seed))
868            np.random.seed(new_seed)
869        opts['pars'] = parse_pars(opts, maxdim=maxdim)
870        if opts['pars'] is None:
871            return
872        result = run_models(opts, verbose=True)
873        if opts['plot']:
874            limits = plot_models(opts, result, limits=limits, setnum=k)
875        if opts['show_weights']:
876            base, _ = opts['engines']
877            base_pars, _ = opts['pars']
878            model_info = base._kernel.info
879            dim = base._kernel.dim
880            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
881    if opts['plot']:
882        import matplotlib.pyplot as plt
883        plt.show()
884    return limits
885
886def run_models(opts, verbose=False):
887    # type: (Dict[str, Any]) -> Dict[str, Any]
888    """
889    Process a parameter set, return calculation results and times.
890    """
891
892    base, comp = opts['engines']
893    base_n, comp_n = opts['count']
894    base_pars, comp_pars = opts['pars']
895    data = opts['data']
896
897    comparison = comp is not None
898
899    base_time = comp_time = None
900    base_value = comp_value = resid = relerr = None
901
902    # Base calculation
903    try:
904        base_raw, base_time = time_calculation(base, base_pars, base_n)
905        base_value = np.ma.masked_invalid(base_raw)
906        if verbose:
907            print("%s t=%.2f ms, intensity=%.0f"
908                  % (base.engine, base_time, base_value.sum()))
909        _show_invalid(data, base_value)
910    except ImportError:
911        traceback.print_exc()
912
913    # Comparison calculation
914    if comparison:
915        try:
916            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
917            comp_value = np.ma.masked_invalid(comp_raw)
918            if verbose:
919                print("%s t=%.2f ms, intensity=%.0f"
920                      % (comp.engine, comp_time, comp_value.sum()))
921            _show_invalid(data, comp_value)
922        except ImportError:
923            traceback.print_exc()
924
925    # Compare, but only if computing both forms
926    if comparison:
927        resid = (base_value - comp_value)
928        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
929        if verbose:
930            _print_stats("|%s-%s|"
931                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
932                         resid)
933            _print_stats("|(%s-%s)/%s|"
934                         % (base.engine, comp.engine, comp.engine),
935                         relerr)
936
937    return dict(base_value=base_value, comp_value=comp_value,
938                base_time=base_time, comp_time=comp_time,
939                resid=resid, relerr=relerr)
940
941
942def _print_stats(label, err):
943    # type: (str, np.ma.ndarray) -> None
944    # work with trimmed data, not the full set
945    sorted_err = np.sort(abs(err.compressed()))
946    if len(sorted_err) == 0.:
947        print(label + "  no valid values")
948        return
949
950    p50 = int((len(sorted_err)-1)*0.50)
951    p98 = int((len(sorted_err)-1)*0.98)
952    data = [
953        "max:%.3e"%sorted_err[-1],
954        "median:%.3e"%sorted_err[p50],
955        "98%%:%.3e"%sorted_err[p98],
956        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
957        "zero-offset:%+.3e"%np.mean(sorted_err),
958        ]
959    print(label+"  "+"  ".join(data))
960
961
962def plot_models(opts, result, limits=None, setnum=0):
963    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
964    """
965    Plot the results from :func:`run_model`.
966    """
967    import matplotlib.pyplot as plt
968
969    base_value, comp_value = result['base_value'], result['comp_value']
970    base_time, comp_time = result['base_time'], result['comp_time']
971    resid, relerr = result['resid'], result['relerr']
972
973    have_base, have_comp = (base_value is not None), (comp_value is not None)
974    base, comp = opts['engines']
975    data = opts['data']
976    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
977
978    # Plot if requested
979    view = opts['view']
980    if limits is None:
981        vmin, vmax = np.inf, -np.inf
982        if have_base:
983            vmin = min(vmin, base_value.min())
984            vmax = max(vmax, base_value.max())
985        if have_comp:
986            vmin = min(vmin, comp_value.min())
987            vmax = max(vmax, comp_value.max())
988        limits = vmin, vmax
989
990    if have_base:
991        if have_comp:
992            plt.subplot(131)
993        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
994        plt.title("%s t=%.2f ms"%(base.engine, base_time))
995        #cbar_title = "log I"
996    if have_comp:
997        if have_base:
998            plt.subplot(132)
999        if not opts['is2d'] and have_base:
1000            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
1001        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
1002        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
1003        #cbar_title = "log I"
1004    if have_base and have_comp:
1005        plt.subplot(133)
1006        if not opts['rel_err']:
1007            err, errstr, errview = resid, "abs err", "linear"
1008        else:
1009            err, errstr, errview = abs(relerr), "rel err", "log"
1010            if (err == 0.).all():
1011                errview = 'linear'
1012        if 0:  # 95% cutoff
1013            sorted_err = np.sort(err.flatten())
1014            cutoff = sorted_err[int(sorted_err.size*0.95)]
1015            err[err > cutoff] = cutoff
1016        #err,errstr = base/comp,"ratio"
1017        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
1018        plt.xscale('log' if view == 'log' and not opts['is2d'] else 'linear')
1019        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
1020        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
1021        #cbar_title = errstr if errview=="linear" else "log "+errstr
1022    #if is2D:
1023    #    h = plt.colorbar()
1024    #    h.ax.set_title(cbar_title)
1025    fig = plt.gcf()
1026    extra_title = ' '+opts['title'] if opts['title'] else ''
1027    fig.suptitle(":".join(opts['name']) + extra_title)
1028
1029    if have_base and have_comp and opts['show_hist']:
1030        plt.figure()
1031        v = relerr
1032        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
1033        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
1034        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
1035                   % (base.engine, comp.engine, comp.engine))
1036        plt.ylabel('P(err)')
1037        plt.title('Distribution of relative error between calculation engines')
1038
1039    return limits
1040
1041
1042# ===========================================================================
1043#
1044
1045# Set of command line options.
1046# Normal options such as -plot/-noplot are specified as 'name'.
1047# For options such as -nq=500 which require a value use 'name='.
1048#
1049OPTIONS = [
1050    # Plotting
1051    'plot', 'noplot', 'weights',
1052    'linear', 'log', 'q4',
1053    'rel', 'abs',
1054    'hist', 'nohist',
1055    'title=',
1056
1057    # Data generation
1058    'data=', 'noise=', 'res=', 'nq=', 'q=',
1059    'lowq', 'midq', 'highq', 'exq', 'zero',
1060    '2d', '1d',
1061
1062    # Parameter set
1063    'preset', 'random', 'random=', 'sets=',
1064    'demo', 'default',  # TODO: remove demo/default
1065    'nopars', 'pars',
1066    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
1067
1068    # Calculation options
1069    'poly', 'mono', 'cutoff=',
1070    'magnetic', 'nonmagnetic',
1071    'accuracy=',
1072    'neval=',  # for timing...
1073
1074    # Precision options
1075    'engine=',
1076    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1077    'sasview',  # TODO: remove sasview 3.x support
1078
1079    # Output options
1080    'help', 'html', 'edit',
1081    ]
1082
1083NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1084VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1085
1086
1087def columnize(items, indent="", width=79):
1088    # type: (List[str], str, int) -> str
1089    """
1090    Format a list of strings into columns.
1091
1092    Returns a string with carriage returns ready for printing.
1093    """
1094    column_width = max(len(w) for w in items) + 1
1095    num_columns = (width - len(indent)) // column_width
1096    num_rows = len(items) // num_columns
1097    items = items + [""] * (num_rows * num_columns - len(items))
1098    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1099    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1100             for row in zip(*columns)]
1101    output = indent + ("\n"+indent).join(lines)
1102    return output
1103
1104
1105def get_pars(model_info, use_demo=False):
1106    # type: (ModelInfo, bool) -> ParameterSet
1107    """
1108    Extract demo parameters from the model definition.
1109    """
1110    # Get the default values for the parameters
1111    pars = {}
1112    for p in model_info.parameters.call_parameters:
1113        parts = [('', p.default)]
1114        if p.polydisperse:
1115            parts.append(('_pd', 0.0))
1116            parts.append(('_pd_n', 0))
1117            parts.append(('_pd_nsigma', 3.0))
1118            parts.append(('_pd_type', "gaussian"))
1119        for ext, val in parts:
1120            if p.length > 1:
1121                dict(("%s%d%s" % (p.id, k, ext), val)
1122                     for k in range(1, p.length+1))
1123            else:
1124                pars[p.id + ext] = val
1125
1126    # Plug in values given in demo
1127    if use_demo and model_info.demo:
1128        pars.update(model_info.demo)
1129    return pars
1130
1131INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1132def isnumber(s):
1133    # type: (str) -> bool
1134    """Return True if string contains an int or float"""
1135    match = FLOAT_RE.match(s)
1136    isfloat = (match and not s[match.end():])
1137    return isfloat or INTEGER_RE.match(s)
1138
1139# For distinguishing pairs of models for comparison
1140# key-value pair separator =
1141# shell characters  | & ; <> $ % ' " \ # `
1142# model and parameter names _
1143# parameter expressions - + * / . ( )
1144# path characters including tilde expansion and windows drive ~ / :
1145# not sure about brackets [] {}
1146# maybe one of the following @ ? ^ ! ,
1147PAR_SPLIT = ','
1148def parse_opts(argv):
1149    # type: (List[str]) -> Dict[str, Any]
1150    """
1151    Parse command line options.
1152    """
1153    MODELS = core.list_models()
1154    flags = [arg for arg in argv
1155             if arg.startswith('-')]
1156    values = [arg for arg in argv
1157              if not arg.startswith('-') and '=' in arg]
1158    positional_args = [arg for arg in argv
1159                       if not arg.startswith('-') and '=' not in arg]
1160    models = "\n    ".join("%-15s"%v for v in MODELS)
1161    if len(positional_args) == 0:
1162        print(USAGE)
1163        print("\nAvailable models:")
1164        print(columnize(MODELS, indent="  "))
1165        return None
1166
1167    invalid = [o[1:] for o in flags
1168               if o[1:] not in NAME_OPTIONS
1169               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1170    if invalid:
1171        print("Invalid options: %s"%(", ".join(invalid)))
1172        return None
1173
1174    name = positional_args[-1]
1175
1176    # pylint: disable=bad-whitespace
1177    # Interpret the flags
1178    opts = {
1179        'plot'      : True,
1180        'view'      : 'log',
1181        'is2d'      : False,
1182        'qmin'      : None,
1183        'qmax'      : 0.05,
1184        'nq'        : 128,
1185        'res'       : 0.0,
1186        'noise'     : 0.0,
1187        'accuracy'  : 'Low',
1188        'cutoff'    : '0.0',
1189        'seed'      : -1,  # default to preset
1190        'mono'      : True,
1191        # Default to magnetic a magnetic moment is set on the command line
1192        'magnetic'  : False,
1193        'show_pars' : False,
1194        'show_hist' : False,
1195        'rel_err'   : True,
1196        'explore'   : False,
1197        'use_demo'  : True,
1198        'zero'      : False,
1199        'html'      : False,
1200        'title'     : None,
1201        'datafile'  : None,
1202        'sets'      : 0,
1203        'engine'    : 'default',
1204        'count'     : '1',
1205        'show_weights' : False,
1206        'sphere'    : 0,
1207    }
1208    for arg in flags:
1209        if arg == '-noplot':    opts['plot'] = False
1210        elif arg == '-plot':    opts['plot'] = True
1211        elif arg == '-linear':  opts['view'] = 'linear'
1212        elif arg == '-log':     opts['view'] = 'log'
1213        elif arg == '-q4':      opts['view'] = 'q4'
1214        elif arg == '-1d':      opts['is2d'] = False
1215        elif arg == '-2d':      opts['is2d'] = True
1216        elif arg == '-exq':     opts['qmax'] = 10.0
1217        elif arg == '-highq':   opts['qmax'] = 1.0
1218        elif arg == '-midq':    opts['qmax'] = 0.2
1219        elif arg == '-lowq':    opts['qmax'] = 0.05
1220        elif arg == '-zero':    opts['zero'] = True
1221        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1222        elif arg.startswith('-q='):
1223            opts['qmin'], opts['qmax'] = [float(v) for v in arg[3:].split(':')]
1224        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1225        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1226        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1227        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1228        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1229        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1230        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1231        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1232        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
1233        elif arg.startswith('-random='):
1234            opts['seed'] = int(arg[8:])
1235            opts['sets'] = 0
1236        elif arg == '-random':
1237            opts['seed'] = np.random.randint(1000000)
1238            opts['sets'] = 0
1239        elif arg.startswith('-sphere'):
1240            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1241            opts['is2d'] = True
1242        elif arg == '-preset':  opts['seed'] = -1
1243        elif arg == '-mono':    opts['mono'] = True
1244        elif arg == '-poly':    opts['mono'] = False
1245        elif arg == '-magnetic':       opts['magnetic'] = True
1246        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1247        elif arg == '-pars':    opts['show_pars'] = True
1248        elif arg == '-nopars':  opts['show_pars'] = False
1249        elif arg == '-hist':    opts['show_hist'] = True
1250        elif arg == '-nohist':  opts['show_hist'] = False
1251        elif arg == '-rel':     opts['rel_err'] = True
1252        elif arg == '-abs':     opts['rel_err'] = False
1253        elif arg == '-half':    opts['engine'] = 'half'
1254        elif arg == '-fast':    opts['engine'] = 'fast'
1255        elif arg == '-single':  opts['engine'] = 'single'
1256        elif arg == '-double':  opts['engine'] = 'double'
1257        elif arg == '-single!': opts['engine'] = 'single!'
1258        elif arg == '-double!': opts['engine'] = 'double!'
1259        elif arg == '-quad!':   opts['engine'] = 'quad!'
1260        elif arg == '-sasview': opts['engine'] = 'sasview'
1261        elif arg == '-edit':    opts['explore'] = True
1262        elif arg == '-demo':    opts['use_demo'] = True
1263        elif arg == '-default': opts['use_demo'] = False
1264        elif arg == '-weights': opts['show_weights'] = True
1265        elif arg == '-html':    opts['html'] = True
1266        elif arg == '-help':    opts['html'] = True
1267    # pylint: enable=bad-whitespace
1268
1269    # Magnetism forces 2D for now
1270    if opts['magnetic']:
1271        opts['is2d'] = True
1272
1273    # Force random if sets is used
1274    if opts['sets'] >= 1 and opts['seed'] < 0:
1275        opts['seed'] = np.random.randint(1000000)
1276    if opts['sets'] == 0:
1277        opts['sets'] = 1
1278
1279    # Create the computational engines
1280    if opts['qmin'] is None:
1281        opts['qmin'] = 0.001*opts['qmax']
1282    if opts['datafile'] is not None:
1283        data = load_data(os.path.expanduser(opts['datafile']))
1284    else:
1285        data, _ = make_data(opts)
1286
1287    comparison = any(PAR_SPLIT in v for v in values)
1288    if PAR_SPLIT in name:
1289        names = name.split(PAR_SPLIT, 2)
1290        comparison = True
1291    else:
1292        names = [name]*2
1293    try:
1294        model_info = [core.load_model_info(k) for k in names]
1295    except ImportError as exc:
1296        print(str(exc))
1297        print("Could not find model; use one of:\n    " + models)
1298        return None
1299
1300    if PAR_SPLIT in opts['engine']:
1301        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
1302        comparison = True
1303    else:
1304        opts['engine'] = [opts['engine']]*2
1305
1306    if PAR_SPLIT in opts['count']:
1307        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
1308        comparison = True
1309    else:
1310        opts['count'] = [int(opts['count'])]*2
1311
1312    if PAR_SPLIT in opts['cutoff']:
1313        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1314        comparison = True
1315    else:
1316        opts['cutoff'] = [float(opts['cutoff'])]*2
1317
1318    base = make_engine(model_info[0], data, opts['engine'][0], opts['cutoff'][0])
1319    if comparison:
1320        comp = make_engine(model_info[1], data, opts['engine'][1], opts['cutoff'][1])
1321    else:
1322        comp = None
1323
1324    # pylint: disable=bad-whitespace
1325    # Remember it all
1326    opts.update({
1327        'data'      : data,
1328        'name'      : names,
1329        'info'      : model_info,
1330        'engines'   : [base, comp],
1331        'values'    : values,
1332    })
1333    # pylint: enable=bad-whitespace
1334
1335    # Set the integration parameters to the half sphere
1336    if opts['sphere'] > 0:
1337        set_spherical_integration_parameters(opts, opts['sphere'])
1338
1339    return opts
1340
1341def set_spherical_integration_parameters(opts, steps):
1342    # type: (Dict[str, Any], int) -> None
1343    """
1344    Set integration parameters for spherical integration over the entire
1345    surface in theta-phi coordinates.
1346    """
1347    # Set the integration parameters to the half sphere
1348    opts['values'].extend([
1349        #'theta=90',
1350        'theta_pd=%g'%(90/np.sqrt(3)),
1351        'theta_pd_n=%d'%steps,
1352        'theta_pd_type=rectangle',
1353        #'phi=0',
1354        'phi_pd=%g'%(180/np.sqrt(3)),
1355        'phi_pd_n=%d'%(2*steps),
1356        'phi_pd_type=rectangle',
1357        #'background=0',
1358    ])
1359    if 'psi' in opts['info'][0].parameters:
1360        opts['values'].extend([
1361            #'psi=0',
1362            'psi_pd=%g'%(180/np.sqrt(3)),
1363            'psi_pd_n=%d'%(2*steps),
1364            'psi_pd_type=rectangle',
1365        ])
1366
1367def parse_pars(opts, maxdim=np.inf):
1368    # type: (Dict[str, Any], float) -> Tuple[Dict[str, float], Dict[str, float]]
1369    """
1370    Generate a parameter set.
1371
1372    The default values come from the model, or a randomized model if a seed
1373    value is given.  Next, evaluate any parameter expressions, constraining
1374    the value of the parameter within and between models.  If *maxdim* is
1375    given, limit parameters with units of Angstrom to this value.
1376
1377    Returns a pair of parameter dictionaries for base and comparison models.
1378    """
1379    model_info, model_info2 = opts['info']
1380
1381    # Get demo parameters from model definition, or use default parameters
1382    # if model does not define demo parameters
1383    pars = get_pars(model_info, opts['use_demo'])
1384    pars2 = get_pars(model_info2, opts['use_demo'])
1385    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1386    # randomize parameters
1387    #pars.update(set_pars)  # set value before random to control range
1388    if opts['seed'] > -1:
1389        pars = randomize_pars(model_info, pars)
1390        limit_dimensions(model_info, pars, maxdim)
1391        if model_info != model_info2:
1392            pars2 = randomize_pars(model_info2, pars2)
1393            limit_dimensions(model_info, pars2, maxdim)
1394            # Share values for parameters with the same name
1395            for k, v in pars.items():
1396                if k in pars2:
1397                    pars2[k] = v
1398        else:
1399            pars2 = pars.copy()
1400        constrain_pars(model_info, pars)
1401        constrain_pars(model_info2, pars2)
1402    pars = suppress_pd(pars, opts['mono'])
1403    pars2 = suppress_pd(pars2, opts['mono'])
1404    pars = suppress_magnetism(pars, not opts['magnetic'])
1405    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1406
1407    # Fill in parameters given on the command line
1408    presets = {}
1409    presets2 = {}
1410    for arg in opts['values']:
1411        k, v = arg.split('=', 1)
1412        if k not in pars and k not in pars2:
1413            # extract base name without polydispersity info
1414            s = set(p.split('_pd')[0] for p in pars)
1415            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1416            return None
1417        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v, v)
1418        if v1 and k in pars:
1419            presets[k] = float(v1) if isnumber(v1) else v1
1420        if v2 and k in pars2:
1421            presets2[k] = float(v2) if isnumber(v2) else v2
1422
1423    # If pd given on the command line, default pd_n to 35
1424    for k, v in list(presets.items()):
1425        if k.endswith('_pd'):
1426            presets.setdefault(k+'_n', 35.)
1427    for k, v in list(presets2.items()):
1428        if k.endswith('_pd'):
1429            presets2.setdefault(k+'_n', 35.)
1430
1431    # Evaluate preset parameter expressions
1432    context = MATH.copy()
1433    context['np'] = np
1434    context.update(pars)
1435    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1436    for k, v in presets.items():
1437        if not isinstance(v, float) and not k.endswith('_type'):
1438            presets[k] = eval(v, context)
1439    context.update(presets)
1440    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1441    for k, v in presets2.items():
1442        if not isinstance(v, float) and not k.endswith('_type'):
1443            presets2[k] = eval(v, context)
1444
1445    # update parameters with presets
1446    pars.update(presets)  # set value after random to control value
1447    pars2.update(presets2)  # set value after random to control value
1448    #import pprint; pprint.pprint(model_info)
1449
1450    if opts['show_pars']:
1451        if model_info.name != model_info2.name or pars != pars2:
1452            print("==== %s ====="%model_info.name)
1453            print(str(parlist(model_info, pars, opts['is2d'])))
1454            print("==== %s ====="%model_info2.name)
1455            print(str(parlist(model_info2, pars2, opts['is2d'])))
1456        else:
1457            print(str(parlist(model_info, pars, opts['is2d'])))
1458
1459    return pars, pars2
1460
1461def show_docs(opts):
1462    # type: (Dict[str, Any]) -> None
1463    """
1464    show html docs for the model
1465    """
1466    import os
1467    from .generate import make_html
1468    from . import rst2html
1469
1470    info = opts['info'][0]
1471    html = make_html(info)
1472    path = os.path.dirname(info.filename)
1473    url = "file://" + path.replace("\\", "/")[2:] + "/"
1474    rst2html.view_html_qtapp(html, url)
1475
1476def explore(opts):
1477    # type: (Dict[str, Any]) -> None
1478    """
1479    explore the model using the bumps gui.
1480    """
1481    import wx  # type: ignore
1482    from bumps.names import FitProblem  # type: ignore
1483    from bumps.gui.app_frame import AppFrame  # type: ignore
1484    from bumps.gui import signal
1485
1486    is_mac = "cocoa" in wx.version()
1487    # Create an app if not running embedded
1488    app = wx.App() if wx.GetApp() is None else None
1489    model = Explore(opts)
1490    problem = FitProblem(model)
1491    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1492    if not is_mac:
1493        frame.Show()
1494    frame.panel.set_model(model=problem)
1495    frame.panel.Layout()
1496    frame.panel.aui.Split(0, wx.TOP)
1497    def _reset_parameters(event):
1498        model.revert_values()
1499        signal.update_parameters(problem)
1500    frame.Bind(wx.EVT_TOOL, _reset_parameters, frame.ToolBar.GetToolByPos(1))
1501    if is_mac: frame.Show()
1502    # If running withing an app, start the main loop
1503    if app:
1504        app.MainLoop()
1505
1506class Explore(object):
1507    """
1508    Bumps wrapper for a SAS model comparison.
1509
1510    The resulting object can be used as a Bumps fit problem so that
1511    parameters can be adjusted in the GUI, with plots updated on the fly.
1512    """
1513    def __init__(self, opts):
1514        # type: (Dict[str, Any]) -> None
1515        from bumps.cli import config_matplotlib  # type: ignore
1516        from . import bumps_model
1517        config_matplotlib()
1518        self.opts = opts
1519        opts['pars'] = list(opts['pars'])
1520        p1, p2 = opts['pars']
1521        m1, m2 = opts['info']
1522        self.fix_p2 = m1 != m2 or p1 != p2
1523        model_info = m1
1524        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1525        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1526        if not opts['is2d']:
1527            for p in model_info.parameters.user_parameters({}, is2d=False):
1528                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1529                    k = p.name+ext
1530                    v = pars.get(k, None)
1531                    if v is not None:
1532                        v.range(*parameter_range(k, v.value))
1533        else:
1534            for k, v in pars.items():
1535                v.range(*parameter_range(k, v.value))
1536
1537        self.pars = pars
1538        self.starting_values = dict((k, v.value) for k, v in pars.items())
1539        self.pd_types = pd_types
1540        self.limits = None
1541
1542    def revert_values(self):
1543        # type: () -> None
1544        """
1545        Restore starting values of the parameters.
1546        """
1547        for k, v in self.starting_values.items():
1548            self.pars[k].value = v
1549
1550    def model_update(self):
1551        # type: () -> None
1552        """
1553        Respond to signal that model parameters have been changed.
1554        """
1555        pass
1556
1557    def numpoints(self):
1558        # type: () -> int
1559        """
1560        Return the number of points.
1561        """
1562        return len(self.pars) + 1  # so dof is 1
1563
1564    def parameters(self):
1565        # type: () -> Any   # Dict/List hierarchy of parameters
1566        """
1567        Return a dictionary of parameters.
1568        """
1569        return self.pars
1570
1571    def nllf(self):
1572        # type: () -> float
1573        """
1574        Return cost.
1575        """
1576        # pylint: disable=no-self-use
1577        return 0.  # No nllf
1578
1579    def plot(self, view='log'):
1580        # type: (str) -> None
1581        """
1582        Plot the data and residuals.
1583        """
1584        pars = dict((k, v.value) for k, v in self.pars.items())
1585        pars.update(self.pd_types)
1586        self.opts['pars'][0] = pars
1587        if not self.fix_p2:
1588            self.opts['pars'][1] = pars
1589        result = run_models(self.opts)
1590        limits = plot_models(self.opts, result, limits=self.limits)
1591        if self.limits is None:
1592            vmin, vmax = limits
1593            self.limits = vmax*1e-7, 1.3*vmax
1594            import pylab; pylab.clf()
1595            plot_models(self.opts, result, limits=self.limits)
1596
1597
1598def main(*argv):
1599    # type: (*str) -> None
1600    """
1601    Main program.
1602    """
1603    opts = parse_opts(argv)
1604    if opts is not None:
1605        if opts['seed'] > -1:
1606            print("Randomize using -random=%i"%opts['seed'])
1607            np.random.seed(opts['seed'])
1608        if opts['html']:
1609            show_docs(opts)
1610        elif opts['explore']:
1611            opts['pars'] = parse_pars(opts)
1612            if opts['pars'] is None:
1613                return
1614            explore(opts)
1615        else:
1616            compare(opts)
1617
1618if __name__ == "__main__":
1619    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.