source: sasmodels/sasmodels/compare.py @ 31eea1f

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 31eea1f was 31eea1f, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

explore accuracy of different 1D integration schemes

  • Property mode set to 100755
File size: 54.9 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel, get_mesh
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47from .weights import plot_weights
48
49try:
50    from typing import Optional, Dict, Any, Callable, Tuple
51except Exception:
52    pass
53else:
54    from .modelinfo import ModelInfo, Parameter, ParameterSet
55    from .data import Data
56    Calculator = Callable[[float], np.ndarray]
57
58USAGE = """
59usage: sascomp model [options...] [key=val]
60
61Generate and compare SAS models.  If a single model is specified it shows
62a plot of that model.  Different models can be compared, or the same model
63with different parameters.  The same model with the same parameters can
64be compared with different calculation engines to see the effects of precision
65on the resultant values.
66
67model or model1,model2 are the names of the models to compare (see below).
68
69Options (* for default):
70
71    === data generation ===
72    -data="path" uses q, dq from the data file
73    -noise=0 sets the measurement error dI/I
74    -res=0 sets the resolution width dQ/Q if calculating with resolution
75    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
76    -nq=128 sets the number of Q points in the data set
77    -1d*/-2d computes 1d or 2d data
78    -zero indicates that q=0 should be included
79
80    === model parameters ===
81    -preset*/-random[=seed] preset or random parameters
82    -sets=n generates n random datasets with the seed given by -random=seed
83    -pars/-nopars* prints the parameter set or not
84    -default/-demo* use demo vs default parameters
85    -sphere[=150] set up spherical integration over theta/phi using n points
86
87    === calculation options ===
88    -mono*/-poly force monodisperse or allow polydisperse random parameters
89    -cutoff=1e-5* cutoff value for including a point in polydispersity
90    -magnetic/-nonmagnetic* suppress magnetism
91    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
92    -neval=1 sets the number of evals for more accurate timing
93
94    === precision options ===
95    -engine=default uses the default calcution precision
96    -single/-double/-half/-fast sets an OpenCL calculation engine
97    -single!/-double!/-quad! sets an OpenMP calculation engine
98    -sasview sets the sasview calculation engine
99
100    === plotting ===
101    -plot*/-noplot plots or suppress the plot of the model
102    -linear/-log*/-q4 intensity scaling on plots
103    -hist/-nohist* plot histogram of relative error
104    -abs/-rel* plot relative or absolute error
105    -title="note" adds note to the plot title, after the model name
106    -weights shows weights plots for the polydisperse parameters
107
108    === output options ===
109    -edit starts the parameter explorer
110    -help/-html shows the model docs instead of running the model
111
112The interpretation of quad precision depends on architecture, and may
113vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
114On unix and mac you may need single quotes around the DLL computation
115engines, such as -engine='single!,double!' since !, is treated as a history
116expansion request in the shell.
117
118Key=value pairs allow you to set specific values for the model parameters.
119Key=value1,value2 to compare different values of the same parameter. The
120value can be an expression including other parameters.
121
122Items later on the command line override those that appear earlier.
123
124Examples:
125
126    # compare single and double precision calculation for a barbell
127    sascomp barbell -engine=single,double
128
129    # generate 10 random lorentz models, with seed=27
130    sascomp lorentz -sets=10 -seed=27
131
132    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
133    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
134
135    # model timing test requires multiple evals to perform the estimate
136    sascomp pringle -engine=single,double -timing=100,100 -noplot
137"""
138
139# Update docs with command line usage string.   This is separate from the usual
140# doc string so that we can display it at run time if there is an error.
141# lin
142__doc__ = (__doc__  # pylint: disable=redefined-builtin
143           + """
144Program description
145-------------------
146
147""" + USAGE)
148
149kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
150
151# list of math functions for use in evaluating parameters
152MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
153
154# CRUFT python 2.6
155if not hasattr(datetime.timedelta, 'total_seconds'):
156    def delay(dt):
157        """Return number date-time delta as number seconds"""
158        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
159else:
160    def delay(dt):
161        """Return number date-time delta as number seconds"""
162        return dt.total_seconds()
163
164
165class push_seed(object):
166    """
167    Set the seed value for the random number generator.
168
169    When used in a with statement, the random number generator state is
170    restored after the with statement is complete.
171
172    :Parameters:
173
174    *seed* : int or array_like, optional
175        Seed for RandomState
176
177    :Example:
178
179    Seed can be used directly to set the seed::
180
181        >>> from numpy.random import randint
182        >>> push_seed(24)
183        <...push_seed object at...>
184        >>> print(randint(0,1000000,3))
185        [242082    899 211136]
186
187    Seed can also be used in a with statement, which sets the random
188    number generator state for the enclosed computations and restores
189    it to the previous state on completion::
190
191        >>> with push_seed(24):
192        ...    print(randint(0,1000000,3))
193        [242082    899 211136]
194
195    Using nested contexts, we can demonstrate that state is indeed
196    restored after the block completes::
197
198        >>> with push_seed(24):
199        ...    print(randint(0,1000000))
200        ...    with push_seed(24):
201        ...        print(randint(0,1000000,3))
202        ...    print(randint(0,1000000))
203        242082
204        [242082    899 211136]
205        899
206
207    The restore step is protected against exceptions in the block::
208
209        >>> with push_seed(24):
210        ...    print(randint(0,1000000))
211        ...    try:
212        ...        with push_seed(24):
213        ...            print(randint(0,1000000,3))
214        ...            raise Exception()
215        ...    except Exception:
216        ...        print("Exception raised")
217        ...    print(randint(0,1000000))
218        242082
219        [242082    899 211136]
220        Exception raised
221        899
222    """
223    def __init__(self, seed=None):
224        # type: (Optional[int]) -> None
225        self._state = np.random.get_state()
226        np.random.seed(seed)
227
228    def __enter__(self):
229        # type: () -> None
230        pass
231
232    def __exit__(self, exc_type, exc_value, traceback):
233        # type: (Any, BaseException, Any) -> None
234        # TODO: better typing for __exit__ method
235        np.random.set_state(self._state)
236
237def tic():
238    # type: () -> Callable[[], float]
239    """
240    Timer function.
241
242    Use "toc=tic()" to start the clock and "toc()" to measure
243    a time interval.
244    """
245    then = datetime.datetime.now()
246    return lambda: delay(datetime.datetime.now() - then)
247
248
249def set_beam_stop(data, radius, outer=None):
250    # type: (Data, float, float) -> None
251    """
252    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
253
254    Note: this function does not require sasview
255    """
256    if hasattr(data, 'qx_data'):
257        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
258        data.mask = (q < radius)
259        if outer is not None:
260            data.mask |= (q >= outer)
261    else:
262        data.mask = (data.x < radius)
263        if outer is not None:
264            data.mask |= (data.x >= outer)
265
266
267def parameter_range(p, v):
268    # type: (str, float) -> Tuple[float, float]
269    """
270    Choose a parameter range based on parameter name and initial value.
271    """
272    # process the polydispersity options
273    if p.endswith('_pd_n'):
274        return 0., 100.
275    elif p.endswith('_pd_nsigma'):
276        return 0., 5.
277    elif p.endswith('_pd_type'):
278        raise ValueError("Cannot return a range for a string value")
279    elif any(s in p for s in ('theta', 'phi', 'psi')):
280        # orientation in [-180,180], orientation pd in [0,45]
281        if p.endswith('_pd'):
282            return 0., 180.
283        else:
284            return -180., 180.
285    elif p.endswith('_pd'):
286        return 0., 1.
287    elif 'sld' in p:
288        return -0.5, 10.
289    elif p == 'background':
290        return 0., 10.
291    elif p == 'scale':
292        return 0., 1.e3
293    elif v < 0.:
294        return 2.*v, -2.*v
295    else:
296        return 0., (2.*v if v > 0. else 1.)
297
298
299def _randomize_one(model_info, name, value):
300    # type: (ModelInfo, str, float) -> float
301    # type: (ModelInfo, str, str) -> str
302    """
303    Randomize a single parameter.
304    """
305    # Set the amount of polydispersity/angular dispersion, but by default pd_n
306    # is zero so there is no polydispersity.  This allows us to turn on/off
307    # pd by setting pd_n, and still have randomly generated values
308    if name.endswith('_pd'):
309        par = model_info.parameters[name[:-3]]
310        if par.type == 'orientation':
311            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
312            return 180*np.random.beta(2.5, 20)
313        else:
314            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
315            return np.random.beta(1.5, 7)
316
317    # pd is selected globally rather than per parameter, so set to 0 for no pd
318    # In particular, when multiple pd dimensions, want to decrease the number
319    # of points per dimension for faster computation
320    if name.endswith('_pd_n'):
321        return 0
322
323    # Don't mess with distribution type for now
324    if name.endswith('_pd_type'):
325        return 'gaussian'
326
327    # type-dependent value of number of sigmas; for gaussian use 3.
328    if name.endswith('_pd_nsigma'):
329        return 3.
330
331    # background in the range [0.01, 1]
332    if name == 'background':
333        return 10**np.random.uniform(-2, 0)
334
335    # scale defaults to 0.1% to 30% volume fraction
336    if name == 'scale':
337        return 10**np.random.uniform(-3, -0.5)
338
339    # If it is a list of choices, pick one at random with equal probability
340    # In practice, the model specific random generator will override.
341    par = model_info.parameters[name]
342    if len(par.limits) > 2:  # choice list
343        return np.random.randint(len(par.limits))
344
345    # If it is a fixed range, pick from it with equal probability.
346    # For logarithmic ranges, the model will have to override.
347    if np.isfinite(par.limits).all():
348        return np.random.uniform(*par.limits)
349
350    # If the paramter is marked as an sld use the range of neutron slds
351    # TODO: ought to randomly contrast match a pair of SLDs
352    if par.type == 'sld':
353        return np.random.uniform(-0.5, 12)
354
355    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
356    if par.name.startswith('M0:'):
357        return np.random.uniform(0, 5)
358
359    # Guess at the random length/radius/thickness.  In practice, all models
360    # are going to set their own reasonable ranges.
361    if par.type == 'volume':
362        if ('length' in par.name or
363                'radius' in par.name or
364                'thick' in par.name):
365            return 10**np.random.uniform(2, 4)
366
367    # In the absence of any other info, select a value in [0, 2v], or
368    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
369    # model random parameter generators will override this default.
370    low, high = parameter_range(par.name, value)
371    limits = (max(par.limits[0], low), min(par.limits[1], high))
372    return np.random.uniform(*limits)
373
374def _random_pd(model_info, pars):
375    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
376    pd_volume = []
377    pd_oriented = []
378    for p in pd:
379        if p.type == 'orientation':
380            pd_oriented.append(p.name)
381        elif p.length_control is not None:
382            n = int(pars.get(p.length_control, 1) + 0.5)
383            pd_volume.extend(p.name+str(k+1) for k in range(n))
384        elif p.length > 1:
385            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
386        else:
387            pd_volume.append(p.name)
388    u = np.random.rand()
389    n = len(pd_volume)
390    if u < 0.01 or n < 1:
391        pass  # 1% chance of no polydispersity
392    elif u < 0.86 or n < 2:
393        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
394    elif u < 0.99 or n < 3:
395        choices = np.random.choice(len(pd_volume), size=2)
396        pars[pd_volume[choices[0]]+"_pd_n"] = 25
397        pars[pd_volume[choices[1]]+"_pd_n"] = 10
398    else:
399        choices = np.random.choice(len(pd_volume), size=3)
400        pars[pd_volume[choices[0]]+"_pd_n"] = 25
401        pars[pd_volume[choices[1]]+"_pd_n"] = 10
402        pars[pd_volume[choices[2]]+"_pd_n"] = 5
403    if pd_oriented:
404        pars['theta_pd_n'] = 20
405        if np.random.rand() < 0.1:
406            pars['phi_pd_n'] = 5
407        if np.random.rand() < 0.1:
408            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
409                #print("generating psi_pd_n")
410                pars['psi_pd_n'] = 5
411
412    ## Show selected polydispersity
413    #for name, value in pars.items():
414    #    if name.endswith('_pd_n') and value > 0:
415    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
416
417
418def randomize_pars(model_info, pars):
419    # type: (ModelInfo, ParameterSet) -> ParameterSet
420    """
421    Generate random values for all of the parameters.
422
423    Valid ranges for the random number generator are guessed from the name of
424    the parameter; this will not account for constraints such as cap radius
425    greater than cylinder radius in the capped_cylinder model, so
426    :func:`constrain_pars` needs to be called afterward..
427    """
428    # Note: the sort guarantees order of calls to random number generator
429    random_pars = dict((p, _randomize_one(model_info, p, v))
430                       for p, v in sorted(pars.items()))
431    if model_info.random is not None:
432        random_pars.update(model_info.random())
433    _random_pd(model_info, random_pars)
434    return random_pars
435
436
437def limit_dimensions(model_info, pars, maxdim):
438    # type: (ModelInfo, ParameterSet, float) -> None
439    """
440    Limit parameters of units of Ang to maxdim.
441    """
442    for p in model_info.parameters.call_parameters:
443        value = pars[p.name]
444        if p.units == 'Ang' and value > maxdim:
445            pars[p.name] = maxdim*10**np.random.uniform(-3,0)
446
447def constrain_pars(model_info, pars):
448    # type: (ModelInfo, ParameterSet) -> None
449    """
450    Restrict parameters to valid values.
451
452    This includes model specific code for models such as capped_cylinder
453    which need to support within model constraints (cap radius more than
454    cylinder radius in this case).
455
456    Warning: this updates the *pars* dictionary in place.
457    """
458    # TODO: move the model specific code to the individual models
459    name = model_info.id
460    # if it is a product model, then just look at the form factor since
461    # none of the structure factors need any constraints.
462    if '*' in name:
463        name = name.split('*')[0]
464
465    # Suppress magnetism for python models (not yet implemented)
466    if callable(model_info.Iq):
467        pars.update(suppress_magnetism(pars))
468
469    if name == 'barbell':
470        if pars['radius_bell'] < pars['radius']:
471            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
472
473    elif name == 'capped_cylinder':
474        if pars['radius_cap'] < pars['radius']:
475            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
476
477    elif name == 'guinier':
478        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
479        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
480        # => ln A - (Rg^2 q^2/3) > -30 ln 10
481        # => Rg^2 q^2/3 < 30 ln 10 + ln A
482        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
483        #q_max = 0.2  # mid q maximum
484        q_max = 1.0  # high q maximum
485        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
486        pars['rg'] = min(pars['rg'], rg_max)
487
488    elif name == 'pearl_necklace':
489        if pars['radius'] < pars['thick_string']:
490            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
491        pass
492
493    elif name == 'rpa':
494        # Make sure phi sums to 1.0
495        if pars['case_num'] < 2:
496            pars['Phi1'] = 0.
497            pars['Phi2'] = 0.
498        elif pars['case_num'] < 5:
499            pars['Phi1'] = 0.
500        total = sum(pars['Phi'+c] for c in '1234')
501        for c in '1234':
502            pars['Phi'+c] /= total
503
504def parlist(model_info, pars, is2d):
505    # type: (ModelInfo, ParameterSet, bool) -> str
506    """
507    Format the parameter list for printing.
508    """
509    is2d = True
510    lines = []
511    parameters = model_info.parameters
512    magnetic = False
513    magnetic_pars = []
514    for p in parameters.user_parameters(pars, is2d):
515        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
516            continue
517        if p.id.startswith('up:'):
518            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
519            continue
520        fields = dict(
521            value=pars.get(p.id, p.default),
522            pd=pars.get(p.id+"_pd", 0.),
523            n=int(pars.get(p.id+"_pd_n", 0)),
524            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
525            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
526            relative_pd=p.relative_pd,
527            M0=pars.get('M0:'+p.id, 0.),
528            mphi=pars.get('mphi:'+p.id, 0.),
529            mtheta=pars.get('mtheta:'+p.id, 0.),
530        )
531        lines.append(_format_par(p.name, **fields))
532        magnetic = magnetic or fields['M0'] != 0.
533    if magnetic and magnetic_pars:
534        lines.append(" ".join(magnetic_pars))
535    return "\n".join(lines)
536
537    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
538
539def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
540                relative_pd=False, M0=0., mphi=0., mtheta=0.):
541    # type: (str, float, float, int, float, str) -> str
542    line = "%s: %g"%(name, value)
543    if pd != 0.  and n != 0:
544        if relative_pd:
545            pd *= value
546        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
547                % (pd, n, nsigma, nsigma, pdtype)
548    if M0 != 0.:
549        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
550    return line
551
552def suppress_pd(pars, suppress=True):
553    # type: (ParameterSet) -> ParameterSet
554    """
555    If suppress is True complete eliminate polydispersity of the model to test
556    models more quickly.  If suppress is False, make sure at least one
557    parameter is polydisperse, setting the first polydispersity parameter to
558    15% if no polydispersity is given (with no explicit demo parameters given
559    in the model, there will be no default polydispersity).
560    """
561    pars = pars.copy()
562    #print("pars=", pars)
563    if suppress:
564        for p in pars:
565            if p.endswith("_pd_n"):
566                pars[p] = 0
567    else:
568        any_pd = False
569        first_pd = None
570        for p in pars:
571            if p.endswith("_pd_n"):
572                pd = pars.get(p[:-2], 0.)
573                any_pd |= (pars[p] != 0 and pd != 0.)
574                if first_pd is None:
575                    first_pd = p
576        if not any_pd and first_pd is not None:
577            if pars[first_pd] == 0:
578                pars[first_pd] = 35
579            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
580                pars[first_pd[:-2]] = 0.15
581    return pars
582
583def suppress_magnetism(pars, suppress=True):
584    # type: (ParameterSet) -> ParameterSet
585    """
586    If suppress is True complete eliminate magnetism of the model to test
587    models more quickly.  If suppress is False, make sure at least one sld
588    parameter is magnetic, setting the first parameter to have a strong
589    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
590    in the model, there will be no default magnetism).
591    """
592    pars = pars.copy()
593    if suppress:
594        for p in pars:
595            if p.startswith("M0:"):
596                pars[p] = 0
597    else:
598        any_mag = False
599        first_mag = None
600        for p in pars:
601            if p.startswith("M0:"):
602                any_mag |= (pars[p] != 0)
603                if first_mag is None:
604                    first_mag = p
605        if not any_mag and first_mag is not None:
606            pars[first_mag] = 8.
607    return pars
608
609def eval_sasview(model_info, data):
610    # type: (Modelinfo, Data) -> Calculator
611    """
612    Return a model calculator using the pre-4.0 SasView models.
613    """
614    # importing sas here so that the error message will be that sas failed to
615    # import rather than the more obscure smear_selection not imported error
616    import sas
617    import sas.models
618    from sas.models.qsmearing import smear_selection
619    from sas.models.MultiplicationModel import MultiplicationModel
620    from sas.models.dispersion_models import models as dispersers
621
622    def get_model_class(name):
623        # type: (str) -> "sas.models.BaseComponent"
624        #print("new",sorted(_pars.items()))
625        __import__('sas.models.' + name)
626        ModelClass = getattr(getattr(sas.models, name, None), name, None)
627        if ModelClass is None:
628            raise ValueError("could not find model %r in sas.models"%name)
629        return ModelClass
630
631    # WARNING: ugly hack when handling model!
632    # Sasview models with multiplicity need to be created with the target
633    # multiplicity, so we cannot create the target model ahead of time for
634    # for multiplicity models.  Instead we store the model in a list and
635    # update the first element of that list with the new multiplicity model
636    # every time we evaluate.
637
638    # grab the sasview model, or create it if it is a product model
639    if model_info.composition:
640        composition_type, parts = model_info.composition
641        if composition_type == 'product':
642            P, S = [get_model_class(revert_name(p))() for p in parts]
643            model = [MultiplicationModel(P, S)]
644        else:
645            raise ValueError("sasview mixture models not supported by compare")
646    else:
647        old_name = revert_name(model_info)
648        if old_name is None:
649            raise ValueError("model %r does not exist in old sasview"
650                            % model_info.id)
651        ModelClass = get_model_class(old_name)
652        model = [ModelClass()]
653    model[0].disperser_handles = {}
654
655    # build a smearer with which to call the model, if necessary
656    smearer = smear_selection(data, model=model)
657    if hasattr(data, 'qx_data'):
658        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
659        index = ((~data.mask) & (~np.isnan(data.data))
660                 & (q >= data.qmin) & (q <= data.qmax))
661        if smearer is not None:
662            smearer.model = model  # because smear_selection has a bug
663            smearer.accuracy = data.accuracy
664            smearer.set_index(index)
665            def _call_smearer():
666                smearer.model = model[0]
667                return smearer.get_value()
668            theory = _call_smearer
669        else:
670            theory = lambda: model[0].evalDistribution([data.qx_data[index],
671                                                        data.qy_data[index]])
672    elif smearer is not None:
673        theory = lambda: smearer(model[0].evalDistribution(data.x))
674    else:
675        theory = lambda: model[0].evalDistribution(data.x)
676
677    def calculator(**pars):
678        # type: (float, ...) -> np.ndarray
679        """
680        Sasview calculator for model.
681        """
682        oldpars = revert_pars(model_info, pars)
683        # For multiplicity models, create a model with the correct multiplicity
684        control = oldpars.pop("CONTROL", None)
685        if control is not None:
686            # sphericalSLD has one fewer multiplicity.  This update should
687            # happen in revert_pars, but it hasn't been called yet.
688            model[0] = ModelClass(control)
689        # paying for parameter conversion each time to keep life simple, if not fast
690        for k, v in oldpars.items():
691            if k.endswith('.type'):
692                par = k[:-5]
693                if v == 'gaussian': continue
694                cls = dispersers[v if v != 'rectangle' else 'rectangula']
695                handle = cls()
696                model[0].disperser_handles[par] = handle
697                try:
698                    model[0].set_dispersion(par, handle)
699                except Exception:
700                    exception.annotate_exception("while setting %s to %r"
701                                                 %(par, v))
702                    raise
703
704
705        #print("sasview pars",oldpars)
706        for k, v in oldpars.items():
707            name_attr = k.split('.')  # polydispersity components
708            if len(name_attr) == 2:
709                par, disp_par = name_attr
710                model[0].dispersion[par][disp_par] = v
711            else:
712                model[0].setParam(k, v)
713        return theory()
714
715    calculator.engine = "sasview"
716    return calculator
717
718DTYPE_MAP = {
719    'half': '16',
720    'fast': 'fast',
721    'single': '32',
722    'double': '64',
723    'quad': '128',
724    'f16': '16',
725    'f32': '32',
726    'f64': '64',
727    'float16': '16',
728    'float32': '32',
729    'float64': '64',
730    'float128': '128',
731    'longdouble': '128',
732}
733def eval_opencl(model_info, data, dtype='single', cutoff=0.):
734    # type: (ModelInfo, Data, str, float) -> Calculator
735    """
736    Return a model calculator using the OpenCL calculation engine.
737    """
738    if not core.HAVE_OPENCL:
739        raise RuntimeError("OpenCL not available")
740    model = core.build_model(model_info, dtype=dtype, platform="ocl")
741    calculator = DirectModel(data, model, cutoff=cutoff)
742    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
743    return calculator
744
745def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
746    # type: (ModelInfo, Data, str, float) -> Calculator
747    """
748    Return a model calculator using the DLL calculation engine.
749    """
750    model = core.build_model(model_info, dtype=dtype, platform="dll")
751    calculator = DirectModel(data, model, cutoff=cutoff)
752    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
753    return calculator
754
755def time_calculation(calculator, pars, evals=1):
756    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
757    """
758    Compute the average calculation time over N evaluations.
759
760    An additional call is generated without polydispersity in order to
761    initialize the calculation engine, and make the average more stable.
762    """
763    # initialize the code so time is more accurate
764    if evals > 1:
765        calculator(**suppress_pd(pars))
766    toc = tic()
767    # make sure there is at least one eval
768    value = calculator(**pars)
769    for _ in range(evals-1):
770        value = calculator(**pars)
771    average_time = toc()*1000. / evals
772    #print("I(q)",value)
773    return value, average_time
774
775def make_data(opts):
776    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
777    """
778    Generate an empty dataset, used with the model to set Q points
779    and resolution.
780
781    *opts* contains the options, with 'qmax', 'nq', 'res',
782    'accuracy', 'is2d' and 'view' parsed from the command line.
783    """
784    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
785    if opts['is2d']:
786        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
787        data = empty_data2D(q, resolution=res)
788        data.accuracy = opts['accuracy']
789        set_beam_stop(data, 0.0004)
790        index = ~data.mask
791    else:
792        if opts['view'] == 'log' and not opts['zero']:
793            qmax = math.log10(qmax)
794            q = np.logspace(qmax-3, qmax, nq)
795        else:
796            q = np.linspace(0.001*qmax, qmax, nq)
797        if opts['zero']:
798            q = np.hstack((0, q))
799        data = empty_data1D(q, resolution=res)
800        index = slice(None, None)
801    return data, index
802
803def make_engine(model_info, data, dtype, cutoff):
804    # type: (ModelInfo, Data, str, float) -> Calculator
805    """
806    Generate the appropriate calculation engine for the given datatype.
807
808    Datatypes with '!' appended are evaluated using external C DLLs rather
809    than OpenCL.
810    """
811    if dtype == 'sasview':
812        return eval_sasview(model_info, data)
813    elif dtype is None or not dtype.endswith('!'):
814        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
815    else:
816        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
817
818def _show_invalid(data, theory):
819    # type: (Data, np.ma.ndarray) -> None
820    """
821    Display a list of the non-finite values in theory.
822    """
823    if not theory.mask.any():
824        return
825
826    if hasattr(data, 'x'):
827        bad = zip(data.x[theory.mask], theory[theory.mask])
828        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
829
830
831def compare(opts, limits=None, maxdim=np.inf):
832    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
833    """
834    Preform a comparison using options from the command line.
835
836    *limits* are the limits on the values to use, either to set the y-axis
837    for 1D or to set the colormap scale for 2D.  If None, then they are
838    inferred from the data and returned. When exploring using Bumps,
839    the limits are set when the model is initially called, and maintained
840    as the values are adjusted, making it easier to see the effects of the
841    parameters.
842
843    *maxdim* is the maximum value for any parameter with units of Angstrom.
844    """
845    for k in range(opts['sets']):
846        if k > 1:
847            # print a separate seed for each dataset for better reproducibility
848            new_seed = np.random.randint(1000000)
849            print("Set %d uses -random=%i"%(k+1,new_seed))
850            np.random.seed(new_seed)
851        opts['pars'] = parse_pars(opts, maxdim=maxdim)
852        if opts['pars'] is None:
853            return
854        result = run_models(opts, verbose=True)
855        if opts['plot']:
856            limits = plot_models(opts, result, limits=limits, setnum=k)
857        if opts['show_weights']:
858            base, _ = opts['engines']
859            base_pars, _ = opts['pars']
860            model_info = base._kernel.info
861            dim = base._kernel.dim
862            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
863    if opts['plot']:
864        import matplotlib.pyplot as plt
865        plt.show()
866    return limits
867
868def run_models(opts, verbose=False):
869    # type: (Dict[str, Any]) -> Dict[str, Any]
870
871    base, comp = opts['engines']
872    base_n, comp_n = opts['count']
873    base_pars, comp_pars = opts['pars']
874    data = opts['data']
875
876    comparison = comp is not None
877
878    base_time = comp_time = None
879    base_value = comp_value = resid = relerr = None
880
881    # Base calculation
882    try:
883        base_raw, base_time = time_calculation(base, base_pars, base_n)
884        base_value = np.ma.masked_invalid(base_raw)
885        if verbose:
886            print("%s t=%.2f ms, intensity=%.0f"
887                  % (base.engine, base_time, base_value.sum()))
888        _show_invalid(data, base_value)
889    except ImportError:
890        traceback.print_exc()
891
892    # Comparison calculation
893    if comparison:
894        try:
895            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
896            comp_value = np.ma.masked_invalid(comp_raw)
897            if verbose:
898                print("%s t=%.2f ms, intensity=%.0f"
899                      % (comp.engine, comp_time, comp_value.sum()))
900            _show_invalid(data, comp_value)
901        except ImportError:
902            traceback.print_exc()
903
904    # Compare, but only if computing both forms
905    if comparison:
906        resid = (base_value - comp_value)
907        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
908        if verbose:
909            _print_stats("|%s-%s|"
910                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
911                         resid)
912            _print_stats("|(%s-%s)/%s|"
913                         % (base.engine, comp.engine, comp.engine),
914                         relerr)
915
916    return dict(base_value=base_value, comp_value=comp_value,
917                base_time=base_time, comp_time=comp_time,
918                resid=resid, relerr=relerr)
919
920
921def _print_stats(label, err):
922    # type: (str, np.ma.ndarray) -> None
923    # work with trimmed data, not the full set
924    sorted_err = np.sort(abs(err.compressed()))
925    if len(sorted_err) == 0.:
926        print(label + "  no valid values")
927        return
928
929    p50 = int((len(sorted_err)-1)*0.50)
930    p98 = int((len(sorted_err)-1)*0.98)
931    data = [
932        "max:%.3e"%sorted_err[-1],
933        "median:%.3e"%sorted_err[p50],
934        "98%%:%.3e"%sorted_err[p98],
935        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
936        "zero-offset:%+.3e"%np.mean(sorted_err),
937        ]
938    print(label+"  "+"  ".join(data))
939
940
941def plot_models(opts, result, limits=None, setnum=0):
942    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
943    import matplotlib.pyplot as plt
944
945    base_value, comp_value = result['base_value'], result['comp_value']
946    base_time, comp_time = result['base_time'], result['comp_time']
947    resid, relerr = result['resid'], result['relerr']
948
949    have_base, have_comp = (base_value is not None), (comp_value is not None)
950    base, comp = opts['engines']
951    data = opts['data']
952    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
953
954    # Plot if requested
955    view = opts['view']
956    if limits is None:
957        vmin, vmax = np.inf, -np.inf
958        if have_base:
959            vmin = min(vmin, base_value.min())
960            vmax = max(vmax, base_value.max())
961        if have_comp:
962            vmin = min(vmin, comp_value.min())
963            vmax = max(vmax, comp_value.max())
964        limits = vmin, vmax
965
966    if have_base:
967        if have_comp:
968            plt.subplot(131)
969        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
970        plt.title("%s t=%.2f ms"%(base.engine, base_time))
971        #cbar_title = "log I"
972    if have_comp:
973        if have_base:
974            plt.subplot(132)
975        if not opts['is2d'] and have_base:
976            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
977        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
978        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
979        #cbar_title = "log I"
980    if have_base and have_comp:
981        plt.subplot(133)
982        if not opts['rel_err']:
983            err, errstr, errview = resid, "abs err", "linear"
984        else:
985            err, errstr, errview = abs(relerr), "rel err", "log"
986        if 0:  # 95% cutoff
987            sorted = np.sort(err.flatten())
988            cutoff = sorted[int(sorted.size*0.95)]
989            err[err > cutoff] = cutoff
990        #err,errstr = base/comp,"ratio"
991        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
992        plt.xscale('log' if view == 'log' else linear)
993        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
994        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
995        #cbar_title = errstr if errview=="linear" else "log "+errstr
996    #if is2D:
997    #    h = plt.colorbar()
998    #    h.ax.set_title(cbar_title)
999    fig = plt.gcf()
1000    extra_title = ' '+opts['title'] if opts['title'] else ''
1001    fig.suptitle(":".join(opts['name']) + extra_title)
1002
1003    if have_base and have_comp and opts['show_hist']:
1004        plt.figure()
1005        v = relerr
1006        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
1007        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
1008        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
1009                   % (base.engine, comp.engine, comp.engine))
1010        plt.ylabel('P(err)')
1011        plt.title('Distribution of relative error between calculation engines')
1012
1013    return limits
1014
1015
1016# ===========================================================================
1017#
1018
1019# Set of command line options.
1020# Normal options such as -plot/-noplot are specified as 'name'.
1021# For options such as -nq=500 which require a value use 'name='.
1022#
1023OPTIONS = [
1024    # Plotting
1025    'plot', 'noplot', 'weights',
1026    'linear', 'log', 'q4',
1027    'rel', 'abs',
1028    'hist', 'nohist',
1029    'title=',
1030
1031    # Data generation
1032    'data=', 'noise=', 'res=',
1033    'nq=', 'lowq', 'midq', 'highq', 'exq', 'zero',
1034    '2d', '1d',
1035
1036    # Parameter set
1037    'preset', 'random', 'random=', 'sets=',
1038    'demo', 'default',  # TODO: remove demo/default
1039    'nopars', 'pars',
1040    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
1041
1042    # Calculation options
1043    'poly', 'mono', 'cutoff=',
1044    'magnetic', 'nonmagnetic',
1045    'accuracy=',
1046    'neval=',  # for timing...
1047
1048    # Precision options
1049    'engine=',
1050    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1051    'sasview',  # TODO: remove sasview 3.x support
1052
1053    # Output options
1054    'help', 'html', 'edit',
1055    ]
1056
1057NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1058VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1059
1060
1061def columnize(items, indent="", width=79):
1062    # type: (List[str], str, int) -> str
1063    """
1064    Format a list of strings into columns.
1065
1066    Returns a string with carriage returns ready for printing.
1067    """
1068    column_width = max(len(w) for w in items) + 1
1069    num_columns = (width - len(indent)) // column_width
1070    num_rows = len(items) // num_columns
1071    items = items + [""] * (num_rows * num_columns - len(items))
1072    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1073    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1074             for row in zip(*columns)]
1075    output = indent + ("\n"+indent).join(lines)
1076    return output
1077
1078
1079def get_pars(model_info, use_demo=False):
1080    # type: (ModelInfo, bool) -> ParameterSet
1081    """
1082    Extract demo parameters from the model definition.
1083    """
1084    # Get the default values for the parameters
1085    pars = {}
1086    for p in model_info.parameters.call_parameters:
1087        parts = [('', p.default)]
1088        if p.polydisperse:
1089            parts.append(('_pd', 0.0))
1090            parts.append(('_pd_n', 0))
1091            parts.append(('_pd_nsigma', 3.0))
1092            parts.append(('_pd_type', "gaussian"))
1093        for ext, val in parts:
1094            if p.length > 1:
1095                dict(("%s%d%s" % (p.id, k, ext), val)
1096                     for k in range(1, p.length+1))
1097            else:
1098                pars[p.id + ext] = val
1099
1100    # Plug in values given in demo
1101    if use_demo and model_info.demo:
1102        pars.update(model_info.demo)
1103    return pars
1104
1105INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1106def isnumber(str):
1107    match = FLOAT_RE.match(str)
1108    isfloat = (match and not str[match.end():])
1109    return isfloat or INTEGER_RE.match(str)
1110
1111# For distinguishing pairs of models for comparison
1112# key-value pair separator =
1113# shell characters  | & ; <> $ % ' " \ # `
1114# model and parameter names _
1115# parameter expressions - + * / . ( )
1116# path characters including tilde expansion and windows drive ~ / :
1117# not sure about brackets [] {}
1118# maybe one of the following @ ? ^ ! ,
1119PAR_SPLIT = ','
1120def parse_opts(argv):
1121    # type: (List[str]) -> Dict[str, Any]
1122    """
1123    Parse command line options.
1124    """
1125    MODELS = core.list_models()
1126    flags = [arg for arg in argv
1127             if arg.startswith('-')]
1128    values = [arg for arg in argv
1129              if not arg.startswith('-') and '=' in arg]
1130    positional_args = [arg for arg in argv
1131                       if not arg.startswith('-') and '=' not in arg]
1132    models = "\n    ".join("%-15s"%v for v in MODELS)
1133    if len(positional_args) == 0:
1134        print(USAGE)
1135        print("\nAvailable models:")
1136        print(columnize(MODELS, indent="  "))
1137        return None
1138
1139    invalid = [o[1:] for o in flags
1140               if o[1:] not in NAME_OPTIONS
1141               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1142    if invalid:
1143        print("Invalid options: %s"%(", ".join(invalid)))
1144        return None
1145
1146    name = positional_args[-1]
1147
1148    # pylint: disable=bad-whitespace
1149    # Interpret the flags
1150    opts = {
1151        'plot'      : True,
1152        'view'      : 'log',
1153        'is2d'      : False,
1154        'qmax'      : 0.05,
1155        'nq'        : 128,
1156        'res'       : 0.0,
1157        'noise'     : 0.0,
1158        'accuracy'  : 'Low',
1159        'cutoff'    : '0.0',
1160        'seed'      : -1,  # default to preset
1161        'mono'      : True,
1162        # Default to magnetic a magnetic moment is set on the command line
1163        'magnetic'  : False,
1164        'show_pars' : False,
1165        'show_hist' : False,
1166        'rel_err'   : True,
1167        'explore'   : False,
1168        'use_demo'  : True,
1169        'zero'      : False,
1170        'html'      : False,
1171        'title'     : None,
1172        'datafile'  : None,
1173        'sets'      : 0,
1174        'engine'    : 'default',
1175        'count'     : '1',
1176        'show_weights' : False,
1177        'sphere'    : 0,
1178    }
1179    for arg in flags:
1180        if arg == '-noplot':    opts['plot'] = False
1181        elif arg == '-plot':    opts['plot'] = True
1182        elif arg == '-linear':  opts['view'] = 'linear'
1183        elif arg == '-log':     opts['view'] = 'log'
1184        elif arg == '-q4':      opts['view'] = 'q4'
1185        elif arg == '-1d':      opts['is2d'] = False
1186        elif arg == '-2d':      opts['is2d'] = True
1187        elif arg == '-exq':     opts['qmax'] = 10.0
1188        elif arg == '-highq':   opts['qmax'] = 1.0
1189        elif arg == '-midq':    opts['qmax'] = 0.2
1190        elif arg == '-lowq':    opts['qmax'] = 0.05
1191        elif arg == '-zero':    opts['zero'] = True
1192        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1193        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1194        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1195        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1196        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1197        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1198        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1199        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1200        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1201        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
1202        elif arg.startswith('-random='):
1203            opts['seed'] = int(arg[8:])
1204            opts['sets'] = 0
1205        elif arg == '-random':
1206            opts['seed'] = np.random.randint(1000000)
1207            opts['sets'] = 0
1208        elif arg.startswith('-sphere'):
1209            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1210            opts['is2d'] = True
1211        elif arg == '-preset':  opts['seed'] = -1
1212        elif arg == '-mono':    opts['mono'] = True
1213        elif arg == '-poly':    opts['mono'] = False
1214        elif arg == '-magnetic':       opts['magnetic'] = True
1215        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1216        elif arg == '-pars':    opts['show_pars'] = True
1217        elif arg == '-nopars':  opts['show_pars'] = False
1218        elif arg == '-hist':    opts['show_hist'] = True
1219        elif arg == '-nohist':  opts['show_hist'] = False
1220        elif arg == '-rel':     opts['rel_err'] = True
1221        elif arg == '-abs':     opts['rel_err'] = False
1222        elif arg == '-half':    opts['engine'] = 'half'
1223        elif arg == '-fast':    opts['engine'] = 'fast'
1224        elif arg == '-single':  opts['engine'] = 'single'
1225        elif arg == '-double':  opts['engine'] = 'double'
1226        elif arg == '-single!': opts['engine'] = 'single!'
1227        elif arg == '-double!': opts['engine'] = 'double!'
1228        elif arg == '-quad!':   opts['engine'] = 'quad!'
1229        elif arg == '-sasview': opts['engine'] = 'sasview'
1230        elif arg == '-edit':    opts['explore'] = True
1231        elif arg == '-demo':    opts['use_demo'] = True
1232        elif arg == '-default': opts['use_demo'] = False
1233        elif arg == '-weights': opts['show_weights'] = True
1234        elif arg == '-html':    opts['html'] = True
1235        elif arg == '-help':    opts['html'] = True
1236    # pylint: enable=bad-whitespace
1237
1238    # Magnetism forces 2D for now
1239    if opts['magnetic']:
1240        opts['is2d'] = True
1241
1242    # Force random if sets is used
1243    if opts['sets'] >= 1 and opts['seed'] < 0:
1244        opts['seed'] = np.random.randint(1000000)
1245    if opts['sets'] == 0:
1246        opts['sets'] = 1
1247
1248    # Create the computational engines
1249    if opts['datafile'] is not None:
1250        data = load_data(os.path.expanduser(opts['datafile']))
1251    else:
1252        data, _ = make_data(opts)
1253
1254    comparison = any(PAR_SPLIT in v for v in values)
1255    if PAR_SPLIT in name:
1256        names = name.split(PAR_SPLIT, 2)
1257        comparison = True
1258    else:
1259        names = [name]*2
1260    try:
1261        model_info = [core.load_model_info(k) for k in names]
1262    except ImportError as exc:
1263        print(str(exc))
1264        print("Could not find model; use one of:\n    " + models)
1265        return None
1266
1267    if PAR_SPLIT in opts['engine']:
1268        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
1269        comparison = True
1270    else:
1271        opts['engine'] = [opts['engine']]*2
1272
1273    if PAR_SPLIT in opts['count']:
1274        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
1275        comparison = True
1276    else:
1277        opts['count'] = [int(opts['count'])]*2
1278
1279    if PAR_SPLIT in opts['cutoff']:
1280        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1281        comparison = True
1282    else:
1283        opts['cutoff'] = [float(opts['cutoff'])]*2
1284
1285    base = make_engine(model_info[0], data, opts['engine'][0], opts['cutoff'][0])
1286    if comparison:
1287        comp = make_engine(model_info[1], data, opts['engine'][1], opts['cutoff'][1])
1288    else:
1289        comp = None
1290
1291    # pylint: disable=bad-whitespace
1292    # Remember it all
1293    opts.update({
1294        'data'      : data,
1295        'name'      : names,
1296        'info'      : model_info,
1297        'engines'   : [base, comp],
1298        'values'    : values,
1299    })
1300    # pylint: enable=bad-whitespace
1301
1302    # Set the integration parameters to the half sphere
1303    if opts['sphere'] > 0:
1304        set_spherical_integration_parameters(opts, opts['sphere'])
1305
1306    return opts
1307
1308def set_spherical_integration_parameters(opts, steps):
1309    """
1310    Set integration parameters for spherical integration over the entire
1311    surface in theta-phi coordinates.
1312    """
1313    # Set the integration parameters to the half sphere
1314    opts['values'].extend([
1315        #'theta=90',
1316        'theta_pd=%g'%(90/np.sqrt(3)),
1317        'theta_pd_n=%d'%steps,
1318        'theta_pd_type=rectangle',
1319        #'phi=0',
1320        'phi_pd=%g'%(180/np.sqrt(3)),
1321        'phi_pd_n=%d'%(2*steps),
1322        'phi_pd_type=rectangle',
1323        #'background=0',
1324    ])
1325    if 'psi' in opts['info'][0].parameters:
1326        #opts['values'].append('psi=0')
1327        pass
1328
1329def parse_pars(opts, maxdim=np.inf):
1330    model_info, model_info2 = opts['info']
1331
1332    # Get demo parameters from model definition, or use default parameters
1333    # if model does not define demo parameters
1334    pars = get_pars(model_info, opts['use_demo'])
1335    pars2 = get_pars(model_info2, opts['use_demo'])
1336    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1337    # randomize parameters
1338    #pars.update(set_pars)  # set value before random to control range
1339    if opts['seed'] > -1:
1340        pars = randomize_pars(model_info, pars)
1341        limit_dimensions(model_info, pars, maxdim)
1342        if model_info != model_info2:
1343            pars2 = randomize_pars(model_info2, pars2)
1344            limit_dimensions(model_info, pars2, maxdim)
1345            # Share values for parameters with the same name
1346            for k, v in pars.items():
1347                if k in pars2:
1348                    pars2[k] = v
1349        else:
1350            pars2 = pars.copy()
1351        constrain_pars(model_info, pars)
1352        constrain_pars(model_info2, pars2)
1353    pars = suppress_pd(pars, opts['mono'])
1354    pars2 = suppress_pd(pars2, opts['mono'])
1355    pars = suppress_magnetism(pars, not opts['magnetic'])
1356    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1357
1358    # Fill in parameters given on the command line
1359    presets = {}
1360    presets2 = {}
1361    for arg in opts['values']:
1362        k, v = arg.split('=', 1)
1363        if k not in pars and k not in pars2:
1364            # extract base name without polydispersity info
1365            s = set(p.split('_pd')[0] for p in pars)
1366            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1367            return None
1368        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
1369        if v1 and k in pars:
1370            presets[k] = float(v1) if isnumber(v1) else v1
1371        if v2 and k in pars2:
1372            presets2[k] = float(v2) if isnumber(v2) else v2
1373
1374    # If pd given on the command line, default pd_n to 35
1375    for k, v in list(presets.items()):
1376        if k.endswith('_pd'):
1377            presets.setdefault(k+'_n', 35.)
1378    for k, v in list(presets2.items()):
1379        if k.endswith('_pd'):
1380            presets2.setdefault(k+'_n', 35.)
1381
1382    # Evaluate preset parameter expressions
1383    context = MATH.copy()
1384    context['np'] = np
1385    context.update(pars)
1386    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1387    for k, v in presets.items():
1388        if not isinstance(v, float) and not k.endswith('_type'):
1389            presets[k] = eval(v, context)
1390    context.update(presets)
1391    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1392    for k, v in presets2.items():
1393        if not isinstance(v, float) and not k.endswith('_type'):
1394            presets2[k] = eval(v, context)
1395
1396    # update parameters with presets
1397    pars.update(presets)  # set value after random to control value
1398    pars2.update(presets2)  # set value after random to control value
1399    #import pprint; pprint.pprint(model_info)
1400
1401    if opts['show_pars']:
1402        if model_info.name != model_info2.name or pars != pars2:
1403            print("==== %s ====="%model_info.name)
1404            print(str(parlist(model_info, pars, opts['is2d'])))
1405            print("==== %s ====="%model_info2.name)
1406            print(str(parlist(model_info2, pars2, opts['is2d'])))
1407        else:
1408            print(str(parlist(model_info, pars, opts['is2d'])))
1409
1410    return pars, pars2
1411
1412def show_docs(opts):
1413    # type: (Dict[str, Any]) -> None
1414    """
1415    show html docs for the model
1416    """
1417    import os
1418    from .generate import make_html
1419    from . import rst2html
1420
1421    info = opts['info'][0]
1422    html = make_html(info)
1423    path = os.path.dirname(info.filename)
1424    url = "file://"+path.replace("\\","/")[2:]+"/"
1425    rst2html.view_html_qtapp(html, url)
1426
1427def explore(opts):
1428    # type: (Dict[str, Any]) -> None
1429    """
1430    explore the model using the bumps gui.
1431    """
1432    import wx  # type: ignore
1433    from bumps.names import FitProblem  # type: ignore
1434    from bumps.gui.app_frame import AppFrame  # type: ignore
1435    from bumps.gui import signal
1436
1437    is_mac = "cocoa" in wx.version()
1438    # Create an app if not running embedded
1439    app = wx.App() if wx.GetApp() is None else None
1440    model = Explore(opts)
1441    problem = FitProblem(model)
1442    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1443    if not is_mac:
1444        frame.Show()
1445    frame.panel.set_model(model=problem)
1446    frame.panel.Layout()
1447    frame.panel.aui.Split(0, wx.TOP)
1448    def reset_parameters(event):
1449        model.revert_values()
1450        signal.update_parameters(problem)
1451    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1452    if is_mac: frame.Show()
1453    # If running withing an app, start the main loop
1454    if app:
1455        app.MainLoop()
1456
1457class Explore(object):
1458    """
1459    Bumps wrapper for a SAS model comparison.
1460
1461    The resulting object can be used as a Bumps fit problem so that
1462    parameters can be adjusted in the GUI, with plots updated on the fly.
1463    """
1464    def __init__(self, opts):
1465        # type: (Dict[str, Any]) -> None
1466        from bumps.cli import config_matplotlib  # type: ignore
1467        from . import bumps_model
1468        config_matplotlib()
1469        self.opts = opts
1470        opts['pars'] = list(opts['pars'])
1471        p1, p2 = opts['pars']
1472        m1, m2 = opts['info']
1473        self.fix_p2 = m1 != m2 or p1 != p2
1474        model_info = m1
1475        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1476        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1477        if not opts['is2d']:
1478            for p in model_info.parameters.user_parameters({}, is2d=False):
1479                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1480                    k = p.name+ext
1481                    v = pars.get(k, None)
1482                    if v is not None:
1483                        v.range(*parameter_range(k, v.value))
1484        else:
1485            for k, v in pars.items():
1486                v.range(*parameter_range(k, v.value))
1487
1488        self.pars = pars
1489        self.starting_values = dict((k, v.value) for k, v in pars.items())
1490        self.pd_types = pd_types
1491        self.limits = None
1492
1493    def revert_values(self):
1494        for k, v in self.starting_values.items():
1495            self.pars[k].value = v
1496
1497    def model_update(self):
1498        pass
1499
1500    def numpoints(self):
1501        # type: () -> int
1502        """
1503        Return the number of points.
1504        """
1505        return len(self.pars) + 1  # so dof is 1
1506
1507    def parameters(self):
1508        # type: () -> Any   # Dict/List hierarchy of parameters
1509        """
1510        Return a dictionary of parameters.
1511        """
1512        return self.pars
1513
1514    def nllf(self):
1515        # type: () -> float
1516        """
1517        Return cost.
1518        """
1519        # pylint: disable=no-self-use
1520        return 0.  # No nllf
1521
1522    def plot(self, view='log'):
1523        # type: (str) -> None
1524        """
1525        Plot the data and residuals.
1526        """
1527        pars = dict((k, v.value) for k, v in self.pars.items())
1528        pars.update(self.pd_types)
1529        self.opts['pars'][0] = pars
1530        if not self.fix_p2:
1531            self.opts['pars'][1] = pars
1532        result = run_models(self.opts)
1533        limits = plot_models(self.opts, result, limits=self.limits)
1534        if self.limits is None:
1535            vmin, vmax = limits
1536            self.limits = vmax*1e-7, 1.3*vmax
1537            import pylab; pylab.clf()
1538            plot_models(self.opts, result, limits=self.limits)
1539
1540
1541def main(*argv):
1542    # type: (*str) -> None
1543    """
1544    Main program.
1545    """
1546    opts = parse_opts(argv)
1547    if opts is not None:
1548        if opts['seed'] > -1:
1549            print("Randomize using -random=%i"%opts['seed'])
1550            np.random.seed(opts['seed'])
1551        if opts['html']:
1552            show_docs(opts)
1553        elif opts['explore']:
1554            opts['pars'] = parse_pars(opts)
1555            if opts['pars'] is None:
1556                return
1557            explore(opts)
1558        else:
1559            compare(opts)
1560
1561if __name__ == "__main__":
1562    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.