source: sasmodels/sasmodels/compare.py @ 765eb0e

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 765eb0e was 765eb0e, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

allow random generation of parameters for product and mixture models

  • Property mode set to 100755
File size: 52.3 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except Exception:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56
57USAGE = """
58usage: sascomp model [options...] [key=val]
59
60Generate and compare SAS models.  If a single model is specified it shows
61a plot of that model.  Different models can be compared, or the same model
62with different parameters.  The same model with the same parameters can
63be compared with different calculation engines to see the effects of precision
64on the resultant values.
65
66model or model1,model2 are the names of the models to compare (see below).
67
68Options (* for default):
69
70    === data generation ===
71    -data="path" uses q, dq from the data file
72    -noise=0 sets the measurement error dI/I
73    -res=0 sets the resolution width dQ/Q if calculating with resolution
74    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
75    -nq=128 sets the number of Q points in the data set
76    -1d*/-2d computes 1d or 2d data
77    -zero indicates that q=0 should be included
78
79    === model parameters ===
80    -preset*/-random[=seed] preset or random parameters
81    -sets=n generates n random datasets with the seed given by -random=seed
82    -pars/-nopars* prints the parameter set or not
83    -default/-demo* use demo vs default parameters
84
85    === calculation options ===
86    -mono*/-poly force monodisperse or allow polydisperse demo parameters
87    -cutoff=1e-5* cutoff value for including a point in polydispersity
88    -magnetic/-nonmagnetic* suppress magnetism
89    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
90    -neval=1 sets the number of evals for more accurate timing
91
92    === precision options ===
93    -calc=default uses the default calcution precision
94    -single/-double/-half/-fast sets an OpenCL calculation engine
95    -single!/-double!/-quad! sets an OpenMP calculation engine
96    -sasview sets the sasview calculation engine
97
98    === plotting ===
99    -plot*/-noplot plots or suppress the plot of the model
100    -linear/-log*/-q4 intensity scaling on plots
101    -hist/-nohist* plot histogram of relative error
102    -abs/-rel* plot relative or absolute error
103    -title="note" adds note to the plot title, after the model name
104
105    === output options ===
106    -edit starts the parameter explorer
107    -help/-html shows the model docs instead of running the model
108
109The interpretation of quad precision depends on architecture, and may
110vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
111On unix and mac you may need single quotes around the DLL computation
112engines, such as -calc='single!,double!' since !, is treated as a history
113expansion request in the shell.
114
115Key=value pairs allow you to set specific values for the model parameters.
116Key=value1,value2 to compare different values of the same parameter. The
117value can be an expression including other parameters.
118
119Items later on the command line override those that appear earlier.
120
121Examples:
122
123    # compare single and double precision calculation for a barbell
124    sascomp barbell -calc=single,double
125
126    # generate 10 random lorentz models, with seed=27
127    sascomp lorentz -sets=10 -seed=27
128
129    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
130    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
131
132    # model timing test requires multiple evals to perform the estimate
133    sascomp pringle -calc=single,double -timing=100,100 -noplot
134"""
135
136# Update docs with command line usage string.   This is separate from the usual
137# doc string so that we can display it at run time if there is an error.
138# lin
139__doc__ = (__doc__  # pylint: disable=redefined-builtin
140           + """
141Program description
142-------------------
143
144""" + USAGE)
145
146kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
147
148# list of math functions for use in evaluating parameters
149MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
150
151# CRUFT python 2.6
152if not hasattr(datetime.timedelta, 'total_seconds'):
153    def delay(dt):
154        """Return number date-time delta as number seconds"""
155        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
156else:
157    def delay(dt):
158        """Return number date-time delta as number seconds"""
159        return dt.total_seconds()
160
161
162class push_seed(object):
163    """
164    Set the seed value for the random number generator.
165
166    When used in a with statement, the random number generator state is
167    restored after the with statement is complete.
168
169    :Parameters:
170
171    *seed* : int or array_like, optional
172        Seed for RandomState
173
174    :Example:
175
176    Seed can be used directly to set the seed::
177
178        >>> from numpy.random import randint
179        >>> push_seed(24)
180        <...push_seed object at...>
181        >>> print(randint(0,1000000,3))
182        [242082    899 211136]
183
184    Seed can also be used in a with statement, which sets the random
185    number generator state for the enclosed computations and restores
186    it to the previous state on completion::
187
188        >>> with push_seed(24):
189        ...    print(randint(0,1000000,3))
190        [242082    899 211136]
191
192    Using nested contexts, we can demonstrate that state is indeed
193    restored after the block completes::
194
195        >>> with push_seed(24):
196        ...    print(randint(0,1000000))
197        ...    with push_seed(24):
198        ...        print(randint(0,1000000,3))
199        ...    print(randint(0,1000000))
200        242082
201        [242082    899 211136]
202        899
203
204    The restore step is protected against exceptions in the block::
205
206        >>> with push_seed(24):
207        ...    print(randint(0,1000000))
208        ...    try:
209        ...        with push_seed(24):
210        ...            print(randint(0,1000000,3))
211        ...            raise Exception()
212        ...    except Exception:
213        ...        print("Exception raised")
214        ...    print(randint(0,1000000))
215        242082
216        [242082    899 211136]
217        Exception raised
218        899
219    """
220    def __init__(self, seed=None):
221        # type: (Optional[int]) -> None
222        self._state = np.random.get_state()
223        np.random.seed(seed)
224
225    def __enter__(self):
226        # type: () -> None
227        pass
228
229    def __exit__(self, exc_type, exc_value, traceback):
230        # type: (Any, BaseException, Any) -> None
231        # TODO: better typing for __exit__ method
232        np.random.set_state(self._state)
233
234def tic():
235    # type: () -> Callable[[], float]
236    """
237    Timer function.
238
239    Use "toc=tic()" to start the clock and "toc()" to measure
240    a time interval.
241    """
242    then = datetime.datetime.now()
243    return lambda: delay(datetime.datetime.now() - then)
244
245
246def set_beam_stop(data, radius, outer=None):
247    # type: (Data, float, float) -> None
248    """
249    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
250
251    Note: this function does not require sasview
252    """
253    if hasattr(data, 'qx_data'):
254        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
255        data.mask = (q < radius)
256        if outer is not None:
257            data.mask |= (q >= outer)
258    else:
259        data.mask = (data.x < radius)
260        if outer is not None:
261            data.mask |= (data.x >= outer)
262
263
264def parameter_range(p, v):
265    # type: (str, float) -> Tuple[float, float]
266    """
267    Choose a parameter range based on parameter name and initial value.
268    """
269    # process the polydispersity options
270    if p.endswith('_pd_n'):
271        return 0., 100.
272    elif p.endswith('_pd_nsigma'):
273        return 0., 5.
274    elif p.endswith('_pd_type'):
275        raise ValueError("Cannot return a range for a string value")
276    elif any(s in p for s in ('theta', 'phi', 'psi')):
277        # orientation in [-180,180], orientation pd in [0,45]
278        if p.endswith('_pd'):
279            return 0., 45.
280        else:
281            return -180., 180.
282    elif p.endswith('_pd'):
283        return 0., 1.
284    elif 'sld' in p:
285        return -0.5, 10.
286    elif p == 'background':
287        return 0., 10.
288    elif p == 'scale':
289        return 0., 1.e3
290    elif v < 0.:
291        return 2.*v, -2.*v
292    else:
293        return 0., (2.*v if v > 0. else 1.)
294
295
296def _randomize_one(model_info, name, value):
297    # type: (ModelInfo, str, float) -> float
298    # type: (ModelInfo, str, str) -> str
299    """
300    Randomize a single parameter.
301    """
302    # Set the amount of polydispersity/angular dispersion, but by default pd_n
303    # is zero so there is no polydispersity.  This allows us to turn on/off
304    # pd by setting pd_n, and still have randomly generated values
305    if name.endswith('_pd'):
306        par = model_info.parameters[name[:-3]]
307        if par.type == 'orientation':
308            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
309            return 180*np.random.beta(2.5, 20)
310        else:
311            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
312            return np.random.beta(1.5, 7)
313
314    # pd is selected globally rather than per parameter, so set to 0 for no pd
315    # In particular, when multiple pd dimensions, want to decrease the number
316    # of points per dimension for faster computation
317    if name.endswith('_pd_n'):
318        return 0
319
320    # Don't mess with distribution type for now
321    if name.endswith('_pd_type'):
322        return 'gaussian'
323
324    # type-dependent value of number of sigmas; for gaussian use 3.
325    if name.endswith('_pd_nsigma'):
326        return 3.
327
328    # background in the range [0.01, 1]
329    if name == 'background':
330        return 10**np.random.uniform(-2, 0)
331
332    # scale defaults to 0.1% to 30% volume fraction
333    if name == 'scale':
334        return 10**np.random.uniform(-3, -0.5)
335
336    # If it is a list of choices, pick one at random with equal probability
337    # In practice, the model specific random generator will override.
338    par = model_info.parameters[name]
339    if len(par.limits) > 2:  # choice list
340        return np.random.randint(len(par.limits))
341
342    # If it is a fixed range, pick from it with equal probability.
343    # For logarithmic ranges, the model will have to override.
344    if np.isfinite(par.limits).all():
345        return np.random.uniform(*par.limits)
346
347    # If the paramter is marked as an sld use the range of neutron slds
348    # TODO: ought to randomly contrast match a pair of SLDs
349    if par.type == 'sld':
350        return np.random.uniform(-0.5, 12)
351
352    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
353    if par.name.startswith('M0:'):
354        return np.random.uniform(0, 5)
355
356    # Guess at the random length/radius/thickness.  In practice, all models
357    # are going to set their own reasonable ranges.
358    if par.type == 'volume':
359        if ('length' in par.name or
360                'radius' in par.name or
361                'thick' in par.name):
362            return 10**np.random.uniform(2, 4)
363
364    # In the absence of any other info, select a value in [0, 2v], or
365    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
366    # model random parameter generators will override this default.
367    low, high = parameter_range(par.name, value)
368    limits = (max(par.limits[0], low), min(par.limits[1], high))
369    return np.random.uniform(*limits)
370
371def _random_pd(model_info, pars):
372    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
373    pd_volume = []
374    pd_oriented = []
375    for p in pd:
376        if p.type == 'orientation':
377            pd_oriented.append(p.name)
378        elif p.length_control is not None:
379            n = int(pars.get(p.length_control, 1) + 0.5)
380            pd_volume.extend(p.name+str(k+1) for k in range(n))
381        elif p.length > 1:
382            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
383        else:
384            pd_volume.append(p.name)
385    u = np.random.rand()
386    n = len(pd_volume)
387    if u < 0.01 or n < 1:
388        pass  # 1% chance of no polydispersity
389    elif u < 0.86 or n < 2:
390        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
391    elif u < 0.99 or n < 3:
392        choices = np.random.choice(len(pd_volume), size=2)
393        pars[pd_volume[choices[0]]+"_pd_n"] = 25
394        pars[pd_volume[choices[1]]+"_pd_n"] = 10
395    else:
396        choices = np.random.choice(len(pd_volume), size=3)
397        pars[pd_volume[choices[0]]+"_pd_n"] = 25
398        pars[pd_volume[choices[1]]+"_pd_n"] = 10
399        pars[pd_volume[choices[2]]+"_pd_n"] = 5
400    if pd_oriented:
401        pars['theta_pd_n'] = 20
402        if np.random.rand() < 0.1:
403            pars['phi_pd_n'] = 5
404        if np.random.rand() < 0.1:
405            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
406                #print("generating psi_pd_n")
407                pars['psi_pd_n'] = 5
408
409    ## Show selected polydispersity
410    #for name, value in pars.items():
411    #    if name.endswith('_pd_n') and value > 0:
412    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
413
414
415def randomize_pars(model_info, pars):
416    # type: (ModelInfo, ParameterSet) -> ParameterSet
417    """
418    Generate random values for all of the parameters.
419
420    Valid ranges for the random number generator are guessed from the name of
421    the parameter; this will not account for constraints such as cap radius
422    greater than cylinder radius in the capped_cylinder model, so
423    :func:`constrain_pars` needs to be called afterward..
424    """
425    # Note: the sort guarantees order of calls to random number generator
426    random_pars = dict((p, _randomize_one(model_info, p, v))
427                       for p, v in sorted(pars.items()))
428    if model_info.random is not None:
429        random_pars.update(model_info.random())
430    _random_pd(model_info, random_pars)
431    return random_pars
432
433
434def constrain_pars(model_info, pars):
435    # type: (ModelInfo, ParameterSet) -> None
436    """
437    Restrict parameters to valid values.
438
439    This includes model specific code for models such as capped_cylinder
440    which need to support within model constraints (cap radius more than
441    cylinder radius in this case).
442
443    Warning: this updates the *pars* dictionary in place.
444    """
445    # TODO: move the model specific code to the individual models
446    name = model_info.id
447    # if it is a product model, then just look at the form factor since
448    # none of the structure factors need any constraints.
449    if '*' in name:
450        name = name.split('*')[0]
451
452    # Suppress magnetism for python models (not yet implemented)
453    if callable(model_info.Iq):
454        pars.update(suppress_magnetism(pars))
455
456    if name == 'barbell':
457        if pars['radius_bell'] < pars['radius']:
458            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
459
460    elif name == 'capped_cylinder':
461        if pars['radius_cap'] < pars['radius']:
462            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
463
464    elif name == 'guinier':
465        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
466        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
467        # => ln A - (Rg^2 q^2/3) > -30 ln 10
468        # => Rg^2 q^2/3 < 30 ln 10 + ln A
469        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
470        #q_max = 0.2  # mid q maximum
471        q_max = 1.0  # high q maximum
472        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
473        pars['rg'] = min(pars['rg'], rg_max)
474
475    elif name == 'pearl_necklace':
476        if pars['radius'] < pars['thick_string']:
477            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
478        pass
479
480    elif name == 'rpa':
481        # Make sure phi sums to 1.0
482        if pars['case_num'] < 2:
483            pars['Phi1'] = 0.
484            pars['Phi2'] = 0.
485        elif pars['case_num'] < 5:
486            pars['Phi1'] = 0.
487        total = sum(pars['Phi'+c] for c in '1234')
488        for c in '1234':
489            pars['Phi'+c] /= total
490
491def parlist(model_info, pars, is2d):
492    # type: (ModelInfo, ParameterSet, bool) -> str
493    """
494    Format the parameter list for printing.
495    """
496    lines = []
497    parameters = model_info.parameters
498    magnetic = False
499    magnetic_pars = []
500    for p in parameters.user_parameters(pars, is2d):
501        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
502            continue
503        if p.id.startswith('up:'):
504            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
505            continue
506        fields = dict(
507            value=pars.get(p.id, p.default),
508            pd=pars.get(p.id+"_pd", 0.),
509            n=int(pars.get(p.id+"_pd_n", 0)),
510            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
511            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
512            relative_pd=p.relative_pd,
513            M0=pars.get('M0:'+p.id, 0.),
514            mphi=pars.get('mphi:'+p.id, 0.),
515            mtheta=pars.get('mtheta:'+p.id, 0.),
516        )
517        lines.append(_format_par(p.name, **fields))
518        magnetic = magnetic or fields['M0'] != 0.
519    if magnetic and magnetic_pars:
520        lines.append(" ".join(magnetic_pars))
521    return "\n".join(lines)
522
523    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
524
525def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
526                relative_pd=False, M0=0., mphi=0., mtheta=0.):
527    # type: (str, float, float, int, float, str) -> str
528    line = "%s: %g"%(name, value)
529    if pd != 0.  and n != 0:
530        if relative_pd:
531            pd *= value
532        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
533                % (pd, n, nsigma, nsigma, pdtype)
534    if M0 != 0.:
535        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
536    return line
537
538def suppress_pd(pars, suppress=True):
539    # type: (ParameterSet) -> ParameterSet
540    """
541    If suppress is True complete eliminate polydispersity of the model to test
542    models more quickly.  If suppress is False, make sure at least one
543    parameter is polydisperse, setting the first polydispersity parameter to
544    15% if no polydispersity is given (with no explicit demo parameters given
545    in the model, there will be no default polydispersity).
546    """
547    pars = pars.copy()
548    #print("pars=", pars)
549    if suppress:
550        for p in pars:
551            if p.endswith("_pd_n"):
552                pars[p] = 0
553    else:
554        any_pd = False
555        first_pd = None
556        for p in pars:
557            if p.endswith("_pd_n"):
558                pd = pars.get(p[:-2], 0.)
559                any_pd |= (pars[p] != 0 and pd != 0.)
560                if first_pd is None:
561                    first_pd = p
562        if not any_pd and first_pd is not None:
563            if pars[first_pd] == 0:
564                pars[first_pd] = 35
565            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
566                pars[first_pd[:-2]] = 0.15
567    return pars
568
569def suppress_magnetism(pars, suppress=True):
570    # type: (ParameterSet) -> ParameterSet
571    """
572    If suppress is True complete eliminate magnetism of the model to test
573    models more quickly.  If suppress is False, make sure at least one sld
574    parameter is magnetic, setting the first parameter to have a strong
575    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
576    in the model, there will be no default magnetism).
577    """
578    pars = pars.copy()
579    if suppress:
580        for p in pars:
581            if p.startswith("M0:"):
582                pars[p] = 0
583    else:
584        any_mag = False
585        first_mag = None
586        for p in pars:
587            if p.startswith("M0:"):
588                any_mag |= (pars[p] != 0)
589                if first_mag is None:
590                    first_mag = p
591        if not any_mag and first_mag is not None:
592            pars[first_mag] = 8.
593    return pars
594
595def eval_sasview(model_info, data):
596    # type: (Modelinfo, Data) -> Calculator
597    """
598    Return a model calculator using the pre-4.0 SasView models.
599    """
600    # importing sas here so that the error message will be that sas failed to
601    # import rather than the more obscure smear_selection not imported error
602    import sas
603    import sas.models
604    from sas.models.qsmearing import smear_selection
605    from sas.models.MultiplicationModel import MultiplicationModel
606    from sas.models.dispersion_models import models as dispersers
607
608    def get_model_class(name):
609        # type: (str) -> "sas.models.BaseComponent"
610        #print("new",sorted(_pars.items()))
611        __import__('sas.models.' + name)
612        ModelClass = getattr(getattr(sas.models, name, None), name, None)
613        if ModelClass is None:
614            raise ValueError("could not find model %r in sas.models"%name)
615        return ModelClass
616
617    # WARNING: ugly hack when handling model!
618    # Sasview models with multiplicity need to be created with the target
619    # multiplicity, so we cannot create the target model ahead of time for
620    # for multiplicity models.  Instead we store the model in a list and
621    # update the first element of that list with the new multiplicity model
622    # every time we evaluate.
623
624    # grab the sasview model, or create it if it is a product model
625    if model_info.composition:
626        composition_type, parts = model_info.composition
627        if composition_type == 'product':
628            P, S = [get_model_class(revert_name(p))() for p in parts]
629            model = [MultiplicationModel(P, S)]
630        else:
631            raise ValueError("sasview mixture models not supported by compare")
632    else:
633        old_name = revert_name(model_info)
634        if old_name is None:
635            raise ValueError("model %r does not exist in old sasview"
636                            % model_info.id)
637        ModelClass = get_model_class(old_name)
638        model = [ModelClass()]
639    model[0].disperser_handles = {}
640
641    # build a smearer with which to call the model, if necessary
642    smearer = smear_selection(data, model=model)
643    if hasattr(data, 'qx_data'):
644        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
645        index = ((~data.mask) & (~np.isnan(data.data))
646                 & (q >= data.qmin) & (q <= data.qmax))
647        if smearer is not None:
648            smearer.model = model  # because smear_selection has a bug
649            smearer.accuracy = data.accuracy
650            smearer.set_index(index)
651            def _call_smearer():
652                smearer.model = model[0]
653                return smearer.get_value()
654            theory = _call_smearer
655        else:
656            theory = lambda: model[0].evalDistribution([data.qx_data[index],
657                                                        data.qy_data[index]])
658    elif smearer is not None:
659        theory = lambda: smearer(model[0].evalDistribution(data.x))
660    else:
661        theory = lambda: model[0].evalDistribution(data.x)
662
663    def calculator(**pars):
664        # type: (float, ...) -> np.ndarray
665        """
666        Sasview calculator for model.
667        """
668        oldpars = revert_pars(model_info, pars)
669        # For multiplicity models, create a model with the correct multiplicity
670        control = oldpars.pop("CONTROL", None)
671        if control is not None:
672            # sphericalSLD has one fewer multiplicity.  This update should
673            # happen in revert_pars, but it hasn't been called yet.
674            model[0] = ModelClass(control)
675        # paying for parameter conversion each time to keep life simple, if not fast
676        for k, v in oldpars.items():
677            if k.endswith('.type'):
678                par = k[:-5]
679                if v == 'gaussian': continue
680                cls = dispersers[v if v != 'rectangle' else 'rectangula']
681                handle = cls()
682                model[0].disperser_handles[par] = handle
683                try:
684                    model[0].set_dispersion(par, handle)
685                except Exception:
686                    exception.annotate_exception("while setting %s to %r"
687                                                 %(par, v))
688                    raise
689
690
691        #print("sasview pars",oldpars)
692        for k, v in oldpars.items():
693            name_attr = k.split('.')  # polydispersity components
694            if len(name_attr) == 2:
695                par, disp_par = name_attr
696                model[0].dispersion[par][disp_par] = v
697            else:
698                model[0].setParam(k, v)
699        return theory()
700
701    calculator.engine = "sasview"
702    return calculator
703
704DTYPE_MAP = {
705    'half': '16',
706    'fast': 'fast',
707    'single': '32',
708    'double': '64',
709    'quad': '128',
710    'f16': '16',
711    'f32': '32',
712    'f64': '64',
713    'float16': '16',
714    'float32': '32',
715    'float64': '64',
716    'float128': '128',
717    'longdouble': '128',
718}
719def eval_opencl(model_info, data, dtype='single', cutoff=0.):
720    # type: (ModelInfo, Data, str, float) -> Calculator
721    """
722    Return a model calculator using the OpenCL calculation engine.
723    """
724    if not core.HAVE_OPENCL:
725        raise RuntimeError("OpenCL not available")
726    model = core.build_model(model_info, dtype=dtype, platform="ocl")
727    calculator = DirectModel(data, model, cutoff=cutoff)
728    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
729    return calculator
730
731def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
732    # type: (ModelInfo, Data, str, float) -> Calculator
733    """
734    Return a model calculator using the DLL calculation engine.
735    """
736    model = core.build_model(model_info, dtype=dtype, platform="dll")
737    calculator = DirectModel(data, model, cutoff=cutoff)
738    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
739    return calculator
740
741def time_calculation(calculator, pars, evals=1):
742    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
743    """
744    Compute the average calculation time over N evaluations.
745
746    An additional call is generated without polydispersity in order to
747    initialize the calculation engine, and make the average more stable.
748    """
749    # initialize the code so time is more accurate
750    if evals > 1:
751        calculator(**suppress_pd(pars))
752    toc = tic()
753    # make sure there is at least one eval
754    value = calculator(**pars)
755    for _ in range(evals-1):
756        value = calculator(**pars)
757    average_time = toc()*1000. / evals
758    #print("I(q)",value)
759    return value, average_time
760
761def make_data(opts):
762    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
763    """
764    Generate an empty dataset, used with the model to set Q points
765    and resolution.
766
767    *opts* contains the options, with 'qmax', 'nq', 'res',
768    'accuracy', 'is2d' and 'view' parsed from the command line.
769    """
770    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
771    if opts['is2d']:
772        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
773        data = empty_data2D(q, resolution=res)
774        data.accuracy = opts['accuracy']
775        set_beam_stop(data, 0.0004)
776        index = ~data.mask
777    else:
778        if opts['view'] == 'log' and not opts['zero']:
779            qmax = math.log10(qmax)
780            q = np.logspace(qmax-3, qmax, nq)
781        else:
782            q = np.linspace(0.001*qmax, qmax, nq)
783        if opts['zero']:
784            q = np.hstack((0, q))
785        data = empty_data1D(q, resolution=res)
786        index = slice(None, None)
787    return data, index
788
789def make_engine(model_info, data, dtype, cutoff):
790    # type: (ModelInfo, Data, str, float) -> Calculator
791    """
792    Generate the appropriate calculation engine for the given datatype.
793
794    Datatypes with '!' appended are evaluated using external C DLLs rather
795    than OpenCL.
796    """
797    if dtype == 'sasview':
798        return eval_sasview(model_info, data)
799    elif dtype is None or not dtype.endswith('!'):
800        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
801    else:
802        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
803
804def _show_invalid(data, theory):
805    # type: (Data, np.ma.ndarray) -> None
806    """
807    Display a list of the non-finite values in theory.
808    """
809    if not theory.mask.any():
810        return
811
812    if hasattr(data, 'x'):
813        bad = zip(data.x[theory.mask], theory[theory.mask])
814        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
815
816
817def compare(opts, limits=None):
818    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
819    """
820    Preform a comparison using options from the command line.
821
822    *limits* are the limits on the values to use, either to set the y-axis
823    for 1D or to set the colormap scale for 2D.  If None, then they are
824    inferred from the data and returned. When exploring using Bumps,
825    the limits are set when the model is initially called, and maintained
826    as the values are adjusted, making it easier to see the effects of the
827    parameters.
828    """
829    limits = np.Inf, -np.Inf
830    for k in range(opts['sets']):
831        opts['pars'] = parse_pars(opts)
832        if opts['pars'] is None:
833            return
834        result = run_models(opts, verbose=True)
835        if opts['plot']:
836            limits = plot_models(opts, result, limits=limits, setnum=k)
837    if opts['plot']:
838        import matplotlib.pyplot as plt
839        plt.show()
840
841def run_models(opts, verbose=False):
842    # type: (Dict[str, Any]) -> Dict[str, Any]
843
844    base, comp = opts['engines']
845    base_n, comp_n = opts['count']
846    base_pars, comp_pars = opts['pars']
847    data = opts['data']
848
849    comparison = comp is not None
850
851    base_time = comp_time = None
852    base_value = comp_value = resid = relerr = None
853
854    # Base calculation
855    try:
856        base_raw, base_time = time_calculation(base, base_pars, base_n)
857        base_value = np.ma.masked_invalid(base_raw)
858        if verbose:
859            print("%s t=%.2f ms, intensity=%.0f"
860                  % (base.engine, base_time, base_value.sum()))
861        _show_invalid(data, base_value)
862    except ImportError:
863        traceback.print_exc()
864
865    # Comparison calculation
866    if comparison:
867        try:
868            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
869            comp_value = np.ma.masked_invalid(comp_raw)
870            if verbose:
871                print("%s t=%.2f ms, intensity=%.0f"
872                      % (comp.engine, comp_time, comp_value.sum()))
873            _show_invalid(data, comp_value)
874        except ImportError:
875            traceback.print_exc()
876
877    # Compare, but only if computing both forms
878    if comparison:
879        resid = (base_value - comp_value)
880        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
881        if verbose:
882            _print_stats("|%s-%s|"
883                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
884                         resid)
885            _print_stats("|(%s-%s)/%s|"
886                         % (base.engine, comp.engine, comp.engine),
887                         relerr)
888
889    return dict(base_value=base_value, comp_value=comp_value,
890                base_time=base_time, comp_time=comp_time,
891                resid=resid, relerr=relerr)
892
893
894def _print_stats(label, err):
895    # type: (str, np.ma.ndarray) -> None
896    # work with trimmed data, not the full set
897    sorted_err = np.sort(abs(err.compressed()))
898    if len(sorted_err) == 0.:
899        print(label + "  no valid values")
900        return
901
902    p50 = int((len(sorted_err)-1)*0.50)
903    p98 = int((len(sorted_err)-1)*0.98)
904    data = [
905        "max:%.3e"%sorted_err[-1],
906        "median:%.3e"%sorted_err[p50],
907        "98%%:%.3e"%sorted_err[p98],
908        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
909        "zero-offset:%+.3e"%np.mean(sorted_err),
910        ]
911    print(label+"  "+"  ".join(data))
912
913
914def plot_models(opts, result, limits=(np.Inf, -np.Inf), setnum=0):
915    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
916    base_value, comp_value = result['base_value'], result['comp_value']
917    base_time, comp_time = result['base_time'], result['comp_time']
918    resid, relerr = result['resid'], result['relerr']
919
920    have_base, have_comp = (base_value is not None), (comp_value is not None)
921    base, comp = opts['engines']
922    data = opts['data']
923    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
924
925    # Plot if requested
926    view = opts['view']
927    import matplotlib.pyplot as plt
928    vmin, vmax = limits
929    if have_base:
930        vmin = min(vmin, base_value.min())
931        vmax = max(vmax, base_value.max())
932    if have_comp:
933        vmin = min(vmin, comp_value.min())
934        vmax = max(vmax, comp_value.max())
935    limits = vmin, vmax
936
937    if have_base:
938        if have_comp:
939            plt.subplot(131)
940        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
941        plt.title("%s t=%.2f ms"%(base.engine, base_time))
942        #cbar_title = "log I"
943    if have_comp:
944        if have_base:
945            plt.subplot(132)
946        if not opts['is2d'] and have_base:
947            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
948        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
949        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
950        #cbar_title = "log I"
951    if have_base and have_comp:
952        plt.subplot(133)
953        if not opts['rel_err']:
954            err, errstr, errview = resid, "abs err", "linear"
955        else:
956            err, errstr, errview = abs(relerr), "rel err", "log"
957        if 0:  # 95% cutoff
958            sorted = np.sort(err.flatten())
959            cutoff = sorted[int(sorted.size*0.95)]
960            err[err > cutoff] = cutoff
961        #err,errstr = base/comp,"ratio"
962        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
963        if view == 'linear':
964            plt.xscale('linear')
965        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
966        #cbar_title = errstr if errview=="linear" else "log "+errstr
967    #if is2D:
968    #    h = plt.colorbar()
969    #    h.ax.set_title(cbar_title)
970    fig = plt.gcf()
971    extra_title = ' '+opts['title'] if opts['title'] else ''
972    fig.suptitle(":".join(opts['name']) + extra_title)
973
974    if have_base and have_comp and opts['show_hist']:
975        plt.figure()
976        v = relerr
977        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
978        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
979        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
980                   % (base.engine, comp.engine, comp.engine))
981        plt.ylabel('P(err)')
982        plt.title('Distribution of relative error between calculation engines')
983
984    return limits
985
986
987# ===========================================================================
988#
989
990# Set of command line options.
991# Normal options such as -plot/-noplot are specified as 'name'.
992# For options such as -nq=500 which require a value use 'name='.
993#
994OPTIONS = [
995    # Plotting
996    'plot', 'noplot',
997    'linear', 'log', 'q4',
998    'rel', 'abs',
999    'hist', 'nohist',
1000    'title=',
1001
1002    # Data generation
1003    'data=', 'noise=', 'res=',
1004    'nq=', 'lowq', 'midq', 'highq', 'exq', 'zero',
1005    '2d', '1d',
1006
1007    # Parameter set
1008    'preset', 'random', 'random=', 'sets=',
1009    'demo', 'default',  # TODO: remove demo/default
1010    'nopars', 'pars',
1011
1012    # Calculation options
1013    'poly', 'mono', 'cutoff=',
1014    'magnetic', 'nonmagnetic',
1015    'accuracy=',
1016    'neval=',  # for timing...
1017
1018    # Precision options
1019    'calc=',
1020    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1021    'sasview',  # TODO: remove sasview 3.x support
1022
1023    # Output options
1024    'help', 'html', 'edit',
1025    ]
1026
1027NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1028VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1029
1030
1031def columnize(items, indent="", width=79):
1032    # type: (List[str], str, int) -> str
1033    """
1034    Format a list of strings into columns.
1035
1036    Returns a string with carriage returns ready for printing.
1037    """
1038    column_width = max(len(w) for w in items) + 1
1039    num_columns = (width - len(indent)) // column_width
1040    num_rows = len(items) // num_columns
1041    items = items + [""] * (num_rows * num_columns - len(items))
1042    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1043    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1044             for row in zip(*columns)]
1045    output = indent + ("\n"+indent).join(lines)
1046    return output
1047
1048
1049def get_pars(model_info, use_demo=False):
1050    # type: (ModelInfo, bool) -> ParameterSet
1051    """
1052    Extract demo parameters from the model definition.
1053    """
1054    # Get the default values for the parameters
1055    pars = {}
1056    for p in model_info.parameters.call_parameters:
1057        parts = [('', p.default)]
1058        if p.polydisperse:
1059            parts.append(('_pd', 0.0))
1060            parts.append(('_pd_n', 0))
1061            parts.append(('_pd_nsigma', 3.0))
1062            parts.append(('_pd_type', "gaussian"))
1063        for ext, val in parts:
1064            if p.length > 1:
1065                dict(("%s%d%s" % (p.id, k, ext), val)
1066                     for k in range(1, p.length+1))
1067            else:
1068                pars[p.id + ext] = val
1069
1070    # Plug in values given in demo
1071    if use_demo and model_info.demo:
1072        pars.update(model_info.demo)
1073    return pars
1074
1075INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1076def isnumber(str):
1077    match = FLOAT_RE.match(str)
1078    isfloat = (match and not str[match.end():])
1079    return isfloat or INTEGER_RE.match(str)
1080
1081# For distinguishing pairs of models for comparison
1082# key-value pair separator =
1083# shell characters  | & ; <> $ % ' " \ # `
1084# model and parameter names _
1085# parameter expressions - + * / . ( )
1086# path characters including tilde expansion and windows drive ~ / :
1087# not sure about brackets [] {}
1088# maybe one of the following @ ? ^ ! ,
1089PAR_SPLIT = ','
1090def parse_opts(argv):
1091    # type: (List[str]) -> Dict[str, Any]
1092    """
1093    Parse command line options.
1094    """
1095    MODELS = core.list_models()
1096    flags = [arg for arg in argv
1097             if arg.startswith('-')]
1098    values = [arg for arg in argv
1099              if not arg.startswith('-') and '=' in arg]
1100    positional_args = [arg for arg in argv
1101                       if not arg.startswith('-') and '=' not in arg]
1102    models = "\n    ".join("%-15s"%v for v in MODELS)
1103    if len(positional_args) == 0:
1104        print(USAGE)
1105        print("\nAvailable models:")
1106        print(columnize(MODELS, indent="  "))
1107        return None
1108
1109    invalid = [o[1:] for o in flags
1110               if o[1:] not in NAME_OPTIONS
1111               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1112    if invalid:
1113        print("Invalid options: %s"%(", ".join(invalid)))
1114        return None
1115
1116    name = positional_args[-1]
1117
1118    # pylint: disable=bad-whitespace
1119    # Interpret the flags
1120    opts = {
1121        'plot'      : True,
1122        'view'      : 'log',
1123        'is2d'      : False,
1124        'qmax'      : 0.05,
1125        'nq'        : 128,
1126        'res'       : 0.0,
1127        'noise'     : 0.0,
1128        'accuracy'  : 'Low',
1129        'cutoff'    : '0.0',
1130        'seed'      : -1,  # default to preset
1131        'mono'      : True,
1132        # Default to magnetic a magnetic moment is set on the command line
1133        'magnetic'  : False,
1134        'show_pars' : False,
1135        'show_hist' : False,
1136        'rel_err'   : True,
1137        'explore'   : False,
1138        'use_demo'  : True,
1139        'zero'      : False,
1140        'html'      : False,
1141        'title'     : None,
1142        'datafile'  : None,
1143        'sets'      : 0,
1144        'engine'    : 'default',
1145        'evals'     : '1',
1146    }
1147    for arg in flags:
1148        if arg == '-noplot':    opts['plot'] = False
1149        elif arg == '-plot':    opts['plot'] = True
1150        elif arg == '-linear':  opts['view'] = 'linear'
1151        elif arg == '-log':     opts['view'] = 'log'
1152        elif arg == '-q4':      opts['view'] = 'q4'
1153        elif arg == '-1d':      opts['is2d'] = False
1154        elif arg == '-2d':      opts['is2d'] = True
1155        elif arg == '-exq':     opts['qmax'] = 10.0
1156        elif arg == '-highq':   opts['qmax'] = 1.0
1157        elif arg == '-midq':    opts['qmax'] = 0.2
1158        elif arg == '-lowq':    opts['qmax'] = 0.05
1159        elif arg == '-zero':    opts['zero'] = True
1160        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1161        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1162        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1163        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1164        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1165        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1166        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
1167        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1168        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1169        elif arg.startswith('-calc='):     opts['engine'] = arg[6:]
1170        elif arg.startswith('-neval='):    opts['evals'] = arg[7:]
1171        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
1172        elif arg == '-preset':  opts['seed'] = -1
1173        elif arg == '-mono':    opts['mono'] = True
1174        elif arg == '-poly':    opts['mono'] = False
1175        elif arg == '-magnetic':       opts['magnetic'] = True
1176        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1177        elif arg == '-pars':    opts['show_pars'] = True
1178        elif arg == '-nopars':  opts['show_pars'] = False
1179        elif arg == '-hist':    opts['show_hist'] = True
1180        elif arg == '-nohist':  opts['show_hist'] = False
1181        elif arg == '-rel':     opts['rel_err'] = True
1182        elif arg == '-abs':     opts['rel_err'] = False
1183        elif arg == '-half':    opts['engine'] = 'half'
1184        elif arg == '-fast':    opts['engine'] = 'fast'
1185        elif arg == '-single':  opts['engine'] = 'single'
1186        elif arg == '-double':  opts['engine'] = 'double'
1187        elif arg == '-single!': opts['engine'] = 'single!'
1188        elif arg == '-double!': opts['engine'] = 'double!'
1189        elif arg == '-quad!':   opts['engine'] = 'quad!'
1190        elif arg == '-sasview': opts['engine'] = 'sasview'
1191        elif arg == '-edit':    opts['explore'] = True
1192        elif arg == '-demo':    opts['use_demo'] = True
1193        elif arg == '-default': opts['use_demo'] = False
1194        elif arg == '-html':    opts['html'] = True
1195        elif arg == '-help':    opts['html'] = True
1196    # pylint: enable=bad-whitespace
1197
1198    # Magnetism forces 2D for now
1199    if opts['magnetic']:
1200        opts['is2d'] = True
1201
1202    # Force random if sets is used
1203    if opts['sets'] >= 1 and opts['seed'] < 0:
1204        opts['seed'] = np.random.randint(1000000)
1205    if opts['sets'] == 0:
1206        opts['sets'] = 1
1207
1208    # Create the computational engines
1209    if opts['datafile'] is not None:
1210        data = load_data(os.path.expanduser(opts['datafile']))
1211    else:
1212        data, _ = make_data(opts)
1213
1214    comparison = any(PAR_SPLIT in v for v in values)
1215    if PAR_SPLIT in name:
1216        names = name.split(PAR_SPLIT, 2)
1217        comparison = True
1218    else:
1219        names = [name]*2
1220    try:
1221        model_info = [core.load_model_info(k) for k in names]
1222    except ImportError as exc:
1223        print(str(exc))
1224        print("Could not find model; use one of:\n    " + models)
1225        return None
1226
1227    if PAR_SPLIT in opts['engine']:
1228        engine_types = opts['engine'].split(PAR_SPLIT, 2)
1229        comparison = True
1230    else:
1231        engine_types = [opts['engine']]*2
1232
1233    if PAR_SPLIT in opts['evals']:
1234        evals = [int(k) for k in opts['evals'].split(PAR_SPLIT, 2)]
1235        comparison = True
1236    else:
1237        evals = [int(opts['evals'])]*2
1238
1239    if PAR_SPLIT in opts['cutoff']:
1240        cutoff = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1241        comparison = True
1242    else:
1243        cutoff = [float(opts['cutoff'])]*2
1244
1245    base = make_engine(model_info[0], data, engine_types[0], cutoff[0])
1246    if comparison:
1247        comp = make_engine(model_info[1], data, engine_types[1], cutoff[1])
1248    else:
1249        comp = None
1250
1251    # pylint: disable=bad-whitespace
1252    # Remember it all
1253    opts.update({
1254        'data'      : data,
1255        'name'      : names,
1256        'def'       : model_info,
1257        'count'     : evals,
1258        'engines'   : [base, comp],
1259        'values'    : values,
1260    })
1261    # pylint: enable=bad-whitespace
1262
1263    return opts
1264
1265def parse_pars(opts):
1266    model_info, model_info2 = opts['def']
1267
1268    # Get demo parameters from model definition, or use default parameters
1269    # if model does not define demo parameters
1270    pars = get_pars(model_info, opts['use_demo'])
1271    pars2 = get_pars(model_info2, opts['use_demo'])
1272    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1273    # randomize parameters
1274    #pars.update(set_pars)  # set value before random to control range
1275    if opts['seed'] > -1:
1276        pars = randomize_pars(model_info, pars)
1277        if model_info != model_info2:
1278            pars2 = randomize_pars(model_info2, pars2)
1279            # Share values for parameters with the same name
1280            for k, v in pars.items():
1281                if k in pars2:
1282                    pars2[k] = v
1283        else:
1284            pars2 = pars.copy()
1285        constrain_pars(model_info, pars)
1286        constrain_pars(model_info2, pars2)
1287    pars = suppress_pd(pars, opts['mono'])
1288    pars2 = suppress_pd(pars2, opts['mono'])
1289    pars = suppress_magnetism(pars, not opts['magnetic'])
1290    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1291
1292    # Fill in parameters given on the command line
1293    presets = {}
1294    presets2 = {}
1295    for arg in opts['values']:
1296        k, v = arg.split('=', 1)
1297        if k not in pars and k not in pars2:
1298            # extract base name without polydispersity info
1299            s = set(p.split('_pd')[0] for p in pars)
1300            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1301            return None
1302        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
1303        if v1 and k in pars:
1304            presets[k] = float(v1) if isnumber(v1) else v1
1305        if v2 and k in pars2:
1306            presets2[k] = float(v2) if isnumber(v2) else v2
1307
1308    # If pd given on the command line, default pd_n to 35
1309    for k, v in list(presets.items()):
1310        if k.endswith('_pd'):
1311            presets.setdefault(k+'_n', 35.)
1312    for k, v in list(presets2.items()):
1313        if k.endswith('_pd'):
1314            presets2.setdefault(k+'_n', 35.)
1315
1316    # Evaluate preset parameter expressions
1317    context = MATH.copy()
1318    context['np'] = np
1319    context.update(pars)
1320    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1321    for k, v in presets.items():
1322        if not isinstance(v, float) and not k.endswith('_type'):
1323            presets[k] = eval(v, context)
1324    context.update(presets)
1325    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1326    for k, v in presets2.items():
1327        if not isinstance(v, float) and not k.endswith('_type'):
1328            presets2[k] = eval(v, context)
1329
1330    # update parameters with presets
1331    pars.update(presets)  # set value after random to control value
1332    pars2.update(presets2)  # set value after random to control value
1333    #import pprint; pprint.pprint(model_info)
1334
1335    if opts['show_pars']:
1336        if model_info.name != model_info2.name or pars != pars2:
1337            print("==== %s ====="%model_info.name)
1338            print(str(parlist(model_info, pars, opts['is2d'])))
1339            print("==== %s ====="%model_info2.name)
1340            print(str(parlist(model_info2, pars2, opts['is2d'])))
1341        else:
1342            print(str(parlist(model_info, pars, opts['is2d'])))
1343
1344    return pars, pars2
1345
1346def show_docs(opts):
1347    # type: (Dict[str, Any]) -> None
1348    """
1349    show html docs for the model
1350    """
1351    import os
1352    from .generate import make_html
1353    from . import rst2html
1354
1355    info = opts['def'][0]
1356    html = make_html(info)
1357    path = os.path.dirname(info.filename)
1358    url = "file://"+path.replace("\\","/")[2:]+"/"
1359    rst2html.view_html_qtapp(html, url)
1360
1361def explore(opts):
1362    # type: (Dict[str, Any]) -> None
1363    """
1364    explore the model using the bumps gui.
1365    """
1366    import wx  # type: ignore
1367    from bumps.names import FitProblem  # type: ignore
1368    from bumps.gui.app_frame import AppFrame  # type: ignore
1369    from bumps.gui import signal
1370
1371    is_mac = "cocoa" in wx.version()
1372    # Create an app if not running embedded
1373    app = wx.App() if wx.GetApp() is None else None
1374    model = Explore(opts)
1375    problem = FitProblem(model)
1376    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1377    if not is_mac:
1378        frame.Show()
1379    frame.panel.set_model(model=problem)
1380    frame.panel.Layout()
1381    frame.panel.aui.Split(0, wx.TOP)
1382    def reset_parameters(event):
1383        model.revert_values()
1384        signal.update_parameters(problem)
1385    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1386    if is_mac: frame.Show()
1387    # If running withing an app, start the main loop
1388    if app:
1389        app.MainLoop()
1390
1391class Explore(object):
1392    """
1393    Bumps wrapper for a SAS model comparison.
1394
1395    The resulting object can be used as a Bumps fit problem so that
1396    parameters can be adjusted in the GUI, with plots updated on the fly.
1397    """
1398    def __init__(self, opts):
1399        # type: (Dict[str, Any]) -> None
1400        from bumps.cli import config_matplotlib  # type: ignore
1401        from . import bumps_model
1402        config_matplotlib()
1403        self.opts = opts
1404        opts['pars'] = list(opts['pars'])
1405        p1, p2 = opts['pars']
1406        m1, m2 = opts['def']
1407        self.fix_p2 = m1 != m2 or p1 != p2
1408        model_info = m1
1409        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1410        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1411        if not opts['is2d']:
1412            for p in model_info.parameters.user_parameters({}, is2d=False):
1413                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1414                    k = p.name+ext
1415                    v = pars.get(k, None)
1416                    if v is not None:
1417                        v.range(*parameter_range(k, v.value))
1418        else:
1419            for k, v in pars.items():
1420                v.range(*parameter_range(k, v.value))
1421
1422        self.pars = pars
1423        self.starting_values = dict((k, v.value) for k, v in pars.items())
1424        self.pd_types = pd_types
1425        self.limits = np.Inf, -np.Inf
1426
1427    def revert_values(self):
1428        for k, v in self.starting_values.items():
1429            self.pars[k].value = v
1430
1431    def model_update(self):
1432        pass
1433
1434    def numpoints(self):
1435        # type: () -> int
1436        """
1437        Return the number of points.
1438        """
1439        return len(self.pars) + 1  # so dof is 1
1440
1441    def parameters(self):
1442        # type: () -> Any   # Dict/List hierarchy of parameters
1443        """
1444        Return a dictionary of parameters.
1445        """
1446        return self.pars
1447
1448    def nllf(self):
1449        # type: () -> float
1450        """
1451        Return cost.
1452        """
1453        # pylint: disable=no-self-use
1454        return 0.  # No nllf
1455
1456    def plot(self, view='log'):
1457        # type: (str) -> None
1458        """
1459        Plot the data and residuals.
1460        """
1461        pars = dict((k, v.value) for k, v in self.pars.items())
1462        pars.update(self.pd_types)
1463        self.opts['pars'][0] = pars
1464        if not self.fix_p2:
1465            self.opts['pars'][1] = pars
1466        result = run_models(self.opts)
1467        limits = plot_models(self.opts, result, limits=self.limits)
1468        if self.limits is None:
1469            vmin, vmax = limits
1470            self.limits = vmax*1e-7, 1.3*vmax
1471            import pylab; pylab.clf()
1472            plot_models(self.opts, result, limits=self.limits)
1473
1474
1475def main(*argv):
1476    # type: (*str) -> None
1477    """
1478    Main program.
1479    """
1480    opts = parse_opts(argv)
1481    if opts is not None:
1482        if opts['seed'] > -1:
1483            print("Randomize using -random=%i"%opts['seed'])
1484            np.random.seed(opts['seed'])
1485        if opts['html']:
1486            show_docs(opts)
1487        elif opts['explore']:
1488            opts['pars'] = parse_pars(opts)
1489            if opts['pars'] is None:
1490                return
1491            explore(opts)
1492        else:
1493            compare(opts)
1494
1495if __name__ == "__main__":
1496    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.