source: sasmodels/sasmodels/compare.py @ 97d89af

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 97d89af was 97d89af, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

provide default polydispersity/magnetism for sascomp if -poly/-magnetic is requested

  • Property mode set to 100755
File size: 52.0 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import os
33import math
34import datetime
35import traceback
36import re
37
38import numpy as np  # type: ignore
39
40from . import core
41from . import kerneldll
42from . import exception
43from .data import plot_theory, empty_data1D, empty_data2D, load_data
44from .direct_model import DirectModel
45from .convert import revert_name, revert_pars, constrain_new_to_old
46from .generate import FLOAT_RE
47
48try:
49    from typing import Optional, Dict, Any, Callable, Tuple
50except Exception:
51    pass
52else:
53    from .modelinfo import ModelInfo, Parameter, ParameterSet
54    from .data import Data
55    Calculator = Callable[[float], np.ndarray]
56
57USAGE = """
58usage: sascomp model [options...] [key=val]
59
60Generate and compare SAS models.  If a single model is specified it shows
61a plot of that model.  Different models can be compared, or the same model
62with different parameters.  The same model with the same parameters can
63be compared with different calculation engines to see the effects of precision
64on the resultant values.
65
66model or model1,model2 are the names of the models to compare (see below).
67
68Options (* for default):
69
70    === data generation ===
71    -data="path" uses q, dq from the data file
72    -noise=0 sets the measurement error dI/I
73    -res=0 sets the resolution width dQ/Q if calculating with resolution
74    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
75    -nq=128 sets the number of Q points in the data set
76    -1d*/-2d computes 1d or 2d data
77    -zero indicates that q=0 should be included
78
79    === model parameters ===
80    -preset*/-random[=seed] preset or random parameters
81    -sets=n generates n random datasets, with the seed given by -random=seed
82    -pars/-nopars* prints the parameter set or not
83    -default/-demo* use demo vs default parameters
84
85    === calculation options ===
86    -mono*/-poly force monodisperse or allow polydisperse demo parameters
87    -cutoff=1e-5* cutoff value for including a point in polydispersity
88    -magnetic/-nonmagnetic* suppress magnetism
89    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
90
91    === precision options ===
92    -calc=default uses the default calcution precision
93    -single/-double/-half/-fast sets an OpenCL calculation engine
94    -single!/-double!/-quad! sets an OpenMP calculation engine
95    -sasview sets the sasview calculation engine
96
97    === plotting ===
98    -plot*/-noplot plots or suppress the plot of the model
99    -linear/-log*/-q4 intensity scaling on plots
100    -hist/-nohist* plot histogram of relative error
101    -abs/-rel* plot relative or absolute error
102    -title="note" adds note to the plot title, after the model name
103
104    === output options ===
105    -edit starts the parameter explorer
106    -help/-html shows the model docs instead of running the model
107
108The interpretation of quad precision depends on architecture, and may
109vary from 64-bit to 128-bit, with 80-bit floats being common (1e-19 precision).
110On unix and mac you may need single quotes around the DLL computation
111engines, such as -calc='single!,double!' since !, is treated as a history
112expansion request in the shell.
113
114Key=value pairs allow you to set specific values for the model parameters.
115Key=value1,value2 to compare different values of the same parameter. The
116value can be an expression including other parameters.
117
118Items later on the command line override those that appear earlier.
119
120Examples:
121
122    # compare single and double precision calculation for a barbell
123    sascomp barbell -calc=single,double
124
125    # generate 10 random lorentz models, with seed=27
126    sascomp lorentz -sets=10 -seed=27
127
128    # compare ellipsoid with R = R_polar = R_equatorial to sphere of radius R
129    sascomp sphere,ellipsoid radius_polar=radius radius_equatorial=radius
130
131    # model timing test requires multiple evals to perform the estimate
132    sascomp pringle -calc=single,double -timing=100,100 -noplot
133"""
134
135# Update docs with command line usage string.   This is separate from the usual
136# doc string so that we can display it at run time if there is an error.
137# lin
138__doc__ = (__doc__  # pylint: disable=redefined-builtin
139           + """
140Program description
141-------------------
142
143""" + USAGE)
144
145kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
146
147# list of math functions for use in evaluating parameters
148MATH = dict((k,getattr(math, k)) for k in dir(math) if not k.startswith('_'))
149
150# CRUFT python 2.6
151if not hasattr(datetime.timedelta, 'total_seconds'):
152    def delay(dt):
153        """Return number date-time delta as number seconds"""
154        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
155else:
156    def delay(dt):
157        """Return number date-time delta as number seconds"""
158        return dt.total_seconds()
159
160
161class push_seed(object):
162    """
163    Set the seed value for the random number generator.
164
165    When used in a with statement, the random number generator state is
166    restored after the with statement is complete.
167
168    :Parameters:
169
170    *seed* : int or array_like, optional
171        Seed for RandomState
172
173    :Example:
174
175    Seed can be used directly to set the seed::
176
177        >>> from numpy.random import randint
178        >>> push_seed(24)
179        <...push_seed object at...>
180        >>> print(randint(0,1000000,3))
181        [242082    899 211136]
182
183    Seed can also be used in a with statement, which sets the random
184    number generator state for the enclosed computations and restores
185    it to the previous state on completion::
186
187        >>> with push_seed(24):
188        ...    print(randint(0,1000000,3))
189        [242082    899 211136]
190
191    Using nested contexts, we can demonstrate that state is indeed
192    restored after the block completes::
193
194        >>> with push_seed(24):
195        ...    print(randint(0,1000000))
196        ...    with push_seed(24):
197        ...        print(randint(0,1000000,3))
198        ...    print(randint(0,1000000))
199        242082
200        [242082    899 211136]
201        899
202
203    The restore step is protected against exceptions in the block::
204
205        >>> with push_seed(24):
206        ...    print(randint(0,1000000))
207        ...    try:
208        ...        with push_seed(24):
209        ...            print(randint(0,1000000,3))
210        ...            raise Exception()
211        ...    except Exception:
212        ...        print("Exception raised")
213        ...    print(randint(0,1000000))
214        242082
215        [242082    899 211136]
216        Exception raised
217        899
218    """
219    def __init__(self, seed=None):
220        # type: (Optional[int]) -> None
221        self._state = np.random.get_state()
222        np.random.seed(seed)
223
224    def __enter__(self):
225        # type: () -> None
226        pass
227
228    def __exit__(self, exc_type, exc_value, traceback):
229        # type: (Any, BaseException, Any) -> None
230        # TODO: better typing for __exit__ method
231        np.random.set_state(self._state)
232
233def tic():
234    # type: () -> Callable[[], float]
235    """
236    Timer function.
237
238    Use "toc=tic()" to start the clock and "toc()" to measure
239    a time interval.
240    """
241    then = datetime.datetime.now()
242    return lambda: delay(datetime.datetime.now() - then)
243
244
245def set_beam_stop(data, radius, outer=None):
246    # type: (Data, float, float) -> None
247    """
248    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
249
250    Note: this function does not require sasview
251    """
252    if hasattr(data, 'qx_data'):
253        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
254        data.mask = (q < radius)
255        if outer is not None:
256            data.mask |= (q >= outer)
257    else:
258        data.mask = (data.x < radius)
259        if outer is not None:
260            data.mask |= (data.x >= outer)
261
262
263def parameter_range(p, v):
264    # type: (str, float) -> Tuple[float, float]
265    """
266    Choose a parameter range based on parameter name and initial value.
267    """
268    # process the polydispersity options
269    if p.endswith('_pd_n'):
270        return 0., 100.
271    elif p.endswith('_pd_nsigma'):
272        return 0., 5.
273    elif p.endswith('_pd_type'):
274        raise ValueError("Cannot return a range for a string value")
275    elif any(s in p for s in ('theta', 'phi', 'psi')):
276        # orientation in [-180,180], orientation pd in [0,45]
277        if p.endswith('_pd'):
278            return 0., 45.
279        else:
280            return -180., 180.
281    elif p.endswith('_pd'):
282        return 0., 1.
283    elif 'sld' in p:
284        return -0.5, 10.
285    elif p == 'background':
286        return 0., 10.
287    elif p == 'scale':
288        return 0., 1.e3
289    elif v < 0.:
290        return 2.*v, -2.*v
291    else:
292        return 0., (2.*v if v > 0. else 1.)
293
294
295def _randomize_one(model_info, name, value):
296    # type: (ModelInfo, str, float) -> float
297    # type: (ModelInfo, str, str) -> str
298    """
299    Randomize a single parameter.
300    """
301    # Set the amount of polydispersity/angular dispersion, but by default pd_n
302    # is zero so there is no polydispersity.  This allows us to turn on/off
303    # pd by setting pd_n, and still have randomly generated values
304    if name.endswith('_pd'):
305        par = model_info.parameters[name[:-3]]
306        if par.type == 'orientation':
307            # Let oriention variation peak around 13 degrees; 95% < 42 degrees
308            return 180*np.random.beta(2.5, 20)
309        else:
310            # Let polydispersity peak around 15%; 95% < 0.4; max=100%
311            return np.random.beta(1.5, 7)
312
313    # pd is selected globally rather than per parameter, so set to 0 for no pd
314    # In particular, when multiple pd dimensions, want to decrease the number
315    # of points per dimension for faster computation
316    if name.endswith('_pd_n'):
317        return 0
318
319    # Don't mess with distribution type for now
320    if name.endswith('_pd_type'):
321        return 'gaussian'
322
323    # type-dependent value of number of sigmas; for gaussian use 3.
324    if name.endswith('_pd_nsigma'):
325        return 3.
326
327    # background in the range [0.01, 1]
328    if name == 'background':
329        return 10**np.random.uniform(-2, 0)
330
331    # scale defaults to 0.1% to 30% volume fraction
332    if name == 'scale':
333        return 10**np.random.uniform(-3, -0.5)
334
335    # If it is a list of choices, pick one at random with equal probability
336    # In practice, the model specific random generator will override.
337    par = model_info.parameters[name]
338    if len(par.limits) > 2:  # choice list
339        return np.random.randint(len(par.limits))
340
341    # If it is a fixed range, pick from it with equal probability.
342    # For logarithmic ranges, the model will have to override.
343    if np.isfinite(par.limits).all():
344        return np.random.uniform(*par.limits)
345
346    # If the paramter is marked as an sld use the range of neutron slds
347    # Should be doing something with randomly selected contrast matching
348    # but for now use random contrasts.  Since real data is contrast-matched,
349    # this will favour selection of hollow models when they exist.
350    if par.type == 'sld':
351        # Range of neutron SLDs
352        return np.random.uniform(-0.5, 12)
353
354    # Guess at the random length/radius/thickness.  In practice, all models
355    # are going to set their own reasonable ranges.
356    if par.type == 'volume':
357        if ('length' in par.name or
358                'radius' in par.name or
359                'thick' in par.name):
360            return 10**np.random.uniform(2, 4)
361
362    # In the absence of any other info, select a value in [0, 2v], or
363    # [-2|v|, 2|v|] if v is negative, or [0, 1] if v is zero.  Mostly the
364    # model random parameter generators will override this default.
365    low, high = parameter_range(par.name, value)
366    limits = (max(par.limits[0], low), min(par.limits[1], high))
367    return np.random.uniform(*limits)
368
369def _random_pd(model_info, pars):
370    pd = [p for p in model_info.parameters.kernel_parameters if p.polydisperse]
371    pd_volume = []
372    pd_oriented = []
373    for p in pd:
374        if p.type == 'orientation':
375            pd_oriented.append(p.name)
376        elif p.length_control is not None:
377            n = int(pars.get(p.length_control, 1) + 0.5)
378            pd_volume.extend(p.name+str(k+1) for k in range(n))
379        elif p.length > 1:
380            pd_volume.extend(p.name+str(k+1) for k in range(p.length))
381        else:
382            pd_volume.append(p.name)
383    u = np.random.rand()
384    n = len(pd_volume)
385    if u < 0.01 or n < 1:
386        pass  # 1% chance of no polydispersity
387    elif u < 0.86 or n < 2:
388        pars[np.random.choice(pd_volume)+"_pd_n"] = 35
389    elif u < 0.99 or n < 3:
390        choices = np.random.choice(len(pd_volume), size=2)
391        pars[pd_volume[choices[0]]+"_pd_n"] = 25
392        pars[pd_volume[choices[1]]+"_pd_n"] = 10
393    else:
394        choices = np.random.choice(len(pd_volume), size=3)
395        pars[pd_volume[choices[0]]+"_pd_n"] = 25
396        pars[pd_volume[choices[1]]+"_pd_n"] = 10
397        pars[pd_volume[choices[2]]+"_pd_n"] = 5
398    if pd_oriented:
399        pars['theta_pd_n'] = 20
400        if np.random.rand() < 0.1:
401            pars['phi_pd_n'] = 5
402        if np.random.rand() < 0.1:
403            pars['psi_pd_n'] = 5
404
405    ## Show selected polydispersity
406    #for name, value in pars.items():
407    #    if name.endswith('_pd_n') and value > 0:
408    #        print(name, value, pars.get(name[:-5], 0), pars.get(name[:-2], 0))
409
410
411def randomize_pars(model_info, pars):
412    # type: (ModelInfo, ParameterSet) -> ParameterSet
413    """
414    Generate random values for all of the parameters.
415
416    Valid ranges for the random number generator are guessed from the name of
417    the parameter; this will not account for constraints such as cap radius
418    greater than cylinder radius in the capped_cylinder model, so
419    :func:`constrain_pars` needs to be called afterward..
420    """
421    # Note: the sort guarantees order of calls to random number generator
422    random_pars = dict((p, _randomize_one(model_info, p, v))
423                       for p, v in sorted(pars.items()))
424    if model_info.random is not None:
425        random_pars.update(model_info.random())
426    _random_pd(model_info, random_pars)
427    return random_pars
428
429
430def constrain_pars(model_info, pars):
431    # type: (ModelInfo, ParameterSet) -> None
432    """
433    Restrict parameters to valid values.
434
435    This includes model specific code for models such as capped_cylinder
436    which need to support within model constraints (cap radius more than
437    cylinder radius in this case).
438
439    Warning: this updates the *pars* dictionary in place.
440    """
441    # TODO: move the model specific code to the individual models
442    name = model_info.id
443    # if it is a product model, then just look at the form factor since
444    # none of the structure factors need any constraints.
445    if '*' in name:
446        name = name.split('*')[0]
447
448    # Suppress magnetism for python models (not yet implemented)
449    if callable(model_info.Iq):
450        pars.update(suppress_magnetism(pars))
451
452    if name == 'barbell':
453        if pars['radius_bell'] < pars['radius']:
454            pars['radius'], pars['radius_bell'] = pars['radius_bell'], pars['radius']
455
456    elif name == 'capped_cylinder':
457        if pars['radius_cap'] < pars['radius']:
458            pars['radius'], pars['radius_cap'] = pars['radius_cap'], pars['radius']
459
460    elif name == 'guinier':
461        # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
462        # I(q) = A e^-(Rg^2 q^2/3) > e^-(30 ln 10)
463        # => ln A - (Rg^2 q^2/3) > -30 ln 10
464        # => Rg^2 q^2/3 < 30 ln 10 + ln A
465        # => Rg < sqrt(90 ln 10 + 3 ln A)/q
466        #q_max = 0.2  # mid q maximum
467        q_max = 1.0  # high q maximum
468        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
469        pars['rg'] = min(pars['rg'], rg_max)
470
471    elif name == 'pearl_necklace':
472        if pars['radius'] < pars['thick_string']:
473            pars['radius'], pars['thick_string'] = pars['thick_string'], pars['radius']
474        pass
475
476    elif name == 'rpa':
477        # Make sure phi sums to 1.0
478        if pars['case_num'] < 2:
479            pars['Phi1'] = 0.
480            pars['Phi2'] = 0.
481        elif pars['case_num'] < 5:
482            pars['Phi1'] = 0.
483        total = sum(pars['Phi'+c] for c in '1234')
484        for c in '1234':
485            pars['Phi'+c] /= total
486
487def parlist(model_info, pars, is2d):
488    # type: (ModelInfo, ParameterSet, bool) -> str
489    """
490    Format the parameter list for printing.
491    """
492    lines = []
493    parameters = model_info.parameters
494    magnetic = False
495    magnetic_pars = []
496    for p in parameters.user_parameters(pars, is2d):
497        if any(p.id.startswith(x) for x in ('M0:', 'mtheta:', 'mphi:')):
498            continue
499        if p.id.startswith('up:'):
500            magnetic_pars.append("%s=%s"%(p.id, pars.get(p.id, p.default)))
501            continue
502        fields = dict(
503            value=pars.get(p.id, p.default),
504            pd=pars.get(p.id+"_pd", 0.),
505            n=int(pars.get(p.id+"_pd_n", 0)),
506            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
507            pdtype=pars.get(p.id+"_pd_type", 'gaussian'),
508            relative_pd=p.relative_pd,
509            M0=pars.get('M0:'+p.id, 0.),
510            mphi=pars.get('mphi:'+p.id, 0.),
511            mtheta=pars.get('mtheta:'+p.id, 0.),
512        )
513        lines.append(_format_par(p.name, **fields))
514        magnetic = magnetic or fields['M0'] != 0.
515    if magnetic and magnetic_pars:
516        lines.append(" ".join(magnetic_pars))
517    return "\n".join(lines)
518
519    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
520
521def _format_par(name, value=0., pd=0., n=0, nsigma=3., pdtype='gaussian',
522                relative_pd=False, M0=0., mphi=0., mtheta=0.):
523    # type: (str, float, float, int, float, str) -> str
524    line = "%s: %g"%(name, value)
525    if pd != 0.  and n != 0:
526        if relative_pd:
527            pd *= value
528        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
529                % (pd, n, nsigma, nsigma, pdtype)
530    if M0 != 0.:
531        line += "  M0:%.3f  mphi:%.1f  mtheta:%.1f" % (M0, mphi, mtheta)
532    return line
533
534def suppress_pd(pars, suppress=True):
535    # type: (ParameterSet) -> ParameterSet
536    """
537    If suppress is True complete eliminate polydispersity of the model to test
538    models more quickly.  If suppress is False, make sure at least one
539    parameter is polydisperse, setting the first polydispersity parameter to
540    15% if no polydispersity is given (with no explicit demo parameters given
541    in the model, there will be no default polydispersity).
542    """
543    pars = pars.copy()
544    if suppress:
545        for p in pars:
546            if p.endswith("_pd_n"):
547                pars[p] = 0
548    else:
549        any_pd = False
550        first_pd = None
551        for p in pars:
552            if p.endswith("_pd_n"):
553                any_pd |= (pars[p] != 0 and pars[p[:-2]] != 0.)
554                if first_pd is None:
555                    first_pd = p
556        if not any_pd and first_pd is not None:
557            if pars[first_pd] == 0:
558                pars[first_pd] = 35
559            if pars[first_pd[:-2]] == 0:
560                pars[first_pd[:-2]] = 0.15
561    return pars
562
563def suppress_magnetism(pars, suppress=True):
564    # type: (ParameterSet) -> ParameterSet
565    """
566    If suppress is True complete eliminate magnetism of the model to test
567    models more quickly.  If suppress is False, make sure at least one sld
568    parameter is magnetic, setting the first parameter to have a strong
569    magnetic sld (8/A^2) at 60 degrees (with no explicit demo parameters given
570    in the model, there will be no default magnetism).
571    """
572    pars = pars.copy()
573    if suppress:
574        for p in pars:
575            if p.startswith("M0:"):
576                pars[p] = 0
577    else:
578        any_mag = False
579        first_mag = None
580        for p in pars:
581            if p.startswith("M0:"):
582                any_mag |= (pars[p] != 0)
583                if first_mag is None:
584                    first_mag = p
585        if not any_mag and first_mag is not None:
586            pars[first_mag] = 8.
587    return pars
588
589def eval_sasview(model_info, data):
590    # type: (Modelinfo, Data) -> Calculator
591    """
592    Return a model calculator using the pre-4.0 SasView models.
593    """
594    # importing sas here so that the error message will be that sas failed to
595    # import rather than the more obscure smear_selection not imported error
596    import sas
597    import sas.models
598    from sas.models.qsmearing import smear_selection
599    from sas.models.MultiplicationModel import MultiplicationModel
600    from sas.models.dispersion_models import models as dispersers
601
602    def get_model_class(name):
603        # type: (str) -> "sas.models.BaseComponent"
604        #print("new",sorted(_pars.items()))
605        __import__('sas.models.' + name)
606        ModelClass = getattr(getattr(sas.models, name, None), name, None)
607        if ModelClass is None:
608            raise ValueError("could not find model %r in sas.models"%name)
609        return ModelClass
610
611    # WARNING: ugly hack when handling model!
612    # Sasview models with multiplicity need to be created with the target
613    # multiplicity, so we cannot create the target model ahead of time for
614    # for multiplicity models.  Instead we store the model in a list and
615    # update the first element of that list with the new multiplicity model
616    # every time we evaluate.
617
618    # grab the sasview model, or create it if it is a product model
619    if model_info.composition:
620        composition_type, parts = model_info.composition
621        if composition_type == 'product':
622            P, S = [get_model_class(revert_name(p))() for p in parts]
623            model = [MultiplicationModel(P, S)]
624        else:
625            raise ValueError("sasview mixture models not supported by compare")
626    else:
627        old_name = revert_name(model_info)
628        if old_name is None:
629            raise ValueError("model %r does not exist in old sasview"
630                            % model_info.id)
631        ModelClass = get_model_class(old_name)
632        model = [ModelClass()]
633    model[0].disperser_handles = {}
634
635    # build a smearer with which to call the model, if necessary
636    smearer = smear_selection(data, model=model)
637    if hasattr(data, 'qx_data'):
638        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
639        index = ((~data.mask) & (~np.isnan(data.data))
640                 & (q >= data.qmin) & (q <= data.qmax))
641        if smearer is not None:
642            smearer.model = model  # because smear_selection has a bug
643            smearer.accuracy = data.accuracy
644            smearer.set_index(index)
645            def _call_smearer():
646                smearer.model = model[0]
647                return smearer.get_value()
648            theory = _call_smearer
649        else:
650            theory = lambda: model[0].evalDistribution([data.qx_data[index],
651                                                        data.qy_data[index]])
652    elif smearer is not None:
653        theory = lambda: smearer(model[0].evalDistribution(data.x))
654    else:
655        theory = lambda: model[0].evalDistribution(data.x)
656
657    def calculator(**pars):
658        # type: (float, ...) -> np.ndarray
659        """
660        Sasview calculator for model.
661        """
662        oldpars = revert_pars(model_info, pars)
663        # For multiplicity models, create a model with the correct multiplicity
664        control = oldpars.pop("CONTROL", None)
665        if control is not None:
666            # sphericalSLD has one fewer multiplicity.  This update should
667            # happen in revert_pars, but it hasn't been called yet.
668            model[0] = ModelClass(control)
669        # paying for parameter conversion each time to keep life simple, if not fast
670        for k, v in oldpars.items():
671            if k.endswith('.type'):
672                par = k[:-5]
673                if v == 'gaussian': continue
674                cls = dispersers[v if v != 'rectangle' else 'rectangula']
675                handle = cls()
676                model[0].disperser_handles[par] = handle
677                try:
678                    model[0].set_dispersion(par, handle)
679                except Exception:
680                    exception.annotate_exception("while setting %s to %r"
681                                                 %(par, v))
682                    raise
683
684
685        #print("sasview pars",oldpars)
686        for k, v in oldpars.items():
687            name_attr = k.split('.')  # polydispersity components
688            if len(name_attr) == 2:
689                par, disp_par = name_attr
690                model[0].dispersion[par][disp_par] = v
691            else:
692                model[0].setParam(k, v)
693        return theory()
694
695    calculator.engine = "sasview"
696    return calculator
697
698DTYPE_MAP = {
699    'half': '16',
700    'fast': 'fast',
701    'single': '32',
702    'double': '64',
703    'quad': '128',
704    'f16': '16',
705    'f32': '32',
706    'f64': '64',
707    'float16': '16',
708    'float32': '32',
709    'float64': '64',
710    'float128': '128',
711    'longdouble': '128',
712}
713def eval_opencl(model_info, data, dtype='single', cutoff=0.):
714    # type: (ModelInfo, Data, str, float) -> Calculator
715    """
716    Return a model calculator using the OpenCL calculation engine.
717    """
718    if not core.HAVE_OPENCL:
719        raise RuntimeError("OpenCL not available")
720    model = core.build_model(model_info, dtype=dtype, platform="ocl")
721    calculator = DirectModel(data, model, cutoff=cutoff)
722    calculator.engine = "OCL%s"%DTYPE_MAP[str(model.dtype)]
723    return calculator
724
725def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
726    # type: (ModelInfo, Data, str, float) -> Calculator
727    """
728    Return a model calculator using the DLL calculation engine.
729    """
730    model = core.build_model(model_info, dtype=dtype, platform="dll")
731    calculator = DirectModel(data, model, cutoff=cutoff)
732    calculator.engine = "OMP%s"%DTYPE_MAP[str(model.dtype)]
733    return calculator
734
735def time_calculation(calculator, pars, evals=1):
736    # type: (Calculator, ParameterSet, int) -> Tuple[np.ndarray, float]
737    """
738    Compute the average calculation time over N evaluations.
739
740    An additional call is generated without polydispersity in order to
741    initialize the calculation engine, and make the average more stable.
742    """
743    # initialize the code so time is more accurate
744    if evals > 1:
745        calculator(**suppress_pd(pars))
746    toc = tic()
747    # make sure there is at least one eval
748    value = calculator(**pars)
749    for _ in range(evals-1):
750        value = calculator(**pars)
751    average_time = toc()*1000. / evals
752    #print("I(q)",value)
753    return value, average_time
754
755def make_data(opts):
756    # type: (Dict[str, Any]) -> Tuple[Data, np.ndarray]
757    """
758    Generate an empty dataset, used with the model to set Q points
759    and resolution.
760
761    *opts* contains the options, with 'qmax', 'nq', 'res',
762    'accuracy', 'is2d' and 'view' parsed from the command line.
763    """
764    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
765    if opts['is2d']:
766        q = np.linspace(-qmax, qmax, nq)  # type: np.ndarray
767        data = empty_data2D(q, resolution=res)
768        data.accuracy = opts['accuracy']
769        set_beam_stop(data, 0.0004)
770        index = ~data.mask
771    else:
772        if opts['view'] == 'log' and not opts['zero']:
773            qmax = math.log10(qmax)
774            q = np.logspace(qmax-3, qmax, nq)
775        else:
776            q = np.linspace(0.001*qmax, qmax, nq)
777        if opts['zero']:
778            q = np.hstack((0, q))
779        data = empty_data1D(q, resolution=res)
780        index = slice(None, None)
781    return data, index
782
783def make_engine(model_info, data, dtype, cutoff):
784    # type: (ModelInfo, Data, str, float) -> Calculator
785    """
786    Generate the appropriate calculation engine for the given datatype.
787
788    Datatypes with '!' appended are evaluated using external C DLLs rather
789    than OpenCL.
790    """
791    if dtype == 'sasview':
792        return eval_sasview(model_info, data)
793    elif dtype is None or not dtype.endswith('!'):
794        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
795    else:
796        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
797
798def _show_invalid(data, theory):
799    # type: (Data, np.ma.ndarray) -> None
800    """
801    Display a list of the non-finite values in theory.
802    """
803    if not theory.mask.any():
804        return
805
806    if hasattr(data, 'x'):
807        bad = zip(data.x[theory.mask], theory[theory.mask])
808        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x, y in bad))
809
810
811def compare(opts, limits=None):
812    # type: (Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
813    """
814    Preform a comparison using options from the command line.
815
816    *limits* are the limits on the values to use, either to set the y-axis
817    for 1D or to set the colormap scale for 2D.  If None, then they are
818    inferred from the data and returned. When exploring using Bumps,
819    the limits are set when the model is initially called, and maintained
820    as the values are adjusted, making it easier to see the effects of the
821    parameters.
822    """
823    limits = np.Inf, -np.Inf
824    for k in range(opts['sets']):
825        opts['pars'] = parse_pars(opts)
826        if opts['pars'] is None:
827            return
828        result = run_models(opts, verbose=True)
829        if opts['plot']:
830            limits = plot_models(opts, result, limits=limits, setnum=k)
831    if opts['plot']:
832        import matplotlib.pyplot as plt
833        plt.show()
834
835def run_models(opts, verbose=False):
836    # type: (Dict[str, Any]) -> Dict[str, Any]
837
838    base, comp = opts['engines']
839    base_n, comp_n = opts['count']
840    base_pars, comp_pars = opts['pars']
841    data = opts['data']
842
843    comparison = comp is not None
844
845    base_time = comp_time = None
846    base_value = comp_value = resid = relerr = None
847
848    # Base calculation
849    try:
850        base_raw, base_time = time_calculation(base, base_pars, base_n)
851        base_value = np.ma.masked_invalid(base_raw)
852        if verbose:
853            print("%s t=%.2f ms, intensity=%.0f"
854                  % (base.engine, base_time, base_value.sum()))
855        _show_invalid(data, base_value)
856    except ImportError:
857        traceback.print_exc()
858
859    # Comparison calculation
860    if comparison:
861        try:
862            comp_raw, comp_time = time_calculation(comp, comp_pars, comp_n)
863            comp_value = np.ma.masked_invalid(comp_raw)
864            if verbose:
865                print("%s t=%.2f ms, intensity=%.0f"
866                      % (comp.engine, comp_time, comp_value.sum()))
867            _show_invalid(data, comp_value)
868        except ImportError:
869            traceback.print_exc()
870
871    # Compare, but only if computing both forms
872    if comparison:
873        resid = (base_value - comp_value)
874        relerr = resid/np.where(comp_value != 0., abs(comp_value), 1.0)
875        if verbose:
876            _print_stats("|%s-%s|"
877                         % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
878                         resid)
879            _print_stats("|(%s-%s)/%s|"
880                         % (base.engine, comp.engine, comp.engine),
881                         relerr)
882
883    return dict(base_value=base_value, comp_value=comp_value,
884                base_time=base_time, comp_time=comp_time,
885                resid=resid, relerr=relerr)
886
887
888def _print_stats(label, err):
889    # type: (str, np.ma.ndarray) -> None
890    # work with trimmed data, not the full set
891    sorted_err = np.sort(abs(err.compressed()))
892    if len(sorted_err) == 0.:
893        print(label + "  no valid values")
894        return
895
896    p50 = int((len(sorted_err)-1)*0.50)
897    p98 = int((len(sorted_err)-1)*0.98)
898    data = [
899        "max:%.3e"%sorted_err[-1],
900        "median:%.3e"%sorted_err[p50],
901        "98%%:%.3e"%sorted_err[p98],
902        "rms:%.3e"%np.sqrt(np.mean(sorted_err**2)),
903        "zero-offset:%+.3e"%np.mean(sorted_err),
904        ]
905    print(label+"  "+"  ".join(data))
906
907
908def plot_models(opts, result, limits=(np.Inf, -np.Inf), setnum=0):
909    # type: (Dict[str, Any], Dict[str, Any], Optional[Tuple[float, float]]) -> Tuple[float, float]
910    base_value, comp_value = result['base_value'], result['comp_value']
911    base_time, comp_time = result['base_time'], result['comp_time']
912    resid, relerr = result['resid'], result['relerr']
913
914    have_base, have_comp = (base_value is not None), (comp_value is not None)
915    base, comp = opts['engines']
916    data = opts['data']
917    use_data = (opts['datafile'] is not None) and (have_base ^ have_comp)
918
919    # Plot if requested
920    view = opts['view']
921    import matplotlib.pyplot as plt
922    vmin, vmax = limits
923    if have_base:
924        vmin = min(vmin, base_value.min())
925        vmax = max(vmax, base_value.max())
926    if have_comp:
927        vmin = min(vmin, comp_value.min())
928        vmax = max(vmax, comp_value.max())
929    limits = vmin, vmax
930
931    if have_base:
932        if have_comp:
933            plt.subplot(131)
934        plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
935        plt.title("%s t=%.2f ms"%(base.engine, base_time))
936        #cbar_title = "log I"
937    if have_comp:
938        if have_base:
939            plt.subplot(132)
940        if not opts['is2d'] and have_base:
941            plot_theory(data, base_value, view=view, use_data=use_data, limits=limits)
942        plot_theory(data, comp_value, view=view, use_data=use_data, limits=limits)
943        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
944        #cbar_title = "log I"
945    if have_base and have_comp:
946        plt.subplot(133)
947        if not opts['rel_err']:
948            err, errstr, errview = resid, "abs err", "linear"
949        else:
950            err, errstr, errview = abs(relerr), "rel err", "log"
951        if 0:  # 95% cutoff
952            sorted = np.sort(err.flatten())
953            cutoff = sorted[int(sorted.size*0.95)]
954            err[err > cutoff] = cutoff
955        #err,errstr = base/comp,"ratio"
956        plot_theory(data, None, resid=err, view=errview, use_data=use_data)
957        if view == 'linear':
958            plt.xscale('linear')
959        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
960        #cbar_title = errstr if errview=="linear" else "log "+errstr
961    #if is2D:
962    #    h = plt.colorbar()
963    #    h.ax.set_title(cbar_title)
964    fig = plt.gcf()
965    extra_title = ' '+opts['title'] if opts['title'] else ''
966    fig.suptitle(":".join(opts['name']) + extra_title)
967
968    if have_base and have_comp and opts['show_hist']:
969        plt.figure()
970        v = relerr
971        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
972        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
973        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
974                   % (base.engine, comp.engine, comp.engine))
975        plt.ylabel('P(err)')
976        plt.title('Distribution of relative error between calculation engines')
977
978    return limits
979
980
981# ===========================================================================
982#
983
984# Set of command line options.
985# Normal options such as -plot/-noplot are specified as 'name'.
986# For options such as -nq=500 which require a value use 'name='.
987#
988OPTIONS = [
989    # Plotting
990    'plot', 'noplot',
991    'linear', 'log', 'q4',
992    'rel', 'abs',
993    'hist', 'nohist',
994    'title=',
995
996    # Data generation
997    'data=', 'noise=', 'res=',
998    'nq=', 'lowq', 'midq', 'highq', 'exq', 'zero',
999    '2d', '1d',
1000
1001    # Parameter set
1002    'preset', 'random', 'random=', 'sets=',
1003    'demo', 'default',  # TODO: remove demo/default
1004    'nopars', 'pars',
1005
1006    # Calculation options
1007    'poly', 'mono', 'cutoff=',
1008    'magnetic', 'nonmagnetic',
1009    'accuracy=',
1010
1011    # Precision options
1012    'calc=',
1013    'half', 'fast', 'single', 'double', 'single!', 'double!', 'quad!',
1014    'sasview',  # TODO: remove sasview 3.x support
1015    'timing=',
1016
1017    # Output options
1018    'help', 'html', 'edit',
1019    ]
1020
1021NAME_OPTIONS = set(k for k in OPTIONS if not k.endswith('='))
1022VALUE_OPTIONS = [k[:-1] for k in OPTIONS if k.endswith('=')]
1023
1024
1025def columnize(items, indent="", width=79):
1026    # type: (List[str], str, int) -> str
1027    """
1028    Format a list of strings into columns.
1029
1030    Returns a string with carriage returns ready for printing.
1031    """
1032    column_width = max(len(w) for w in items) + 1
1033    num_columns = (width - len(indent)) // column_width
1034    num_rows = len(items) // num_columns
1035    items = items + [""] * (num_rows * num_columns - len(items))
1036    columns = [items[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
1037    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
1038             for row in zip(*columns)]
1039    output = indent + ("\n"+indent).join(lines)
1040    return output
1041
1042
1043def get_pars(model_info, use_demo=False):
1044    # type: (ModelInfo, bool) -> ParameterSet
1045    """
1046    Extract demo parameters from the model definition.
1047    """
1048    # Get the default values for the parameters
1049    pars = {}
1050    for p in model_info.parameters.call_parameters:
1051        parts = [('', p.default)]
1052        if p.polydisperse:
1053            parts.append(('_pd', 0.0))
1054            parts.append(('_pd_n', 0))
1055            parts.append(('_pd_nsigma', 3.0))
1056            parts.append(('_pd_type', "gaussian"))
1057        for ext, val in parts:
1058            if p.length > 1:
1059                dict(("%s%d%s" % (p.id, k, ext), val)
1060                     for k in range(1, p.length+1))
1061            else:
1062                pars[p.id + ext] = val
1063
1064    # Plug in values given in demo
1065    if use_demo:
1066        pars.update(model_info.demo)
1067    return pars
1068
1069INTEGER_RE = re.compile("^[+-]?[1-9][0-9]*$")
1070def isnumber(str):
1071    match = FLOAT_RE.match(str)
1072    isfloat = (match and not str[match.end():])
1073    return isfloat or INTEGER_RE.match(str)
1074
1075# For distinguishing pairs of models for comparison
1076# key-value pair separator =
1077# shell characters  | & ; <> $ % ' " \ # `
1078# model and parameter names _
1079# parameter expressions - + * / . ( )
1080# path characters including tilde expansion and windows drive ~ / :
1081# not sure about brackets [] {}
1082# maybe one of the following @ ? ^ ! ,
1083PAR_SPLIT = ','
1084def parse_opts(argv):
1085    # type: (List[str]) -> Dict[str, Any]
1086    """
1087    Parse command line options.
1088    """
1089    MODELS = core.list_models()
1090    flags = [arg for arg in argv
1091             if arg.startswith('-')]
1092    values = [arg for arg in argv
1093              if not arg.startswith('-') and '=' in arg]
1094    positional_args = [arg for arg in argv
1095                       if not arg.startswith('-') and '=' not in arg]
1096    models = "\n    ".join("%-15s"%v for v in MODELS)
1097    if len(positional_args) == 0:
1098        print(USAGE)
1099        print("\nAvailable models:")
1100        print(columnize(MODELS, indent="  "))
1101        return None
1102
1103    invalid = [o[1:] for o in flags
1104               if o[1:] not in NAME_OPTIONS
1105               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
1106    if invalid:
1107        print("Invalid options: %s"%(", ".join(invalid)))
1108        return None
1109
1110    name = positional_args[-1]
1111
1112    # pylint: disable=bad-whitespace
1113    # Interpret the flags
1114    opts = {
1115        'plot'      : True,
1116        'view'      : 'log',
1117        'is2d'      : False,
1118        'qmax'      : 0.05,
1119        'nq'        : 128,
1120        'res'       : 0.0,
1121        'noise'     : 0.0,
1122        'accuracy'  : 'Low',
1123        'cutoff'    : '0.0',
1124        'seed'      : -1,  # default to preset
1125        'mono'      : True,
1126        # Default to magnetic a magnetic moment is set on the command line
1127        'magnetic'  : False,
1128        'show_pars' : False,
1129        'show_hist' : False,
1130        'rel_err'   : True,
1131        'explore'   : False,
1132        'use_demo'  : True,
1133        'zero'      : False,
1134        'html'      : False,
1135        'title'     : None,
1136        'datafile'  : None,
1137        'sets'      : 1,
1138        'engine'    : 'default',
1139        'evals'     : '1',
1140    }
1141    for arg in flags:
1142        if arg == '-noplot':    opts['plot'] = False
1143        elif arg == '-plot':    opts['plot'] = True
1144        elif arg == '-linear':  opts['view'] = 'linear'
1145        elif arg == '-log':     opts['view'] = 'log'
1146        elif arg == '-q4':      opts['view'] = 'q4'
1147        elif arg == '-1d':      opts['is2d'] = False
1148        elif arg == '-2d':      opts['is2d'] = True
1149        elif arg == '-exq':     opts['qmax'] = 10.0
1150        elif arg == '-highq':   opts['qmax'] = 1.0
1151        elif arg == '-midq':    opts['qmax'] = 0.2
1152        elif arg == '-lowq':    opts['qmax'] = 0.05
1153        elif arg == '-zero':    opts['zero'] = True
1154        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
1155        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
1156        elif arg.startswith('-noise='):    opts['noise'] = float(arg[7:])
1157        elif arg.startswith('-sets='):     opts['sets'] = int(arg[6:])
1158        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
1159        elif arg.startswith('-cutoff='):   opts['cutoff'] = arg[8:]
1160        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
1161        elif arg.startswith('-title='):    opts['title'] = arg[7:]
1162        elif arg.startswith('-data='):     opts['datafile'] = arg[6:]
1163        elif arg.startswith('-calc='):     opts['engine'] = arg[6:]
1164        elif arg.startswith('-neval='):    opts['evals'] = arg[7:]
1165        elif arg == '-random':  opts['seed'] = np.random.randint(1000000)
1166        elif arg == '-preset':  opts['seed'] = -1
1167        elif arg == '-mono':    opts['mono'] = True
1168        elif arg == '-poly':    opts['mono'] = False
1169        elif arg == '-magnetic':       opts['magnetic'] = True
1170        elif arg == '-nonmagnetic':    opts['magnetic'] = False
1171        elif arg == '-pars':    opts['show_pars'] = True
1172        elif arg == '-nopars':  opts['show_pars'] = False
1173        elif arg == '-hist':    opts['show_hist'] = True
1174        elif arg == '-nohist':  opts['show_hist'] = False
1175        elif arg == '-rel':     opts['rel_err'] = True
1176        elif arg == '-abs':     opts['rel_err'] = False
1177        elif arg == '-half':    opts['engine'] = 'half'
1178        elif arg == '-fast':    opts['engine'] = 'fast'
1179        elif arg == '-single':  opts['engine'] = 'single'
1180        elif arg == '-double':  opts['engine'] = 'double'
1181        elif arg == '-single!': opts['engine'] = 'single!'
1182        elif arg == '-double!': opts['engine'] = 'double!'
1183        elif arg == '-quad!':   opts['engine'] = 'quad!'
1184        elif arg == '-sasview': opts['engine'] = 'sasview'
1185        elif arg == '-edit':    opts['explore'] = True
1186        elif arg == '-demo':    opts['use_demo'] = True
1187        elif arg == '-default': opts['use_demo'] = False
1188        elif arg == '-html':    opts['html'] = True
1189        elif arg == '-help':    opts['html'] = True
1190    # pylint: enable=bad-whitespace
1191
1192    # Magnetism forces 2D for now
1193    if opts['magnetic']:
1194        opts['is2d'] = True
1195
1196    # Force random if more than one set
1197    if opts['sets'] > 1 and opts['seed'] < 0:
1198        opts['seed'] = np.random.randint(1000000)
1199
1200    # Create the computational engines
1201    if opts['datafile'] is not None:
1202        data = load_data(os.path.expanduser(opts['datafile']))
1203    else:
1204        data, _ = make_data(opts)
1205
1206    comparison = any(PAR_SPLIT in v for v in values)
1207    if PAR_SPLIT in name:
1208        names = name.split(PAR_SPLIT, 2)
1209        comparison = True
1210    else:
1211        names = [name]*2
1212    try:
1213        model_info = [core.load_model_info(k) for k in names]
1214    except ImportError as exc:
1215        print(str(exc))
1216        print("Could not find model; use one of:\n    " + models)
1217        return None
1218
1219    if PAR_SPLIT in opts['engine']:
1220        engine_types = opts['engine'].split(PAR_SPLIT, 2)
1221        comparison = True
1222    else:
1223        engine_types = [opts['engine']]*2
1224
1225    if PAR_SPLIT in opts['evals']:
1226        evals = [int(k) for k in opts['evals'].split(PAR_SPLIT, 2)]
1227        comparison = True
1228    else:
1229        evals = [int(opts['evals'])]*2
1230
1231    if PAR_SPLIT in opts['cutoff']:
1232        cutoff = [float(k) for k in opts['cutoff'].split(PAR_SPLIT, 2)]
1233        comparison = True
1234    else:
1235        cutoff = [float(opts['cutoff'])]*2
1236
1237    base = make_engine(model_info[0], data, engine_types[0], cutoff[0])
1238    if comparison:
1239        comp = make_engine(model_info[1], data, engine_types[1], cutoff[1])
1240    else:
1241        comp = None
1242
1243    # pylint: disable=bad-whitespace
1244    # Remember it all
1245    opts.update({
1246        'data'      : data,
1247        'name'      : names,
1248        'def'       : model_info,
1249        'count'     : evals,
1250        'engines'   : [base, comp],
1251        'values'    : values,
1252    })
1253    # pylint: enable=bad-whitespace
1254
1255    return opts
1256
1257def parse_pars(opts):
1258    model_info, model_info2 = opts['def']
1259
1260    # Get demo parameters from model definition, or use default parameters
1261    # if model does not define demo parameters
1262    pars = get_pars(model_info, opts['use_demo'])
1263    pars2 = get_pars(model_info2, opts['use_demo'])
1264    pars2.update((k, v) for k, v in pars.items() if k in pars2)
1265    # randomize parameters
1266    #pars.update(set_pars)  # set value before random to control range
1267    if opts['seed'] > -1:
1268        pars = randomize_pars(model_info, pars)
1269        if model_info != model_info2:
1270            pars2 = randomize_pars(model_info2, pars2)
1271            # Share values for parameters with the same name
1272            for k, v in pars.items():
1273                if k in pars2:
1274                    pars2[k] = v
1275        else:
1276            pars2 = pars.copy()
1277        constrain_pars(model_info, pars)
1278        constrain_pars(model_info2, pars2)
1279    pars = suppress_pd(pars, opts['mono'])
1280    pars2 = suppress_pd(pars2, opts['mono'])
1281    pars = suppress_magnetism(pars, not opts['magnetic'])
1282    pars2 = suppress_magnetism(pars2, not opts['magnetic'])
1283
1284    # Fill in parameters given on the command line
1285    presets = {}
1286    presets2 = {}
1287    for arg in opts['values']:
1288        k, v = arg.split('=', 1)
1289        if k not in pars and k not in pars2:
1290            # extract base name without polydispersity info
1291            s = set(p.split('_pd')[0] for p in pars)
1292            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
1293            return None
1294        v1, v2 = v.split(PAR_SPLIT, 2) if PAR_SPLIT in v else (v,v)
1295        if v1 and k in pars:
1296            presets[k] = float(v1) if isnumber(v1) else v1
1297        if v2 and k in pars2:
1298            presets2[k] = float(v2) if isnumber(v2) else v2
1299
1300    # If pd given on the command line, default pd_n to 35
1301    for k, v in list(presets.items()):
1302        if k.endswith('_pd'):
1303            presets.setdefault(k+'_n', 35.)
1304    for k, v in list(presets2.items()):
1305        if k.endswith('_pd'):
1306            presets2.setdefault(k+'_n', 35.)
1307
1308    # Evaluate preset parameter expressions
1309    context = MATH.copy()
1310    context['np'] = np
1311    context.update(pars)
1312    context.update((k, v) for k, v in presets.items() if isinstance(v, float))
1313    for k, v in presets.items():
1314        if not isinstance(v, float) and not k.endswith('_type'):
1315            presets[k] = eval(v, context)
1316    context.update(presets)
1317    context.update((k, v) for k, v in presets2.items() if isinstance(v, float))
1318    for k, v in presets2.items():
1319        if not isinstance(v, float) and not k.endswith('_type'):
1320            presets2[k] = eval(v, context)
1321
1322    # update parameters with presets
1323    pars.update(presets)  # set value after random to control value
1324    pars2.update(presets2)  # set value after random to control value
1325    #import pprint; pprint.pprint(model_info)
1326
1327    if opts['show_pars']:
1328        if model_info.name != model_info2.name or pars != pars2:
1329            print("==== %s ====="%model_info.name)
1330            print(str(parlist(model_info, pars, opts['is2d'])))
1331            print("==== %s ====="%model_info2.name)
1332            print(str(parlist(model_info2, pars2, opts['is2d'])))
1333        else:
1334            print(str(parlist(model_info, pars, opts['is2d'])))
1335
1336    return pars, pars2
1337
1338def show_docs(opts):
1339    # type: (Dict[str, Any]) -> None
1340    """
1341    show html docs for the model
1342    """
1343    import os
1344    from .generate import make_html
1345    from . import rst2html
1346
1347    info = opts['def'][0]
1348    html = make_html(info)
1349    path = os.path.dirname(info.filename)
1350    url = "file://"+path.replace("\\","/")[2:]+"/"
1351    rst2html.view_html_qtapp(html, url)
1352
1353def explore(opts):
1354    # type: (Dict[str, Any]) -> None
1355    """
1356    explore the model using the bumps gui.
1357    """
1358    import wx  # type: ignore
1359    from bumps.names import FitProblem  # type: ignore
1360    from bumps.gui.app_frame import AppFrame  # type: ignore
1361    from bumps.gui import signal
1362
1363    is_mac = "cocoa" in wx.version()
1364    # Create an app if not running embedded
1365    app = wx.App() if wx.GetApp() is None else None
1366    model = Explore(opts)
1367    problem = FitProblem(model)
1368    frame = AppFrame(parent=None, title="explore", size=(1000, 700))
1369    if not is_mac:
1370        frame.Show()
1371    frame.panel.set_model(model=problem)
1372    frame.panel.Layout()
1373    frame.panel.aui.Split(0, wx.TOP)
1374    def reset_parameters(event):
1375        model.revert_values()
1376        signal.update_parameters(problem)
1377    frame.Bind(wx.EVT_TOOL, reset_parameters, frame.ToolBar.GetToolByPos(1))
1378    if is_mac: frame.Show()
1379    # If running withing an app, start the main loop
1380    if app:
1381        app.MainLoop()
1382
1383class Explore(object):
1384    """
1385    Bumps wrapper for a SAS model comparison.
1386
1387    The resulting object can be used as a Bumps fit problem so that
1388    parameters can be adjusted in the GUI, with plots updated on the fly.
1389    """
1390    def __init__(self, opts):
1391        # type: (Dict[str, Any]) -> None
1392        from bumps.cli import config_matplotlib  # type: ignore
1393        from . import bumps_model
1394        config_matplotlib()
1395        self.opts = opts
1396        opts['pars'] = list(opts['pars'])
1397        p1, p2 = opts['pars']
1398        m1, m2 = opts['def']
1399        self.fix_p2 = m1 != m2 or p1 != p2
1400        model_info = m1
1401        pars, pd_types = bumps_model.create_parameters(model_info, **p1)
1402        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
1403        if not opts['is2d']:
1404            for p in model_info.parameters.user_parameters({}, is2d=False):
1405                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
1406                    k = p.name+ext
1407                    v = pars.get(k, None)
1408                    if v is not None:
1409                        v.range(*parameter_range(k, v.value))
1410        else:
1411            for k, v in pars.items():
1412                v.range(*parameter_range(k, v.value))
1413
1414        self.pars = pars
1415        self.starting_values = dict((k, v.value) for k, v in pars.items())
1416        self.pd_types = pd_types
1417        self.limits = np.Inf, -np.Inf
1418
1419    def revert_values(self):
1420        for k, v in self.starting_values.items():
1421            self.pars[k].value = v
1422
1423    def model_update(self):
1424        pass
1425
1426    def numpoints(self):
1427        # type: () -> int
1428        """
1429        Return the number of points.
1430        """
1431        return len(self.pars) + 1  # so dof is 1
1432
1433    def parameters(self):
1434        # type: () -> Any   # Dict/List hierarchy of parameters
1435        """
1436        Return a dictionary of parameters.
1437        """
1438        return self.pars
1439
1440    def nllf(self):
1441        # type: () -> float
1442        """
1443        Return cost.
1444        """
1445        # pylint: disable=no-self-use
1446        return 0.  # No nllf
1447
1448    def plot(self, view='log'):
1449        # type: (str) -> None
1450        """
1451        Plot the data and residuals.
1452        """
1453        pars = dict((k, v.value) for k, v in self.pars.items())
1454        pars.update(self.pd_types)
1455        self.opts['pars'][0] = pars
1456        if not self.fix_p2:
1457            self.opts['pars'][1] = pars
1458        result = run_models(self.opts)
1459        limits = plot_models(self.opts, result, limits=self.limits)
1460        if self.limits is None:
1461            vmin, vmax = limits
1462            self.limits = vmax*1e-7, 1.3*vmax
1463            import pylab; pylab.clf()
1464            plot_models(self.opts, result, limits=self.limits)
1465
1466
1467def main(*argv):
1468    # type: (*str) -> None
1469    """
1470    Main program.
1471    """
1472    opts = parse_opts(argv)
1473    if opts is not None:
1474        if opts['seed'] > -1:
1475            print("Randomize using -random=%i"%opts['seed'])
1476            np.random.seed(opts['seed'])
1477        if opts['html']:
1478            show_docs(opts)
1479        elif opts['explore']:
1480            opts['pars'] = parse_pars(opts)
1481            if opts['pars'] is None:
1482                return
1483            explore(opts)
1484        else:
1485            compare(opts)
1486
1487if __name__ == "__main__":
1488    main(*sys.argv[1:])
Note: See TracBrowser for help on using the repository browser.