source: sasmodels/sasmodels/compare.py @ 0f6c41c

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 0f6c41c was 0f6c41c, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

set range of randomly generated magnetic SLDs

  • Property mode set to 100755
File size: 51.9 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except Exception:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56
57USAGE = """
58usage: sascomp model [options...] [key=val]
59
60Generate and compare SAS models.  If a single model is specified it shows
61a plot of that model.  Different models can be compared, or the same model
62with different parameters.  The same model with the same parameters can
63be compared with different calculation engines to see the effects of precision
64on the resultant values.
65
66model or model1,model2 are the names of the models to compare (see below).
67
68Options (* for default):
69
70    === data generation ===
71    -data="path" uses q, dq from the data file
72    -noise=0 sets the measurement error dI/I
73    -res=0 sets the resolution width dQ/Q if calculating with resolution
74    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
75    -nq=128 sets the number of Q points in the data set
76    -1d*/-2d computes 1d or 2d data
77    -zero indicates that q=0 should be included
78
79    === model parameters ===
80    -preset*/-random[=seed] preset or random parameters
81    -sets=n generates n random datasets, with the seed given by -random=seed
82    -pars/-nopars* prints the parameter set or not
83    -default/-demo* use demo vs default parameters
84
85    === calculation options ===
86    -mono*/-poly force monodisperse or allow polydisperse demo parameters
87    -cutoff=1e-5* cutoff value for including a point in polydispersity
88    -magnetic/-nonmagnetic* suppress magnetism
89    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
90
91    === precision options ===
92    -calc=default uses the default calcution precision
93    -single/-double/-half/-fast sets an OpenCL calculation engine
94    -single!/-double!/-quad! sets an OpenMP calculation engine
95    -sasview sets the sasview calculation engine
96
97    === plotting ===
98    -plot*/-noplot plots or suppress the plot of the model
99    -linear/-log*/-q4 intensity scaling on plots
100    -hist/-nohist* plot histogram of relative error
101    -abs/-rel* plot relative or absolute error
102    -title="note" adds note to the plot title, after the model name
103
104    === output options ===
105    -edit starts the parameter explorer
106    -help/-html shows the model docs instead of running the model
107
108The interpretation of quad precision depends on architecture, and may
109vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
110On unix and mac you may need single quotes around the DLL computation
111engines, such as -calc='single!,double!' since !, is treated as a history
112expansion request in the shell.
113
114Key=value pairs allow you to set specific values for the model parameters.
115Key=value1,value2 to compare different values of the same parameter. The
116value can be an expression including other parameters.
117
118Items later on the command line override those that appear earlier.
119
120Examples:
121
122    # compare single and double precision calculation for a barbell
123    sascomp barbell -calc=single,double
124
125    # generate 10 random lorentz models, with seed=27
126    sascomp lorentz -sets=10 -seed=27
127
128    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
129    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
130
131    # model timing test requires multiple evals to perform the estimate
132    sascomp pringle -calc=single,double -timing=100,100 -noplot
133"""
134
135# Update docs with command line usage string.   This is separate from the usual
136# doc string so that we can display it at run time if there is an error.
137# lin
138__doc__ = (__doc__  # pylint: disable=redefined-builtin
139           + """
140Program description
141-------------------
142
143""" + USAGE)
144
145kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
146
147# list of math functions for use in evaluating parameters
148MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
149
150# CRUFT python 2.6
151if not hasattr(datetime.timedelta, 'total_seconds'):
152    def delay(dt):
153        """Return number date-time delta as number seconds"""
154        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
155else:
156    def delay(dt):
157        """Return number date-time delta as number seconds"""
158        return dt.total_seconds()
159
160
161class push_seed(object):
162    """
163    Set the seed value for the random number generator.
164
165    When used in a with statement, the random number generator state is
166    restored after the with statement is complete.
167
168    :Parameters:
169
170    *seed* : int or array_like, optional
171        Seed for RandomState
172
173    :Example:
174
175    Seed can be used directly to set the seed::
176
177        >>> from numpy.random import randint
178        >>> push_seed(24)
179        <...push_seed object at...>
180        >>> print(randint(0,1000000,3))
181        [242082    899 211136]
182
183    Seed can also be used in a with statement, which sets the random
184    number generator state for the enclosed computations and restores
185    it to the previous state on completion::
186
187        >>> with push_seed(24):
188        ...    print(randint(0,1000000,3))
189        [242082    899 211136]
190
191    Using nested contexts, we can demonstrate that state is indeed
192    restored after the block completes::
193
194        >>> with push_seed(24):
195        ...    print(randint(0,1000000))
196        ...    with push_seed(24):
197        ...        print(randint(0,1000000,3))
198        ...    print(randint(0,1000000))
199        242082
200        [242082    899 211136]
201        899
202
203    The restore step is protected against exceptions in the block::
204
205        >>> with push_seed(24):
206        ...    print(randint(0,1000000))
207        ...    try:
208        ...        with push_seed(24):
209        ...            print(randint(0,1000000,3))
210        ...            raise Exception()
211        ...    except Exception:
212        ...        print("Exception raised")
213        ...    print(randint(0,1000000))
214        242082
215        [242082    899 211136]
216        Exception raised
217        899
218    """
219    def __init__(self, seed=None):
220        # type: (Optional[int]) -> None
221        self._state = np.random.get_state()
222        np.random.seed(seed)
223
224    def __enter__(self):
225        # type: () -> None
226        pass
227
228    def __exit__(self, exc_type, exc_value, traceback):
229        # type: (Any, BaseException, Any) -> None
230        # TODO: better typing for __exit__ method
231        np.random.set_state(self._state)
232
233def tic():
234    # type: () -> Callable[[], float]
235    """
236    Timer function.
237
238    Use "toc=tic()" to start the clock and "toc()" to measure
239    a time interval.
240    """
241    then = datetime.datetime.now()
242    return lambda: delay(datetime.datetime.now() - then)
243
244
245def set_beam_stop(data, radius, outer=None):
246    # type: (Data, float, float) -> None
247    """
248    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
249
250    Note: this function does not require sasview
251    """
252    if hasattr(data, 'qx_data'):
253        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
254        data.mask = (q < radius)
255        if outer is not None:
256            data.mask |= (q >= outer)
257    else:
258        data.mask = (data.x < radius)
259        if outer is not None:
260            data.mask |= (data.x >= outer)
261
262
263def parameter_range(p, v):
264    # type: (str, float) -> Tuple[float, float]
265    """
266    Choose a parameter range based on parameter name and initial value.
267    """
268    # process the polydispersity options
269    if p.endswith('_pd_n'):
270        return 0., 100.
271    elif p.endswith('_pd_nsigma'):
272        return 0., 5.
273    elif p.endswith('_pd_type'):
274        raise ValueError("Cannot return a range for a string value")
275    elif any(s in p for s in ('theta', 'phi', 'psi')):
276        # orientation in [-180,180], orientation pd in [0,45]
277        if p.endswith('_pd'):
278            return 0., 45.
279        else:
280            return -180., 180.
281    elif p.endswith('_pd'):
282        return 0., 1.
283    elif 'sld' in p:
284        return -0.5, 10.
285    elif p == 'background':
286        return 0., 10.
287    elif p == 'scale':
288        return 0., 1.e3
289    elif v < 0.:
290        return 2.*v, -2.*v
291    else:
292        return 0., (2.*v if v > 0. else 1.)
293
294
295def _randomize_one(model_info, name, value):
296    # type: (ModelInfo, str, float) -> float
297    # type: (ModelInfo, str, str) -> str
298    """
299    Randomize a single parameter.
300    """
301    # Set the amount of polydispersity/angular dispersion, but by default pd_n
302    # is zero so there is no polydispersity.  This allows us to turn on/off
303    # pd by setting pd_n, and still have randomly generated values
304    if name.endswith('_pd'):
305        par = model_info.parameters[name[:-3]]
306        if par.type == 'orientation':
307            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
308            return 180*np.random.beta(2.5, 20)
309        else:
310            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
311            return np.random.beta(1.5, 7)
312
313    # pd is selected globally rather than per parameter, so set to 0 for no pd
314    # In particular, when multiple pd dimensions, want to decrease the number
315    # of points per dimension for faster computation
316    if name.endswith('_pd_n'):
317        return 0
318
319    # Don't mess with distribution type for now
320    if name.endswith('_pd_type'):
321        return 'gaussian'
322
323    # type-dependent value of number of sigmas; for gaussian use 3.
324    if name.endswith('_pd_nsigma'):
325        return 3.
326
327    # background in the range [0.01, 1]
328    if name == 'background':
329        return 10**np.random.uniform(-2, 0)
330
331    # scale defaults to 0.1% to 30% volume fraction
332    if name == 'scale':
333        return 10**np.random.uniform(-3, -0.5)
334
335    # If it is a list of choices, pick one at random with equal probability
336    # In practice, the model specific random generator will override.
337    par = model_info.parameters[name]
338    if len(par.limits) > 2:  # choice list
339        return np.random.randint(len(par.limits))
340
341    # If it is a fixed range, pick from it with equal probability.
342    # For logarithmic ranges, the model will have to override.
343    if np.isfinite(par.limits).all():
344        return np.random.uniform(*par.limits)
345
346    # If the paramter is marked as an sld use the range of neutron slds
347    # TODO: ought to randomly contrast match a pair of SLDs
348    if par.type == 'sld':
349        return np.random.uniform(-0.5, 12)
350
351    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
352    if par.name.startswith('M0:'):
353        return np.random.uniform(0, 5)
354
355    # Guess at the random length/radius/thickness.  In practice, all models
356    # are going to set their own reasonable ranges.
357    if par.type == 'volume':
358        if ('length' in par.name or
359                'radius' in par.name or
360                'thick' in par.name):
361            return 10**np.random.uniform(2, 4)
362
363    # In the absence of any other info, select a value in [0, 2v], or
364    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
365    # model random parameter generators will override this default.
366    low, high = parameter_range(par.name, value)
367    limits = (max(par.limits[0], low), min(par.limits[1], high))
368    return np.random.uniform(*limits)
369
370def _random_pd(model_info, pars):
371    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
372    pd_volume = []
373    pd_oriented = []
374    for p in pd:
375        if p.type == 'orientation':
376            pd_oriented.append(p.name)
377        elif p.length_control is not None:
378            n = int(pars.get(p.length_control, 1) + 0.5)
379            pd_volume.extend(p.name+str(k+1) for k in range(n))
380        elif p.length > 1:
381            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
382        else:
383            pd_volume.append(p.name)
384    u = np.random.rand()
385    n = len(pd_volume)
386    if u < 0.01 or n < 1:
387        pass  # 1% chance of no polydispersity
388    elif u < 0.86 or n < 2:
389        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
390    elif u < 0.99 or n < 3:
391        choices = np.random.choice(len(pd_volume), size=2)
392        pars[pd_volume[choices[0]]+"_pd_n"] = 25
393        pars[pd_volume[choices[1]]+"_pd_n"] = 10
394    else:
395        choices = np.random.choice(len(pd_volume), size=3)
396        pars[pd_volume[choices[0]]+"_pd_n"] = 25
397        pars[pd_volume[choices[1]]+"_pd_n"] = 10
398        pars[pd_volume[choices[2]]+"_pd_n"] = 5
399    if pd_oriented:
400        pars['theta_pd_n'] = 20
401        if np.random.rand() < 0.1:
402            pars['phi_pd_n'] = 5
403        if np.random.rand() < 0.1:
404            pars['psi_pd_n'] = 5
405
406    ## Show selected polydispersity
407    #for name, value in pars.items():
408    #    if name.endswith('_pd_n') and value > 0:
409    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
410
411
412def randomize_pars(model_info, pars):
413    # type: (ModelInfo, ParameterSet) -> ParameterSet
414    """
415    Generate random values for all of the parameters.
416
417    Valid ranges for the random number generator are guessed from the name of
418    the parameter; this will not account for constraints such as cap radius
419    greater than cylinder radius in the capped_cylinder model, so
420    :func:`constrain_pars` needs to be called afterward..
421    """
422    # Note: the sort guarantees order of calls to random number generator
423    random_pars = dict((p, _randomize_one(model_info, p, v))
424                       for p, v in sorted(pars.items()))
425    if model_info.random is not None:
426        random_pars.update(model_info.random())
427    _random_pd(model_info, random_pars)
428    return random_pars
429
430
431def constrain_pars(model_info, pars):
432    # type: (ModelInfo, ParameterSet) -> None
433    """
434    Restrict parameters to valid values.
435
436    This includes model specific code for models such as capped_cylinder
437    which need to support within model constraints (cap radius more than
438    cylinder radius in this case).
439
440    Warning: this updates the *pars* dictionary in place.
441    """
442    # TODO: move the model specific code to the individual models
443    name = model_info.id
444    # if it is a product model, then just look at the form factor since
445    # none of the structure factors need any constraints.
446    if '*' in name:
447        name = name.split('*')[0]
448
449    # Suppress magnetism for python models (not yet implemented)
450    if callable(model_info.Iq):
451        pars.update(suppress_magnetism(pars))
452
453    if name == 'barbell':
454        if pars['radius_bell'] < pars['radius']:
455            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
456
457    elif name == 'capped_cylinder':
458        if pars['radius_cap'] < pars['radius']:
459            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
460
461    elif name == 'guinier':
462        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
463        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
464        # => ln A - (Rg^2 q^2/3) > -30 ln 10
465        # => Rg^2 q^2/3 < 30 ln 10 + ln A
466        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
467        #q_max = 0.2  # mid q maximum
468        q_max = 1.0  # high q maximum
469        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
470        pars['rg'] = min(pars['rg'], rg_max)
471
472    elif name == 'pearl_necklace':
473        if pars['radius'] < pars['thick_string']:
474            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
475        pass
476
477    elif name == 'rpa':
478        # Make sure phi sums to 1.0
479        if pars['case_num'] < 2:
480            pars['Phi1'] = 0.
481            pars['Phi2'] = 0.
482        elif pars['case_num'] < 5:
483            pars['Phi1'] = 0.
484        total = sum(pars['Phi'+c] for c in '1234')
485        for c in '1234':
486            pars['Phi'+c] /= total
487
488def parlist(model_info, pars, is2d):
489    # type: (ModelInfo, ParameterSet, bool) -> str
490    """
491    Format the parameter list for printing.
492    """
493    lines = []
494    parameters = model_info.parameters
495    magnetic = False
496    magnetic_pars = []
497    for p in parameters.user_parameters(pars, is2d):
498        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
499            continue
500        if p.id.startswith('up:'):
501            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
502            continue
503        fields = dict(
504            value=pars.get(p.id, p.default),
505            pd=pars.get(p.id+"_pd", 0.),
506            n=int(pars.get(p.id+"_pd_n", 0)),
507            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
508            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
509            relative_pd=p.relative_pd,
510            M0=pars.get('M0:'+p.id, 0.),
511            mphi=pars.get('mphi:'+p.id, 0.),
512            mtheta=pars.get('mtheta:'+p.id, 0.),
513        )
514        lines.append(_format_par(p.name, **fields))
515        magnetic = magnetic or fields['M0'] != 0.
516    if magnetic and magnetic_pars:
517        lines.append(" ".join(magnetic_pars))
518    return "\n".join(lines)
519
520    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
521
522def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
523                relative_pd=False, M0=0., mphi=0., mtheta=0.):
524    # type: (str, float, float, int, float, str) -> str
525    line = "%s: %g"%(name, value)
526    if pd != 0.  and n != 0:
527        if relative_pd:
528            pd *= value
529        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
530                % (pd, n, nsigma, nsigma, pdtype)
531    if M0 != 0.:
532        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
533    return line
534
535def suppress_pd(pars, suppress=True):
536    # type: (ParameterSet) -> ParameterSet
537    """
538    If suppress is True complete eliminate polydispersity of the model to test
539    models more quickly.  If suppress is False, make sure at least one
540    parameter is polydisperse, setting the first polydispersity parameter to
541    15% if no polydispersity is given (with no explicit demo parameters given
542    in the model, there will be no default polydispersity).
543    """
544    pars = pars.copy()
545    if suppress:
546        for p in pars:
547            if p.endswith("_pd_n"):
548                pars[p] = 0
549    else:
550        any_pd = False
551        first_pd = None
552        for p in pars:
553            if p.endswith("_pd_n"):
554                any_pd |= (pars[p] != 0 and pars[p[:-2]] != 0.)
555                if first_pd is None:
556                    first_pd = p
557        if not any_pd and first_pd is not None:
558            if pars[first_pd] == 0:
559                pars[first_pd] = 35
560            if pars[first_pd[:-2]] == 0:
561                pars[first_pd[:-2]] = 0.15
562    return pars
563
564def suppress_magnetism(pars, suppress=True):
565    # type: (ParameterSet) -> ParameterSet
566    """
567    If suppress is True complete eliminate magnetism of the model to test
568    models more quickly.  If suppress is False, make sure at least one sld
569    parameter is magnetic, setting the first parameter to have a strong
570    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
571    in the model, there will be no default magnetism).
572    """
573    pars = pars.copy()
574    if suppress:
575        for p in pars:
576            if p.startswith("M0:"):
577                pars[p] = 0
578    else:
579        any_mag = False
580        first_mag = None
581        for p in pars:
582            if p.startswith("M0:"):
583                any_mag |= (pars[p] != 0)
584                if first_mag is None:
585                    first_mag = p
586        if not any_mag and first_mag is not None:
587            pars[first_mag] = 8.
588    return pars
589
590def eval_sasview(model_info, data):
591    # type: (Modelinfo, Data) -> Calculator
592    """
593    Return a model calculator using the pre-4.0 SasView models.
594    """
595    # importing sas here so that the error message will be that sas failed to
596    # import rather than the more obscure smear_selection not imported error
597    import sas
598    import sas.models
599    from sas.models.qsmearing import smear_selection
600    from sas.models.MultiplicationModel import MultiplicationModel
601    from sas.models.dispersion_models import models as dispersers
602
603    def get_model_class(name):
604        # type: (str) -> "sas.models.BaseComponent"
605        #print("new",sorted(_pars.items()))
606        __import__('sas.models.' + name)
607        ModelClass = getattr(getattr(sas.models, name, None), name, None)
608        if ModelClass is None:
609            raise ValueError("could not find model %r in sas.models"%name)
610        return ModelClass
611
612    # WARNING: ugly hack when handling model!
613    # Sasview models with multiplicity need to be created with the target
614    # multiplicity, so we cannot create the target model ahead of time for
615    # for multiplicity models.  Instead we store the model in a list and
616    # update the first element of that list with the new multiplicity model
617    # every time we evaluate.
618
619    # grab the sasview model, or create it if it is a product model
620    if model_info.composition:
621        composition_type, parts = model_info.composition
622        if composition_type == 'product':
623            P, S = [get_model_class(revert_name(p))() for p in parts]
624            model = [MultiplicationModel(P, S)]
625        else:
626            raise ValueError("sasview mixture models not supported by compare")
627    else:
628        old_name = revert_name(model_info)
629        if old_name is None:
630            raise ValueError("model %r does not exist in old sasview"
631                            % model_info.id)
632        ModelClass = get_model_class(old_name)
633        model = [ModelClass()]
634    model[0].disperser_handles = {}
635
636    # build a smearer with which to call the model, if necessary
637    smearer = smear_selection(data, model=model)
638    if hasattr(data, 'qx_data'):
639        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
640        index = ((~data.mask) & (~np.isnan(data.data))
641                 & (q >= data.qmin) & (q <= data.qmax))
642        if smearer is not None:
643            smearer.model = model  # because smear_selection has a bug
644            smearer.accuracy = data.accuracy
645            smearer.set_index(index)
646            def _call_smearer():
647                smearer.model = model[0]
648                return smearer.get_value()
649            theory = _call_smearer
650        else:
651            theory = lambda: model[0].evalDistribution([data.qx_data[index],
652                                                        data.qy_data[index]])
653    elif smearer is not None:
654        theory = lambda: smearer(model[0].evalDistribution(data.x))
655    else:
656        theory = lambda: model[0].evalDistribution(data.x)
657
658    def calculator(**pars):
659        # type: (float, ...) -> np.ndarray
660        """
661        Sasview calculator for model.
662        """
663        oldpars = revert_pars(model_info, pars)
664        # For multiplicity models, create a model with the correct multiplicity
665        control = oldpars.pop("CONTROL", None)
666        if control is not None:
667            # sphericalSLD has one fewer multiplicity.  This update should
668            # happen in revert_pars, but it hasn't been called yet.
669            model[0] = ModelClass(control)
670        # paying for parameter conversion each time to keep life simple, if not fast
671        for k, v in oldpars.items():
672            if k.endswith('.type'):
673                par = k[:-5]
674                if v == 'gaussian': continue
675                cls = dispersers[v if v != 'rectangle' else 'rectangula']
676                handle = cls()
677                model[0].disperser_handles[par] = handle
678                try:
679                    model[0].set_dispersion(par, handle)
680                except Exception:
681                    exception.annotate_exception("while setting %s to %r"
682                                                 %(par, v))
683                    raise
684
685
686        #print("sasview pars",oldpars)
687        for k, v in oldpars.items():
688            name_attr = k.split('.')  # polydispersity components
689            if len(name_attr) == 2:
690                par, disp_par = name_attr
691                model[0].dispersion[par][disp_par] = v
692            else:
693                model[0].setParam(k, v)
694        return theory()
695
696    calculator.engine = "sasview"
697    return calculator
698
699DTYPE_MAP = {
700    'half': '16',
701    'fast': 'fast',
702    'single': '32',
703    'double': '64',
704    'quad': '128',
705    'f16': '16',
706    'f32': '32',
707    'f64': '64',
708    'float16': '16',
709    'float32': '32',
710    'float64': '64',
711    'float128': '128',
712    'longdouble': '128',
713}
714def eval_opencl(model_info, data, dtype='single', cutoff=0.):
715    # type: (ModelInfo, Data, str, float) -> Calculator
716    """
717    Return a model calculator using the OpenCL calculation engine.
718    """
719    if not core.HAVE_OPENCL:
720        raise RuntimeError("OpenCL not available")
721    model = core.build_model(model_info, dtype=dtype, platform="ocl")
722    calculator = DirectModel(data, model, cutoff=cutoff)
723    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
724    return calculator
725
726def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
727    # type: (ModelInfo, Data, str, float) -> Calculator
728    """
729    Return a model calculator using the DLL calculation engine.
730    """
731    model = core.build_model(model_info, dtype=dtype, platform="dll")
732    calculator = DirectModel(data, model, cutoff=cutoff)
733    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
734    return calculator
735
736def time_calculation(calculator, pars, evals=1):
737    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
738    """
739    Compute the average calculation time over N evaluations.
740
741    An additional call is generated without polydispersity in order to
742    initialize the calculation engine, and make the average more stable.
743    """
744    # initialize the code so time is more accurate
745    if evals > 1:
746        calculator(**suppress_pd(pars))
747    toc = tic()
748    # make sure there is at least one eval
749    value = calculator(**pars)
750    for _ in range(evals-1):
751        value = calculator(**pars)
752    average_time = toc()*1000. / evals
753    #print("I(q)",value)
754    return value, average_time
755
756def make_data(opts):
757    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
758    """
759    Generate an empty dataset, used with the model to set Q points
760    and resolution.
761
762    *opts* contains the options, with 'qmax', 'nq', 'res',
763    'accuracy', 'is2d' and 'view' parsed from the command line.
764    """
765    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
766    if opts['is2d']:
767        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
768        data = empty_data2D(q, resolution=res)
769        data.accuracy = opts['accuracy']
770        set_beam_stop(data, 0.0004)
771        index = ~data.mask
772    else:
773        if opts['view'] == 'log' and not opts['zero']:
774            qmax = math.log10(qmax)
775            q = np.logspace(qmax-3, qmax, nq)
776        else:
777            q = np.linspace(0.001*qmax, qmax, nq)
778        if opts['zero']:
779            q = np.hstack((0, q))
780        data = empty_data1D(q, resolution=res)
781        index = slice(None, None)
782    return data, index
783
784def make_engine(model_info, data, dtype, cutoff):
785    # type: (ModelInfo, Data, str, float) -> Calculator
786    """
787    Generate the appropriate calculation engine for the given datatype.
788
789    Datatypes with '!' appended are evaluated using external C DLLs rather
790    than OpenCL.
791    """
792    if dtype == 'sasview':
793        return eval_sasview(model_info, data)
794    elif dtype is None or not dtype.endswith('!'):
795        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
796    else:
797        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
798
799def _show_invalid(data, theory):
800    # type: (Data, np.ma.ndarray) -> None
801    """
802    Display a list of the non-finite values in theory.
803    """
804    if not theory.mask.any():
805        return
806
807    if hasattr(data, 'x'):
808        bad = zip(data.x[theory.mask], theory[theory.mask])
809        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
810
811
812def compare(opts, limits=None):
813    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
814    """
815    Preform a comparison using options from the command line.
816
817    *limits* are the limits on the values to use, either to set the y-axis
818    for 1D or to set the colormap scale for 2D.  If None, then they are
819    inferred from the data and returned. When exploring using Bumps,
820    the limits are set when the model is initially called, and maintained
821    as the values are adjusted, making it easier to see the effects of the
822    parameters.
823    """
824    limits = np.Inf, -np.Inf
825    for k in range(opts['sets']):
826        opts['pars'] = parse_pars(opts)
827        if opts['pars'] is None:
828            return
829        result = run_models(opts, verbose=True)
830        if opts['plot']:
831            limits = plot_models(opts, result, limits=limits, setnum=k)
832    if opts['plot']:
833        import matplotlib.pyplot as plt
834        plt.show()
835
836def run_models(opts, verbose=False):
837    # type: (Dict[str, Any]) -> Dict[str, Any]
838
839    base, comp = opts['engines']
840    base_n, comp_n = opts['count']
841    base_pars, comp_pars = opts['pars']
842    data = opts['data']
843
844    comparison = comp is not None
845
846    base_time = comp_time = None
847    base_value = comp_value = resid = relerr = None
848
849    # Base calculation
850    try:
851        base_raw, base_time = time_calculation(base, base_pars, base_n)
852        base_value = np.ma.masked_invalid(base_raw)
853        if verbose:
854            print("%s t=%.2f ms, intensity=%.0f"
855                  % (base.engine, base_time, base_value.sum()))
856        _show_invalid(data, base_value)
857    except ImportError:
858        traceback.print_exc()
859
860    # Comparison calculation
861    if comparison:
862        try:
863            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
864            comp_value = np.ma.masked_invalid(comp_raw)
865            if verbose:
866                print("%s t=%.2f ms, intensity=%.0f"
867                      % (comp.engine, comp_time, comp_value.sum()))
868            _show_invalid(data, comp_value)
869        except ImportError:
870            traceback.print_exc()
871
872    # Compare, but only if computing both forms
873    if comparison:
874        resid = (base_value - comp_value)
875        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
876        if verbose:
877            _print_stats("|%s-%s|"
878                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
879                         resid)
880            _print_stats("|(%s-%s)/%s|"
881                         % (base.engine, comp.engine, comp.engine),
882                         relerr)
883
884    return dict(base_value=base_value, comp_value=comp_value,
885                base_time=base_time, comp_time=comp_time,
886                resid=resid, relerr=relerr)
887
888
889def _print_stats(label, err):
890    # type: (str, np.ma.ndarray) -> None
891    # work with trimmed data, not the full set
892    sorted_err = np.sort(abs(err.compressed()))
893    if len(sorted_err) == 0.:
894        print(label + "  no valid values")
895        return
896
897    p50 = int((len(sorted_err)-1)*0.50)
898    p98 = int((len(sorted_err)-1)*0.98)
899    data = [
900        "max:%.3e"%sorted_err[-1],
901        "median:%.3e"%sorted_err[p50],
902        "98%%:%.3e"%sorted_err[p98],
903        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
904        "zero-offset:%+.3e"%np.mean(sorted_err),
905        ]
906    print(label+"  "+"  ".join(data))
907
908
909def plot_models(opts, result, limits=(np.Inf, -np.Inf), setnum=0):
910    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
911    base_value, comp_value = result['base_value'], result['comp_value']
912    base_time, comp_time = result['base_time'], result['comp_time']
913    resid, relerr = result['resid'], result['relerr']
914
915    have_base, have_comp = (base_value is not None), (comp_value is not None)
916    base, comp = opts['engines']
917    data = opts['data']
918    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
919
920    # Plot if requested
921    view = opts['view']
922    import matplotlib.pyplot as plt
923    vmin, vmax = limits
924    if have_base:
925        vmin = min(vmin, base_value.min())
926        vmax = max(vmax, base_value.max())
927    if have_comp:
928        vmin = min(vmin, comp_value.min())
929        vmax = max(vmax, comp_value.max())
930    limits = vmin, vmax
931
932    if have_base:
933        if have_comp:
934            plt.subplot(131)
935        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
936        plt.title("%s t=%.2f ms"%(base.engine, base_time))
937        #cbar_title = "log I"
938    if have_comp:
939        if have_base:
940            plt.subplot(132)
941        if not opts['is2d'] and have_base:
942            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
943        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
944        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
945        #cbar_title = "log I"
946    if have_base and have_comp:
947        plt.subplot(133)
948        if not opts['rel_err']:
949            err, errstr, errview = resid, "abs err", "linear"
950        else:
951            err, errstr, errview = abs(relerr), "rel err", "log"
952        if 0:  # 95% cutoff
953            sorted = np.sort(err.flatten())
954            cutoff = sorted[int(sorted.size*0.95)]
955            err[err > cutoff] = cutoff
956        #err,errstr = base/comp,"ratio"
957        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
958        if view == 'linear':
959            plt.xscale('linear')
960        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
961        #cbar_title = errstr if errview=="linear" else "log "+errstr
962    #if is2D:
963    #    h = plt.colorbar()
964    #    h.ax.set_title(cbar_title)
965    fig = plt.gcf()
966    extra_title = ' '+opts['title'] if opts['title'] else ''
967    fig.suptitle(":".join(opts['name']) + extra_title)
968
969    if have_base and have_comp and opts['show_hist']:
970        plt.figure()
971        v = relerr
972        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
973        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
974        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
975                   % (base.engine, comp.engine, comp.engine))
976        plt.ylabel('P(err)')
977        plt.title('Distribution of relative error between calculation engines')
978
979    return limits
980
981
982# ===========================================================================
983#
984
985# Set of command line options.
986# Normal options such as -plot/-noplot are specified as 'name'.
987# For options such as -nq=500 which require a value use 'name='.
988#
989OPTIONS = [
990    # Plotting
991    'plot', 'noplot',
992    'linear', 'log', 'q4',
993    'rel', 'abs',
994    'hist', 'nohist',
995    'title=',
996
997    # Data generation
998    'data=', 'noise=', 'res=',
999    'nq=', 'lowq', 'midq', 'highq', 'exq', 'zero',
1000    '2d', '1d',
1001
1002    # Parameter set
1003    'preset', 'random', 'random=', 'sets=',
1004    'demo', 'default',  # TODO: remove demo/default
1005    'nopars', 'pars',
1006
1007    # Calculation options
1008    'poly', 'mono', 'cutoff=',
1009    'magnetic', 'nonmagnetic',
1010    'accuracy=',
1011
1012    # Precision options
1013    'calc=',
1014    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1015    'sasview',  # TODO: remove sasview 3.x support
1016    'timing=',
1017
1018    # Output options
1019    'help', 'html', 'edit',
1020    ]
1021
1022NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1023VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1024
1025
1026def columnize(items, indent="", width=79):
1027    # type: (List[str], str, int) -> str
1028    """
1029    Format a list of strings into columns.
1030
1031    Returns a string with carriage returns ready for printing.
1032    """
1033    column_width = max(len(w) for w in items) + 1
1034    num_columns = (width - len(indent)) // column_width
1035    num_rows = len(items) // num_columns
1036    items = items + [""] * (num_rows * num_columns - len(items))
1037    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1038    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1039             for row in zip(*columns)]
1040    output = indent + ("\n"+indent).join(lines)
1041    return output
1042
1043
1044def get_pars(model_info, use_demo=False):
1045    # type: (ModelInfo, bool) -> ParameterSet
1046    """
1047    Extract demo parameters from the model definition.
1048    """
1049    # Get the default values for the parameters
1050    pars = {}
1051    for p in model_info.parameters.call_parameters:
1052        parts = [('', p.default)]
1053        if p.polydisperse:
1054            parts.append(('_pd', 0.0))
1055            parts.append(('_pd_n', 0))
1056            parts.append(('_pd_nsigma', 3.0))
1057            parts.append(('_pd_type', "gaussian"))
1058        for ext, val in parts:
1059            if p.length > 1:
1060                dict(("%s%d%s" % (p.id, k, ext), val)
1061                     for k in range(1, p.length+1))
1062            else:
1063                pars[p.id + ext] = val
1064
1065    # Plug in values given in demo
1066    if use_demo:
1067        pars.update(model_info.demo)
1068    return pars
1069
1070INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1071def isnumber(str):
1072    match = FLOAT_RE.match(str)
1073    isfloat = (match and not str[match.end():])
1074    return isfloat or INTEGER_RE.match(str)
1075
1076# For distinguishing pairs of models for comparison
1077# key-value pair separator =
1078# shell characters  | & ; <> $ % ' " \ # `
1079# model and parameter names _
1080# parameter expressions - + * / . ( )
1081# path characters including tilde expansion and windows drive ~ / :
1082# not sure about brackets [] {}
1083# maybe one of the following @ ? ^ ! ,
1084PAR_SPLIT = ','
1085def parse_opts(argv):
1086    # type: (List[str]) -> Dict[str, Any]
1087    """
1088    Parse command line options.
1089    """
1090    MODELS = core.list_models()
1091    flags = [arg for arg in argv
1092             if arg.startswith('-')]
1093    values = [arg for arg in argv
1094              if not arg.startswith('-') and '=' in arg]
1095    positional_args = [arg for arg in argv
1096                       if not arg.startswith('-') and '=' not in arg]
1097    models = "\n    ".join("%-15s"%v for v in MODELS)
1098    if len(positional_args) == 0:
1099        print(USAGE)
1100        print("\nAvailable models:")
1101        print(columnize(MODELS, indent="  "))
1102        return None
1103
1104    invalid = [o[1:] for o in flags
1105               if o[1:] not in NAME_OPTIONS
1106               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1107    if invalid:
1108        print("Invalid options: %s"%(", ".join(invalid)))
1109        return None
1110
1111    name = positional_args[-1]
1112
1113    # pylint: disable=bad-whitespace
1114    # Interpret the flags
1115    opts = {
1116        'plot'      : True,
1117        'view'      : 'log',
1118        'is2d'      : False,
1119        'qmax'      : 0.05,
1120        'nq'        : 128,
1121        'res'       : 0.0,
1122        'noise'     : 0.0,
1123        'accuracy'  : 'Low',
1124        'cutoff'    : '0.0',
1125        'seed'      : -1,  # default to preset
1126        'mono'      : True,
1127        # Default to magnetic a magnetic moment is set on the command line
1128        'magnetic'  : False,
1129        'show_pars' : False,
1130        'show_hist' : False,
1131        'rel_err'   : True,
1132        'explore'   : False,
1133        'use_demo'  : True,
1134        'zero'      : False,
1135        'html'      : False,
1136        'title'     : None,
1137        'datafile'  : None,
1138        'sets'      : 1,
1139        'engine'    : 'default',
1140        'evals'     : '1',
1141    }
1142    for arg in flags:
1143        if arg == '-noplot':    opts['plot'] = False
1144        elif arg == '-plot':    opts['plot'] = True
1145        elif arg == '-linear':  opts['view'] = 'linear'
1146        elif arg == '-log':     opts['view'] = 'log'
1147        elif arg == '-q4':      opts['view'] = 'q4'
1148        elif arg == '-1d':      opts['is2d'] = False
1149        elif arg == '-2d':      opts['is2d'] = True
1150        elif arg == '-exq':     opts['qmax'] = 10.0
1151        elif arg == '-highq':   opts['qmax'] = 1.0
1152        elif arg == '-midq':    opts['qmax'] = 0.2
1153        elif arg == '-lowq':    opts['qmax'] = 0.05
1154        elif arg == '-zero':    opts['zero'] = True
1155        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1156        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1157        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1158        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1159        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1160        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1161        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
1162        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1163        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1164        elif arg.startswith('-calc='):     opts['engine'] = arg[6:]
1165        elif arg.startswith('-neval='):    opts['evals'] = arg[7:]
1166        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
1167        elif arg == '-preset':  opts['seed'] = -1
1168        elif arg == '-mono':    opts['mono'] = True
1169        elif arg == '-poly':    opts['mono'] = False
1170        elif arg == '-magnetic':       opts['magnetic'] = True
1171        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1172        elif arg == '-pars':    opts['show_pars'] = True
1173        elif arg == '-nopars':  opts['show_pars'] = False
1174        elif arg == '-hist':    opts['show_hist'] = True
1175        elif arg == '-nohist':  opts['show_hist'] = False
1176        elif arg == '-rel':     opts['rel_err'] = True
1177        elif arg == '-abs':     opts['rel_err'] = False
1178        elif arg == '-half':    opts['engine'] = 'half'
1179        elif arg == '-fast':    opts['engine'] = 'fast'
1180        elif arg == '-single':  opts['engine'] = 'single'
1181        elif arg == '-double':  opts['engine'] = 'double'
1182        elif arg == '-single!': opts['engine'] = 'single!'
1183        elif arg == '-double!': opts['engine'] = 'double!'
1184        elif arg == '-quad!':   opts['engine'] = 'quad!'
1185        elif arg == '-sasview': opts['engine'] = 'sasview'
1186        elif arg == '-edit':    opts['explore'] = True
1187        elif arg == '-demo':    opts['use_demo'] = True
1188        elif arg == '-default': opts['use_demo'] = False
1189        elif arg == '-html':    opts['html'] = True
1190        elif arg == '-help':    opts['html'] = True
1191    # pylint: enable=bad-whitespace
1192
1193    # Magnetism forces 2D for now
1194    if opts['magnetic']:
1195        opts['is2d'] = True
1196
1197    # Force random if more than one set
1198    if opts['sets'] > 1 and opts['seed'] < 0:
1199        opts['seed'] = np.random.randint(1000000)
1200
1201    # Create the computational engines
1202    if opts['datafile'] is not None:
1203        data = load_data(os.path.expanduser(opts['datafile']))
1204    else:
1205        data, _ = make_data(opts)
1206
1207    comparison = any(PAR_SPLIT in v for v in values)
1208    if PAR_SPLIT in name:
1209        names = name.split(PAR_SPLIT, 2)
1210        comparison = True
1211    else:
1212        names = [name]*2
1213    try:
1214        model_info = [core.load_model_info(k) for k in names]
1215    except ImportError as exc:
1216        print(str(exc))
1217        print("Could not find model; use one of:\n    " + models)
1218        return None
1219
1220    if PAR_SPLIT in opts['engine']:
1221        engine_types = opts['engine'].split(PAR_SPLIT, 2)
1222        comparison = True
1223    else:
1224        engine_types = [opts['engine']]*2
1225
1226    if PAR_SPLIT in opts['evals']:
1227        evals = [int(k) for k in opts['evals'].split(PAR_SPLIT, 2)]
1228        comparison = True
1229    else:
1230        evals = [int(opts['evals'])]*2
1231
1232    if PAR_SPLIT in opts['cutoff']:
1233        cutoff = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1234        comparison = True
1235    else:
1236        cutoff = [float(opts['cutoff'])]*2
1237
1238    base = make_engine(model_info[0], data, engine_types[0], cutoff[0])
1239    if comparison:
1240        comp = make_engine(model_info[1], data, engine_types[1], cutoff[1])
1241    else:
1242        comp = None
1243
1244    # pylint: disable=bad-whitespace
1245    # Remember it all
1246    opts.update({
1247        'data'      : data,
1248        'name'      : names,
1249        'def'       : model_info,
1250        'count'     : evals,
1251        'engines'   : [base, comp],
1252        'values'    : values,
1253    })
1254    # pylint: enable=bad-whitespace
1255
1256    return opts
1257
1258def parse_pars(opts):
1259    model_info, model_info2 = opts['def']
1260
1261    # Get demo parameters from model definition, or use default parameters
1262    # if model does not define demo parameters
1263    pars = get_pars(model_info, opts['use_demo'])
1264    pars2 = get_pars(model_info2, opts['use_demo'])
1265    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1266    # randomize parameters
1267    #pars.update(set_pars)  # set value before random to control range
1268    if opts['seed'] > -1:
1269        pars = randomize_pars(model_info, pars)
1270        if model_info != model_info2:
1271            pars2 = randomize_pars(model_info2, pars2)
1272            # Share values for parameters with the same name
1273            for k, v in pars.items():
1274                if k in pars2:
1275                    pars2[k] = v
1276        else:
1277            pars2 = pars.copy()
1278        constrain_pars(model_info, pars)
1279        constrain_pars(model_info2, pars2)
1280    pars = suppress_pd(pars, opts['mono'])
1281    pars2 = suppress_pd(pars2, opts['mono'])
1282    pars = suppress_magnetism(pars, not opts['magnetic'])
1283    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1284
1285    # Fill in parameters given on the command line
1286    presets = {}
1287    presets2 = {}
1288    for arg in opts['values']:
1289        k, v = arg.split('=', 1)
1290        if k not in pars and k not in pars2:
1291            # extract base name without polydispersity info
1292            s = set(p.split('_pd')[0] for p in pars)
1293            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1294            return None
1295        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
1296        if v1 and k in pars:
1297            presets[k] = float(v1) if isnumber(v1) else v1
1298        if v2 and k in pars2:
1299            presets2[k] = float(v2) if isnumber(v2) else v2
1300
1301    # If pd given on the command line, default pd_n to 35
1302    for k, v in list(presets.items()):
1303        if k.endswith('_pd'):
1304            presets.setdefault(k+'_n', 35.)
1305    for k, v in list(presets2.items()):
1306        if k.endswith('_pd'):
1307            presets2.setdefault(k+'_n', 35.)
1308
1309    # Evaluate preset parameter expressions
1310    context = MATH.copy()
1311    context['np'] = np
1312    context.update(pars)
1313    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1314    for k, v in presets.items():
1315        if not isinstance(v, float) and not k.endswith('_type'):
1316            presets[k] = eval(v, context)
1317    context.update(presets)
1318    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1319    for k, v in presets2.items():
1320        if not isinstance(v, float) and not k.endswith('_type'):
1321            presets2[k] = eval(v, context)
1322
1323    # update parameters with presets
1324    pars.update(presets)  # set value after random to control value
1325    pars2.update(presets2)  # set value after random to control value
1326    #import pprint; pprint.pprint(model_info)
1327
1328    if opts['show_pars']:
1329        if model_info.name != model_info2.name or pars != pars2:
1330            print("==== %s ====="%model_info.name)
1331            print(str(parlist(model_info, pars, opts['is2d'])))
1332            print("==== %s ====="%model_info2.name)
1333            print(str(parlist(model_info2, pars2, opts['is2d'])))
1334        else:
1335            print(str(parlist(model_info, pars, opts['is2d'])))
1336
1337    return pars, pars2
1338
1339def show_docs(opts):
1340    # type: (Dict[str, Any]) -> None
1341    """
1342    show html docs for the model
1343    """
1344    import os
1345    from .generate import make_html
1346    from . import rst2html
1347
1348    info = opts['def'][0]
1349    html = make_html(info)
1350    path = os.path.dirname(info.filename)
1351    url = "file://"+path.replace("\\","/")[2:]+"/"
1352    rst2html.view_html_qtapp(html, url)
1353
1354def explore(opts):
1355    # type: (Dict[str, Any]) -> None
1356    """
1357    explore the model using the bumps gui.
1358    """
1359    import wx  # type: ignore
1360    from bumps.names import FitProblem  # type: ignore
1361    from bumps.gui.app_frame import AppFrame  # type: ignore
1362    from bumps.gui import signal
1363
1364    is_mac = "cocoa" in wx.version()
1365    # Create an app if not running embedded
1366    app = wx.App() if wx.GetApp() is None else None
1367    model = Explore(opts)
1368    problem = FitProblem(model)
1369    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1370    if not is_mac:
1371        frame.Show()
1372    frame.panel.set_model(model=problem)
1373    frame.panel.Layout()
1374    frame.panel.aui.Split(0, wx.TOP)
1375    def reset_parameters(event):
1376        model.revert_values()
1377        signal.update_parameters(problem)
1378    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1379    if is_mac: frame.Show()
1380    # If running withing an app, start the main loop
1381    if app:
1382        app.MainLoop()
1383
1384class Explore(object):
1385    """
1386    Bumps wrapper for a SAS model comparison.
1387
1388    The resulting object can be used as a Bumps fit problem so that
1389    parameters can be adjusted in the GUI, with plots updated on the fly.
1390    """
1391    def __init__(self, opts):
1392        # type: (Dict[str, Any]) -> None
1393        from bumps.cli import config_matplotlib  # type: ignore
1394        from . import bumps_model
1395        config_matplotlib()
1396        self.opts = opts
1397        opts['pars'] = list(opts['pars'])
1398        p1, p2 = opts['pars']
1399        m1, m2 = opts['def']
1400        self.fix_p2 = m1 != m2 or p1 != p2
1401        model_info = m1
1402        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1403        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1404        if not opts['is2d']:
1405            for p in model_info.parameters.user_parameters({}, is2d=False):
1406                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1407                    k = p.name+ext
1408                    v = pars.get(k, None)
1409                    if v is not None:
1410                        v.range(*parameter_range(k, v.value))
1411        else:
1412            for k, v in pars.items():
1413                v.range(*parameter_range(k, v.value))
1414
1415        self.pars = pars
1416        self.starting_values = dict((k, v.value) for k, v in pars.items())
1417        self.pd_types = pd_types
1418        self.limits = np.Inf, -np.Inf
1419
1420    def revert_values(self):
1421        for k, v in self.starting_values.items():
1422            self.pars[k].value = v
1423
1424    def model_update(self):
1425        pass
1426
1427    def numpoints(self):
1428        # type: () -> int
1429        """
1430        Return the number of points.
1431        """
1432        return len(self.pars) + 1  # so dof is 1
1433
1434    def parameters(self):
1435        # type: () -> Any   # Dict/List hierarchy of parameters
1436        """
1437        Return a dictionary of parameters.
1438        """
1439        return self.pars
1440
1441    def nllf(self):
1442        # type: () -> float
1443        """
1444        Return cost.
1445        """
1446        # pylint: disable=no-self-use
1447        return 0.  # No nllf
1448
1449    def plot(self, view='log'):
1450        # type: (str) -> None
1451        """
1452        Plot the data and residuals.
1453        """
1454        pars = dict((k, v.value) for k, v in self.pars.items())
1455        pars.update(self.pd_types)
1456        self.opts['pars'][0] = pars
1457        if not self.fix_p2:
1458            self.opts['pars'][1] = pars
1459        result = run_models(self.opts)
1460        limits = plot_models(self.opts, result, limits=self.limits)
1461        if self.limits is None:
1462            vmin, vmax = limits
1463            self.limits = vmax*1e-7, 1.3*vmax
1464            import pylab; pylab.clf()
1465            plot_models(self.opts, result, limits=self.limits)
1466
1467
1468def main(*argv):
1469    # type: (*str) -> None
1470    """
1471    Main program.
1472    """
1473    opts = parse_opts(argv)
1474    if opts is not None:
1475        if opts['seed'] > -1:
1476            print("Randomize using -random=%i"%opts['seed'])
1477            np.random.seed(opts['seed'])
1478        if opts['html']:
1479            show_docs(opts)
1480        elif opts['explore']:
1481            opts['pars'] = parse_pars(opts)
1482            if opts['pars'] is None:
1483                return
1484            explore(opts)
1485        else:
1486            compare(opts)
1487
1488if __name__ == "__main__":
1489    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.