source: sasmodels/sasmodels/compare.py @ 376b0ee

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 376b0ee was 376b0ee, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

sascomp: fix bug in -random when comparing models

  • Property mode set to 100755
File size: 55.4 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
[caeb06d]3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
[9cfcac8]11tag the model as double precision only.
[caeb06d]12
[9cfcac8]13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
[caeb06d]14sasmodels root to see the command line options.
15
[9cfcac8]16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
[caeb06d]18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
[190fc2b]31import sys
[a769b54]32import os
[190fc2b]33import math
34import datetime
35import traceback
[ff1fff5]36import re
[190fc2b]37
[7ae2b7f]38import numpy as np  # type: ignore
[190fc2b]39
40from . import core
41from . import kerneldll
[6831fa0]42from . import exception
[a769b54]43from .data import plot_theory, empty_data1D, empty_data2D, load_data
[3c24ccd]44from .direct_model import DirectModel, get_mesh
[f247314]45from .convert import revert_name, revert_pars, constrain_new_to_old
[ff1fff5]46from .generate import FLOAT_RE
[3c24ccd]47from .weights import plot_weights
[190fc2b]48
[dd7fc12]49try:
50    from typing import Optional, Dict, Any, Callable, Tuple
[6831fa0]51except Exception:
[dd7fc12]52    pass
53else:
54    from .modelinfo import ModelInfo, Parameter, ParameterSet
55    from .data import Data
[8d62008]56    Calculator = Callable[[float], np.ndarray]
[dd7fc12]57
[caeb06d]58USAGE = """
[bb39b4a]59usage: sascomp model [options...] [key=val]
[caeb06d]60
[bb39b4a]61Generate and compare SAS models.  If a single model is specified it shows
62a plot of that model.  Different models can be compared, or the same model
63with different parameters.  The same model with the same parameters can
64be compared with different calculation engines to see the effects of precision
65on the resultant values.
[caeb06d]66
[8c65a33]67model or model1,model2 are the names of the models to compare (see below).
[caeb06d]68
69Options (* for default):
70
[bb39b4a]71    === data generation ===
72    -data="path" uses q, dq from the data file
73    -noise=0 sets the measurement error dI/I
74    -res=0 sets the resolution width dQ/Q if calculating with resolution
[caeb06d]75    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[ced5bd2]76    -q=min:max alternative specification of qrange
[caeb06d]77    -nq=128 sets the number of Q points in the data set
78    -1d*/-2d computes 1d or 2d data
[bb39b4a]79    -zero indicates that q=0 should be included
80
81    === model parameters ===
[caeb06d]82    -preset*/-random[=seed] preset or random parameters
[d9ec8f9]83    -sets=n generates n random datasets with the seed given by -random=seed
[caeb06d]84    -pars/-nopars* prints the parameter set or not
[98d6cfc]85    -default/-demo* use demo vs default parameters
[e3571cb]86    -sphere[=150] set up spherical integration over theta/phi using n points
[caeb06d]87
[bb39b4a]88    === calculation options ===
[e3571cb]89    -mono*/-poly force monodisperse or allow polydisperse random parameters
[bb39b4a]90    -cutoff=1e-5* cutoff value for including a point in polydispersity
91    -magnetic/-nonmagnetic* suppress magnetism
92    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
[765eb0e]93    -neval=1 sets the number of evals for more accurate timing
[caeb06d]94
[bb39b4a]95    === precision options ===
[8698a0d]96    -engine=default uses the default calcution precision
[caeb06d]97    -single/-double/-half/-fast sets an OpenCL calculation engine
98    -single!/-double!/-quad! sets an OpenMP calculation engine
99    -sasview sets the sasview calculation engine
100
[bb39b4a]101    === plotting ===
102    -plot*/-noplot plots or suppress the plot of the model
103    -linear/-log*/-q4 intensity scaling on plots
104    -hist/-nohist* plot histogram of relative error
105    -abs/-rel* plot relative or absolute error
106    -title="note" adds note to the plot title, after the model name
[3c24ccd]107    -weights shows weights plots for the polydisperse parameters
[bb39b4a]108
109    === output options ===
110    -edit starts the parameter explorer
111    -help/-html shows the model docs instead of running the model
112
113The interpretation of quad precision depends on architecture, and may
114vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
115On unix and mac you may need single quotes around the DLL computation
[8698a0d]116engines, such as -engine='single!,double!' since !, is treated as a history
[bb39b4a]117expansion request in the shell.
[caeb06d]118
119Key=value pairs allow you to set specific values for the model parameters.
[bb39b4a]120Key=value1,value2 to compare different values of the same parameter. The
121value can be an expression including other parameters.
122
123Items later on the command line override those that appear earlier.
124
125Examples:
126
127    # compare single and double precision calculation for a barbell
[8698a0d]128    sascomp barbell -engine=single,double
[bb39b4a]129
130    # generate 10 random lorentz models, with seed=27
131    sascomp lorentz -sets=10 -seed=27
132
133    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
134    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
135
136    # model timing test requires multiple evals to perform the estimate
[8698a0d]137    sascomp pringle -engine=single,double -timing=100,100 -noplot
[caeb06d]138"""
139
140# Update docs with command line usage string.   This is separate from the usual
141# doc string so that we can display it at run time if there is an error.
142# lin
[d15a908]143__doc__ = (__doc__  # pylint: disable=redefined-builtin
144           + """
[caeb06d]145Program description
146-------------------
147
[bb39b4a]148""" + USAGE)
[caeb06d]149
[750ffa5]150kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]151
[248561a]152# list of math functions for use in evaluating parameters
153MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
154
[7cf2cfd]155# CRUFT python 2.6
156if not hasattr(datetime.timedelta, 'total_seconds'):
157    def delay(dt):
158        """Return number date-time delta as number seconds"""
159        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
160else:
161    def delay(dt):
162        """Return number date-time delta as number seconds"""
163        return dt.total_seconds()
164
165
[4f2478e]166class push_seed(object):
167    """
168    Set the seed value for the random number generator.
169
170    When used in a with statement, the random number generator state is
171    restored after the with statement is complete.
172
173    :Parameters:
174
175    *seed* : int or array_like, optional
176        Seed for RandomState
177
178    :Example:
179
180    Seed can be used directly to set the seed::
181
182        >>> from numpy.random import randint
183        >>> push_seed(24)
184        <...push_seed object at...>
185        >>> print(randint(0,1000000,3))
186        [242082    899 211136]
187
188    Seed can also be used in a with statement, which sets the random
189    number generator state for the enclosed computations and restores
190    it to the previous state on completion::
191
192        >>> with push_seed(24):
193        ...    print(randint(0,1000000,3))
194        [242082    899 211136]
195
196    Using nested contexts, we can demonstrate that state is indeed
197    restored after the block completes::
198
199        >>> with push_seed(24):
200        ...    print(randint(0,1000000))
201        ...    with push_seed(24):
202        ...        print(randint(0,1000000,3))
203        ...    print(randint(0,1000000))
204        242082
205        [242082    899 211136]
206        899
207
208    The restore step is protected against exceptions in the block::
209
210        >>> with push_seed(24):
211        ...    print(randint(0,1000000))
212        ...    try:
213        ...        with push_seed(24):
214        ...            print(randint(0,1000000,3))
215        ...            raise Exception()
[dd7fc12]216        ...    except Exception:
[4f2478e]217        ...        print("Exception raised")
218        ...    print(randint(0,1000000))
219        242082
220        [242082    899 211136]
221        Exception raised
222        899
223    """
224    def __init__(self, seed=None):
[dd7fc12]225        # type: (Optional[int]) -> None
[4f2478e]226        self._state = np.random.get_state()
227        np.random.seed(seed)
228
229    def __enter__(self):
[dd7fc12]230        # type: () -> None
231        pass
[4f2478e]232
[b32dafd]233    def __exit__(self, exc_type, exc_value, traceback):
[dd7fc12]234        # type: (Any, BaseException, Any) -> None
235        # TODO: better typing for __exit__ method
[4f2478e]236        np.random.set_state(self._state)
237
[7cf2cfd]238def tic():
[dd7fc12]239    # type: () -> Callable[[], float]
[7cf2cfd]240    """
241    Timer function.
242
243    Use "toc=tic()" to start the clock and "toc()" to measure
244    a time interval.
245    """
246    then = datetime.datetime.now()
247    return lambda: delay(datetime.datetime.now() - then)
248
249
250def set_beam_stop(data, radius, outer=None):
[dd7fc12]251    # type: (Data, float, float) -> None
[7cf2cfd]252    """
253    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
254
[dd7fc12]255    Note: this function does not require sasview
[7cf2cfd]256    """
257    if hasattr(data, 'qx_data'):
258        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
259        data.mask = (q < radius)
260        if outer is not None:
261            data.mask |= (q >= outer)
262    else:
263        data.mask = (data.x < radius)
264        if outer is not None:
265            data.mask |= (data.x >= outer)
266
[8a20be5]267
[ec7e360]268def parameter_range(p, v):
[dd7fc12]269    # type: (str, float) -> Tuple[float, float]
[87985ca]270    """
[ec7e360]271    Choose a parameter range based on parameter name and initial value.
[87985ca]272    """
[8bd7b77]273    # process the polydispersity options
[ec7e360]274    if p.endswith('_pd_n'):
[dd7fc12]275        return 0., 100.
[ec7e360]276    elif p.endswith('_pd_nsigma'):
[dd7fc12]277        return 0., 5.
[ec7e360]278    elif p.endswith('_pd_type'):
[dd7fc12]279        raise ValueError("Cannot return a range for a string value")
[caeb06d]280    elif any(s in p for s in ('theta', 'phi', 'psi')):
[87985ca]281        # orientation in [-180,180], orientation pd in [0,45]
282        if p.endswith('_pd'):
[e3571cb]283            return 0., 180.
[87985ca]284        else:
[dd7fc12]285            return -180., 180.
[87985ca]286    elif p.endswith('_pd'):
[dd7fc12]287        return 0., 1.
[8bd7b77]288    elif 'sld' in p:
[dd7fc12]289        return -0.5, 10.
[eb46451]290    elif p == 'background':
[dd7fc12]291        return 0., 10.
[eb46451]292    elif p == 'scale':
[dd7fc12]293        return 0., 1.e3
294    elif v < 0.:
295        return 2.*v, -2.*v
[87985ca]296    else:
[dd7fc12]297        return 0., (2.*v if v > 0. else 1.)
[87985ca]298
[4f2478e]299
[0bdddc2]300def _randomize_one(model_info, name, value):
[dd7fc12]301    # type: (ModelInfo, str, float) -> float
302    # type: (ModelInfo, str, str) -> str
[ec7e360]303    """
[caeb06d]304    Randomize a single parameter.
[ec7e360]305    """
[31df0c9]306    # Set the amount of polydispersity/angular dispersion, but by default pd_n
307    # is zero so there is no polydispersity.  This allows us to turn on/off
308    # pd by setting pd_n, and still have randomly generated values
[0bdddc2]309    if name.endswith('_pd'):
310        par = model_info.parameters[name[:-3]]
311        if par.type == 'orientation':
312            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
313            return 180*np.random.beta(2.5, 20)
314        else:
315            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
316            return np.random.beta(1.5, 7)
[8bd7b77]317
[31df0c9]318    # pd is selected globally rather than per parameter, so set to 0 for no pd
319    # In particular, when multiple pd dimensions, want to decrease the number
320    # of points per dimension for faster computation
[0bdddc2]321    if name.endswith('_pd_n'):
322        return 0
323
[31df0c9]324    # Don't mess with distribution type for now
[0bdddc2]325    if name.endswith('_pd_type'):
326        return 'gaussian'
327
[31df0c9]328    # type-dependent value of number of sigmas; for gaussian use 3.
[0bdddc2]329    if name.endswith('_pd_nsigma'):
330        return 3.
[8bd7b77]331
[31df0c9]332    # background in the range [0.01, 1]
[0bdddc2]333    if name == 'background':
[31df0c9]334        return 10**np.random.uniform(-2, 0)
[0bdddc2]335
[31df0c9]336    # scale defaults to 0.1% to 30% volume fraction
[0bdddc2]337    if name == 'scale':
[31df0c9]338        return 10**np.random.uniform(-3, -0.5)
[0bdddc2]339
[31df0c9]340    # If it is a list of choices, pick one at random with equal probability
341    # In practice, the model specific random generator will override.
[0bdddc2]342    par = model_info.parameters[name]
[8bd7b77]343    if len(par.limits) > 2:  # choice list
344        return np.random.randint(len(par.limits))
345
[31df0c9]346    # If it is a fixed range, pick from it with equal probability.
347    # For logarithmic ranges, the model will have to override.
[0bdddc2]348    if np.isfinite(par.limits).all():
349        return np.random.uniform(*par.limits)
350
[31df0c9]351    # If the paramter is marked as an sld use the range of neutron slds
[0f6c41c]352    # TODO: ought to randomly contrast match a pair of SLDs
[0bdddc2]353    if par.type == 'sld':
354        return np.random.uniform(-0.5, 12)
[8bd7b77]355
[0f6c41c]356    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
357    if par.name.startswith('M0:'):
358        return np.random.uniform(0, 5)
359
[31df0c9]360    # Guess at the random length/radius/thickness.  In practice, all models
361    # are going to set their own reasonable ranges.
[0bdddc2]362    if par.type == 'volume':
363        if ('length' in par.name or
364                'radius' in par.name or
365                'thick' in par.name):
[31df0c9]366            return 10**np.random.uniform(2, 4)
[0bdddc2]367
[31df0c9]368    # In the absence of any other info, select a value in [0, 2v], or
369    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
370    # model random parameter generators will override this default.
[0bdddc2]371    low, high = parameter_range(par.name, value)
372    limits = (max(par.limits[0], low), min(par.limits[1], high))
[8bd7b77]373    return np.random.uniform(*limits)
[cd3dba0]374
[109d963]375def _random_pd(model_info, pars):
376    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
377    pd_volume = []
378    pd_oriented = []
379    for p in pd:
380        if p.type == 'orientation':
381            pd_oriented.append(p.name)
382        elif p.length_control is not None:
[232bb12]383            n = int(pars.get(p.length_control, 1) + 0.5)
[109d963]384            pd_volume.extend(p.name+str(k+1) for k in range(n))
385        elif p.length > 1:
386            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
387        else:
388            pd_volume.append(p.name)
389    u = np.random.rand()
390    n = len(pd_volume)
391    if u < 0.01 or n < 1:
392        pass  # 1% chance of no polydispersity
393    elif u < 0.86 or n < 2:
394        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
395    elif u < 0.99 or n < 3:
396        choices = np.random.choice(len(pd_volume), size=2)
397        pars[pd_volume[choices[0]]+"_pd_n"] = 25
398        pars[pd_volume[choices[1]]+"_pd_n"] = 10
399    else:
400        choices = np.random.choice(len(pd_volume), size=3)
401        pars[pd_volume[choices[0]]+"_pd_n"] = 25
402        pars[pd_volume[choices[1]]+"_pd_n"] = 10
403        pars[pd_volume[choices[2]]+"_pd_n"] = 5
404    if pd_oriented:
405        pars['theta_pd_n'] = 20
406        if np.random.rand() < 0.1:
407            pars['phi_pd_n'] = 5
408        if np.random.rand() < 0.1:
[4553dae]409            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
410                #print("generating psi_pd_n")
411                pars['psi_pd_n'] = 5
[109d963]412
413    ## Show selected polydispersity
414    #for name, value in pars.items():
415    #    if name.endswith('_pd_n') and value > 0:
416    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
417
418
419def randomize_pars(model_info, pars):
420    # type: (ModelInfo, ParameterSet) -> ParameterSet
[caeb06d]421    """
422    Generate random values for all of the parameters.
423
424    Valid ranges for the random number generator are guessed from the name of
425    the parameter; this will not account for constraints such as cap radius
426    greater than cylinder radius in the capped_cylinder model, so
427    :func:`constrain_pars` needs to be called afterward..
428    """
[0bdddc2]429    # Note: the sort guarantees order of calls to random number generator
430    random_pars = dict((p, _randomize_one(model_info, p, v))
431                       for p, v in sorted(pars.items()))
432    if model_info.random is not None:
433        random_pars.update(model_info.random())
[109d963]434    _random_pd(model_info, random_pars)
[dd7fc12]435    return random_pars
[cd3dba0]436
[109d963]437
[e3571cb]438def limit_dimensions(model_info, pars, maxdim):
439    # type: (ModelInfo, ParameterSet, float) -> None
440    """
441    Limit parameters of units of Ang to maxdim.
442    """
443    for p in model_info.parameters.call_parameters:
444        value = pars[p.name]
445        if p.units == 'Ang' and value > maxdim:
446            pars[p.name] = maxdim*10**np.random.uniform(-3,0)
447
[17bbadd]448def constrain_pars(model_info, pars):
[dd7fc12]449    # type: (ModelInfo, ParameterSet) -> None
[9a66e65]450    """
451    Restrict parameters to valid values.
[caeb06d]452
453    This includes model specific code for models such as capped_cylinder
454    which need to support within model constraints (cap radius more than
455    cylinder radius in this case).
[dd7fc12]456
457    Warning: this updates the *pars* dictionary in place.
[9a66e65]458    """
[109d963]459    # TODO: move the model specific code to the individual models
[6d6508e]460    name = model_info.id
[17bbadd]461    # if it is a product model, then just look at the form factor since
462    # none of the structure factors need any constraints.
463    if '*' in name:
464        name = name.split('*')[0]
465
[f72d70a]466    # Suppress magnetism for python models (not yet implemented)
467    if callable(model_info.Iq):
468        pars.update(suppress_magnetism(pars))
469
[158cee4]470    if name == 'barbell':
471        if pars['radius_bell'] < pars['radius']:
472            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
[b514adf]473
[158cee4]474    elif name == 'capped_cylinder':
475        if pars['radius_cap'] < pars['radius']:
476            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
477
478    elif name == 'guinier':
479        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
[48462b0]480        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
481        # => ln A - (Rg^2 q^2/3) > -30 ln 10
482        # => Rg^2 q^2/3 < 30 ln 10 + ln A
483        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
[b514adf]484        #q_max = 0.2  # mid q maximum
485        q_max = 1.0  # high q maximum
486        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
[caeb06d]487        pars['rg'] = min(pars['rg'], rg_max)
[cd3dba0]488
[3e8ea5d]489    elif name == 'pearl_necklace':
490        if pars['radius'] < pars['thick_string']:
491            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
492        pass
493
[158cee4]494    elif name == 'rpa':
[82c299f]495        # Make sure phi sums to 1.0
496        if pars['case_num'] < 2:
[8bd7b77]497            pars['Phi1'] = 0.
498            pars['Phi2'] = 0.
[82c299f]499        elif pars['case_num'] < 5:
[8bd7b77]500            pars['Phi1'] = 0.
501        total = sum(pars['Phi'+c] for c in '1234')
502        for c in '1234':
[82c299f]503            pars['Phi'+c] /= total
504
[d6850fa]505def parlist(model_info, pars, is2d):
[dd7fc12]506    # type: (ModelInfo, ParameterSet, bool) -> str
[caeb06d]507    """
508    Format the parameter list for printing.
509    """
[e3571cb]510    is2d = True
[a4a7308]511    lines = []
[6d6508e]512    parameters = model_info.parameters
[0b040de]513    magnetic = False
[97d89af]514    magnetic_pars = []
[d19962c]515    for p in parameters.user_parameters(pars, is2d):
[0b040de]516        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
517            continue
[97d89af]518        if p.id.startswith('up:'):
519            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
[0b040de]520            continue
[d19962c]521        fields = dict(
522            value=pars.get(p.id, p.default),
523            pd=pars.get(p.id+"_pd", 0.),
524            n=int(pars.get(p.id+"_pd_n", 0)),
525            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
[dd7fc12]526            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
[bd49c79]527            relative_pd=p.relative_pd,
[0b040de]528            M0=pars.get('M0:'+p.id, 0.),
529            mphi=pars.get('mphi:'+p.id, 0.),
530            mtheta=pars.get('mtheta:'+p.id, 0.),
[dd7fc12]531        )
[d19962c]532        lines.append(_format_par(p.name, **fields))
[0b040de]533        magnetic = magnetic or fields['M0'] != 0.
[97d89af]534    if magnetic and magnetic_pars:
535        lines.append(" ".join(magnetic_pars))
[a4a7308]536    return "\n".join(lines)
537
538    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
539
[bd49c79]540def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
[0b040de]541                relative_pd=False, M0=0., mphi=0., mtheta=0.):
[dd7fc12]542    # type: (str, float, float, int, float, str) -> str
[a4a7308]543    line = "%s: %g"%(name, value)
544    if pd != 0.  and n != 0:
[bd49c79]545        if relative_pd:
546            pd *= value
[a4a7308]547        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
[dd7fc12]548                % (pd, n, nsigma, nsigma, pdtype)
[0b040de]549    if M0 != 0.:
[b76191e]550        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
[a4a7308]551    return line
[87985ca]552
[97d89af]553def suppress_pd(pars, suppress=True):
[dd7fc12]554    # type: (ParameterSet) -> ParameterSet
[87985ca]555    """
[97d89af]556    If suppress is True complete eliminate polydispersity of the model to test
557    models more quickly.  If suppress is False, make sure at least one
558    parameter is polydisperse, setting the first polydispersity parameter to
559    15% if no polydispersity is given (with no explicit demo parameters given
560    in the model, there will be no default polydispersity).
[87985ca]561    """
[f4f3919]562    pars = pars.copy()
[4553dae]563    #print("pars=", pars)
[97d89af]564    if suppress:
565        for p in pars:
566            if p.endswith("_pd_n"):
567                pars[p] = 0
568    else:
569        any_pd = False
570        first_pd = None
571        for p in pars:
572            if p.endswith("_pd_n"):
[4553dae]573                pd = pars.get(p[:-2], 0.)
574                any_pd |= (pars[p] != 0 and pd != 0.)
[97d89af]575                if first_pd is None:
576                    first_pd = p
577        if not any_pd and first_pd is not None:
578            if pars[first_pd] == 0:
579                pars[first_pd] = 35
[4553dae]580            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
[97d89af]581                pars[first_pd[:-2]] = 0.15
[f4f3919]582    return pars
[87985ca]583
[97d89af]584def suppress_magnetism(pars, suppress=True):
[0b040de]585    # type: (ParameterSet) -> ParameterSet
586    """
[97d89af]587    If suppress is True complete eliminate magnetism of the model to test
588    models more quickly.  If suppress is False, make sure at least one sld
589    parameter is magnetic, setting the first parameter to have a strong
590    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
591    in the model, there will be no default magnetism).
[0b040de]592    """
593    pars = pars.copy()
[97d89af]594    if suppress:
595        for p in pars:
596            if p.startswith("M0:"):
597                pars[p] = 0
598    else:
599        any_mag = False
600        first_mag = None
601        for p in pars:
602            if p.startswith("M0:"):
603                any_mag |= (pars[p] != 0)
604                if first_mag is None:
605                    first_mag = p
606        if not any_mag and first_mag is not None:
607            pars[first_mag] = 8.
[0b040de]608    return pars
609
[17bbadd]610def eval_sasview(model_info, data):
[dd7fc12]611    # type: (Modelinfo, Data) -> Calculator
[caeb06d]612    """
[f247314]613    Return a model calculator using the pre-4.0 SasView models.
[caeb06d]614    """
[dc056b9]615    # importing sas here so that the error message will be that sas failed to
616    # import rather than the more obscure smear_selection not imported error
[2bebe2b]617    import sas
[dd7fc12]618    import sas.models
[8d62008]619    from sas.models.qsmearing import smear_selection
620    from sas.models.MultiplicationModel import MultiplicationModel
[050c2c8]621    from sas.models.dispersion_models import models as dispersers
[ec7e360]622
[256dfe1]623    def get_model_class(name):
[dd7fc12]624        # type: (str) -> "sas.models.BaseComponent"
[17bbadd]625        #print("new",sorted(_pars.items()))
[dd7fc12]626        __import__('sas.models.' + name)
[17bbadd]627        ModelClass = getattr(getattr(sas.models, name, None), name, None)
628        if ModelClass is None:
629            raise ValueError("could not find model %r in sas.models"%name)
[256dfe1]630        return ModelClass
631
632    # WARNING: ugly hack when handling model!
633    # Sasview models with multiplicity need to be created with the target
634    # multiplicity, so we cannot create the target model ahead of time for
635    # for multiplicity models.  Instead we store the model in a list and
636    # update the first element of that list with the new multiplicity model
637    # every time we evaluate.
[17bbadd]638
639    # grab the sasview model, or create it if it is a product model
[6d6508e]640    if model_info.composition:
641        composition_type, parts = model_info.composition
[17bbadd]642        if composition_type == 'product':
[51ec7e8]643            P, S = [get_model_class(revert_name(p))() for p in parts]
[256dfe1]644            model = [MultiplicationModel(P, S)]
[17bbadd]645        else:
[72a081d]646            raise ValueError("sasview mixture models not supported by compare")
[17bbadd]647    else:
[f3bd37f]648        old_name = revert_name(model_info)
649        if old_name is None:
650            raise ValueError("model %r does not exist in old sasview"
651                            % model_info.id)
[256dfe1]652        ModelClass = get_model_class(old_name)
653        model = [ModelClass()]
[050c2c8]654    model[0].disperser_handles = {}
[216a9e1]655
[17bbadd]656    # build a smearer with which to call the model, if necessary
657    smearer = smear_selection(data, model=model)
[ec7e360]658    if hasattr(data, 'qx_data'):
659        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
660        index = ((~data.mask) & (~np.isnan(data.data))
661                 & (q >= data.qmin) & (q <= data.qmax))
662        if smearer is not None:
663            smearer.model = model  # because smear_selection has a bug
664            smearer.accuracy = data.accuracy
665            smearer.set_index(index)
[256dfe1]666            def _call_smearer():
667                smearer.model = model[0]
668                return smearer.get_value()
[b32dafd]669            theory = _call_smearer
[ec7e360]670        else:
[256dfe1]671            theory = lambda: model[0].evalDistribution([data.qx_data[index],
672                                                        data.qy_data[index]])
[ec7e360]673    elif smearer is not None:
[256dfe1]674        theory = lambda: smearer(model[0].evalDistribution(data.x))
[ec7e360]675    else:
[256dfe1]676        theory = lambda: model[0].evalDistribution(data.x)
[ec7e360]677
678    def calculator(**pars):
[dd7fc12]679        # type: (float, ...) -> np.ndarray
[caeb06d]680        """
681        Sasview calculator for model.
682        """
[256dfe1]683        oldpars = revert_pars(model_info, pars)
[bd49c79]684        # For multiplicity models, create a model with the correct multiplicity
685        control = oldpars.pop("CONTROL", None)
686        if control is not None:
687            # sphericalSLD has one fewer multiplicity.  This update should
688            # happen in revert_pars, but it hasn't been called yet.
689            model[0] = ModelClass(control)
690        # paying for parameter conversion each time to keep life simple, if not fast
[050c2c8]691        for k, v in oldpars.items():
692            if k.endswith('.type'):
693                par = k[:-5]
[6831fa0]694                if v == 'gaussian': continue
[050c2c8]695                cls = dispersers[v if v != 'rectangle' else 'rectangula']
696                handle = cls()
697                model[0].disperser_handles[par] = handle
[6831fa0]698                try:
699                    model[0].set_dispersion(par, handle)
700                except Exception:
701                    exception.annotate_exception("while setting %s to %r"
702                                                 %(par, v))
703                    raise
704
[050c2c8]705
[f67f26c]706        #print("sasview pars",oldpars)
[256dfe1]707        for k, v in oldpars.items():
[dd7fc12]708            name_attr = k.split('.')  # polydispersity components
709            if len(name_attr) == 2:
[050c2c8]710                par, disp_par = name_attr
711                model[0].dispersion[par][disp_par] = v
[ec7e360]712            else:
[256dfe1]713                model[0].setParam(k, v)
[ec7e360]714        return theory()
715
716    calculator.engine = "sasview"
717    return calculator
718
719DTYPE_MAP = {
720    'half': '16',
721    'fast': 'fast',
722    'single': '32',
723    'double': '64',
724    'quad': '128',
725    'f16': '16',
726    'f32': '32',
727    'f64': '64',
[650c6d2]728    'float16': '16',
729    'float32': '32',
730    'float64': '64',
731    'float128': '128',
[ec7e360]732    'longdouble': '128',
733}
[17bbadd]734def eval_opencl(model_info, data, dtype='single', cutoff=0.):
[dd7fc12]735    # type: (ModelInfo, Data, str, float) -> Calculator
[caeb06d]736    """
737    Return a model calculator using the OpenCL calculation engine.
738    """
[a738209]739    if not core.HAVE_OPENCL:
740        raise RuntimeError("OpenCL not available")
741    model = core.build_model(model_info, dtype=dtype, platform="ocl")
[7cf2cfd]742    calculator = DirectModel(data, model, cutoff=cutoff)
[bd21b12]743    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
[ec7e360]744    return calculator
[216a9e1]745
[17bbadd]746def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
[dd7fc12]747    # type: (ModelInfo, Data, str, float) -> Calculator
[9cfcac8]748    """
749    Return a model calculator using the DLL calculation engine.
750    """
[72a081d]751    model = core.build_model(model_info, dtype=dtype, platform="dll")
[7cf2cfd]752    calculator = DirectModel(data, model, cutoff=cutoff)
[883ecf4]753    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
[ec7e360]754    return calculator
755
[b32dafd]756def time_calculation(calculator, pars, evals=1):
[dd7fc12]757    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
[caeb06d]758    """
759    Compute the average calculation time over N evaluations.
760
761    An additional call is generated without polydispersity in order to
762    initialize the calculation engine, and make the average more stable.
763    """
[ec7e360]764    # initialize the code so time is more accurate
[b32dafd]765    if evals > 1:
[dd7fc12]766        calculator(**suppress_pd(pars))
[216a9e1]767    toc = tic()
[dd7fc12]768    # make sure there is at least one eval
769    value = calculator(**pars)
[b32dafd]770    for _ in range(evals-1):
[7cf2cfd]771        value = calculator(**pars)
[b32dafd]772    average_time = toc()*1000. / evals
[f2f67a6]773    #print("I(q)",value)
[216a9e1]774    return value, average_time
775
[ec7e360]776def make_data(opts):
[dd7fc12]777    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
[caeb06d]778    """
779    Generate an empty dataset, used with the model to set Q points
780    and resolution.
781
782    *opts* contains the options, with 'qmax', 'nq', 'res',
783    'accuracy', 'is2d' and 'view' parsed from the command line.
784    """
[ced5bd2]785    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res']
[ec7e360]786    if opts['is2d']:
[dd7fc12]787        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
788        data = empty_data2D(q, resolution=res)
[ec7e360]789        data.accuracy = opts['accuracy']
[376b0ee]790        set_beam_stop(data, qmin)
[87985ca]791        index = ~data.mask
[216a9e1]792    else:
[e78edc4]793        if opts['view'] == 'log' and not opts['zero']:
[ced5bd2]794            q = np.logspace(math.log10(qmin), math.log10(qmax), nq)
[b89f519]795        else:
[ced5bd2]796            q = np.linspace(qmin, qmax, nq)
[e78edc4]797        if opts['zero']:
798            q = np.hstack((0, q))
[ec7e360]799        data = empty_data1D(q, resolution=res)
[216a9e1]800        index = slice(None, None)
801    return data, index
802
[17bbadd]803def make_engine(model_info, data, dtype, cutoff):
[dd7fc12]804    # type: (ModelInfo, Data, str, float) -> Calculator
[caeb06d]805    """
806    Generate the appropriate calculation engine for the given datatype.
807
808    Datatypes with '!' appended are evaluated using external C DLLs rather
809    than OpenCL.
810    """
[ec7e360]811    if dtype == 'sasview':
[17bbadd]812        return eval_sasview(model_info, data)
[9f6823b]813    elif dtype is None or not dtype.endswith('!'):
[17bbadd]814        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
[bd21b12]815    else:
816        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
[87985ca]817
[e78edc4]818def _show_invalid(data, theory):
[dd7fc12]819    # type: (Data, np.ma.ndarray) -> None
820    """
821    Display a list of the non-finite values in theory.
822    """
[e78edc4]823    if not theory.mask.any():
824        return
825
826    if hasattr(data, 'x'):
827        bad = zip(data.x[theory.mask], theory[theory.mask])
[dd7fc12]828        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
[e78edc4]829
830
[e3571cb]831def compare(opts, limits=None, maxdim=np.inf):
[dd7fc12]832    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
[caeb06d]833    """
834    Preform a comparison using options from the command line.
835
836    *limits* are the limits on the values to use, either to set the y-axis
837    for 1D or to set the colormap scale for 2D.  If None, then they are
838    inferred from the data and returned. When exploring using Bumps,
839    the limits are set when the model is initially called, and maintained
840    as the values are adjusted, making it easier to see the effects of the
841    parameters.
[e3571cb]842
843    *maxdim* is the maximum value for any parameter with units of Angstrom.
[caeb06d]844    """
[0bdddc2]845    for k in range(opts['sets']):
[e3571cb]846        if k > 1:
847            # print a separate seed for each dataset for better reproducibility
848            new_seed = np.random.randint(1000000)
849            print("Set %d uses -random=%i"%(k+1,new_seed))
850            np.random.seed(new_seed)
851        opts['pars'] = parse_pars(opts, maxdim=maxdim)
[8f04da4]852        if opts['pars'] is None:
853            return
[0bdddc2]854        result = run_models(opts, verbose=True)
855        if opts['plot']:
856            limits = plot_models(opts, result, limits=limits, setnum=k)
[3c24ccd]857        if opts['show_weights']:
858            base, _ = opts['engines']
859            base_pars, _ = opts['pars']
860            model_info = base._kernel.info
861            dim = base._kernel.dim
862            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
[0bdddc2]863    if opts['plot']:
864        import matplotlib.pyplot as plt
865        plt.show()
[fbb9397]866    return limits
[ca9e54e]867
868def run_models(opts, verbose=False):
869    # type: (Dict[str, Any]) -> Dict[str, Any]
870
[bb39b4a]871    base, comp = opts['engines']
872    base_n, comp_n = opts['count']
873    base_pars, comp_pars = opts['pars']
[ec7e360]874    data = opts['data']
[87985ca]875
[bb39b4a]876    comparison = comp is not None
[ca9e54e]877
[dd7fc12]878    base_time = comp_time = None
879    base_value = comp_value = resid = relerr = None
880
[4b41184]881    # Base calculation
[bb39b4a]882    try:
883        base_raw, base_time = time_calculation(base, base_pars, base_n)
884        base_value = np.ma.masked_invalid(base_raw)
885        if verbose:
886            print("%s t=%.2f ms, intensity=%.0f"
887                  % (base.engine, base_time, base_value.sum()))
888        _show_invalid(data, base_value)
889    except ImportError:
890        traceback.print_exc()
[4b41184]891
892    # Comparison calculation
[bb39b4a]893    if comparison:
[7cf2cfd]894        try:
[bb39b4a]895            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
[dd7fc12]896            comp_value = np.ma.masked_invalid(comp_raw)
[ca9e54e]897            if verbose:
898                print("%s t=%.2f ms, intensity=%.0f"
899                      % (comp.engine, comp_time, comp_value.sum()))
[e78edc4]900            _show_invalid(data, comp_value)
[7cf2cfd]901        except ImportError:
[5753e4e]902            traceback.print_exc()
[87985ca]903
904    # Compare, but only if computing both forms
[bb39b4a]905    if comparison:
[ec7e360]906        resid = (base_value - comp_value)
[b32dafd]907        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
[ca9e54e]908        if verbose:
909            _print_stats("|%s-%s|"
910                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
911                         resid)
912            _print_stats("|(%s-%s)/%s|"
913                         % (base.engine, comp.engine, comp.engine),
914                         relerr)
915
916    return dict(base_value=base_value, comp_value=comp_value,
917                base_time=base_time, comp_time=comp_time,
918                resid=resid, relerr=relerr)
919
920
921def _print_stats(label, err):
922    # type: (str, np.ma.ndarray) -> None
923    # work with trimmed data, not the full set
924    sorted_err = np.sort(abs(err.compressed()))
925    if len(sorted_err) == 0.:
926        print(label + "  no valid values")
927        return
928
929    p50 = int((len(sorted_err)-1)*0.50)
930    p98 = int((len(sorted_err)-1)*0.98)
931    data = [
932        "max:%.3e"%sorted_err[-1],
933        "median:%.3e"%sorted_err[p50],
934        "98%%:%.3e"%sorted_err[p98],
935        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
936        "zero-offset:%+.3e"%np.mean(sorted_err),
937        ]
938    print(label+"  "+"  ".join(data))
939
940
[fbb9397]941def plot_models(opts, result, limits=None, setnum=0):
[ca9e54e]942    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
[fbb9397]943    import matplotlib.pyplot as plt
944
[97d89af]945    base_value, comp_value = result['base_value'], result['comp_value']
[ca9e54e]946    base_time, comp_time = result['base_time'], result['comp_time']
947    resid, relerr = result['resid'], result['relerr']
948
949    have_base, have_comp = (base_value is not None), (comp_value is not None)
[bb39b4a]950    base, comp = opts['engines']
[ca9e54e]951    data = opts['data']
[630156b]952    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
[87985ca]953
954    # Plot if requested
[ec7e360]955    view = opts['view']
[fbb9397]956    if limits is None:
957        vmin, vmax = np.inf, -np.inf
958        if have_base:
959            vmin = min(vmin, base_value.min())
960            vmax = max(vmax, base_value.max())
961        if have_comp:
962            vmin = min(vmin, comp_value.min())
963            vmax = max(vmax, comp_value.max())
964        limits = vmin, vmax
[013adb7]965
[ca9e54e]966    if have_base:
[bb39b4a]967        if have_comp:
968            plt.subplot(131)
[a769b54]969        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
[af92b73]970        plt.title("%s t=%.2f ms"%(base.engine, base_time))
[ec7e360]971        #cbar_title = "log I"
[ca9e54e]972    if have_comp:
[bb39b4a]973        if have_base:
974            plt.subplot(132)
[ca9e54e]975        if not opts['is2d'] and have_base:
[a769b54]976            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
977        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
[af92b73]978        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
[7cf2cfd]979        #cbar_title = "log I"
[ca9e54e]980    if have_base and have_comp:
[87985ca]981        plt.subplot(133)
[d5e650d]982        if not opts['rel_err']:
[caeb06d]983            err, errstr, errview = resid, "abs err", "linear"
[29f5536]984        else:
[caeb06d]985            err, errstr, errview = abs(relerr), "rel err", "log"
[ced5bd2]986            if (err == 0.).all():
987                errview = 'linear'
[158cee4]988        if 0:  # 95% cutoff
989            sorted = np.sort(err.flatten())
990            cutoff = sorted[int(sorted.size*0.95)]
[bb39b4a]991            err[err > cutoff] = cutoff
[4b41184]992        #err,errstr = base/comp,"ratio"
[a769b54]993        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
[3bfd924]994        plt.xscale('log' if view == 'log' and not opts['is2d'] else 'linear')
[e3571cb]995        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
[e78edc4]996        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
[7cf2cfd]997        #cbar_title = errstr if errview=="linear" else "log "+errstr
998    #if is2D:
999    #    h = plt.colorbar()
1000    #    h.ax.set_title(cbar_title)
[0c24a82]1001    fig = plt.gcf()
[a0d75ce]1002    extra_title = ' '+opts['title'] if opts['title'] else ''
[ff1fff5]1003    fig.suptitle(":".join(opts['name']) + extra_title)
[ba69383]1004
[ca9e54e]1005    if have_base and have_comp and opts['show_hist']:
[ba69383]1006        plt.figure()
[346bc88]1007        v = relerr
[caeb06d]1008        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
1009        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
1010        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
1011                   % (base.engine, comp.engine, comp.engine))
[ba69383]1012        plt.ylabel('P(err)')
[ec7e360]1013        plt.title('Distribution of relative error between calculation engines')
[ba69383]1014
[013adb7]1015    return limits
1016
[0763009]1017
[87985ca]1018# ===========================================================================
1019#
[bb39b4a]1020
1021# Set of command line options.
1022# Normal options such as -plot/-noplot are specified as 'name'.
1023# For options such as -nq=500 which require a value use 'name='.
1024#
1025OPTIONS = [
1026    # Plotting
[3c24ccd]1027    'plot', 'noplot', 'weights',
[b89f519]1028    'linear', 'log', 'q4',
[bb39b4a]1029    'rel', 'abs',
[5d316e9]1030    'hist', 'nohist',
[bb39b4a]1031    'title=',
1032
1033    # Data generation
[ced5bd2]1034    'data=', 'noise=', 'res=', 'nq=', 'q=',
1035    'lowq', 'midq', 'highq', 'exq', 'zero',
[bb39b4a]1036    '2d', '1d',
1037
1038    # Parameter set
1039    'preset', 'random', 'random=', 'sets=',
1040    'demo', 'default',  # TODO: remove demo/default
1041    'nopars', 'pars',
[e3571cb]1042    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
[bb39b4a]1043
1044    # Calculation options
1045    'poly', 'mono', 'cutoff=',
1046    'magnetic', 'nonmagnetic',
1047    'accuracy=',
[765eb0e]1048    'neval=',  # for timing...
[bb39b4a]1049
1050    # Precision options
[8698a0d]1051    'engine=',
[bb39b4a]1052    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1053    'sasview',  # TODO: remove sasview 3.x support
1054
1055    # Output options
1056    'help', 'html', 'edit',
[87985ca]1057    ]
1058
[bb39b4a]1059NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1060VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1061
1062
[b32dafd]1063def columnize(items, indent="", width=79):
[dd7fc12]1064    # type: (List[str], str, int) -> str
[caeb06d]1065    """
[1d4017a]1066    Format a list of strings into columns.
1067
1068    Returns a string with carriage returns ready for printing.
[caeb06d]1069    """
[b32dafd]1070    column_width = max(len(w) for w in items) + 1
[7cf2cfd]1071    num_columns = (width - len(indent)) // column_width
[b32dafd]1072    num_rows = len(items) // num_columns
1073    items = items + [""] * (num_rows * num_columns - len(items))
1074    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
[7cf2cfd]1075    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1076             for row in zip(*columns)]
1077    output = indent + ("\n"+indent).join(lines)
1078    return output
1079
1080
[98d6cfc]1081def get_pars(model_info, use_demo=False):
[dd7fc12]1082    # type: (ModelInfo, bool) -> ParameterSet
[caeb06d]1083    """
1084    Extract demo parameters from the model definition.
1085    """
[ec7e360]1086    # Get the default values for the parameters
[c499331]1087    pars = {}
[6d6508e]1088    for p in model_info.parameters.call_parameters:
[c499331]1089        parts = [('', p.default)]
1090        if p.polydisperse:
1091            parts.append(('_pd', 0.0))
1092            parts.append(('_pd_n', 0))
1093            parts.append(('_pd_nsigma', 3.0))
1094            parts.append(('_pd_type', "gaussian"))
1095        for ext, val in parts:
1096            if p.length > 1:
[b32dafd]1097                dict(("%s%d%s" % (p.id, k, ext), val)
1098                     for k in range(1, p.length+1))
[c499331]1099            else:
[b32dafd]1100                pars[p.id + ext] = val
[ec7e360]1101
1102    # Plug in values given in demo
[765eb0e]1103    if use_demo and model_info.demo:
[6d6508e]1104        pars.update(model_info.demo)
[373d1b6]1105    return pars
1106
[ff1fff5]1107INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1108def isnumber(str):
1109    match = FLOAT_RE.match(str)
1110    isfloat = (match and not str[match.end():])
1111    return isfloat or INTEGER_RE.match(str)
[17bbadd]1112
[8c65a33]1113# For distinguishing pairs of models for comparison
1114# key-value pair separator =
1115# shell characters  | & ; <> $ % ' " \ # `
1116# model and parameter names _
1117# parameter expressions - + * / . ( )
1118# path characters including tilde expansion and windows drive ~ / :
1119# not sure about brackets [] {}
1120# maybe one of the following @ ? ^ ! ,
[bb39b4a]1121PAR_SPLIT = ','
[424fe00]1122def parse_opts(argv):
1123    # type: (List[str]) -> Dict[str, Any]
[caeb06d]1124    """
1125    Parse command line options.
1126    """
[fc0fcd0]1127    MODELS = core.list_models()
[424fe00]1128    flags = [arg for arg in argv
[caeb06d]1129             if arg.startswith('-')]
[424fe00]1130    values = [arg for arg in argv
[caeb06d]1131              if not arg.startswith('-') and '=' in arg]
[424fe00]1132    positional_args = [arg for arg in argv
[0bdddc2]1133                       if not arg.startswith('-') and '=' not in arg]
[d547f16]1134    models = "\n    ".join("%-15s"%v for v in MODELS)
[424fe00]1135    if len(positional_args) == 0:
[7cf2cfd]1136        print(USAGE)
[caeb06d]1137        print("\nAvailable models:")
[7cf2cfd]1138        print(columnize(MODELS, indent="  "))
[424fe00]1139        return None
[87985ca]1140
[ec7e360]1141    invalid = [o[1:] for o in flags
[216a9e1]1142               if o[1:] not in NAME_OPTIONS
[d15a908]1143               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]1144    if invalid:
[9404dd3]1145        print("Invalid options: %s"%(", ".join(invalid)))
[424fe00]1146        return None
[87985ca]1147
[bb39b4a]1148    name = positional_args[-1]
[ec7e360]1149
[d15a908]1150    # pylint: disable=bad-whitespace
[ec7e360]1151    # Interpret the flags
1152    opts = {
1153        'plot'      : True,
1154        'view'      : 'log',
1155        'is2d'      : False,
[ced5bd2]1156        'qmin'      : None,
[ec7e360]1157        'qmax'      : 0.05,
1158        'nq'        : 128,
1159        'res'       : 0.0,
[bb39b4a]1160        'noise'     : 0.0,
[ec7e360]1161        'accuracy'  : 'Low',
[bb39b4a]1162        'cutoff'    : '0.0',
[ec7e360]1163        'seed'      : -1,  # default to preset
[630156b]1164        'mono'      : True,
[0b040de]1165        # Default to magnetic a magnetic moment is set on the command line
[b6f10d8]1166        'magnetic'  : False,
[ec7e360]1167        'show_pars' : False,
1168        'show_hist' : False,
1169        'rel_err'   : True,
1170        'explore'   : False,
[98d6cfc]1171        'use_demo'  : True,
[dd7fc12]1172        'zero'      : False,
[234c532]1173        'html'      : False,
[a0d75ce]1174        'title'     : None,
[630156b]1175        'datafile'  : None,
[d9ec8f9]1176        'sets'      : 0,
[bb39b4a]1177        'engine'    : 'default',
[e3571cb]1178        'count'     : '1',
[3c24ccd]1179        'show_weights' : False,
[e3571cb]1180        'sphere'    : 0,
[ec7e360]1181    }
1182    for arg in flags:
1183        if arg == '-noplot':    opts['plot'] = False
1184        elif arg == '-plot':    opts['plot'] = True
1185        elif arg == '-linear':  opts['view'] = 'linear'
1186        elif arg == '-log':     opts['view'] = 'log'
1187        elif arg == '-q4':      opts['view'] = 'q4'
1188        elif arg == '-1d':      opts['is2d'] = False
1189        elif arg == '-2d':      opts['is2d'] = True
1190        elif arg == '-exq':     opts['qmax'] = 10.0
1191        elif arg == '-highq':   opts['qmax'] = 1.0
1192        elif arg == '-midq':    opts['qmax'] = 0.2
[ce0b154]1193        elif arg == '-lowq':    opts['qmax'] = 0.05
[e78edc4]1194        elif arg == '-zero':    opts['zero'] = True
[ec7e360]1195        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
[ced5bd2]1196        elif arg.startswith('-q='):
1197            opts['qmin'], opts['qmax'] = [float(v) for v in arg[3:].split(':')]
[ec7e360]1198        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
[bb39b4a]1199        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
[0bdddc2]1200        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
[ec7e360]1201        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
[bb39b4a]1202        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
[a769b54]1203        elif arg.startswith('-title='):    opts['title'] = arg[7:]
[630156b]1204        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
[8698a0d]1205        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
[e3571cb]1206        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
[31eea1f]1207        elif arg.startswith('-random='):
1208            opts['seed'] = int(arg[8:])
1209            opts['sets'] = 0
1210        elif arg == '-random':
1211            opts['seed'] = np.random.randint(1000000)
1212            opts['sets'] = 0
[e3571cb]1213        elif arg.startswith('-sphere'):
1214            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1215            opts['is2d'] = True
[ec7e360]1216        elif arg == '-preset':  opts['seed'] = -1
1217        elif arg == '-mono':    opts['mono'] = True
1218        elif arg == '-poly':    opts['mono'] = False
[0b040de]1219        elif arg == '-magnetic':       opts['magnetic'] = True
1220        elif arg == '-nonmagnetic':    opts['magnetic'] = False
[ec7e360]1221        elif arg == '-pars':    opts['show_pars'] = True
1222        elif arg == '-nopars':  opts['show_pars'] = False
1223        elif arg == '-hist':    opts['show_hist'] = True
1224        elif arg == '-nohist':  opts['show_hist'] = False
1225        elif arg == '-rel':     opts['rel_err'] = True
1226        elif arg == '-abs':     opts['rel_err'] = False
[bb39b4a]1227        elif arg == '-half':    opts['engine'] = 'half'
1228        elif arg == '-fast':    opts['engine'] = 'fast'
1229        elif arg == '-single':  opts['engine'] = 'single'
1230        elif arg == '-double':  opts['engine'] = 'double'
1231        elif arg == '-single!': opts['engine'] = 'single!'
1232        elif arg == '-double!': opts['engine'] = 'double!'
1233        elif arg == '-quad!':   opts['engine'] = 'quad!'
1234        elif arg == '-sasview': opts['engine'] = 'sasview'
[ec7e360]1235        elif arg == '-edit':    opts['explore'] = True
[98d6cfc]1236        elif arg == '-demo':    opts['use_demo'] = True
[97d89af]1237        elif arg == '-default': opts['use_demo'] = False
[3c24ccd]1238        elif arg == '-weights': opts['show_weights'] = True
[234c532]1239        elif arg == '-html':    opts['html'] = True
[630156b]1240        elif arg == '-help':    opts['html'] = True
[d15a908]1241    # pylint: enable=bad-whitespace
[ec7e360]1242
[97d89af]1243    # Magnetism forces 2D for now
1244    if opts['magnetic']:
1245        opts['is2d'] = True
1246
[d9ec8f9]1247    # Force random if sets is used
1248    if opts['sets'] >= 1 and opts['seed'] < 0:
[0bdddc2]1249        opts['seed'] = np.random.randint(1000000)
[d9ec8f9]1250    if opts['sets'] == 0:
1251        opts['sets'] = 1
[0bdddc2]1252
[bb39b4a]1253    # Create the computational engines
[ced5bd2]1254    if opts['qmin'] is None:
1255        opts['qmin'] = 0.001*opts['qmax']
[bb39b4a]1256    if opts['datafile'] is not None:
1257        data = load_data(os.path.expanduser(opts['datafile']))
1258    else:
1259        data, _ = make_data(opts)
1260
1261    comparison = any(PAR_SPLIT in v for v in values)
1262    if PAR_SPLIT in name:
1263        names = name.split(PAR_SPLIT, 2)
1264        comparison = True
[ff1fff5]1265    else:
[bb39b4a]1266        names = [name]*2
[ff1fff5]1267    try:
[bb39b4a]1268        model_info = [core.load_model_info(k) for k in names]
[ff1fff5]1269    except ImportError as exc:
1270        print(str(exc))
1271        print("Could not find model; use one of:\n    " + models)
1272        return None
[87985ca]1273
[bb39b4a]1274    if PAR_SPLIT in opts['engine']:
[e3571cb]1275        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
[bb39b4a]1276        comparison = True
1277    else:
[e3571cb]1278        opts['engine'] = [opts['engine']]*2
[0bdddc2]1279
[e3571cb]1280    if PAR_SPLIT in opts['count']:
1281        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
[bb39b4a]1282        comparison = True
[0bdddc2]1283    else:
[e3571cb]1284        opts['count'] = [int(opts['count'])]*2
[bb39b4a]1285
1286    if PAR_SPLIT in opts['cutoff']:
[e3571cb]1287        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
[bb39b4a]1288        comparison = True
[0bdddc2]1289    else:
[e3571cb]1290        opts['cutoff'] = [float(opts['cutoff'])]*2
[bb39b4a]1291
[e3571cb]1292    base = make_engine(model_info[0], data, opts['engine'][0], opts['cutoff'][0])
[bb39b4a]1293    if comparison:
[e3571cb]1294        comp = make_engine(model_info[1], data, opts['engine'][1], opts['cutoff'][1])
[0bdddc2]1295    else:
1296        comp = None
1297
1298    # pylint: disable=bad-whitespace
1299    # Remember it all
1300    opts.update({
1301        'data'      : data,
[bb39b4a]1302        'name'      : names,
[e3571cb]1303        'info'      : model_info,
[0bdddc2]1304        'engines'   : [base, comp],
1305        'values'    : values,
1306    })
1307    # pylint: enable=bad-whitespace
1308
[e3571cb]1309    # Set the integration parameters to the half sphere
1310    if opts['sphere'] > 0:
1311        set_spherical_integration_parameters(opts, opts['sphere'])
1312
[0bdddc2]1313    return opts
1314
[e3571cb]1315def set_spherical_integration_parameters(opts, steps):
1316    """
1317    Set integration parameters for spherical integration over the entire
1318    surface in theta-phi coordinates.
1319    """
1320    # Set the integration parameters to the half sphere
1321    opts['values'].extend([
[31eea1f]1322        #'theta=90',
[e3571cb]1323        'theta_pd=%g'%(90/np.sqrt(3)),
1324        'theta_pd_n=%d'%steps,
1325        'theta_pd_type=rectangle',
[31eea1f]1326        #'phi=0',
[e3571cb]1327        'phi_pd=%g'%(180/np.sqrt(3)),
1328        'phi_pd_n=%d'%(2*steps),
1329        'phi_pd_type=rectangle',
1330        #'background=0',
1331    ])
1332    if 'psi' in opts['info'][0].parameters:
[a5f91a7]1333        opts['values'].extend([
1334            #'psi=0',
1335            'psi_pd=%g'%(180/np.sqrt(3)),
1336            'psi_pd_n=%d'%(2*steps),
1337            'psi_pd_type=rectangle',
1338        ])
[31eea1f]1339        pass
[e3571cb]1340
1341def parse_pars(opts, maxdim=np.inf):
1342    model_info, model_info2 = opts['info']
[0bdddc2]1343
[ec7e360]1344    # Get demo parameters from model definition, or use default parameters
1345    # if model does not define demo parameters
[98d6cfc]1346    pars = get_pars(model_info, opts['use_demo'])
[ff1fff5]1347    pars2 = get_pars(model_info2, opts['use_demo'])
[248561a]1348    pars2.update((k, v) for k, v in pars.items() if k in pars2)
[ff1fff5]1349    # randomize parameters
1350    #pars.update(set_pars)  # set value before random to control range
1351    if opts['seed'] > -1:
[0bdddc2]1352        pars = randomize_pars(model_info, pars)
[e3571cb]1353        limit_dimensions(model_info, pars, maxdim)
[ff1fff5]1354        if model_info != model_info2:
[0bdddc2]1355            pars2 = randomize_pars(model_info2, pars2)
[376b0ee]1356            limit_dimensions(model_info2, pars2, maxdim)
[158cee4]1357            # Share values for parameters with the same name
1358            for k, v in pars.items():
1359                if k in pars2:
1360                    pars2[k] = v
[ff1fff5]1361        else:
1362            pars2 = pars.copy()
[158cee4]1363        constrain_pars(model_info, pars)
1364        constrain_pars(model_info2, pars2)
[97d89af]1365    pars = suppress_pd(pars, opts['mono'])
1366    pars2 = suppress_pd(pars2, opts['mono'])
1367    pars = suppress_magnetism(pars, not opts['magnetic'])
1368    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
[87985ca]1369
1370    # Fill in parameters given on the command line
[ec7e360]1371    presets = {}
[ff1fff5]1372    presets2 = {}
[0bdddc2]1373    for arg in opts['values']:
[d15a908]1374        k, v = arg.split('=', 1)
[ff1fff5]1375        if k not in pars and k not in pars2:
[ec7e360]1376            # extract base name without polydispersity info
[87985ca]1377            s = set(p.split('_pd')[0] for p in pars)
[d15a908]1378            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
[424fe00]1379            return None
[bb39b4a]1380        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
[ff1fff5]1381        if v1 and k in pars:
1382            presets[k] = float(v1) if isnumber(v1) else v1
1383        if v2 and k in pars2:
1384            presets2[k] = float(v2) if isnumber(v2) else v2
1385
[b6f10d8]1386    # If pd given on the command line, default pd_n to 35
1387    for k, v in list(presets.items()):
1388        if k.endswith('_pd'):
1389            presets.setdefault(k+'_n', 35.)
1390    for k, v in list(presets2.items()):
1391        if k.endswith('_pd'):
1392            presets2.setdefault(k+'_n', 35.)
1393
[ff1fff5]1394    # Evaluate preset parameter expressions
[248561a]1395    context = MATH.copy()
[fe25eda]1396    context['np'] = np
[248561a]1397    context.update(pars)
[0bdddc2]1398    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
[ff1fff5]1399    for k, v in presets.items():
1400        if not isinstance(v, float) and not k.endswith('_type'):
1401            presets[k] = eval(v, context)
1402    context.update(presets)
[0bdddc2]1403    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
[ff1fff5]1404    for k, v in presets2.items():
1405        if not isinstance(v, float) and not k.endswith('_type'):
1406            presets2[k] = eval(v, context)
1407
1408    # update parameters with presets
[ec7e360]1409    pars.update(presets)  # set value after random to control value
[ff1fff5]1410    pars2.update(presets2)  # set value after random to control value
[fcd7bbd]1411    #import pprint; pprint.pprint(model_info)
[ff1fff5]1412
[ec7e360]1413    if opts['show_pars']:
[0bdddc2]1414        if model_info.name != model_info2.name or pars != pars2:
[248561a]1415            print("==== %s ====="%model_info.name)
1416            print(str(parlist(model_info, pars, opts['is2d'])))
1417            print("==== %s ====="%model_info2.name)
1418            print(str(parlist(model_info2, pars2, opts['is2d'])))
1419        else:
1420            print(str(parlist(model_info, pars, opts['is2d'])))
[ec7e360]1421
[0bdddc2]1422    return pars, pars2
[ec7e360]1423
[234c532]1424def show_docs(opts):
1425    # type: (Dict[str, Any]) -> None
1426    """
1427    show html docs for the model
1428    """
[c4e3215]1429    import os
1430    from .generate import make_html
1431    from . import rst2html
1432
[e3571cb]1433    info = opts['info'][0]
[c4e3215]1434    html = make_html(info)
1435    path = os.path.dirname(info.filename)
1436    url = "file://"+path.replace("\\","/")[2:]+"/"
1437    rst2html.view_html_qtapp(html, url)
[234c532]1438
[ec7e360]1439def explore(opts):
[dd7fc12]1440    # type: (Dict[str, Any]) -> None
[d15a908]1441    """
[234c532]1442    explore the model using the bumps gui.
[d15a908]1443    """
[7ae2b7f]1444    import wx  # type: ignore
1445    from bumps.names import FitProblem  # type: ignore
1446    from bumps.gui.app_frame import AppFrame  # type: ignore
[ca9e54e]1447    from bumps.gui import signal
[ec7e360]1448
[d15a908]1449    is_mac = "cocoa" in wx.version()
[80013a6]1450    # Create an app if not running embedded
1451    app = wx.App() if wx.GetApp() is None else None
[ca9e54e]1452    model = Explore(opts)
1453    problem = FitProblem(model)
[0bdddc2]1454    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1455    if not is_mac:
1456        frame.Show()
[ec7e360]1457    frame.panel.set_model(model=problem)
1458    frame.panel.Layout()
1459    frame.panel.aui.Split(0, wx.TOP)
[ca9e54e]1460    def reset_parameters(event):
1461        model.revert_values()
1462        signal.update_parameters(problem)
1463    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
[d15a908]1464    if is_mac: frame.Show()
[80013a6]1465    # If running withing an app, start the main loop
[0bdddc2]1466    if app:
1467        app.MainLoop()
[ec7e360]1468
1469class Explore(object):
1470    """
[d15a908]1471    Bumps wrapper for a SAS model comparison.
1472
1473    The resulting object can be used as a Bumps fit problem so that
1474    parameters can be adjusted in the GUI, with plots updated on the fly.
[ec7e360]1475    """
1476    def __init__(self, opts):
[dd7fc12]1477        # type: (Dict[str, Any]) -> None
[7ae2b7f]1478        from bumps.cli import config_matplotlib  # type: ignore
[608e31e]1479        from . import bumps_model
[ec7e360]1480        config_matplotlib()
1481        self.opts = opts
[0bdddc2]1482        opts['pars'] = list(opts['pars'])
[ca9e54e]1483        p1, p2 = opts['pars']
[e3571cb]1484        m1, m2 = opts['info']
[ca9e54e]1485        self.fix_p2 = m1 != m2 or p1 != p2
1486        model_info = m1
1487        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
[21b116f]1488        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
[ec7e360]1489        if not opts['is2d']:
[85fe7f8]1490            for p in model_info.parameters.user_parameters({}, is2d=False):
[303d8d6]1491                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
[69aa451]1492                    k = p.name+ext
[303d8d6]1493                    v = pars.get(k, None)
1494                    if v is not None:
1495                        v.range(*parameter_range(k, v.value))
[ec7e360]1496        else:
[013adb7]1497            for k, v in pars.items():
[ec7e360]1498                v.range(*parameter_range(k, v.value))
1499
1500        self.pars = pars
[ca9e54e]1501        self.starting_values = dict((k, v.value) for k, v in pars.items())
[ec7e360]1502        self.pd_types = pd_types
[fbb9397]1503        self.limits = None
[ec7e360]1504
[ca9e54e]1505    def revert_values(self):
1506        for k, v in self.starting_values.items():
1507            self.pars[k].value = v
1508
1509    def model_update(self):
1510        pass
1511
[ec7e360]1512    def numpoints(self):
[dd7fc12]1513        # type: () -> int
[ec7e360]1514        """
[608e31e]1515        Return the number of points.
[ec7e360]1516        """
1517        return len(self.pars) + 1  # so dof is 1
1518
1519    def parameters(self):
[dd7fc12]1520        # type: () -> Any   # Dict/List hierarchy of parameters
[ec7e360]1521        """
[608e31e]1522        Return a dictionary of parameters.
[ec7e360]1523        """
1524        return self.pars
1525
1526    def nllf(self):
[dd7fc12]1527        # type: () -> float
[608e31e]1528        """
1529        Return cost.
1530        """
[d15a908]1531        # pylint: disable=no-self-use
[ec7e360]1532        return 0.  # No nllf
1533
1534    def plot(self, view='log'):
[dd7fc12]1535        # type: (str) -> None
[ec7e360]1536        """
1537        Plot the data and residuals.
1538        """
[608e31e]1539        pars = dict((k, v.value) for k, v in self.pars.items())
[ec7e360]1540        pars.update(self.pd_types)
[ff1fff5]1541        self.opts['pars'][0] = pars
[ca9e54e]1542        if not self.fix_p2:
1543            self.opts['pars'][1] = pars
1544        result = run_models(self.opts)
1545        limits = plot_models(self.opts, result, limits=self.limits)
[013adb7]1546        if self.limits is None:
1547            vmin, vmax = limits
[dd7fc12]1548            self.limits = vmax*1e-7, 1.3*vmax
[ca9e54e]1549            import pylab; pylab.clf()
1550            plot_models(self.opts, result, limits=self.limits)
[87985ca]1551
1552
[424fe00]1553def main(*argv):
1554    # type: (*str) -> None
[d15a908]1555    """
1556    Main program.
1557    """
[424fe00]1558    opts = parse_opts(argv)
1559    if opts is not None:
[48462b0]1560        if opts['seed'] > -1:
1561            print("Randomize using -random=%i"%opts['seed'])
1562            np.random.seed(opts['seed'])
[234c532]1563        if opts['html']:
1564            show_docs(opts)
1565        elif opts['explore']:
[0bdddc2]1566            opts['pars'] = parse_pars(opts)
[8f04da4]1567            if opts['pars'] is None:
1568                return
[424fe00]1569            explore(opts)
1570        else:
1571            compare(opts)
[d15a908]1572
[8a20be5]1573if __name__ == "__main__":
[424fe00]1574    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.