source: sasmodels/sasmodels/compare.py @ ce8c388

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ce8c388 was d9ec8f9, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

elliptical cylinder axis ratio is 1 or greater

  • Property mode set to 100755
File size: 52.2 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except Exception:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56
57USAGE = """
58usage: sascomp model [options...] [key=val]
59
60Generate and compare SAS models.  If a single model is specified it shows
61a plot of that model.  Different models can be compared, or the same model
62with different parameters.  The same model with the same parameters can
63be compared with different calculation engines to see the effects of precision
64on the resultant values.
65
66model or model1,model2 are the names of the models to compare (see below).
67
68Options (* for default):
69
70    === data generation ===
71    -data="path" uses q, dq from the data file
72    -noise=0 sets the measurement error dI/I
73    -res=0 sets the resolution width dQ/Q if calculating with resolution
74    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
75    -nq=128 sets the number of Q points in the data set
76    -1d*/-2d computes 1d or 2d data
77    -zero indicates that q=0 should be included
78
79    === model parameters ===
80    -preset*/-random[=seed] preset or random parameters
81    -sets=n generates n random datasets with the seed given by -random=seed
82    -pars/-nopars* prints the parameter set or not
83    -default/-demo* use demo vs default parameters
84
85    === calculation options ===
86    -mono*/-poly force monodisperse or allow polydisperse demo parameters
87    -cutoff=1e-5* cutoff value for including a point in polydispersity
88    -magnetic/-nonmagnetic* suppress magnetism
89    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
90
91    === precision options ===
92    -calc=default uses the default calcution precision
93    -single/-double/-half/-fast sets an OpenCL calculation engine
94    -single!/-double!/-quad! sets an OpenMP calculation engine
95    -sasview sets the sasview calculation engine
96
97    === plotting ===
98    -plot*/-noplot plots or suppress the plot of the model
99    -linear/-log*/-q4 intensity scaling on plots
100    -hist/-nohist* plot histogram of relative error
101    -abs/-rel* plot relative or absolute error
102    -title="note" adds note to the plot title, after the model name
103
104    === output options ===
105    -edit starts the parameter explorer
106    -help/-html shows the model docs instead of running the model
107
108The interpretation of quad precision depends on architecture, and may
109vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
110On unix and mac you may need single quotes around the DLL computation
111engines, such as -calc='single!,double!' since !, is treated as a history
112expansion request in the shell.
113
114Key=value pairs allow you to set specific values for the model parameters.
115Key=value1,value2 to compare different values of the same parameter. The
116value can be an expression including other parameters.
117
118Items later on the command line override those that appear earlier.
119
120Examples:
121
122    # compare single and double precision calculation for a barbell
123    sascomp barbell -calc=single,double
124
125    # generate 10 random lorentz models, with seed=27
126    sascomp lorentz -sets=10 -seed=27
127
128    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
129    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
130
131    # model timing test requires multiple evals to perform the estimate
132    sascomp pringle -calc=single,double -timing=100,100 -noplot
133"""
134
135# Update docs with command line usage string.   This is separate from the usual
136# doc string so that we can display it at run time if there is an error.
137# lin
138__doc__ = (__doc__  # pylint: disable=redefined-builtin
139           + """
140Program description
141-------------------
142
143""" + USAGE)
144
145kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
146
147# list of math functions for use in evaluating parameters
148MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
149
150# CRUFT python 2.6
151if not hasattr(datetime.timedelta, 'total_seconds'):
152    def delay(dt):
153        """Return number date-time delta as number seconds"""
154        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
155else:
156    def delay(dt):
157        """Return number date-time delta as number seconds"""
158        return dt.total_seconds()
159
160
161class push_seed(object):
162    """
163    Set the seed value for the random number generator.
164
165    When used in a with statement, the random number generator state is
166    restored after the with statement is complete.
167
168    :Parameters:
169
170    *seed* : int or array_like, optional
171        Seed for RandomState
172
173    :Example:
174
175    Seed can be used directly to set the seed::
176
177        >>> from numpy.random import randint
178        >>> push_seed(24)
179        <...push_seed object at...>
180        >>> print(randint(0,1000000,3))
181        [242082    899 211136]
182
183    Seed can also be used in a with statement, which sets the random
184    number generator state for the enclosed computations and restores
185    it to the previous state on completion::
186
187        >>> with push_seed(24):
188        ...    print(randint(0,1000000,3))
189        [242082    899 211136]
190
191    Using nested contexts, we can demonstrate that state is indeed
192    restored after the block completes::
193
194        >>> with push_seed(24):
195        ...    print(randint(0,1000000))
196        ...    with push_seed(24):
197        ...        print(randint(0,1000000,3))
198        ...    print(randint(0,1000000))
199        242082
200        [242082    899 211136]
201        899
202
203    The restore step is protected against exceptions in the block::
204
205        >>> with push_seed(24):
206        ...    print(randint(0,1000000))
207        ...    try:
208        ...        with push_seed(24):
209        ...            print(randint(0,1000000,3))
210        ...            raise Exception()
211        ...    except Exception:
212        ...        print("Exception raised")
213        ...    print(randint(0,1000000))
214        242082
215        [242082    899 211136]
216        Exception raised
217        899
218    """
219    def __init__(self, seed=None):
220        # type: (Optional[int]) -> None
221        self._state = np.random.get_state()
222        np.random.seed(seed)
223
224    def __enter__(self):
225        # type: () -> None
226        pass
227
228    def __exit__(self, exc_type, exc_value, traceback):
229        # type: (Any, BaseException, Any) -> None
230        # TODO: better typing for __exit__ method
231        np.random.set_state(self._state)
232
233def tic():
234    # type: () -> Callable[[], float]
235    """
236    Timer function.
237
238    Use "toc=tic()" to start the clock and "toc()" to measure
239    a time interval.
240    """
241    then = datetime.datetime.now()
242    return lambda: delay(datetime.datetime.now() - then)
243
244
245def set_beam_stop(data, radius, outer=None):
246    # type: (Data, float, float) -> None
247    """
248    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
249
250    Note: this function does not require sasview
251    """
252    if hasattr(data, 'qx_data'):
253        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
254        data.mask = (q < radius)
255        if outer is not None:
256            data.mask |= (q >= outer)
257    else:
258        data.mask = (data.x < radius)
259        if outer is not None:
260            data.mask |= (data.x >= outer)
261
262
263def parameter_range(p, v):
264    # type: (str, float) -> Tuple[float, float]
265    """
266    Choose a parameter range based on parameter name and initial value.
267    """
268    # process the polydispersity options
269    if p.endswith('_pd_n'):
270        return 0., 100.
271    elif p.endswith('_pd_nsigma'):
272        return 0., 5.
273    elif p.endswith('_pd_type'):
274        raise ValueError("Cannot return a range for a string value")
275    elif any(s in p for s in ('theta', 'phi', 'psi')):
276        # orientation in [-180,180], orientation pd in [0,45]
277        if p.endswith('_pd'):
278            return 0., 45.
279        else:
280            return -180., 180.
281    elif p.endswith('_pd'):
282        return 0., 1.
283    elif 'sld' in p:
284        return -0.5, 10.
285    elif p == 'background':
286        return 0., 10.
287    elif p == 'scale':
288        return 0., 1.e3
289    elif v < 0.:
290        return 2.*v, -2.*v
291    else:
292        return 0., (2.*v if v > 0. else 1.)
293
294
295def _randomize_one(model_info, name, value):
296    # type: (ModelInfo, str, float) -> float
297    # type: (ModelInfo, str, str) -> str
298    """
299    Randomize a single parameter.
300    """
301    # Set the amount of polydispersity/angular dispersion, but by default pd_n
302    # is zero so there is no polydispersity.  This allows us to turn on/off
303    # pd by setting pd_n, and still have randomly generated values
304    if name.endswith('_pd'):
305        par = model_info.parameters[name[:-3]]
306        if par.type == 'orientation':
307            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
308            return 180*np.random.beta(2.5, 20)
309        else:
310            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
311            return np.random.beta(1.5, 7)
312
313    # pd is selected globally rather than per parameter, so set to 0 for no pd
314    # In particular, when multiple pd dimensions, want to decrease the number
315    # of points per dimension for faster computation
316    if name.endswith('_pd_n'):
317        return 0
318
319    # Don't mess with distribution type for now
320    if name.endswith('_pd_type'):
321        return 'gaussian'
322
323    # type-dependent value of number of sigmas; for gaussian use 3.
324    if name.endswith('_pd_nsigma'):
325        return 3.
326
327    # background in the range [0.01, 1]
328    if name == 'background':
329        return 10**np.random.uniform(-2, 0)
330
331    # scale defaults to 0.1% to 30% volume fraction
332    if name == 'scale':
333        return 10**np.random.uniform(-3, -0.5)
334
335    # If it is a list of choices, pick one at random with equal probability
336    # In practice, the model specific random generator will override.
337    par = model_info.parameters[name]
338    if len(par.limits) > 2:  # choice list
339        return np.random.randint(len(par.limits))
340
341    # If it is a fixed range, pick from it with equal probability.
342    # For logarithmic ranges, the model will have to override.
343    if np.isfinite(par.limits).all():
344        return np.random.uniform(*par.limits)
345
346    # If the paramter is marked as an sld use the range of neutron slds
347    # TODO: ought to randomly contrast match a pair of SLDs
348    if par.type == 'sld':
349        return np.random.uniform(-0.5, 12)
350
351    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
352    if par.name.startswith('M0:'):
353        return np.random.uniform(0, 5)
354
355    # Guess at the random length/radius/thickness.  In practice, all models
356    # are going to set their own reasonable ranges.
357    if par.type == 'volume':
358        if ('length' in par.name or
359                'radius' in par.name or
360                'thick' in par.name):
361            return 10**np.random.uniform(2, 4)
362
363    # In the absence of any other info, select a value in [0, 2v], or
364    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
365    # model random parameter generators will override this default.
366    low, high = parameter_range(par.name, value)
367    limits = (max(par.limits[0], low), min(par.limits[1], high))
368    return np.random.uniform(*limits)
369
370def _random_pd(model_info, pars):
371    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
372    pd_volume = []
373    pd_oriented = []
374    for p in pd:
375        if p.type == 'orientation':
376            pd_oriented.append(p.name)
377        elif p.length_control is not None:
378            n = int(pars.get(p.length_control, 1) + 0.5)
379            pd_volume.extend(p.name+str(k+1) for k in range(n))
380        elif p.length > 1:
381            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
382        else:
383            pd_volume.append(p.name)
384    u = np.random.rand()
385    n = len(pd_volume)
386    if u < 0.01 or n < 1:
387        pass  # 1% chance of no polydispersity
388    elif u < 0.86 or n < 2:
389        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
390    elif u < 0.99 or n < 3:
391        choices = np.random.choice(len(pd_volume), size=2)
392        pars[pd_volume[choices[0]]+"_pd_n"] = 25
393        pars[pd_volume[choices[1]]+"_pd_n"] = 10
394    else:
395        choices = np.random.choice(len(pd_volume), size=3)
396        pars[pd_volume[choices[0]]+"_pd_n"] = 25
397        pars[pd_volume[choices[1]]+"_pd_n"] = 10
398        pars[pd_volume[choices[2]]+"_pd_n"] = 5
399    if pd_oriented:
400        pars['theta_pd_n'] = 20
401        if np.random.rand() < 0.1:
402            pars['phi_pd_n'] = 5
403        if np.random.rand() < 0.1:
404            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
405                #print("generating psi_pd_n")
406                pars['psi_pd_n'] = 5
407
408    ## Show selected polydispersity
409    #for name, value in pars.items():
410    #    if name.endswith('_pd_n') and value > 0:
411    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
412
413
414def randomize_pars(model_info, pars):
415    # type: (ModelInfo, ParameterSet) -> ParameterSet
416    """
417    Generate random values for all of the parameters.
418
419    Valid ranges for the random number generator are guessed from the name of
420    the parameter; this will not account for constraints such as cap radius
421    greater than cylinder radius in the capped_cylinder model, so
422    :func:`constrain_pars` needs to be called afterward..
423    """
424    # Note: the sort guarantees order of calls to random number generator
425    random_pars = dict((p, _randomize_one(model_info, p, v))
426                       for p, v in sorted(pars.items()))
427    if model_info.random is not None:
428        random_pars.update(model_info.random())
429    _random_pd(model_info, random_pars)
430    return random_pars
431
432
433def constrain_pars(model_info, pars):
434    # type: (ModelInfo, ParameterSet) -> None
435    """
436    Restrict parameters to valid values.
437
438    This includes model specific code for models such as capped_cylinder
439    which need to support within model constraints (cap radius more than
440    cylinder radius in this case).
441
442    Warning: this updates the *pars* dictionary in place.
443    """
444    # TODO: move the model specific code to the individual models
445    name = model_info.id
446    # if it is a product model, then just look at the form factor since
447    # none of the structure factors need any constraints.
448    if '*' in name:
449        name = name.split('*')[0]
450
451    # Suppress magnetism for python models (not yet implemented)
452    if callable(model_info.Iq):
453        pars.update(suppress_magnetism(pars))
454
455    if name == 'barbell':
456        if pars['radius_bell'] < pars['radius']:
457            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
458
459    elif name == 'capped_cylinder':
460        if pars['radius_cap'] < pars['radius']:
461            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
462
463    elif name == 'guinier':
464        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
465        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
466        # => ln A - (Rg^2 q^2/3) > -30 ln 10
467        # => Rg^2 q^2/3 < 30 ln 10 + ln A
468        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
469        #q_max = 0.2  # mid q maximum
470        q_max = 1.0  # high q maximum
471        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
472        pars['rg'] = min(pars['rg'], rg_max)
473
474    elif name == 'pearl_necklace':
475        if pars['radius'] < pars['thick_string']:
476            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
477        pass
478
479    elif name == 'rpa':
480        # Make sure phi sums to 1.0
481        if pars['case_num'] < 2:
482            pars['Phi1'] = 0.
483            pars['Phi2'] = 0.
484        elif pars['case_num'] < 5:
485            pars['Phi1'] = 0.
486        total = sum(pars['Phi'+c] for c in '1234')
487        for c in '1234':
488            pars['Phi'+c] /= total
489
490def parlist(model_info, pars, is2d):
491    # type: (ModelInfo, ParameterSet, bool) -> str
492    """
493    Format the parameter list for printing.
494    """
495    lines = []
496    parameters = model_info.parameters
497    magnetic = False
498    magnetic_pars = []
499    for p in parameters.user_parameters(pars, is2d):
500        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
501            continue
502        if p.id.startswith('up:'):
503            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
504            continue
505        fields = dict(
506            value=pars.get(p.id, p.default),
507            pd=pars.get(p.id+"_pd", 0.),
508            n=int(pars.get(p.id+"_pd_n", 0)),
509            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
510            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
511            relative_pd=p.relative_pd,
512            M0=pars.get('M0:'+p.id, 0.),
513            mphi=pars.get('mphi:'+p.id, 0.),
514            mtheta=pars.get('mtheta:'+p.id, 0.),
515        )
516        lines.append(_format_par(p.name, **fields))
517        magnetic = magnetic or fields['M0'] != 0.
518    if magnetic and magnetic_pars:
519        lines.append(" ".join(magnetic_pars))
520    return "\n".join(lines)
521
522    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
523
524def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
525                relative_pd=False, M0=0., mphi=0., mtheta=0.):
526    # type: (str, float, float, int, float, str) -> str
527    line = "%s: %g"%(name, value)
528    if pd != 0.  and n != 0:
529        if relative_pd:
530            pd *= value
531        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
532                % (pd, n, nsigma, nsigma, pdtype)
533    if M0 != 0.:
534        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
535    return line
536
537def suppress_pd(pars, suppress=True):
538    # type: (ParameterSet) -> ParameterSet
539    """
540    If suppress is True complete eliminate polydispersity of the model to test
541    models more quickly.  If suppress is False, make sure at least one
542    parameter is polydisperse, setting the first polydispersity parameter to
543    15% if no polydispersity is given (with no explicit demo parameters given
544    in the model, there will be no default polydispersity).
545    """
546    pars = pars.copy()
547    #print("pars=", pars)
548    if suppress:
549        for p in pars:
550            if p.endswith("_pd_n"):
551                pars[p] = 0
552    else:
553        any_pd = False
554        first_pd = None
555        for p in pars:
556            if p.endswith("_pd_n"):
557                pd = pars.get(p[:-2], 0.)
558                any_pd |= (pars[p] != 0 and pd != 0.)
559                if first_pd is None:
560                    first_pd = p
561        if not any_pd and first_pd is not None:
562            if pars[first_pd] == 0:
563                pars[first_pd] = 35
564            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
565                pars[first_pd[:-2]] = 0.15
566    return pars
567
568def suppress_magnetism(pars, suppress=True):
569    # type: (ParameterSet) -> ParameterSet
570    """
571    If suppress is True complete eliminate magnetism of the model to test
572    models more quickly.  If suppress is False, make sure at least one sld
573    parameter is magnetic, setting the first parameter to have a strong
574    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
575    in the model, there will be no default magnetism).
576    """
577    pars = pars.copy()
578    if suppress:
579        for p in pars:
580            if p.startswith("M0:"):
581                pars[p] = 0
582    else:
583        any_mag = False
584        first_mag = None
585        for p in pars:
586            if p.startswith("M0:"):
587                any_mag |= (pars[p] != 0)
588                if first_mag is None:
589                    first_mag = p
590        if not any_mag and first_mag is not None:
591            pars[first_mag] = 8.
592    return pars
593
594def eval_sasview(model_info, data):
595    # type: (Modelinfo, Data) -> Calculator
596    """
597    Return a model calculator using the pre-4.0 SasView models.
598    """
599    # importing sas here so that the error message will be that sas failed to
600    # import rather than the more obscure smear_selection not imported error
601    import sas
602    import sas.models
603    from sas.models.qsmearing import smear_selection
604    from sas.models.MultiplicationModel import MultiplicationModel
605    from sas.models.dispersion_models import models as dispersers
606
607    def get_model_class(name):
608        # type: (str) -> "sas.models.BaseComponent"
609        #print("new",sorted(_pars.items()))
610        __import__('sas.models.' + name)
611        ModelClass = getattr(getattr(sas.models, name, None), name, None)
612        if ModelClass is None:
613            raise ValueError("could not find model %r in sas.models"%name)
614        return ModelClass
615
616    # WARNING: ugly hack when handling model!
617    # Sasview models with multiplicity need to be created with the target
618    # multiplicity, so we cannot create the target model ahead of time for
619    # for multiplicity models.  Instead we store the model in a list and
620    # update the first element of that list with the new multiplicity model
621    # every time we evaluate.
622
623    # grab the sasview model, or create it if it is a product model
624    if model_info.composition:
625        composition_type, parts = model_info.composition
626        if composition_type == 'product':
627            P, S = [get_model_class(revert_name(p))() for p in parts]
628            model = [MultiplicationModel(P, S)]
629        else:
630            raise ValueError("sasview mixture models not supported by compare")
631    else:
632        old_name = revert_name(model_info)
633        if old_name is None:
634            raise ValueError("model %r does not exist in old sasview"
635                            % model_info.id)
636        ModelClass = get_model_class(old_name)
637        model = [ModelClass()]
638    model[0].disperser_handles = {}
639
640    # build a smearer with which to call the model, if necessary
641    smearer = smear_selection(data, model=model)
642    if hasattr(data, 'qx_data'):
643        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
644        index = ((~data.mask) & (~np.isnan(data.data))
645                 & (q >= data.qmin) & (q <= data.qmax))
646        if smearer is not None:
647            smearer.model = model  # because smear_selection has a bug
648            smearer.accuracy = data.accuracy
649            smearer.set_index(index)
650            def _call_smearer():
651                smearer.model = model[0]
652                return smearer.get_value()
653            theory = _call_smearer
654        else:
655            theory = lambda: model[0].evalDistribution([data.qx_data[index],
656                                                        data.qy_data[index]])
657    elif smearer is not None:
658        theory = lambda: smearer(model[0].evalDistribution(data.x))
659    else:
660        theory = lambda: model[0].evalDistribution(data.x)
661
662    def calculator(**pars):
663        # type: (float, ...) -> np.ndarray
664        """
665        Sasview calculator for model.
666        """
667        oldpars = revert_pars(model_info, pars)
668        # For multiplicity models, create a model with the correct multiplicity
669        control = oldpars.pop("CONTROL", None)
670        if control is not None:
671            # sphericalSLD has one fewer multiplicity.  This update should
672            # happen in revert_pars, but it hasn't been called yet.
673            model[0] = ModelClass(control)
674        # paying for parameter conversion each time to keep life simple, if not fast
675        for k, v in oldpars.items():
676            if k.endswith('.type'):
677                par = k[:-5]
678                if v == 'gaussian': continue
679                cls = dispersers[v if v != 'rectangle' else 'rectangula']
680                handle = cls()
681                model[0].disperser_handles[par] = handle
682                try:
683                    model[0].set_dispersion(par, handle)
684                except Exception:
685                    exception.annotate_exception("while setting %s to %r"
686                                                 %(par, v))
687                    raise
688
689
690        #print("sasview pars",oldpars)
691        for k, v in oldpars.items():
692            name_attr = k.split('.')  # polydispersity components
693            if len(name_attr) == 2:
694                par, disp_par = name_attr
695                model[0].dispersion[par][disp_par] = v
696            else:
697                model[0].setParam(k, v)
698        return theory()
699
700    calculator.engine = "sasview"
701    return calculator
702
703DTYPE_MAP = {
704    'half': '16',
705    'fast': 'fast',
706    'single': '32',
707    'double': '64',
708    'quad': '128',
709    'f16': '16',
710    'f32': '32',
711    'f64': '64',
712    'float16': '16',
713    'float32': '32',
714    'float64': '64',
715    'float128': '128',
716    'longdouble': '128',
717}
718def eval_opencl(model_info, data, dtype='single', cutoff=0.):
719    # type: (ModelInfo, Data, str, float) -> Calculator
720    """
721    Return a model calculator using the OpenCL calculation engine.
722    """
723    if not core.HAVE_OPENCL:
724        raise RuntimeError("OpenCL not available")
725    model = core.build_model(model_info, dtype=dtype, platform="ocl")
726    calculator = DirectModel(data, model, cutoff=cutoff)
727    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
728    return calculator
729
730def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
731    # type: (ModelInfo, Data, str, float) -> Calculator
732    """
733    Return a model calculator using the DLL calculation engine.
734    """
735    model = core.build_model(model_info, dtype=dtype, platform="dll")
736    calculator = DirectModel(data, model, cutoff=cutoff)
737    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
738    return calculator
739
740def time_calculation(calculator, pars, evals=1):
741    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
742    """
743    Compute the average calculation time over N evaluations.
744
745    An additional call is generated without polydispersity in order to
746    initialize the calculation engine, and make the average more stable.
747    """
748    # initialize the code so time is more accurate
749    if evals > 1:
750        calculator(**suppress_pd(pars))
751    toc = tic()
752    # make sure there is at least one eval
753    value = calculator(**pars)
754    for _ in range(evals-1):
755        value = calculator(**pars)
756    average_time = toc()*1000. / evals
757    #print("I(q)",value)
758    return value, average_time
759
760def make_data(opts):
761    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
762    """
763    Generate an empty dataset, used with the model to set Q points
764    and resolution.
765
766    *opts* contains the options, with 'qmax', 'nq', 'res',
767    'accuracy', 'is2d' and 'view' parsed from the command line.
768    """
769    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
770    if opts['is2d']:
771        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
772        data = empty_data2D(q, resolution=res)
773        data.accuracy = opts['accuracy']
774        set_beam_stop(data, 0.0004)
775        index = ~data.mask
776    else:
777        if opts['view'] == 'log' and not opts['zero']:
778            qmax = math.log10(qmax)
779            q = np.logspace(qmax-3, qmax, nq)
780        else:
781            q = np.linspace(0.001*qmax, qmax, nq)
782        if opts['zero']:
783            q = np.hstack((0, q))
784        data = empty_data1D(q, resolution=res)
785        index = slice(None, None)
786    return data, index
787
788def make_engine(model_info, data, dtype, cutoff):
789    # type: (ModelInfo, Data, str, float) -> Calculator
790    """
791    Generate the appropriate calculation engine for the given datatype.
792
793    Datatypes with '!' appended are evaluated using external C DLLs rather
794    than OpenCL.
795    """
796    if dtype == 'sasview':
797        return eval_sasview(model_info, data)
798    elif dtype is None or not dtype.endswith('!'):
799        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
800    else:
801        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
802
803def _show_invalid(data, theory):
804    # type: (Data, np.ma.ndarray) -> None
805    """
806    Display a list of the non-finite values in theory.
807    """
808    if not theory.mask.any():
809        return
810
811    if hasattr(data, 'x'):
812        bad = zip(data.x[theory.mask], theory[theory.mask])
813        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
814
815
816def compare(opts, limits=None):
817    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
818    """
819    Preform a comparison using options from the command line.
820
821    *limits* are the limits on the values to use, either to set the y-axis
822    for 1D or to set the colormap scale for 2D.  If None, then they are
823    inferred from the data and returned. When exploring using Bumps,
824    the limits are set when the model is initially called, and maintained
825    as the values are adjusted, making it easier to see the effects of the
826    parameters.
827    """
828    limits = np.Inf, -np.Inf
829    for k in range(opts['sets']):
830        opts['pars'] = parse_pars(opts)
831        if opts['pars'] is None:
832            return
833        result = run_models(opts, verbose=True)
834        if opts['plot']:
835            limits = plot_models(opts, result, limits=limits, setnum=k)
836    if opts['plot']:
837        import matplotlib.pyplot as plt
838        plt.show()
839
840def run_models(opts, verbose=False):
841    # type: (Dict[str, Any]) -> Dict[str, Any]
842
843    base, comp = opts['engines']
844    base_n, comp_n = opts['count']
845    base_pars, comp_pars = opts['pars']
846    data = opts['data']
847
848    comparison = comp is not None
849
850    base_time = comp_time = None
851    base_value = comp_value = resid = relerr = None
852
853    # Base calculation
854    try:
855        base_raw, base_time = time_calculation(base, base_pars, base_n)
856        base_value = np.ma.masked_invalid(base_raw)
857        if verbose:
858            print("%s t=%.2f ms, intensity=%.0f"
859                  % (base.engine, base_time, base_value.sum()))
860        _show_invalid(data, base_value)
861    except ImportError:
862        traceback.print_exc()
863
864    # Comparison calculation
865    if comparison:
866        try:
867            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
868            comp_value = np.ma.masked_invalid(comp_raw)
869            if verbose:
870                print("%s t=%.2f ms, intensity=%.0f"
871                      % (comp.engine, comp_time, comp_value.sum()))
872            _show_invalid(data, comp_value)
873        except ImportError:
874            traceback.print_exc()
875
876    # Compare, but only if computing both forms
877    if comparison:
878        resid = (base_value - comp_value)
879        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
880        if verbose:
881            _print_stats("|%s-%s|"
882                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
883                         resid)
884            _print_stats("|(%s-%s)/%s|"
885                         % (base.engine, comp.engine, comp.engine),
886                         relerr)
887
888    return dict(base_value=base_value, comp_value=comp_value,
889                base_time=base_time, comp_time=comp_time,
890                resid=resid, relerr=relerr)
891
892
893def _print_stats(label, err):
894    # type: (str, np.ma.ndarray) -> None
895    # work with trimmed data, not the full set
896    sorted_err = np.sort(abs(err.compressed()))
897    if len(sorted_err) == 0.:
898        print(label + "  no valid values")
899        return
900
901    p50 = int((len(sorted_err)-1)*0.50)
902    p98 = int((len(sorted_err)-1)*0.98)
903    data = [
904        "max:%.3e"%sorted_err[-1],
905        "median:%.3e"%sorted_err[p50],
906        "98%%:%.3e"%sorted_err[p98],
907        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
908        "zero-offset:%+.3e"%np.mean(sorted_err),
909        ]
910    print(label+"  "+"  ".join(data))
911
912
913def plot_models(opts, result, limits=(np.Inf, -np.Inf), setnum=0):
914    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
915    base_value, comp_value = result['base_value'], result['comp_value']
916    base_time, comp_time = result['base_time'], result['comp_time']
917    resid, relerr = result['resid'], result['relerr']
918
919    have_base, have_comp = (base_value is not None), (comp_value is not None)
920    base, comp = opts['engines']
921    data = opts['data']
922    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
923
924    # Plot if requested
925    view = opts['view']
926    import matplotlib.pyplot as plt
927    vmin, vmax = limits
928    if have_base:
929        vmin = min(vmin, base_value.min())
930        vmax = max(vmax, base_value.max())
931    if have_comp:
932        vmin = min(vmin, comp_value.min())
933        vmax = max(vmax, comp_value.max())
934    limits = vmin, vmax
935
936    if have_base:
937        if have_comp:
938            plt.subplot(131)
939        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
940        plt.title("%s t=%.2f ms"%(base.engine, base_time))
941        #cbar_title = "log I"
942    if have_comp:
943        if have_base:
944            plt.subplot(132)
945        if not opts['is2d'] and have_base:
946            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
947        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
948        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
949        #cbar_title = "log I"
950    if have_base and have_comp:
951        plt.subplot(133)
952        if not opts['rel_err']:
953            err, errstr, errview = resid, "abs err", "linear"
954        else:
955            err, errstr, errview = abs(relerr), "rel err", "log"
956        if 0:  # 95% cutoff
957            sorted = np.sort(err.flatten())
958            cutoff = sorted[int(sorted.size*0.95)]
959            err[err > cutoff] = cutoff
960        #err,errstr = base/comp,"ratio"
961        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
962        if view == 'linear':
963            plt.xscale('linear')
964        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
965        #cbar_title = errstr if errview=="linear" else "log "+errstr
966    #if is2D:
967    #    h = plt.colorbar()
968    #    h.ax.set_title(cbar_title)
969    fig = plt.gcf()
970    extra_title = ' '+opts['title'] if opts['title'] else ''
971    fig.suptitle(":".join(opts['name']) + extra_title)
972
973    if have_base and have_comp and opts['show_hist']:
974        plt.figure()
975        v = relerr
976        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
977        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
978        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
979                   % (base.engine, comp.engine, comp.engine))
980        plt.ylabel('P(err)')
981        plt.title('Distribution of relative error between calculation engines')
982
983    return limits
984
985
986# ===========================================================================
987#
988
989# Set of command line options.
990# Normal options such as -plot/-noplot are specified as 'name'.
991# For options such as -nq=500 which require a value use 'name='.
992#
993OPTIONS = [
994    # Plotting
995    'plot', 'noplot',
996    'linear', 'log', 'q4',
997    'rel', 'abs',
998    'hist', 'nohist',
999    'title=',
1000
1001    # Data generation
1002    'data=', 'noise=', 'res=',
1003    'nq=', 'lowq', 'midq', 'highq', 'exq', 'zero',
1004    '2d', '1d',
1005
1006    # Parameter set
1007    'preset', 'random', 'random=', 'sets=',
1008    'demo', 'default',  # TODO: remove demo/default
1009    'nopars', 'pars',
1010
1011    # Calculation options
1012    'poly', 'mono', 'cutoff=',
1013    'magnetic', 'nonmagnetic',
1014    'accuracy=',
1015
1016    # Precision options
1017    'calc=',
1018    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1019    'sasview',  # TODO: remove sasview 3.x support
1020    'timing=',
1021
1022    # Output options
1023    'help', 'html', 'edit',
1024    ]
1025
1026NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1027VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1028
1029
1030def columnize(items, indent="", width=79):
1031    # type: (List[str], str, int) -> str
1032    """
1033    Format a list of strings into columns.
1034
1035    Returns a string with carriage returns ready for printing.
1036    """
1037    column_width = max(len(w) for w in items) + 1
1038    num_columns = (width - len(indent)) // column_width
1039    num_rows = len(items) // num_columns
1040    items = items + [""] * (num_rows * num_columns - len(items))
1041    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1042    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1043             for row in zip(*columns)]
1044    output = indent + ("\n"+indent).join(lines)
1045    return output
1046
1047
1048def get_pars(model_info, use_demo=False):
1049    # type: (ModelInfo, bool) -> ParameterSet
1050    """
1051    Extract demo parameters from the model definition.
1052    """
1053    # Get the default values for the parameters
1054    pars = {}
1055    for p in model_info.parameters.call_parameters:
1056        parts = [('', p.default)]
1057        if p.polydisperse:
1058            parts.append(('_pd', 0.0))
1059            parts.append(('_pd_n', 0))
1060            parts.append(('_pd_nsigma', 3.0))
1061            parts.append(('_pd_type', "gaussian"))
1062        for ext, val in parts:
1063            if p.length > 1:
1064                dict(("%s%d%s" % (p.id, k, ext), val)
1065                     for k in range(1, p.length+1))
1066            else:
1067                pars[p.id + ext] = val
1068
1069    # Plug in values given in demo
1070    if use_demo:
1071        pars.update(model_info.demo)
1072    return pars
1073
1074INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1075def isnumber(str):
1076    match = FLOAT_RE.match(str)
1077    isfloat = (match and not str[match.end():])
1078    return isfloat or INTEGER_RE.match(str)
1079
1080# For distinguishing pairs of models for comparison
1081# key-value pair separator =
1082# shell characters  | & ; <> $ % ' " \ # `
1083# model and parameter names _
1084# parameter expressions - + * / . ( )
1085# path characters including tilde expansion and windows drive ~ / :
1086# not sure about brackets [] {}
1087# maybe one of the following @ ? ^ ! ,
1088PAR_SPLIT = ','
1089def parse_opts(argv):
1090    # type: (List[str]) -> Dict[str, Any]
1091    """
1092    Parse command line options.
1093    """
1094    MODELS = core.list_models()
1095    flags = [arg for arg in argv
1096             if arg.startswith('-')]
1097    values = [arg for arg in argv
1098              if not arg.startswith('-') and '=' in arg]
1099    positional_args = [arg for arg in argv
1100                       if not arg.startswith('-') and '=' not in arg]
1101    models = "\n    ".join("%-15s"%v for v in MODELS)
1102    if len(positional_args) == 0:
1103        print(USAGE)
1104        print("\nAvailable models:")
1105        print(columnize(MODELS, indent="  "))
1106        return None
1107
1108    invalid = [o[1:] for o in flags
1109               if o[1:] not in NAME_OPTIONS
1110               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1111    if invalid:
1112        print("Invalid options: %s"%(", ".join(invalid)))
1113        return None
1114
1115    name = positional_args[-1]
1116
1117    # pylint: disable=bad-whitespace
1118    # Interpret the flags
1119    opts = {
1120        'plot'      : True,
1121        'view'      : 'log',
1122        'is2d'      : False,
1123        'qmax'      : 0.05,
1124        'nq'        : 128,
1125        'res'       : 0.0,
1126        'noise'     : 0.0,
1127        'accuracy'  : 'Low',
1128        'cutoff'    : '0.0',
1129        'seed'      : -1,  # default to preset
1130        'mono'      : True,
1131        # Default to magnetic a magnetic moment is set on the command line
1132        'magnetic'  : False,
1133        'show_pars' : False,
1134        'show_hist' : False,
1135        'rel_err'   : True,
1136        'explore'   : False,
1137        'use_demo'  : True,
1138        'zero'      : False,
1139        'html'      : False,
1140        'title'     : None,
1141        'datafile'  : None,
1142        'sets'      : 0,
1143        'engine'    : 'default',
1144        'evals'     : '1',
1145    }
1146    for arg in flags:
1147        if arg == '-noplot':    opts['plot'] = False
1148        elif arg == '-plot':    opts['plot'] = True
1149        elif arg == '-linear':  opts['view'] = 'linear'
1150        elif arg == '-log':     opts['view'] = 'log'
1151        elif arg == '-q4':      opts['view'] = 'q4'
1152        elif arg == '-1d':      opts['is2d'] = False
1153        elif arg == '-2d':      opts['is2d'] = True
1154        elif arg == '-exq':     opts['qmax'] = 10.0
1155        elif arg == '-highq':   opts['qmax'] = 1.0
1156        elif arg == '-midq':    opts['qmax'] = 0.2
1157        elif arg == '-lowq':    opts['qmax'] = 0.05
1158        elif arg == '-zero':    opts['zero'] = True
1159        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1160        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1161        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1162        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1163        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1164        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1165        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
1166        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1167        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1168        elif arg.startswith('-calc='):     opts['engine'] = arg[6:]
1169        elif arg.startswith('-neval='):    opts['evals'] = arg[7:]
1170        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
1171        elif arg == '-preset':  opts['seed'] = -1
1172        elif arg == '-mono':    opts['mono'] = True
1173        elif arg == '-poly':    opts['mono'] = False
1174        elif arg == '-magnetic':       opts['magnetic'] = True
1175        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1176        elif arg == '-pars':    opts['show_pars'] = True
1177        elif arg == '-nopars':  opts['show_pars'] = False
1178        elif arg == '-hist':    opts['show_hist'] = True
1179        elif arg == '-nohist':  opts['show_hist'] = False
1180        elif arg == '-rel':     opts['rel_err'] = True
1181        elif arg == '-abs':     opts['rel_err'] = False
1182        elif arg == '-half':    opts['engine'] = 'half'
1183        elif arg == '-fast':    opts['engine'] = 'fast'
1184        elif arg == '-single':  opts['engine'] = 'single'
1185        elif arg == '-double':  opts['engine'] = 'double'
1186        elif arg == '-single!': opts['engine'] = 'single!'
1187        elif arg == '-double!': opts['engine'] = 'double!'
1188        elif arg == '-quad!':   opts['engine'] = 'quad!'
1189        elif arg == '-sasview': opts['engine'] = 'sasview'
1190        elif arg == '-edit':    opts['explore'] = True
1191        elif arg == '-demo':    opts['use_demo'] = True
1192        elif arg == '-default': opts['use_demo'] = False
1193        elif arg == '-html':    opts['html'] = True
1194        elif arg == '-help':    opts['html'] = True
1195    # pylint: enable=bad-whitespace
1196
1197    # Magnetism forces 2D for now
1198    if opts['magnetic']:
1199        opts['is2d'] = True
1200
1201    # Force random if sets is used
1202    if opts['sets'] >= 1 and opts['seed'] < 0:
1203        opts['seed'] = np.random.randint(1000000)
1204    if opts['sets'] == 0:
1205        opts['sets'] = 1
1206
1207    # Create the computational engines
1208    if opts['datafile'] is not None:
1209        data = load_data(os.path.expanduser(opts['datafile']))
1210    else:
1211        data, _ = make_data(opts)
1212
1213    comparison = any(PAR_SPLIT in v for v in values)
1214    if PAR_SPLIT in name:
1215        names = name.split(PAR_SPLIT, 2)
1216        comparison = True
1217    else:
1218        names = [name]*2
1219    try:
1220        model_info = [core.load_model_info(k) for k in names]
1221    except ImportError as exc:
1222        print(str(exc))
1223        print("Could not find model; use one of:\n    " + models)
1224        return None
1225
1226    if PAR_SPLIT in opts['engine']:
1227        engine_types = opts['engine'].split(PAR_SPLIT, 2)
1228        comparison = True
1229    else:
1230        engine_types = [opts['engine']]*2
1231
1232    if PAR_SPLIT in opts['evals']:
1233        evals = [int(k) for k in opts['evals'].split(PAR_SPLIT, 2)]
1234        comparison = True
1235    else:
1236        evals = [int(opts['evals'])]*2
1237
1238    if PAR_SPLIT in opts['cutoff']:
1239        cutoff = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1240        comparison = True
1241    else:
1242        cutoff = [float(opts['cutoff'])]*2
1243
1244    base = make_engine(model_info[0], data, engine_types[0], cutoff[0])
1245    if comparison:
1246        comp = make_engine(model_info[1], data, engine_types[1], cutoff[1])
1247    else:
1248        comp = None
1249
1250    # pylint: disable=bad-whitespace
1251    # Remember it all
1252    opts.update({
1253        'data'      : data,
1254        'name'      : names,
1255        'def'       : model_info,
1256        'count'     : evals,
1257        'engines'   : [base, comp],
1258        'values'    : values,
1259    })
1260    # pylint: enable=bad-whitespace
1261
1262    return opts
1263
1264def parse_pars(opts):
1265    model_info, model_info2 = opts['def']
1266
1267    # Get demo parameters from model definition, or use default parameters
1268    # if model does not define demo parameters
1269    pars = get_pars(model_info, opts['use_demo'])
1270    pars2 = get_pars(model_info2, opts['use_demo'])
1271    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1272    # randomize parameters
1273    #pars.update(set_pars)  # set value before random to control range
1274    if opts['seed'] > -1:
1275        pars = randomize_pars(model_info, pars)
1276        if model_info != model_info2:
1277            pars2 = randomize_pars(model_info2, pars2)
1278            # Share values for parameters with the same name
1279            for k, v in pars.items():
1280                if k in pars2:
1281                    pars2[k] = v
1282        else:
1283            pars2 = pars.copy()
1284        constrain_pars(model_info, pars)
1285        constrain_pars(model_info2, pars2)
1286    pars = suppress_pd(pars, opts['mono'])
1287    pars2 = suppress_pd(pars2, opts['mono'])
1288    pars = suppress_magnetism(pars, not opts['magnetic'])
1289    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1290
1291    # Fill in parameters given on the command line
1292    presets = {}
1293    presets2 = {}
1294    for arg in opts['values']:
1295        k, v = arg.split('=', 1)
1296        if k not in pars and k not in pars2:
1297            # extract base name without polydispersity info
1298            s = set(p.split('_pd')[0] for p in pars)
1299            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1300            return None
1301        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
1302        if v1 and k in pars:
1303            presets[k] = float(v1) if isnumber(v1) else v1
1304        if v2 and k in pars2:
1305            presets2[k] = float(v2) if isnumber(v2) else v2
1306
1307    # If pd given on the command line, default pd_n to 35
1308    for k, v in list(presets.items()):
1309        if k.endswith('_pd'):
1310            presets.setdefault(k+'_n', 35.)
1311    for k, v in list(presets2.items()):
1312        if k.endswith('_pd'):
1313            presets2.setdefault(k+'_n', 35.)
1314
1315    # Evaluate preset parameter expressions
1316    context = MATH.copy()
1317    context['np'] = np
1318    context.update(pars)
1319    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1320    for k, v in presets.items():
1321        if not isinstance(v, float) and not k.endswith('_type'):
1322            presets[k] = eval(v, context)
1323    context.update(presets)
1324    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1325    for k, v in presets2.items():
1326        if not isinstance(v, float) and not k.endswith('_type'):
1327            presets2[k] = eval(v, context)
1328
1329    # update parameters with presets
1330    pars.update(presets)  # set value after random to control value
1331    pars2.update(presets2)  # set value after random to control value
1332    #import pprint; pprint.pprint(model_info)
1333
1334    if opts['show_pars']:
1335        if model_info.name != model_info2.name or pars != pars2:
1336            print("==== %s ====="%model_info.name)
1337            print(str(parlist(model_info, pars, opts['is2d'])))
1338            print("==== %s ====="%model_info2.name)
1339            print(str(parlist(model_info2, pars2, opts['is2d'])))
1340        else:
1341            print(str(parlist(model_info, pars, opts['is2d'])))
1342
1343    return pars, pars2
1344
1345def show_docs(opts):
1346    # type: (Dict[str, Any]) -> None
1347    """
1348    show html docs for the model
1349    """
1350    import os
1351    from .generate import make_html
1352    from . import rst2html
1353
1354    info = opts['def'][0]
1355    html = make_html(info)
1356    path = os.path.dirname(info.filename)
1357    url = "file://"+path.replace("\\","/")[2:]+"/"
1358    rst2html.view_html_qtapp(html, url)
1359
1360def explore(opts):
1361    # type: (Dict[str, Any]) -> None
1362    """
1363    explore the model using the bumps gui.
1364    """
1365    import wx  # type: ignore
1366    from bumps.names import FitProblem  # type: ignore
1367    from bumps.gui.app_frame import AppFrame  # type: ignore
1368    from bumps.gui import signal
1369
1370    is_mac = "cocoa" in wx.version()
1371    # Create an app if not running embedded
1372    app = wx.App() if wx.GetApp() is None else None
1373    model = Explore(opts)
1374    problem = FitProblem(model)
1375    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1376    if not is_mac:
1377        frame.Show()
1378    frame.panel.set_model(model=problem)
1379    frame.panel.Layout()
1380    frame.panel.aui.Split(0, wx.TOP)
1381    def reset_parameters(event):
1382        model.revert_values()
1383        signal.update_parameters(problem)
1384    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1385    if is_mac: frame.Show()
1386    # If running withing an app, start the main loop
1387    if app:
1388        app.MainLoop()
1389
1390class Explore(object):
1391    """
1392    Bumps wrapper for a SAS model comparison.
1393
1394    The resulting object can be used as a Bumps fit problem so that
1395    parameters can be adjusted in the GUI, with plots updated on the fly.
1396    """
1397    def __init__(self, opts):
1398        # type: (Dict[str, Any]) -> None
1399        from bumps.cli import config_matplotlib  # type: ignore
1400        from . import bumps_model
1401        config_matplotlib()
1402        self.opts = opts
1403        opts['pars'] = list(opts['pars'])
1404        p1, p2 = opts['pars']
1405        m1, m2 = opts['def']
1406        self.fix_p2 = m1 != m2 or p1 != p2
1407        model_info = m1
1408        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1409        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1410        if not opts['is2d']:
1411            for p in model_info.parameters.user_parameters({}, is2d=False):
1412                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1413                    k = p.name+ext
1414                    v = pars.get(k, None)
1415                    if v is not None:
1416                        v.range(*parameter_range(k, v.value))
1417        else:
1418            for k, v in pars.items():
1419                v.range(*parameter_range(k, v.value))
1420
1421        self.pars = pars
1422        self.starting_values = dict((k, v.value) for k, v in pars.items())
1423        self.pd_types = pd_types
1424        self.limits = np.Inf, -np.Inf
1425
1426    def revert_values(self):
1427        for k, v in self.starting_values.items():
1428            self.pars[k].value = v
1429
1430    def model_update(self):
1431        pass
1432
1433    def numpoints(self):
1434        # type: () -> int
1435        """
1436        Return the number of points.
1437        """
1438        return len(self.pars) + 1  # so dof is 1
1439
1440    def parameters(self):
1441        # type: () -> Any   # Dict/List hierarchy of parameters
1442        """
1443        Return a dictionary of parameters.
1444        """
1445        return self.pars
1446
1447    def nllf(self):
1448        # type: () -> float
1449        """
1450        Return cost.
1451        """
1452        # pylint: disable=no-self-use
1453        return 0.  # No nllf
1454
1455    def plot(self, view='log'):
1456        # type: (str) -> None
1457        """
1458        Plot the data and residuals.
1459        """
1460        pars = dict((k, v.value) for k, v in self.pars.items())
1461        pars.update(self.pd_types)
1462        self.opts['pars'][0] = pars
1463        if not self.fix_p2:
1464            self.opts['pars'][1] = pars
1465        result = run_models(self.opts)
1466        limits = plot_models(self.opts, result, limits=self.limits)
1467        if self.limits is None:
1468            vmin, vmax = limits
1469            self.limits = vmax*1e-7, 1.3*vmax
1470            import pylab; pylab.clf()
1471            plot_models(self.opts, result, limits=self.limits)
1472
1473
1474def main(*argv):
1475    # type: (*str) -> None
1476    """
1477    Main program.
1478    """
1479    opts = parse_opts(argv)
1480    if opts is not None:
1481        if opts['seed'] > -1:
1482            print("Randomize using -random=%i"%opts['seed'])
1483            np.random.seed(opts['seed'])
1484        if opts['html']:
1485            show_docs(opts)
1486        elif opts['explore']:
1487            opts['pars'] = parse_pars(opts)
1488            if opts['pars'] is None:
1489                return
1490            explore(opts)
1491        else:
1492            compare(opts)
1493
1494if __name__ == "__main__":
1495    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.