source: sasmodels/sasmodels/compare.py @ ff31782

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ff31782 was ff31782, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

add -ngauss option to compare in order to set the number of integration points

  • Property mode set to 100755
File size: 55.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel, get_mesh
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE, set_integration_size
47from .weights import plot_weights
48
49try:
50    from typing import Optional, Dict, Any, Callable, Tuple
51except Exception:
52    pass
53else:
54    from .modelinfo import ModelInfo, Parameter, ParameterSet
55    from .data import Data
56    Calculator = Callable[[float], np.ndarray]
57
58USAGE = """
59usage: sascomp model [options...] [key=val]
60
61Generate and compare SAS models.  If a single model is specified it shows
62a plot of that model.  Different models can be compared, or the same model
63with different parameters.  The same model with the same parameters can
64be compared with different calculation engines to see the effects of precision
65on the resultant values.
66
67model or model1,model2 are the names of the models to compare (see below).
68
69Options (* for default):
70
71    === data generation ===
72    -data="path" uses q, dq from the data file
73    -noise=0 sets the measurement error dI/I
74    -res=0 sets the resolution width dQ/Q if calculating with resolution
75    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
76    -q=min:max alternative specification of qrange
77    -nq=128 sets the number of Q points in the data set
78    -1d*/-2d computes 1d or 2d data
79    -zero indicates that q=0 should be included
80
81    === model parameters ===
82    -preset*/-random[=seed] preset or random parameters
83    -sets=n generates n random datasets with the seed given by -random=seed
84    -pars/-nopars* prints the parameter set or not
85    -default/-demo* use demo vs default parameters
86    -sphere[=150] set up spherical integration over theta/phi using n points
87
88    === calculation options ===
89    -mono*/-poly force monodisperse or allow polydisperse random parameters
90    -cutoff=1e-5* cutoff value for including a point in polydispersity
91    -magnetic/-nonmagnetic* suppress magnetism
92    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
93    -neval=1 sets the number of evals for more accurate timing
94
95    === precision options ===
96    -engine=default uses the default calcution precision
97    -single/-double/-half/-fast sets an OpenCL calculation engine
98    -single!/-double!/-quad! sets an OpenMP calculation engine
99    -sasview sets the sasview calculation engine
100
101    === plotting ===
102    -plot*/-noplot plots or suppress the plot of the model
103    -linear/-log*/-q4 intensity scaling on plots
104    -hist/-nohist* plot histogram of relative error
105    -abs/-rel* plot relative or absolute error
106    -title="note" adds note to the plot title, after the model name
107    -weights shows weights plots for the polydisperse parameters
108
109    === output options ===
110    -edit starts the parameter explorer
111    -help/-html shows the model docs instead of running the model
112
113The interpretation of quad precision depends on architecture, and may
114vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
115On unix and mac you may need single quotes around the DLL computation
116engines, such as -engine='single!,double!' since !, is treated as a history
117expansion request in the shell.
118
119Key=value pairs allow you to set specific values for the model parameters.
120Key=value1,value2 to compare different values of the same parameter. The
121value can be an expression including other parameters.
122
123Items later on the command line override those that appear earlier.
124
125Examples:
126
127    # compare single and double precision calculation for a barbell
128    sascomp barbell -engine=single,double
129
130    # generate 10 random lorentz models, with seed=27
131    sascomp lorentz -sets=10 -seed=27
132
133    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
134    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
135
136    # model timing test requires multiple evals to perform the estimate
137    sascomp pringle -engine=single,double -timing=100,100 -noplot
138"""
139
140# Update docs with command line usage string.   This is separate from the usual
141# doc string so that we can display it at run time if there is an error.
142# lin
143__doc__ = (__doc__  # pylint: disable=redefined-builtin
144           + """
145Program description
146-------------------
147
148""" + USAGE)
149
150kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
151
152# list of math functions for use in evaluating parameters
153MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
154
155# CRUFT python 2.6
156if not hasattr(datetime.timedelta, 'total_seconds'):
157    def delay(dt):
158        """Return number date-time delta as number seconds"""
159        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
160else:
161    def delay(dt):
162        """Return number date-time delta as number seconds"""
163        return dt.total_seconds()
164
165
166class push_seed(object):
167    """
168    Set the seed value for the random number generator.
169
170    When used in a with statement, the random number generator state is
171    restored after the with statement is complete.
172
173    :Parameters:
174
175    *seed* : int or array_like, optional
176        Seed for RandomState
177
178    :Example:
179
180    Seed can be used directly to set the seed::
181
182        >>> from numpy.random import randint
183        >>> push_seed(24)
184        <...push_seed object at...>
185        >>> print(randint(0,1000000,3))
186        [242082    899 211136]
187
188    Seed can also be used in a with statement, which sets the random
189    number generator state for the enclosed computations and restores
190    it to the previous state on completion::
191
192        >>> with push_seed(24):
193        ...    print(randint(0,1000000,3))
194        [242082    899 211136]
195
196    Using nested contexts, we can demonstrate that state is indeed
197    restored after the block completes::
198
199        >>> with push_seed(24):
200        ...    print(randint(0,1000000))
201        ...    with push_seed(24):
202        ...        print(randint(0,1000000,3))
203        ...    print(randint(0,1000000))
204        242082
205        [242082    899 211136]
206        899
207
208    The restore step is protected against exceptions in the block::
209
210        >>> with push_seed(24):
211        ...    print(randint(0,1000000))
212        ...    try:
213        ...        with push_seed(24):
214        ...            print(randint(0,1000000,3))
215        ...            raise Exception()
216        ...    except Exception:
217        ...        print("Exception raised")
218        ...    print(randint(0,1000000))
219        242082
220        [242082    899 211136]
221        Exception raised
222        899
223    """
224    def __init__(self, seed=None):
225        # type: (Optional[int]) -> None
226        self._state = np.random.get_state()
227        np.random.seed(seed)
228
229    def __enter__(self):
230        # type: () -> None
231        pass
232
233    def __exit__(self, exc_type, exc_value, traceback):
234        # type: (Any, BaseException, Any) -> None
235        # TODO: better typing for __exit__ method
236        np.random.set_state(self._state)
237
238def tic():
239    # type: () -> Callable[[], float]
240    """
241    Timer function.
242
243    Use "toc=tic()" to start the clock and "toc()" to measure
244    a time interval.
245    """
246    then = datetime.datetime.now()
247    return lambda: delay(datetime.datetime.now() - then)
248
249
250def set_beam_stop(data, radius, outer=None):
251    # type: (Data, float, float) -> None
252    """
253    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
254
255    Note: this function does not require sasview
256    """
257    if hasattr(data, 'qx_data'):
258        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
259        data.mask = (q < radius)
260        if outer is not None:
261            data.mask |= (q >= outer)
262    else:
263        data.mask = (data.x < radius)
264        if outer is not None:
265            data.mask |= (data.x >= outer)
266
267
268def parameter_range(p, v):
269    # type: (str, float) -> Tuple[float, float]
270    """
271    Choose a parameter range based on parameter name and initial value.
272    """
273    # process the polydispersity options
274    if p.endswith('_pd_n'):
275        return 0., 100.
276    elif p.endswith('_pd_nsigma'):
277        return 0., 5.
278    elif p.endswith('_pd_type'):
279        raise ValueError("Cannot return a range for a string value")
280    elif any(s in p for s in ('theta', 'phi', 'psi')):
281        # orientation in [-180,180], orientation pd in [0,45]
282        if p.endswith('_pd'):
283            return 0., 180.
284        else:
285            return -180., 180.
286    elif p.endswith('_pd'):
287        return 0., 1.
288    elif 'sld' in p:
289        return -0.5, 10.
290    elif p == 'background':
291        return 0., 10.
292    elif p == 'scale':
293        return 0., 1.e3
294    elif v < 0.:
295        return 2.*v, -2.*v
296    else:
297        return 0., (2.*v if v > 0. else 1.)
298
299
300def _randomize_one(model_info, name, value):
301    # type: (ModelInfo, str, float) -> float
302    # type: (ModelInfo, str, str) -> str
303    """
304    Randomize a single parameter.
305    """
306    # Set the amount of polydispersity/angular dispersion, but by default pd_n
307    # is zero so there is no polydispersity.  This allows us to turn on/off
308    # pd by setting pd_n, and still have randomly generated values
309    if name.endswith('_pd'):
310        par = model_info.parameters[name[:-3]]
311        if par.type == 'orientation':
312            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
313            return 180*np.random.beta(2.5, 20)
314        else:
315            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
316            return np.random.beta(1.5, 7)
317
318    # pd is selected globally rather than per parameter, so set to 0 for no pd
319    # In particular, when multiple pd dimensions, want to decrease the number
320    # of points per dimension for faster computation
321    if name.endswith('_pd_n'):
322        return 0
323
324    # Don't mess with distribution type for now
325    if name.endswith('_pd_type'):
326        return 'gaussian'
327
328    # type-dependent value of number of sigmas; for gaussian use 3.
329    if name.endswith('_pd_nsigma'):
330        return 3.
331
332    # background in the range [0.01, 1]
333    if name == 'background':
334        return 10**np.random.uniform(-2, 0)
335
336    # scale defaults to 0.1% to 30% volume fraction
337    if name == 'scale':
338        return 10**np.random.uniform(-3, -0.5)
339
340    # If it is a list of choices, pick one at random with equal probability
341    # In practice, the model specific random generator will override.
342    par = model_info.parameters[name]
343    if len(par.limits) > 2:  # choice list
344        return np.random.randint(len(par.limits))
345
346    # If it is a fixed range, pick from it with equal probability.
347    # For logarithmic ranges, the model will have to override.
348    if np.isfinite(par.limits).all():
349        return np.random.uniform(*par.limits)
350
351    # If the paramter is marked as an sld use the range of neutron slds
352    # TODO: ought to randomly contrast match a pair of SLDs
353    if par.type == 'sld':
354        return np.random.uniform(-0.5, 12)
355
356    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
357    if par.name.startswith('M0:'):
358        return np.random.uniform(0, 5)
359
360    # Guess at the random length/radius/thickness.  In practice, all models
361    # are going to set their own reasonable ranges.
362    if par.type == 'volume':
363        if ('length' in par.name or
364                'radius' in par.name or
365                'thick' in par.name):
366            return 10**np.random.uniform(2, 4)
367
368    # In the absence of any other info, select a value in [0, 2v], or
369    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
370    # model random parameter generators will override this default.
371    low, high = parameter_range(par.name, value)
372    limits = (max(par.limits[0], low), min(par.limits[1], high))
373    return np.random.uniform(*limits)
374
375def _random_pd(model_info, pars):
376    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
377    pd_volume = []
378    pd_oriented = []
379    for p in pd:
380        if p.type == 'orientation':
381            pd_oriented.append(p.name)
382        elif p.length_control is not None:
383            n = int(pars.get(p.length_control, 1) + 0.5)
384            pd_volume.extend(p.name+str(k+1) for k in range(n))
385        elif p.length > 1:
386            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
387        else:
388            pd_volume.append(p.name)
389    u = np.random.rand()
390    n = len(pd_volume)
391    if u < 0.01 or n < 1:
392        pass  # 1% chance of no polydispersity
393    elif u < 0.86 or n < 2:
394        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
395    elif u < 0.99 or n < 3:
396        choices = np.random.choice(len(pd_volume), size=2)
397        pars[pd_volume[choices[0]]+"_pd_n"] = 25
398        pars[pd_volume[choices[1]]+"_pd_n"] = 10
399    else:
400        choices = np.random.choice(len(pd_volume), size=3)
401        pars[pd_volume[choices[0]]+"_pd_n"] = 25
402        pars[pd_volume[choices[1]]+"_pd_n"] = 10
403        pars[pd_volume[choices[2]]+"_pd_n"] = 5
404    if pd_oriented:
405        pars['theta_pd_n'] = 20
406        if np.random.rand() < 0.1:
407            pars['phi_pd_n'] = 5
408        if np.random.rand() < 0.1:
409            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
410                #print("generating psi_pd_n")
411                pars['psi_pd_n'] = 5
412
413    ## Show selected polydispersity
414    #for name, value in pars.items():
415    #    if name.endswith('_pd_n') and value > 0:
416    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
417
418
419def randomize_pars(model_info, pars):
420    # type: (ModelInfo, ParameterSet) -> ParameterSet
421    """
422    Generate random values for all of the parameters.
423
424    Valid ranges for the random number generator are guessed from the name of
425    the parameter; this will not account for constraints such as cap radius
426    greater than cylinder radius in the capped_cylinder model, so
427    :func:`constrain_pars` needs to be called afterward..
428    """
429    # Note: the sort guarantees order of calls to random number generator
430    random_pars = dict((p, _randomize_one(model_info, p, v))
431                       for p, v in sorted(pars.items()))
432    if model_info.random is not None:
433        random_pars.update(model_info.random())
434    _random_pd(model_info, random_pars)
435    return random_pars
436
437
438def limit_dimensions(model_info, pars, maxdim):
439    # type: (ModelInfo, ParameterSet, float) -> None
440    """
441    Limit parameters of units of Ang to maxdim.
442    """
443    for p in model_info.parameters.call_parameters:
444        value = pars[p.name]
445        if p.units == 'Ang' and value > maxdim:
446            pars[p.name] = maxdim*10**np.random.uniform(-3,0)
447
448def constrain_pars(model_info, pars):
449    # type: (ModelInfo, ParameterSet) -> None
450    """
451    Restrict parameters to valid values.
452
453    This includes model specific code for models such as capped_cylinder
454    which need to support within model constraints (cap radius more than
455    cylinder radius in this case).
456
457    Warning: this updates the *pars* dictionary in place.
458    """
459    # TODO: move the model specific code to the individual models
460    name = model_info.id
461    # if it is a product model, then just look at the form factor since
462    # none of the structure factors need any constraints.
463    if '*' in name:
464        name = name.split('*')[0]
465
466    # Suppress magnetism for python models (not yet implemented)
467    if callable(model_info.Iq):
468        pars.update(suppress_magnetism(pars))
469
470    if name == 'barbell':
471        if pars['radius_bell'] < pars['radius']:
472            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
473
474    elif name == 'capped_cylinder':
475        if pars['radius_cap'] < pars['radius']:
476            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
477
478    elif name == 'guinier':
479        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
480        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
481        # => ln A - (Rg^2 q^2/3) > -30 ln 10
482        # => Rg^2 q^2/3 < 30 ln 10 + ln A
483        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
484        #q_max = 0.2  # mid q maximum
485        q_max = 1.0  # high q maximum
486        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
487        pars['rg'] = min(pars['rg'], rg_max)
488
489    elif name == 'pearl_necklace':
490        if pars['radius'] < pars['thick_string']:
491            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
492        pass
493
494    elif name == 'rpa':
495        # Make sure phi sums to 1.0
496        if pars['case_num'] < 2:
497            pars['Phi1'] = 0.
498            pars['Phi2'] = 0.
499        elif pars['case_num'] < 5:
500            pars['Phi1'] = 0.
501        total = sum(pars['Phi'+c] for c in '1234')
502        for c in '1234':
503            pars['Phi'+c] /= total
504
505def parlist(model_info, pars, is2d):
506    # type: (ModelInfo, ParameterSet, bool) -> str
507    """
508    Format the parameter list for printing.
509    """
510    is2d = True
511    lines = []
512    parameters = model_info.parameters
513    magnetic = False
514    magnetic_pars = []
515    for p in parameters.user_parameters(pars, is2d):
516        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
517            continue
518        if p.id.startswith('up:'):
519            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
520            continue
521        fields = dict(
522            value=pars.get(p.id, p.default),
523            pd=pars.get(p.id+"_pd", 0.),
524            n=int(pars.get(p.id+"_pd_n", 0)),
525            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
526            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
527            relative_pd=p.relative_pd,
528            M0=pars.get('M0:'+p.id, 0.),
529            mphi=pars.get('mphi:'+p.id, 0.),
530            mtheta=pars.get('mtheta:'+p.id, 0.),
531        )
532        lines.append(_format_par(p.name, **fields))
533        magnetic = magnetic or fields['M0'] != 0.
534    if magnetic and magnetic_pars:
535        lines.append(" ".join(magnetic_pars))
536    return "\n".join(lines)
537
538    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
539
540def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
541                relative_pd=False, M0=0., mphi=0., mtheta=0.):
542    # type: (str, float, float, int, float, str) -> str
543    line = "%s: %g"%(name, value)
544    if pd != 0.  and n != 0:
545        if relative_pd:
546            pd *= value
547        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
548                % (pd, n, nsigma, nsigma, pdtype)
549    if M0 != 0.:
550        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
551    return line
552
553def suppress_pd(pars, suppress=True):
554    # type: (ParameterSet) -> ParameterSet
555    """
556    If suppress is True complete eliminate polydispersity of the model to test
557    models more quickly.  If suppress is False, make sure at least one
558    parameter is polydisperse, setting the first polydispersity parameter to
559    15% if no polydispersity is given (with no explicit demo parameters given
560    in the model, there will be no default polydispersity).
561    """
562    pars = pars.copy()
563    #print("pars=", pars)
564    if suppress:
565        for p in pars:
566            if p.endswith("_pd_n"):
567                pars[p] = 0
568    else:
569        any_pd = False
570        first_pd = None
571        for p in pars:
572            if p.endswith("_pd_n"):
573                pd = pars.get(p[:-2], 0.)
574                any_pd |= (pars[p] != 0 and pd != 0.)
575                if first_pd is None:
576                    first_pd = p
577        if not any_pd and first_pd is not None:
578            if pars[first_pd] == 0:
579                pars[first_pd] = 35
580            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
581                pars[first_pd[:-2]] = 0.15
582    return pars
583
584def suppress_magnetism(pars, suppress=True):
585    # type: (ParameterSet) -> ParameterSet
586    """
587    If suppress is True complete eliminate magnetism of the model to test
588    models more quickly.  If suppress is False, make sure at least one sld
589    parameter is magnetic, setting the first parameter to have a strong
590    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
591    in the model, there will be no default magnetism).
592    """
593    pars = pars.copy()
594    if suppress:
595        for p in pars:
596            if p.startswith("M0:"):
597                pars[p] = 0
598    else:
599        any_mag = False
600        first_mag = None
601        for p in pars:
602            if p.startswith("M0:"):
603                any_mag |= (pars[p] != 0)
604                if first_mag is None:
605                    first_mag = p
606        if not any_mag and first_mag is not None:
607            pars[first_mag] = 8.
608    return pars
609
610def eval_sasview(model_info, data):
611    # type: (Modelinfo, Data) -> Calculator
612    """
613    Return a model calculator using the pre-4.0 SasView models.
614    """
615    # importing sas here so that the error message will be that sas failed to
616    # import rather than the more obscure smear_selection not imported error
617    import sas
618    import sas.models
619    from sas.models.qsmearing import smear_selection
620    from sas.models.MultiplicationModel import MultiplicationModel
621    from sas.models.dispersion_models import models as dispersers
622
623    def get_model_class(name):
624        # type: (str) -> "sas.models.BaseComponent"
625        #print("new",sorted(_pars.items()))
626        __import__('sas.models.' + name)
627        ModelClass = getattr(getattr(sas.models, name, None), name, None)
628        if ModelClass is None:
629            raise ValueError("could not find model %r in sas.models"%name)
630        return ModelClass
631
632    # WARNING: ugly hack when handling model!
633    # Sasview models with multiplicity need to be created with the target
634    # multiplicity, so we cannot create the target model ahead of time for
635    # for multiplicity models.  Instead we store the model in a list and
636    # update the first element of that list with the new multiplicity model
637    # every time we evaluate.
638
639    # grab the sasview model, or create it if it is a product model
640    if model_info.composition:
641        composition_type, parts = model_info.composition
642        if composition_type == 'product':
643            P, S = [get_model_class(revert_name(p))() for p in parts]
644            model = [MultiplicationModel(P, S)]
645        else:
646            raise ValueError("sasview mixture models not supported by compare")
647    else:
648        old_name = revert_name(model_info)
649        if old_name is None:
650            raise ValueError("model %r does not exist in old sasview"
651                            % model_info.id)
652        ModelClass = get_model_class(old_name)
653        model = [ModelClass()]
654    model[0].disperser_handles = {}
655
656    # build a smearer with which to call the model, if necessary
657    smearer = smear_selection(data, model=model)
658    if hasattr(data, 'qx_data'):
659        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
660        index = ((~data.mask) & (~np.isnan(data.data))
661                 & (q >= data.qmin) & (q <= data.qmax))
662        if smearer is not None:
663            smearer.model = model  # because smear_selection has a bug
664            smearer.accuracy = data.accuracy
665            smearer.set_index(index)
666            def _call_smearer():
667                smearer.model = model[0]
668                return smearer.get_value()
669            theory = _call_smearer
670        else:
671            theory = lambda: model[0].evalDistribution([data.qx_data[index],
672                                                        data.qy_data[index]])
673    elif smearer is not None:
674        theory = lambda: smearer(model[0].evalDistribution(data.x))
675    else:
676        theory = lambda: model[0].evalDistribution(data.x)
677
678    def calculator(**pars):
679        # type: (float, ...) -> np.ndarray
680        """
681        Sasview calculator for model.
682        """
683        oldpars = revert_pars(model_info, pars)
684        # For multiplicity models, create a model with the correct multiplicity
685        control = oldpars.pop("CONTROL", None)
686        if control is not None:
687            # sphericalSLD has one fewer multiplicity.  This update should
688            # happen in revert_pars, but it hasn't been called yet.
689            model[0] = ModelClass(control)
690        # paying for parameter conversion each time to keep life simple, if not fast
691        for k, v in oldpars.items():
692            if k.endswith('.type'):
693                par = k[:-5]
694                if v == 'gaussian': continue
695                cls = dispersers[v if v != 'rectangle' else 'rectangula']
696                handle = cls()
697                model[0].disperser_handles[par] = handle
698                try:
699                    model[0].set_dispersion(par, handle)
700                except Exception:
701                    exception.annotate_exception("while setting %s to %r"
702                                                 %(par, v))
703                    raise
704
705
706        #print("sasview pars",oldpars)
707        for k, v in oldpars.items():
708            name_attr = k.split('.')  # polydispersity components
709            if len(name_attr) == 2:
710                par, disp_par = name_attr
711                model[0].dispersion[par][disp_par] = v
712            else:
713                model[0].setParam(k, v)
714        return theory()
715
716    calculator.engine = "sasview"
717    return calculator
718
719DTYPE_MAP = {
720    'half': '16',
721    'fast': 'fast',
722    'single': '32',
723    'double': '64',
724    'quad': '128',
725    'f16': '16',
726    'f32': '32',
727    'f64': '64',
728    'float16': '16',
729    'float32': '32',
730    'float64': '64',
731    'float128': '128',
732    'longdouble': '128',
733}
734def eval_opencl(model_info, data, dtype='single', cutoff=0.):
735    # type: (ModelInfo, Data, str, float) -> Calculator
736    """
737    Return a model calculator using the OpenCL calculation engine.
738    """
739    if not core.HAVE_OPENCL:
740        raise RuntimeError("OpenCL not available")
741    model = core.build_model(model_info, dtype=dtype, platform="ocl")
742    calculator = DirectModel(data, model, cutoff=cutoff)
743    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
744    return calculator
745
746def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
747    # type: (ModelInfo, Data, str, float) -> Calculator
748    """
749    Return a model calculator using the DLL calculation engine.
750    """
751    model = core.build_model(model_info, dtype=dtype, platform="dll")
752    calculator = DirectModel(data, model, cutoff=cutoff)
753    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
754    return calculator
755
756def time_calculation(calculator, pars, evals=1):
757    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
758    """
759    Compute the average calculation time over N evaluations.
760
761    An additional call is generated without polydispersity in order to
762    initialize the calculation engine, and make the average more stable.
763    """
764    # initialize the code so time is more accurate
765    if evals > 1:
766        calculator(**suppress_pd(pars))
767    toc = tic()
768    # make sure there is at least one eval
769    value = calculator(**pars)
770    for _ in range(evals-1):
771        value = calculator(**pars)
772    average_time = toc()*1000. / evals
773    #print("I(q)",value)
774    return value, average_time
775
776def make_data(opts):
777    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
778    """
779    Generate an empty dataset, used with the model to set Q points
780    and resolution.
781
782    *opts* contains the options, with 'qmax', 'nq', 'res',
783    'accuracy', 'is2d' and 'view' parsed from the command line.
784    """
785    qmin, qmax, nq, res = opts['qmin'], opts['qmax'], opts['nq'], opts['res']
786    if opts['is2d']:
787        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
788        data = empty_data2D(q, resolution=res)
789        data.accuracy = opts['accuracy']
790        set_beam_stop(data, qmin)
791        index = ~data.mask
792    else:
793        if opts['view'] == 'log' and not opts['zero']:
794            q = np.logspace(math.log10(qmin), math.log10(qmax), nq)
795        else:
796            q = np.linspace(qmin, qmax, nq)
797        if opts['zero']:
798            q = np.hstack((0, q))
799        data = empty_data1D(q, resolution=res)
800        index = slice(None, None)
801    return data, index
802
803def make_engine(model_info, data, dtype, cutoff, ngauss=0):
804    # type: (ModelInfo, Data, str, float) -> Calculator
805    """
806    Generate the appropriate calculation engine for the given datatype.
807
808    Datatypes with '!' appended are evaluated using external C DLLs rather
809    than OpenCL.
810    """
811    if ngauss:
812        set_integration_size(model_info, ngauss)
813
814    if dtype == 'sasview':
815        return eval_sasview(model_info, data)
816    elif dtype is None or not dtype.endswith('!'):
817        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
818    else:
819        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
820
821def _show_invalid(data, theory):
822    # type: (Data, np.ma.ndarray) -> None
823    """
824    Display a list of the non-finite values in theory.
825    """
826    if not theory.mask.any():
827        return
828
829    if hasattr(data, 'x'):
830        bad = zip(data.x[theory.mask], theory[theory.mask])
831        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
832
833
834def compare(opts, limits=None, maxdim=np.inf):
835    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
836    """
837    Preform a comparison using options from the command line.
838
839    *limits* are the limits on the values to use, either to set the y-axis
840    for 1D or to set the colormap scale for 2D.  If None, then they are
841    inferred from the data and returned. When exploring using Bumps,
842    the limits are set when the model is initially called, and maintained
843    as the values are adjusted, making it easier to see the effects of the
844    parameters.
845
846    *maxdim* is the maximum value for any parameter with units of Angstrom.
847    """
848    for k in range(opts['sets']):
849        if k > 1:
850            # print a separate seed for each dataset for better reproducibility
851            new_seed = np.random.randint(1000000)
852            print("Set %d uses -random=%i"%(k+1,new_seed))
853            np.random.seed(new_seed)
854        opts['pars'] = parse_pars(opts, maxdim=maxdim)
855        if opts['pars'] is None:
856            return
857        result = run_models(opts, verbose=True)
858        if opts['plot']:
859            limits = plot_models(opts, result, limits=limits, setnum=k)
860        if opts['show_weights']:
861            base, _ = opts['engines']
862            base_pars, _ = opts['pars']
863            model_info = base._kernel.info
864            dim = base._kernel.dim
865            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
866    if opts['plot']:
867        import matplotlib.pyplot as plt
868        plt.show()
869    return limits
870
871def run_models(opts, verbose=False):
872    # type: (Dict[str, Any]) -> Dict[str, Any]
873
874    base, comp = opts['engines']
875    base_n, comp_n = opts['count']
876    base_pars, comp_pars = opts['pars']
877    data = opts['data']
878
879    comparison = comp is not None
880
881    base_time = comp_time = None
882    base_value = comp_value = resid = relerr = None
883
884    # Base calculation
885    try:
886        base_raw, base_time = time_calculation(base, base_pars, base_n)
887        base_value = np.ma.masked_invalid(base_raw)
888        if verbose:
889            print("%s t=%.2f ms, intensity=%.0f"
890                  % (base.engine, base_time, base_value.sum()))
891        _show_invalid(data, base_value)
892    except ImportError:
893        traceback.print_exc()
894
895    # Comparison calculation
896    if comparison:
897        try:
898            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
899            comp_value = np.ma.masked_invalid(comp_raw)
900            if verbose:
901                print("%s t=%.2f ms, intensity=%.0f"
902                      % (comp.engine, comp_time, comp_value.sum()))
903            _show_invalid(data, comp_value)
904        except ImportError:
905            traceback.print_exc()
906
907    # Compare, but only if computing both forms
908    if comparison:
909        resid = (base_value - comp_value)
910        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
911        if verbose:
912            _print_stats("|%s-%s|"
913                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
914                         resid)
915            _print_stats("|(%s-%s)/%s|"
916                         % (base.engine, comp.engine, comp.engine),
917                         relerr)
918
919    return dict(base_value=base_value, comp_value=comp_value,
920                base_time=base_time, comp_time=comp_time,
921                resid=resid, relerr=relerr)
922
923
924def _print_stats(label, err):
925    # type: (str, np.ma.ndarray) -> None
926    # work with trimmed data, not the full set
927    sorted_err = np.sort(abs(err.compressed()))
928    if len(sorted_err) == 0.:
929        print(label + "  no valid values")
930        return
931
932    p50 = int((len(sorted_err)-1)*0.50)
933    p98 = int((len(sorted_err)-1)*0.98)
934    data = [
935        "max:%.3e"%sorted_err[-1],
936        "median:%.3e"%sorted_err[p50],
937        "98%%:%.3e"%sorted_err[p98],
938        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
939        "zero-offset:%+.3e"%np.mean(sorted_err),
940        ]
941    print(label+"  "+"  ".join(data))
942
943
944def plot_models(opts, result, limits=None, setnum=0):
945    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
946    import matplotlib.pyplot as plt
947
948    base_value, comp_value = result['base_value'], result['comp_value']
949    base_time, comp_time = result['base_time'], result['comp_time']
950    resid, relerr = result['resid'], result['relerr']
951
952    have_base, have_comp = (base_value is not None), (comp_value is not None)
953    base, comp = opts['engines']
954    data = opts['data']
955    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
956
957    # Plot if requested
958    view = opts['view']
959    if limits is None:
960        vmin, vmax = np.inf, -np.inf
961        if have_base:
962            vmin = min(vmin, base_value.min())
963            vmax = max(vmax, base_value.max())
964        if have_comp:
965            vmin = min(vmin, comp_value.min())
966            vmax = max(vmax, comp_value.max())
967        limits = vmin, vmax
968
969    if have_base:
970        if have_comp:
971            plt.subplot(131)
972        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
973        plt.title("%s t=%.2f ms"%(base.engine, base_time))
974        #cbar_title = "log I"
975    if have_comp:
976        if have_base:
977            plt.subplot(132)
978        if not opts['is2d'] and have_base:
979            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
980        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
981        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
982        #cbar_title = "log I"
983    if have_base and have_comp:
984        plt.subplot(133)
985        if not opts['rel_err']:
986            err, errstr, errview = resid, "abs err", "linear"
987        else:
988            err, errstr, errview = abs(relerr), "rel err", "log"
989            if (err == 0.).all():
990                errview = 'linear'
991        if 0:  # 95% cutoff
992            sorted = np.sort(err.flatten())
993            cutoff = sorted[int(sorted.size*0.95)]
994            err[err > cutoff] = cutoff
995        #err,errstr = base/comp,"ratio"
996        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
997        plt.xscale('log' if view == 'log' and not opts['is2d'] else 'linear')
998        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
999        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
1000        #cbar_title = errstr if errview=="linear" else "log "+errstr
1001    #if is2D:
1002    #    h = plt.colorbar()
1003    #    h.ax.set_title(cbar_title)
1004    fig = plt.gcf()
1005    extra_title = ' '+opts['title'] if opts['title'] else ''
1006    fig.suptitle(":".join(opts['name']) + extra_title)
1007
1008    if have_base and have_comp and opts['show_hist']:
1009        plt.figure()
1010        v = relerr
1011        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
1012        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
1013        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
1014                   % (base.engine, comp.engine, comp.engine))
1015        plt.ylabel('P(err)')
1016        plt.title('Distribution of relative error between calculation engines')
1017
1018    return limits
1019
1020
1021# ===========================================================================
1022#
1023
1024# Set of command line options.
1025# Normal options such as -plot/-noplot are specified as 'name'.
1026# For options such as -nq=500 which require a value use 'name='.
1027#
1028OPTIONS = [
1029    # Plotting
1030    'plot', 'noplot', 'weights',
1031    'linear', 'log', 'q4',
1032    'rel', 'abs',
1033    'hist', 'nohist',
1034    'title=',
1035
1036    # Data generation
1037    'data=', 'noise=', 'res=', 'nq=', 'q=',
1038    'lowq', 'midq', 'highq', 'exq', 'zero',
1039    '2d', '1d',
1040
1041    # Parameter set
1042    'preset', 'random', 'random=', 'sets=',
1043    'demo', 'default',  # TODO: remove demo/default
1044    'nopars', 'pars',
1045    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
1046
1047    # Calculation options
1048    'poly', 'mono', 'cutoff=',
1049    'magnetic', 'nonmagnetic',
1050    'accuracy=', 'ngauss=',
1051    'neval=',  # for timing...
1052
1053    # Precision options
1054    'engine=',
1055    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1056    'sasview',  # TODO: remove sasview 3.x support
1057
1058    # Output options
1059    'help', 'html', 'edit',
1060    ]
1061
1062NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1063VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1064
1065
1066def columnize(items, indent="", width=79):
1067    # type: (List[str], str, int) -> str
1068    """
1069    Format a list of strings into columns.
1070
1071    Returns a string with carriage returns ready for printing.
1072    """
1073    column_width = max(len(w) for w in items) + 1
1074    num_columns = (width - len(indent)) // column_width
1075    num_rows = len(items) // num_columns
1076    items = items + [""] * (num_rows * num_columns - len(items))
1077    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1078    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1079             for row in zip(*columns)]
1080    output = indent + ("\n"+indent).join(lines)
1081    return output
1082
1083
1084def get_pars(model_info, use_demo=False):
1085    # type: (ModelInfo, bool) -> ParameterSet
1086    """
1087    Extract demo parameters from the model definition.
1088    """
1089    # Get the default values for the parameters
1090    pars = {}
1091    for p in model_info.parameters.call_parameters:
1092        parts = [('', p.default)]
1093        if p.polydisperse:
1094            parts.append(('_pd', 0.0))
1095            parts.append(('_pd_n', 0))
1096            parts.append(('_pd_nsigma', 3.0))
1097            parts.append(('_pd_type', "gaussian"))
1098        for ext, val in parts:
1099            if p.length > 1:
1100                dict(("%s%d%s" % (p.id, k, ext), val)
1101                     for k in range(1, p.length+1))
1102            else:
1103                pars[p.id + ext] = val
1104
1105    # Plug in values given in demo
1106    if use_demo and model_info.demo:
1107        pars.update(model_info.demo)
1108    return pars
1109
1110INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1111def isnumber(str):
1112    match = FLOAT_RE.match(str)
1113    isfloat = (match and not str[match.end():])
1114    return isfloat or INTEGER_RE.match(str)
1115
1116# For distinguishing pairs of models for comparison
1117# key-value pair separator =
1118# shell characters  | & ; <> $ % ' " \ # `
1119# model and parameter names _
1120# parameter expressions - + * / . ( )
1121# path characters including tilde expansion and windows drive ~ / :
1122# not sure about brackets [] {}
1123# maybe one of the following @ ? ^ ! ,
1124PAR_SPLIT = ','
1125def parse_opts(argv):
1126    # type: (List[str]) -> Dict[str, Any]
1127    """
1128    Parse command line options.
1129    """
1130    MODELS = core.list_models()
1131    flags = [arg for arg in argv
1132             if arg.startswith('-')]
1133    values = [arg for arg in argv
1134              if not arg.startswith('-') and '=' in arg]
1135    positional_args = [arg for arg in argv
1136                       if not arg.startswith('-') and '=' not in arg]
1137    models = "\n    ".join("%-15s"%v for v in MODELS)
1138    if len(positional_args) == 0:
1139        print(USAGE)
1140        print("\nAvailable models:")
1141        print(columnize(MODELS, indent="  "))
1142        return None
1143
1144    invalid = [o[1:] for o in flags
1145               if o[1:] not in NAME_OPTIONS
1146               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1147    if invalid:
1148        print("Invalid options: %s"%(", ".join(invalid)))
1149        return None
1150
1151    name = positional_args[-1]
1152
1153    # pylint: disable=bad-whitespace
1154    # Interpret the flags
1155    opts = {
1156        'plot'      : True,
1157        'view'      : 'log',
1158        'is2d'      : False,
1159        'qmin'      : None,
1160        'qmax'      : 0.05,
1161        'nq'        : 128,
1162        'res'       : 0.0,
1163        'noise'     : 0.0,
1164        'accuracy'  : 'Low',
1165        'cutoff'    : '0.0',
1166        'seed'      : -1,  # default to preset
1167        'mono'      : True,
1168        # Default to magnetic a magnetic moment is set on the command line
1169        'magnetic'  : False,
1170        'show_pars' : False,
1171        'show_hist' : False,
1172        'rel_err'   : True,
1173        'explore'   : False,
1174        'use_demo'  : True,
1175        'zero'      : False,
1176        'html'      : False,
1177        'title'     : None,
1178        'datafile'  : None,
1179        'sets'      : 0,
1180        'engine'    : 'default',
1181        'count'     : '1',
1182        'show_weights' : False,
1183        'sphere'    : 0,
1184        'ngauss'    : '0',
1185    }
1186    for arg in flags:
1187        if arg == '-noplot':    opts['plot'] = False
1188        elif arg == '-plot':    opts['plot'] = True
1189        elif arg == '-linear':  opts['view'] = 'linear'
1190        elif arg == '-log':     opts['view'] = 'log'
1191        elif arg == '-q4':      opts['view'] = 'q4'
1192        elif arg == '-1d':      opts['is2d'] = False
1193        elif arg == '-2d':      opts['is2d'] = True
1194        elif arg == '-exq':     opts['qmax'] = 10.0
1195        elif arg == '-highq':   opts['qmax'] = 1.0
1196        elif arg == '-midq':    opts['qmax'] = 0.2
1197        elif arg == '-lowq':    opts['qmax'] = 0.05
1198        elif arg == '-zero':    opts['zero'] = True
1199        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1200        elif arg.startswith('-q='):
1201            opts['qmin'], opts['qmax'] = [float(v) for v in arg[3:].split(':')]
1202        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1203        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1204        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1205        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1206        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1207        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1208        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1209        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1210        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
1211        elif arg.startswith('-ngauss='):   opts['ngauss'] = arg[8:]
1212        elif arg.startswith('-random='):
1213            opts['seed'] = int(arg[8:])
1214            opts['sets'] = 0
1215        elif arg == '-random':
1216            opts['seed'] = np.random.randint(1000000)
1217            opts['sets'] = 0
1218        elif arg.startswith('-sphere'):
1219            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1220            opts['is2d'] = True
1221        elif arg == '-preset':  opts['seed'] = -1
1222        elif arg == '-mono':    opts['mono'] = True
1223        elif arg == '-poly':    opts['mono'] = False
1224        elif arg == '-magnetic':       opts['magnetic'] = True
1225        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1226        elif arg == '-pars':    opts['show_pars'] = True
1227        elif arg == '-nopars':  opts['show_pars'] = False
1228        elif arg == '-hist':    opts['show_hist'] = True
1229        elif arg == '-nohist':  opts['show_hist'] = False
1230        elif arg == '-rel':     opts['rel_err'] = True
1231        elif arg == '-abs':     opts['rel_err'] = False
1232        elif arg == '-half':    opts['engine'] = 'half'
1233        elif arg == '-fast':    opts['engine'] = 'fast'
1234        elif arg == '-single':  opts['engine'] = 'single'
1235        elif arg == '-double':  opts['engine'] = 'double'
1236        elif arg == '-single!': opts['engine'] = 'single!'
1237        elif arg == '-double!': opts['engine'] = 'double!'
1238        elif arg == '-quad!':   opts['engine'] = 'quad!'
1239        elif arg == '-sasview': opts['engine'] = 'sasview'
1240        elif arg == '-edit':    opts['explore'] = True
1241        elif arg == '-demo':    opts['use_demo'] = True
1242        elif arg == '-default': opts['use_demo'] = False
1243        elif arg == '-weights': opts['show_weights'] = True
1244        elif arg == '-html':    opts['html'] = True
1245        elif arg == '-help':    opts['html'] = True
1246    # pylint: enable=bad-whitespace
1247
1248    # Magnetism forces 2D for now
1249    if opts['magnetic']:
1250        opts['is2d'] = True
1251
1252    # Force random if sets is used
1253    if opts['sets'] >= 1 and opts['seed'] < 0:
1254        opts['seed'] = np.random.randint(1000000)
1255    if opts['sets'] == 0:
1256        opts['sets'] = 1
1257
1258    # Create the computational engines
1259    if opts['qmin'] is None:
1260        opts['qmin'] = 0.001*opts['qmax']
1261    if opts['datafile'] is not None:
1262        data = load_data(os.path.expanduser(opts['datafile']))
1263    else:
1264        data, _ = make_data(opts)
1265
1266    comparison = any(PAR_SPLIT in v for v in values)
1267
1268    if PAR_SPLIT in name:
1269        names = name.split(PAR_SPLIT, 2)
1270        comparison = True
1271    else:
1272        names = [name]*2
1273    try:
1274        model_info = [core.load_model_info(k) for k in names]
1275    except ImportError as exc:
1276        print(str(exc))
1277        print("Could not find model; use one of:\n    " + models)
1278        return None
1279
1280    if PAR_SPLIT in opts['ngauss']:
1281        opts['ngauss'] = [int(k) for k in opts['ngauss'].split(PAR_SPLIT, 2)]
1282        comparison = True
1283    else:
1284        opts['ngauss'] = [int(opts['ngauss'])]*2
1285
1286    if PAR_SPLIT in opts['engine']:
1287        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
1288        comparison = True
1289    else:
1290        opts['engine'] = [opts['engine']]*2
1291
1292    if PAR_SPLIT in opts['count']:
1293        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
1294        comparison = True
1295    else:
1296        opts['count'] = [int(opts['count'])]*2
1297
1298    if PAR_SPLIT in opts['cutoff']:
1299        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1300        comparison = True
1301    else:
1302        opts['cutoff'] = [float(opts['cutoff'])]*2
1303
1304    base = make_engine(model_info[0], data, opts['engine'][0],
1305                       opts['cutoff'][0], opts['ngauss'][0])
1306    if comparison:
1307        comp = make_engine(model_info[1], data, opts['engine'][1],
1308                           opts['cutoff'][1], opts['ngauss'][1])
1309    else:
1310        comp = None
1311
1312    # pylint: disable=bad-whitespace
1313    # Remember it all
1314    opts.update({
1315        'data'      : data,
1316        'name'      : names,
1317        'info'      : model_info,
1318        'engines'   : [base, comp],
1319        'values'    : values,
1320    })
1321    # pylint: enable=bad-whitespace
1322
1323    # Set the integration parameters to the half sphere
1324    if opts['sphere'] > 0:
1325        set_spherical_integration_parameters(opts, opts['sphere'])
1326
1327    return opts
1328
1329def set_spherical_integration_parameters(opts, steps):
1330    """
1331    Set integration parameters for spherical integration over the entire
1332    surface in theta-phi coordinates.
1333    """
1334    # Set the integration parameters to the half sphere
1335    opts['values'].extend([
1336        #'theta=90',
1337        'theta_pd=%g'%(90/np.sqrt(3)),
1338        'theta_pd_n=%d'%steps,
1339        'theta_pd_type=rectangle',
1340        #'phi=0',
1341        'phi_pd=%g'%(180/np.sqrt(3)),
1342        'phi_pd_n=%d'%(2*steps),
1343        'phi_pd_type=rectangle',
1344        #'background=0',
1345    ])
1346    if 'psi' in opts['info'][0].parameters:
1347        opts['values'].extend([
1348            #'psi=0',
1349            'psi_pd=%g'%(180/np.sqrt(3)),
1350            'psi_pd_n=%d'%(2*steps),
1351            'psi_pd_type=rectangle',
1352        ])
1353        pass
1354
1355def parse_pars(opts, maxdim=np.inf):
1356    model_info, model_info2 = opts['info']
1357
1358    # Get demo parameters from model definition, or use default parameters
1359    # if model does not define demo parameters
1360    pars = get_pars(model_info, opts['use_demo'])
1361    pars2 = get_pars(model_info2, opts['use_demo'])
1362    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1363    # randomize parameters
1364    #pars.update(set_pars)  # set value before random to control range
1365    if opts['seed'] > -1:
1366        pars = randomize_pars(model_info, pars)
1367        limit_dimensions(model_info, pars, maxdim)
1368        if model_info != model_info2:
1369            pars2 = randomize_pars(model_info2, pars2)
1370            limit_dimensions(model_info2, pars2, maxdim)
1371            # Share values for parameters with the same name
1372            for k, v in pars.items():
1373                if k in pars2:
1374                    pars2[k] = v
1375        else:
1376            pars2 = pars.copy()
1377        constrain_pars(model_info, pars)
1378        constrain_pars(model_info2, pars2)
1379    pars = suppress_pd(pars, opts['mono'])
1380    pars2 = suppress_pd(pars2, opts['mono'])
1381    pars = suppress_magnetism(pars, not opts['magnetic'])
1382    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1383
1384    # Fill in parameters given on the command line
1385    presets = {}
1386    presets2 = {}
1387    for arg in opts['values']:
1388        k, v = arg.split('=', 1)
1389        if k not in pars and k not in pars2:
1390            # extract base name without polydispersity info
1391            s = set(p.split('_pd')[0] for p in pars)
1392            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1393            return None
1394        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
1395        if v1 and k in pars:
1396            presets[k] = float(v1) if isnumber(v1) else v1
1397        if v2 and k in pars2:
1398            presets2[k] = float(v2) if isnumber(v2) else v2
1399
1400    # If pd given on the command line, default pd_n to 35
1401    for k, v in list(presets.items()):
1402        if k.endswith('_pd'):
1403            presets.setdefault(k+'_n', 35.)
1404    for k, v in list(presets2.items()):
1405        if k.endswith('_pd'):
1406            presets2.setdefault(k+'_n', 35.)
1407
1408    # Evaluate preset parameter expressions
1409    context = MATH.copy()
1410    context['np'] = np
1411    context.update(pars)
1412    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1413    for k, v in presets.items():
1414        if not isinstance(v, float) and not k.endswith('_type'):
1415            presets[k] = eval(v, context)
1416    context.update(presets)
1417    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1418    for k, v in presets2.items():
1419        if not isinstance(v, float) and not k.endswith('_type'):
1420            presets2[k] = eval(v, context)
1421
1422    # update parameters with presets
1423    pars.update(presets)  # set value after random to control value
1424    pars2.update(presets2)  # set value after random to control value
1425    #import pprint; pprint.pprint(model_info)
1426
1427    if opts['show_pars']:
1428        if model_info.name != model_info2.name or pars != pars2:
1429            print("==== %s ====="%model_info.name)
1430            print(str(parlist(model_info, pars, opts['is2d'])))
1431            print("==== %s ====="%model_info2.name)
1432            print(str(parlist(model_info2, pars2, opts['is2d'])))
1433        else:
1434            print(str(parlist(model_info, pars, opts['is2d'])))
1435
1436    return pars, pars2
1437
1438def show_docs(opts):
1439    # type: (Dict[str, Any]) -> None
1440    """
1441    show html docs for the model
1442    """
1443    import os
1444    from .generate import make_html
1445    from . import rst2html
1446
1447    info = opts['info'][0]
1448    html = make_html(info)
1449    path = os.path.dirname(info.filename)
1450    url = "file://"+path.replace("\\","/")[2:]+"/"
1451    rst2html.view_html_qtapp(html, url)
1452
1453def explore(opts):
1454    # type: (Dict[str, Any]) -> None
1455    """
1456    explore the model using the bumps gui.
1457    """
1458    import wx  # type: ignore
1459    from bumps.names import FitProblem  # type: ignore
1460    from bumps.gui.app_frame import AppFrame  # type: ignore
1461    from bumps.gui import signal
1462
1463    is_mac = "cocoa" in wx.version()
1464    # Create an app if not running embedded
1465    app = wx.App() if wx.GetApp() is None else None
1466    model = Explore(opts)
1467    problem = FitProblem(model)
1468    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1469    if not is_mac:
1470        frame.Show()
1471    frame.panel.set_model(model=problem)
1472    frame.panel.Layout()
1473    frame.panel.aui.Split(0, wx.TOP)
1474    def reset_parameters(event):
1475        model.revert_values()
1476        signal.update_parameters(problem)
1477    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1478    if is_mac: frame.Show()
1479    # If running withing an app, start the main loop
1480    if app:
1481        app.MainLoop()
1482
1483class Explore(object):
1484    """
1485    Bumps wrapper for a SAS model comparison.
1486
1487    The resulting object can be used as a Bumps fit problem so that
1488    parameters can be adjusted in the GUI, with plots updated on the fly.
1489    """
1490    def __init__(self, opts):
1491        # type: (Dict[str, Any]) -> None
1492        from bumps.cli import config_matplotlib  # type: ignore
1493        from . import bumps_model
1494        config_matplotlib()
1495        self.opts = opts
1496        opts['pars'] = list(opts['pars'])
1497        p1, p2 = opts['pars']
1498        m1, m2 = opts['info']
1499        self.fix_p2 = m1 != m2 or p1 != p2
1500        model_info = m1
1501        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1502        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1503        if not opts['is2d']:
1504            for p in model_info.parameters.user_parameters({}, is2d=False):
1505                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1506                    k = p.name+ext
1507                    v = pars.get(k, None)
1508                    if v is not None:
1509                        v.range(*parameter_range(k, v.value))
1510        else:
1511            for k, v in pars.items():
1512                v.range(*parameter_range(k, v.value))
1513
1514        self.pars = pars
1515        self.starting_values = dict((k, v.value) for k, v in pars.items())
1516        self.pd_types = pd_types
1517        self.limits = None
1518
1519    def revert_values(self):
1520        for k, v in self.starting_values.items():
1521            self.pars[k].value = v
1522
1523    def model_update(self):
1524        pass
1525
1526    def numpoints(self):
1527        # type: () -> int
1528        """
1529        Return the number of points.
1530        """
1531        return len(self.pars) + 1  # so dof is 1
1532
1533    def parameters(self):
1534        # type: () -> Any   # Dict/List hierarchy of parameters
1535        """
1536        Return a dictionary of parameters.
1537        """
1538        return self.pars
1539
1540    def nllf(self):
1541        # type: () -> float
1542        """
1543        Return cost.
1544        """
1545        # pylint: disable=no-self-use
1546        return 0.  # No nllf
1547
1548    def plot(self, view='log'):
1549        # type: (str) -> None
1550        """
1551        Plot the data and residuals.
1552        """
1553        pars = dict((k, v.value) for k, v in self.pars.items())
1554        pars.update(self.pd_types)
1555        self.opts['pars'][0] = pars
1556        if not self.fix_p2:
1557            self.opts['pars'][1] = pars
1558        result = run_models(self.opts)
1559        limits = plot_models(self.opts, result, limits=self.limits)
1560        if self.limits is None:
1561            vmin, vmax = limits
1562            self.limits = vmax*1e-7, 1.3*vmax
1563            import pylab; pylab.clf()
1564            plot_models(self.opts, result, limits=self.limits)
1565
1566
1567def main(*argv):
1568    # type: (*str) -> None
1569    """
1570    Main program.
1571    """
1572    opts = parse_opts(argv)
1573    if opts is not None:
1574        if opts['seed'] > -1:
1575            print("Randomize using -random=%i"%opts['seed'])
1576            np.random.seed(opts['seed'])
1577        if opts['html']:
1578            show_docs(opts)
1579        elif opts['explore']:
1580            opts['pars'] = parse_pars(opts)
1581            if opts['pars'] is None:
1582                return
1583            explore(opts)
1584        else:
1585            compare(opts)
1586
1587if __name__ == "__main__":
1588    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.