source: sasmodels/sasmodels/compare.py @ bb39b4a

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since bb39b4a was bb39b4a, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

make the sascomp command line more regular

  • Property mode set to 100755
File size: 50.5 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except Exception:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56
57USAGE = """
58usage: sascomp model [options...] [key=val]
59
60Generate and compare SAS models.  If a single model is specified it shows
61a plot of that model.  Different models can be compared, or the same model
62with different parameters.  The same model with the same parameters can
63be compared with different calculation engines to see the effects of precision
64on the resultant values.
65
66model or model1,model2 are the names of the models to compare (see below).
67
68Options (* for default):
69
70    === data generation ===
71    -data="path" uses q, dq from the data file
72    -noise=0 sets the measurement error dI/I
73    -res=0 sets the resolution width dQ/Q if calculating with resolution
74    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
75    -nq=128 sets the number of Q points in the data set
76    -1d*/-2d computes 1d or 2d data
77    -zero indicates that q=0 should be included
78
79    === model parameters ===
80    -preset*/-random[=seed] preset or random parameters
81    -sets=n generates n random datasets, with the seed given by -random=seed
82    -pars/-nopars* prints the parameter set or not
83    -default/-demo* use demo vs default parameters
84
85    === calculation options ===
86    -mono*/-poly force monodisperse or allow polydisperse demo parameters
87    -cutoff=1e-5* cutoff value for including a point in polydispersity
88    -magnetic/-nonmagnetic* suppress magnetism
89    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
90
91    === precision options ===
92    -calc=default uses the default calcution precision
93    -single/-double/-half/-fast sets an OpenCL calculation engine
94    -single!/-double!/-quad! sets an OpenMP calculation engine
95    -sasview sets the sasview calculation engine
96
97    === plotting ===
98    -plot*/-noplot plots or suppress the plot of the model
99    -linear/-log*/-q4 intensity scaling on plots
100    -hist/-nohist* plot histogram of relative error
101    -abs/-rel* plot relative or absolute error
102    -title="note" adds note to the plot title, after the model name
103
104    === output options ===
105    -edit starts the parameter explorer
106    -help/-html shows the model docs instead of running the model
107
108The interpretation of quad precision depends on architecture, and may
109vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
110On unix and mac you may need single quotes around the DLL computation
111engines, such as -calc='single!,double!' since !, is treated as a history
112expansion request in the shell.
113
114Key=value pairs allow you to set specific values for the model parameters.
115Key=value1,value2 to compare different values of the same parameter. The
116value can be an expression including other parameters.
117
118Items later on the command line override those that appear earlier.
119
120Examples:
121
122    # compare single and double precision calculation for a barbell
123    sascomp barbell -calc=single,double
124
125    # generate 10 random lorentz models, with seed=27
126    sascomp lorentz -sets=10 -seed=27
127
128    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
129    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
130
131    # model timing test requires multiple evals to perform the estimate
132    sascomp pringle -calc=single,double -timing=100,100 -noplot
133"""
134
135# Update docs with command line usage string.   This is separate from the usual
136# doc string so that we can display it at run time if there is an error.
137# lin
138__doc__ = (__doc__  # pylint: disable=redefined-builtin
139           + """
140Program description
141-------------------
142
143""" + USAGE)
144
145kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
146
147# list of math functions for use in evaluating parameters
148MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
149
150# CRUFT python 2.6
151if not hasattr(datetime.timedelta, 'total_seconds'):
152    def delay(dt):
153        """Return number date-time delta as number seconds"""
154        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
155else:
156    def delay(dt):
157        """Return number date-time delta as number seconds"""
158        return dt.total_seconds()
159
160
161class push_seed(object):
162    """
163    Set the seed value for the random number generator.
164
165    When used in a with statement, the random number generator state is
166    restored after the with statement is complete.
167
168    :Parameters:
169
170    *seed* : int or array_like, optional
171        Seed for RandomState
172
173    :Example:
174
175    Seed can be used directly to set the seed::
176
177        >>> from numpy.random import randint
178        >>> push_seed(24)
179        <...push_seed object at...>
180        >>> print(randint(0,1000000,3))
181        [242082    899 211136]
182
183    Seed can also be used in a with statement, which sets the random
184    number generator state for the enclosed computations and restores
185    it to the previous state on completion::
186
187        >>> with push_seed(24):
188        ...    print(randint(0,1000000,3))
189        [242082    899 211136]
190
191    Using nested contexts, we can demonstrate that state is indeed
192    restored after the block completes::
193
194        >>> with push_seed(24):
195        ...    print(randint(0,1000000))
196        ...    with push_seed(24):
197        ...        print(randint(0,1000000,3))
198        ...    print(randint(0,1000000))
199        242082
200        [242082    899 211136]
201        899
202
203    The restore step is protected against exceptions in the block::
204
205        >>> with push_seed(24):
206        ...    print(randint(0,1000000))
207        ...    try:
208        ...        with push_seed(24):
209        ...            print(randint(0,1000000,3))
210        ...            raise Exception()
211        ...    except Exception:
212        ...        print("Exception raised")
213        ...    print(randint(0,1000000))
214        242082
215        [242082    899 211136]
216        Exception raised
217        899
218    """
219    def __init__(self, seed=None):
220        # type: (Optional[int]) -> None
221        self._state = np.random.get_state()
222        np.random.seed(seed)
223
224    def __enter__(self):
225        # type: () -> None
226        pass
227
228    def __exit__(self, exc_type, exc_value, traceback):
229        # type: (Any, BaseException, Any) -> None
230        # TODO: better typing for __exit__ method
231        np.random.set_state(self._state)
232
233def tic():
234    # type: () -> Callable[[], float]
235    """
236    Timer function.
237
238    Use "toc=tic()" to start the clock and "toc()" to measure
239    a time interval.
240    """
241    then = datetime.datetime.now()
242    return lambda: delay(datetime.datetime.now() - then)
243
244
245def set_beam_stop(data, radius, outer=None):
246    # type: (Data, float, float) -> None
247    """
248    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
249
250    Note: this function does not require sasview
251    """
252    if hasattr(data, 'qx_data'):
253        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
254        data.mask = (q < radius)
255        if outer is not None:
256            data.mask |= (q >= outer)
257    else:
258        data.mask = (data.x < radius)
259        if outer is not None:
260            data.mask |= (data.x >= outer)
261
262
263def parameter_range(p, v):
264    # type: (str, float) -> Tuple[float, float]
265    """
266    Choose a parameter range based on parameter name and initial value.
267    """
268    # process the polydispersity options
269    if p.endswith('_pd_n'):
270        return 0., 100.
271    elif p.endswith('_pd_nsigma'):
272        return 0., 5.
273    elif p.endswith('_pd_type'):
274        raise ValueError("Cannot return a range for a string value")
275    elif any(s in p for s in ('theta', 'phi', 'psi')):
276        # orientation in [-180,180], orientation pd in [0,45]
277        if p.endswith('_pd'):
278            return 0., 45.
279        else:
280            return -180., 180.
281    elif p.endswith('_pd'):
282        return 0., 1.
283    elif 'sld' in p:
284        return -0.5, 10.
285    elif p == 'background':
286        return 0., 10.
287    elif p == 'scale':
288        return 0., 1.e3
289    elif v < 0.:
290        return 2.*v, -2.*v
291    else:
292        return 0., (2.*v if v > 0. else 1.)
293
294
295def _randomize_one(model_info, name, value):
296    # type: (ModelInfo, str, float) -> float
297    # type: (ModelInfo, str, str) -> str
298    """
299    Randomize a single parameter.
300    """
301    # Set the amount of polydispersity/angular dispersion, but by default pd_n
302    # is zero so there is no polydispersity.  This allows us to turn on/off
303    # pd by setting pd_n, and still have randomly generated values
304    if name.endswith('_pd'):
305        par = model_info.parameters[name[:-3]]
306        if par.type == 'orientation':
307            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
308            return 180*np.random.beta(2.5, 20)
309        else:
310            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
311            return np.random.beta(1.5, 7)
312
313    # pd is selected globally rather than per parameter, so set to 0 for no pd
314    # In particular, when multiple pd dimensions, want to decrease the number
315    # of points per dimension for faster computation
316    if name.endswith('_pd_n'):
317        return 0
318
319    # Don't mess with distribution type for now
320    if name.endswith('_pd_type'):
321        return 'gaussian'
322
323    # type-dependent value of number of sigmas; for gaussian use 3.
324    if name.endswith('_pd_nsigma'):
325        return 3.
326
327    # background in the range [0.01, 1]
328    if name == 'background':
329        return 10**np.random.uniform(-2, 0)
330
331    # scale defaults to 0.1% to 30% volume fraction
332    if name == 'scale':
333        return 10**np.random.uniform(-3, -0.5)
334
335    # If it is a list of choices, pick one at random with equal probability
336    # In practice, the model specific random generator will override.
337    par = model_info.parameters[name]
338    if len(par.limits) > 2:  # choice list
339        return np.random.randint(len(par.limits))
340
341    # If it is a fixed range, pick from it with equal probability.
342    # For logarithmic ranges, the model will have to override.
343    if np.isfinite(par.limits).all():
344        return np.random.uniform(*par.limits)
345
346    # If the paramter is marked as an sld use the range of neutron slds
347    # Should be doing something with randomly selected contrast matching
348    # but for now use random contrasts.  Since real data is contrast-matched,
349    # this will favour selection of hollow models when they exist.
350    if par.type == 'sld':
351        # Range of neutron SLDs
352        return np.random.uniform(-0.5, 12)
353
354    # Guess at the random length/radius/thickness.  In practice, all models
355    # are going to set their own reasonable ranges.
356    if par.type == 'volume':
357        if ('length' in par.name or
358                'radius' in par.name or
359                'thick' in par.name):
360            return 10**np.random.uniform(2, 4)
361
362    # In the absence of any other info, select a value in [0, 2v], or
363    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
364    # model random parameter generators will override this default.
365    low, high = parameter_range(par.name, value)
366    limits = (max(par.limits[0], low), min(par.limits[1], high))
367    return np.random.uniform(*limits)
368
369def _random_pd(model_info, pars):
370    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
371    pd_volume = []
372    pd_oriented = []
373    for p in pd:
374        if p.type == 'orientation':
375            pd_oriented.append(p.name)
376        elif p.length_control is not None:
377            n = int(pars.get(p.length_control, 1) + 0.5)
378            pd_volume.extend(p.name+str(k+1) for k in range(n))
379        elif p.length > 1:
380            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
381        else:
382            pd_volume.append(p.name)
383    u = np.random.rand()
384    n = len(pd_volume)
385    if u < 0.01 or n < 1:
386        pass  # 1% chance of no polydispersity
387    elif u < 0.86 or n < 2:
388        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
389    elif u < 0.99 or n < 3:
390        choices = np.random.choice(len(pd_volume), size=2)
391        pars[pd_volume[choices[0]]+"_pd_n"] = 25
392        pars[pd_volume[choices[1]]+"_pd_n"] = 10
393    else:
394        choices = np.random.choice(len(pd_volume), size=3)
395        pars[pd_volume[choices[0]]+"_pd_n"] = 25
396        pars[pd_volume[choices[1]]+"_pd_n"] = 10
397        pars[pd_volume[choices[2]]+"_pd_n"] = 5
398    if pd_oriented:
399        pars['theta_pd_n'] = 20
400        if np.random.rand() < 0.1:
401            pars['phi_pd_n'] = 5
402        if np.random.rand() < 0.1:
403            pars['psi_pd_n'] = 5
404
405    ## Show selected polydispersity
406    #for name, value in pars.items():
407    #    if name.endswith('_pd_n') and value > 0:
408    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
409
410
411def randomize_pars(model_info, pars):
412    # type: (ModelInfo, ParameterSet) -> ParameterSet
413    """
414    Generate random values for all of the parameters.
415
416    Valid ranges for the random number generator are guessed from the name of
417    the parameter; this will not account for constraints such as cap radius
418    greater than cylinder radius in the capped_cylinder model, so
419    :func:`constrain_pars` needs to be called afterward..
420    """
421    # Note: the sort guarantees order of calls to random number generator
422    random_pars = dict((p, _randomize_one(model_info, p, v))
423                       for p, v in sorted(pars.items()))
424    if model_info.random is not None:
425        random_pars.update(model_info.random())
426    _random_pd(model_info, random_pars)
427    return random_pars
428
429
430def constrain_pars(model_info, pars):
431    # type: (ModelInfo, ParameterSet) -> None
432    """
433    Restrict parameters to valid values.
434
435    This includes model specific code for models such as capped_cylinder
436    which need to support within model constraints (cap radius more than
437    cylinder radius in this case).
438
439    Warning: this updates the *pars* dictionary in place.
440    """
441    # TODO: move the model specific code to the individual models
442    name = model_info.id
443    # if it is a product model, then just look at the form factor since
444    # none of the structure factors need any constraints.
445    if '*' in name:
446        name = name.split('*')[0]
447
448    # Suppress magnetism for python models (not yet implemented)
449    if callable(model_info.Iq):
450        pars.update(suppress_magnetism(pars))
451
452    if name == 'barbell':
453        if pars['radius_bell'] < pars['radius']:
454            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
455
456    elif name == 'capped_cylinder':
457        if pars['radius_cap'] < pars['radius']:
458            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
459
460    elif name == 'guinier':
461        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
462        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
463        # => ln A - (Rg^2 q^2/3) > -30 ln 10
464        # => Rg^2 q^2/3 < 30 ln 10 + ln A
465        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
466        #q_max = 0.2  # mid q maximum
467        q_max = 1.0  # high q maximum
468        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
469        pars['rg'] = min(pars['rg'], rg_max)
470
471    elif name == 'pearl_necklace':
472        if pars['radius'] < pars['thick_string']:
473            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
474        pass
475
476    elif name == 'rpa':
477        # Make sure phi sums to 1.0
478        if pars['case_num'] < 2:
479            pars['Phi1'] = 0.
480            pars['Phi2'] = 0.
481        elif pars['case_num'] < 5:
482            pars['Phi1'] = 0.
483        total = sum(pars['Phi'+c] for c in '1234')
484        for c in '1234':
485            pars['Phi'+c] /= total
486
487def parlist(model_info, pars, is2d):
488    # type: (ModelInfo, ParameterSet, bool) -> str
489    """
490    Format the parameter list for printing.
491    """
492    lines = []
493    parameters = model_info.parameters
494    magnetic = False
495    for p in parameters.user_parameters(pars, is2d):
496        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
497            continue
498        if p.id.startswith('up:') and not magnetic:
499            continue
500        fields = dict(
501            value=pars.get(p.id, p.default),
502            pd=pars.get(p.id+"_pd", 0.),
503            n=int(pars.get(p.id+"_pd_n", 0)),
504            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
505            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
506            relative_pd=p.relative_pd,
507            M0=pars.get('M0:'+p.id, 0.),
508            mphi=pars.get('mphi:'+p.id, 0.),
509            mtheta=pars.get('mtheta:'+p.id, 0.),
510        )
511        lines.append(_format_par(p.name, **fields))
512        magnetic = magnetic or fields['M0'] != 0.
513    return "\n".join(lines)
514
515    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
516
517def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
518                relative_pd=False, M0=0., mphi=0., mtheta=0.):
519    # type: (str, float, float, int, float, str) -> str
520    line = "%s: %g"%(name, value)
521    if pd != 0.  and n != 0:
522        if relative_pd:
523            pd *= value
524        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
525                % (pd, n, nsigma, nsigma, pdtype)
526    if M0 != 0.:
527        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
528    return line
529
530def suppress_pd(pars):
531    # type: (ParameterSet) -> ParameterSet
532    """
533    Suppress theta_pd for now until the normalization is resolved.
534
535    May also suppress complete polydispersity of the model to test
536    models more quickly.
537    """
538    pars = pars.copy()
539    for p in pars:
540        if p.endswith("_pd_n"):
541            pars[p] = 0
542    return pars
543
544def suppress_magnetism(pars):
545    # type: (ParameterSet) -> ParameterSet
546    """
547    Suppress theta_pd for now until the normalization is resolved.
548
549    May also suppress complete polydispersity of the model to test
550    models more quickly.
551    """
552    pars = pars.copy()
553    for p in pars:
554        if p.startswith("M0:"): pars[p] = 0
555    return pars
556
557def eval_sasview(model_info, data):
558    # type: (Modelinfo, Data) -> Calculator
559    """
560    Return a model calculator using the pre-4.0 SasView models.
561    """
562    # importing sas here so that the error message will be that sas failed to
563    # import rather than the more obscure smear_selection not imported error
564    import sas
565    import sas.models
566    from sas.models.qsmearing import smear_selection
567    from sas.models.MultiplicationModel import MultiplicationModel
568    from sas.models.dispersion_models import models as dispersers
569
570    def get_model_class(name):
571        # type: (str) -> "sas.models.BaseComponent"
572        #print("new",sorted(_pars.items()))
573        __import__('sas.models.' + name)
574        ModelClass = getattr(getattr(sas.models, name, None), name, None)
575        if ModelClass is None:
576            raise ValueError("could not find model %r in sas.models"%name)
577        return ModelClass
578
579    # WARNING: ugly hack when handling model!
580    # Sasview models with multiplicity need to be created with the target
581    # multiplicity, so we cannot create the target model ahead of time for
582    # for multiplicity models.  Instead we store the model in a list and
583    # update the first element of that list with the new multiplicity model
584    # every time we evaluate.
585
586    # grab the sasview model, or create it if it is a product model
587    if model_info.composition:
588        composition_type, parts = model_info.composition
589        if composition_type == 'product':
590            P, S = [get_model_class(revert_name(p))() for p in parts]
591            model = [MultiplicationModel(P, S)]
592        else:
593            raise ValueError("sasview mixture models not supported by compare")
594    else:
595        old_name = revert_name(model_info)
596        if old_name is None:
597            raise ValueError("model %r does not exist in old sasview"
598                            % model_info.id)
599        ModelClass = get_model_class(old_name)
600        model = [ModelClass()]
601    model[0].disperser_handles = {}
602
603    # build a smearer with which to call the model, if necessary
604    smearer = smear_selection(data, model=model)
605    if hasattr(data, 'qx_data'):
606        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
607        index = ((~data.mask) & (~np.isnan(data.data))
608                 & (q >= data.qmin) & (q <= data.qmax))
609        if smearer is not None:
610            smearer.model = model  # because smear_selection has a bug
611            smearer.accuracy = data.accuracy
612            smearer.set_index(index)
613            def _call_smearer():
614                smearer.model = model[0]
615                return smearer.get_value()
616            theory = _call_smearer
617        else:
618            theory = lambda: model[0].evalDistribution([data.qx_data[index],
619                                                        data.qy_data[index]])
620    elif smearer is not None:
621        theory = lambda: smearer(model[0].evalDistribution(data.x))
622    else:
623        theory = lambda: model[0].evalDistribution(data.x)
624
625    def calculator(**pars):
626        # type: (float, ...) -> np.ndarray
627        """
628        Sasview calculator for model.
629        """
630        oldpars = revert_pars(model_info, pars)
631        # For multiplicity models, create a model with the correct multiplicity
632        control = oldpars.pop("CONTROL", None)
633        if control is not None:
634            # sphericalSLD has one fewer multiplicity.  This update should
635            # happen in revert_pars, but it hasn't been called yet.
636            model[0] = ModelClass(control)
637        # paying for parameter conversion each time to keep life simple, if not fast
638        for k, v in oldpars.items():
639            if k.endswith('.type'):
640                par = k[:-5]
641                if v == 'gaussian': continue
642                cls = dispersers[v if v != 'rectangle' else 'rectangula']
643                handle = cls()
644                model[0].disperser_handles[par] = handle
645                try:
646                    model[0].set_dispersion(par, handle)
647                except Exception:
648                    exception.annotate_exception("while setting %s to %r"
649                                                 %(par, v))
650                    raise
651
652
653        #print("sasview pars",oldpars)
654        for k, v in oldpars.items():
655            name_attr = k.split('.')  # polydispersity components
656            if len(name_attr) == 2:
657                par, disp_par = name_attr
658                model[0].dispersion[par][disp_par] = v
659            else:
660                model[0].setParam(k, v)
661        return theory()
662
663    calculator.engine = "sasview"
664    return calculator
665
666DTYPE_MAP = {
667    'half': '16',
668    'fast': 'fast',
669    'single': '32',
670    'double': '64',
671    'quad': '128',
672    'f16': '16',
673    'f32': '32',
674    'f64': '64',
675    'float16': '16',
676    'float32': '32',
677    'float64': '64',
678    'float128': '128',
679    'longdouble': '128',
680}
681def eval_opencl(model_info, data, dtype='single', cutoff=0.):
682    # type: (ModelInfo, Data, str, float) -> Calculator
683    """
684    Return a model calculator using the OpenCL calculation engine.
685    """
686    if not core.HAVE_OPENCL:
687        raise RuntimeError("OpenCL not available")
688    model = core.build_model(model_info, dtype=dtype, platform="ocl")
689    calculator = DirectModel(data, model, cutoff=cutoff)
690    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
691    return calculator
692
693def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
694    # type: (ModelInfo, Data, str, float) -> Calculator
695    """
696    Return a model calculator using the DLL calculation engine.
697    """
698    model = core.build_model(model_info, dtype=dtype, platform="dll")
699    calculator = DirectModel(data, model, cutoff=cutoff)
700    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
701    return calculator
702
703def time_calculation(calculator, pars, evals=1):
704    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
705    """
706    Compute the average calculation time over N evaluations.
707
708    An additional call is generated without polydispersity in order to
709    initialize the calculation engine, and make the average more stable.
710    """
711    # initialize the code so time is more accurate
712    if evals > 1:
713        calculator(**suppress_pd(pars))
714    toc = tic()
715    # make sure there is at least one eval
716    value = calculator(**pars)
717    for _ in range(evals-1):
718        value = calculator(**pars)
719    average_time = toc()*1000. / evals
720    #print("I(q)",value)
721    return value, average_time
722
723def make_data(opts):
724    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
725    """
726    Generate an empty dataset, used with the model to set Q points
727    and resolution.
728
729    *opts* contains the options, with 'qmax', 'nq', 'res',
730    'accuracy', 'is2d' and 'view' parsed from the command line.
731    """
732    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
733    if opts['is2d']:
734        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
735        data = empty_data2D(q, resolution=res)
736        data.accuracy = opts['accuracy']
737        set_beam_stop(data, 0.0004)
738        index = ~data.mask
739    else:
740        if opts['view'] == 'log' and not opts['zero']:
741            qmax = math.log10(qmax)
742            q = np.logspace(qmax-3, qmax, nq)
743        else:
744            q = np.linspace(0.001*qmax, qmax, nq)
745        if opts['zero']:
746            q = np.hstack((0, q))
747        data = empty_data1D(q, resolution=res)
748        index = slice(None, None)
749    return data, index
750
751def make_engine(model_info, data, dtype, cutoff):
752    # type: (ModelInfo, Data, str, float) -> Calculator
753    """
754    Generate the appropriate calculation engine for the given datatype.
755
756    Datatypes with '!' appended are evaluated using external C DLLs rather
757    than OpenCL.
758    """
759    if dtype == 'sasview':
760        return eval_sasview(model_info, data)
761    elif dtype is None or not dtype.endswith('!'):
762        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
763    else:
764        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
765
766def _show_invalid(data, theory):
767    # type: (Data, np.ma.ndarray) -> None
768    """
769    Display a list of the non-finite values in theory.
770    """
771    if not theory.mask.any():
772        return
773
774    if hasattr(data, 'x'):
775        bad = zip(data.x[theory.mask], theory[theory.mask])
776        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
777
778
779def compare(opts, limits=None):
780    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
781    """
782    Preform a comparison using options from the command line.
783
784    *limits* are the limits on the values to use, either to set the y-axis
785    for 1D or to set the colormap scale for 2D.  If None, then they are
786    inferred from the data and returned. When exploring using Bumps,
787    the limits are set when the model is initially called, and maintained
788    as the values are adjusted, making it easier to see the effects of the
789    parameters.
790    """
791    limits = np.Inf, -np.Inf
792    for k in range(opts['sets']):
793        opts['pars'] = parse_pars(opts)
794        if opts['pars'] is None:
795            return
796        result = run_models(opts, verbose=True)
797        if opts['plot']:
798            limits = plot_models(opts, result, limits=limits, setnum=k)
799    if opts['plot']:
800        import matplotlib.pyplot as plt
801        plt.show()
802
803def run_models(opts, verbose=False):
804    # type: (Dict[str, Any]) -> Dict[str, Any]
805
806    base, comp = opts['engines']
807    base_n, comp_n = opts['count']
808    base_pars, comp_pars = opts['pars']
809    data = opts['data']
810
811    comparison = comp is not None
812
813    base_time = comp_time = None
814    base_value = comp_value = resid = relerr = None
815
816    # Base calculation
817    try:
818        base_raw, base_time = time_calculation(base, base_pars, base_n)
819        base_value = np.ma.masked_invalid(base_raw)
820        if verbose:
821            print("%s t=%.2f ms, intensity=%.0f"
822                  % (base.engine, base_time, base_value.sum()))
823        _show_invalid(data, base_value)
824    except ImportError:
825        traceback.print_exc()
826
827    # Comparison calculation
828    if comparison:
829        try:
830            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
831            comp_value = np.ma.masked_invalid(comp_raw)
832            if verbose:
833                print("%s t=%.2f ms, intensity=%.0f"
834                      % (comp.engine, comp_time, comp_value.sum()))
835            _show_invalid(data, comp_value)
836        except ImportError:
837            traceback.print_exc()
838
839    # Compare, but only if computing both forms
840    if comparison:
841        resid = (base_value - comp_value)
842        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
843        if verbose:
844            _print_stats("|%s-%s|"
845                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
846                         resid)
847            _print_stats("|(%s-%s)/%s|"
848                         % (base.engine, comp.engine, comp.engine),
849                         relerr)
850
851    return dict(base_value=base_value, comp_value=comp_value,
852                base_time=base_time, comp_time=comp_time,
853                resid=resid, relerr=relerr)
854
855
856def _print_stats(label, err):
857    # type: (str, np.ma.ndarray) -> None
858    # work with trimmed data, not the full set
859    sorted_err = np.sort(abs(err.compressed()))
860    if len(sorted_err) == 0.:
861        print(label + "  no valid values")
862        return
863
864    p50 = int((len(sorted_err)-1)*0.50)
865    p98 = int((len(sorted_err)-1)*0.98)
866    data = [
867        "max:%.3e"%sorted_err[-1],
868        "median:%.3e"%sorted_err[p50],
869        "98%%:%.3e"%sorted_err[p98],
870        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
871        "zero-offset:%+.3e"%np.mean(sorted_err),
872        ]
873    print(label+"  "+"  ".join(data))
874
875
876def plot_models(opts, result, limits=(np.Inf, -np.Inf), setnum=0):
877    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
878    base_value, comp_value= result['base_value'], result['comp_value']
879    base_time, comp_time = result['base_time'], result['comp_time']
880    resid, relerr = result['resid'], result['relerr']
881
882    have_base, have_comp = (base_value is not None), (comp_value is not None)
883    base, comp = opts['engines']
884    data = opts['data']
885    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
886
887    # Plot if requested
888    view = opts['view']
889    import matplotlib.pyplot as plt
890    vmin, vmax = limits
891    if have_base:
892        vmin = min(vmin, base_value.min())
893        vmax = max(vmax, base_value.max())
894    if have_comp:
895        vmin = min(vmin, comp_value.min())
896        vmax = max(vmax, comp_value.max())
897    limits = vmin, vmax
898
899    if have_base:
900        if have_comp:
901            plt.subplot(131)
902        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
903        plt.title("%s t=%.2f ms"%(base.engine, base_time))
904        #cbar_title = "log I"
905    if have_comp:
906        if have_base:
907            plt.subplot(132)
908        if not opts['is2d'] and have_base:
909            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
910        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
911        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
912        #cbar_title = "log I"
913    if have_base and have_comp:
914        plt.subplot(133)
915        if not opts['rel_err']:
916            err, errstr, errview = resid, "abs err", "linear"
917        else:
918            err, errstr, errview = abs(relerr), "rel err", "log"
919        if 0:  # 95% cutoff
920            sorted = np.sort(err.flatten())
921            cutoff = sorted[int(sorted.size*0.95)]
922            err[err > cutoff] = cutoff
923        #err,errstr = base/comp,"ratio"
924        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
925        if view == 'linear':
926            plt.xscale('linear')
927        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
928        #cbar_title = errstr if errview=="linear" else "log "+errstr
929    #if is2D:
930    #    h = plt.colorbar()
931    #    h.ax.set_title(cbar_title)
932    fig = plt.gcf()
933    extra_title = ' '+opts['title'] if opts['title'] else ''
934    fig.suptitle(":".join(opts['name']) + extra_title)
935
936    if have_base and have_comp and opts['show_hist']:
937        plt.figure()
938        v = relerr
939        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
940        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
941        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
942                   % (base.engine, comp.engine, comp.engine))
943        plt.ylabel('P(err)')
944        plt.title('Distribution of relative error between calculation engines')
945
946    return limits
947
948
949# ===========================================================================
950#
951
952# Set of command line options.
953# Normal options such as -plot/-noplot are specified as 'name'.
954# For options such as -nq=500 which require a value use 'name='.
955#
956OPTIONS = [
957    # Plotting
958    'plot', 'noplot',
959    'linear', 'log', 'q4',
960    'rel', 'abs',
961    'hist', 'nohist',
962    'title=',
963
964    # Data generation
965    'data=', 'noise=', 'res=',
966    'nq=', 'lowq', 'midq', 'highq', 'exq', 'zero',
967    '2d', '1d',
968
969    # Parameter set
970    'preset', 'random', 'random=', 'sets=',
971    'demo', 'default',  # TODO: remove demo/default
972    'nopars', 'pars',
973
974    # Calculation options
975    'poly', 'mono', 'cutoff=',
976    'magnetic', 'nonmagnetic',
977    'accuracy=',
978
979    # Precision options
980    'calc=',
981    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
982    'sasview',  # TODO: remove sasview 3.x support
983    'timing=',
984
985    # Output options
986    'help', 'html', 'edit',
987    ]
988
989NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
990VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
991
992
993def columnize(items, indent="", width=79):
994    # type: (List[str], str, int) -> str
995    """
996    Format a list of strings into columns.
997
998    Returns a string with carriage returns ready for printing.
999    """
1000    column_width = max(len(w) for w in items) + 1
1001    num_columns = (width - len(indent)) // column_width
1002    num_rows = len(items) // num_columns
1003    items = items + [""] * (num_rows * num_columns - len(items))
1004    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1005    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1006             for row in zip(*columns)]
1007    output = indent + ("\n"+indent).join(lines)
1008    return output
1009
1010
1011def get_pars(model_info, use_demo=False):
1012    # type: (ModelInfo, bool) -> ParameterSet
1013    """
1014    Extract demo parameters from the model definition.
1015    """
1016    # Get the default values for the parameters
1017    pars = {}
1018    for p in model_info.parameters.call_parameters:
1019        parts = [('', p.default)]
1020        if p.polydisperse:
1021            parts.append(('_pd', 0.0))
1022            parts.append(('_pd_n', 0))
1023            parts.append(('_pd_nsigma', 3.0))
1024            parts.append(('_pd_type', "gaussian"))
1025        for ext, val in parts:
1026            if p.length > 1:
1027                dict(("%s%d%s" % (p.id, k, ext), val)
1028                     for k in range(1, p.length+1))
1029            else:
1030                pars[p.id + ext] = val
1031
1032    # Plug in values given in demo
1033    if use_demo:
1034        pars.update(model_info.demo)
1035    return pars
1036
1037INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1038def isnumber(str):
1039    match = FLOAT_RE.match(str)
1040    isfloat = (match and not str[match.end():])
1041    return isfloat or INTEGER_RE.match(str)
1042
1043# For distinguishing pairs of models for comparison
1044# key-value pair separator =
1045# shell characters  | & ; <> $ % ' " \ # `
1046# model and parameter names _
1047# parameter expressions - + * / . ( )
1048# path characters including tilde expansion and windows drive ~ / :
1049# not sure about brackets [] {}
1050# maybe one of the following @ ? ^ ! ,
1051PAR_SPLIT = ','
1052def parse_opts(argv):
1053    # type: (List[str]) -> Dict[str, Any]
1054    """
1055    Parse command line options.
1056    """
1057    MODELS = core.list_models()
1058    flags = [arg for arg in argv
1059             if arg.startswith('-')]
1060    values = [arg for arg in argv
1061              if not arg.startswith('-') and '=' in arg]
1062    positional_args = [arg for arg in argv
1063                       if not arg.startswith('-') and '=' not in arg]
1064    models = "\n    ".join("%-15s"%v for v in MODELS)
1065    if len(positional_args) == 0:
1066        print(USAGE)
1067        print("\nAvailable models:")
1068        print(columnize(MODELS, indent="  "))
1069        return None
1070
1071    invalid = [o[1:] for o in flags
1072               if o[1:] not in NAME_OPTIONS
1073               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1074    if invalid:
1075        print("Invalid options: %s"%(", ".join(invalid)))
1076        return None
1077
1078    name = positional_args[-1]
1079
1080    # pylint: disable=bad-whitespace
1081    # Interpret the flags
1082    opts = {
1083        'plot'      : True,
1084        'view'      : 'log',
1085        'is2d'      : False,
1086        'qmax'      : 0.05,
1087        'nq'        : 128,
1088        'res'       : 0.0,
1089        'noise'     : 0.0,
1090        'accuracy'  : 'Low',
1091        'cutoff'    : '0.0',
1092        'seed'      : -1,  # default to preset
1093        'mono'      : True,
1094        # Default to magnetic a magnetic moment is set on the command line
1095        'magnetic'  : False,
1096        'show_pars' : False,
1097        'show_hist' : False,
1098        'rel_err'   : True,
1099        'explore'   : False,
1100        'use_demo'  : True,
1101        'zero'      : False,
1102        'html'      : False,
1103        'title'     : None,
1104        'datafile'  : None,
1105        'sets'      : 1,
1106        'engine'    : 'default',
1107        'evals'     : '1',
1108    }
1109    for arg in flags:
1110        if arg == '-noplot':    opts['plot'] = False
1111        elif arg == '-plot':    opts['plot'] = True
1112        elif arg == '-linear':  opts['view'] = 'linear'
1113        elif arg == '-log':     opts['view'] = 'log'
1114        elif arg == '-q4':      opts['view'] = 'q4'
1115        elif arg == '-1d':      opts['is2d'] = False
1116        elif arg == '-2d':      opts['is2d'] = True
1117        elif arg == '-exq':     opts['qmax'] = 10.0
1118        elif arg == '-highq':   opts['qmax'] = 1.0
1119        elif arg == '-midq':    opts['qmax'] = 0.2
1120        elif arg == '-lowq':    opts['qmax'] = 0.05
1121        elif arg == '-zero':    opts['zero'] = True
1122        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1123        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1124        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1125        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1126        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1127        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1128        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
1129        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1130        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1131        elif arg.startswith('-calc='):     opts['engine'] = arg[6:]
1132        elif arg.startswith('-neval='):    opts['evals'] = arg[7:]
1133        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
1134        elif arg == '-preset':  opts['seed'] = -1
1135        elif arg == '-mono':    opts['mono'] = True
1136        elif arg == '-poly':    opts['mono'] = False
1137        elif arg == '-magnetic':       opts['magnetic'] = True
1138        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1139        elif arg == '-pars':    opts['show_pars'] = True
1140        elif arg == '-nopars':  opts['show_pars'] = False
1141        elif arg == '-hist':    opts['show_hist'] = True
1142        elif arg == '-nohist':  opts['show_hist'] = False
1143        elif arg == '-rel':     opts['rel_err'] = True
1144        elif arg == '-abs':     opts['rel_err'] = False
1145        elif arg == '-half':    opts['engine'] = 'half'
1146        elif arg == '-fast':    opts['engine'] = 'fast'
1147        elif arg == '-single':  opts['engine'] = 'single'
1148        elif arg == '-double':  opts['engine'] = 'double'
1149        elif arg == '-single!': opts['engine'] = 'single!'
1150        elif arg == '-double!': opts['engine'] = 'double!'
1151        elif arg == '-quad!':   opts['engine'] = 'quad!'
1152        elif arg == '-sasview': opts['engine'] = 'sasview'
1153        elif arg == '-edit':    opts['explore'] = True
1154        elif arg == '-demo':    opts['use_demo'] = True
1155        elif arg == '-default':    opts['use_demo'] = False
1156        elif arg == '-html':    opts['html'] = True
1157        elif arg == '-help':    opts['html'] = True
1158    # pylint: enable=bad-whitespace
1159
1160    # Force random if more than one set
1161    if opts['sets'] > 1 and opts['seed'] < 0:
1162        opts['seed'] = np.random.randint(1000000)
1163
1164    # Create the computational engines
1165    if opts['datafile'] is not None:
1166        data = load_data(os.path.expanduser(opts['datafile']))
1167    else:
1168        data, _ = make_data(opts)
1169
1170    comparison = any(PAR_SPLIT in v for v in values)
1171    if PAR_SPLIT in name:
1172        names = name.split(PAR_SPLIT, 2)
1173        comparison = True
1174    else:
1175        names = [name]*2
1176    try:
1177        model_info = [core.load_model_info(k) for k in names]
1178    except ImportError as exc:
1179        print(str(exc))
1180        print("Could not find model; use one of:\n    " + models)
1181        return None
1182
1183    if PAR_SPLIT in opts['engine']:
1184        engine_types = opts['engine'].split(PAR_SPLIT, 2)
1185        comparison = True
1186    else:
1187        engine_types = [opts['engine']]*2
1188
1189    if PAR_SPLIT in opts['evals']:
1190        evals = [int(k) for k in opts['evals'].split(PAR_SPLIT, 2)]
1191        comparison = True
1192    else:
1193        evals = [int(opts['evals'])]*2
1194
1195    if PAR_SPLIT in opts['cutoff']:
1196        cutoff = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1197        comparison = True
1198    else:
1199        cutoff = [float(opts['cutoff'])]*2
1200
1201    base = make_engine(model_info[0], data, engine_types[0], cutoff[0])
1202    if comparison:
1203        comp = make_engine(model_info[1], data, engine_types[1], cutoff[1])
1204    else:
1205        comp = None
1206
1207    # pylint: disable=bad-whitespace
1208    # Remember it all
1209    opts.update({
1210        'data'      : data,
1211        'name'      : names,
1212        'def'       : model_info,
1213        'count'     : evals,
1214        'engines'   : [base, comp],
1215        'values'    : values,
1216    })
1217    # pylint: enable=bad-whitespace
1218
1219    return opts
1220
1221def parse_pars(opts):
1222    model_info, model_info2 = opts['def']
1223
1224    # Get demo parameters from model definition, or use default parameters
1225    # if model does not define demo parameters
1226    pars = get_pars(model_info, opts['use_demo'])
1227    pars2 = get_pars(model_info2, opts['use_demo'])
1228    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1229    # randomize parameters
1230    #pars.update(set_pars)  # set value before random to control range
1231    if opts['seed'] > -1:
1232        pars = randomize_pars(model_info, pars)
1233        if model_info != model_info2:
1234            pars2 = randomize_pars(model_info2, pars2)
1235            # Share values for parameters with the same name
1236            for k, v in pars.items():
1237                if k in pars2:
1238                    pars2[k] = v
1239        else:
1240            pars2 = pars.copy()
1241        constrain_pars(model_info, pars)
1242        constrain_pars(model_info2, pars2)
1243    if opts['mono']:
1244        pars = suppress_pd(pars)
1245        pars2 = suppress_pd(pars2)
1246    if not opts['magnetic']:
1247        pars = suppress_magnetism(pars)
1248        pars2 = suppress_magnetism(pars2)
1249
1250    # Fill in parameters given on the command line
1251    presets = {}
1252    presets2 = {}
1253    for arg in opts['values']:
1254        k, v = arg.split('=', 1)
1255        if k not in pars and k not in pars2:
1256            # extract base name without polydispersity info
1257            s = set(p.split('_pd')[0] for p in pars)
1258            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1259            return None
1260        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
1261        if v1 and k in pars:
1262            presets[k] = float(v1) if isnumber(v1) else v1
1263        if v2 and k in pars2:
1264            presets2[k] = float(v2) if isnumber(v2) else v2
1265
1266    # If pd given on the command line, default pd_n to 35
1267    for k, v in list(presets.items()):
1268        if k.endswith('_pd'):
1269            presets.setdefault(k+'_n', 35.)
1270    for k, v in list(presets2.items()):
1271        if k.endswith('_pd'):
1272            presets2.setdefault(k+'_n', 35.)
1273
1274    # Evaluate preset parameter expressions
1275    context = MATH.copy()
1276    context['np'] = np
1277    context.update(pars)
1278    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1279    for k, v in presets.items():
1280        if not isinstance(v, float) and not k.endswith('_type'):
1281            presets[k] = eval(v, context)
1282    context.update(presets)
1283    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1284    for k, v in presets2.items():
1285        if not isinstance(v, float) and not k.endswith('_type'):
1286            presets2[k] = eval(v, context)
1287
1288    # update parameters with presets
1289    pars.update(presets)  # set value after random to control value
1290    pars2.update(presets2)  # set value after random to control value
1291    #import pprint; pprint.pprint(model_info)
1292
1293    if opts['show_pars']:
1294        if model_info.name != model_info2.name or pars != pars2:
1295            print("==== %s ====="%model_info.name)
1296            print(str(parlist(model_info, pars, opts['is2d'])))
1297            print("==== %s ====="%model_info2.name)
1298            print(str(parlist(model_info2, pars2, opts['is2d'])))
1299        else:
1300            print(str(parlist(model_info, pars, opts['is2d'])))
1301
1302    return pars, pars2
1303
1304def show_docs(opts):
1305    # type: (Dict[str, Any]) -> None
1306    """
1307    show html docs for the model
1308    """
1309    import os
1310    from .generate import make_html
1311    from . import rst2html
1312
1313    info = opts['def'][0]
1314    html = make_html(info)
1315    path = os.path.dirname(info.filename)
1316    url = "file://"+path.replace("\\","/")[2:]+"/"
1317    rst2html.view_html_qtapp(html, url)
1318
1319def explore(opts):
1320    # type: (Dict[str, Any]) -> None
1321    """
1322    explore the model using the bumps gui.
1323    """
1324    import wx  # type: ignore
1325    from bumps.names import FitProblem  # type: ignore
1326    from bumps.gui.app_frame import AppFrame  # type: ignore
1327    from bumps.gui import signal
1328
1329    is_mac = "cocoa" in wx.version()
1330    # Create an app if not running embedded
1331    app = wx.App() if wx.GetApp() is None else None
1332    model = Explore(opts)
1333    problem = FitProblem(model)
1334    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1335    if not is_mac:
1336        frame.Show()
1337    frame.panel.set_model(model=problem)
1338    frame.panel.Layout()
1339    frame.panel.aui.Split(0, wx.TOP)
1340    def reset_parameters(event):
1341        model.revert_values()
1342        signal.update_parameters(problem)
1343    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1344    if is_mac: frame.Show()
1345    # If running withing an app, start the main loop
1346    if app:
1347        app.MainLoop()
1348
1349class Explore(object):
1350    """
1351    Bumps wrapper for a SAS model comparison.
1352
1353    The resulting object can be used as a Bumps fit problem so that
1354    parameters can be adjusted in the GUI, with plots updated on the fly.
1355    """
1356    def __init__(self, opts):
1357        # type: (Dict[str, Any]) -> None
1358        from bumps.cli import config_matplotlib  # type: ignore
1359        from . import bumps_model
1360        config_matplotlib()
1361        self.opts = opts
1362        opts['pars'] = list(opts['pars'])
1363        p1, p2 = opts['pars']
1364        m1, m2 = opts['def']
1365        self.fix_p2 = m1 != m2 or p1 != p2
1366        model_info = m1
1367        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1368        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1369        if not opts['is2d']:
1370            for p in model_info.parameters.user_parameters({}, is2d=False):
1371                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1372                    k = p.name+ext
1373                    v = pars.get(k, None)
1374                    if v is not None:
1375                        v.range(*parameter_range(k, v.value))
1376        else:
1377            for k, v in pars.items():
1378                v.range(*parameter_range(k, v.value))
1379
1380        self.pars = pars
1381        self.starting_values = dict((k, v.value) for k, v in pars.items())
1382        self.pd_types = pd_types
1383        self.limits = np.Inf, -np.Inf
1384
1385    def revert_values(self):
1386        for k, v in self.starting_values.items():
1387            self.pars[k].value = v
1388
1389    def model_update(self):
1390        pass
1391
1392    def numpoints(self):
1393        # type: () -> int
1394        """
1395        Return the number of points.
1396        """
1397        return len(self.pars) + 1  # so dof is 1
1398
1399    def parameters(self):
1400        # type: () -> Any   # Dict/List hierarchy of parameters
1401        """
1402        Return a dictionary of parameters.
1403        """
1404        return self.pars
1405
1406    def nllf(self):
1407        # type: () -> float
1408        """
1409        Return cost.
1410        """
1411        # pylint: disable=no-self-use
1412        return 0.  # No nllf
1413
1414    def plot(self, view='log'):
1415        # type: (str) -> None
1416        """
1417        Plot the data and residuals.
1418        """
1419        pars = dict((k, v.value) for k, v in self.pars.items())
1420        pars.update(self.pd_types)
1421        self.opts['pars'][0] = pars
1422        if not self.fix_p2:
1423            self.opts['pars'][1] = pars
1424        result = run_models(self.opts)
1425        limits = plot_models(self.opts, result, limits=self.limits)
1426        if self.limits is None:
1427            vmin, vmax = limits
1428            self.limits = vmax*1e-7, 1.3*vmax
1429            import pylab; pylab.clf()
1430            plot_models(self.opts, result, limits=self.limits)
1431
1432
1433def main(*argv):
1434    # type: (*str) -> None
1435    """
1436    Main program.
1437    """
1438    opts = parse_opts(argv)
1439    if opts is not None:
1440        if opts['seed'] > -1:
1441            print("Randomize using -random=%i"%opts['seed'])
1442            np.random.seed(opts['seed'])
1443        if opts['html']:
1444            show_docs(opts)
1445        elif opts['explore']:
1446            opts['pars'] = parse_pars(opts)
1447            if opts['pars'] is None:
1448                return
1449            explore(opts)
1450        else:
1451            compare(opts)
1452
1453if __name__ == "__main__":
1454    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.