source: sasmodels/sasmodels/compare.py @ d5e650d

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since d5e650d was d5e650d, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

fix option for plotting absolute rather than relative error in compare

  • Property mode set to 100755
File size: 29.0 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33from os.path import basename, dirname, join as joinpath
34import glob
35import datetime
36import traceback
37
38import numpy as np
39
40from . import core
41from . import kerneldll
42from . import generate
43from .data import plot_theory, empty_data1D, empty_data2D
44from .direct_model import DirectModel
45from .convert import revert_model, constrain_new_to_old
46
47USAGE = """
48usage: compare.py model N1 N2 [options...] [key=val]
49
50Compare the speed and value for a model between the SasView original and the
51sasmodels rewrite.
52
53model is the name of the model to compare (see below).
54N1 is the number of times to run sasmodels (default=1).
55N2 is the number times to run sasview (default=1).
56
57Options (* for default):
58
59    -plot*/-noplot plots or suppress the plot of the model
60    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
61    -nq=128 sets the number of Q points in the data set
62    -1d*/-2d computes 1d or 2d data
63    -preset*/-random[=seed] preset or random parameters
64    -mono/-poly* force monodisperse/polydisperse
65    -cutoff=1e-5* cutoff value for including a point in polydispersity
66    -pars/-nopars* prints the parameter set or not
67    -abs/-rel* plot relative or absolute error
68    -linear/-log*/-q4 intensity scaling
69    -hist/-nohist* plot histogram of relative error
70    -res=0 sets the resolution width dQ/Q if calculating with resolution
71    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
72    -edit starts the parameter explorer
73
74Any two calculation engines can be selected for comparison:
75
76    -single/-double/-half/-fast sets an OpenCL calculation engine
77    -single!/-double!/-quad! sets an OpenMP calculation engine
78    -sasview sets the sasview calculation engine
79
80The default is -single -sasview.  Note that the interpretation of quad
81precision depends on architecture, and may vary from 64-bit to 128-bit,
82with 80-bit floats being common (1e-19 precision).
83
84Key=value pairs allow you to set specific values for the model parameters.
85"""
86
87# Update docs with command line usage string.   This is separate from the usual
88# doc string so that we can display it at run time if there is an error.
89# lin
90__doc__ = (__doc__  # pylint: disable=redefined-builtin
91           + """
92Program description
93-------------------
94
95"""
96           + USAGE)
97
98kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
99
100# List of available models
101ROOT = dirname(__file__)
102MODELS = [basename(f)[:-3]
103          for f in sorted(glob.glob(joinpath(ROOT, "models", "[a-zA-Z]*.py")))]
104
105# CRUFT python 2.6
106if not hasattr(datetime.timedelta, 'total_seconds'):
107    def delay(dt):
108        """Return number date-time delta as number seconds"""
109        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
110else:
111    def delay(dt):
112        """Return number date-time delta as number seconds"""
113        return dt.total_seconds()
114
115
116class push_seed(object):
117    """
118    Set the seed value for the random number generator.
119
120    When used in a with statement, the random number generator state is
121    restored after the with statement is complete.
122
123    :Parameters:
124
125    *seed* : int or array_like, optional
126        Seed for RandomState
127
128    :Example:
129
130    Seed can be used directly to set the seed::
131
132        >>> from numpy.random import randint
133        >>> push_seed(24)
134        <...push_seed object at...>
135        >>> print(randint(0,1000000,3))
136        [242082    899 211136]
137
138    Seed can also be used in a with statement, which sets the random
139    number generator state for the enclosed computations and restores
140    it to the previous state on completion::
141
142        >>> with push_seed(24):
143        ...    print(randint(0,1000000,3))
144        [242082    899 211136]
145
146    Using nested contexts, we can demonstrate that state is indeed
147    restored after the block completes::
148
149        >>> with push_seed(24):
150        ...    print(randint(0,1000000))
151        ...    with push_seed(24):
152        ...        print(randint(0,1000000,3))
153        ...    print(randint(0,1000000))
154        242082
155        [242082    899 211136]
156        899
157
158    The restore step is protected against exceptions in the block::
159
160        >>> with push_seed(24):
161        ...    print(randint(0,1000000))
162        ...    try:
163        ...        with push_seed(24):
164        ...            print(randint(0,1000000,3))
165        ...            raise Exception()
166        ...    except:
167        ...        print("Exception raised")
168        ...    print(randint(0,1000000))
169        242082
170        [242082    899 211136]
171        Exception raised
172        899
173    """
174    def __init__(self, seed=None):
175        self._state = np.random.get_state()
176        np.random.seed(seed)
177
178    def __enter__(self):
179        return None
180
181    def __exit__(self, *args):
182        np.random.set_state(self._state)
183
184def tic():
185    """
186    Timer function.
187
188    Use "toc=tic()" to start the clock and "toc()" to measure
189    a time interval.
190    """
191    then = datetime.datetime.now()
192    return lambda: delay(datetime.datetime.now() - then)
193
194
195def set_beam_stop(data, radius, outer=None):
196    """
197    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
198
199    Note: this function does not use the sasview package
200    """
201    if hasattr(data, 'qx_data'):
202        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
203        data.mask = (q < radius)
204        if outer is not None:
205            data.mask |= (q >= outer)
206    else:
207        data.mask = (data.x < radius)
208        if outer is not None:
209            data.mask |= (data.x >= outer)
210
211
212def parameter_range(p, v):
213    """
214    Choose a parameter range based on parameter name and initial value.
215    """
216    if p.endswith('_pd_n'):
217        return [0, 100]
218    elif p.endswith('_pd_nsigma'):
219        return [0, 5]
220    elif p.endswith('_pd_type'):
221        return v
222    elif any(s in p for s in ('theta', 'phi', 'psi')):
223        # orientation in [-180,180], orientation pd in [0,45]
224        if p.endswith('_pd'):
225            return [0, 45]
226        else:
227            return [-180, 180]
228    elif 'sld' in p:
229        return [-0.5, 10]
230    elif p.endswith('_pd'):
231        return [0, 1]
232    elif p == 'background':
233        return [0, 10]
234    elif p == 'scale':
235        return [0, 1e3]
236    elif p == 'case_num':
237        # RPA hack
238        return [0, 10]
239    elif v < 0:
240        # Kxy parameters in rpa model can be negative
241        return [2*v, -2*v]
242    else:
243        return [0, (2*v if v > 0 else 1)]
244
245
246def _randomize_one(p, v):
247    """
248    Randomize a single parameter.
249    """
250    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
251        return v
252    else:
253        return np.random.uniform(*parameter_range(p, v))
254
255
256def randomize_pars(pars, seed=None):
257    """
258    Generate random values for all of the parameters.
259
260    Valid ranges for the random number generator are guessed from the name of
261    the parameter; this will not account for constraints such as cap radius
262    greater than cylinder radius in the capped_cylinder model, so
263    :func:`constrain_pars` needs to be called afterward..
264    """
265    with push_seed(seed):
266        # Note: the sort guarantees order `of calls to random number generator
267        pars = dict((p, _randomize_one(p, v))
268                    for p, v in sorted(pars.items()))
269    return pars
270
271def constrain_pars(model_definition, pars):
272    """
273    Restrict parameters to valid values.
274
275    This includes model specific code for models such as capped_cylinder
276    which need to support within model constraints (cap radius more than
277    cylinder radius in this case).
278    """
279    name = model_definition.name
280    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
281        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
282    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
283        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
284
285    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
286    if name == 'guinier':
287        #q_max = 0.2  # mid q maximum
288        q_max = 1.0  # high q maximum
289        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
290        pars['rg'] = min(pars['rg'], rg_max)
291
292    if name == 'rpa':
293        # Make sure phi sums to 1.0
294        if pars['case_num'] < 2:
295            pars['Phia'] = 0.
296            pars['Phib'] = 0.
297        elif pars['case_num'] < 5:
298            pars['Phia'] = 0.
299        total = sum(pars['Phi'+c] for c in 'abcd')
300        for c in 'abcd':
301            pars['Phi'+c] /= total
302
303def parlist(pars):
304    """
305    Format the parameter list for printing.
306    """
307    return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
308
309def suppress_pd(pars):
310    """
311    Suppress theta_pd for now until the normalization is resolved.
312
313    May also suppress complete polydispersity of the model to test
314    models more quickly.
315    """
316    pars = pars.copy()
317    for p in pars:
318        if p.endswith("_pd_n"): pars[p] = 0
319    return pars
320
321def eval_sasview(model_definition, data):
322    """
323    Return a model calculator using the SasView fitting engine.
324    """
325    # importing sas here so that the error message will be that sas failed to
326    # import rather than the more obscure smear_selection not imported error
327    import sas
328    from sas.models.qsmearing import smear_selection
329
330    # convert model parameters from sasmodel form to sasview form
331    #print("old",sorted(pars.items()))
332    modelname, _ = revert_model(model_definition, {})
333    #print("new",sorted(_pars.items()))
334    sas = __import__('sas.models.'+modelname)
335    ModelClass = getattr(getattr(sas.models, modelname, None), modelname, None)
336    if ModelClass is None:
337        raise ValueError("could not find model %r in sas.models"%modelname)
338    model = ModelClass()
339    smearer = smear_selection(data, model=model)
340
341    if hasattr(data, 'qx_data'):
342        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
343        index = ((~data.mask) & (~np.isnan(data.data))
344                 & (q >= data.qmin) & (q <= data.qmax))
345        if smearer is not None:
346            smearer.model = model  # because smear_selection has a bug
347            smearer.accuracy = data.accuracy
348            smearer.set_index(index)
349            theory = lambda: smearer.get_value()
350        else:
351            theory = lambda: model.evalDistribution([data.qx_data[index],
352                                                     data.qy_data[index]])
353    elif smearer is not None:
354        theory = lambda: smearer(model.evalDistribution(data.x))
355    else:
356        theory = lambda: model.evalDistribution(data.x)
357
358    def calculator(**pars):
359        """
360        Sasview calculator for model.
361        """
362        # paying for parameter conversion each time to keep life simple, if not fast
363        _, pars = revert_model(model_definition, pars)
364        for k, v in pars.items():
365            parts = k.split('.')  # polydispersity components
366            if len(parts) == 2:
367                model.dispersion[parts[0]][parts[1]] = v
368            else:
369                model.setParam(k, v)
370        return theory()
371
372    calculator.engine = "sasview"
373    return calculator
374
375DTYPE_MAP = {
376    'half': '16',
377    'fast': 'fast',
378    'single': '32',
379    'double': '64',
380    'quad': '128',
381    'f16': '16',
382    'f32': '32',
383    'f64': '64',
384    'longdouble': '128',
385}
386def eval_opencl(model_definition, data, dtype='single', cutoff=0.):
387    """
388    Return a model calculator using the OpenCL calculation engine.
389    """
390    try:
391        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
392    except Exception as exc:
393        print(exc)
394        print("... trying again with single precision")
395        dtype = 'single'
396        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
397    calculator = DirectModel(data, model, cutoff=cutoff)
398    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
399    return calculator
400
401def eval_ctypes(model_definition, data, dtype='double', cutoff=0.):
402    """
403    Return a model calculator using the DLL calculation engine.
404    """
405    if dtype == 'quad':
406        dtype = 'longdouble'
407    model = core.load_model(model_definition, dtype=dtype, platform="dll")
408    calculator = DirectModel(data, model, cutoff=cutoff)
409    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
410    return calculator
411
412def time_calculation(calculator, pars, Nevals=1):
413    """
414    Compute the average calculation time over N evaluations.
415
416    An additional call is generated without polydispersity in order to
417    initialize the calculation engine, and make the average more stable.
418    """
419    # initialize the code so time is more accurate
420    value = calculator(**suppress_pd(pars))
421    toc = tic()
422    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
423        value = calculator(**pars)
424    average_time = toc()*1000./Nevals
425    return value, average_time
426
427def make_data(opts):
428    """
429    Generate an empty dataset, used with the model to set Q points
430    and resolution.
431
432    *opts* contains the options, with 'qmax', 'nq', 'res',
433    'accuracy', 'is2d' and 'view' parsed from the command line.
434    """
435    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
436    if opts['is2d']:
437        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
438        data.accuracy = opts['accuracy']
439        set_beam_stop(data, 0.004)
440        index = ~data.mask
441    else:
442        if opts['view'] == 'log':
443            qmax = math.log10(qmax)
444            q = np.logspace(qmax-3, qmax, nq)
445        else:
446            q = np.linspace(0.001*qmax, qmax, nq)
447        data = empty_data1D(q, resolution=res)
448        index = slice(None, None)
449    return data, index
450
451def make_engine(model_definition, data, dtype, cutoff):
452    """
453    Generate the appropriate calculation engine for the given datatype.
454
455    Datatypes with '!' appended are evaluated using external C DLLs rather
456    than OpenCL.
457    """
458    if dtype == 'sasview':
459        return eval_sasview(model_definition, data)
460    elif dtype.endswith('!'):
461        return eval_ctypes(model_definition, data, dtype=dtype[:-1],
462                           cutoff=cutoff)
463    else:
464        return eval_opencl(model_definition, data, dtype=dtype,
465                           cutoff=cutoff)
466
467def compare(opts, limits=None):
468    """
469    Preform a comparison using options from the command line.
470
471    *limits* are the limits on the values to use, either to set the y-axis
472    for 1D or to set the colormap scale for 2D.  If None, then they are
473    inferred from the data and returned. When exploring using Bumps,
474    the limits are set when the model is initially called, and maintained
475    as the values are adjusted, making it easier to see the effects of the
476    parameters.
477    """
478    Nbase, Ncomp = opts['n1'], opts['n2']
479    pars = opts['pars']
480    data = opts['data']
481
482    # Base calculation
483    if Nbase > 0:
484        base = opts['engines'][0]
485        try:
486            base_value, base_time = time_calculation(base, pars, Nbase)
487            print("%s t=%.1f ms, intensity=%.0f"
488                  % (base.engine, base_time, sum(base_value)))
489        except ImportError:
490            traceback.print_exc()
491            Nbase = 0
492
493    # Comparison calculation
494    if Ncomp > 0:
495        comp = opts['engines'][1]
496        try:
497            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
498            print("%s t=%.1f ms, intensity=%.0f"
499                  % (comp.engine, comp_time, sum(comp_value)))
500        except ImportError:
501            traceback.print_exc()
502            Ncomp = 0
503
504    # Compare, but only if computing both forms
505    if Nbase > 0 and Ncomp > 0:
506        resid = (base_value - comp_value)
507        relerr = resid/comp_value
508        _print_stats("|%s-%s|"
509                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
510                     resid)
511        _print_stats("|(%s-%s)/%s|"
512                     % (base.engine, comp.engine, comp.engine),
513                     relerr)
514
515    # Plot if requested
516    if not opts['plot'] and not opts['explore']: return
517    view = opts['view']
518    import matplotlib.pyplot as plt
519    if limits is None:
520        vmin, vmax = np.Inf, -np.Inf
521        if Nbase > 0:
522            vmin = min(vmin, min(base_value))
523            vmax = max(vmax, max(base_value))
524        if Ncomp > 0:
525            vmin = min(vmin, min(comp_value))
526            vmax = max(vmax, max(comp_value))
527        limits = vmin, vmax
528
529    if Nbase > 0:
530        if Ncomp > 0: plt.subplot(131)
531        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
532        plt.title("%s t=%.1f ms"%(base.engine, base_time))
533        #cbar_title = "log I"
534    if Ncomp > 0:
535        if Nbase > 0: plt.subplot(132)
536        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
537        plt.title("%s t=%.1f ms"%(comp.engine, comp_time))
538        #cbar_title = "log I"
539    if Ncomp > 0 and Nbase > 0:
540        plt.subplot(133)
541        if not opts['rel_err']:
542            err, errstr, errview = resid, "abs err", "linear"
543        else:
544            err, errstr, errview = abs(relerr), "rel err", "log"
545        #err,errstr = base/comp,"ratio"
546        plot_theory(data, None, resid=err, view=errview, use_data=False)
547        if view == 'linear':
548            plt.xscale('linear')
549        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
550        #cbar_title = errstr if errview=="linear" else "log "+errstr
551    #if is2D:
552    #    h = plt.colorbar()
553    #    h.ax.set_title(cbar_title)
554
555    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
556        plt.figure()
557        v = relerr
558        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
559        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
560        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
561                   % (base.engine, comp.engine, comp.engine))
562        plt.ylabel('P(err)')
563        plt.title('Distribution of relative error between calculation engines')
564
565    if not opts['explore']:
566        plt.show()
567
568    return limits
569
570def _print_stats(label, err):
571    sorted_err = np.sort(abs(err))
572    p50 = int((len(err)-1)*0.50)
573    p98 = int((len(err)-1)*0.98)
574    data = [
575        "max:%.3e"%sorted_err[-1],
576        "median:%.3e"%sorted_err[p50],
577        "98%%:%.3e"%sorted_err[p98],
578        "rms:%.3e"%np.sqrt(np.mean(err**2)),
579        "zero-offset:%+.3e"%np.mean(err),
580        ]
581    print(label+"  "+"  ".join(data))
582
583
584
585# ===========================================================================
586#
587NAME_OPTIONS = set([
588    'plot', 'noplot',
589    'half', 'fast', 'single', 'double',
590    'single!', 'double!', 'quad!', 'sasview',
591    'lowq', 'midq', 'highq', 'exq',
592    '2d', '1d',
593    'preset', 'random',
594    'poly', 'mono',
595    'nopars', 'pars',
596    'rel', 'abs',
597    'linear', 'log', 'q4',
598    'hist', 'nohist',
599    'edit',
600    ])
601VALUE_OPTIONS = [
602    # Note: random is both a name option and a value option
603    'cutoff', 'random', 'nq', 'res', 'accuracy',
604    ]
605
606def columnize(L, indent="", width=79):
607    """
608    Format a list of strings into columns.
609
610    Returns a string with carriage returns ready for printing.
611    """
612    column_width = max(len(w) for w in L) + 1
613    num_columns = (width - len(indent)) // column_width
614    num_rows = len(L) // num_columns
615    L = L + [""] * (num_rows*num_columns - len(L))
616    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
617    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
618             for row in zip(*columns)]
619    output = indent + ("\n"+indent).join(lines)
620    return output
621
622
623def get_demo_pars(model_definition):
624    """
625    Extract demo parameters from the model definition.
626    """
627    info = generate.make_info(model_definition)
628    # Get the default values for the parameters
629    pars = dict((p[0], p[2]) for p in info['parameters'])
630
631    # Fill in default values for the polydispersity parameters
632    for p in info['parameters']:
633        if p[4] in ('volume', 'orientation'):
634            pars[p[0]+'_pd'] = 0.0
635            pars[p[0]+'_pd_n'] = 0
636            pars[p[0]+'_pd_nsigma'] = 3.0
637            pars[p[0]+'_pd_type'] = "gaussian"
638
639    # Plug in values given in demo
640    pars.update(info['demo'])
641    return pars
642
643def parse_opts():
644    """
645    Parse command line options.
646    """
647    flags = [arg for arg in sys.argv[1:]
648             if arg.startswith('-')]
649    values = [arg for arg in sys.argv[1:]
650              if not arg.startswith('-') and '=' in arg]
651    args = [arg for arg in sys.argv[1:]
652            if not arg.startswith('-') and '=' not in arg]
653    models = "\n    ".join("%-15s"%v for v in MODELS)
654    if len(args) == 0:
655        print(USAGE)
656        print("\nAvailable models:")
657        print(columnize(MODELS, indent="  "))
658        sys.exit(1)
659    if args[0] not in MODELS:
660        print("Model %r not available. Use one of:\n    %s"%(args[0], models))
661        sys.exit(1)
662    if len(args) > 3:
663        print("expected parameters: model N1 N2")
664
665    invalid = [o[1:] for o in flags
666               if o[1:] not in NAME_OPTIONS
667               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
668    if invalid:
669        print("Invalid options: %s"%(", ".join(invalid)))
670        sys.exit(1)
671
672
673    # pylint: disable=bad-whitespace
674    # Interpret the flags
675    opts = {
676        'plot'      : True,
677        'view'      : 'log',
678        'is2d'      : False,
679        'qmax'      : 0.05,
680        'nq'        : 128,
681        'res'       : 0.0,
682        'accuracy'  : 'Low',
683        'cutoff'    : 1e-5,
684        'seed'      : -1,  # default to preset
685        'mono'      : False,
686        'show_pars' : False,
687        'show_hist' : False,
688        'rel_err'   : True,
689        'explore'   : False,
690    }
691    engines = []
692    for arg in flags:
693        if arg == '-noplot':    opts['plot'] = False
694        elif arg == '-plot':    opts['plot'] = True
695        elif arg == '-linear':  opts['view'] = 'linear'
696        elif arg == '-log':     opts['view'] = 'log'
697        elif arg == '-q4':      opts['view'] = 'q4'
698        elif arg == '-1d':      opts['is2d'] = False
699        elif arg == '-2d':      opts['is2d'] = True
700        elif arg == '-exq':     opts['qmax'] = 10.0
701        elif arg == '-highq':   opts['qmax'] = 1.0
702        elif arg == '-midq':    opts['qmax'] = 0.2
703        elif arg == '-loq':     opts['qmax'] = 0.05
704        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
705        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
706        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
707        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
708        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
709        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
710        elif arg == '-preset':  opts['seed'] = -1
711        elif arg == '-mono':    opts['mono'] = True
712        elif arg == '-poly':    opts['mono'] = False
713        elif arg == '-pars':    opts['show_pars'] = True
714        elif arg == '-nopars':  opts['show_pars'] = False
715        elif arg == '-hist':    opts['show_hist'] = True
716        elif arg == '-nohist':  opts['show_hist'] = False
717        elif arg == '-rel':     opts['rel_err'] = True
718        elif arg == '-abs':     opts['rel_err'] = False
719        elif arg == '-half':    engines.append(arg[1:])
720        elif arg == '-fast':    engines.append(arg[1:])
721        elif arg == '-single':  engines.append(arg[1:])
722        elif arg == '-double':  engines.append(arg[1:])
723        elif arg == '-single!': engines.append(arg[1:])
724        elif arg == '-double!': engines.append(arg[1:])
725        elif arg == '-quad!':   engines.append(arg[1:])
726        elif arg == '-sasview': engines.append(arg[1:])
727        elif arg == '-edit':    opts['explore'] = True
728    # pylint: enable=bad-whitespace
729
730    if len(engines) == 0:
731        engines.extend(['single', 'sasview'])
732    elif len(engines) == 1:
733        if engines[0][0] != 'sasview':
734            engines.append('sasview')
735        else:
736            engines.append('single')
737    elif len(engines) > 2:
738        del engines[2:]
739
740    name = args[0]
741    model_definition = core.load_model_definition(name)
742
743    n1 = int(args[1]) if len(args) > 1 else 1
744    n2 = int(args[2]) if len(args) > 2 else 1
745
746    # Get demo parameters from model definition, or use default parameters
747    # if model does not define demo parameters
748    pars = get_demo_pars(model_definition)
749
750    # Fill in parameters given on the command line
751    presets = {}
752    for arg in values:
753        k, v = arg.split('=', 1)
754        if k not in pars:
755            # extract base name without polydispersity info
756            s = set(p.split('_pd')[0] for p in pars)
757            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
758            sys.exit(1)
759        presets[k] = float(v) if not k.endswith('type') else v
760
761    # randomize parameters
762    #pars.update(set_pars)  # set value before random to control range
763    if opts['seed'] > -1:
764        pars = randomize_pars(pars, seed=opts['seed'])
765        print("Randomize using -random=%i"%opts['seed'])
766    if opts['mono']:
767        pars = suppress_pd(pars)
768    pars.update(presets)  # set value after random to control value
769    constrain_pars(model_definition, pars)
770    constrain_new_to_old(model_definition, pars)
771    if opts['show_pars']:
772        print("pars " + str(parlist(pars)))
773
774    # Create the computational engines
775    data, _ = make_data(opts)
776    if n1:
777        base = make_engine(model_definition, data, engines[0], opts['cutoff'])
778    else:
779        base = None
780    if n2:
781        comp = make_engine(model_definition, data, engines[1], opts['cutoff'])
782    else:
783        comp = None
784
785    # pylint: disable=bad-whitespace
786    # Remember it all
787    opts.update({
788        'name'      : name,
789        'def'       : model_definition,
790        'n1'        : n1,
791        'n2'        : n2,
792        'presets'   : presets,
793        'pars'      : pars,
794        'data'      : data,
795        'engines'   : [base, comp],
796    })
797    # pylint: enable=bad-whitespace
798
799    return opts
800
801def explore(opts):
802    """
803    Explore the model using the Bumps GUI.
804    """
805    import wx
806    from bumps.names import FitProblem
807    from bumps.gui.app_frame import AppFrame
808
809    problem = FitProblem(Explore(opts))
810    is_mac = "cocoa" in wx.version()
811    app = wx.App()
812    frame = AppFrame(parent=None, title="explore")
813    if not is_mac: frame.Show()
814    frame.panel.set_model(model=problem)
815    frame.panel.Layout()
816    frame.panel.aui.Split(0, wx.TOP)
817    if is_mac: frame.Show()
818    app.MainLoop()
819
820class Explore(object):
821    """
822    Bumps wrapper for a SAS model comparison.
823
824    The resulting object can be used as a Bumps fit problem so that
825    parameters can be adjusted in the GUI, with plots updated on the fly.
826    """
827    def __init__(self, opts):
828        from bumps.cli import config_matplotlib
829        from . import bumps_model
830        config_matplotlib()
831        self.opts = opts
832        info = generate.make_info(opts['def'])
833        pars, pd_types = bumps_model.create_parameters(info, **opts['pars'])
834        if not opts['is2d']:
835            active = [base + ext
836                      for base in info['partype']['pd-1d']
837                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
838            active.extend(info['partype']['fixed-1d'])
839            for k in active:
840                v = pars[k]
841                v.range(*parameter_range(k, v.value))
842        else:
843            for k, v in pars.items():
844                v.range(*parameter_range(k, v.value))
845
846        self.pars = pars
847        self.pd_types = pd_types
848        self.limits = None
849
850    def numpoints(self):
851        """
852        Return the number of points.
853        """
854        return len(self.pars) + 1  # so dof is 1
855
856    def parameters(self):
857        """
858        Return a dictionary of parameters.
859        """
860        return self.pars
861
862    def nllf(self):
863        """
864        Return cost.
865        """
866        # pylint: disable=no-self-use
867        return 0.  # No nllf
868
869    def plot(self, view='log'):
870        """
871        Plot the data and residuals.
872        """
873        pars = dict((k, v.value) for k, v in self.pars.items())
874        pars.update(self.pd_types)
875        self.opts['pars'] = pars
876        limits = compare(self.opts, limits=self.limits)
877        if self.limits is None:
878            vmin, vmax = limits
879            vmax = 1.3*vmax
880            vmin = vmax*1e-7
881            self.limits = vmin, vmax
882
883
884def main():
885    """
886    Main program.
887    """
888    opts = parse_opts()
889    if opts['explore']:
890        explore(opts)
891    else:
892        compare(opts)
893
894if __name__ == "__main__":
895    main()
Note: See TracBrowser for help on using the repository browser.