source: sasmodels/sasmodels/compare.py @ fa1582e

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since fa1582e was fa1582e, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

remove constraint which set n_stacking=1 on stacked disk model in compare

  • Property mode set to 100755
File size: 30.6 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np
37
38from . import core
39from . import kerneldll
40from . import product
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_pars, constrain_new_to_old
44
45USAGE = """
46usage: compare.py model N1 N2 [options...] [key=val]
47
48Compare the speed and value for a model between the SasView original and the
49sasmodels rewrite.
50
51model is the name of the model to compare (see below).
52N1 is the number of times to run sasmodels (default=1).
53N2 is the number times to run sasview (default=1).
54
55Options (* for default):
56
57    -plot*/-noplot plots or suppress the plot of the model
58    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
59    -nq=128 sets the number of Q points in the data set
60    -1d*/-2d computes 1d or 2d data
61    -preset*/-random[=seed] preset or random parameters
62    -mono/-poly* force monodisperse/polydisperse
63    -cutoff=1e-5* cutoff value for including a point in polydispersity
64    -pars/-nopars* prints the parameter set or not
65    -abs/-rel* plot relative or absolute error
66    -linear/-log*/-q4 intensity scaling
67    -hist/-nohist* plot histogram of relative error
68    -res=0 sets the resolution width dQ/Q if calculating with resolution
69    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
70    -edit starts the parameter explorer
71    -default/-demo* use demo vs default parameters
72
73Any two calculation engines can be selected for comparison:
74
75    -single/-double/-half/-fast sets an OpenCL calculation engine
76    -single!/-double!/-quad! sets an OpenMP calculation engine
77    -sasview sets the sasview calculation engine
78
79The default is -single -sasview.  Note that the interpretation of quad
80precision depends on architecture, and may vary from 64-bit to 128-bit,
81with 80-bit floats being common (1e-19 precision).
82
83Key=value pairs allow you to set specific values for the model parameters.
84"""
85
86# Update docs with command line usage string.   This is separate from the usual
87# doc string so that we can display it at run time if there is an error.
88# lin
89__doc__ = (__doc__  # pylint: disable=redefined-builtin
90           + """
91Program description
92-------------------
93
94"""
95           + USAGE)
96
97kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
98
99MODELS = core.list_models()
100
101# CRUFT python 2.6
102if not hasattr(datetime.timedelta, 'total_seconds'):
103    def delay(dt):
104        """Return number date-time delta as number seconds"""
105        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
106else:
107    def delay(dt):
108        """Return number date-time delta as number seconds"""
109        return dt.total_seconds()
110
111
112class push_seed(object):
113    """
114    Set the seed value for the random number generator.
115
116    When used in a with statement, the random number generator state is
117    restored after the with statement is complete.
118
119    :Parameters:
120
121    *seed* : int or array_like, optional
122        Seed for RandomState
123
124    :Example:
125
126    Seed can be used directly to set the seed::
127
128        >>> from numpy.random import randint
129        >>> push_seed(24)
130        <...push_seed object at...>
131        >>> print(randint(0,1000000,3))
132        [242082    899 211136]
133
134    Seed can also be used in a with statement, which sets the random
135    number generator state for the enclosed computations and restores
136    it to the previous state on completion::
137
138        >>> with push_seed(24):
139        ...    print(randint(0,1000000,3))
140        [242082    899 211136]
141
142    Using nested contexts, we can demonstrate that state is indeed
143    restored after the block completes::
144
145        >>> with push_seed(24):
146        ...    print(randint(0,1000000))
147        ...    with push_seed(24):
148        ...        print(randint(0,1000000,3))
149        ...    print(randint(0,1000000))
150        242082
151        [242082    899 211136]
152        899
153
154    The restore step is protected against exceptions in the block::
155
156        >>> with push_seed(24):
157        ...    print(randint(0,1000000))
158        ...    try:
159        ...        with push_seed(24):
160        ...            print(randint(0,1000000,3))
161        ...            raise Exception()
162        ...    except:
163        ...        print("Exception raised")
164        ...    print(randint(0,1000000))
165        242082
166        [242082    899 211136]
167        Exception raised
168        899
169    """
170    def __init__(self, seed=None):
171        self._state = np.random.get_state()
172        np.random.seed(seed)
173
174    def __enter__(self):
175        return None
176
177    def __exit__(self, *args):
178        np.random.set_state(self._state)
179
180def tic():
181    """
182    Timer function.
183
184    Use "toc=tic()" to start the clock and "toc()" to measure
185    a time interval.
186    """
187    then = datetime.datetime.now()
188    return lambda: delay(datetime.datetime.now() - then)
189
190
191def set_beam_stop(data, radius, outer=None):
192    """
193    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
194
195    Note: this function does not use the sasview package
196    """
197    if hasattr(data, 'qx_data'):
198        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
199        data.mask = (q < radius)
200        if outer is not None:
201            data.mask |= (q >= outer)
202    else:
203        data.mask = (data.x < radius)
204        if outer is not None:
205            data.mask |= (data.x >= outer)
206
207
208def parameter_range(p, v):
209    """
210    Choose a parameter range based on parameter name and initial value.
211    """
212    if p.endswith('_pd_n'):
213        return [0, 100]
214    elif p.endswith('_pd_nsigma'):
215        return [0, 5]
216    elif p.endswith('_pd_type'):
217        return v
218    elif any(s in p for s in ('theta', 'phi', 'psi')):
219        # orientation in [-180,180], orientation pd in [0,45]
220        if p.endswith('_pd'):
221            return [0, 45]
222        else:
223            return [-180, 180]
224    elif 'sld' in p:
225        return [-0.5, 10]
226    elif p.endswith('_pd'):
227        return [0, 1]
228    elif p == 'background':
229        return [0, 10]
230    elif p == 'scale':
231        return [0, 1e3]
232    elif p == 'case_num':
233        # RPA hack
234        return [0, 10]
235    elif v < 0:
236        # Kxy parameters in rpa model can be negative
237        return [2*v, -2*v]
238    else:
239        return [0, (2*v if v > 0 else 1)]
240
241
242def _randomize_one(p, v):
243    """
244    Randomize a single parameter.
245    """
246    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
247        return v
248    else:
249        return np.random.uniform(*parameter_range(p, v))
250
251
252def randomize_pars(pars, seed=None):
253    """
254    Generate random values for all of the parameters.
255
256    Valid ranges for the random number generator are guessed from the name of
257    the parameter; this will not account for constraints such as cap radius
258    greater than cylinder radius in the capped_cylinder model, so
259    :func:`constrain_pars` needs to be called afterward..
260    """
261    with push_seed(seed):
262        # Note: the sort guarantees order `of calls to random number generator
263        pars = dict((p, _randomize_one(p, v))
264                    for p, v in sorted(pars.items()))
265    return pars
266
267def constrain_pars(model_info, pars):
268    """
269    Restrict parameters to valid values.
270
271    This includes model specific code for models such as capped_cylinder
272    which need to support within model constraints (cap radius more than
273    cylinder radius in this case).
274    """
275    name = model_info['id']
276    # if it is a product model, then just look at the form factor since
277    # none of the structure factors need any constraints.
278    if '*' in name:
279        name = name.split('*')[0]
280
281    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
282        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
283    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
284        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
285
286    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
287    if name == 'guinier':
288        #q_max = 0.2  # mid q maximum
289        q_max = 1.0  # high q maximum
290        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
291        pars['rg'] = min(pars['rg'], rg_max)
292
293    if name == 'rpa':
294        # Make sure phi sums to 1.0
295        if pars['case_num'] < 2:
296            pars['Phia'] = 0.
297            pars['Phib'] = 0.
298        elif pars['case_num'] < 5:
299            pars['Phia'] = 0.
300        total = sum(pars['Phi'+c] for c in 'abcd')
301        for c in 'abcd':
302            pars['Phi'+c] /= total
303
304def parlist(model_info, pars, is2d):
305    """
306    Format the parameter list for printing.
307    """
308    if is2d:
309        exclude = lambda n: False
310    else:
311        partype = model_info['partype']
312        par1d = set(partype['fixed-1d']+partype['pd-1d'])
313        exclude = lambda n: n not in par1d
314    lines = []
315    for p in model_info['parameters']:
316        if exclude(p.name): continue
317        fields = dict(
318            value=pars.get(p.name, p.default),
319            pd=pars.get(p.name+"_pd", 0.),
320            n=int(pars.get(p.name+"_pd_n", 0)),
321            nsigma=pars.get(p.name+"_pd_nsgima", 3.),
322            type=pars.get(p.name+"_pd_type", 'gaussian'))
323        lines.append(_format_par(p.name, **fields))
324    return "\n".join(lines)
325
326    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
327
328def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
329    line = "%s: %g"%(name, value)
330    if pd != 0.  and n != 0:
331        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
332                % (pd, n, nsigma, nsigma, type)
333    return line
334
335def suppress_pd(pars):
336    """
337    Suppress theta_pd for now until the normalization is resolved.
338
339    May also suppress complete polydispersity of the model to test
340    models more quickly.
341    """
342    pars = pars.copy()
343    for p in pars:
344        if p.endswith("_pd_n"): pars[p] = 0
345    return pars
346
347def eval_sasview(model_info, data):
348    """
349    Return a model calculator using the SasView fitting engine.
350    """
351    # importing sas here so that the error message will be that sas failed to
352    # import rather than the more obscure smear_selection not imported error
353    import sas
354    from sas.models.qsmearing import smear_selection
355
356    def get_model(name):
357        #print("new",sorted(_pars.items()))
358        sas = __import__('sas.models.' + name)
359        ModelClass = getattr(getattr(sas.models, name, None), name, None)
360        if ModelClass is None:
361            raise ValueError("could not find model %r in sas.models"%name)
362        return ModelClass()
363
364    # grab the sasview model, or create it if it is a product model
365    if model_info['composition']:
366        composition_type, parts = model_info['composition']
367        if composition_type == 'product':
368            from sas.sascalc.fit.MultiplicationModel import MultiplicationModel
369            P, S = [get_model(p) for p in model_info['oldname']]
370            model = MultiplicationModel(P, S)
371        else:
372            raise ValueError("sasview mixture models not supported by compare")
373    else:
374        model = get_model(model_info['oldname'])
375
376    # build a smearer with which to call the model, if necessary
377    smearer = smear_selection(data, model=model)
378    if hasattr(data, 'qx_data'):
379        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
380        index = ((~data.mask) & (~np.isnan(data.data))
381                 & (q >= data.qmin) & (q <= data.qmax))
382        if smearer is not None:
383            smearer.model = model  # because smear_selection has a bug
384            smearer.accuracy = data.accuracy
385            smearer.set_index(index)
386            theory = lambda: smearer.get_value()
387        else:
388            theory = lambda: model.evalDistribution([data.qx_data[index],
389                                                     data.qy_data[index]])
390    elif smearer is not None:
391        theory = lambda: smearer(model.evalDistribution(data.x))
392    else:
393        theory = lambda: model.evalDistribution(data.x)
394
395    def calculator(**pars):
396        """
397        Sasview calculator for model.
398        """
399        # paying for parameter conversion each time to keep life simple, if not fast
400        pars = revert_pars(model_info, pars)
401        for k, v in pars.items():
402            parts = k.split('.')  # polydispersity components
403            if len(parts) == 2:
404                model.dispersion[parts[0]][parts[1]] = v
405            else:
406                model.setParam(k, v)
407        return theory()
408
409    calculator.engine = "sasview"
410    return calculator
411
412DTYPE_MAP = {
413    'half': '16',
414    'fast': 'fast',
415    'single': '32',
416    'double': '64',
417    'quad': '128',
418    'f16': '16',
419    'f32': '32',
420    'f64': '64',
421    'longdouble': '128',
422}
423def eval_opencl(model_info, data, dtype='single', cutoff=0.):
424    """
425    Return a model calculator using the OpenCL calculation engine.
426    """
427    try:
428        model = core.build_model(model_info, dtype=dtype, platform="ocl")
429    except Exception as exc:
430        print(exc)
431        print("... trying again with single precision")
432        model = core.build_model(model_info, dtype='single', platform="ocl")
433    calculator = DirectModel(data, model, cutoff=cutoff)
434    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
435    return calculator
436
437def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
438    """
439    Return a model calculator using the DLL calculation engine.
440    """
441    if dtype == 'quad':
442        dtype = 'longdouble'
443    model = core.build_model(model_info, dtype=dtype, platform="dll")
444    calculator = DirectModel(data, model, cutoff=cutoff)
445    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
446    return calculator
447
448def time_calculation(calculator, pars, Nevals=1):
449    """
450    Compute the average calculation time over N evaluations.
451
452    An additional call is generated without polydispersity in order to
453    initialize the calculation engine, and make the average more stable.
454    """
455    # initialize the code so time is more accurate
456    value = calculator(**suppress_pd(pars))
457    toc = tic()
458    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
459        value = calculator(**pars)
460    average_time = toc()*1000./Nevals
461    return value, average_time
462
463def make_data(opts):
464    """
465    Generate an empty dataset, used with the model to set Q points
466    and resolution.
467
468    *opts* contains the options, with 'qmax', 'nq', 'res',
469    'accuracy', 'is2d' and 'view' parsed from the command line.
470    """
471    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
472    if opts['is2d']:
473        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
474        data.accuracy = opts['accuracy']
475        set_beam_stop(data, 0.004)
476        index = ~data.mask
477    else:
478        if opts['view'] == 'log':
479            qmax = math.log10(qmax)
480            q = np.logspace(qmax-3, qmax, nq)
481        else:
482            q = np.linspace(0.001*qmax, qmax, nq)
483        data = empty_data1D(q, resolution=res)
484        index = slice(None, None)
485    return data, index
486
487def make_engine(model_info, data, dtype, cutoff):
488    """
489    Generate the appropriate calculation engine for the given datatype.
490
491    Datatypes with '!' appended are evaluated using external C DLLs rather
492    than OpenCL.
493    """
494    if dtype == 'sasview':
495        return eval_sasview(model_info, data)
496    elif dtype.endswith('!'):
497        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
498    else:
499        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
500
501def compare(opts, limits=None):
502    """
503    Preform a comparison using options from the command line.
504
505    *limits* are the limits on the values to use, either to set the y-axis
506    for 1D or to set the colormap scale for 2D.  If None, then they are
507    inferred from the data and returned. When exploring using Bumps,
508    the limits are set when the model is initially called, and maintained
509    as the values are adjusted, making it easier to see the effects of the
510    parameters.
511    """
512    Nbase, Ncomp = opts['n1'], opts['n2']
513    pars = opts['pars']
514    data = opts['data']
515
516    # Base calculation
517    if Nbase > 0:
518        base = opts['engines'][0]
519        try:
520            base_value, base_time = time_calculation(base, pars, Nbase)
521            print("%s t=%.2f ms, intensity=%.0f"
522                  % (base.engine, base_time, sum(base_value)))
523        except ImportError:
524            traceback.print_exc()
525            Nbase = 0
526
527    # Comparison calculation
528    if Ncomp > 0:
529        comp = opts['engines'][1]
530        try:
531            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
532            print("%s t=%.2f ms, intensity=%.0f"
533                  % (comp.engine, comp_time, sum(comp_value)))
534        except ImportError:
535            traceback.print_exc()
536            Ncomp = 0
537
538    # Compare, but only if computing both forms
539    if Nbase > 0 and Ncomp > 0:
540        resid = (base_value - comp_value)
541        relerr = resid/comp_value
542        _print_stats("|%s-%s|"
543                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
544                     resid)
545        _print_stats("|(%s-%s)/%s|"
546                     % (base.engine, comp.engine, comp.engine),
547                     relerr)
548
549    # Plot if requested
550    if not opts['plot'] and not opts['explore']: return
551    view = opts['view']
552    import matplotlib.pyplot as plt
553    if limits is None:
554        vmin, vmax = np.Inf, -np.Inf
555        if Nbase > 0:
556            vmin = min(vmin, min(base_value))
557            vmax = max(vmax, max(base_value))
558        if Ncomp > 0:
559            vmin = min(vmin, min(comp_value))
560            vmax = max(vmax, max(comp_value))
561        limits = vmin, vmax
562
563    if Nbase > 0:
564        if Ncomp > 0: plt.subplot(131)
565        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
566        plt.title("%s t=%.2f ms"%(base.engine, base_time))
567        #cbar_title = "log I"
568    if Ncomp > 0:
569        if Nbase > 0: plt.subplot(132)
570        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
571        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
572        #cbar_title = "log I"
573    if Ncomp > 0 and Nbase > 0:
574        plt.subplot(133)
575        if not opts['rel_err']:
576            err, errstr, errview = resid, "abs err", "linear"
577        else:
578            err, errstr, errview = abs(relerr), "rel err", "log"
579        #err,errstr = base/comp,"ratio"
580        plot_theory(data, None, resid=err, view=errview, use_data=False)
581        if view == 'linear':
582            plt.xscale('linear')
583        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
584        #cbar_title = errstr if errview=="linear" else "log "+errstr
585    #if is2D:
586    #    h = plt.colorbar()
587    #    h.ax.set_title(cbar_title)
588
589    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
590        plt.figure()
591        v = relerr
592        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
593        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
594        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
595                   % (base.engine, comp.engine, comp.engine))
596        plt.ylabel('P(err)')
597        plt.title('Distribution of relative error between calculation engines')
598
599    if not opts['explore']:
600        plt.show()
601
602    return limits
603
604def _print_stats(label, err):
605    sorted_err = np.sort(abs(err))
606    p50 = int((len(err)-1)*0.50)
607    p98 = int((len(err)-1)*0.98)
608    data = [
609        "max:%.3e"%sorted_err[-1],
610        "median:%.3e"%sorted_err[p50],
611        "98%%:%.3e"%sorted_err[p98],
612        "rms:%.3e"%np.sqrt(np.mean(err**2)),
613        "zero-offset:%+.3e"%np.mean(err),
614        ]
615    print(label+"  "+"  ".join(data))
616
617
618
619# ===========================================================================
620#
621NAME_OPTIONS = set([
622    'plot', 'noplot',
623    'half', 'fast', 'single', 'double',
624    'single!', 'double!', 'quad!', 'sasview',
625    'lowq', 'midq', 'highq', 'exq',
626    '2d', '1d',
627    'preset', 'random',
628    'poly', 'mono',
629    'nopars', 'pars',
630    'rel', 'abs',
631    'linear', 'log', 'q4',
632    'hist', 'nohist',
633    'edit',
634    'demo', 'default',
635    ])
636VALUE_OPTIONS = [
637    # Note: random is both a name option and a value option
638    'cutoff', 'random', 'nq', 'res', 'accuracy',
639    ]
640
641def columnize(L, indent="", width=79):
642    """
643    Format a list of strings into columns.
644
645    Returns a string with carriage returns ready for printing.
646    """
647    column_width = max(len(w) for w in L) + 1
648    num_columns = (width - len(indent)) // column_width
649    num_rows = len(L) // num_columns
650    L = L + [""] * (num_rows*num_columns - len(L))
651    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
652    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
653             for row in zip(*columns)]
654    output = indent + ("\n"+indent).join(lines)
655    return output
656
657
658def get_pars(model_info, use_demo=False):
659    """
660    Extract demo parameters from the model definition.
661    """
662    # Get the default values for the parameters
663    pars = dict((p.name, p.default) for p in model_info['parameters'])
664
665    # Fill in default values for the polydispersity parameters
666    for p in model_info['parameters']:
667        if p.type in ('volume', 'orientation'):
668            pars[p.name+'_pd'] = 0.0
669            pars[p.name+'_pd_n'] = 0
670            pars[p.name+'_pd_nsigma'] = 3.0
671            pars[p.name+'_pd_type'] = "gaussian"
672
673    # Plug in values given in demo
674    if use_demo:
675        pars.update(model_info['demo'])
676    return pars
677
678
679def parse_opts():
680    """
681    Parse command line options.
682    """
683    MODELS = core.list_models()
684    flags = [arg for arg in sys.argv[1:]
685             if arg.startswith('-')]
686    values = [arg for arg in sys.argv[1:]
687              if not arg.startswith('-') and '=' in arg]
688    args = [arg for arg in sys.argv[1:]
689            if not arg.startswith('-') and '=' not in arg]
690    models = "\n    ".join("%-15s"%v for v in MODELS)
691    if len(args) == 0:
692        print(USAGE)
693        print("\nAvailable models:")
694        print(columnize(MODELS, indent="  "))
695        sys.exit(1)
696    if len(args) > 3:
697        print("expected parameters: model N1 N2")
698
699    name = args[0]
700    try:
701        model_info = core.load_model_info(name)
702    except ImportError, exc:
703        print(str(exc))
704        print("Could not find model; use one of:\n    " + models)
705        sys.exit(1)
706
707    invalid = [o[1:] for o in flags
708               if o[1:] not in NAME_OPTIONS
709               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
710    if invalid:
711        print("Invalid options: %s"%(", ".join(invalid)))
712        sys.exit(1)
713
714
715    # pylint: disable=bad-whitespace
716    # Interpret the flags
717    opts = {
718        'plot'      : True,
719        'view'      : 'log',
720        'is2d'      : False,
721        'qmax'      : 0.05,
722        'nq'        : 128,
723        'res'       : 0.0,
724        'accuracy'  : 'Low',
725        'cutoff'    : 0.0,
726        'seed'      : -1,  # default to preset
727        'mono'      : False,
728        'show_pars' : False,
729        'show_hist' : False,
730        'rel_err'   : True,
731        'explore'   : False,
732        'use_demo'  : True,
733    }
734    engines = []
735    for arg in flags:
736        if arg == '-noplot':    opts['plot'] = False
737        elif arg == '-plot':    opts['plot'] = True
738        elif arg == '-linear':  opts['view'] = 'linear'
739        elif arg == '-log':     opts['view'] = 'log'
740        elif arg == '-q4':      opts['view'] = 'q4'
741        elif arg == '-1d':      opts['is2d'] = False
742        elif arg == '-2d':      opts['is2d'] = True
743        elif arg == '-exq':     opts['qmax'] = 10.0
744        elif arg == '-highq':   opts['qmax'] = 1.0
745        elif arg == '-midq':    opts['qmax'] = 0.2
746        elif arg == '-lowq':    opts['qmax'] = 0.05
747        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
748        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
749        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
750        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
751        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
752        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
753        elif arg == '-preset':  opts['seed'] = -1
754        elif arg == '-mono':    opts['mono'] = True
755        elif arg == '-poly':    opts['mono'] = False
756        elif arg == '-pars':    opts['show_pars'] = True
757        elif arg == '-nopars':  opts['show_pars'] = False
758        elif arg == '-hist':    opts['show_hist'] = True
759        elif arg == '-nohist':  opts['show_hist'] = False
760        elif arg == '-rel':     opts['rel_err'] = True
761        elif arg == '-abs':     opts['rel_err'] = False
762        elif arg == '-half':    engines.append(arg[1:])
763        elif arg == '-fast':    engines.append(arg[1:])
764        elif arg == '-single':  engines.append(arg[1:])
765        elif arg == '-double':  engines.append(arg[1:])
766        elif arg == '-single!': engines.append(arg[1:])
767        elif arg == '-double!': engines.append(arg[1:])
768        elif arg == '-quad!':   engines.append(arg[1:])
769        elif arg == '-sasview': engines.append(arg[1:])
770        elif arg == '-edit':    opts['explore'] = True
771        elif arg == '-demo':    opts['use_demo'] = True
772        elif arg == '-default':    opts['use_demo'] = False
773    # pylint: enable=bad-whitespace
774
775    if len(engines) == 0:
776        engines.extend(['single', 'sasview'])
777    elif len(engines) == 1:
778        if engines[0][0] != 'sasview':
779            engines.append('sasview')
780        else:
781            engines.append('single')
782    elif len(engines) > 2:
783        del engines[2:]
784
785    n1 = int(args[1]) if len(args) > 1 else 1
786    n2 = int(args[2]) if len(args) > 2 else 1
787    use_sasview = any(engine=='sasview' and count>0
788                      for engine, count in zip(engines, [n1, n2]))
789
790    # Get demo parameters from model definition, or use default parameters
791    # if model does not define demo parameters
792    pars = get_pars(model_info, opts['use_demo'])
793
794
795    # Fill in parameters given on the command line
796    presets = {}
797    for arg in values:
798        k, v = arg.split('=', 1)
799        if k not in pars:
800            # extract base name without polydispersity info
801            s = set(p.split('_pd')[0] for p in pars)
802            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
803            sys.exit(1)
804        presets[k] = float(v) if not k.endswith('type') else v
805
806    # randomize parameters
807    #pars.update(set_pars)  # set value before random to control range
808    if opts['seed'] > -1:
809        pars = randomize_pars(pars, seed=opts['seed'])
810        print("Randomize using -random=%i"%opts['seed'])
811    if opts['mono']:
812        pars = suppress_pd(pars)
813    pars.update(presets)  # set value after random to control value
814    #import pprint; pprint.pprint(model_info)
815    constrain_pars(model_info, pars)
816    if use_sasview:
817        constrain_new_to_old(model_info, pars)
818    if opts['show_pars']:
819        print(str(parlist(model_info, pars, opts['is2d'])))
820
821    # Create the computational engines
822    data, _ = make_data(opts)
823    if n1:
824        base = make_engine(model_info, data, engines[0], opts['cutoff'])
825    else:
826        base = None
827    if n2:
828        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
829    else:
830        comp = None
831
832    # pylint: disable=bad-whitespace
833    # Remember it all
834    opts.update({
835        'name'      : name,
836        'def'       : model_info,
837        'n1'        : n1,
838        'n2'        : n2,
839        'presets'   : presets,
840        'pars'      : pars,
841        'data'      : data,
842        'engines'   : [base, comp],
843    })
844    # pylint: enable=bad-whitespace
845
846    return opts
847
848def explore(opts):
849    """
850    Explore the model using the Bumps GUI.
851    """
852    import wx
853    from bumps.names import FitProblem
854    from bumps.gui.app_frame import AppFrame
855
856    problem = FitProblem(Explore(opts))
857    is_mac = "cocoa" in wx.version()
858    app = wx.App()
859    frame = AppFrame(parent=None, title="explore")
860    if not is_mac: frame.Show()
861    frame.panel.set_model(model=problem)
862    frame.panel.Layout()
863    frame.panel.aui.Split(0, wx.TOP)
864    if is_mac: frame.Show()
865    app.MainLoop()
866
867class Explore(object):
868    """
869    Bumps wrapper for a SAS model comparison.
870
871    The resulting object can be used as a Bumps fit problem so that
872    parameters can be adjusted in the GUI, with plots updated on the fly.
873    """
874    def __init__(self, opts):
875        from bumps.cli import config_matplotlib
876        from . import bumps_model
877        config_matplotlib()
878        self.opts = opts
879        model_info = opts['def']
880        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
881        if not opts['is2d']:
882            active = [base + ext
883                      for base in model_info['partype']['pd-1d']
884                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
885            active.extend(model_info['partype']['fixed-1d'])
886            for k in active:
887                v = pars[k]
888                v.range(*parameter_range(k, v.value))
889        else:
890            for k, v in pars.items():
891                v.range(*parameter_range(k, v.value))
892
893        self.pars = pars
894        self.pd_types = pd_types
895        self.limits = None
896
897    def numpoints(self):
898        """
899        Return the number of points.
900        """
901        return len(self.pars) + 1  # so dof is 1
902
903    def parameters(self):
904        """
905        Return a dictionary of parameters.
906        """
907        return self.pars
908
909    def nllf(self):
910        """
911        Return cost.
912        """
913        # pylint: disable=no-self-use
914        return 0.  # No nllf
915
916    def plot(self, view='log'):
917        """
918        Plot the data and residuals.
919        """
920        pars = dict((k, v.value) for k, v in self.pars.items())
921        pars.update(self.pd_types)
922        self.opts['pars'] = pars
923        limits = compare(self.opts, limits=self.limits)
924        if self.limits is None:
925            vmin, vmax = limits
926            vmax = 1.3*vmax
927            vmin = vmax*1e-7
928            self.limits = vmin, vmax
929
930
931def main():
932    """
933    Main program.
934    """
935    opts = parse_opts()
936    if opts['explore']:
937        explore(opts)
938    else:
939        compare(opts)
940
941if __name__ == "__main__":
942    main()
Note: See TracBrowser for help on using the repository browser.