source: sasmodels/sasmodels/compare.py @ 8f04da4

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 8f04da4 was 8f04da4, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

tuned random model generation for more models

  • Property mode set to 100755
File size: 49.1 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except Exception:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56
57USAGE = """
58usage: compare.py model N1 N2 [options...] [key=val]
59
60Compare the speed and value for a model between the SasView original and the
61sasmodels rewrite.
62
63model or model1,model2 are the names of the models to compare (see below).
64N1 is the number of times to run sasmodels (default=1).
65N2 is the number times to run sasview (default=1).
66
67Options (* for default):
68
69    -plot*/-noplot plots or suppress the plot of the model
70    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
71    -nq=128 sets the number of Q points in the data set
72    -zero indicates that q=0 should be included
73    -1d*/-2d computes 1d or 2d data
74    -preset*/-random[=seed] preset or random parameters
75    -mono*/-poly force monodisperse or allow polydisperse demo parameters
76    -magnetic/-nonmagnetic* suppress magnetism
77    -cutoff=1e-5* cutoff value for including a point in polydispersity
78    -pars/-nopars* prints the parameter set or not
79    -abs/-rel* plot relative or absolute error
80    -linear/-log*/-q4 intensity scaling
81    -hist/-nohist* plot histogram of relative error
82    -res=0 sets the resolution width dQ/Q if calculating with resolution
83    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
84    -edit starts the parameter explorer
85    -default/-demo* use demo vs default parameters
86    -help/-html shows the model docs instead of running the model
87    -title="note" adds note to the plot title, after the model name
88    -data="path" uses q, dq from the data file
89
90Any two calculation engines can be selected for comparison:
91
92    -single/-double/-half/-fast sets an OpenCL calculation engine
93    -single!/-double!/-quad! sets an OpenMP calculation engine
94    -sasview sets the sasview calculation engine
95
96The default is -single -double.  Note that the interpretation of quad
97precision depends on architecture, and may vary from 64-bit to 128-bit,
98with 80-bit floats being common (1e-19 precision).
99
100Key=value pairs allow you to set specific values for the model parameters.
101Key=value1,value2 to compare different values of the same parameter.
102value can be an expression including other parameters
103"""
104
105# Update docs with command line usage string.   This is separate from the usual
106# doc string so that we can display it at run time if there is an error.
107# lin
108__doc__ = (__doc__  # pylint: disable=redefined-builtin
109           + """
110Program description
111-------------------
112
113"""
114           + USAGE)
115
116kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
117
118# list of math functions for use in evaluating parameters
119MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
120
121# CRUFT python 2.6
122if not hasattr(datetime.timedelta, 'total_seconds'):
123    def delay(dt):
124        """Return number date-time delta as number seconds"""
125        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
126else:
127    def delay(dt):
128        """Return number date-time delta as number seconds"""
129        return dt.total_seconds()
130
131
132class push_seed(object):
133    """
134    Set the seed value for the random number generator.
135
136    When used in a with statement, the random number generator state is
137    restored after the with statement is complete.
138
139    :Parameters:
140
141    *seed* : int or array_like, optional
142        Seed for RandomState
143
144    :Example:
145
146    Seed can be used directly to set the seed::
147
148        >>> from numpy.random import randint
149        >>> push_seed(24)
150        <...push_seed object at...>
151        >>> print(randint(0,1000000,3))
152        [242082    899 211136]
153
154    Seed can also be used in a with statement, which sets the random
155    number generator state for the enclosed computations and restores
156    it to the previous state on completion::
157
158        >>> with push_seed(24):
159        ...    print(randint(0,1000000,3))
160        [242082    899 211136]
161
162    Using nested contexts, we can demonstrate that state is indeed
163    restored after the block completes::
164
165        >>> with push_seed(24):
166        ...    print(randint(0,1000000))
167        ...    with push_seed(24):
168        ...        print(randint(0,1000000,3))
169        ...    print(randint(0,1000000))
170        242082
171        [242082    899 211136]
172        899
173
174    The restore step is protected against exceptions in the block::
175
176        >>> with push_seed(24):
177        ...    print(randint(0,1000000))
178        ...    try:
179        ...        with push_seed(24):
180        ...            print(randint(0,1000000,3))
181        ...            raise Exception()
182        ...    except Exception:
183        ...        print("Exception raised")
184        ...    print(randint(0,1000000))
185        242082
186        [242082    899 211136]
187        Exception raised
188        899
189    """
190    def __init__(self, seed=None):
191        # type: (Optional[int]) -> None
192        self._state = np.random.get_state()
193        np.random.seed(seed)
194
195    def __enter__(self):
196        # type: () -> None
197        pass
198
199    def __exit__(self, exc_type, exc_value, traceback):
200        # type: (Any, BaseException, Any) -> None
201        # TODO: better typing for __exit__ method
202        np.random.set_state(self._state)
203
204def tic():
205    # type: () -> Callable[[], float]
206    """
207    Timer function.
208
209    Use "toc=tic()" to start the clock and "toc()" to measure
210    a time interval.
211    """
212    then = datetime.datetime.now()
213    return lambda: delay(datetime.datetime.now() - then)
214
215
216def set_beam_stop(data, radius, outer=None):
217    # type: (Data, float, float) -> None
218    """
219    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
220
221    Note: this function does not require sasview
222    """
223    if hasattr(data, 'qx_data'):
224        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
225        data.mask = (q < radius)
226        if outer is not None:
227            data.mask |= (q >= outer)
228    else:
229        data.mask = (data.x < radius)
230        if outer is not None:
231            data.mask |= (data.x >= outer)
232
233
234def parameter_range(p, v):
235    # type: (str, float) -> Tuple[float, float]
236    """
237    Choose a parameter range based on parameter name and initial value.
238    """
239    # process the polydispersity options
240    if p.endswith('_pd_n'):
241        return 0., 100.
242    elif p.endswith('_pd_nsigma'):
243        return 0., 5.
244    elif p.endswith('_pd_type'):
245        raise ValueError("Cannot return a range for a string value")
246    elif any(s in p for s in ('theta', 'phi', 'psi')):
247        # orientation in [-180,180], orientation pd in [0,45]
248        if p.endswith('_pd'):
249            return 0., 45.
250        else:
251            return -180., 180.
252    elif p.endswith('_pd'):
253        return 0., 1.
254    elif 'sld' in p:
255        return -0.5, 10.
256    elif p == 'background':
257        return 0., 10.
258    elif p == 'scale':
259        return 0., 1.e3
260    elif v < 0.:
261        return 2.*v, -2.*v
262    else:
263        return 0., (2.*v if v > 0. else 1.)
264
265
266def _randomize_one(model_info, name, value):
267    # type: (ModelInfo, str, float) -> float
268    # type: (ModelInfo, str, str) -> str
269    """
270    Randomize a single parameter.
271    """
272    # Set the amount of polydispersity/angular dispersion, but by default pd_n
273    # is zero so there is no polydispersity.  This allows us to turn on/off
274    # pd by setting pd_n, and still have randomly generated values
275    if name.endswith('_pd'):
276        par = model_info.parameters[name[:-3]]
277        if par.type == 'orientation':
278            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
279            return 180*np.random.beta(2.5, 20)
280        else:
281            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
282            return np.random.beta(1.5, 7)
283
284    # pd is selected globally rather than per parameter, so set to 0 for no pd
285    # In particular, when multiple pd dimensions, want to decrease the number
286    # of points per dimension for faster computation
287    if name.endswith('_pd_n'):
288        return 0
289
290    # Don't mess with distribution type for now
291    if name.endswith('_pd_type'):
292        return 'gaussian'
293
294    # type-dependent value of number of sigmas; for gaussian use 3.
295    if name.endswith('_pd_nsigma'):
296        return 3.
297
298    # background in the range [0.01, 1]
299    if name == 'background':
300        return 10**np.random.uniform(-2, 0)
301
302    # scale defaults to 0.1% to 30% volume fraction
303    if name == 'scale':
304        return 10**np.random.uniform(-3, -0.5)
305
306    # If it is a list of choices, pick one at random with equal probability
307    # In practice, the model specific random generator will override.
308    par = model_info.parameters[name]
309    if len(par.limits) > 2:  # choice list
310        return np.random.randint(len(par.limits))
311
312    # If it is a fixed range, pick from it with equal probability.
313    # For logarithmic ranges, the model will have to override.
314    if np.isfinite(par.limits).all():
315        return np.random.uniform(*par.limits)
316
317    # If the paramter is marked as an sld use the range of neutron slds
318    # Should be doing something with randomly selected contrast matching
319    # but for now use random contrasts.  Since real data is contrast-matched,
320    # this will favour selection of hollow models when they exist.
321    if par.type == 'sld':
322        # Range of neutron SLDs
323        return np.random.uniform(-0.5, 12)
324
325    # Guess at the random length/radius/thickness.  In practice, all models
326    # are going to set their own reasonable ranges.
327    if par.type == 'volume':
328        if ('length' in par.name or
329                'radius' in par.name or
330                'thick' in par.name):
331            return 10**np.random.uniform(2, 4)
332
333    # In the absence of any other info, select a value in [0, 2v], or
334    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
335    # model random parameter generators will override this default.
336    low, high = parameter_range(par.name, value)
337    limits = (max(par.limits[0], low), min(par.limits[1], high))
338    return np.random.uniform(*limits)
339
340def _random_pd(model_info, pars):
341    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
342    pd_volume = []
343    pd_oriented = []
344    for p in pd:
345        if p.type == 'orientation':
346            pd_oriented.append(p.name)
347        elif p.length_control is not None:
348            n = pars.get(p.length_control, 1)
349            pd_volume.extend(p.name+str(k+1) for k in range(n))
350        elif p.length > 1:
351            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
352        else:
353            pd_volume.append(p.name)
354    u = np.random.rand()
355    n = len(pd_volume)
356    if u < 0.01 or n < 1:
357        pass  # 1% chance of no polydispersity
358    elif u < 0.86 or n < 2:
359        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
360    elif u < 0.99 or n < 3:
361        choices = np.random.choice(len(pd_volume), size=2)
362        pars[pd_volume[choices[0]]+"_pd_n"] = 25
363        pars[pd_volume[choices[1]]+"_pd_n"] = 10
364    else:
365        choices = np.random.choice(len(pd_volume), size=3)
366        pars[pd_volume[choices[0]]+"_pd_n"] = 25
367        pars[pd_volume[choices[1]]+"_pd_n"] = 10
368        pars[pd_volume[choices[2]]+"_pd_n"] = 5
369    if pd_oriented:
370        pars['theta_pd_n'] = 20
371        if np.random.rand() < 0.1:
372            pars['phi_pd_n'] = 5
373        if np.random.rand() < 0.1:
374            pars['psi_pd_n'] = 5
375
376    ## Show selected polydispersity
377    #for name, value in pars.items():
378    #    if name.endswith('_pd_n') and value > 0:
379    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
380
381
382def randomize_pars(model_info, pars):
383    # type: (ModelInfo, ParameterSet) -> ParameterSet
384    """
385    Generate random values for all of the parameters.
386
387    Valid ranges for the random number generator are guessed from the name of
388    the parameter; this will not account for constraints such as cap radius
389    greater than cylinder radius in the capped_cylinder model, so
390    :func:`constrain_pars` needs to be called afterward..
391    """
392    # Note: the sort guarantees order of calls to random number generator
393    random_pars = dict((p, _randomize_one(model_info, p, v))
394                       for p, v in sorted(pars.items()))
395    if model_info.random is not None:
396        random_pars.update(model_info.random())
397    _random_pd(model_info, random_pars)
398    return random_pars
399
400
401def constrain_pars(model_info, pars):
402    # type: (ModelInfo, ParameterSet) -> None
403    """
404    Restrict parameters to valid values.
405
406    This includes model specific code for models such as capped_cylinder
407    which need to support within model constraints (cap radius more than
408    cylinder radius in this case).
409
410    Warning: this updates the *pars* dictionary in place.
411    """
412    # TODO: move the model specific code to the individual models
413    name = model_info.id
414    # if it is a product model, then just look at the form factor since
415    # none of the structure factors need any constraints.
416    if '*' in name:
417        name = name.split('*')[0]
418
419    # Suppress magnetism for python models (not yet implemented)
420    if callable(model_info.Iq):
421        pars.update(suppress_magnetism(pars))
422
423    if name == 'barbell':
424        if pars['radius_bell'] < pars['radius']:
425            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
426
427    elif name == 'capped_cylinder':
428        if pars['radius_cap'] < pars['radius']:
429            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
430
431    elif name == 'guinier':
432        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
433        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
434        # => ln A - (Rg^2 q^2/3) > -30 ln 10
435        # => Rg^2 q^2/3 < 30 ln 10 + ln A
436        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
437        #q_max = 0.2  # mid q maximum
438        q_max = 1.0  # high q maximum
439        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
440        pars['rg'] = min(pars['rg'], rg_max)
441
442    elif name == 'pearl_necklace':
443        if pars['radius'] < pars['thick_string']:
444            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
445        pass
446
447    elif name == 'rpa':
448        # Make sure phi sums to 1.0
449        if pars['case_num'] < 2:
450            pars['Phi1'] = 0.
451            pars['Phi2'] = 0.
452        elif pars['case_num'] < 5:
453            pars['Phi1'] = 0.
454        total = sum(pars['Phi'+c] for c in '1234')
455        for c in '1234':
456            pars['Phi'+c] /= total
457
458def parlist(model_info, pars, is2d):
459    # type: (ModelInfo, ParameterSet, bool) -> str
460    """
461    Format the parameter list for printing.
462    """
463    lines = []
464    parameters = model_info.parameters
465    magnetic = False
466    for p in parameters.user_parameters(pars, is2d):
467        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
468            continue
469        if p.id.startswith('up:') and not magnetic:
470            continue
471        fields = dict(
472            value=pars.get(p.id, p.default),
473            pd=pars.get(p.id+"_pd", 0.),
474            n=int(pars.get(p.id+"_pd_n", 0)),
475            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
476            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
477            relative_pd=p.relative_pd,
478            M0=pars.get('M0:'+p.id, 0.),
479            mphi=pars.get('mphi:'+p.id, 0.),
480            mtheta=pars.get('mtheta:'+p.id, 0.),
481        )
482        lines.append(_format_par(p.name, **fields))
483        magnetic = magnetic or fields['M0'] != 0.
484    return "\n".join(lines)
485
486    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
487
488def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
489                relative_pd=False, M0=0., mphi=0., mtheta=0.):
490    # type: (str, float, float, int, float, str) -> str
491    line = "%s: %g"%(name, value)
492    if pd != 0.  and n != 0:
493        if relative_pd:
494            pd *= value
495        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
496                % (pd, n, nsigma, nsigma, pdtype)
497    if M0 != 0.:
498        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
499    return line
500
501def suppress_pd(pars):
502    # type: (ParameterSet) -> ParameterSet
503    """
504    Suppress theta_pd for now until the normalization is resolved.
505
506    May also suppress complete polydispersity of the model to test
507    models more quickly.
508    """
509    pars = pars.copy()
510    for p in pars:
511        if p.endswith("_pd_n"):
512            pars[p] = 0
513    return pars
514
515def suppress_magnetism(pars):
516    # type: (ParameterSet) -> ParameterSet
517    """
518    Suppress theta_pd for now until the normalization is resolved.
519
520    May also suppress complete polydispersity of the model to test
521    models more quickly.
522    """
523    pars = pars.copy()
524    for p in pars:
525        if p.startswith("M0:"): pars[p] = 0
526    return pars
527
528def eval_sasview(model_info, data):
529    # type: (Modelinfo, Data) -> Calculator
530    """
531    Return a model calculator using the pre-4.0 SasView models.
532    """
533    # importing sas here so that the error message will be that sas failed to
534    # import rather than the more obscure smear_selection not imported error
535    import sas
536    import sas.models
537    from sas.models.qsmearing import smear_selection
538    from sas.models.MultiplicationModel import MultiplicationModel
539    from sas.models.dispersion_models import models as dispersers
540
541    def get_model_class(name):
542        # type: (str) -> "sas.models.BaseComponent"
543        #print("new",sorted(_pars.items()))
544        __import__('sas.models.' + name)
545        ModelClass = getattr(getattr(sas.models, name, None), name, None)
546        if ModelClass is None:
547            raise ValueError("could not find model %r in sas.models"%name)
548        return ModelClass
549
550    # WARNING: ugly hack when handling model!
551    # Sasview models with multiplicity need to be created with the target
552    # multiplicity, so we cannot create the target model ahead of time for
553    # for multiplicity models.  Instead we store the model in a list and
554    # update the first element of that list with the new multiplicity model
555    # every time we evaluate.
556
557    # grab the sasview model, or create it if it is a product model
558    if model_info.composition:
559        composition_type, parts = model_info.composition
560        if composition_type == 'product':
561            P, S = [get_model_class(revert_name(p))() for p in parts]
562            model = [MultiplicationModel(P, S)]
563        else:
564            raise ValueError("sasview mixture models not supported by compare")
565    else:
566        old_name = revert_name(model_info)
567        if old_name is None:
568            raise ValueError("model %r does not exist in old sasview"
569                            % model_info.id)
570        ModelClass = get_model_class(old_name)
571        model = [ModelClass()]
572    model[0].disperser_handles = {}
573
574    # build a smearer with which to call the model, if necessary
575    smearer = smear_selection(data, model=model)
576    if hasattr(data, 'qx_data'):
577        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
578        index = ((~data.mask) & (~np.isnan(data.data))
579                 & (q >= data.qmin) & (q <= data.qmax))
580        if smearer is not None:
581            smearer.model = model  # because smear_selection has a bug
582            smearer.accuracy = data.accuracy
583            smearer.set_index(index)
584            def _call_smearer():
585                smearer.model = model[0]
586                return smearer.get_value()
587            theory = _call_smearer
588        else:
589            theory = lambda: model[0].evalDistribution([data.qx_data[index],
590                                                        data.qy_data[index]])
591    elif smearer is not None:
592        theory = lambda: smearer(model[0].evalDistribution(data.x))
593    else:
594        theory = lambda: model[0].evalDistribution(data.x)
595
596    def calculator(**pars):
597        # type: (float, ...) -> np.ndarray
598        """
599        Sasview calculator for model.
600        """
601        oldpars = revert_pars(model_info, pars)
602        # For multiplicity models, create a model with the correct multiplicity
603        control = oldpars.pop("CONTROL", None)
604        if control is not None:
605            # sphericalSLD has one fewer multiplicity.  This update should
606            # happen in revert_pars, but it hasn't been called yet.
607            model[0] = ModelClass(control)
608        # paying for parameter conversion each time to keep life simple, if not fast
609        for k, v in oldpars.items():
610            if k.endswith('.type'):
611                par = k[:-5]
612                if v == 'gaussian': continue
613                cls = dispersers[v if v != 'rectangle' else 'rectangula']
614                handle = cls()
615                model[0].disperser_handles[par] = handle
616                try:
617                    model[0].set_dispersion(par, handle)
618                except Exception:
619                    exception.annotate_exception("while setting %s to %r"
620                                                 %(par, v))
621                    raise
622
623
624        #print("sasview pars",oldpars)
625        for k, v in oldpars.items():
626            name_attr = k.split('.')  # polydispersity components
627            if len(name_attr) == 2:
628                par, disp_par = name_attr
629                model[0].dispersion[par][disp_par] = v
630            else:
631                model[0].setParam(k, v)
632        return theory()
633
634    calculator.engine = "sasview"
635    return calculator
636
637DTYPE_MAP = {
638    'half': '16',
639    'fast': 'fast',
640    'single': '32',
641    'double': '64',
642    'quad': '128',
643    'f16': '16',
644    'f32': '32',
645    'f64': '64',
646    'float16': '16',
647    'float32': '32',
648    'float64': '64',
649    'float128': '128',
650    'longdouble': '128',
651}
652def eval_opencl(model_info, data, dtype='single', cutoff=0.):
653    # type: (ModelInfo, Data, str, float) -> Calculator
654    """
655    Return a model calculator using the OpenCL calculation engine.
656    """
657    if not core.HAVE_OPENCL:
658        raise RuntimeError("OpenCL not available")
659    model = core.build_model(model_info, dtype=dtype, platform="ocl")
660    calculator = DirectModel(data, model, cutoff=cutoff)
661    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
662    return calculator
663
664def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
665    # type: (ModelInfo, Data, str, float) -> Calculator
666    """
667    Return a model calculator using the DLL calculation engine.
668    """
669    model = core.build_model(model_info, dtype=dtype, platform="dll")
670    calculator = DirectModel(data, model, cutoff=cutoff)
671    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
672    return calculator
673
674def time_calculation(calculator, pars, evals=1):
675    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
676    """
677    Compute the average calculation time over N evaluations.
678
679    An additional call is generated without polydispersity in order to
680    initialize the calculation engine, and make the average more stable.
681    """
682    # initialize the code so time is more accurate
683    if evals > 1:
684        calculator(**suppress_pd(pars))
685    toc = tic()
686    # make sure there is at least one eval
687    value = calculator(**pars)
688    for _ in range(evals-1):
689        value = calculator(**pars)
690    average_time = toc()*1000. / evals
691    #print("I(q)",value)
692    return value, average_time
693
694def make_data(opts):
695    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
696    """
697    Generate an empty dataset, used with the model to set Q points
698    and resolution.
699
700    *opts* contains the options, with 'qmax', 'nq', 'res',
701    'accuracy', 'is2d' and 'view' parsed from the command line.
702    """
703    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
704    if opts['is2d']:
705        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
706        data = empty_data2D(q, resolution=res)
707        data.accuracy = opts['accuracy']
708        set_beam_stop(data, 0.0004)
709        index = ~data.mask
710    else:
711        if opts['view'] == 'log' and not opts['zero']:
712            qmax = math.log10(qmax)
713            q = np.logspace(qmax-3, qmax, nq)
714        else:
715            q = np.linspace(0.001*qmax, qmax, nq)
716        if opts['zero']:
717            q = np.hstack((0, q))
718        data = empty_data1D(q, resolution=res)
719        index = slice(None, None)
720    return data, index
721
722def make_engine(model_info, data, dtype, cutoff):
723    # type: (ModelInfo, Data, str, float) -> Calculator
724    """
725    Generate the appropriate calculation engine for the given datatype.
726
727    Datatypes with '!' appended are evaluated using external C DLLs rather
728    than OpenCL.
729    """
730    if dtype == 'sasview':
731        return eval_sasview(model_info, data)
732    elif dtype.endswith('!'):
733        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
734    else:
735        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
736
737def _show_invalid(data, theory):
738    # type: (Data, np.ma.ndarray) -> None
739    """
740    Display a list of the non-finite values in theory.
741    """
742    if not theory.mask.any():
743        return
744
745    if hasattr(data, 'x'):
746        bad = zip(data.x[theory.mask], theory[theory.mask])
747        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
748
749
750def compare(opts, limits=None):
751    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
752    """
753    Preform a comparison using options from the command line.
754
755    *limits* are the limits on the values to use, either to set the y-axis
756    for 1D or to set the colormap scale for 2D.  If None, then they are
757    inferred from the data and returned. When exploring using Bumps,
758    the limits are set when the model is initially called, and maintained
759    as the values are adjusted, making it easier to see the effects of the
760    parameters.
761    """
762    limits = np.Inf, -np.Inf
763    for k in range(opts['sets']):
764        opts['pars'] = parse_pars(opts)
765        if opts['pars'] is None:
766            return
767        result = run_models(opts, verbose=True)
768        if opts['plot']:
769            limits = plot_models(opts, result, limits=limits, setnum=k)
770    if opts['plot']:
771        import matplotlib.pyplot as plt
772        plt.show()
773
774def run_models(opts, verbose=False):
775    # type: (Dict[str, Any]) -> Dict[str, Any]
776
777    n_base, n_comp = opts['count']
778    pars, pars2 = opts['pars']
779    data = opts['data']
780
781    # silence the linter
782    base = opts['engines'][0] if n_base else None
783    comp = opts['engines'][1] if n_comp else None
784
785    base_time = comp_time = None
786    base_value = comp_value = resid = relerr = None
787
788    # Base calculation
789    if n_base > 0:
790        try:
791            base_raw, base_time = time_calculation(base, pars, n_base)
792            base_value = np.ma.masked_invalid(base_raw)
793            if verbose:
794                print("%s t=%.2f ms, intensity=%.0f"
795                      % (base.engine, base_time, base_value.sum()))
796            _show_invalid(data, base_value)
797        except ImportError:
798            traceback.print_exc()
799            n_base = 0
800
801    # Comparison calculation
802    if n_comp > 0:
803        try:
804            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
805            comp_value = np.ma.masked_invalid(comp_raw)
806            if verbose:
807                print("%s t=%.2f ms, intensity=%.0f"
808                      % (comp.engine, comp_time, comp_value.sum()))
809            _show_invalid(data, comp_value)
810        except ImportError:
811            traceback.print_exc()
812            n_comp = 0
813
814    # Compare, but only if computing both forms
815    if n_base > 0 and n_comp > 0:
816        resid = (base_value - comp_value)
817        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
818        if verbose:
819            _print_stats("|%s-%s|"
820                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
821                         resid)
822            _print_stats("|(%s-%s)/%s|"
823                         % (base.engine, comp.engine, comp.engine),
824                         relerr)
825
826    return dict(base_value=base_value, comp_value=comp_value,
827                base_time=base_time, comp_time=comp_time,
828                resid=resid, relerr=relerr)
829
830
831def _print_stats(label, err):
832    # type: (str, np.ma.ndarray) -> None
833    # work with trimmed data, not the full set
834    sorted_err = np.sort(abs(err.compressed()))
835    if len(sorted_err) == 0.:
836        print(label + "  no valid values")
837        return
838
839    p50 = int((len(sorted_err)-1)*0.50)
840    p98 = int((len(sorted_err)-1)*0.98)
841    data = [
842        "max:%.3e"%sorted_err[-1],
843        "median:%.3e"%sorted_err[p50],
844        "98%%:%.3e"%sorted_err[p98],
845        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
846        "zero-offset:%+.3e"%np.mean(sorted_err),
847        ]
848    print(label+"  "+"  ".join(data))
849
850
851def plot_models(opts, result, limits=(np.Inf, -np.Inf), setnum=0):
852    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
853    base_value, comp_value= result['base_value'], result['comp_value']
854    base_time, comp_time = result['base_time'], result['comp_time']
855    resid, relerr = result['resid'], result['relerr']
856
857    have_base, have_comp = (base_value is not None), (comp_value is not None)
858    base = opts['engines'][0] if have_base else None
859    comp = opts['engines'][1] if have_comp else None
860    data = opts['data']
861    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
862
863    # Plot if requested
864    view = opts['view']
865    import matplotlib.pyplot as plt
866    vmin, vmax = limits
867    if have_base:
868        vmin = min(vmin, base_value.min())
869        vmax = max(vmax, base_value.max())
870    if have_comp:
871        vmin = min(vmin, comp_value.min())
872        vmax = max(vmax, comp_value.max())
873    limits = vmin, vmax
874
875    if have_base:
876        if have_comp: plt.subplot(131)
877        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
878        plt.title("%s t=%.2f ms"%(base.engine, base_time))
879        #cbar_title = "log I"
880    if have_comp:
881        if have_base: plt.subplot(132)
882        if not opts['is2d'] and have_base:
883            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
884        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
885        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
886        #cbar_title = "log I"
887    if have_base and have_comp:
888        plt.subplot(133)
889        if not opts['rel_err']:
890            err, errstr, errview = resid, "abs err", "linear"
891        else:
892            err, errstr, errview = abs(relerr), "rel err", "log"
893        if 0:  # 95% cutoff
894            sorted = np.sort(err.flatten())
895            cutoff = sorted[int(sorted.size*0.95)]
896            err[err>cutoff] = cutoff
897        #err,errstr = base/comp,"ratio"
898        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
899        if view == 'linear':
900            plt.xscale('linear')
901        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
902        #cbar_title = errstr if errview=="linear" else "log "+errstr
903    #if is2D:
904    #    h = plt.colorbar()
905    #    h.ax.set_title(cbar_title)
906    fig = plt.gcf()
907    extra_title = ' '+opts['title'] if opts['title'] else ''
908    fig.suptitle(":".join(opts['name']) + extra_title)
909
910    if have_base and have_comp and opts['show_hist']:
911        plt.figure()
912        v = relerr
913        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
914        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
915        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
916                   % (base.engine, comp.engine, comp.engine))
917        plt.ylabel('P(err)')
918        plt.title('Distribution of relative error between calculation engines')
919
920    return limits
921
922
923# ===========================================================================
924#
925NAME_OPTIONS = set([
926    'plot', 'noplot',
927    'half', 'fast', 'single', 'double',
928    'single!', 'double!', 'quad!', 'sasview',
929    'lowq', 'midq', 'highq', 'exq', 'zero',
930    '2d', '1d',
931    'preset', 'random',
932    'poly', 'mono',
933    'magnetic', 'nonmagnetic',
934    'nopars', 'pars',
935    'rel', 'abs',
936    'linear', 'log', 'q4',
937    'hist', 'nohist',
938    'edit', 'html', 'help',
939    'demo', 'default',
940    ])
941VALUE_OPTIONS = [
942    # Note: random is both a name option and a value option
943    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title', 'data', 'sets'
944    ]
945
946def columnize(items, indent="", width=79):
947    # type: (List[str], str, int) -> str
948    """
949    Format a list of strings into columns.
950
951    Returns a string with carriage returns ready for printing.
952    """
953    column_width = max(len(w) for w in items) + 1
954    num_columns = (width - len(indent)) // column_width
955    num_rows = len(items) // num_columns
956    items = items + [""] * (num_rows * num_columns - len(items))
957    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
958    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
959             for row in zip(*columns)]
960    output = indent + ("\n"+indent).join(lines)
961    return output
962
963
964def get_pars(model_info, use_demo=False):
965    # type: (ModelInfo, bool) -> ParameterSet
966    """
967    Extract demo parameters from the model definition.
968    """
969    # Get the default values for the parameters
970    pars = {}
971    for p in model_info.parameters.call_parameters:
972        parts = [('', p.default)]
973        if p.polydisperse:
974            parts.append(('_pd', 0.0))
975            parts.append(('_pd_n', 0))
976            parts.append(('_pd_nsigma', 3.0))
977            parts.append(('_pd_type', "gaussian"))
978        for ext, val in parts:
979            if p.length > 1:
980                dict(("%s%d%s" % (p.id, k, ext), val)
981                     for k in range(1, p.length+1))
982            else:
983                pars[p.id + ext] = val
984
985    # Plug in values given in demo
986    if use_demo:
987        pars.update(model_info.demo)
988    return pars
989
990INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
991def isnumber(str):
992    match = FLOAT_RE.match(str)
993    isfloat = (match and not str[match.end():])
994    return isfloat or INTEGER_RE.match(str)
995
996# For distinguishing pairs of models for comparison
997# key-value pair separator =
998# shell characters  | & ; <> $ % ' " \ # `
999# model and parameter names _
1000# parameter expressions - + * / . ( )
1001# path characters including tilde expansion and windows drive ~ / :
1002# not sure about brackets [] {}
1003# maybe one of the following @ ? ^ ! ,
1004MODEL_SPLIT = ','
1005def parse_opts(argv):
1006    # type: (List[str]) -> Dict[str, Any]
1007    """
1008    Parse command line options.
1009    """
1010    MODELS = core.list_models()
1011    flags = [arg for arg in argv
1012             if arg.startswith('-')]
1013    values = [arg for arg in argv
1014              if not arg.startswith('-') and '=' in arg]
1015    positional_args = [arg for arg in argv
1016                       if not arg.startswith('-') and '=' not in arg]
1017    models = "\n    ".join("%-15s"%v for v in MODELS)
1018    if len(positional_args) == 0:
1019        print(USAGE)
1020        print("\nAvailable models:")
1021        print(columnize(MODELS, indent="  "))
1022        return None
1023    if len(positional_args) > 3:
1024        print("expected parameters: model N1 N2")
1025
1026    invalid = [o[1:] for o in flags
1027               if o[1:] not in NAME_OPTIONS
1028               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1029    if invalid:
1030        print("Invalid options: %s"%(", ".join(invalid)))
1031        return None
1032
1033    name = positional_args[0]
1034    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
1035    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
1036
1037    # pylint: disable=bad-whitespace
1038    # Interpret the flags
1039    opts = {
1040        'plot'      : True,
1041        'view'      : 'log',
1042        'is2d'      : False,
1043        'qmax'      : 0.05,
1044        'nq'        : 128,
1045        'res'       : 0.0,
1046        'accuracy'  : 'Low',
1047        'cutoff'    : 0.0,
1048        'seed'      : -1,  # default to preset
1049        'mono'      : True,
1050        # Default to magnetic a magnetic moment is set on the command line
1051        'magnetic'  : False,
1052        'show_pars' : False,
1053        'show_hist' : False,
1054        'rel_err'   : True,
1055        'explore'   : False,
1056        'use_demo'  : True,
1057        'zero'      : False,
1058        'html'      : False,
1059        'title'     : None,
1060        'datafile'  : None,
1061        'sets'      : 1,
1062    }
1063    engines = []
1064    for arg in flags:
1065        if arg == '-noplot':    opts['plot'] = False
1066        elif arg == '-plot':    opts['plot'] = True
1067        elif arg == '-linear':  opts['view'] = 'linear'
1068        elif arg == '-log':     opts['view'] = 'log'
1069        elif arg == '-q4':      opts['view'] = 'q4'
1070        elif arg == '-1d':      opts['is2d'] = False
1071        elif arg == '-2d':      opts['is2d'] = True
1072        elif arg == '-exq':     opts['qmax'] = 10.0
1073        elif arg == '-highq':   opts['qmax'] = 1.0
1074        elif arg == '-midq':    opts['qmax'] = 0.2
1075        elif arg == '-lowq':    opts['qmax'] = 0.05
1076        elif arg == '-zero':    opts['zero'] = True
1077        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1078        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1079        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1080        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1081        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
1082        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
1083        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1084        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1085        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
1086        elif arg == '-preset':  opts['seed'] = -1
1087        elif arg == '-mono':    opts['mono'] = True
1088        elif arg == '-poly':    opts['mono'] = False
1089        elif arg == '-magnetic':       opts['magnetic'] = True
1090        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1091        elif arg == '-pars':    opts['show_pars'] = True
1092        elif arg == '-nopars':  opts['show_pars'] = False
1093        elif arg == '-hist':    opts['show_hist'] = True
1094        elif arg == '-nohist':  opts['show_hist'] = False
1095        elif arg == '-rel':     opts['rel_err'] = True
1096        elif arg == '-abs':     opts['rel_err'] = False
1097        elif arg == '-half':    engines.append(arg[1:])
1098        elif arg == '-fast':    engines.append(arg[1:])
1099        elif arg == '-single':  engines.append(arg[1:])
1100        elif arg == '-double':  engines.append(arg[1:])
1101        elif arg == '-single!': engines.append(arg[1:])
1102        elif arg == '-double!': engines.append(arg[1:])
1103        elif arg == '-quad!':   engines.append(arg[1:])
1104        elif arg == '-sasview': engines.append(arg[1:])
1105        elif arg == '-edit':    opts['explore'] = True
1106        elif arg == '-demo':    opts['use_demo'] = True
1107        elif arg == '-default':    opts['use_demo'] = False
1108        elif arg == '-html':    opts['html'] = True
1109        elif arg == '-help':    opts['html'] = True
1110    # pylint: enable=bad-whitespace
1111
1112    # Force random if more than one set
1113    if opts['sets'] > 1 and opts['seed'] < 0:
1114        opts['seed'] = np.random.randint(1000000)
1115
1116    if MODEL_SPLIT in name:
1117        name, name2 = name.split(MODEL_SPLIT, 2)
1118    else:
1119        name2 = name
1120    try:
1121        model_info = core.load_model_info(name)
1122        model_info2 = core.load_model_info(name2) if name2 != name else model_info
1123    except ImportError as exc:
1124        print(str(exc))
1125        print("Could not find model; use one of:\n    " + models)
1126        return None
1127
1128    # TODO: check if presets are different when deciding if models are same
1129    same_model = name == name2
1130    if len(engines) == 0:
1131        if same_model:
1132            engines.extend(['single', 'double'])
1133        else:
1134            engines.extend(['single', 'single'])
1135    elif len(engines) == 1:
1136        if not same_model:
1137            engines.append(engines[0])
1138        elif engines[0] == 'double':
1139            engines.append('single')
1140        else:
1141            engines.append('double')
1142    elif len(engines) > 2:
1143        del engines[2:]
1144
1145    # Create the computational engines
1146    if opts['datafile'] is not None:
1147        data = load_data(os.path.expanduser(opts['datafile']))
1148    else:
1149        data, _ = make_data(opts)
1150    if n1:
1151        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1152    else:
1153        base = None
1154    if n2:
1155        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1156    else:
1157        comp = None
1158
1159    # pylint: disable=bad-whitespace
1160    # Remember it all
1161    opts.update({
1162        'data'      : data,
1163        'name'      : [name, name2],
1164        'def'       : [model_info, model_info2],
1165        'count'     : [n1, n2],
1166        'engines'   : [base, comp],
1167        'values'    : values,
1168    })
1169    # pylint: enable=bad-whitespace
1170
1171    return opts
1172
1173def parse_pars(opts):
1174    model_info, model_info2 = opts['def']
1175
1176    # Get demo parameters from model definition, or use default parameters
1177    # if model does not define demo parameters
1178    pars = get_pars(model_info, opts['use_demo'])
1179    pars2 = get_pars(model_info2, opts['use_demo'])
1180    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1181    # randomize parameters
1182    #pars.update(set_pars)  # set value before random to control range
1183    if opts['seed'] > -1:
1184        pars = randomize_pars(model_info, pars)
1185        if model_info != model_info2:
1186            pars2 = randomize_pars(model_info2, pars2)
1187            # Share values for parameters with the same name
1188            for k, v in pars.items():
1189                if k in pars2:
1190                    pars2[k] = v
1191        else:
1192            pars2 = pars.copy()
1193        constrain_pars(model_info, pars)
1194        constrain_pars(model_info2, pars2)
1195    if opts['mono']:
1196        pars = suppress_pd(pars)
1197        pars2 = suppress_pd(pars2)
1198    if not opts['magnetic']:
1199        pars = suppress_magnetism(pars)
1200        pars2 = suppress_magnetism(pars2)
1201
1202    # Fill in parameters given on the command line
1203    presets = {}
1204    presets2 = {}
1205    for arg in opts['values']:
1206        k, v = arg.split('=', 1)
1207        if k not in pars and k not in pars2:
1208            # extract base name without polydispersity info
1209            s = set(p.split('_pd')[0] for p in pars)
1210            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1211            return None
1212        v1, v2 = v.split(MODEL_SPLIT, 2) if MODEL_SPLIT in v else (v,v)
1213        if v1 and k in pars:
1214            presets[k] = float(v1) if isnumber(v1) else v1
1215        if v2 and k in pars2:
1216            presets2[k] = float(v2) if isnumber(v2) else v2
1217
1218    # If pd given on the command line, default pd_n to 35
1219    for k, v in list(presets.items()):
1220        if k.endswith('_pd'):
1221            presets.setdefault(k+'_n', 35.)
1222    for k, v in list(presets2.items()):
1223        if k.endswith('_pd'):
1224            presets2.setdefault(k+'_n', 35.)
1225
1226    # Evaluate preset parameter expressions
1227    context = MATH.copy()
1228    context['np'] = np
1229    context.update(pars)
1230    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1231    for k, v in presets.items():
1232        if not isinstance(v, float) and not k.endswith('_type'):
1233            presets[k] = eval(v, context)
1234    context.update(presets)
1235    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1236    for k, v in presets2.items():
1237        if not isinstance(v, float) and not k.endswith('_type'):
1238            presets2[k] = eval(v, context)
1239
1240    # update parameters with presets
1241    pars.update(presets)  # set value after random to control value
1242    pars2.update(presets2)  # set value after random to control value
1243    #import pprint; pprint.pprint(model_info)
1244
1245    if opts['show_pars']:
1246        if model_info.name != model_info2.name or pars != pars2:
1247            print("==== %s ====="%model_info.name)
1248            print(str(parlist(model_info, pars, opts['is2d'])))
1249            print("==== %s ====="%model_info2.name)
1250            print(str(parlist(model_info2, pars2, opts['is2d'])))
1251        else:
1252            print(str(parlist(model_info, pars, opts['is2d'])))
1253
1254    return pars, pars2
1255
1256def show_docs(opts):
1257    # type: (Dict[str, Any]) -> None
1258    """
1259    show html docs for the model
1260    """
1261    import os
1262    from .generate import make_html
1263    from . import rst2html
1264
1265    info = opts['def'][0]
1266    html = make_html(info)
1267    path = os.path.dirname(info.filename)
1268    url = "file://"+path.replace("\\","/")[2:]+"/"
1269    rst2html.view_html_qtapp(html, url)
1270
1271def explore(opts):
1272    # type: (Dict[str, Any]) -> None
1273    """
1274    explore the model using the bumps gui.
1275    """
1276    import wx  # type: ignore
1277    from bumps.names import FitProblem  # type: ignore
1278    from bumps.gui.app_frame import AppFrame  # type: ignore
1279    from bumps.gui import signal
1280
1281    is_mac = "cocoa" in wx.version()
1282    # Create an app if not running embedded
1283    app = wx.App() if wx.GetApp() is None else None
1284    model = Explore(opts)
1285    problem = FitProblem(model)
1286    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1287    if not is_mac:
1288        frame.Show()
1289    frame.panel.set_model(model=problem)
1290    frame.panel.Layout()
1291    frame.panel.aui.Split(0, wx.TOP)
1292    def reset_parameters(event):
1293        model.revert_values()
1294        signal.update_parameters(problem)
1295    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1296    if is_mac: frame.Show()
1297    # If running withing an app, start the main loop
1298    if app:
1299        app.MainLoop()
1300
1301class Explore(object):
1302    """
1303    Bumps wrapper for a SAS model comparison.
1304
1305    The resulting object can be used as a Bumps fit problem so that
1306    parameters can be adjusted in the GUI, with plots updated on the fly.
1307    """
1308    def __init__(self, opts):
1309        # type: (Dict[str, Any]) -> None
1310        from bumps.cli import config_matplotlib  # type: ignore
1311        from . import bumps_model
1312        config_matplotlib()
1313        self.opts = opts
1314        opts['pars'] = list(opts['pars'])
1315        p1, p2 = opts['pars']
1316        m1, m2 = opts['def']
1317        self.fix_p2 = m1 != m2 or p1 != p2
1318        model_info = m1
1319        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1320        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1321        if not opts['is2d']:
1322            for p in model_info.parameters.user_parameters({}, is2d=False):
1323                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1324                    k = p.name+ext
1325                    v = pars.get(k, None)
1326                    if v is not None:
1327                        v.range(*parameter_range(k, v.value))
1328        else:
1329            for k, v in pars.items():
1330                v.range(*parameter_range(k, v.value))
1331
1332        self.pars = pars
1333        self.starting_values = dict((k, v.value) for k, v in pars.items())
1334        self.pd_types = pd_types
1335        self.limits = np.Inf, -np.Inf
1336
1337    def revert_values(self):
1338        for k, v in self.starting_values.items():
1339            self.pars[k].value = v
1340
1341    def model_update(self):
1342        pass
1343
1344    def numpoints(self):
1345        # type: () -> int
1346        """
1347        Return the number of points.
1348        """
1349        return len(self.pars) + 1  # so dof is 1
1350
1351    def parameters(self):
1352        # type: () -> Any   # Dict/List hierarchy of parameters
1353        """
1354        Return a dictionary of parameters.
1355        """
1356        return self.pars
1357
1358    def nllf(self):
1359        # type: () -> float
1360        """
1361        Return cost.
1362        """
1363        # pylint: disable=no-self-use
1364        return 0.  # No nllf
1365
1366    def plot(self, view='log'):
1367        # type: (str) -> None
1368        """
1369        Plot the data and residuals.
1370        """
1371        pars = dict((k, v.value) for k, v in self.pars.items())
1372        pars.update(self.pd_types)
1373        self.opts['pars'][0] = pars
1374        if not self.fix_p2:
1375            self.opts['pars'][1] = pars
1376        result = run_models(self.opts)
1377        limits = plot_models(self.opts, result, limits=self.limits)
1378        if self.limits is None:
1379            vmin, vmax = limits
1380            self.limits = vmax*1e-7, 1.3*vmax
1381            import pylab; pylab.clf()
1382            plot_models(self.opts, result, limits=self.limits)
1383
1384
1385def main(*argv):
1386    # type: (*str) -> None
1387    """
1388    Main program.
1389    """
1390    opts = parse_opts(argv)
1391    if opts is not None:
1392        if opts['seed'] > -1:
1393            print("Randomize using -random=%i"%opts['seed'])
1394            np.random.seed(opts['seed'])
1395        if opts['html']:
1396            show_docs(opts)
1397        elif opts['explore']:
1398            opts['pars'] = parse_pars(opts)
1399            if opts['pars'] is None:
1400                return
1401            explore(opts)
1402        else:
1403            compare(opts)
1404
1405if __name__ == "__main__":
1406    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.