source: sasmodels/sasmodels/compare.py @ ced5bd2

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ced5bd2 was ced5bd2, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

allow -q=min:max on the sascomp command line

  • Property mode set to 100755
File size: 52.6 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except Exception:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56
57USAGE = """
58usage: sascomp model [options...] [key=val]
59
60Generate and compare SAS models.  If a single model is specified it shows
61a plot of that model.  Different models can be compared, or the same model
62with different parameters.  The same model with the same parameters can
63be compared with different calculation engines to see the effects of precision
64on the resultant values.
65
66model or model1,model2 are the names of the models to compare (see below).
67
68Options (* for default):
69
70    === data generation ===
71    -data="path" uses q, dq from the data file
72    -noise=0 sets the measurement error dI/I
73    -res=0 sets the resolution width dQ/Q if calculating with resolution
74    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
75    -q=min:max alternative specification of qrange
76    -nq=128 sets the number of Q points in the data set
77    -1d*/-2d computes 1d or 2d data
78    -zero indicates that q=0 should be included
79
80    === model parameters ===
81    -preset*/-random[=seed] preset or random parameters
82    -sets=n generates n random datasets with the seed given by -random=seed
83    -pars/-nopars* prints the parameter set or not
84    -default/-demo* use demo vs default parameters
85
86    === calculation options ===
87    -mono*/-poly force monodisperse or allow polydisperse demo parameters
88    -cutoff=1e-5* cutoff value for including a point in polydispersity
89    -magnetic/-nonmagnetic* suppress magnetism
90    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
91    -neval=1 sets the number of evals for more accurate timing
92
93    === precision options ===
94    -calc=default uses the default calcution precision
95    -single/-double/-half/-fast sets an OpenCL calculation engine
96    -single!/-double!/-quad! sets an OpenMP calculation engine
97    -sasview sets the sasview calculation engine
98
99    === plotting ===
100    -plot*/-noplot plots or suppress the plot of the model
101    -linear/-log*/-q4 intensity scaling on plots
102    -hist/-nohist* plot histogram of relative error
103    -abs/-rel* plot relative or absolute error
104    -title="note" adds note to the plot title, after the model name
105
106    === output options ===
107    -edit starts the parameter explorer
108    -help/-html shows the model docs instead of running the model
109
110The interpretation of quad precision depends on architecture, and may
111vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
112On unix and mac you may need single quotes around the DLL computation
113engines, such as -calc='single!,double!' since !, is treated as a history
114expansion request in the shell.
115
116Key=value pairs allow you to set specific values for the model parameters.
117Key=value1,value2 to compare different values of the same parameter. The
118value can be an expression including other parameters.
119
120Items later on the command line override those that appear earlier.
121
122Examples:
123
124    # compare single and double precision calculation for a barbell
125    sascomp barbell -calc=single,double
126
127    # generate 10 random lorentz models, with seed=27
128    sascomp lorentz -sets=10 -seed=27
129
130    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
131    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
132
133    # model timing test requires multiple evals to perform the estimate
134    sascomp pringle -calc=single,double -timing=100,100 -noplot
135"""
136
137# Update docs with command line usage string.   This is separate from the usual
138# doc string so that we can display it at run time if there is an error.
139# lin
140__doc__ = (__doc__  # pylint: disable=redefined-builtin
141           + """
142Program description
143-------------------
144
145""" + USAGE)
146
147kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
148
149# list of math functions for use in evaluating parameters
150MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
151
152# CRUFT python 2.6
153if not hasattr(datetime.timedelta, 'total_seconds'):
154    def delay(dt):
155        """Return number date-time delta as number seconds"""
156        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
157else:
158    def delay(dt):
159        """Return number date-time delta as number seconds"""
160        return dt.total_seconds()
161
162
163class push_seed(object):
164    """
165    Set the seed value for the random number generator.
166
167    When used in a with statement, the random number generator state is
168    restored after the with statement is complete.
169
170    :Parameters:
171
172    *seed* : int or array_like, optional
173        Seed for RandomState
174
175    :Example:
176
177    Seed can be used directly to set the seed::
178
179        >>> from numpy.random import randint
180        >>> push_seed(24)
181        <...push_seed object at...>
182        >>> print(randint(0,1000000,3))
183        [242082    899 211136]
184
185    Seed can also be used in a with statement, which sets the random
186    number generator state for the enclosed computations and restores
187    it to the previous state on completion::
188
189        >>> with push_seed(24):
190        ...    print(randint(0,1000000,3))
191        [242082    899 211136]
192
193    Using nested contexts, we can demonstrate that state is indeed
194    restored after the block completes::
195
196        >>> with push_seed(24):
197        ...    print(randint(0,1000000))
198        ...    with push_seed(24):
199        ...        print(randint(0,1000000,3))
200        ...    print(randint(0,1000000))
201        242082
202        [242082    899 211136]
203        899
204
205    The restore step is protected against exceptions in the block::
206
207        >>> with push_seed(24):
208        ...    print(randint(0,1000000))
209        ...    try:
210        ...        with push_seed(24):
211        ...            print(randint(0,1000000,3))
212        ...            raise Exception()
213        ...    except Exception:
214        ...        print("Exception raised")
215        ...    print(randint(0,1000000))
216        242082
217        [242082    899 211136]
218        Exception raised
219        899
220    """
221    def __init__(self, seed=None):
222        # type: (Optional[int]) -> None
223        self._state = np.random.get_state()
224        np.random.seed(seed)
225
226    def __enter__(self):
227        # type: () -> None
228        pass
229
230    def __exit__(self, exc_type, exc_value, traceback):
231        # type: (Any, BaseException, Any) -> None
232        # TODO: better typing for __exit__ method
233        np.random.set_state(self._state)
234
235def tic():
236    # type: () -> Callable[[], float]
237    """
238    Timer function.
239
240    Use "toc=tic()" to start the clock and "toc()" to measure
241    a time interval.
242    """
243    then = datetime.datetime.now()
244    return lambda: delay(datetime.datetime.now() - then)
245
246
247def set_beam_stop(data, radius, outer=None):
248    # type: (Data, float, float) -> None
249    """
250    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
251
252    Note: this function does not require sasview
253    """
254    if hasattr(data, 'qx_data'):
255        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
256        data.mask = (q < radius)
257        if outer is not None:
258            data.mask |= (q >= outer)
259    else:
260        data.mask = (data.x < radius)
261        if outer is not None:
262            data.mask |= (data.x >= outer)
263
264
265def parameter_range(p, v):
266    # type: (str, float) -> Tuple[float, float]
267    """
268    Choose a parameter range based on parameter name and initial value.
269    """
270    # process the polydispersity options
271    if p.endswith('_pd_n'):
272        return 0., 100.
273    elif p.endswith('_pd_nsigma'):
274        return 0., 5.
275    elif p.endswith('_pd_type'):
276        raise ValueError("Cannot return a range for a string value")
277    elif any(s in p for s in ('theta', 'phi', 'psi')):
278        # orientation in [-180,180], orientation pd in [0,45]
279        if p.endswith('_pd'):
280            return 0., 45.
281        else:
282            return -180., 180.
283    elif p.endswith('_pd'):
284        return 0., 1.
285    elif 'sld' in p:
286        return -0.5, 10.
287    elif p == 'background':
288        return 0., 10.
289    elif p == 'scale':
290        return 0., 1.e3
291    elif v < 0.:
292        return 2.*v, -2.*v
293    else:
294        return 0., (2.*v if v > 0. else 1.)
295
296
297def _randomize_one(model_info, name, value):
298    # type: (ModelInfo, str, float) -> float
299    # type: (ModelInfo, str, str) -> str
300    """
301    Randomize a single parameter.
302    """
303    # Set the amount of polydispersity/angular dispersion, but by default pd_n
304    # is zero so there is no polydispersity.  This allows us to turn on/off
305    # pd by setting pd_n, and still have randomly generated values
306    if name.endswith('_pd'):
307        par = model_info.parameters[name[:-3]]
308        if par.type == 'orientation':
309            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
310            return 180*np.random.beta(2.5, 20)
311        else:
312            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
313            return np.random.beta(1.5, 7)
314
315    # pd is selected globally rather than per parameter, so set to 0 for no pd
316    # In particular, when multiple pd dimensions, want to decrease the number
317    # of points per dimension for faster computation
318    if name.endswith('_pd_n'):
319        return 0
320
321    # Don't mess with distribution type for now
322    if name.endswith('_pd_type'):
323        return 'gaussian'
324
325    # type-dependent value of number of sigmas; for gaussian use 3.
326    if name.endswith('_pd_nsigma'):
327        return 3.
328
329    # background in the range [0.01, 1]
330    if name == 'background':
331        return 10**np.random.uniform(-2, 0)
332
333    # scale defaults to 0.1% to 30% volume fraction
334    if name == 'scale':
335        return 10**np.random.uniform(-3, -0.5)
336
337    # If it is a list of choices, pick one at random with equal probability
338    # In practice, the model specific random generator will override.
339    par = model_info.parameters[name]
340    if len(par.limits) > 2:  # choice list
341        return np.random.randint(len(par.limits))
342
343    # If it is a fixed range, pick from it with equal probability.
344    # For logarithmic ranges, the model will have to override.
345    if np.isfinite(par.limits).all():
346        return np.random.uniform(*par.limits)
347
348    # If the paramter is marked as an sld use the range of neutron slds
349    # TODO: ought to randomly contrast match a pair of SLDs
350    if par.type == 'sld':
351        return np.random.uniform(-0.5, 12)
352
353    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
354    if par.name.startswith('M0:'):
355        return np.random.uniform(0, 5)
356
357    # Guess at the random length/radius/thickness.  In practice, all models
358    # are going to set their own reasonable ranges.
359    if par.type == 'volume':
360        if ('length' in par.name or
361                'radius' in par.name or
362                'thick' in par.name):
363            return 10**np.random.uniform(2, 4)
364
365    # In the absence of any other info, select a value in [0, 2v], or
366    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
367    # model random parameter generators will override this default.
368    low, high = parameter_range(par.name, value)
369    limits = (max(par.limits[0], low), min(par.limits[1], high))
370    return np.random.uniform(*limits)
371
372def _random_pd(model_info, pars):
373    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
374    pd_volume = []
375    pd_oriented = []
376    for p in pd:
377        if p.type == 'orientation':
378            pd_oriented.append(p.name)
379        elif p.length_control is not None:
380            n = int(pars.get(p.length_control, 1) + 0.5)
381            pd_volume.extend(p.name+str(k+1) for k in range(n))
382        elif p.length > 1:
383            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
384        else:
385            pd_volume.append(p.name)
386    u = np.random.rand()
387    n = len(pd_volume)
388    if u < 0.01 or n < 1:
389        pass  # 1% chance of no polydispersity
390    elif u < 0.86 or n < 2:
391        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
392    elif u < 0.99 or n < 3:
393        choices = np.random.choice(len(pd_volume), size=2)
394        pars[pd_volume[choices[0]]+"_pd_n"] = 25
395        pars[pd_volume[choices[1]]+"_pd_n"] = 10
396    else:
397        choices = np.random.choice(len(pd_volume), size=3)
398        pars[pd_volume[choices[0]]+"_pd_n"] = 25
399        pars[pd_volume[choices[1]]+"_pd_n"] = 10
400        pars[pd_volume[choices[2]]+"_pd_n"] = 5
401    if pd_oriented:
402        pars['theta_pd_n'] = 20
403        if np.random.rand() < 0.1:
404            pars['phi_pd_n'] = 5
405        if np.random.rand() < 0.1:
406            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
407                #print("generating psi_pd_n")
408                pars['psi_pd_n'] = 5
409
410    ## Show selected polydispersity
411    #for name, value in pars.items():
412    #    if name.endswith('_pd_n') and value > 0:
413    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
414
415
416def randomize_pars(model_info, pars):
417    # type: (ModelInfo, ParameterSet) -> ParameterSet
418    """
419    Generate random values for all of the parameters.
420
421    Valid ranges for the random number generator are guessed from the name of
422    the parameter; this will not account for constraints such as cap radius
423    greater than cylinder radius in the capped_cylinder model, so
424    :func:`constrain_pars` needs to be called afterward..
425    """
426    # Note: the sort guarantees order of calls to random number generator
427    random_pars = dict((p, _randomize_one(model_info, p, v))
428                       for p, v in sorted(pars.items()))
429    if model_info.random is not None:
430        random_pars.update(model_info.random())
431    _random_pd(model_info, random_pars)
432    return random_pars
433
434
435def constrain_pars(model_info, pars):
436    # type: (ModelInfo, ParameterSet) -> None
437    """
438    Restrict parameters to valid values.
439
440    This includes model specific code for models such as capped_cylinder
441    which need to support within model constraints (cap radius more than
442    cylinder radius in this case).
443
444    Warning: this updates the *pars* dictionary in place.
445    """
446    # TODO: move the model specific code to the individual models
447    name = model_info.id
448    # if it is a product model, then just look at the form factor since
449    # none of the structure factors need any constraints.
450    if '*' in name:
451        name = name.split('*')[0]
452
453    # Suppress magnetism for python models (not yet implemented)
454    if callable(model_info.Iq):
455        pars.update(suppress_magnetism(pars))
456
457    if name == 'barbell':
458        if pars['radius_bell'] < pars['radius']:
459            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
460
461    elif name == 'capped_cylinder':
462        if pars['radius_cap'] < pars['radius']:
463            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
464
465    elif name == 'guinier':
466        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
467        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
468        # => ln A - (Rg^2 q^2/3) > -30 ln 10
469        # => Rg^2 q^2/3 < 30 ln 10 + ln A
470        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
471        #q_max = 0.2  # mid q maximum
472        q_max = 1.0  # high q maximum
473        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
474        pars['rg'] = min(pars['rg'], rg_max)
475
476    elif name == 'pearl_necklace':
477        if pars['radius'] < pars['thick_string']:
478            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
479        pass
480
481    elif name == 'rpa':
482        # Make sure phi sums to 1.0
483        if pars['case_num'] < 2:
484            pars['Phi1'] = 0.
485            pars['Phi2'] = 0.
486        elif pars['case_num'] < 5:
487            pars['Phi1'] = 0.
488        total = sum(pars['Phi'+c] for c in '1234')
489        for c in '1234':
490            pars['Phi'+c] /= total
491
492def parlist(model_info, pars, is2d):
493    # type: (ModelInfo, ParameterSet, bool) -> str
494    """
495    Format the parameter list for printing.
496    """
497    lines = []
498    parameters = model_info.parameters
499    magnetic = False
500    magnetic_pars = []
501    for p in parameters.user_parameters(pars, is2d):
502        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
503            continue
504        if p.id.startswith('up:'):
505            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
506            continue
507        fields = dict(
508            value=pars.get(p.id, p.default),
509            pd=pars.get(p.id+"_pd", 0.),
510            n=int(pars.get(p.id+"_pd_n", 0)),
511            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
512            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
513            relative_pd=p.relative_pd,
514            M0=pars.get('M0:'+p.id, 0.),
515            mphi=pars.get('mphi:'+p.id, 0.),
516            mtheta=pars.get('mtheta:'+p.id, 0.),
517        )
518        lines.append(_format_par(p.name, **fields))
519        magnetic = magnetic or fields['M0'] != 0.
520    if magnetic and magnetic_pars:
521        lines.append(" ".join(magnetic_pars))
522    return "\n".join(lines)
523
524    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
525
526def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
527                relative_pd=False, M0=0., mphi=0., mtheta=0.):
528    # type: (str, float, float, int, float, str) -> str
529    line = "%s: %g"%(name, value)
530    if pd != 0.  and n != 0:
531        if relative_pd:
532            pd *= value
533        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
534                % (pd, n, nsigma, nsigma, pdtype)
535    if M0 != 0.:
536        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
537    return line
538
539def suppress_pd(pars, suppress=True):
540    # type: (ParameterSet) -> ParameterSet
541    """
542    If suppress is True complete eliminate polydispersity of the model to test
543    models more quickly.  If suppress is False, make sure at least one
544    parameter is polydisperse, setting the first polydispersity parameter to
545    15% if no polydispersity is given (with no explicit demo parameters given
546    in the model, there will be no default polydispersity).
547    """
548    pars = pars.copy()
549    #print("pars=", pars)
550    if suppress:
551        for p in pars:
552            if p.endswith("_pd_n"):
553                pars[p] = 0
554    else:
555        any_pd = False
556        first_pd = None
557        for p in pars:
558            if p.endswith("_pd_n"):
559                pd = pars.get(p[:-2], 0.)
560                any_pd |= (pars[p] != 0 and pd != 0.)
561                if first_pd is None:
562                    first_pd = p
563        if not any_pd and first_pd is not None:
564            if pars[first_pd] == 0:
565                pars[first_pd] = 35
566            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
567                pars[first_pd[:-2]] = 0.15
568    return pars
569
570def suppress_magnetism(pars, suppress=True):
571    # type: (ParameterSet) -> ParameterSet
572    """
573    If suppress is True complete eliminate magnetism of the model to test
574    models more quickly.  If suppress is False, make sure at least one sld
575    parameter is magnetic, setting the first parameter to have a strong
576    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
577    in the model, there will be no default magnetism).
578    """
579    pars = pars.copy()
580    if suppress:
581        for p in pars:
582            if p.startswith("M0:"):
583                pars[p] = 0
584    else:
585        any_mag = False
586        first_mag = None
587        for p in pars:
588            if p.startswith("M0:"):
589                any_mag |= (pars[p] != 0)
590                if first_mag is None:
591                    first_mag = p
592        if not any_mag and first_mag is not None:
593            pars[first_mag] = 8.
594    return pars
595
596def eval_sasview(model_info, data):
597    # type: (Modelinfo, Data) -> Calculator
598    """
599    Return a model calculator using the pre-4.0 SasView models.
600    """
601    # importing sas here so that the error message will be that sas failed to
602    # import rather than the more obscure smear_selection not imported error
603    import sas
604    import sas.models
605    from sas.models.qsmearing import smear_selection
606    from sas.models.MultiplicationModel import MultiplicationModel
607    from sas.models.dispersion_models import models as dispersers
608
609    def get_model_class(name):
610        # type: (str) -> "sas.models.BaseComponent"
611        #print("new",sorted(_pars.items()))
612        __import__('sas.models.' + name)
613        ModelClass = getattr(getattr(sas.models, name, None), name, None)
614        if ModelClass is None:
615            raise ValueError("could not find model %r in sas.models"%name)
616        return ModelClass
617
618    # WARNING: ugly hack when handling model!
619    # Sasview models with multiplicity need to be created with the target
620    # multiplicity, so we cannot create the target model ahead of time for
621    # for multiplicity models.  Instead we store the model in a list and
622    # update the first element of that list with the new multiplicity model
623    # every time we evaluate.
624
625    # grab the sasview model, or create it if it is a product model
626    if model_info.composition:
627        composition_type, parts = model_info.composition
628        if composition_type == 'product':
629            P, S = [get_model_class(revert_name(p))() for p in parts]
630            model = [MultiplicationModel(P, S)]
631        else:
632            raise ValueError("sasview mixture models not supported by compare")
633    else:
634        old_name = revert_name(model_info)
635        if old_name is None:
636            raise ValueError("model %r does not exist in old sasview"
637                            % model_info.id)
638        ModelClass = get_model_class(old_name)
639        model = [ModelClass()]
640    model[0].disperser_handles = {}
641
642    # build a smearer with which to call the model, if necessary
643    smearer = smear_selection(data, model=model)
644    if hasattr(data, 'qx_data'):
645        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
646        index = ((~data.mask) & (~np.isnan(data.data))
647                 & (q >= data.qmin) & (q <= data.qmax))
648        if smearer is not None:
649            smearer.model = model  # because smear_selection has a bug
650            smearer.accuracy = data.accuracy
651            smearer.set_index(index)
652            def _call_smearer():
653                smearer.model = model[0]
654                return smearer.get_value()
655            theory = _call_smearer
656        else:
657            theory = lambda: model[0].evalDistribution([data.qx_data[index],
658                                                        data.qy_data[index]])
659    elif smearer is not None:
660        theory = lambda: smearer(model[0].evalDistribution(data.x))
661    else:
662        theory = lambda: model[0].evalDistribution(data.x)
663
664    def calculator(**pars):
665        # type: (float, ...) -> np.ndarray
666        """
667        Sasview calculator for model.
668        """
669        oldpars = revert_pars(model_info, pars)
670        # For multiplicity models, create a model with the correct multiplicity
671        control = oldpars.pop("CONTROL", None)
672        if control is not None:
673            # sphericalSLD has one fewer multiplicity.  This update should
674            # happen in revert_pars, but it hasn't been called yet.
675            model[0] = ModelClass(control)
676        # paying for parameter conversion each time to keep life simple, if not fast
677        for k, v in oldpars.items():
678            if k.endswith('.type'):
679                par = k[:-5]
680                if v == 'gaussian': continue
681                cls = dispersers[v if v != 'rectangle' else 'rectangula']
682                handle = cls()
683                model[0].disperser_handles[par] = handle
684                try:
685                    model[0].set_dispersion(par, handle)
686                except Exception:
687                    exception.annotate_exception("while setting %s to %r"
688                                                 %(par, v))
689                    raise
690
691
692        #print("sasview pars",oldpars)
693        for k, v in oldpars.items():
694            name_attr = k.split('.')  # polydispersity components
695            if len(name_attr) == 2:
696                par, disp_par = name_attr
697                model[0].dispersion[par][disp_par] = v
698            else:
699                model[0].setParam(k, v)
700        return theory()
701
702    calculator.engine = "sasview"
703    return calculator
704
705DTYPE_MAP = {
706    'half': '16',
707    'fast': 'fast',
708    'single': '32',
709    'double': '64',
710    'quad': '128',
711    'f16': '16',
712    'f32': '32',
713    'f64': '64',
714    'float16': '16',
715    'float32': '32',
716    'float64': '64',
717    'float128': '128',
718    'longdouble': '128',
719}
720def eval_opencl(model_info, data, dtype='single', cutoff=0.):
721    # type: (ModelInfo, Data, str, float) -> Calculator
722    """
723    Return a model calculator using the OpenCL calculation engine.
724    """
725    if not core.HAVE_OPENCL:
726        raise RuntimeError("OpenCL not available")
727    model = core.build_model(model_info, dtype=dtype, platform="ocl")
728    calculator = DirectModel(data, model, cutoff=cutoff)
729    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
730    return calculator
731
732def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
733    # type: (ModelInfo, Data, str, float) -> Calculator
734    """
735    Return a model calculator using the DLL calculation engine.
736    """
737    model = core.build_model(model_info, dtype=dtype, platform="dll")
738    calculator = DirectModel(data, model, cutoff=cutoff)
739    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
740    return calculator
741
742def time_calculation(calculator, pars, evals=1):
743    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
744    """
745    Compute the average calculation time over N evaluations.
746
747    An additional call is generated without polydispersity in order to
748    initialize the calculation engine, and make the average more stable.
749    """
750    # initialize the code so time is more accurate
751    if evals > 1:
752        calculator(**suppress_pd(pars))
753    toc = tic()
754    # make sure there is at least one eval
755    value = calculator(**pars)
756    for _ in range(evals-1):
757        value = calculator(**pars)
758    average_time = toc()*1000. / evals
759    #print("I(q)",value)
760    return value, average_time
761
762def make_data(opts):
763    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
764    """
765    Generate an empty dataset, used with the model to set Q points
766    and resolution.
767
768    *opts* contains the options, with 'qmax', 'nq', 'res',
769    'accuracy', 'is2d' and 'view' parsed from the command line.
770    """
771    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res']
772    if opts['is2d']:
773        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
774        data = empty_data2D(q, resolution=res)
775        data.accuracy = opts['accuracy']
776        set_beam_stop(data, 0.0004)
777        index = ~data.mask
778    else:
779        if opts['view'] == 'log' and not opts['zero']:
780            q = np.logspace(math.log10(qmin), math.log10(qmax), nq)
781        else:
782            q = np.linspace(qmin, qmax, nq)
783        if opts['zero']:
784            q = np.hstack((0, q))
785        data = empty_data1D(q, resolution=res)
786        index = slice(None, None)
787    return data, index
788
789def make_engine(model_info, data, dtype, cutoff):
790    # type: (ModelInfo, Data, str, float) -> Calculator
791    """
792    Generate the appropriate calculation engine for the given datatype.
793
794    Datatypes with '!' appended are evaluated using external C DLLs rather
795    than OpenCL.
796    """
797    if dtype == 'sasview':
798        return eval_sasview(model_info, data)
799    elif dtype is None or not dtype.endswith('!'):
800        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
801    else:
802        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
803
804def _show_invalid(data, theory):
805    # type: (Data, np.ma.ndarray) -> None
806    """
807    Display a list of the non-finite values in theory.
808    """
809    if not theory.mask.any():
810        return
811
812    if hasattr(data, 'x'):
813        bad = zip(data.x[theory.mask], theory[theory.mask])
814        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
815
816
817def compare(opts, limits=None):
818    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
819    """
820    Preform a comparison using options from the command line.
821
822    *limits* are the limits on the values to use, either to set the y-axis
823    for 1D or to set the colormap scale for 2D.  If None, then they are
824    inferred from the data and returned. When exploring using Bumps,
825    the limits are set when the model is initially called, and maintained
826    as the values are adjusted, making it easier to see the effects of the
827    parameters.
828    """
829    limits = np.Inf, -np.Inf
830    for k in range(opts['sets']):
831        opts['pars'] = parse_pars(opts)
832        if opts['pars'] is None:
833            return
834        result = run_models(opts, verbose=True)
835        if opts['plot']:
836            limits = plot_models(opts, result, limits=limits, setnum=k)
837    if opts['plot']:
838        import matplotlib.pyplot as plt
839        plt.show()
840
841def run_models(opts, verbose=False):
842    # type: (Dict[str, Any]) -> Dict[str, Any]
843
844    base, comp = opts['engines']
845    base_n, comp_n = opts['count']
846    base_pars, comp_pars = opts['pars']
847    data = opts['data']
848
849    comparison = comp is not None
850
851    base_time = comp_time = None
852    base_value = comp_value = resid = relerr = None
853
854    # Base calculation
855    try:
856        base_raw, base_time = time_calculation(base, base_pars, base_n)
857        base_value = np.ma.masked_invalid(base_raw)
858        if verbose:
859            print("%s t=%.2f ms, intensity=%.0f"
860                  % (base.engine, base_time, base_value.sum()))
861        _show_invalid(data, base_value)
862    except ImportError:
863        traceback.print_exc()
864
865    # Comparison calculation
866    if comparison:
867        try:
868            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
869            comp_value = np.ma.masked_invalid(comp_raw)
870            if verbose:
871                print("%s t=%.2f ms, intensity=%.0f"
872                      % (comp.engine, comp_time, comp_value.sum()))
873            _show_invalid(data, comp_value)
874        except ImportError:
875            traceback.print_exc()
876
877    # Compare, but only if computing both forms
878    if comparison:
879        resid = (base_value - comp_value)
880        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
881        if verbose:
882            _print_stats("|%s-%s|"
883                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
884                         resid)
885            _print_stats("|(%s-%s)/%s|"
886                         % (base.engine, comp.engine, comp.engine),
887                         relerr)
888
889    return dict(base_value=base_value, comp_value=comp_value,
890                base_time=base_time, comp_time=comp_time,
891                resid=resid, relerr=relerr)
892
893
894def _print_stats(label, err):
895    # type: (str, np.ma.ndarray) -> None
896    # work with trimmed data, not the full set
897    sorted_err = np.sort(abs(err.compressed()))
898    if len(sorted_err) == 0.:
899        print(label + "  no valid values")
900        return
901
902    p50 = int((len(sorted_err)-1)*0.50)
903    p98 = int((len(sorted_err)-1)*0.98)
904    data = [
905        "max:%.3e"%sorted_err[-1],
906        "median:%.3e"%sorted_err[p50],
907        "98%%:%.3e"%sorted_err[p98],
908        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
909        "zero-offset:%+.3e"%np.mean(sorted_err),
910        ]
911    print(label+"  "+"  ".join(data))
912
913
914def plot_models(opts, result, limits=(np.Inf, -np.Inf), setnum=0):
915    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
916    base_value, comp_value = result['base_value'], result['comp_value']
917    base_time, comp_time = result['base_time'], result['comp_time']
918    resid, relerr = result['resid'], result['relerr']
919
920    have_base, have_comp = (base_value is not None), (comp_value is not None)
921    base, comp = opts['engines']
922    data = opts['data']
923    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
924
925    # Plot if requested
926    view = opts['view']
927    import matplotlib.pyplot as plt
928    vmin, vmax = limits
929    if have_base:
930        vmin = min(vmin, base_value.min())
931        vmax = max(vmax, base_value.max())
932    if have_comp:
933        vmin = min(vmin, comp_value.min())
934        vmax = max(vmax, comp_value.max())
935    limits = vmin, vmax
936
937    if have_base:
938        if have_comp:
939            plt.subplot(131)
940        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
941        plt.title("%s t=%.2f ms"%(base.engine, base_time))
942        #cbar_title = "log I"
943    if have_comp:
944        if have_base:
945            plt.subplot(132)
946        if not opts['is2d'] and have_base:
947            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
948        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
949        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
950        #cbar_title = "log I"
951    if have_base and have_comp:
952        plt.subplot(133)
953        if not opts['rel_err']:
954            err, errstr, errview = resid, "abs err", "linear"
955        else:
956            err, errstr, errview = abs(relerr), "rel err", "log"
957            if (err == 0.).all():
958                errview = 'linear'
959        if 0:  # 95% cutoff
960            sorted = np.sort(err.flatten())
961            cutoff = sorted[int(sorted.size*0.95)]
962            err[err > cutoff] = cutoff
963        #err,errstr = base/comp,"ratio"
964        plot_theory(data, None, resid=err, view=view, use_data=use_data)
965        plt.yscale(errview)
966        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
967        #cbar_title = errstr if errview=="linear" else "log "+errstr
968    #if is2D:
969    #    h = plt.colorbar()
970    #    h.ax.set_title(cbar_title)
971    fig = plt.gcf()
972    extra_title = ' '+opts['title'] if opts['title'] else ''
973    fig.suptitle(":".join(opts['name']) + extra_title)
974
975    if have_base and have_comp and opts['show_hist']:
976        plt.figure()
977        v = relerr
978        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
979        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
980        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
981                   % (base.engine, comp.engine, comp.engine))
982        plt.ylabel('P(err)')
983        plt.title('Distribution of relative error between calculation engines')
984
985    return limits
986
987
988# ===========================================================================
989#
990
991# Set of command line options.
992# Normal options such as -plot/-noplot are specified as 'name'.
993# For options such as -nq=500 which require a value use 'name='.
994#
995OPTIONS = [
996    # Plotting
997    'plot', 'noplot',
998    'linear', 'log', 'q4',
999    'rel', 'abs',
1000    'hist', 'nohist',
1001    'title=',
1002
1003    # Data generation
1004    'data=', 'noise=', 'res=', 'nq=', 'q=',
1005    'lowq', 'midq', 'highq', 'exq', 'zero',
1006    '2d', '1d',
1007
1008    # Parameter set
1009    'preset', 'random', 'random=', 'sets=',
1010    'demo', 'default',  # TODO: remove demo/default
1011    'nopars', 'pars',
1012
1013    # Calculation options
1014    'poly', 'mono', 'cutoff=',
1015    'magnetic', 'nonmagnetic',
1016    'accuracy=',
1017    'neval=',  # for timing...
1018
1019    # Precision options
1020    'calc=',
1021    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1022    'sasview',  # TODO: remove sasview 3.x support
1023
1024    # Output options
1025    'help', 'html', 'edit',
1026    ]
1027
1028NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1029VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1030
1031
1032def columnize(items, indent="", width=79):
1033    # type: (List[str], str, int) -> str
1034    """
1035    Format a list of strings into columns.
1036
1037    Returns a string with carriage returns ready for printing.
1038    """
1039    column_width = max(len(w) for w in items) + 1
1040    num_columns = (width - len(indent)) // column_width
1041    num_rows = len(items) // num_columns
1042    items = items + [""] * (num_rows * num_columns - len(items))
1043    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1044    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1045             for row in zip(*columns)]
1046    output = indent + ("\n"+indent).join(lines)
1047    return output
1048
1049
1050def get_pars(model_info, use_demo=False):
1051    # type: (ModelInfo, bool) -> ParameterSet
1052    """
1053    Extract demo parameters from the model definition.
1054    """
1055    # Get the default values for the parameters
1056    pars = {}
1057    for p in model_info.parameters.call_parameters:
1058        parts = [('', p.default)]
1059        if p.polydisperse:
1060            parts.append(('_pd', 0.0))
1061            parts.append(('_pd_n', 0))
1062            parts.append(('_pd_nsigma', 3.0))
1063            parts.append(('_pd_type', "gaussian"))
1064        for ext, val in parts:
1065            if p.length > 1:
1066                dict(("%s%d%s" % (p.id, k, ext), val)
1067                     for k in range(1, p.length+1))
1068            else:
1069                pars[p.id + ext] = val
1070
1071    # Plug in values given in demo
1072    if use_demo and model_info.demo:
1073        pars.update(model_info.demo)
1074    return pars
1075
1076INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1077def isnumber(str):
1078    match = FLOAT_RE.match(str)
1079    isfloat = (match and not str[match.end():])
1080    return isfloat or INTEGER_RE.match(str)
1081
1082# For distinguishing pairs of models for comparison
1083# key-value pair separator =
1084# shell characters  | & ; <> $ % ' " \ # `
1085# model and parameter names _
1086# parameter expressions - + * / . ( )
1087# path characters including tilde expansion and windows drive ~ / :
1088# not sure about brackets [] {}
1089# maybe one of the following @ ? ^ ! ,
1090PAR_SPLIT = ','
1091def parse_opts(argv):
1092    # type: (List[str]) -> Dict[str, Any]
1093    """
1094    Parse command line options.
1095    """
1096    MODELS = core.list_models()
1097    flags = [arg for arg in argv
1098             if arg.startswith('-')]
1099    values = [arg for arg in argv
1100              if not arg.startswith('-') and '=' in arg]
1101    positional_args = [arg for arg in argv
1102                       if not arg.startswith('-') and '=' not in arg]
1103    models = "\n    ".join("%-15s"%v for v in MODELS)
1104    if len(positional_args) == 0:
1105        print(USAGE)
1106        print("\nAvailable models:")
1107        print(columnize(MODELS, indent="  "))
1108        return None
1109
1110    invalid = [o[1:] for o in flags
1111               if o[1:] not in NAME_OPTIONS
1112               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1113    if invalid:
1114        print("Invalid options: %s"%(", ".join(invalid)))
1115        return None
1116
1117    name = positional_args[-1]
1118
1119    # pylint: disable=bad-whitespace
1120    # Interpret the flags
1121    opts = {
1122        'plot'      : True,
1123        'view'      : 'log',
1124        'is2d'      : False,
1125        'qmin'      : None,
1126        'qmax'      : 0.05,
1127        'nq'        : 128,
1128        'res'       : 0.0,
1129        'noise'     : 0.0,
1130        'accuracy'  : 'Low',
1131        'cutoff'    : '0.0',
1132        'seed'      : -1,  # default to preset
1133        'mono'      : True,
1134        # Default to magnetic a magnetic moment is set on the command line
1135        'magnetic'  : False,
1136        'show_pars' : False,
1137        'show_hist' : False,
1138        'rel_err'   : True,
1139        'explore'   : False,
1140        'use_demo'  : True,
1141        'zero'      : False,
1142        'html'      : False,
1143        'title'     : None,
1144        'datafile'  : None,
1145        'sets'      : 0,
1146        'engine'    : 'default',
1147        'evals'     : '1',
1148    }
1149    for arg in flags:
1150        if arg == '-noplot':    opts['plot'] = False
1151        elif arg == '-plot':    opts['plot'] = True
1152        elif arg == '-linear':  opts['view'] = 'linear'
1153        elif arg == '-log':     opts['view'] = 'log'
1154        elif arg == '-q4':      opts['view'] = 'q4'
1155        elif arg == '-1d':      opts['is2d'] = False
1156        elif arg == '-2d':      opts['is2d'] = True
1157        elif arg == '-exq':     opts['qmax'] = 10.0
1158        elif arg == '-highq':   opts['qmax'] = 1.0
1159        elif arg == '-midq':    opts['qmax'] = 0.2
1160        elif arg == '-lowq':    opts['qmax'] = 0.05
1161        elif arg == '-zero':    opts['zero'] = True
1162        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1163        elif arg.startswith('-q='):
1164            opts['qmin'], opts['qmax'] = [float(v) for v in arg[3:].split(':')]
1165        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1166        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1167        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1168        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1169        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1170        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
1171        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1172        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1173        elif arg.startswith('-calc='):     opts['engine'] = arg[6:]
1174        elif arg.startswith('-neval='):    opts['evals'] = arg[7:]
1175        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
1176        elif arg == '-preset':  opts['seed'] = -1
1177        elif arg == '-mono':    opts['mono'] = True
1178        elif arg == '-poly':    opts['mono'] = False
1179        elif arg == '-magnetic':       opts['magnetic'] = True
1180        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1181        elif arg == '-pars':    opts['show_pars'] = True
1182        elif arg == '-nopars':  opts['show_pars'] = False
1183        elif arg == '-hist':    opts['show_hist'] = True
1184        elif arg == '-nohist':  opts['show_hist'] = False
1185        elif arg == '-rel':     opts['rel_err'] = True
1186        elif arg == '-abs':     opts['rel_err'] = False
1187        elif arg == '-half':    opts['engine'] = 'half'
1188        elif arg == '-fast':    opts['engine'] = 'fast'
1189        elif arg == '-single':  opts['engine'] = 'single'
1190        elif arg == '-double':  opts['engine'] = 'double'
1191        elif arg == '-single!': opts['engine'] = 'single!'
1192        elif arg == '-double!': opts['engine'] = 'double!'
1193        elif arg == '-quad!':   opts['engine'] = 'quad!'
1194        elif arg == '-sasview': opts['engine'] = 'sasview'
1195        elif arg == '-edit':    opts['explore'] = True
1196        elif arg == '-demo':    opts['use_demo'] = True
1197        elif arg == '-default': opts['use_demo'] = False
1198        elif arg == '-html':    opts['html'] = True
1199        elif arg == '-help':    opts['html'] = True
1200    # pylint: enable=bad-whitespace
1201
1202    # Magnetism forces 2D for now
1203    if opts['magnetic']:
1204        opts['is2d'] = True
1205
1206    # Force random if sets is used
1207    if opts['sets'] >= 1 and opts['seed'] < 0:
1208        opts['seed'] = np.random.randint(1000000)
1209    if opts['sets'] == 0:
1210        opts['sets'] = 1
1211
1212    # Create the computational engines
1213    if opts['qmin'] is None:
1214        opts['qmin'] = 0.001*opts['qmax']
1215    if opts['datafile'] is not None:
1216        data = load_data(os.path.expanduser(opts['datafile']))
1217    else:
1218        data, _ = make_data(opts)
1219
1220    comparison = any(PAR_SPLIT in v for v in values)
1221    if PAR_SPLIT in name:
1222        names = name.split(PAR_SPLIT, 2)
1223        comparison = True
1224    else:
1225        names = [name]*2
1226    try:
1227        model_info = [core.load_model_info(k) for k in names]
1228    except ImportError as exc:
1229        print(str(exc))
1230        print("Could not find model; use one of:\n    " + models)
1231        return None
1232
1233    if PAR_SPLIT in opts['engine']:
1234        engine_types = opts['engine'].split(PAR_SPLIT, 2)
1235        comparison = True
1236    else:
1237        engine_types = [opts['engine']]*2
1238
1239    if PAR_SPLIT in opts['evals']:
1240        evals = [int(k) for k in opts['evals'].split(PAR_SPLIT, 2)]
1241        comparison = True
1242    else:
1243        evals = [int(opts['evals'])]*2
1244
1245    if PAR_SPLIT in opts['cutoff']:
1246        cutoff = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1247        comparison = True
1248    else:
1249        cutoff = [float(opts['cutoff'])]*2
1250
1251    base = make_engine(model_info[0], data, engine_types[0], cutoff[0])
1252    if comparison:
1253        comp = make_engine(model_info[1], data, engine_types[1], cutoff[1])
1254    else:
1255        comp = None
1256
1257    # pylint: disable=bad-whitespace
1258    # Remember it all
1259    opts.update({
1260        'data'      : data,
1261        'name'      : names,
1262        'def'       : model_info,
1263        'count'     : evals,
1264        'engines'   : [base, comp],
1265        'values'    : values,
1266    })
1267    # pylint: enable=bad-whitespace
1268
1269    return opts
1270
1271def parse_pars(opts):
1272    model_info, model_info2 = opts['def']
1273
1274    # Get demo parameters from model definition, or use default parameters
1275    # if model does not define demo parameters
1276    pars = get_pars(model_info, opts['use_demo'])
1277    pars2 = get_pars(model_info2, opts['use_demo'])
1278    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1279    # randomize parameters
1280    #pars.update(set_pars)  # set value before random to control range
1281    if opts['seed'] > -1:
1282        pars = randomize_pars(model_info, pars)
1283        if model_info != model_info2:
1284            pars2 = randomize_pars(model_info2, pars2)
1285            # Share values for parameters with the same name
1286            for k, v in pars.items():
1287                if k in pars2:
1288                    pars2[k] = v
1289        else:
1290            pars2 = pars.copy()
1291        constrain_pars(model_info, pars)
1292        constrain_pars(model_info2, pars2)
1293    pars = suppress_pd(pars, opts['mono'])
1294    pars2 = suppress_pd(pars2, opts['mono'])
1295    pars = suppress_magnetism(pars, not opts['magnetic'])
1296    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1297
1298    # Fill in parameters given on the command line
1299    presets = {}
1300    presets2 = {}
1301    for arg in opts['values']:
1302        k, v = arg.split('=', 1)
1303        if k not in pars and k not in pars2:
1304            # extract base name without polydispersity info
1305            s = set(p.split('_pd')[0] for p in pars)
1306            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1307            return None
1308        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
1309        if v1 and k in pars:
1310            presets[k] = float(v1) if isnumber(v1) else v1
1311        if v2 and k in pars2:
1312            presets2[k] = float(v2) if isnumber(v2) else v2
1313
1314    # If pd given on the command line, default pd_n to 35
1315    for k, v in list(presets.items()):
1316        if k.endswith('_pd'):
1317            presets.setdefault(k+'_n', 35.)
1318    for k, v in list(presets2.items()):
1319        if k.endswith('_pd'):
1320            presets2.setdefault(k+'_n', 35.)
1321
1322    # Evaluate preset parameter expressions
1323    context = MATH.copy()
1324    context['np'] = np
1325    context.update(pars)
1326    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1327    for k, v in presets.items():
1328        if not isinstance(v, float) and not k.endswith('_type'):
1329            presets[k] = eval(v, context)
1330    context.update(presets)
1331    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1332    for k, v in presets2.items():
1333        if not isinstance(v, float) and not k.endswith('_type'):
1334            presets2[k] = eval(v, context)
1335
1336    # update parameters with presets
1337    pars.update(presets)  # set value after random to control value
1338    pars2.update(presets2)  # set value after random to control value
1339    #import pprint; pprint.pprint(model_info)
1340
1341    if opts['show_pars']:
1342        if model_info.name != model_info2.name or pars != pars2:
1343            print("==== %s ====="%model_info.name)
1344            print(str(parlist(model_info, pars, opts['is2d'])))
1345            print("==== %s ====="%model_info2.name)
1346            print(str(parlist(model_info2, pars2, opts['is2d'])))
1347        else:
1348            print(str(parlist(model_info, pars, opts['is2d'])))
1349
1350    return pars, pars2
1351
1352def show_docs(opts):
1353    # type: (Dict[str, Any]) -> None
1354    """
1355    show html docs for the model
1356    """
1357    import os
1358    from .generate import make_html
1359    from . import rst2html
1360
1361    info = opts['def'][0]
1362    html = make_html(info)
1363    path = os.path.dirname(info.filename)
1364    url = "file://"+path.replace("\\","/")[2:]+"/"
1365    rst2html.view_html_qtapp(html, url)
1366
1367def explore(opts):
1368    # type: (Dict[str, Any]) -> None
1369    """
1370    explore the model using the bumps gui.
1371    """
1372    import wx  # type: ignore
1373    from bumps.names import FitProblem  # type: ignore
1374    from bumps.gui.app_frame import AppFrame  # type: ignore
1375    from bumps.gui import signal
1376
1377    is_mac = "cocoa" in wx.version()
1378    # Create an app if not running embedded
1379    app = wx.App() if wx.GetApp() is None else None
1380    model = Explore(opts)
1381    problem = FitProblem(model)
1382    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1383    if not is_mac:
1384        frame.Show()
1385    frame.panel.set_model(model=problem)
1386    frame.panel.Layout()
1387    frame.panel.aui.Split(0, wx.TOP)
1388    def reset_parameters(event):
1389        model.revert_values()
1390        signal.update_parameters(problem)
1391    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1392    if is_mac: frame.Show()
1393    # If running withing an app, start the main loop
1394    if app:
1395        app.MainLoop()
1396
1397class Explore(object):
1398    """
1399    Bumps wrapper for a SAS model comparison.
1400
1401    The resulting object can be used as a Bumps fit problem so that
1402    parameters can be adjusted in the GUI, with plots updated on the fly.
1403    """
1404    def __init__(self, opts):
1405        # type: (Dict[str, Any]) -> None
1406        from bumps.cli import config_matplotlib  # type: ignore
1407        from . import bumps_model
1408        config_matplotlib()
1409        self.opts = opts
1410        opts['pars'] = list(opts['pars'])
1411        p1, p2 = opts['pars']
1412        m1, m2 = opts['def']
1413        self.fix_p2 = m1 != m2 or p1 != p2
1414        model_info = m1
1415        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1416        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1417        if not opts['is2d']:
1418            for p in model_info.parameters.user_parameters({}, is2d=False):
1419                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1420                    k = p.name+ext
1421                    v = pars.get(k, None)
1422                    if v is not None:
1423                        v.range(*parameter_range(k, v.value))
1424        else:
1425            for k, v in pars.items():
1426                v.range(*parameter_range(k, v.value))
1427
1428        self.pars = pars
1429        self.starting_values = dict((k, v.value) for k, v in pars.items())
1430        self.pd_types = pd_types
1431        self.limits = np.Inf, -np.Inf
1432
1433    def revert_values(self):
1434        for k, v in self.starting_values.items():
1435            self.pars[k].value = v
1436
1437    def model_update(self):
1438        pass
1439
1440    def numpoints(self):
1441        # type: () -> int
1442        """
1443        Return the number of points.
1444        """
1445        return len(self.pars) + 1  # so dof is 1
1446
1447    def parameters(self):
1448        # type: () -> Any   # Dict/List hierarchy of parameters
1449        """
1450        Return a dictionary of parameters.
1451        """
1452        return self.pars
1453
1454    def nllf(self):
1455        # type: () -> float
1456        """
1457        Return cost.
1458        """
1459        # pylint: disable=no-self-use
1460        return 0.  # No nllf
1461
1462    def plot(self, view='log'):
1463        # type: (str) -> None
1464        """
1465        Plot the data and residuals.
1466        """
1467        pars = dict((k, v.value) for k, v in self.pars.items())
1468        pars.update(self.pd_types)
1469        self.opts['pars'][0] = pars
1470        if not self.fix_p2:
1471            self.opts['pars'][1] = pars
1472        result = run_models(self.opts)
1473        limits = plot_models(self.opts, result, limits=self.limits)
1474        if self.limits is None:
1475            vmin, vmax = limits
1476            self.limits = vmax*1e-7, 1.3*vmax
1477            import pylab; pylab.clf()
1478            plot_models(self.opts, result, limits=self.limits)
1479
1480
1481def main(*argv):
1482    # type: (*str) -> None
1483    """
1484    Main program.
1485    """
1486    opts = parse_opts(argv)
1487    if opts is not None:
1488        if opts['seed'] > -1:
1489            print("Randomize using -random=%i"%opts['seed'])
1490            np.random.seed(opts['seed'])
1491        if opts['html']:
1492            show_docs(opts)
1493        elif opts['explore']:
1494            opts['pars'] = parse_pars(opts)
1495            if opts['pars'] is None:
1496                return
1497            explore(opts)
1498        else:
1499            compare(opts)
1500
1501if __name__ == "__main__":
1502    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.