source: sasmodels/sasmodels/compare.py @ 109d963

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 109d963 was 109d963, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

set up random polydispersity; usually sets about 15% pd on a single parameter, can could add pd to 0, 1, 2, or 3 parameters

  • Property mode set to 100755
File size: 47.6 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except Exception:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56
57USAGE = """
58usage: compare.py model N1 N2 [options...] [key=val]
59
60Compare the speed and value for a model between the SasView original and the
61sasmodels rewrite.
62
63model or model1,model2 are the names of the models to compare (see below).
64N1 is the number of times to run sasmodels (default=1).
65N2 is the number times to run sasview (default=1).
66
67Options (* for default):
68
69    -plot*/-noplot plots or suppress the plot of the model
70    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
71    -nq=128 sets the number of Q points in the data set
72    -zero indicates that q=0 should be included
73    -1d*/-2d computes 1d or 2d data
74    -preset*/-random[=seed] preset or random parameters
75    -mono*/-poly force monodisperse or allow polydisperse demo parameters
76    -magnetic/-nonmagnetic* suppress magnetism
77    -cutoff=1e-5* cutoff value for including a point in polydispersity
78    -pars/-nopars* prints the parameter set or not
79    -abs/-rel* plot relative or absolute error
80    -linear/-log*/-q4 intensity scaling
81    -hist/-nohist* plot histogram of relative error
82    -res=0 sets the resolution width dQ/Q if calculating with resolution
83    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
84    -edit starts the parameter explorer
85    -default/-demo* use demo vs default parameters
86    -help/-html shows the model docs instead of running the model
87    -title="note" adds note to the plot title, after the model name
88    -data="path" uses q, dq from the data file
89
90Any two calculation engines can be selected for comparison:
91
92    -single/-double/-half/-fast sets an OpenCL calculation engine
93    -single!/-double!/-quad! sets an OpenMP calculation engine
94    -sasview sets the sasview calculation engine
95
96The default is -single -double.  Note that the interpretation of quad
97precision depends on architecture, and may vary from 64-bit to 128-bit,
98with 80-bit floats being common (1e-19 precision).
99
100Key=value pairs allow you to set specific values for the model parameters.
101Key=value1,value2 to compare different values of the same parameter.
102value can be an expression including other parameters
103"""
104
105# Update docs with command line usage string.   This is separate from the usual
106# doc string so that we can display it at run time if there is an error.
107# lin
108__doc__ = (__doc__  # pylint: disable=redefined-builtin
109           + """
110Program description
111-------------------
112
113"""
114           + USAGE)
115
116kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
117
118# list of math functions for use in evaluating parameters
119MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
120
121# CRUFT python 2.6
122if not hasattr(datetime.timedelta, 'total_seconds'):
123    def delay(dt):
124        """Return number date-time delta as number seconds"""
125        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
126else:
127    def delay(dt):
128        """Return number date-time delta as number seconds"""
129        return dt.total_seconds()
130
131
132class push_seed(object):
133    """
134    Set the seed value for the random number generator.
135
136    When used in a with statement, the random number generator state is
137    restored after the with statement is complete.
138
139    :Parameters:
140
141    *seed* : int or array_like, optional
142        Seed for RandomState
143
144    :Example:
145
146    Seed can be used directly to set the seed::
147
148        >>> from numpy.random import randint
149        >>> push_seed(24)
150        <...push_seed object at...>
151        >>> print(randint(0,1000000,3))
152        [242082    899 211136]
153
154    Seed can also be used in a with statement, which sets the random
155    number generator state for the enclosed computations and restores
156    it to the previous state on completion::
157
158        >>> with push_seed(24):
159        ...    print(randint(0,1000000,3))
160        [242082    899 211136]
161
162    Using nested contexts, we can demonstrate that state is indeed
163    restored after the block completes::
164
165        >>> with push_seed(24):
166        ...    print(randint(0,1000000))
167        ...    with push_seed(24):
168        ...        print(randint(0,1000000,3))
169        ...    print(randint(0,1000000))
170        242082
171        [242082    899 211136]
172        899
173
174    The restore step is protected against exceptions in the block::
175
176        >>> with push_seed(24):
177        ...    print(randint(0,1000000))
178        ...    try:
179        ...        with push_seed(24):
180        ...            print(randint(0,1000000,3))
181        ...            raise Exception()
182        ...    except Exception:
183        ...        print("Exception raised")
184        ...    print(randint(0,1000000))
185        242082
186        [242082    899 211136]
187        Exception raised
188        899
189    """
190    def __init__(self, seed=None):
191        # type: (Optional[int]) -> None
192        self._state = np.random.get_state()
193        np.random.seed(seed)
194
195    def __enter__(self):
196        # type: () -> None
197        pass
198
199    def __exit__(self, exc_type, exc_value, traceback):
200        # type: (Any, BaseException, Any) -> None
201        # TODO: better typing for __exit__ method
202        np.random.set_state(self._state)
203
204def tic():
205    # type: () -> Callable[[], float]
206    """
207    Timer function.
208
209    Use "toc=tic()" to start the clock and "toc()" to measure
210    a time interval.
211    """
212    then = datetime.datetime.now()
213    return lambda: delay(datetime.datetime.now() - then)
214
215
216def set_beam_stop(data, radius, outer=None):
217    # type: (Data, float, float) -> None
218    """
219    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
220
221    Note: this function does not require sasview
222    """
223    if hasattr(data, 'qx_data'):
224        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
225        data.mask = (q < radius)
226        if outer is not None:
227            data.mask |= (q >= outer)
228    else:
229        data.mask = (data.x < radius)
230        if outer is not None:
231            data.mask |= (data.x >= outer)
232
233
234def parameter_range(p, v):
235    # type: (str, float) -> Tuple[float, float]
236    """
237    Choose a parameter range based on parameter name and initial value.
238    """
239    # process the polydispersity options
240    if p.endswith('_pd_n'):
241        return 0., 100.
242    elif p.endswith('_pd_nsigma'):
243        return 0., 5.
244    elif p.endswith('_pd_type'):
245        raise ValueError("Cannot return a range for a string value")
246    elif any(s in p for s in ('theta', 'phi', 'psi')):
247        # orientation in [-180,180], orientation pd in [0,45]
248        if p.endswith('_pd'):
249            return 0., 45.
250        else:
251            return -180., 180.
252    elif p.endswith('_pd'):
253        return 0., 1.
254    elif 'sld' in p:
255        return -0.5, 10.
256    elif p == 'background':
257        return 0., 10.
258    elif p == 'scale':
259        return 0., 1.e3
260    elif v < 0.:
261        return 2.*v, -2.*v
262    else:
263        return 0., (2.*v if v > 0. else 1.)
264
265
266def _randomize_one(model_info, name, value):
267    # type: (ModelInfo, str, float) -> float
268    # type: (ModelInfo, str, str) -> str
269    """
270    Randomize a single parameter.
271    """
272    if name.endswith('_pd'):
273        par = model_info.parameters[name[:-3]]
274        if par.type == 'orientation':
275            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
276            return 180*np.random.beta(2.5, 20)
277        else:
278            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
279            return np.random.beta(1.5, 7)
280
281    if name.endswith('_pd_n'):
282        # let pd be selected globally rather than per parameter
283        return 0
284
285    if name.endswith('_pd_type'):
286        # Don't mess with distribution type for now
287        return 'gaussian'
288
289    if name.endswith('_pd_nsigma'):
290        # type-dependent value; for gaussian use 3.
291        return 3.
292
293    if name == 'background':
294        return np.random.uniform(0, 1)
295
296    if name == 'scale':
297        return 10**np.random.uniform(-5,0)
298
299    par = model_info.parameters[name]
300    if len(par.limits) > 2:  # choice list
301        return np.random.randint(len(par.limits))
302
303    if np.isfinite(par.limits).all():
304        return np.random.uniform(*par.limits)
305
306    if par.type == 'sld':
307        # Range of neutron SLDs
308        return np.random.uniform(-0.5, 12)
309
310    if par.type == 'volume':
311        if ('length' in par.name or
312                'radius' in par.name or
313                'thick' in par.name):
314            return 10**np.random.uniform(2,4)
315
316    low, high = parameter_range(par.name, value)
317    limits = (max(par.limits[0], low), min(par.limits[1], high))
318    return np.random.uniform(*limits)
319
320def _random_pd(model_info, pars):
321    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
322    pd_volume = []
323    pd_oriented = []
324    for p in pd:
325        if p.type == 'orientation':
326            pd_oriented.append(p.name)
327        elif p.length_control is not None:
328            n = pars.get(p.length_control, 1)
329            pd_volume.extend(p.name+str(k+1) for k in range(n))
330        elif p.length > 1:
331            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
332        else:
333            pd_volume.append(p.name)
334    u = np.random.rand()
335    n = len(pd_volume)
336    if u < 0.01 or n < 1:
337        pass  # 1% chance of no polydispersity
338    elif u < 0.86 or n < 2:
339        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
340    elif u < 0.99 or n < 3:
341        choices = np.random.choice(len(pd_volume), size=2)
342        pars[pd_volume[choices[0]]+"_pd_n"] = 25
343        pars[pd_volume[choices[1]]+"_pd_n"] = 10
344    else:
345        choices = np.random.choice(len(pd_volume), size=3)
346        pars[pd_volume[choices[0]]+"_pd_n"] = 25
347        pars[pd_volume[choices[1]]+"_pd_n"] = 10
348        pars[pd_volume[choices[2]]+"_pd_n"] = 5
349    if pd_oriented:
350        pars['theta_pd_n'] = 20
351        if np.random.rand() < 0.1:
352            pars['phi_pd_n'] = 5
353        if np.random.rand() < 0.1:
354            pars['psi_pd_n'] = 5
355
356    ## Show selected polydispersity
357    #for name, value in pars.items():
358    #    if name.endswith('_pd_n') and value > 0:
359    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
360
361
362def randomize_pars(model_info, pars):
363    # type: (ModelInfo, ParameterSet) -> ParameterSet
364    """
365    Generate random values for all of the parameters.
366
367    Valid ranges for the random number generator are guessed from the name of
368    the parameter; this will not account for constraints such as cap radius
369    greater than cylinder radius in the capped_cylinder model, so
370    :func:`constrain_pars` needs to be called afterward..
371    """
372    # Note: the sort guarantees order of calls to random number generator
373    random_pars = dict((p, _randomize_one(model_info, p, v))
374                       for p, v in sorted(pars.items()))
375    if model_info.random is not None:
376        random_pars.update(model_info.random())
377    _random_pd(model_info, random_pars)
378    return random_pars
379
380
381def constrain_pars(model_info, pars):
382    # type: (ModelInfo, ParameterSet) -> None
383    """
384    Restrict parameters to valid values.
385
386    This includes model specific code for models such as capped_cylinder
387    which need to support within model constraints (cap radius more than
388    cylinder radius in this case).
389
390    Warning: this updates the *pars* dictionary in place.
391    """
392    # TODO: move the model specific code to the individual models
393    name = model_info.id
394    # if it is a product model, then just look at the form factor since
395    # none of the structure factors need any constraints.
396    if '*' in name:
397        name = name.split('*')[0]
398
399    # Suppress magnetism for python models (not yet implemented)
400    if callable(model_info.Iq):
401        pars.update(suppress_magnetism(pars))
402
403    if name == 'barbell':
404        if pars['radius_bell'] < pars['radius']:
405            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
406
407    elif name == 'capped_cylinder':
408        if pars['radius_cap'] < pars['radius']:
409            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
410
411    elif name == 'guinier':
412        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
413        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
414        # => ln A - (Rg^2 q^2/3) > -30 ln 10
415        # => Rg^2 q^2/3 < 30 ln 10 + ln A
416        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
417        #q_max = 0.2  # mid q maximum
418        q_max = 1.0  # high q maximum
419        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
420        pars['rg'] = min(pars['rg'], rg_max)
421
422    elif name == 'pearl_necklace':
423        if pars['radius'] < pars['thick_string']:
424            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
425        pass
426
427    elif name == 'rpa':
428        # Make sure phi sums to 1.0
429        if pars['case_num'] < 2:
430            pars['Phi1'] = 0.
431            pars['Phi2'] = 0.
432        elif pars['case_num'] < 5:
433            pars['Phi1'] = 0.
434        total = sum(pars['Phi'+c] for c in '1234')
435        for c in '1234':
436            pars['Phi'+c] /= total
437
438def parlist(model_info, pars, is2d):
439    # type: (ModelInfo, ParameterSet, bool) -> str
440    """
441    Format the parameter list for printing.
442    """
443    lines = []
444    parameters = model_info.parameters
445    magnetic = False
446    for p in parameters.user_parameters(pars, is2d):
447        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
448            continue
449        if p.id.startswith('up:') and not magnetic:
450            continue
451        fields = dict(
452            value=pars.get(p.id, p.default),
453            pd=pars.get(p.id+"_pd", 0.),
454            n=int(pars.get(p.id+"_pd_n", 0)),
455            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
456            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
457            relative_pd=p.relative_pd,
458            M0=pars.get('M0:'+p.id, 0.),
459            mphi=pars.get('mphi:'+p.id, 0.),
460            mtheta=pars.get('mtheta:'+p.id, 0.),
461        )
462        lines.append(_format_par(p.name, **fields))
463        magnetic = magnetic or fields['M0'] != 0.
464    return "\n".join(lines)
465
466    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
467
468def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
469                relative_pd=False, M0=0., mphi=0., mtheta=0.):
470    # type: (str, float, float, int, float, str) -> str
471    line = "%s: %g"%(name, value)
472    if pd != 0.  and n != 0:
473        if relative_pd:
474            pd *= value
475        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
476                % (pd, n, nsigma, nsigma, pdtype)
477    if M0 != 0.:
478        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
479    return line
480
481def suppress_pd(pars):
482    # type: (ParameterSet) -> ParameterSet
483    """
484    Suppress theta_pd for now until the normalization is resolved.
485
486    May also suppress complete polydispersity of the model to test
487    models more quickly.
488    """
489    pars = pars.copy()
490    for p in pars:
491        if p.endswith("_pd_n"):
492            pars[p] = 0
493    return pars
494
495def suppress_magnetism(pars):
496    # type: (ParameterSet) -> ParameterSet
497    """
498    Suppress theta_pd for now until the normalization is resolved.
499
500    May also suppress complete polydispersity of the model to test
501    models more quickly.
502    """
503    pars = pars.copy()
504    for p in pars:
505        if p.startswith("M0:"): pars[p] = 0
506    return pars
507
508def eval_sasview(model_info, data):
509    # type: (Modelinfo, Data) -> Calculator
510    """
511    Return a model calculator using the pre-4.0 SasView models.
512    """
513    # importing sas here so that the error message will be that sas failed to
514    # import rather than the more obscure smear_selection not imported error
515    import sas
516    import sas.models
517    from sas.models.qsmearing import smear_selection
518    from sas.models.MultiplicationModel import MultiplicationModel
519    from sas.models.dispersion_models import models as dispersers
520
521    def get_model_class(name):
522        # type: (str) -> "sas.models.BaseComponent"
523        #print("new",sorted(_pars.items()))
524        __import__('sas.models.' + name)
525        ModelClass = getattr(getattr(sas.models, name, None), name, None)
526        if ModelClass is None:
527            raise ValueError("could not find model %r in sas.models"%name)
528        return ModelClass
529
530    # WARNING: ugly hack when handling model!
531    # Sasview models with multiplicity need to be created with the target
532    # multiplicity, so we cannot create the target model ahead of time for
533    # for multiplicity models.  Instead we store the model in a list and
534    # update the first element of that list with the new multiplicity model
535    # every time we evaluate.
536
537    # grab the sasview model, or create it if it is a product model
538    if model_info.composition:
539        composition_type, parts = model_info.composition
540        if composition_type == 'product':
541            P, S = [get_model_class(revert_name(p))() for p in parts]
542            model = [MultiplicationModel(P, S)]
543        else:
544            raise ValueError("sasview mixture models not supported by compare")
545    else:
546        old_name = revert_name(model_info)
547        if old_name is None:
548            raise ValueError("model %r does not exist in old sasview"
549                            % model_info.id)
550        ModelClass = get_model_class(old_name)
551        model = [ModelClass()]
552    model[0].disperser_handles = {}
553
554    # build a smearer with which to call the model, if necessary
555    smearer = smear_selection(data, model=model)
556    if hasattr(data, 'qx_data'):
557        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
558        index = ((~data.mask) & (~np.isnan(data.data))
559                 & (q >= data.qmin) & (q <= data.qmax))
560        if smearer is not None:
561            smearer.model = model  # because smear_selection has a bug
562            smearer.accuracy = data.accuracy
563            smearer.set_index(index)
564            def _call_smearer():
565                smearer.model = model[0]
566                return smearer.get_value()
567            theory = _call_smearer
568        else:
569            theory = lambda: model[0].evalDistribution([data.qx_data[index],
570                                                        data.qy_data[index]])
571    elif smearer is not None:
572        theory = lambda: smearer(model[0].evalDistribution(data.x))
573    else:
574        theory = lambda: model[0].evalDistribution(data.x)
575
576    def calculator(**pars):
577        # type: (float, ...) -> np.ndarray
578        """
579        Sasview calculator for model.
580        """
581        oldpars = revert_pars(model_info, pars)
582        # For multiplicity models, create a model with the correct multiplicity
583        control = oldpars.pop("CONTROL", None)
584        if control is not None:
585            # sphericalSLD has one fewer multiplicity.  This update should
586            # happen in revert_pars, but it hasn't been called yet.
587            model[0] = ModelClass(control)
588        # paying for parameter conversion each time to keep life simple, if not fast
589        for k, v in oldpars.items():
590            if k.endswith('.type'):
591                par = k[:-5]
592                if v == 'gaussian': continue
593                cls = dispersers[v if v != 'rectangle' else 'rectangula']
594                handle = cls()
595                model[0].disperser_handles[par] = handle
596                try:
597                    model[0].set_dispersion(par, handle)
598                except Exception:
599                    exception.annotate_exception("while setting %s to %r"
600                                                 %(par, v))
601                    raise
602
603
604        #print("sasview pars",oldpars)
605        for k, v in oldpars.items():
606            name_attr = k.split('.')  # polydispersity components
607            if len(name_attr) == 2:
608                par, disp_par = name_attr
609                model[0].dispersion[par][disp_par] = v
610            else:
611                model[0].setParam(k, v)
612        return theory()
613
614    calculator.engine = "sasview"
615    return calculator
616
617DTYPE_MAP = {
618    'half': '16',
619    'fast': 'fast',
620    'single': '32',
621    'double': '64',
622    'quad': '128',
623    'f16': '16',
624    'f32': '32',
625    'f64': '64',
626    'float16': '16',
627    'float32': '32',
628    'float64': '64',
629    'float128': '128',
630    'longdouble': '128',
631}
632def eval_opencl(model_info, data, dtype='single', cutoff=0.):
633    # type: (ModelInfo, Data, str, float) -> Calculator
634    """
635    Return a model calculator using the OpenCL calculation engine.
636    """
637    if not core.HAVE_OPENCL:
638        raise RuntimeError("OpenCL not available")
639    model = core.build_model(model_info, dtype=dtype, platform="ocl")
640    calculator = DirectModel(data, model, cutoff=cutoff)
641    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
642    return calculator
643
644def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
645    # type: (ModelInfo, Data, str, float) -> Calculator
646    """
647    Return a model calculator using the DLL calculation engine.
648    """
649    model = core.build_model(model_info, dtype=dtype, platform="dll")
650    calculator = DirectModel(data, model, cutoff=cutoff)
651    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
652    return calculator
653
654def time_calculation(calculator, pars, evals=1):
655    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
656    """
657    Compute the average calculation time over N evaluations.
658
659    An additional call is generated without polydispersity in order to
660    initialize the calculation engine, and make the average more stable.
661    """
662    # initialize the code so time is more accurate
663    if evals > 1:
664        calculator(**suppress_pd(pars))
665    toc = tic()
666    # make sure there is at least one eval
667    value = calculator(**pars)
668    for _ in range(evals-1):
669        value = calculator(**pars)
670    average_time = toc()*1000. / evals
671    #print("I(q)",value)
672    return value, average_time
673
674def make_data(opts):
675    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
676    """
677    Generate an empty dataset, used with the model to set Q points
678    and resolution.
679
680    *opts* contains the options, with 'qmax', 'nq', 'res',
681    'accuracy', 'is2d' and 'view' parsed from the command line.
682    """
683    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
684    if opts['is2d']:
685        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
686        data = empty_data2D(q, resolution=res)
687        data.accuracy = opts['accuracy']
688        set_beam_stop(data, 0.0004)
689        index = ~data.mask
690    else:
691        if opts['view'] == 'log' and not opts['zero']:
692            qmax = math.log10(qmax)
693            q = np.logspace(qmax-3, qmax, nq)
694        else:
695            q = np.linspace(0.001*qmax, qmax, nq)
696        if opts['zero']:
697            q = np.hstack((0, q))
698        data = empty_data1D(q, resolution=res)
699        index = slice(None, None)
700    return data, index
701
702def make_engine(model_info, data, dtype, cutoff):
703    # type: (ModelInfo, Data, str, float) -> Calculator
704    """
705    Generate the appropriate calculation engine for the given datatype.
706
707    Datatypes with '!' appended are evaluated using external C DLLs rather
708    than OpenCL.
709    """
710    if dtype == 'sasview':
711        return eval_sasview(model_info, data)
712    elif dtype.endswith('!'):
713        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
714    else:
715        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
716
717def _show_invalid(data, theory):
718    # type: (Data, np.ma.ndarray) -> None
719    """
720    Display a list of the non-finite values in theory.
721    """
722    if not theory.mask.any():
723        return
724
725    if hasattr(data, 'x'):
726        bad = zip(data.x[theory.mask], theory[theory.mask])
727        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
728
729
730def compare(opts, limits=None):
731    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
732    """
733    Preform a comparison using options from the command line.
734
735    *limits* are the limits on the values to use, either to set the y-axis
736    for 1D or to set the colormap scale for 2D.  If None, then they are
737    inferred from the data and returned. When exploring using Bumps,
738    the limits are set when the model is initially called, and maintained
739    as the values are adjusted, making it easier to see the effects of the
740    parameters.
741    """
742    limits = np.Inf, -np.Inf
743    for k in range(opts['sets']):
744        opts['pars'] = parse_pars(opts)
745        result = run_models(opts, verbose=True)
746        if opts['plot']:
747            limits = plot_models(opts, result, limits=limits, setnum=k)
748    if opts['plot']:
749        import matplotlib.pyplot as plt
750        plt.show()
751
752def run_models(opts, verbose=False):
753    # type: (Dict[str, Any]) -> Dict[str, Any]
754
755    n_base, n_comp = opts['count']
756    pars, pars2 = opts['pars']
757    data = opts['data']
758
759    # silence the linter
760    base = opts['engines'][0] if n_base else None
761    comp = opts['engines'][1] if n_comp else None
762
763    base_time = comp_time = None
764    base_value = comp_value = resid = relerr = None
765
766    # Base calculation
767    if n_base > 0:
768        try:
769            base_raw, base_time = time_calculation(base, pars, n_base)
770            base_value = np.ma.masked_invalid(base_raw)
771            if verbose:
772                print("%s t=%.2f ms, intensity=%.0f"
773                      % (base.engine, base_time, base_value.sum()))
774            _show_invalid(data, base_value)
775        except ImportError:
776            traceback.print_exc()
777            n_base = 0
778
779    # Comparison calculation
780    if n_comp > 0:
781        try:
782            comp_raw, comp_time = time_calculation(comp, pars2, n_comp)
783            comp_value = np.ma.masked_invalid(comp_raw)
784            if verbose:
785                print("%s t=%.2f ms, intensity=%.0f"
786                      % (comp.engine, comp_time, comp_value.sum()))
787            _show_invalid(data, comp_value)
788        except ImportError:
789            traceback.print_exc()
790            n_comp = 0
791
792    # Compare, but only if computing both forms
793    if n_base > 0 and n_comp > 0:
794        resid = (base_value - comp_value)
795        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
796        if verbose:
797            _print_stats("|%s-%s|"
798                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
799                         resid)
800            _print_stats("|(%s-%s)/%s|"
801                         % (base.engine, comp.engine, comp.engine),
802                         relerr)
803
804    return dict(base_value=base_value, comp_value=comp_value,
805                base_time=base_time, comp_time=comp_time,
806                resid=resid, relerr=relerr)
807
808
809def _print_stats(label, err):
810    # type: (str, np.ma.ndarray) -> None
811    # work with trimmed data, not the full set
812    sorted_err = np.sort(abs(err.compressed()))
813    if len(sorted_err) == 0.:
814        print(label + "  no valid values")
815        return
816
817    p50 = int((len(sorted_err)-1)*0.50)
818    p98 = int((len(sorted_err)-1)*0.98)
819    data = [
820        "max:%.3e"%sorted_err[-1],
821        "median:%.3e"%sorted_err[p50],
822        "98%%:%.3e"%sorted_err[p98],
823        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
824        "zero-offset:%+.3e"%np.mean(sorted_err),
825        ]
826    print(label+"  "+"  ".join(data))
827
828
829def plot_models(opts, result, limits=(np.Inf, -np.Inf), setnum=0):
830    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
831    base_value, comp_value= result['base_value'], result['comp_value']
832    base_time, comp_time = result['base_time'], result['comp_time']
833    resid, relerr = result['resid'], result['relerr']
834
835    have_base, have_comp = (base_value is not None), (comp_value is not None)
836    base = opts['engines'][0] if have_base else None
837    comp = opts['engines'][1] if have_comp else None
838    data = opts['data']
839    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
840
841    # Plot if requested
842    view = opts['view']
843    import matplotlib.pyplot as plt
844    vmin, vmax = limits
845    if have_base:
846        vmin = min(vmin, base_value.min())
847        vmax = max(vmax, base_value.max())
848    if have_comp:
849        vmin = min(vmin, comp_value.min())
850        vmax = max(vmax, comp_value.max())
851    limits = vmin, vmax
852
853    if have_base:
854        if have_comp: plt.subplot(131)
855        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
856        plt.title("%s t=%.2f ms"%(base.engine, base_time))
857        #cbar_title = "log I"
858    if have_comp:
859        if have_base: plt.subplot(132)
860        if not opts['is2d'] and have_base:
861            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
862        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
863        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
864        #cbar_title = "log I"
865    if have_base and have_comp:
866        plt.subplot(133)
867        if not opts['rel_err']:
868            err, errstr, errview = resid, "abs err", "linear"
869        else:
870            err, errstr, errview = abs(relerr), "rel err", "log"
871        if 0:  # 95% cutoff
872            sorted = np.sort(err.flatten())
873            cutoff = sorted[int(sorted.size*0.95)]
874            err[err>cutoff] = cutoff
875        #err,errstr = base/comp,"ratio"
876        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
877        if view == 'linear':
878            plt.xscale('linear')
879        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
880        #cbar_title = errstr if errview=="linear" else "log "+errstr
881    #if is2D:
882    #    h = plt.colorbar()
883    #    h.ax.set_title(cbar_title)
884    fig = plt.gcf()
885    extra_title = ' '+opts['title'] if opts['title'] else ''
886    fig.suptitle(":".join(opts['name']) + extra_title)
887
888    if have_base and have_comp and opts['show_hist']:
889        plt.figure()
890        v = relerr
891        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
892        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
893        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
894                   % (base.engine, comp.engine, comp.engine))
895        plt.ylabel('P(err)')
896        plt.title('Distribution of relative error between calculation engines')
897
898    return limits
899
900
901# ===========================================================================
902#
903NAME_OPTIONS = set([
904    'plot', 'noplot',
905    'half', 'fast', 'single', 'double',
906    'single!', 'double!', 'quad!', 'sasview',
907    'lowq', 'midq', 'highq', 'exq', 'zero',
908    '2d', '1d',
909    'preset', 'random',
910    'poly', 'mono',
911    'magnetic', 'nonmagnetic',
912    'nopars', 'pars',
913    'rel', 'abs',
914    'linear', 'log', 'q4',
915    'hist', 'nohist',
916    'edit', 'html', 'help',
917    'demo', 'default',
918    ])
919VALUE_OPTIONS = [
920    # Note: random is both a name option and a value option
921    'cutoff', 'random', 'nq', 'res', 'accuracy', 'title', 'data', 'sets'
922    ]
923
924def columnize(items, indent="", width=79):
925    # type: (List[str], str, int) -> str
926    """
927    Format a list of strings into columns.
928
929    Returns a string with carriage returns ready for printing.
930    """
931    column_width = max(len(w) for w in items) + 1
932    num_columns = (width - len(indent)) // column_width
933    num_rows = len(items) // num_columns
934    items = items + [""] * (num_rows * num_columns - len(items))
935    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
936    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
937             for row in zip(*columns)]
938    output = indent + ("\n"+indent).join(lines)
939    return output
940
941
942def get_pars(model_info, use_demo=False):
943    # type: (ModelInfo, bool) -> ParameterSet
944    """
945    Extract demo parameters from the model definition.
946    """
947    # Get the default values for the parameters
948    pars = {}
949    for p in model_info.parameters.call_parameters:
950        parts = [('', p.default)]
951        if p.polydisperse:
952            parts.append(('_pd', 0.0))
953            parts.append(('_pd_n', 0))
954            parts.append(('_pd_nsigma', 3.0))
955            parts.append(('_pd_type', "gaussian"))
956        for ext, val in parts:
957            if p.length > 1:
958                dict(("%s%d%s" % (p.id, k, ext), val)
959                     for k in range(1, p.length+1))
960            else:
961                pars[p.id + ext] = val
962
963    # Plug in values given in demo
964    if use_demo:
965        pars.update(model_info.demo)
966    return pars
967
968INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
969def isnumber(str):
970    match = FLOAT_RE.match(str)
971    isfloat = (match and not str[match.end():])
972    return isfloat or INTEGER_RE.match(str)
973
974# For distinguishing pairs of models for comparison
975# key-value pair separator =
976# shell characters  | & ; <> $ % ' " \ # `
977# model and parameter names _
978# parameter expressions - + * / . ( )
979# path characters including tilde expansion and windows drive ~ / :
980# not sure about brackets [] {}
981# maybe one of the following @ ? ^ ! ,
982MODEL_SPLIT = ','
983def parse_opts(argv):
984    # type: (List[str]) -> Dict[str, Any]
985    """
986    Parse command line options.
987    """
988    MODELS = core.list_models()
989    flags = [arg for arg in argv
990             if arg.startswith('-')]
991    values = [arg for arg in argv
992              if not arg.startswith('-') and '=' in arg]
993    positional_args = [arg for arg in argv
994                       if not arg.startswith('-') and '=' not in arg]
995    models = "\n    ".join("%-15s"%v for v in MODELS)
996    if len(positional_args) == 0:
997        print(USAGE)
998        print("\nAvailable models:")
999        print(columnize(MODELS, indent="  "))
1000        return None
1001    if len(positional_args) > 3:
1002        print("expected parameters: model N1 N2")
1003
1004    invalid = [o[1:] for o in flags
1005               if o[1:] not in NAME_OPTIONS
1006               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1007    if invalid:
1008        print("Invalid options: %s"%(", ".join(invalid)))
1009        return None
1010
1011    name = positional_args[0]
1012    n1 = int(positional_args[1]) if len(positional_args) > 1 else 1
1013    n2 = int(positional_args[2]) if len(positional_args) > 2 else 1
1014
1015    # pylint: disable=bad-whitespace
1016    # Interpret the flags
1017    opts = {
1018        'plot'      : True,
1019        'view'      : 'log',
1020        'is2d'      : False,
1021        'qmax'      : 0.05,
1022        'nq'        : 128,
1023        'res'       : 0.0,
1024        'accuracy'  : 'Low',
1025        'cutoff'    : 0.0,
1026        'seed'      : -1,  # default to preset
1027        'mono'      : True,
1028        # Default to magnetic a magnetic moment is set on the command line
1029        'magnetic'  : False,
1030        'show_pars' : False,
1031        'show_hist' : False,
1032        'rel_err'   : True,
1033        'explore'   : False,
1034        'use_demo'  : True,
1035        'zero'      : False,
1036        'html'      : False,
1037        'title'     : None,
1038        'datafile'  : None,
1039        'sets'      : 1,
1040    }
1041    engines = []
1042    for arg in flags:
1043        if arg == '-noplot':    opts['plot'] = False
1044        elif arg == '-plot':    opts['plot'] = True
1045        elif arg == '-linear':  opts['view'] = 'linear'
1046        elif arg == '-log':     opts['view'] = 'log'
1047        elif arg == '-q4':      opts['view'] = 'q4'
1048        elif arg == '-1d':      opts['is2d'] = False
1049        elif arg == '-2d':      opts['is2d'] = True
1050        elif arg == '-exq':     opts['qmax'] = 10.0
1051        elif arg == '-highq':   opts['qmax'] = 1.0
1052        elif arg == '-midq':    opts['qmax'] = 0.2
1053        elif arg == '-lowq':    opts['qmax'] = 0.05
1054        elif arg == '-zero':    opts['zero'] = True
1055        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1056        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1057        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1058        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1059        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
1060        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
1061        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1062        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1063        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
1064        elif arg == '-preset':  opts['seed'] = -1
1065        elif arg == '-mono':    opts['mono'] = True
1066        elif arg == '-poly':    opts['mono'] = False
1067        elif arg == '-magnetic':       opts['magnetic'] = True
1068        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1069        elif arg == '-pars':    opts['show_pars'] = True
1070        elif arg == '-nopars':  opts['show_pars'] = False
1071        elif arg == '-hist':    opts['show_hist'] = True
1072        elif arg == '-nohist':  opts['show_hist'] = False
1073        elif arg == '-rel':     opts['rel_err'] = True
1074        elif arg == '-abs':     opts['rel_err'] = False
1075        elif arg == '-half':    engines.append(arg[1:])
1076        elif arg == '-fast':    engines.append(arg[1:])
1077        elif arg == '-single':  engines.append(arg[1:])
1078        elif arg == '-double':  engines.append(arg[1:])
1079        elif arg == '-single!': engines.append(arg[1:])
1080        elif arg == '-double!': engines.append(arg[1:])
1081        elif arg == '-quad!':   engines.append(arg[1:])
1082        elif arg == '-sasview': engines.append(arg[1:])
1083        elif arg == '-edit':    opts['explore'] = True
1084        elif arg == '-demo':    opts['use_demo'] = True
1085        elif arg == '-default':    opts['use_demo'] = False
1086        elif arg == '-html':    opts['html'] = True
1087        elif arg == '-help':    opts['html'] = True
1088    # pylint: enable=bad-whitespace
1089
1090    # Force random if more than one set
1091    if opts['sets'] > 1 and opts['seed'] < 0:
1092        opts['seed'] = np.random.randint(1000000)
1093
1094    if MODEL_SPLIT in name:
1095        name, name2 = name.split(MODEL_SPLIT, 2)
1096    else:
1097        name2 = name
1098    try:
1099        model_info = core.load_model_info(name)
1100        model_info2 = core.load_model_info(name2) if name2 != name else model_info
1101    except ImportError as exc:
1102        print(str(exc))
1103        print("Could not find model; use one of:\n    " + models)
1104        return None
1105
1106    # TODO: check if presets are different when deciding if models are same
1107    same_model = name == name2
1108    if len(engines) == 0:
1109        if same_model:
1110            engines.extend(['single', 'double'])
1111        else:
1112            engines.extend(['single', 'single'])
1113    elif len(engines) == 1:
1114        if not same_model:
1115            engines.append(engines[0])
1116        elif engines[0] == 'double':
1117            engines.append('single')
1118        else:
1119            engines.append('double')
1120    elif len(engines) > 2:
1121        del engines[2:]
1122
1123    # Create the computational engines
1124    if opts['datafile'] is not None:
1125        data = load_data(os.path.expanduser(opts['datafile']))
1126    else:
1127        data, _ = make_data(opts)
1128    if n1:
1129        base = make_engine(model_info, data, engines[0], opts['cutoff'])
1130    else:
1131        base = None
1132    if n2:
1133        comp = make_engine(model_info2, data, engines[1], opts['cutoff'])
1134    else:
1135        comp = None
1136
1137    # pylint: disable=bad-whitespace
1138    # Remember it all
1139    opts.update({
1140        'data'      : data,
1141        'name'      : [name, name2],
1142        'def'       : [model_info, model_info2],
1143        'count'     : [n1, n2],
1144        'engines'   : [base, comp],
1145        'values'    : values,
1146    })
1147    # pylint: enable=bad-whitespace
1148
1149    return opts
1150
1151def parse_pars(opts):
1152    model_info, model_info2 = opts['def']
1153
1154    # Get demo parameters from model definition, or use default parameters
1155    # if model does not define demo parameters
1156    pars = get_pars(model_info, opts['use_demo'])
1157    pars2 = get_pars(model_info2, opts['use_demo'])
1158    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1159    # randomize parameters
1160    #pars.update(set_pars)  # set value before random to control range
1161    if opts['seed'] > -1:
1162        pars = randomize_pars(model_info, pars)
1163        if model_info != model_info2:
1164            pars2 = randomize_pars(model_info2, pars2)
1165            # Share values for parameters with the same name
1166            for k, v in pars.items():
1167                if k in pars2:
1168                    pars2[k] = v
1169        else:
1170            pars2 = pars.copy()
1171        constrain_pars(model_info, pars)
1172        constrain_pars(model_info2, pars2)
1173    if opts['mono']:
1174        pars = suppress_pd(pars)
1175        pars2 = suppress_pd(pars2)
1176    if not opts['magnetic']:
1177        pars = suppress_magnetism(pars)
1178        pars2 = suppress_magnetism(pars2)
1179
1180    # Fill in parameters given on the command line
1181    presets = {}
1182    presets2 = {}
1183    for arg in opts['values']:
1184        k, v = arg.split('=', 1)
1185        if k not in pars and k not in pars2:
1186            # extract base name without polydispersity info
1187            s = set(p.split('_pd')[0] for p in pars)
1188            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1189            return None
1190        v1, v2 = v.split(MODEL_SPLIT, 2) if MODEL_SPLIT in v else (v,v)
1191        if v1 and k in pars:
1192            presets[k] = float(v1) if isnumber(v1) else v1
1193        if v2 and k in pars2:
1194            presets2[k] = float(v2) if isnumber(v2) else v2
1195
1196    # If pd given on the command line, default pd_n to 35
1197    for k, v in list(presets.items()):
1198        if k.endswith('_pd'):
1199            presets.setdefault(k+'_n', 35.)
1200    for k, v in list(presets2.items()):
1201        if k.endswith('_pd'):
1202            presets2.setdefault(k+'_n', 35.)
1203
1204    # Evaluate preset parameter expressions
1205    context = MATH.copy()
1206    context['np'] = np
1207    context.update(pars)
1208    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1209    for k, v in presets.items():
1210        if not isinstance(v, float) and not k.endswith('_type'):
1211            presets[k] = eval(v, context)
1212    context.update(presets)
1213    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1214    for k, v in presets2.items():
1215        if not isinstance(v, float) and not k.endswith('_type'):
1216            presets2[k] = eval(v, context)
1217
1218    # update parameters with presets
1219    pars.update(presets)  # set value after random to control value
1220    pars2.update(presets2)  # set value after random to control value
1221    #import pprint; pprint.pprint(model_info)
1222
1223    if opts['show_pars']:
1224        if model_info.name != model_info2.name or pars != pars2:
1225            print("==== %s ====="%model_info.name)
1226            print(str(parlist(model_info, pars, opts['is2d'])))
1227            print("==== %s ====="%model_info2.name)
1228            print(str(parlist(model_info2, pars2, opts['is2d'])))
1229        else:
1230            print(str(parlist(model_info, pars, opts['is2d'])))
1231
1232    return pars, pars2
1233
1234def show_docs(opts):
1235    # type: (Dict[str, Any]) -> None
1236    """
1237    show html docs for the model
1238    """
1239    import os
1240    from .generate import make_html
1241    from . import rst2html
1242
1243    info = opts['def'][0]
1244    html = make_html(info)
1245    path = os.path.dirname(info.filename)
1246    url = "file://"+path.replace("\\","/")[2:]+"/"
1247    rst2html.view_html_qtapp(html, url)
1248
1249def explore(opts):
1250    # type: (Dict[str, Any]) -> None
1251    """
1252    explore the model using the bumps gui.
1253    """
1254    import wx  # type: ignore
1255    from bumps.names import FitProblem  # type: ignore
1256    from bumps.gui.app_frame import AppFrame  # type: ignore
1257    from bumps.gui import signal
1258
1259    is_mac = "cocoa" in wx.version()
1260    # Create an app if not running embedded
1261    app = wx.App() if wx.GetApp() is None else None
1262    model = Explore(opts)
1263    problem = FitProblem(model)
1264    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1265    if not is_mac:
1266        frame.Show()
1267    frame.panel.set_model(model=problem)
1268    frame.panel.Layout()
1269    frame.panel.aui.Split(0, wx.TOP)
1270    def reset_parameters(event):
1271        model.revert_values()
1272        signal.update_parameters(problem)
1273    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1274    if is_mac: frame.Show()
1275    # If running withing an app, start the main loop
1276    if app:
1277        app.MainLoop()
1278
1279class Explore(object):
1280    """
1281    Bumps wrapper for a SAS model comparison.
1282
1283    The resulting object can be used as a Bumps fit problem so that
1284    parameters can be adjusted in the GUI, with plots updated on the fly.
1285    """
1286    def __init__(self, opts):
1287        # type: (Dict[str, Any]) -> None
1288        from bumps.cli import config_matplotlib  # type: ignore
1289        from . import bumps_model
1290        config_matplotlib()
1291        self.opts = opts
1292        opts['pars'] = list(opts['pars'])
1293        p1, p2 = opts['pars']
1294        m1, m2 = opts['def']
1295        self.fix_p2 = m1 != m2 or p1 != p2
1296        model_info = m1
1297        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1298        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1299        if not opts['is2d']:
1300            for p in model_info.parameters.user_parameters({}, is2d=False):
1301                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1302                    k = p.name+ext
1303                    v = pars.get(k, None)
1304                    if v is not None:
1305                        v.range(*parameter_range(k, v.value))
1306        else:
1307            for k, v in pars.items():
1308                v.range(*parameter_range(k, v.value))
1309
1310        self.pars = pars
1311        self.starting_values = dict((k, v.value) for k, v in pars.items())
1312        self.pd_types = pd_types
1313        self.limits = np.Inf, -np.Inf
1314
1315    def revert_values(self):
1316        for k, v in self.starting_values.items():
1317            self.pars[k].value = v
1318
1319    def model_update(self):
1320        pass
1321
1322    def numpoints(self):
1323        # type: () -> int
1324        """
1325        Return the number of points.
1326        """
1327        return len(self.pars) + 1  # so dof is 1
1328
1329    def parameters(self):
1330        # type: () -> Any   # Dict/List hierarchy of parameters
1331        """
1332        Return a dictionary of parameters.
1333        """
1334        return self.pars
1335
1336    def nllf(self):
1337        # type: () -> float
1338        """
1339        Return cost.
1340        """
1341        # pylint: disable=no-self-use
1342        return 0.  # No nllf
1343
1344    def plot(self, view='log'):
1345        # type: (str) -> None
1346        """
1347        Plot the data and residuals.
1348        """
1349        pars = dict((k, v.value) for k, v in self.pars.items())
1350        pars.update(self.pd_types)
1351        self.opts['pars'][0] = pars
1352        if not self.fix_p2:
1353            self.opts['pars'][1] = pars
1354        result = run_models(self.opts)
1355        limits = plot_models(self.opts, result, limits=self.limits)
1356        if self.limits is None:
1357            vmin, vmax = limits
1358            self.limits = vmax*1e-7, 1.3*vmax
1359            import pylab; pylab.clf()
1360            plot_models(self.opts, result, limits=self.limits)
1361
1362
1363def main(*argv):
1364    # type: (*str) -> None
1365    """
1366    Main program.
1367    """
1368    opts = parse_opts(argv)
1369    if opts is not None:
1370        if opts['seed'] > -1:
1371            print("Randomize using -random=%i"%opts['seed'])
1372            np.random.seed(opts['seed'])
1373        if opts['html']:
1374            show_docs(opts)
1375        elif opts['explore']:
1376            opts['pars'] = parse_pars(opts)
1377            explore(opts)
1378        else:
1379            compare(opts)
1380
1381if __name__ == "__main__":
1382    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.