source: sasmodels/sasmodels/compare.py @ e3571cb

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since e3571cb was e3571cb, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

allow comparison of 1D with integrated 2D

  • Property mode set to 100755
File size: 54.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel, get_mesh
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47from .weights import plot_weights
48
49try:
50    from typing import Optional, Dict, Any, Callable, Tuple
51except Exception:
52    pass
53else:
54    from .modelinfo import ModelInfo, Parameter, ParameterSet
55    from .data import Data
56    Calculator = Callable[[float], np.ndarray]
57
58USAGE = """
59usage: sascomp model [options...] [key=val]
60
61Generate and compare SAS models.  If a single model is specified it shows
62a plot of that model.  Different models can be compared, or the same model
63with different parameters.  The same model with the same parameters can
64be compared with different calculation engines to see the effects of precision
65on the resultant values.
66
67model or model1,model2 are the names of the models to compare (see below).
68
69Options (* for default):
70
71    === data generation ===
72    -data="path" uses q, dq from the data file
73    -noise=0 sets the measurement error dI/I
74    -res=0 sets the resolution width dQ/Q if calculating with resolution
75    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
76    -nq=128 sets the number of Q points in the data set
77    -1d*/-2d computes 1d or 2d data
78    -zero indicates that q=0 should be included
79
80    === model parameters ===
81    -preset*/-random[=seed] preset or random parameters
82    -sets=n generates n random datasets with the seed given by -random=seed
83    -pars/-nopars* prints the parameter set or not
84    -default/-demo* use demo vs default parameters
85    -sphere[=150] set up spherical integration over theta/phi using n points
86
87    === calculation options ===
88    -mono*/-poly force monodisperse or allow polydisperse random parameters
89    -cutoff=1e-5* cutoff value for including a point in polydispersity
90    -magnetic/-nonmagnetic* suppress magnetism
91    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
92    -neval=1 sets the number of evals for more accurate timing
93
94    === precision options ===
95    -engine=default uses the default calcution precision
96    -single/-double/-half/-fast sets an OpenCL calculation engine
97    -single!/-double!/-quad! sets an OpenMP calculation engine
98    -sasview sets the sasview calculation engine
99
100    === plotting ===
101    -plot*/-noplot plots or suppress the plot of the model
102    -linear/-log*/-q4 intensity scaling on plots
103    -hist/-nohist* plot histogram of relative error
104    -abs/-rel* plot relative or absolute error
105    -title="note" adds note to the plot title, after the model name
106    -weights shows weights plots for the polydisperse parameters
107
108    === output options ===
109    -edit starts the parameter explorer
110    -help/-html shows the model docs instead of running the model
111
112The interpretation of quad precision depends on architecture, and may
113vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
114On unix and mac you may need single quotes around the DLL computation
115engines, such as -engine='single!,double!' since !, is treated as a history
116expansion request in the shell.
117
118Key=value pairs allow you to set specific values for the model parameters.
119Key=value1,value2 to compare different values of the same parameter. The
120value can be an expression including other parameters.
121
122Items later on the command line override those that appear earlier.
123
124Examples:
125
126    # compare single and double precision calculation for a barbell
127    sascomp barbell -engine=single,double
128
129    # generate 10 random lorentz models, with seed=27
130    sascomp lorentz -sets=10 -seed=27
131
132    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
133    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
134
135    # model timing test requires multiple evals to perform the estimate
136    sascomp pringle -engine=single,double -timing=100,100 -noplot
137"""
138
139# Update docs with command line usage string.   This is separate from the usual
140# doc string so that we can display it at run time if there is an error.
141# lin
142__doc__ = (__doc__  # pylint: disable=redefined-builtin
143           + """
144Program description
145-------------------
146
147""" + USAGE)
148
149kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
150
151# list of math functions for use in evaluating parameters
152MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
153
154# CRUFT python 2.6
155if not hasattr(datetime.timedelta, 'total_seconds'):
156    def delay(dt):
157        """Return number date-time delta as number seconds"""
158        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
159else:
160    def delay(dt):
161        """Return number date-time delta as number seconds"""
162        return dt.total_seconds()
163
164
165class push_seed(object):
166    """
167    Set the seed value for the random number generator.
168
169    When used in a with statement, the random number generator state is
170    restored after the with statement is complete.
171
172    :Parameters:
173
174    *seed* : int or array_like, optional
175        Seed for RandomState
176
177    :Example:
178
179    Seed can be used directly to set the seed::
180
181        >>> from numpy.random import randint
182        >>> push_seed(24)
183        <...push_seed object at...>
184        >>> print(randint(0,1000000,3))
185        [242082    899 211136]
186
187    Seed can also be used in a with statement, which sets the random
188    number generator state for the enclosed computations and restores
189    it to the previous state on completion::
190
191        >>> with push_seed(24):
192        ...    print(randint(0,1000000,3))
193        [242082    899 211136]
194
195    Using nested contexts, we can demonstrate that state is indeed
196    restored after the block completes::
197
198        >>> with push_seed(24):
199        ...    print(randint(0,1000000))
200        ...    with push_seed(24):
201        ...        print(randint(0,1000000,3))
202        ...    print(randint(0,1000000))
203        242082
204        [242082    899 211136]
205        899
206
207    The restore step is protected against exceptions in the block::
208
209        >>> with push_seed(24):
210        ...    print(randint(0,1000000))
211        ...    try:
212        ...        with push_seed(24):
213        ...            print(randint(0,1000000,3))
214        ...            raise Exception()
215        ...    except Exception:
216        ...        print("Exception raised")
217        ...    print(randint(0,1000000))
218        242082
219        [242082    899 211136]
220        Exception raised
221        899
222    """
223    def __init__(self, seed=None):
224        # type: (Optional[int]) -> None
225        self._state = np.random.get_state()
226        np.random.seed(seed)
227
228    def __enter__(self):
229        # type: () -> None
230        pass
231
232    def __exit__(self, exc_type, exc_value, traceback):
233        # type: (Any, BaseException, Any) -> None
234        # TODO: better typing for __exit__ method
235        np.random.set_state(self._state)
236
237def tic():
238    # type: () -> Callable[[], float]
239    """
240    Timer function.
241
242    Use "toc=tic()" to start the clock and "toc()" to measure
243    a time interval.
244    """
245    then = datetime.datetime.now()
246    return lambda: delay(datetime.datetime.now() - then)
247
248
249def set_beam_stop(data, radius, outer=None):
250    # type: (Data, float, float) -> None
251    """
252    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
253
254    Note: this function does not require sasview
255    """
256    if hasattr(data, 'qx_data'):
257        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
258        data.mask = (q < radius)
259        if outer is not None:
260            data.mask |= (q >= outer)
261    else:
262        data.mask = (data.x < radius)
263        if outer is not None:
264            data.mask |= (data.x >= outer)
265
266
267def parameter_range(p, v):
268    # type: (str, float) -> Tuple[float, float]
269    """
270    Choose a parameter range based on parameter name and initial value.
271    """
272    # process the polydispersity options
273    if p.endswith('_pd_n'):
274        return 0., 100.
275    elif p.endswith('_pd_nsigma'):
276        return 0., 5.
277    elif p.endswith('_pd_type'):
278        raise ValueError("Cannot return a range for a string value")
279    elif any(s in p for s in ('theta', 'phi', 'psi')):
280        # orientation in [-180,180], orientation pd in [0,45]
281        if p.endswith('_pd'):
282            return 0., 180.
283        else:
284            return -180., 180.
285    elif p.endswith('_pd'):
286        return 0., 1.
287    elif 'sld' in p:
288        return -0.5, 10.
289    elif p == 'background':
290        return 0., 10.
291    elif p == 'scale':
292        return 0., 1.e3
293    elif v < 0.:
294        return 2.*v, -2.*v
295    else:
296        return 0., (2.*v if v > 0. else 1.)
297
298
299def _randomize_one(model_info, name, value):
300    # type: (ModelInfo, str, float) -> float
301    # type: (ModelInfo, str, str) -> str
302    """
303    Randomize a single parameter.
304    """
305    # Set the amount of polydispersity/angular dispersion, but by default pd_n
306    # is zero so there is no polydispersity.  This allows us to turn on/off
307    # pd by setting pd_n, and still have randomly generated values
308    if name.endswith('_pd'):
309        par = model_info.parameters[name[:-3]]
310        if par.type == 'orientation':
311            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
312            return 180*np.random.beta(2.5, 20)
313        else:
314            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
315            return np.random.beta(1.5, 7)
316
317    # pd is selected globally rather than per parameter, so set to 0 for no pd
318    # In particular, when multiple pd dimensions, want to decrease the number
319    # of points per dimension for faster computation
320    if name.endswith('_pd_n'):
321        return 0
322
323    # Don't mess with distribution type for now
324    if name.endswith('_pd_type'):
325        return 'gaussian'
326
327    # type-dependent value of number of sigmas; for gaussian use 3.
328    if name.endswith('_pd_nsigma'):
329        return 3.
330
331    # background in the range [0.01, 1]
332    if name == 'background':
333        return 10**np.random.uniform(-2, 0)
334
335    # scale defaults to 0.1% to 30% volume fraction
336    if name == 'scale':
337        return 10**np.random.uniform(-3, -0.5)
338
339    # If it is a list of choices, pick one at random with equal probability
340    # In practice, the model specific random generator will override.
341    par = model_info.parameters[name]
342    if len(par.limits) > 2:  # choice list
343        return np.random.randint(len(par.limits))
344
345    # If it is a fixed range, pick from it with equal probability.
346    # For logarithmic ranges, the model will have to override.
347    if np.isfinite(par.limits).all():
348        return np.random.uniform(*par.limits)
349
350    # If the paramter is marked as an sld use the range of neutron slds
351    # TODO: ought to randomly contrast match a pair of SLDs
352    if par.type == 'sld':
353        return np.random.uniform(-0.5, 12)
354
355    # Limit magnetic SLDs to a smaller range, from zero to iron=5/A^2
356    if par.name.startswith('M0:'):
357        return np.random.uniform(0, 5)
358
359    # Guess at the random length/radius/thickness.  In practice, all models
360    # are going to set their own reasonable ranges.
361    if par.type == 'volume':
362        if ('length' in par.name or
363                'radius' in par.name or
364                'thick' in par.name):
365            return 10**np.random.uniform(2, 4)
366
367    # In the absence of any other info, select a value in [0, 2v], or
368    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
369    # model random parameter generators will override this default.
370    low, high = parameter_range(par.name, value)
371    limits = (max(par.limits[0], low), min(par.limits[1], high))
372    return np.random.uniform(*limits)
373
374def _random_pd(model_info, pars):
375    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
376    pd_volume = []
377    pd_oriented = []
378    for p in pd:
379        if p.type == 'orientation':
380            pd_oriented.append(p.name)
381        elif p.length_control is not None:
382            n = int(pars.get(p.length_control, 1) + 0.5)
383            pd_volume.extend(p.name+str(k+1) for k in range(n))
384        elif p.length > 1:
385            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
386        else:
387            pd_volume.append(p.name)
388    u = np.random.rand()
389    n = len(pd_volume)
390    if u < 0.01 or n < 1:
391        pass  # 1% chance of no polydispersity
392    elif u < 0.86 or n < 2:
393        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
394    elif u < 0.99 or n < 3:
395        choices = np.random.choice(len(pd_volume), size=2)
396        pars[pd_volume[choices[0]]+"_pd_n"] = 25
397        pars[pd_volume[choices[1]]+"_pd_n"] = 10
398    else:
399        choices = np.random.choice(len(pd_volume), size=3)
400        pars[pd_volume[choices[0]]+"_pd_n"] = 25
401        pars[pd_volume[choices[1]]+"_pd_n"] = 10
402        pars[pd_volume[choices[2]]+"_pd_n"] = 5
403    if pd_oriented:
404        pars['theta_pd_n'] = 20
405        if np.random.rand() < 0.1:
406            pars['phi_pd_n'] = 5
407        if np.random.rand() < 0.1:
408            if any(p.name == 'psi' for p in model_info.parameters.kernel_parameters):
409                #print("generating psi_pd_n")
410                pars['psi_pd_n'] = 5
411
412    ## Show selected polydispersity
413    #for name, value in pars.items():
414    #    if name.endswith('_pd_n') and value > 0:
415    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
416
417
418def randomize_pars(model_info, pars):
419    # type: (ModelInfo, ParameterSet) -> ParameterSet
420    """
421    Generate random values for all of the parameters.
422
423    Valid ranges for the random number generator are guessed from the name of
424    the parameter; this will not account for constraints such as cap radius
425    greater than cylinder radius in the capped_cylinder model, so
426    :func:`constrain_pars` needs to be called afterward..
427    """
428    # Note: the sort guarantees order of calls to random number generator
429    random_pars = dict((p, _randomize_one(model_info, p, v))
430                       for p, v in sorted(pars.items()))
431    if model_info.random is not None:
432        random_pars.update(model_info.random())
433    _random_pd(model_info, random_pars)
434    return random_pars
435
436
437def limit_dimensions(model_info, pars, maxdim):
438    # type: (ModelInfo, ParameterSet, float) -> None
439    """
440    Limit parameters of units of Ang to maxdim.
441    """
442    for p in model_info.parameters.call_parameters:
443        value = pars[p.name]
444        if p.units == 'Ang' and value > maxdim:
445            pars[p.name] = maxdim*10**np.random.uniform(-3,0)
446
447def constrain_pars(model_info, pars):
448    # type: (ModelInfo, ParameterSet) -> None
449    """
450    Restrict parameters to valid values.
451
452    This includes model specific code for models such as capped_cylinder
453    which need to support within model constraints (cap radius more than
454    cylinder radius in this case).
455
456    Warning: this updates the *pars* dictionary in place.
457    """
458    # TODO: move the model specific code to the individual models
459    name = model_info.id
460    # if it is a product model, then just look at the form factor since
461    # none of the structure factors need any constraints.
462    if '*' in name:
463        name = name.split('*')[0]
464
465    # Suppress magnetism for python models (not yet implemented)
466    if callable(model_info.Iq):
467        pars.update(suppress_magnetism(pars))
468
469    if name == 'barbell':
470        if pars['radius_bell'] < pars['radius']:
471            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
472
473    elif name == 'capped_cylinder':
474        if pars['radius_cap'] < pars['radius']:
475            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
476
477    elif name == 'guinier':
478        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
479        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
480        # => ln A - (Rg^2 q^2/3) > -30 ln 10
481        # => Rg^2 q^2/3 < 30 ln 10 + ln A
482        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
483        #q_max = 0.2  # mid q maximum
484        q_max = 1.0  # high q maximum
485        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
486        pars['rg'] = min(pars['rg'], rg_max)
487
488    elif name == 'pearl_necklace':
489        if pars['radius'] < pars['thick_string']:
490            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
491        pass
492
493    elif name == 'rpa':
494        # Make sure phi sums to 1.0
495        if pars['case_num'] < 2:
496            pars['Phi1'] = 0.
497            pars['Phi2'] = 0.
498        elif pars['case_num'] < 5:
499            pars['Phi1'] = 0.
500        total = sum(pars['Phi'+c] for c in '1234')
501        for c in '1234':
502            pars['Phi'+c] /= total
503
504def parlist(model_info, pars, is2d):
505    # type: (ModelInfo, ParameterSet, bool) -> str
506    """
507    Format the parameter list for printing.
508    """
509    is2d = True
510    lines = []
511    parameters = model_info.parameters
512    magnetic = False
513    magnetic_pars = []
514    for p in parameters.user_parameters(pars, is2d):
515        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
516            continue
517        if p.id.startswith('up:'):
518            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
519            continue
520        fields = dict(
521            value=pars.get(p.id, p.default),
522            pd=pars.get(p.id+"_pd", 0.),
523            n=int(pars.get(p.id+"_pd_n", 0)),
524            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
525            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
526            relative_pd=p.relative_pd,
527            M0=pars.get('M0:'+p.id, 0.),
528            mphi=pars.get('mphi:'+p.id, 0.),
529            mtheta=pars.get('mtheta:'+p.id, 0.),
530        )
531        lines.append(_format_par(p.name, **fields))
532        magnetic = magnetic or fields['M0'] != 0.
533    if magnetic and magnetic_pars:
534        lines.append(" ".join(magnetic_pars))
535    return "\n".join(lines)
536
537    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
538
539def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
540                relative_pd=False, M0=0., mphi=0., mtheta=0.):
541    # type: (str, float, float, int, float, str) -> str
542    line = "%s: %g"%(name, value)
543    if pd != 0.  and n != 0:
544        if relative_pd:
545            pd *= value
546        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
547                % (pd, n, nsigma, nsigma, pdtype)
548    if M0 != 0.:
549        line += "  M0:%.3f  mtheta:%.1f  mphi:%.1f" % (M0, mtheta, mphi)
550    return line
551
552def suppress_pd(pars, suppress=True):
553    # type: (ParameterSet) -> ParameterSet
554    """
555    If suppress is True complete eliminate polydispersity of the model to test
556    models more quickly.  If suppress is False, make sure at least one
557    parameter is polydisperse, setting the first polydispersity parameter to
558    15% if no polydispersity is given (with no explicit demo parameters given
559    in the model, there will be no default polydispersity).
560    """
561    pars = pars.copy()
562    #print("pars=", pars)
563    if suppress:
564        for p in pars:
565            if p.endswith("_pd_n"):
566                pars[p] = 0
567    else:
568        any_pd = False
569        first_pd = None
570        for p in pars:
571            if p.endswith("_pd_n"):
572                pd = pars.get(p[:-2], 0.)
573                any_pd |= (pars[p] != 0 and pd != 0.)
574                if first_pd is None:
575                    first_pd = p
576        if not any_pd and first_pd is not None:
577            if pars[first_pd] == 0:
578                pars[first_pd] = 35
579            if first_pd[:-2] not in pars or pars[first_pd[:-2]] == 0:
580                pars[first_pd[:-2]] = 0.15
581    return pars
582
583def suppress_magnetism(pars, suppress=True):
584    # type: (ParameterSet) -> ParameterSet
585    """
586    If suppress is True complete eliminate magnetism of the model to test
587    models more quickly.  If suppress is False, make sure at least one sld
588    parameter is magnetic, setting the first parameter to have a strong
589    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
590    in the model, there will be no default magnetism).
591    """
592    pars = pars.copy()
593    if suppress:
594        for p in pars:
595            if p.startswith("M0:"):
596                pars[p] = 0
597    else:
598        any_mag = False
599        first_mag = None
600        for p in pars:
601            if p.startswith("M0:"):
602                any_mag |= (pars[p] != 0)
603                if first_mag is None:
604                    first_mag = p
605        if not any_mag and first_mag is not None:
606            pars[first_mag] = 8.
607    return pars
608
609def eval_sasview(model_info, data):
610    # type: (Modelinfo, Data) -> Calculator
611    """
612    Return a model calculator using the pre-4.0 SasView models.
613    """
614    # importing sas here so that the error message will be that sas failed to
615    # import rather than the more obscure smear_selection not imported error
616    import sas
617    import sas.models
618    from sas.models.qsmearing import smear_selection
619    from sas.models.MultiplicationModel import MultiplicationModel
620    from sas.models.dispersion_models import models as dispersers
621
622    def get_model_class(name):
623        # type: (str) -> "sas.models.BaseComponent"
624        #print("new",sorted(_pars.items()))
625        __import__('sas.models.' + name)
626        ModelClass = getattr(getattr(sas.models, name, None), name, None)
627        if ModelClass is None:
628            raise ValueError("could not find model %r in sas.models"%name)
629        return ModelClass
630
631    # WARNING: ugly hack when handling model!
632    # Sasview models with multiplicity need to be created with the target
633    # multiplicity, so we cannot create the target model ahead of time for
634    # for multiplicity models.  Instead we store the model in a list and
635    # update the first element of that list with the new multiplicity model
636    # every time we evaluate.
637
638    # grab the sasview model, or create it if it is a product model
639    if model_info.composition:
640        composition_type, parts = model_info.composition
641        if composition_type == 'product':
642            P, S = [get_model_class(revert_name(p))() for p in parts]
643            model = [MultiplicationModel(P, S)]
644        else:
645            raise ValueError("sasview mixture models not supported by compare")
646    else:
647        old_name = revert_name(model_info)
648        if old_name is None:
649            raise ValueError("model %r does not exist in old sasview"
650                            % model_info.id)
651        ModelClass = get_model_class(old_name)
652        model = [ModelClass()]
653    model[0].disperser_handles = {}
654
655    # build a smearer with which to call the model, if necessary
656    smearer = smear_selection(data, model=model)
657    if hasattr(data, 'qx_data'):
658        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
659        index = ((~data.mask) & (~np.isnan(data.data))
660                 & (q >= data.qmin) & (q <= data.qmax))
661        if smearer is not None:
662            smearer.model = model  # because smear_selection has a bug
663            smearer.accuracy = data.accuracy
664            smearer.set_index(index)
665            def _call_smearer():
666                smearer.model = model[0]
667                return smearer.get_value()
668            theory = _call_smearer
669        else:
670            theory = lambda: model[0].evalDistribution([data.qx_data[index],
671                                                        data.qy_data[index]])
672    elif smearer is not None:
673        theory = lambda: smearer(model[0].evalDistribution(data.x))
674    else:
675        theory = lambda: model[0].evalDistribution(data.x)
676
677    def calculator(**pars):
678        # type: (float, ...) -> np.ndarray
679        """
680        Sasview calculator for model.
681        """
682        oldpars = revert_pars(model_info, pars)
683        # For multiplicity models, create a model with the correct multiplicity
684        control = oldpars.pop("CONTROL", None)
685        if control is not None:
686            # sphericalSLD has one fewer multiplicity.  This update should
687            # happen in revert_pars, but it hasn't been called yet.
688            model[0] = ModelClass(control)
689        # paying for parameter conversion each time to keep life simple, if not fast
690        for k, v in oldpars.items():
691            if k.endswith('.type'):
692                par = k[:-5]
693                if v == 'gaussian': continue
694                cls = dispersers[v if v != 'rectangle' else 'rectangula']
695                handle = cls()
696                model[0].disperser_handles[par] = handle
697                try:
698                    model[0].set_dispersion(par, handle)
699                except Exception:
700                    exception.annotate_exception("while setting %s to %r"
701                                                 %(par, v))
702                    raise
703
704
705        #print("sasview pars",oldpars)
706        for k, v in oldpars.items():
707            name_attr = k.split('.')  # polydispersity components
708            if len(name_attr) == 2:
709                par, disp_par = name_attr
710                model[0].dispersion[par][disp_par] = v
711            else:
712                model[0].setParam(k, v)
713        return theory()
714
715    calculator.engine = "sasview"
716    return calculator
717
718DTYPE_MAP = {
719    'half': '16',
720    'fast': 'fast',
721    'single': '32',
722    'double': '64',
723    'quad': '128',
724    'f16': '16',
725    'f32': '32',
726    'f64': '64',
727    'float16': '16',
728    'float32': '32',
729    'float64': '64',
730    'float128': '128',
731    'longdouble': '128',
732}
733def eval_opencl(model_info, data, dtype='single', cutoff=0.):
734    # type: (ModelInfo, Data, str, float) -> Calculator
735    """
736    Return a model calculator using the OpenCL calculation engine.
737    """
738    if not core.HAVE_OPENCL:
739        raise RuntimeError("OpenCL not available")
740    model = core.build_model(model_info, dtype=dtype, platform="ocl")
741    calculator = DirectModel(data, model, cutoff=cutoff)
742    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
743    return calculator
744
745def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
746    # type: (ModelInfo, Data, str, float) -> Calculator
747    """
748    Return a model calculator using the DLL calculation engine.
749    """
750    model = core.build_model(model_info, dtype=dtype, platform="dll")
751    calculator = DirectModel(data, model, cutoff=cutoff)
752    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
753    return calculator
754
755def time_calculation(calculator, pars, evals=1):
756    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
757    """
758    Compute the average calculation time over N evaluations.
759
760    An additional call is generated without polydispersity in order to
761    initialize the calculation engine, and make the average more stable.
762    """
763    # initialize the code so time is more accurate
764    if evals > 1:
765        calculator(**suppress_pd(pars))
766    toc = tic()
767    # make sure there is at least one eval
768    value = calculator(**pars)
769    for _ in range(evals-1):
770        value = calculator(**pars)
771    average_time = toc()*1000. / evals
772    #print("I(q)",value)
773    return value, average_time
774
775def make_data(opts):
776    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
777    """
778    Generate an empty dataset, used with the model to set Q points
779    and resolution.
780
781    *opts* contains the options, with 'qmax', 'nq', 'res',
782    'accuracy', 'is2d' and 'view' parsed from the command line.
783    """
784    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
785    if opts['is2d']:
786        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
787        data = empty_data2D(q, resolution=res)
788        data.accuracy = opts['accuracy']
789        set_beam_stop(data, 0.0004)
790        index = ~data.mask
791    else:
792        if opts['view'] == 'log' and not opts['zero']:
793            qmax = math.log10(qmax)
794            q = np.logspace(qmax-3, qmax, nq)
795        else:
796            q = np.linspace(0.001*qmax, qmax, nq)
797        if opts['zero']:
798            q = np.hstack((0, q))
799        data = empty_data1D(q, resolution=res)
800        index = slice(None, None)
801    return data, index
802
803def make_engine(model_info, data, dtype, cutoff):
804    # type: (ModelInfo, Data, str, float) -> Calculator
805    """
806    Generate the appropriate calculation engine for the given datatype.
807
808    Datatypes with '!' appended are evaluated using external C DLLs rather
809    than OpenCL.
810    """
811    if dtype == 'sasview':
812        return eval_sasview(model_info, data)
813    elif dtype is None or not dtype.endswith('!'):
814        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
815    else:
816        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
817
818def _show_invalid(data, theory):
819    # type: (Data, np.ma.ndarray) -> None
820    """
821    Display a list of the non-finite values in theory.
822    """
823    if not theory.mask.any():
824        return
825
826    if hasattr(data, 'x'):
827        bad = zip(data.x[theory.mask], theory[theory.mask])
828        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
829
830
831def compare(opts, limits=None, maxdim=np.inf):
832    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
833    """
834    Preform a comparison using options from the command line.
835
836    *limits* are the limits on the values to use, either to set the y-axis
837    for 1D or to set the colormap scale for 2D.  If None, then they are
838    inferred from the data and returned. When exploring using Bumps,
839    the limits are set when the model is initially called, and maintained
840    as the values are adjusted, making it easier to see the effects of the
841    parameters.
842
843    *maxdim* is the maximum value for any parameter with units of Angstrom.
844    """
845    for k in range(opts['sets']):
846        if k > 1:
847            # print a separate seed for each dataset for better reproducibility
848            new_seed = np.random.randint(1000000)
849            print("Set %d uses -random=%i"%(k+1,new_seed))
850            np.random.seed(new_seed)
851        opts['pars'] = parse_pars(opts, maxdim=maxdim)
852        if opts['pars'] is None:
853            return
854        result = run_models(opts, verbose=True)
855        if opts['plot']:
856            limits = plot_models(opts, result, limits=limits, setnum=k)
857        if opts['show_weights']:
858            base, _ = opts['engines']
859            base_pars, _ = opts['pars']
860            model_info = base._kernel.info
861            dim = base._kernel.dim
862            plot_weights(model_info, get_mesh(model_info, base_pars, dim=dim))
863    if opts['plot']:
864        import matplotlib.pyplot as plt
865        plt.show()
866    return limits
867
868def run_models(opts, verbose=False):
869    # type: (Dict[str, Any]) -> Dict[str, Any]
870
871    base, comp = opts['engines']
872    base_n, comp_n = opts['count']
873    base_pars, comp_pars = opts['pars']
874    data = opts['data']
875
876    comparison = comp is not None
877
878    base_time = comp_time = None
879    base_value = comp_value = resid = relerr = None
880
881    # Base calculation
882    try:
883        base_raw, base_time = time_calculation(base, base_pars, base_n)
884        base_value = np.ma.masked_invalid(base_raw)
885        if verbose:
886            print("%s t=%.2f ms, intensity=%.0f"
887                  % (base.engine, base_time, base_value.sum()))
888        _show_invalid(data, base_value)
889    except ImportError:
890        traceback.print_exc()
891
892    # Comparison calculation
893    if comparison:
894        try:
895            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
896            comp_value = np.ma.masked_invalid(comp_raw)
897            if verbose:
898                print("%s t=%.2f ms, intensity=%.0f"
899                      % (comp.engine, comp_time, comp_value.sum()))
900            _show_invalid(data, comp_value)
901        except ImportError:
902            traceback.print_exc()
903
904    # Compare, but only if computing both forms
905    if comparison:
906        resid = (base_value - comp_value)
907        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
908        if verbose:
909            _print_stats("|%s-%s|"
910                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
911                         resid)
912            _print_stats("|(%s-%s)/%s|"
913                         % (base.engine, comp.engine, comp.engine),
914                         relerr)
915
916    return dict(base_value=base_value, comp_value=comp_value,
917                base_time=base_time, comp_time=comp_time,
918                resid=resid, relerr=relerr)
919
920
921def _print_stats(label, err):
922    # type: (str, np.ma.ndarray) -> None
923    # work with trimmed data, not the full set
924    sorted_err = np.sort(abs(err.compressed()))
925    if len(sorted_err) == 0.:
926        print(label + "  no valid values")
927        return
928
929    p50 = int((len(sorted_err)-1)*0.50)
930    p98 = int((len(sorted_err)-1)*0.98)
931    data = [
932        "max:%.3e"%sorted_err[-1],
933        "median:%.3e"%sorted_err[p50],
934        "98%%:%.3e"%sorted_err[p98],
935        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
936        "zero-offset:%+.3e"%np.mean(sorted_err),
937        ]
938    print(label+"  "+"  ".join(data))
939
940
941def plot_models(opts, result, limits=None, setnum=0):
942    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
943    import matplotlib.pyplot as plt
944
945    base_value, comp_value = result['base_value'], result['comp_value']
946    base_time, comp_time = result['base_time'], result['comp_time']
947    resid, relerr = result['resid'], result['relerr']
948
949    have_base, have_comp = (base_value is not None), (comp_value is not None)
950    base, comp = opts['engines']
951    data = opts['data']
952    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
953
954    # Plot if requested
955    view = opts['view']
956    if limits is None:
957        vmin, vmax = np.inf, -np.inf
958        if have_base:
959            vmin = min(vmin, base_value.min())
960            vmax = max(vmax, base_value.max())
961        if have_comp:
962            vmin = min(vmin, comp_value.min())
963            vmax = max(vmax, comp_value.max())
964        limits = vmin, vmax
965
966    if have_base:
967        if have_comp:
968            plt.subplot(131)
969        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
970        plt.title("%s t=%.2f ms"%(base.engine, base_time))
971        #cbar_title = "log I"
972    if have_comp:
973        if have_base:
974            plt.subplot(132)
975        if not opts['is2d'] and have_base:
976            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
977        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
978        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
979        #cbar_title = "log I"
980    if have_base and have_comp:
981        plt.subplot(133)
982        if not opts['rel_err']:
983            err, errstr, errview = resid, "abs err", "linear"
984        else:
985            err, errstr, errview = abs(relerr), "rel err", "log"
986        if 0:  # 95% cutoff
987            sorted = np.sort(err.flatten())
988            cutoff = sorted[int(sorted.size*0.95)]
989            err[err > cutoff] = cutoff
990        #err,errstr = base/comp,"ratio"
991        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
992        plt.xscale('log' if view == 'log' else linear)
993        plt.legend(['P%d'%(k+1) for k in range(setnum+1)], loc='best')
994        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
995        #cbar_title = errstr if errview=="linear" else "log "+errstr
996    #if is2D:
997    #    h = plt.colorbar()
998    #    h.ax.set_title(cbar_title)
999    fig = plt.gcf()
1000    extra_title = ' '+opts['title'] if opts['title'] else ''
1001    fig.suptitle(":".join(opts['name']) + extra_title)
1002
1003    if have_base and have_comp and opts['show_hist']:
1004        plt.figure()
1005        v = relerr
1006        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
1007        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
1008        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
1009                   % (base.engine, comp.engine, comp.engine))
1010        plt.ylabel('P(err)')
1011        plt.title('Distribution of relative error between calculation engines')
1012
1013    return limits
1014
1015
1016# ===========================================================================
1017#
1018
1019# Set of command line options.
1020# Normal options such as -plot/-noplot are specified as 'name'.
1021# For options such as -nq=500 which require a value use 'name='.
1022#
1023OPTIONS = [
1024    # Plotting
1025    'plot', 'noplot', 'weights',
1026    'linear', 'log', 'q4',
1027    'rel', 'abs',
1028    'hist', 'nohist',
1029    'title=',
1030
1031    # Data generation
1032    'data=', 'noise=', 'res=',
1033    'nq=', 'lowq', 'midq', 'highq', 'exq', 'zero',
1034    '2d', '1d',
1035
1036    # Parameter set
1037    'preset', 'random', 'random=', 'sets=',
1038    'demo', 'default',  # TODO: remove demo/default
1039    'nopars', 'pars',
1040    'sphere', 'sphere=', # integrate over a sphere in 2d with n points
1041
1042    # Calculation options
1043    'poly', 'mono', 'cutoff=',
1044    'magnetic', 'nonmagnetic',
1045    'accuracy=',
1046    'neval=',  # for timing...
1047
1048    # Precision options
1049    'engine=',
1050    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1051    'sasview',  # TODO: remove sasview 3.x support
1052
1053    # Output options
1054    'help', 'html', 'edit',
1055    ]
1056
1057NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1058VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1059
1060
1061def columnize(items, indent="", width=79):
1062    # type: (List[str], str, int) -> str
1063    """
1064    Format a list of strings into columns.
1065
1066    Returns a string with carriage returns ready for printing.
1067    """
1068    column_width = max(len(w) for w in items) + 1
1069    num_columns = (width - len(indent)) // column_width
1070    num_rows = len(items) // num_columns
1071    items = items + [""] * (num_rows * num_columns - len(items))
1072    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1073    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1074             for row in zip(*columns)]
1075    output = indent + ("\n"+indent).join(lines)
1076    return output
1077
1078
1079def get_pars(model_info, use_demo=False):
1080    # type: (ModelInfo, bool) -> ParameterSet
1081    """
1082    Extract demo parameters from the model definition.
1083    """
1084    # Get the default values for the parameters
1085    pars = {}
1086    for p in model_info.parameters.call_parameters:
1087        parts = [('', p.default)]
1088        if p.polydisperse:
1089            parts.append(('_pd', 0.0))
1090            parts.append(('_pd_n', 0))
1091            parts.append(('_pd_nsigma', 3.0))
1092            parts.append(('_pd_type', "gaussian"))
1093        for ext, val in parts:
1094            if p.length > 1:
1095                dict(("%s%d%s" % (p.id, k, ext), val)
1096                     for k in range(1, p.length+1))
1097            else:
1098                pars[p.id + ext] = val
1099
1100    # Plug in values given in demo
1101    if use_demo and model_info.demo:
1102        pars.update(model_info.demo)
1103    return pars
1104
1105INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1106def isnumber(str):
1107    match = FLOAT_RE.match(str)
1108    isfloat = (match and not str[match.end():])
1109    return isfloat or INTEGER_RE.match(str)
1110
1111# For distinguishing pairs of models for comparison
1112# key-value pair separator =
1113# shell characters  | & ; <> $ % ' " \ # `
1114# model and parameter names _
1115# parameter expressions - + * / . ( )
1116# path characters including tilde expansion and windows drive ~ / :
1117# not sure about brackets [] {}
1118# maybe one of the following @ ? ^ ! ,
1119PAR_SPLIT = ','
1120def parse_opts(argv):
1121    # type: (List[str]) -> Dict[str, Any]
1122    """
1123    Parse command line options.
1124    """
1125    MODELS = core.list_models()
1126    flags = [arg for arg in argv
1127             if arg.startswith('-')]
1128    values = [arg for arg in argv
1129              if not arg.startswith('-') and '=' in arg]
1130    positional_args = [arg for arg in argv
1131                       if not arg.startswith('-') and '=' not in arg]
1132    models = "\n    ".join("%-15s"%v for v in MODELS)
1133    if len(positional_args) == 0:
1134        print(USAGE)
1135        print("\nAvailable models:")
1136        print(columnize(MODELS, indent="  "))
1137        return None
1138
1139    invalid = [o[1:] for o in flags
1140               if o[1:] not in NAME_OPTIONS
1141               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1142    if invalid:
1143        print("Invalid options: %s"%(", ".join(invalid)))
1144        return None
1145
1146    name = positional_args[-1]
1147
1148    # pylint: disable=bad-whitespace
1149    # Interpret the flags
1150    opts = {
1151        'plot'      : True,
1152        'view'      : 'log',
1153        'is2d'      : False,
1154        'qmax'      : 0.05,
1155        'nq'        : 128,
1156        'res'       : 0.0,
1157        'noise'     : 0.0,
1158        'accuracy'  : 'Low',
1159        'cutoff'    : '0.0',
1160        'seed'      : -1,  # default to preset
1161        'mono'      : True,
1162        # Default to magnetic a magnetic moment is set on the command line
1163        'magnetic'  : False,
1164        'show_pars' : False,
1165        'show_hist' : False,
1166        'rel_err'   : True,
1167        'explore'   : False,
1168        'use_demo'  : True,
1169        'zero'      : False,
1170        'html'      : False,
1171        'title'     : None,
1172        'datafile'  : None,
1173        'sets'      : 0,
1174        'engine'    : 'default',
1175        'count'     : '1',
1176        'show_weights' : False,
1177        'sphere'    : 0,
1178    }
1179    for arg in flags:
1180        if arg == '-noplot':    opts['plot'] = False
1181        elif arg == '-plot':    opts['plot'] = True
1182        elif arg == '-linear':  opts['view'] = 'linear'
1183        elif arg == '-log':     opts['view'] = 'log'
1184        elif arg == '-q4':      opts['view'] = 'q4'
1185        elif arg == '-1d':      opts['is2d'] = False
1186        elif arg == '-2d':      opts['is2d'] = True
1187        elif arg == '-exq':     opts['qmax'] = 10.0
1188        elif arg == '-highq':   opts['qmax'] = 1.0
1189        elif arg == '-midq':    opts['qmax'] = 0.2
1190        elif arg == '-lowq':    opts['qmax'] = 0.05
1191        elif arg == '-zero':    opts['zero'] = True
1192        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1193        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1194        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1195        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1196        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1197        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1198        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
1199        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1200        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1201        elif arg.startswith('-engine='):   opts['engine'] = arg[8:]
1202        elif arg.startswith('-neval='):    opts['count'] = arg[7:]
1203        elif arg.startswith('-sphere'):
1204            opts['sphere'] = int(arg[8:]) if len(arg) > 7 else 150
1205            opts['is2d'] = True
1206        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
1207        elif arg == '-preset':  opts['seed'] = -1
1208        elif arg == '-mono':    opts['mono'] = True
1209        elif arg == '-poly':    opts['mono'] = False
1210        elif arg == '-magnetic':       opts['magnetic'] = True
1211        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1212        elif arg == '-pars':    opts['show_pars'] = True
1213        elif arg == '-nopars':  opts['show_pars'] = False
1214        elif arg == '-hist':    opts['show_hist'] = True
1215        elif arg == '-nohist':  opts['show_hist'] = False
1216        elif arg == '-rel':     opts['rel_err'] = True
1217        elif arg == '-abs':     opts['rel_err'] = False
1218        elif arg == '-half':    opts['engine'] = 'half'
1219        elif arg == '-fast':    opts['engine'] = 'fast'
1220        elif arg == '-single':  opts['engine'] = 'single'
1221        elif arg == '-double':  opts['engine'] = 'double'
1222        elif arg == '-single!': opts['engine'] = 'single!'
1223        elif arg == '-double!': opts['engine'] = 'double!'
1224        elif arg == '-quad!':   opts['engine'] = 'quad!'
1225        elif arg == '-sasview': opts['engine'] = 'sasview'
1226        elif arg == '-edit':    opts['explore'] = True
1227        elif arg == '-demo':    opts['use_demo'] = True
1228        elif arg == '-default': opts['use_demo'] = False
1229        elif arg == '-weights': opts['show_weights'] = True
1230        elif arg == '-html':    opts['html'] = True
1231        elif arg == '-help':    opts['html'] = True
1232    # pylint: enable=bad-whitespace
1233
1234    # Magnetism forces 2D for now
1235    if opts['magnetic']:
1236        opts['is2d'] = True
1237
1238    # Force random if sets is used
1239    if opts['sets'] >= 1 and opts['seed'] < 0:
1240        opts['seed'] = np.random.randint(1000000)
1241    if opts['sets'] == 0:
1242        opts['sets'] = 1
1243
1244    # Create the computational engines
1245    if opts['datafile'] is not None:
1246        data = load_data(os.path.expanduser(opts['datafile']))
1247    else:
1248        data, _ = make_data(opts)
1249
1250    comparison = any(PAR_SPLIT in v for v in values)
1251    if PAR_SPLIT in name:
1252        names = name.split(PAR_SPLIT, 2)
1253        comparison = True
1254    else:
1255        names = [name]*2
1256    try:
1257        model_info = [core.load_model_info(k) for k in names]
1258    except ImportError as exc:
1259        print(str(exc))
1260        print("Could not find model; use one of:\n    " + models)
1261        return None
1262
1263    if PAR_SPLIT in opts['engine']:
1264        opts['engine'] = opts['engine'].split(PAR_SPLIT, 2)
1265        comparison = True
1266    else:
1267        opts['engine'] = [opts['engine']]*2
1268
1269    if PAR_SPLIT in opts['count']:
1270        opts['count'] = [int(k) for k in opts['count'].split(PAR_SPLIT, 2)]
1271        comparison = True
1272    else:
1273        opts['count'] = [int(opts['count'])]*2
1274
1275    if PAR_SPLIT in opts['cutoff']:
1276        opts['cutoff'] = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1277        comparison = True
1278    else:
1279        opts['cutoff'] = [float(opts['cutoff'])]*2
1280
1281    base = make_engine(model_info[0], data, opts['engine'][0], opts['cutoff'][0])
1282    if comparison:
1283        comp = make_engine(model_info[1], data, opts['engine'][1], opts['cutoff'][1])
1284    else:
1285        comp = None
1286
1287    # pylint: disable=bad-whitespace
1288    # Remember it all
1289    opts.update({
1290        'data'      : data,
1291        'name'      : names,
1292        'info'      : model_info,
1293        'engines'   : [base, comp],
1294        'values'    : values,
1295    })
1296    # pylint: enable=bad-whitespace
1297
1298    # Set the integration parameters to the half sphere
1299    if opts['sphere'] > 0:
1300        set_spherical_integration_parameters(opts, opts['sphere'])
1301
1302    return opts
1303
1304def set_spherical_integration_parameters(opts, steps):
1305    """
1306    Set integration parameters for spherical integration over the entire
1307    surface in theta-phi coordinates.
1308    """
1309    # Set the integration parameters to the half sphere
1310    opts['values'].extend([
1311        'theta=90',
1312        'theta_pd=%g'%(90/np.sqrt(3)),
1313        'theta_pd_n=%d'%steps,
1314        'theta_pd_type=rectangle',
1315        'phi=0',
1316        'phi_pd=%g'%(180/np.sqrt(3)),
1317        'phi_pd_n=%d'%(2*steps),
1318        'phi_pd_type=rectangle',
1319        #'background=0',
1320    ])
1321    if 'psi' in opts['info'][0].parameters:
1322        opts['values'].append('psi=0')
1323
1324def parse_pars(opts, maxdim=np.inf):
1325    model_info, model_info2 = opts['info']
1326
1327    # Get demo parameters from model definition, or use default parameters
1328    # if model does not define demo parameters
1329    pars = get_pars(model_info, opts['use_demo'])
1330    pars2 = get_pars(model_info2, opts['use_demo'])
1331    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1332    # randomize parameters
1333    #pars.update(set_pars)  # set value before random to control range
1334    if opts['seed'] > -1:
1335        pars = randomize_pars(model_info, pars)
1336        limit_dimensions(model_info, pars, maxdim)
1337        if model_info != model_info2:
1338            pars2 = randomize_pars(model_info2, pars2)
1339            limit_dimensions(model_info, pars2, maxdim)
1340            # Share values for parameters with the same name
1341            for k, v in pars.items():
1342                if k in pars2:
1343                    pars2[k] = v
1344        else:
1345            pars2 = pars.copy()
1346        constrain_pars(model_info, pars)
1347        constrain_pars(model_info2, pars2)
1348    pars = suppress_pd(pars, opts['mono'])
1349    pars2 = suppress_pd(pars2, opts['mono'])
1350    pars = suppress_magnetism(pars, not opts['magnetic'])
1351    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1352
1353    # Fill in parameters given on the command line
1354    presets = {}
1355    presets2 = {}
1356    for arg in opts['values']:
1357        k, v = arg.split('=', 1)
1358        if k not in pars and k not in pars2:
1359            # extract base name without polydispersity info
1360            s = set(p.split('_pd')[0] for p in pars)
1361            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1362            return None
1363        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
1364        if v1 and k in pars:
1365            presets[k] = float(v1) if isnumber(v1) else v1
1366        if v2 and k in pars2:
1367            presets2[k] = float(v2) if isnumber(v2) else v2
1368
1369    # If pd given on the command line, default pd_n to 35
1370    for k, v in list(presets.items()):
1371        if k.endswith('_pd'):
1372            presets.setdefault(k+'_n', 35.)
1373    for k, v in list(presets2.items()):
1374        if k.endswith('_pd'):
1375            presets2.setdefault(k+'_n', 35.)
1376
1377    # Evaluate preset parameter expressions
1378    context = MATH.copy()
1379    context['np'] = np
1380    context.update(pars)
1381    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1382    for k, v in presets.items():
1383        if not isinstance(v, float) and not k.endswith('_type'):
1384            presets[k] = eval(v, context)
1385    context.update(presets)
1386    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1387    for k, v in presets2.items():
1388        if not isinstance(v, float) and not k.endswith('_type'):
1389            presets2[k] = eval(v, context)
1390
1391    # update parameters with presets
1392    pars.update(presets)  # set value after random to control value
1393    pars2.update(presets2)  # set value after random to control value
1394    #import pprint; pprint.pprint(model_info)
1395
1396    if opts['show_pars']:
1397        if model_info.name != model_info2.name or pars != pars2:
1398            print("==== %s ====="%model_info.name)
1399            print(str(parlist(model_info, pars, opts['is2d'])))
1400            print("==== %s ====="%model_info2.name)
1401            print(str(parlist(model_info2, pars2, opts['is2d'])))
1402        else:
1403            print(str(parlist(model_info, pars, opts['is2d'])))
1404
1405    return pars, pars2
1406
1407def show_docs(opts):
1408    # type: (Dict[str, Any]) -> None
1409    """
1410    show html docs for the model
1411    """
1412    import os
1413    from .generate import make_html
1414    from . import rst2html
1415
1416    info = opts['info'][0]
1417    html = make_html(info)
1418    path = os.path.dirname(info.filename)
1419    url = "file://"+path.replace("\\","/")[2:]+"/"
1420    rst2html.view_html_qtapp(html, url)
1421
1422def explore(opts):
1423    # type: (Dict[str, Any]) -> None
1424    """
1425    explore the model using the bumps gui.
1426    """
1427    import wx  # type: ignore
1428    from bumps.names import FitProblem  # type: ignore
1429    from bumps.gui.app_frame import AppFrame  # type: ignore
1430    from bumps.gui import signal
1431
1432    is_mac = "cocoa" in wx.version()
1433    # Create an app if not running embedded
1434    app = wx.App() if wx.GetApp() is None else None
1435    model = Explore(opts)
1436    problem = FitProblem(model)
1437    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1438    if not is_mac:
1439        frame.Show()
1440    frame.panel.set_model(model=problem)
1441    frame.panel.Layout()
1442    frame.panel.aui.Split(0, wx.TOP)
1443    def reset_parameters(event):
1444        model.revert_values()
1445        signal.update_parameters(problem)
1446    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1447    if is_mac: frame.Show()
1448    # If running withing an app, start the main loop
1449    if app:
1450        app.MainLoop()
1451
1452class Explore(object):
1453    """
1454    Bumps wrapper for a SAS model comparison.
1455
1456    The resulting object can be used as a Bumps fit problem so that
1457    parameters can be adjusted in the GUI, with plots updated on the fly.
1458    """
1459    def __init__(self, opts):
1460        # type: (Dict[str, Any]) -> None
1461        from bumps.cli import config_matplotlib  # type: ignore
1462        from . import bumps_model
1463        config_matplotlib()
1464        self.opts = opts
1465        opts['pars'] = list(opts['pars'])
1466        p1, p2 = opts['pars']
1467        m1, m2 = opts['info']
1468        self.fix_p2 = m1 != m2 or p1 != p2
1469        model_info = m1
1470        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1471        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1472        if not opts['is2d']:
1473            for p in model_info.parameters.user_parameters({}, is2d=False):
1474                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1475                    k = p.name+ext
1476                    v = pars.get(k, None)
1477                    if v is not None:
1478                        v.range(*parameter_range(k, v.value))
1479        else:
1480            for k, v in pars.items():
1481                v.range(*parameter_range(k, v.value))
1482
1483        self.pars = pars
1484        self.starting_values = dict((k, v.value) for k, v in pars.items())
1485        self.pd_types = pd_types
1486        self.limits = None
1487
1488    def revert_values(self):
1489        for k, v in self.starting_values.items():
1490            self.pars[k].value = v
1491
1492    def model_update(self):
1493        pass
1494
1495    def numpoints(self):
1496        # type: () -> int
1497        """
1498        Return the number of points.
1499        """
1500        return len(self.pars) + 1  # so dof is 1
1501
1502    def parameters(self):
1503        # type: () -> Any   # Dict/List hierarchy of parameters
1504        """
1505        Return a dictionary of parameters.
1506        """
1507        return self.pars
1508
1509    def nllf(self):
1510        # type: () -> float
1511        """
1512        Return cost.
1513        """
1514        # pylint: disable=no-self-use
1515        return 0.  # No nllf
1516
1517    def plot(self, view='log'):
1518        # type: (str) -> None
1519        """
1520        Plot the data and residuals.
1521        """
1522        pars = dict((k, v.value) for k, v in self.pars.items())
1523        pars.update(self.pd_types)
1524        self.opts['pars'][0] = pars
1525        if not self.fix_p2:
1526            self.opts['pars'][1] = pars
1527        result = run_models(self.opts)
1528        limits = plot_models(self.opts, result, limits=self.limits)
1529        if self.limits is None:
1530            vmin, vmax = limits
1531            self.limits = vmax*1e-7, 1.3*vmax
1532            import pylab; pylab.clf()
1533            plot_models(self.opts, result, limits=self.limits)
1534
1535
1536def main(*argv):
1537    # type: (*str) -> None
1538    """
1539    Main program.
1540    """
1541    opts = parse_opts(argv)
1542    if opts is not None:
1543        if opts['seed'] > -1:
1544            print("Randomize using -random=%i"%opts['seed'])
1545            np.random.seed(opts['seed'])
1546        if opts['html']:
1547            show_docs(opts)
1548        elif opts['explore']:
1549            opts['pars'] = parse_pars(opts)
1550            if opts['pars'] is None:
1551                return
1552            explore(opts)
1553        else:
1554            compare(opts)
1555
1556if __name__ == "__main__":
1557    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.