source: sasmodels/sasmodels/compare.py @ 3c24ccd

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 3c24ccd was 3c24ccd, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

add -weights option to sascomp to show dispersity distribution

  • Property mode set to 100755
File size: 52.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel, get_mesh
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47from .weights import plot_weights
48
49try:
50    from typing import Optional, Dict, Any, Callable, Tuple
51except Exception:
52    pass
53else:
54    from .modelinfo import ModelInfo, Parameter, ParameterSet
55    from .data import Data
56    Calculator = Callable[[float], np.ndarray]
57
58USAGE = """
59usage: sascomp model [options...] [key=val]
60
61Generate and compare SAS models.  If a single model is specified it shows
62a plot of that model.  Different models can be compared, or the same model
63with different parameters.  The same model with the same parameters can
64be compared with different calculation engines to see the effects of precision
65on the resultant values.
66
67model or model1,model2 are the names of the models to compare (see below).
68
69Options (* for default):
70
71    === data generation ===
72    -data="path" uses q, dq from the data file
73    -noise=0 sets the measurement error dI/I
74    -res=0 sets the resolution width dQ/Q if calculating with resolution
75    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
76    -nq=128 sets the number of Q points in the data set
77    -1d*/-2d computes 1d or 2d data
78    -zero indicates that q=0 should be included
79
80    === model parameters ===
81    -preset*/-random[=seed] preset or random parameters
82    -sets=n generates n random datasets with the seed given by -random=seed
83    -pars/-nopars* prints the parameter set or not
84    -default/-demo* use demo vs default parameters
85
86    === calculation options ===
87    -mono*/-poly force monodisperse or allow polydisperse demo parameters
88    -cutoff=1e-5* cutoff value for including a point in polydispersity
89    -magnetic/-nonmagnetic* suppress magnetism
90    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
91    -neval=1 sets the number of evals for more accurate timing
92
93    === precision options ===
94    -engine=default uses the default calcution precision
95    -single/-double/-half/-fast sets an OpenCL calculation engine
96    -single!/-double!/-quad! sets an OpenMP calculation engine
97    -sasview sets the sasview calculation engine
98
99    === plotting ===
100    -plot*/-noplot plots or suppress the plot of the model
101    -linear/-log*/-q4 intensity scaling on plots
102    -hist/-nohist* plot histogram of relative error
103    -abs/-rel* plot relative or absolute error
104    -title="note" adds note to the plot title, after the model name
105    -weights shows weights plots for the polydisperse parameters
106
107    === output options ===
108    -edit starts the parameter explorer
109    -help/-html shows the model docs instead of running the model
110
111The interpretation of quad precision depends on architecture, and may
112vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
113On unix and mac you may need single quotes around the DLL computation
114engines, such as -engine='single!,double!' since !, is treated as a history
115expansion request in the shell.
116
117Key=value pairs allow you to set specific values for the model parameters.
118Key=value1,value2 to compare different values of the same parameter. The
119value can be an expression including other parameters.
120
121Items later on the command line override those that appear earlier.
122
123Examples:
124
125    # compare single and double precision calculation for a barbell
126    sascomp barbell -engine=single,double
127
128    # generate 10 random lorentz models, with seed=27
129    sascomp lorentz -sets=10 -seed=27
130
131    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
132    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
133
134    # model timing test requires multiple evals to perform the estimate
135    sascomp pringle -engine=single,double -timing=100,100 -noplot
136"""
137
138# Update docs with command line usage string.   This is separate from the usual
139# doc string so that we can display it at run time if there is an error.
140# lin
141__doc__ = (__doc__  # pylint: disable=redefined-builtin
142           + """
143Program description
144-------------------
145
146""" + USAGE)
147
148kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
149
150# list of math functions for use in evaluating parameters
151MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
152
153# CRUFT python 2.6
154if not hasattr(datetime.timedelta, 'total_seconds'):
155    def delay(dt):
156        """Return number date-time delta as number seconds"""
157        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
158else:
159    def delay(dt):
160        """Return number date-time delta as number seconds"""
161        return dt.total_seconds()
162
163
164class push_seed(object):
165    """
166    Set the seed value for the random number generator.
167
168    When used in a with statement, the random number generator state is
169    restored after the with statement is complete.
170
171    :Parameters:
172
173    *seed* : int or array_like, optional
174        Seed for RandomState
175
176    :Example:
177
178    Seed can be used directly to set the seed::
179
180        >>> from numpy.random import randint
181        >>> push_seed(24)
182        <...push_seed object at...>
183        >>> print(randint(0,1000000,3))
184        [242082    899 211136]
185
186    Seed can also be used in a with statement, which sets the random
187    number generator state for the enclosed computations and restores
188    it to the previous state on completion::
189
190        >>> with push_seed(24):
191        ...    print(randint(0,1000000,3))
192        [242082    899 211136]
193
194    Using nested contexts, we can demonstrate that state is indeed
195    restored after the block completes::
196
197        >>> with push_seed(24):
198        ...    print(randint(0,1000000))
199        ...    with push_seed(24):
200        ...        print(randint(0,1000000,3))
201        ...    print(randint(0,1000000))
202        242082
203        [242082    899 211136]
204        899
205
206    The restore step is protected against exceptions in the block::
207
208        >>> with push_seed(24):
209        ...    print(randint(0,1000000))
210        ...    try:
211        ...        with push_seed(24):
212        ...            print(randint(0,1000000,3))
213        ...            raise Exception()
214        ...    except Exception:
215        ...        print("Exception raised")
216        ...    print(randint(0,1000000))
217        242082
218        [242082    899 211136]
219        Exception raised
220        899
221    """
222    def __init__(self, seed=None):
223        # type: (Optional[int]) -> None
224        self._state = np.random.get_state()
225        np.random.seed(seed)
226
227    def __enter__(self):
228        # type: () -> None
229        pass
230
231    def __exit__(self, exc_type, exc_value, traceback):
232        # type: (Any, BaseException, Any) -> None
233        # TODO: better typing for __exit__ method
234        np.random.set_state(self._state)
235
236def tic():
237    # type: () -> Callable[[], float]
238    """
239    Timer function.
240
241    Use "toc=tic()" to start the clock and "toc()" to measure
242    a time interval.
243    """
244    then = datetime.datetime.now()
245    return lambda: delay(datetime.datetime.now() - then)
246
247
248def set_beam_stop(data, radius, outer=None):
249    # type: (Data, float, float) -> None
250    """
251    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
252
253    Note: this function does not require sasview
254    """
255    if hasattr(data, 'qx_data'):
256        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
257        data.mask = (q < radius)
258        if outer is not None:
259            data.mask |= (q >= outer)
260    else:
261        data.mask = (data.x < radius)
262        if outer is not None:
263            data.mask |= (data.x >= outer)
264
265
266def parameter_range(p, v):
267    # type: (str, float) -> Tuple[float, float]
268    """
269    Choose a parameter range based on parameter name and initial value.
270    """
271    # process the polydispersity options
272    if p.endswith('_pd_n'):
273        return 0., 100.
274    elif p.endswith('_pd_nsigma'):
275        return 0., 5.
276    elif p.endswith('_pd_type'):
277        raise ValueError("Cannot return a range for a string value")
278    elif any(s in p for s in ('theta', 'phi', 'psi')):
279        # orientation in [-180,180], orientation pd in [0,45]
280        if p.endswith('_pd'):
281            return 0., 45.
282        else:
283            return -180., 180.
284    elif p.endswith('_pd'):
285        return 0., 1.
286    elif 'sld' in p:
287        return -0.5, 10.
288    elif p == 'background':
289        return 0., 10.
290    elif p == 'scale':
291        return 0., 1.e3
292    elif v < 0.:
293        return 2.*v, -2.*v
294    else:
295        return 0., (2.*v if v > 0. else 1.)
296
297
298def _randomize_one(model_info, name, value):
299    # type: (ModelInfo, str, float) -> float
300    # type: (ModelInfo, str, str) -> str
301    """
302    Randomize a single parameter.
303    """
304    # Set the amount of polydispersity/angular dispersion, but by default pd_n
305    # is zero so there is no polydispersity.  This allows us to turn on/off
306    # pd by setting pd_n, and still have randomly generated values
307    if name.endswith('_pd'):
308        par = model_info.parameters[name[:-3]]
309        if par.type == 'orientation':
310            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
311            return 180*np.random.beta(2.5, 20)
312        else:
313            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
314            return np.random.beta(1.5, 7)
315
316    # pd is selected globally rather than per parameter, so set to 0 for no pd
317    # In particular, when multiple pd dimensions, want to decrease the number
318    # of points per dimension for faster computation
319    if name.endswith('_pd_n'):
320        return 0
321
322    # Don't mess with distribution type for now
323    if name.endswith('_pd_type'):
324        return 'gaussian'
325
326    # type-dependent value of number of sigmas; for gaussian use 3.
327    if name.endswith('_pd_nsigma'):
328        return 3.
329
330    # background in the range [0.01, 1]
331    if name == 'background':
332        return 10**np.random.uniform(-2, 0)
333
334    # scale defaults to 0.1% to 30% volume fraction
335    if name == 'scale':
336        return 10**np.random.uniform(-3, -0.5)
337
338    # If it is a list of choices, pick one at random with equal probability
339    # In practice, the model specific random generator will override.
340    par = model_info.parameters[name]
341    if len(par.limits) > 2:  # choice list
342        return np.random.randint(len(par.limits))
343
344    # If it is a fixed range, pick from it with equal probability.
345    # For logarithmic ranges, the model will have to override.
346    if np.isfinite(par.limits).all():
347        return np.random.uniform(*par.limits)
348
349    # If the paramter is marked as an sld use the range of neutron slds
350    # TODO: ought to randomly contrast match a pair of SLDs
351    if par.type == 'sld':
352        return np.random.uniform(-0.5, 12)
353
354    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
355    if par.name.startswith('M0:'):
356        return np.random.uniform(0, 5)
357
358    # Guess at the random length/radius/thickness.  In practice, all models
359    # are going to set their own reasonable ranges.
360    if par.type == 'volume':
361        if ('length' in par.name or
362                'radius' in par.name or
363                'thick' in par.name):
364            return 10**np.random.uniform(2, 4)
365
366    # In the absence of any other info, select a value in [0, 2v], or
367    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
368    # model random parameter generators will override this default.
369    low, high = parameter_range(par.name, value)
370    limits = (max(par.limits[0], low), min(par.limits[1], high))
371    return np.random.uniform(*limits)
372
373def _random_pd(model_info, pars):
374    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
375    pd_volume = []
376    pd_oriented = []
377    for p in pd:
378        if p.type == 'orientation':
379            pd_oriented.append(p.name)
380        elif p.length_control is not None:
381            n = int(pars.get(p.length_control, 1) + 0.5)
382            pd_volume.extend(p.name+str(k+1) for k in range(n))
383        elif p.length > 1:
384            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
385        else:
386            pd_volume.append(p.name)
387    u = np.random.rand()
388    n = len(pd_volume)
389    if u < 0.01 or n < 1:
390        pass  # 1% chance of no polydispersity
391    elif u < 0.86 or n < 2:
392        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
393    elif u < 0.99 or n < 3:
394        choices = np.random.choice(len(pd_volume), size=2)
395        pars[pd_volume[choices[0]]+"_pd_n"] = 25
396        pars[pd_volume[choices[1]]+"_pd_n"] = 10
397    else:
398        choices = np.random.choice(len(pd_volume), size=3)
399        pars[pd_volume[choices[0]]+"_pd_n"] = 25
400        pars[pd_volume[choices[1]]+"_pd_n"] = 10
401        pars[pd_volume[choices[2]]+"_pd_n"] = 5
402    if pd_oriented:
403        pars['theta_pd_n'] = 20
404        if np.random.rand() < 0.1:
405            pars['phi_pd_n'] = 5
406        if np.random.rand() < 0.1:
407            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
408                #print("generating psi_pd_n")
409                pars['psi_pd_n'] = 5
410
411    ## Show selected polydispersity
412    #for name, value in pars.items():
413    #    if name.endswith('_pd_n') and value > 0:
414    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
415
416
417def randomize_pars(model_info, pars):
418    # type: (ModelInfo, ParameterSet) -> ParameterSet
419    """
420    Generate random values for all of the parameters.
421
422    Valid ranges for the random number generator are guessed from the name of
423    the parameter; this will not account for constraints such as cap radius
424    greater than cylinder radius in the capped_cylinder model, so
425    :func:`constrain_pars` needs to be called afterward..
426    """
427    # Note: the sort guarantees order of calls to random number generator
428    random_pars = dict((p, _randomize_one(model_info, p, v))
429                       for p, v in sorted(pars.items()))
430    if model_info.random is not None:
431        random_pars.update(model_info.random())
432    _random_pd(model_info, random_pars)
433    return random_pars
434
435
436def constrain_pars(model_info, pars):
437    # type: (ModelInfo, ParameterSet) -> None
438    """
439    Restrict parameters to valid values.
440
441    This includes model specific code for models such as capped_cylinder
442    which need to support within model constraints (cap radius more than
443    cylinder radius in this case).
444
445    Warning: this updates the *pars* dictionary in place.
446    """
447    # TODO: move the model specific code to the individual models
448    name = model_info.id
449    # if it is a product model, then just look at the form factor since
450    # none of the structure factors need any constraints.
451    if '*' in name:
452        name = name.split('*')[0]
453
454    # Suppress magnetism for python models (not yet implemented)
455    if callable(model_info.Iq):
456        pars.update(suppress_magnetism(pars))
457
458    if name == 'barbell':
459        if pars['radius_bell'] < pars['radius']:
460            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
461
462    elif name == 'capped_cylinder':
463        if pars['radius_cap'] < pars['radius']:
464            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
465
466    elif name == 'guinier':
467        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
468        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
469        # => ln A - (Rg^2 q^2/3) > -30 ln 10
470        # => Rg^2 q^2/3 < 30 ln 10 + ln A
471        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
472        #q_max = 0.2  # mid q maximum
473        q_max = 1.0  # high q maximum
474        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
475        pars['rg'] = min(pars['rg'], rg_max)
476
477    elif name == 'pearl_necklace':
478        if pars['radius'] < pars['thick_string']:
479            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
480        pass
481
482    elif name == 'rpa':
483        # Make sure phi sums to 1.0
484        if pars['case_num'] < 2:
485            pars['Phi1'] = 0.
486            pars['Phi2'] = 0.
487        elif pars['case_num'] < 5:
488            pars['Phi1'] = 0.
489        total = sum(pars['Phi'+c] for c in '1234')
490        for c in '1234':
491            pars['Phi'+c] /= total
492
493def parlist(model_info, pars, is2d):
494    # type: (ModelInfo, ParameterSet, bool) -> str
495    """
496    Format the parameter list for printing.
497    """
498    lines = []
499    parameters = model_info.parameters
500    magnetic = False
501    magnetic_pars = []
502    for p in parameters.user_parameters(pars, is2d):
503        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
504            continue
505        if p.id.startswith('up:'):
506            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
507            continue
508        fields = dict(
509            value=pars.get(p.id, p.default),
510            pd=pars.get(p.id+"_pd", 0.),
511            n=int(pars.get(p.id+"_pd_n", 0)),
512            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
513            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
514            relative_pd=p.relative_pd,
515            M0=pars.get('M0:'+p.id, 0.),
516            mphi=pars.get('mphi:'+p.id, 0.),
517            mtheta=pars.get('mtheta:'+p.id, 0.),
518        )
519        lines.append(_format_par(p.name, **fields))
520        magnetic = magnetic or fields['M0'] != 0.
521    if magnetic and magnetic_pars:
522        lines.append(" ".join(magnetic_pars))
523    return "\n".join(lines)
524
525    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
526
527def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
528                relative_pd=False, M0=0., mphi=0., mtheta=0.):
529    # type: (str, float, float, int, float, str) -> str
530    line = "%s: %g"%(name, value)
531    if pd != 0.  and n != 0:
532        if relative_pd:
533            pd *= value
534        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
535                % (pd, n, nsigma, nsigma, pdtype)
536    if M0 != 0.:
537        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
538    return line
539
540def suppress_pd(pars, suppress=True):
541    # type: (ParameterSet) -> ParameterSet
542    """
543    If suppress is True complete eliminate polydispersity of the model to test
544    models more quickly.  If suppress is False, make sure at least one
545    parameter is polydisperse, setting the first polydispersity parameter to
546    15% if no polydispersity is given (with no explicit demo parameters given
547    in the model, there will be no default polydispersity).
548    """
549    pars = pars.copy()
550    #print("pars=", pars)
551    if suppress:
552        for p in pars:
553            if p.endswith("_pd_n"):
554                pars[p] = 0
555    else:
556        any_pd = False
557        first_pd = None
558        for p in pars:
559            if p.endswith("_pd_n"):
560                pd = pars.get(p[:-2], 0.)
561                any_pd |= (pars[p] != 0 and pd != 0.)
562                if first_pd is None:
563                    first_pd = p
564        if not any_pd and first_pd is not None:
565            if pars[first_pd] == 0:
566                pars[first_pd] = 35
567            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
568                pars[first_pd[:-2]] = 0.15
569    return pars
570
571def suppress_magnetism(pars, suppress=True):
572    # type: (ParameterSet) -> ParameterSet
573    """
574    If suppress is True complete eliminate magnetism of the model to test
575    models more quickly.  If suppress is False, make sure at least one sld
576    parameter is magnetic, setting the first parameter to have a strong
577    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
578    in the model, there will be no default magnetism).
579    """
580    pars = pars.copy()
581    if suppress:
582        for p in pars:
583            if p.startswith("M0:"):
584                pars[p] = 0
585    else:
586        any_mag = False
587        first_mag = None
588        for p in pars:
589            if p.startswith("M0:"):
590                any_mag |= (pars[p] != 0)
591                if first_mag is None:
592                    first_mag = p
593        if not any_mag and first_mag is not None:
594            pars[first_mag] = 8.
595    return pars
596
597def eval_sasview(model_info, data):
598    # type: (Modelinfo, Data) -> Calculator
599    """
600    Return a model calculator using the pre-4.0 SasView models.
601    """
602    # importing sas here so that the error message will be that sas failed to
603    # import rather than the more obscure smear_selection not imported error
604    import sas
605    import sas.models
606    from sas.models.qsmearing import smear_selection
607    from sas.models.MultiplicationModel import MultiplicationModel
608    from sas.models.dispersion_models import models as dispersers
609
610    def get_model_class(name):
611        # type: (str) -> "sas.models.BaseComponent"
612        #print("new",sorted(_pars.items()))
613        __import__('sas.models.' + name)
614        ModelClass = getattr(getattr(sas.models, name, None), name, None)
615        if ModelClass is None:
616            raise ValueError("could not find model %r in sas.models"%name)
617        return ModelClass
618
619    # WARNING: ugly hack when handling model!
620    # Sasview models with multiplicity need to be created with the target
621    # multiplicity, so we cannot create the target model ahead of time for
622    # for multiplicity models.  Instead we store the model in a list and
623    # update the first element of that list with the new multiplicity model
624    # every time we evaluate.
625
626    # grab the sasview model, or create it if it is a product model
627    if model_info.composition:
628        composition_type, parts = model_info.composition
629        if composition_type == 'product':
630            P, S = [get_model_class(revert_name(p))() for p in parts]
631            model = [MultiplicationModel(P, S)]
632        else:
633            raise ValueError("sasview mixture models not supported by compare")
634    else:
635        old_name = revert_name(model_info)
636        if old_name is None:
637            raise ValueError("model %r does not exist in old sasview"
638                            % model_info.id)
639        ModelClass = get_model_class(old_name)
640        model = [ModelClass()]
641    model[0].disperser_handles = {}
642
643    # build a smearer with which to call the model, if necessary
644    smearer = smear_selection(data, model=model)
645    if hasattr(data, 'qx_data'):
646        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
647        index = ((~data.mask) & (~np.isnan(data.data))
648                 & (q >= data.qmin) & (q <= data.qmax))
649        if smearer is not None:
650            smearer.model = model  # because smear_selection has a bug
651            smearer.accuracy = data.accuracy
652            smearer.set_index(index)
653            def _call_smearer():
654                smearer.model = model[0]
655                return smearer.get_value()
656            theory = _call_smearer
657        else:
658            theory = lambda: model[0].evalDistribution([data.qx_data[index],
659                                                        data.qy_data[index]])
660    elif smearer is not None:
661        theory = lambda: smearer(model[0].evalDistribution(data.x))
662    else:
663        theory = lambda: model[0].evalDistribution(data.x)
664
665    def calculator(**pars):
666        # type: (float, ...) -> np.ndarray
667        """
668        Sasview calculator for model.
669        """
670        oldpars = revert_pars(model_info, pars)
671        # For multiplicity models, create a model with the correct multiplicity
672        control = oldpars.pop("CONTROL", None)
673        if control is not None:
674            # sphericalSLD has one fewer multiplicity.  This update should
675            # happen in revert_pars, but it hasn't been called yet.
676            model[0] = ModelClass(control)
677        # paying for parameter conversion each time to keep life simple, if not fast
678        for k, v in oldpars.items():
679            if k.endswith('.type'):
680                par = k[:-5]
681                if v == 'gaussian': continue
682                cls = dispersers[v if v != 'rectangle' else 'rectangula']
683                handle = cls()
684                model[0].disperser_handles[par] = handle
685                try:
686                    model[0].set_dispersion(par, handle)
687                except Exception:
688                    exception.annotate_exception("while setting %s to %r"
689                                                 %(par, v))
690                    raise
691
692
693        #print("sasview pars",oldpars)
694        for k, v in oldpars.items():
695            name_attr = k.split('.')  # polydispersity components
696            if len(name_attr) == 2:
697                par, disp_par = name_attr
698                model[0].dispersion[par][disp_par] = v
699            else:
700                model[0].setParam(k, v)
701        return theory()
702
703    calculator.engine = "sasview"
704    return calculator
705
706DTYPE_MAP = {
707    'half': '16',
708    'fast': 'fast',
709    'single': '32',
710    'double': '64',
711    'quad': '128',
712    'f16': '16',
713    'f32': '32',
714    'f64': '64',
715    'float16': '16',
716    'float32': '32',
717    'float64': '64',
718    'float128': '128',
719    'longdouble': '128',
720}
721def eval_opencl(model_info, data, dtype='single', cutoff=0.):
722    # type: (ModelInfo, Data, str, float) -> Calculator
723    """
724    Return a model calculator using the OpenCL calculation engine.
725    """
726    if not core.HAVE_OPENCL:
727        raise RuntimeError("OpenCL not available")
728    model = core.build_model(model_info, dtype=dtype, platform="ocl")
729    calculator = DirectModel(data, model, cutoff=cutoff)
730    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
731    return calculator
732
733def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
734    # type: (ModelInfo, Data, str, float) -> Calculator
735    """
736    Return a model calculator using the DLL calculation engine.
737    """
738    model = core.build_model(model_info, dtype=dtype, platform="dll")
739    calculator = DirectModel(data, model, cutoff=cutoff)
740    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
741    return calculator
742
743def time_calculation(calculator, pars, evals=1):
744    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
745    """
746    Compute the average calculation time over N evaluations.
747
748    An additional call is generated without polydispersity in order to
749    initialize the calculation engine, and make the average more stable.
750    """
751    # initialize the code so time is more accurate
752    if evals > 1:
753        calculator(**suppress_pd(pars))
754    toc = tic()
755    # make sure there is at least one eval
756    value = calculator(**pars)
757    for _ in range(evals-1):
758        value = calculator(**pars)
759    average_time = toc()*1000. / evals
760    #print("I(q)",value)
761    return value, average_time
762
763def make_data(opts):
764    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
765    """
766    Generate an empty dataset, used with the model to set Q points
767    and resolution.
768
769    *opts* contains the options, with 'qmax', 'nq', 'res',
770    'accuracy', 'is2d' and 'view' parsed from the command line.
771    """
772    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
773    if opts['is2d']:
774        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
775        data = empty_data2D(q, resolution=res)
776        data.accuracy = opts['accuracy']
777        set_beam_stop(data, 0.0004)
778        index = ~data.mask
779    else:
780        if opts['view'] == 'log' and not opts['zero']:
781            qmax = math.log10(qmax)
782            q = np.logspace(qmax-3, qmax, nq)
783        else:
784            q = np.linspace(0.001*qmax, qmax, nq)
785        if opts['zero']:
786            q = np.hstack((0, q))
787        data = empty_data1D(q, resolution=res)
788        index = slice(None, None)
789    return data, index
790
791def make_engine(model_info, data, dtype, cutoff):
792    # type: (ModelInfo, Data, str, float) -> Calculator
793    """
794    Generate the appropriate calculation engine for the given datatype.
795
796    Datatypes with '!' appended are evaluated using external C DLLs rather
797    than OpenCL.
798    """
799    if dtype == 'sasview':
800        return eval_sasview(model_info, data)
801    elif dtype is None or not dtype.endswith('!'):
802        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
803    else:
804        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
805
806def _show_invalid(data, theory):
807    # type: (Data, np.ma.ndarray) -> None
808    """
809    Display a list of the non-finite values in theory.
810    """
811    if not theory.mask.any():
812        return
813
814    if hasattr(data, 'x'):
815        bad = zip(data.x[theory.mask], theory[theory.mask])
816        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
817
818
819def compare(opts, limits=None):
820    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
821    """
822    Preform a comparison using options from the command line.
823
824    *limits* are the limits on the values to use, either to set the y-axis
825    for 1D or to set the colormap scale for 2D.  If None, then they are
826    inferred from the data and returned. When exploring using Bumps,
827    the limits are set when the model is initially called, and maintained
828    as the values are adjusted, making it easier to see the effects of the
829    parameters.
830    """
831    for k in range(opts['sets']):
832        opts['pars'] = parse_pars(opts)
833        if opts['pars'] is None:
834            return
835        result = run_models(opts, verbose=True)
836        if opts['plot']:
837            limits = plot_models(opts, result, limits=limits, setnum=k)
838        if opts['show_weights']:
839            base, _ = opts['engines']
840            base_pars, _ = opts['pars']
841            model_info = base._kernel.info
842            dim = base._kernel.dim
843            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
844    if opts['plot']:
845        import matplotlib.pyplot as plt
846        plt.show()
847    return limits
848
849def run_models(opts, verbose=False):
850    # type: (Dict[str, Any]) -> Dict[str, Any]
851
852    base, comp = opts['engines']
853    base_n, comp_n = opts['count']
854    base_pars, comp_pars = opts['pars']
855    data = opts['data']
856
857    comparison = comp is not None
858
859    base_time = comp_time = None
860    base_value = comp_value = resid = relerr = None
861
862    # Base calculation
863    try:
864        base_raw, base_time = time_calculation(base, base_pars, base_n)
865        base_value = np.ma.masked_invalid(base_raw)
866        if verbose:
867            print("%s t=%.2f ms, intensity=%.0f"
868                  % (base.engine, base_time, base_value.sum()))
869        _show_invalid(data, base_value)
870    except ImportError:
871        traceback.print_exc()
872
873    # Comparison calculation
874    if comparison:
875        try:
876            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
877            comp_value = np.ma.masked_invalid(comp_raw)
878            if verbose:
879                print("%s t=%.2f ms, intensity=%.0f"
880                      % (comp.engine, comp_time, comp_value.sum()))
881            _show_invalid(data, comp_value)
882        except ImportError:
883            traceback.print_exc()
884
885    # Compare, but only if computing both forms
886    if comparison:
887        resid = (base_value - comp_value)
888        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
889        if verbose:
890            _print_stats("|%s-%s|"
891                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
892                         resid)
893            _print_stats("|(%s-%s)/%s|"
894                         % (base.engine, comp.engine, comp.engine),
895                         relerr)
896
897    return dict(base_value=base_value, comp_value=comp_value,
898                base_time=base_time, comp_time=comp_time,
899                resid=resid, relerr=relerr)
900
901
902def _print_stats(label, err):
903    # type: (str, np.ma.ndarray) -> None
904    # work with trimmed data, not the full set
905    sorted_err = np.sort(abs(err.compressed()))
906    if len(sorted_err) == 0.:
907        print(label + "  no valid values")
908        return
909
910    p50 = int((len(sorted_err)-1)*0.50)
911    p98 = int((len(sorted_err)-1)*0.98)
912    data = [
913        "max:%.3e"%sorted_err[-1],
914        "median:%.3e"%sorted_err[p50],
915        "98%%:%.3e"%sorted_err[p98],
916        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
917        "zero-offset:%+.3e"%np.mean(sorted_err),
918        ]
919    print(label+"  "+"  ".join(data))
920
921
922def plot_models(opts, result, limits=None, setnum=0):
923    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
924    import matplotlib.pyplot as plt
925
926    base_value, comp_value = result['base_value'], result['comp_value']
927    base_time, comp_time = result['base_time'], result['comp_time']
928    resid, relerr = result['resid'], result['relerr']
929
930    have_base, have_comp = (base_value is not None), (comp_value is not None)
931    base, comp = opts['engines']
932    data = opts['data']
933    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
934
935    # Plot if requested
936    view = opts['view']
937    if limits is None:
938        vmin, vmax = np.inf, -np.inf
939        if have_base:
940            vmin = min(vmin, base_value.min())
941            vmax = max(vmax, base_value.max())
942        if have_comp:
943            vmin = min(vmin, comp_value.min())
944            vmax = max(vmax, comp_value.max())
945        limits = vmin, vmax
946
947    if have_base:
948        if have_comp:
949            plt.subplot(131)
950        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
951        plt.title("%s t=%.2f ms"%(base.engine, base_time))
952        #cbar_title = "log I"
953    if have_comp:
954        if have_base:
955            plt.subplot(132)
956        if not opts['is2d'] and have_base:
957            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
958        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
959        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
960        #cbar_title = "log I"
961    if have_base and have_comp:
962        plt.subplot(133)
963        if not opts['rel_err']:
964            err, errstr, errview = resid, "abs err", "linear"
965        else:
966            err, errstr, errview = abs(relerr), "rel err", "log"
967        if 0:  # 95% cutoff
968            sorted = np.sort(err.flatten())
969            cutoff = sorted[int(sorted.size*0.95)]
970            err[err > cutoff] = cutoff
971        #err,errstr = base/comp,"ratio"
972        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
973        if view == 'linear':
974            plt.xscale('linear')
975        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
976        #cbar_title = errstr if errview=="linear" else "log "+errstr
977    #if is2D:
978    #    h = plt.colorbar()
979    #    h.ax.set_title(cbar_title)
980    fig = plt.gcf()
981    extra_title = ' '+opts['title'] if opts['title'] else ''
982    fig.suptitle(":".join(opts['name']) + extra_title)
983
984    if have_base and have_comp and opts['show_hist']:
985        plt.figure()
986        v = relerr
987        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
988        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
989        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
990                   % (base.engine, comp.engine, comp.engine))
991        plt.ylabel('P(err)')
992        plt.title('Distribution of relative error between calculation engines')
993
994    return limits
995
996
997# ===========================================================================
998#
999
1000# Set of command line options.
1001# Normal options such as -plot/-noplot are specified as 'name'.
1002# For options such as -nq=500 which require a value use 'name='.
1003#
1004OPTIONS = [
1005    # Plotting
1006    'plot', 'noplot', 'weights',
1007    'linear', 'log', 'q4',
1008    'rel', 'abs',
1009    'hist', 'nohist',
1010    'title=',
1011
1012    # Data generation
1013    'data=', 'noise=', 'res=',
1014    'nq=', 'lowq', 'midq', 'highq', 'exq', 'zero',
1015    '2d', '1d',
1016
1017    # Parameter set
1018    'preset', 'random', 'random=', 'sets=',
1019    'demo', 'default',  # TODO: remove demo/default
1020    'nopars', 'pars',
1021
1022    # Calculation options
1023    'poly', 'mono', 'cutoff=',
1024    'magnetic', 'nonmagnetic',
1025    'accuracy=',
1026    'neval=',  # for timing...
1027
1028    # Precision options
1029    'engine=',
1030    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1031    'sasview',  # TODO: remove sasview 3.x support
1032
1033    # Output options
1034    'help', 'html', 'edit',
1035    ]
1036
1037NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1038VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1039
1040
1041def columnize(items, indent="", width=79):
1042    # type: (List[str], str, int) -> str
1043    """
1044    Format a list of strings into columns.
1045
1046    Returns a string with carriage returns ready for printing.
1047    """
1048    column_width = max(len(w) for w in items) + 1
1049    num_columns = (width - len(indent)) // column_width
1050    num_rows = len(items) // num_columns
1051    items = items + [""] * (num_rows * num_columns - len(items))
1052    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1053    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1054             for row in zip(*columns)]
1055    output = indent + ("\n"+indent).join(lines)
1056    return output
1057
1058
1059def get_pars(model_info, use_demo=False):
1060    # type: (ModelInfo, bool) -> ParameterSet
1061    """
1062    Extract demo parameters from the model definition.
1063    """
1064    # Get the default values for the parameters
1065    pars = {}
1066    for p in model_info.parameters.call_parameters:
1067        parts = [('', p.default)]
1068        if p.polydisperse:
1069            parts.append(('_pd', 0.0))
1070            parts.append(('_pd_n', 0))
1071            parts.append(('_pd_nsigma', 3.0))
1072            parts.append(('_pd_type', "gaussian"))
1073        for ext, val in parts:
1074            if p.length > 1:
1075                dict(("%s%d%s" % (p.id, k, ext), val)
1076                     for k in range(1, p.length+1))
1077            else:
1078                pars[p.id + ext] = val
1079
1080    # Plug in values given in demo
1081    if use_demo and model_info.demo:
1082        pars.update(model_info.demo)
1083    return pars
1084
1085INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1086def isnumber(str):
1087    match = FLOAT_RE.match(str)
1088    isfloat = (match and not str[match.end():])
1089    return isfloat or INTEGER_RE.match(str)
1090
1091# For distinguishing pairs of models for comparison
1092# key-value pair separator =
1093# shell characters  | & ; <> $ % ' " \ # `
1094# model and parameter names _
1095# parameter expressions - + * / . ( )
1096# path characters including tilde expansion and windows drive ~ / :
1097# not sure about brackets [] {}
1098# maybe one of the following @ ? ^ ! ,
1099PAR_SPLIT = ','
1100def parse_opts(argv):
1101    # type: (List[str]) -> Dict[str, Any]
1102    """
1103    Parse command line options.
1104    """
1105    MODELS = core.list_models()
1106    flags = [arg for arg in argv
1107             if arg.startswith('-')]
1108    values = [arg for arg in argv
1109              if not arg.startswith('-') and '=' in arg]
1110    positional_args = [arg for arg in argv
1111                       if not arg.startswith('-') and '=' not in arg]
1112    models = "\n    ".join("%-15s"%v for v in MODELS)
1113    if len(positional_args) == 0:
1114        print(USAGE)
1115        print("\nAvailable models:")
1116        print(columnize(MODELS, indent="  "))
1117        return None
1118
1119    invalid = [o[1:] for o in flags
1120               if o[1:] not in NAME_OPTIONS
1121               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1122    if invalid:
1123        print("Invalid options: %s"%(", ".join(invalid)))
1124        return None
1125
1126    name = positional_args[-1]
1127
1128    # pylint: disable=bad-whitespace
1129    # Interpret the flags
1130    opts = {
1131        'plot'      : True,
1132        'view'      : 'log',
1133        'is2d'      : False,
1134        'qmax'      : 0.05,
1135        'nq'        : 128,
1136        'res'       : 0.0,
1137        'noise'     : 0.0,
1138        'accuracy'  : 'Low',
1139        'cutoff'    : '0.0',
1140        'seed'      : -1,  # default to preset
1141        'mono'      : True,
1142        # Default to magnetic a magnetic moment is set on the command line
1143        'magnetic'  : False,
1144        'show_pars' : False,
1145        'show_hist' : False,
1146        'rel_err'   : True,
1147        'explore'   : False,
1148        'use_demo'  : True,
1149        'zero'      : False,
1150        'html'      : False,
1151        'title'     : None,
1152        'datafile'  : None,
1153        'sets'      : 0,
1154        'engine'    : 'default',
1155        'evals'     : '1',
1156        'show_weights' : False,
1157    }
1158    for arg in flags:
1159        if arg == '-noplot':    opts['plot'] = False
1160        elif arg == '-plot':    opts['plot'] = True
1161        elif arg == '-linear':  opts['view'] = 'linear'
1162        elif arg == '-log':     opts['view'] = 'log'
1163        elif arg == '-q4':      opts['view'] = 'q4'
1164        elif arg == '-1d':      opts['is2d'] = False
1165        elif arg == '-2d':      opts['is2d'] = True
1166        elif arg == '-exq':     opts['qmax'] = 10.0
1167        elif arg == '-highq':   opts['qmax'] = 1.0
1168        elif arg == '-midq':    opts['qmax'] = 0.2
1169        elif arg == '-lowq':    opts['qmax'] = 0.05
1170        elif arg == '-zero':    opts['zero'] = True
1171        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1172        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1173        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1174        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1175        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1176        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1177        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
1178        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1179        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1180        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1181        elif arg.startswith('-neval='):    opts['evals'] = arg[7:]
1182        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
1183        elif arg == '-preset':  opts['seed'] = -1
1184        elif arg == '-mono':    opts['mono'] = True
1185        elif arg == '-poly':    opts['mono'] = False
1186        elif arg == '-magnetic':       opts['magnetic'] = True
1187        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1188        elif arg == '-pars':    opts['show_pars'] = True
1189        elif arg == '-nopars':  opts['show_pars'] = False
1190        elif arg == '-hist':    opts['show_hist'] = True
1191        elif arg == '-nohist':  opts['show_hist'] = False
1192        elif arg == '-rel':     opts['rel_err'] = True
1193        elif arg == '-abs':     opts['rel_err'] = False
1194        elif arg == '-half':    opts['engine'] = 'half'
1195        elif arg == '-fast':    opts['engine'] = 'fast'
1196        elif arg == '-single':  opts['engine'] = 'single'
1197        elif arg == '-double':  opts['engine'] = 'double'
1198        elif arg == '-single!': opts['engine'] = 'single!'
1199        elif arg == '-double!': opts['engine'] = 'double!'
1200        elif arg == '-quad!':   opts['engine'] = 'quad!'
1201        elif arg == '-sasview': opts['engine'] = 'sasview'
1202        elif arg == '-edit':    opts['explore'] = True
1203        elif arg == '-demo':    opts['use_demo'] = True
1204        elif arg == '-default': opts['use_demo'] = False
1205        elif arg == '-weights': opts['show_weights'] = True
1206        elif arg == '-html':    opts['html'] = True
1207        elif arg == '-help':    opts['html'] = True
1208    # pylint: enable=bad-whitespace
1209
1210    # Magnetism forces 2D for now
1211    if opts['magnetic']:
1212        opts['is2d'] = True
1213
1214    # Force random if sets is used
1215    if opts['sets'] >= 1 and opts['seed'] < 0:
1216        opts['seed'] = np.random.randint(1000000)
1217    if opts['sets'] == 0:
1218        opts['sets'] = 1
1219
1220    # Create the computational engines
1221    if opts['datafile'] is not None:
1222        data = load_data(os.path.expanduser(opts['datafile']))
1223    else:
1224        data, _ = make_data(opts)
1225
1226    comparison = any(PAR_SPLIT in v for v in values)
1227    if PAR_SPLIT in name:
1228        names = name.split(PAR_SPLIT, 2)
1229        comparison = True
1230    else:
1231        names = [name]*2
1232    try:
1233        model_info = [core.load_model_info(k) for k in names]
1234    except ImportError as exc:
1235        print(str(exc))
1236        print("Could not find model; use one of:\n    " + models)
1237        return None
1238
1239    if PAR_SPLIT in opts['engine']:
1240        engine_types = opts['engine'].split(PAR_SPLIT, 2)
1241        comparison = True
1242    else:
1243        engine_types = [opts['engine']]*2
1244
1245    if PAR_SPLIT in opts['evals']:
1246        evals = [int(k) for k in opts['evals'].split(PAR_SPLIT, 2)]
1247        comparison = True
1248    else:
1249        evals = [int(opts['evals'])]*2
1250
1251    if PAR_SPLIT in opts['cutoff']:
1252        cutoff = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1253        comparison = True
1254    else:
1255        cutoff = [float(opts['cutoff'])]*2
1256
1257    base = make_engine(model_info[0], data, engine_types[0], cutoff[0])
1258    if comparison:
1259        comp = make_engine(model_info[1], data, engine_types[1], cutoff[1])
1260    else:
1261        comp = None
1262
1263    # pylint: disable=bad-whitespace
1264    # Remember it all
1265    opts.update({
1266        'data'      : data,
1267        'name'      : names,
1268        'def'       : model_info,
1269        'count'     : evals,
1270        'engines'   : [base, comp],
1271        'values'    : values,
1272    })
1273    # pylint: enable=bad-whitespace
1274
1275    return opts
1276
1277def parse_pars(opts):
1278    model_info, model_info2 = opts['def']
1279
1280    # Get demo parameters from model definition, or use default parameters
1281    # if model does not define demo parameters
1282    pars = get_pars(model_info, opts['use_demo'])
1283    pars2 = get_pars(model_info2, opts['use_demo'])
1284    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1285    # randomize parameters
1286    #pars.update(set_pars)  # set value before random to control range
1287    if opts['seed'] > -1:
1288        pars = randomize_pars(model_info, pars)
1289        if model_info != model_info2:
1290            pars2 = randomize_pars(model_info2, pars2)
1291            # Share values for parameters with the same name
1292            for k, v in pars.items():
1293                if k in pars2:
1294                    pars2[k] = v
1295        else:
1296            pars2 = pars.copy()
1297        constrain_pars(model_info, pars)
1298        constrain_pars(model_info2, pars2)
1299    pars = suppress_pd(pars, opts['mono'])
1300    pars2 = suppress_pd(pars2, opts['mono'])
1301    pars = suppress_magnetism(pars, not opts['magnetic'])
1302    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1303
1304    # Fill in parameters given on the command line
1305    presets = {}
1306    presets2 = {}
1307    for arg in opts['values']:
1308        k, v = arg.split('=', 1)
1309        if k not in pars and k not in pars2:
1310            # extract base name without polydispersity info
1311            s = set(p.split('_pd')[0] for p in pars)
1312            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1313            return None
1314        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
1315        if v1 and k in pars:
1316            presets[k] = float(v1) if isnumber(v1) else v1
1317        if v2 and k in pars2:
1318            presets2[k] = float(v2) if isnumber(v2) else v2
1319
1320    # If pd given on the command line, default pd_n to 35
1321    for k, v in list(presets.items()):
1322        if k.endswith('_pd'):
1323            presets.setdefault(k+'_n', 35.)
1324    for k, v in list(presets2.items()):
1325        if k.endswith('_pd'):
1326            presets2.setdefault(k+'_n', 35.)
1327
1328    # Evaluate preset parameter expressions
1329    context = MATH.copy()
1330    context['np'] = np
1331    context.update(pars)
1332    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1333    for k, v in presets.items():
1334        if not isinstance(v, float) and not k.endswith('_type'):
1335            presets[k] = eval(v, context)
1336    context.update(presets)
1337    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1338    for k, v in presets2.items():
1339        if not isinstance(v, float) and not k.endswith('_type'):
1340            presets2[k] = eval(v, context)
1341
1342    # update parameters with presets
1343    pars.update(presets)  # set value after random to control value
1344    pars2.update(presets2)  # set value after random to control value
1345    #import pprint; pprint.pprint(model_info)
1346
1347    if opts['show_pars']:
1348        if model_info.name != model_info2.name or pars != pars2:
1349            print("==== %s ====="%model_info.name)
1350            print(str(parlist(model_info, pars, opts['is2d'])))
1351            print("==== %s ====="%model_info2.name)
1352            print(str(parlist(model_info2, pars2, opts['is2d'])))
1353        else:
1354            print(str(parlist(model_info, pars, opts['is2d'])))
1355
1356    return pars, pars2
1357
1358def show_docs(opts):
1359    # type: (Dict[str, Any]) -> None
1360    """
1361    show html docs for the model
1362    """
1363    import os
1364    from .generate import make_html
1365    from . import rst2html
1366
1367    info = opts['def'][0]
1368    html = make_html(info)
1369    path = os.path.dirname(info.filename)
1370    url = "file://"+path.replace("\\","/")[2:]+"/"
1371    rst2html.view_html_qtapp(html, url)
1372
1373def explore(opts):
1374    # type: (Dict[str, Any]) -> None
1375    """
1376    explore the model using the bumps gui.
1377    """
1378    import wx  # type: ignore
1379    from bumps.names import FitProblem  # type: ignore
1380    from bumps.gui.app_frame import AppFrame  # type: ignore
1381    from bumps.gui import signal
1382
1383    is_mac = "cocoa" in wx.version()
1384    # Create an app if not running embedded
1385    app = wx.App() if wx.GetApp() is None else None
1386    model = Explore(opts)
1387    problem = FitProblem(model)
1388    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1389    if not is_mac:
1390        frame.Show()
1391    frame.panel.set_model(model=problem)
1392    frame.panel.Layout()
1393    frame.panel.aui.Split(0, wx.TOP)
1394    def reset_parameters(event):
1395        model.revert_values()
1396        signal.update_parameters(problem)
1397    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1398    if is_mac: frame.Show()
1399    # If running withing an app, start the main loop
1400    if app:
1401        app.MainLoop()
1402
1403class Explore(object):
1404    """
1405    Bumps wrapper for a SAS model comparison.
1406
1407    The resulting object can be used as a Bumps fit problem so that
1408    parameters can be adjusted in the GUI, with plots updated on the fly.
1409    """
1410    def __init__(self, opts):
1411        # type: (Dict[str, Any]) -> None
1412        from bumps.cli import config_matplotlib  # type: ignore
1413        from . import bumps_model
1414        config_matplotlib()
1415        self.opts = opts
1416        opts['pars'] = list(opts['pars'])
1417        p1, p2 = opts['pars']
1418        m1, m2 = opts['def']
1419        self.fix_p2 = m1 != m2 or p1 != p2
1420        model_info = m1
1421        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1422        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1423        if not opts['is2d']:
1424            for p in model_info.parameters.user_parameters({}, is2d=False):
1425                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1426                    k = p.name+ext
1427                    v = pars.get(k, None)
1428                    if v is not None:
1429                        v.range(*parameter_range(k, v.value))
1430        else:
1431            for k, v in pars.items():
1432                v.range(*parameter_range(k, v.value))
1433
1434        self.pars = pars
1435        self.starting_values = dict((k, v.value) for k, v in pars.items())
1436        self.pd_types = pd_types
1437        self.limits = None
1438
1439    def revert_values(self):
1440        for k, v in self.starting_values.items():
1441            self.pars[k].value = v
1442
1443    def model_update(self):
1444        pass
1445
1446    def numpoints(self):
1447        # type: () -> int
1448        """
1449        Return the number of points.
1450        """
1451        return len(self.pars) + 1  # so dof is 1
1452
1453    def parameters(self):
1454        # type: () -> Any   # Dict/List hierarchy of parameters
1455        """
1456        Return a dictionary of parameters.
1457        """
1458        return self.pars
1459
1460    def nllf(self):
1461        # type: () -> float
1462        """
1463        Return cost.
1464        """
1465        # pylint: disable=no-self-use
1466        return 0.  # No nllf
1467
1468    def plot(self, view='log'):
1469        # type: (str) -> None
1470        """
1471        Plot the data and residuals.
1472        """
1473        pars = dict((k, v.value) for k, v in self.pars.items())
1474        pars.update(self.pd_types)
1475        self.opts['pars'][0] = pars
1476        if not self.fix_p2:
1477            self.opts['pars'][1] = pars
1478        result = run_models(self.opts)
1479        limits = plot_models(self.opts, result, limits=self.limits)
1480        if self.limits is None:
1481            vmin, vmax = limits
1482            self.limits = vmax*1e-7, 1.3*vmax
1483            import pylab; pylab.clf()
1484            plot_models(self.opts, result, limits=self.limits)
1485
1486
1487def main(*argv):
1488    # type: (*str) -> None
1489    """
1490    Main program.
1491    """
1492    opts = parse_opts(argv)
1493    if opts is not None:
1494        if opts['seed'] > -1:
1495            print("Randomize using -random=%i"%opts['seed'])
1496            np.random.seed(opts['seed'])
1497        if opts['html']:
1498            show_docs(opts)
1499        elif opts['explore']:
1500            opts['pars'] = parse_pars(opts)
1501            if opts['pars'] is None:
1502                return
1503            explore(opts)
1504        else:
1505            compare(opts)
1506
1507if __name__ == "__main__":
1508    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.