source: sasmodels/sasmodels/compare.py @ b151003

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since b151003 was b151003, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

Merge remote-tracking branch 'origin/master' into polydisp

Conflicts:

sasmodels/model_test.py

  • Property mode set to 100755
File size: 31.7 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np  # type: ignore
37
38from . import core
39from . import kerneldll
40from .data import plot_theory, empty_data1D, empty_data2D
41from .direct_model import DirectModel
42from .convert import revert_name, revert_pars, constrain_new_to_old
43
44USAGE = """
45usage: compare.py model N1 N2 [options...] [key=val]
46
47Compare the speed and value for a model between the SasView original and the
48sasmodels rewrite.
49
50model is the name of the model to compare (see below).
51N1 is the number of times to run sasmodels (default=1).
52N2 is the number times to run sasview (default=1).
53
54Options (* for default):
55
56    -plot*/-noplot plots or suppress the plot of the model
57    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
58    -nq=128 sets the number of Q points in the data set
59    -zero indicates that q=0 should be included
60    -1d*/-2d computes 1d or 2d data
61    -preset*/-random[=seed] preset or random parameters
62    -mono/-poly* force monodisperse/polydisperse
63    -cutoff=1e-5* cutoff value for including a point in polydispersity
64    -pars/-nopars* prints the parameter set or not
65    -abs/-rel* plot relative or absolute error
66    -linear/-log*/-q4 intensity scaling
67    -hist/-nohist* plot histogram of relative error
68    -res=0 sets the resolution width dQ/Q if calculating with resolution
69    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
70    -edit starts the parameter explorer
71    -default/-demo* use demo vs default parameters
72
73Any two calculation engines can be selected for comparison:
74
75    -single/-double/-half/-fast sets an OpenCL calculation engine
76    -single!/-double!/-quad! sets an OpenMP calculation engine
77    -sasview sets the sasview calculation engine
78
79The default is -single -sasview.  Note that the interpretation of quad
80precision depends on architecture, and may vary from 64-bit to 128-bit,
81with 80-bit floats being common (1e-19 precision).
82
83Key=value pairs allow you to set specific values for the model parameters.
84"""
85
86# Update docs with command line usage string.   This is separate from the usual
87# doc string so that we can display it at run time if there is an error.
88# lin
89__doc__ = (__doc__  # pylint: disable=redefined-builtin
90           + """
91Program description
92-------------------
93
94"""
95           + USAGE)
96
97kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
98
99MODELS = core.list_models()
100
101# CRUFT python 2.6
102if not hasattr(datetime.timedelta, 'total_seconds'):
103    def delay(dt):
104        """Return number date-time delta as number seconds"""
105        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
106else:
107    def delay(dt):
108        """Return number date-time delta as number seconds"""
109        return dt.total_seconds()
110
111
112class push_seed(object):
113    """
114    Set the seed value for the random number generator.
115
116    When used in a with statement, the random number generator state is
117    restored after the with statement is complete.
118
119    :Parameters:
120
121    *seed* : int or array_like, optional
122        Seed for RandomState
123
124    :Example:
125
126    Seed can be used directly to set the seed::
127
128        >>> from numpy.random import randint
129        >>> push_seed(24)
130        <...push_seed object at...>
131        >>> print(randint(0,1000000,3))
132        [242082    899 211136]
133
134    Seed can also be used in a with statement, which sets the random
135    number generator state for the enclosed computations and restores
136    it to the previous state on completion::
137
138        >>> with push_seed(24):
139        ...    print(randint(0,1000000,3))
140        [242082    899 211136]
141
142    Using nested contexts, we can demonstrate that state is indeed
143    restored after the block completes::
144
145        >>> with push_seed(24):
146        ...    print(randint(0,1000000))
147        ...    with push_seed(24):
148        ...        print(randint(0,1000000,3))
149        ...    print(randint(0,1000000))
150        242082
151        [242082    899 211136]
152        899
153
154    The restore step is protected against exceptions in the block::
155
156        >>> with push_seed(24):
157        ...    print(randint(0,1000000))
158        ...    try:
159        ...        with push_seed(24):
160        ...            print(randint(0,1000000,3))
161        ...            raise Exception()
162        ...    except:
163        ...        print("Exception raised")
164        ...    print(randint(0,1000000))
165        242082
166        [242082    899 211136]
167        Exception raised
168        899
169    """
170    def __init__(self, seed=None):
171        self._state = np.random.get_state()
172        np.random.seed(seed)
173
174    def __enter__(self):
175        return None
176
177    def __exit__(self, *args):
178        np.random.set_state(self._state)
179
180def tic():
181    """
182    Timer function.
183
184    Use "toc=tic()" to start the clock and "toc()" to measure
185    a time interval.
186    """
187    then = datetime.datetime.now()
188    return lambda: delay(datetime.datetime.now() - then)
189
190
191def set_beam_stop(data, radius, outer=None):
192    """
193    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
194
195    Note: this function does not use the sasview package
196    """
197    if hasattr(data, 'qx_data'):
198        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
199        data.mask = (q < radius)
200        if outer is not None:
201            data.mask |= (q >= outer)
202    else:
203        data.mask = (data.x < radius)
204        if outer is not None:
205            data.mask |= (data.x >= outer)
206
207
208def parameter_range(p, v):
209    """
210    Choose a parameter range based on parameter name and initial value.
211    """
212    # process the polydispersity options
213    if p.endswith('_pd_n'):
214        return [0, 100]
215    elif p.endswith('_pd_nsigma'):
216        return [0, 5]
217    elif p.endswith('_pd_type'):
218        return v
219    elif any(s in p for s in ('theta', 'phi', 'psi')):
220        # orientation in [-180,180], orientation pd in [0,45]
221        if p.endswith('_pd'):
222            return [0, 45]
223        else:
224            return [-180, 180]
225    elif p.endswith('_pd'):
226        return [0, 1]
227    elif 'sld' in p:
228        return [-0.5, 10]
229    elif p == 'background':
230        return [0, 10]
231    elif p == 'scale':
232        return [0, 1e3]
233    elif v < 0:
234        return [2*v, -2*v]
235    else:
236        return [0, (2*v if v > 0 else 1)]
237
238
239def _randomize_one(model_info, p, v):
240    """
241    Randomize a single parameter.
242    """
243    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
244        return v
245
246    # Find the parameter definition
247    for par in model_info.parameters.call_parameters:
248        if par.name == p:
249            break
250    else:
251        raise ValueError("unknown parameter %r"%p)
252
253    if len(par.limits) > 2:  # choice list
254        return np.random.randint(len(par.limits))
255
256    limits = par.limits
257    if not np.isfinite(limits).all():
258        low, high = parameter_range(p, v)
259        limits = (max(limits[0], low), min(limits[1], high))
260
261    return np.random.uniform(*limits)
262
263
264def randomize_pars(model_info, pars, seed=None):
265    """
266    Generate random values for all of the parameters.
267
268    Valid ranges for the random number generator are guessed from the name of
269    the parameter; this will not account for constraints such as cap radius
270    greater than cylinder radius in the capped_cylinder model, so
271    :func:`constrain_pars` needs to be called afterward..
272    """
273    with push_seed(seed):
274        # Note: the sort guarantees order `of calls to random number generator
275        pars = dict((p, _randomize_one(model_info, p, v))
276                    for p, v in sorted(pars.items()))
277    return pars
278
279def constrain_pars(model_info, pars):
280    """
281    Restrict parameters to valid values.
282
283    This includes model specific code for models such as capped_cylinder
284    which need to support within model constraints (cap radius more than
285    cylinder radius in this case).
286    """
287    name = model_info.id
288    # if it is a product model, then just look at the form factor since
289    # none of the structure factors need any constraints.
290    if '*' in name:
291        name = name.split('*')[0]
292
293    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
294        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
295    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
296        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
297
298    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
299    if name == 'guinier':
300        #q_max = 0.2  # mid q maximum
301        q_max = 1.0  # high q maximum
302        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
303        pars['rg'] = min(pars['rg'], rg_max)
304
305    if name == 'rpa':
306        # Make sure phi sums to 1.0
307        if pars['case_num'] < 2:
308            pars['Phi1'] = 0.
309            pars['Phi2'] = 0.
310        elif pars['case_num'] < 5:
311            pars['Phi1'] = 0.
312        total = sum(pars['Phi'+c] for c in '1234')
313        for c in '1234':
314            pars['Phi'+c] /= total
315
316def parlist(model_info, pars, is2d):
317    """
318    Format the parameter list for printing.
319    """
320    lines = []
321    parameters = model_info.parameters
322    for p in parameters.user_parameters(pars, is2d):
323        fields = dict(
324            value=pars.get(p.id, p.default),
325            pd=pars.get(p.id+"_pd", 0.),
326            n=int(pars.get(p.id+"_pd_n", 0)),
327            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
328            type=pars.get(p.id+"_pd_type", 'gaussian'))
329        lines.append(_format_par(p.name, **fields))
330    return "\n".join(lines)
331
332    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
333
334def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
335    line = "%s: %g"%(name, value)
336    if pd != 0.  and n != 0:
337        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
338                % (pd, n, nsigma, nsigma, type)
339    return line
340
341def suppress_pd(pars):
342    """
343    Suppress theta_pd for now until the normalization is resolved.
344
345    May also suppress complete polydispersity of the model to test
346    models more quickly.
347    """
348    pars = pars.copy()
349    for p in pars:
350        if p.endswith("_pd_n"): pars[p] = 0
351    return pars
352
353def eval_sasview(model_info, data):
354    """
355    Return a model calculator using the pre-4.0 SasView models.
356    """
357    # importing sas here so that the error message will be that sas failed to
358    # import rather than the more obscure smear_selection not imported error
359    import sas
360    from sas.models.qsmearing import smear_selection
361
362    def get_model(name):
363        #print("new",sorted(_pars.items()))
364        sas = __import__('sas.models.' + name)
365        ModelClass = getattr(getattr(sas.models, name, None), name, None)
366        if ModelClass is None:
367            raise ValueError("could not find model %r in sas.models"%name)
368        return ModelClass()
369
370    # grab the sasview model, or create it if it is a product model
371    if model_info.composition:
372        composition_type, parts = model_info.composition
373        if composition_type == 'product':
374            from sas.models.MultiplicationModel import MultiplicationModel
375            P, S = [get_model(revert_name(p)) for p in parts]
376            model = MultiplicationModel(P, S)
377        else:
378            raise ValueError("sasview mixture models not supported by compare")
379    else:
380        model = get_model(revert_name(model_info))
381
382    # build a smearer with which to call the model, if necessary
383    smearer = smear_selection(data, model=model)
384    if hasattr(data, 'qx_data'):
385        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
386        index = ((~data.mask) & (~np.isnan(data.data))
387                 & (q >= data.qmin) & (q <= data.qmax))
388        if smearer is not None:
389            smearer.model = model  # because smear_selection has a bug
390            smearer.accuracy = data.accuracy
391            smearer.set_index(index)
392            theory = lambda: smearer.get_value()
393        else:
394            theory = lambda: model.evalDistribution([data.qx_data[index],
395                                                     data.qy_data[index]])
396    elif smearer is not None:
397        theory = lambda: smearer(model.evalDistribution(data.x))
398    else:
399        theory = lambda: model.evalDistribution(data.x)
400
401    def calculator(**pars):
402        """
403        Sasview calculator for model.
404        """
405        # paying for parameter conversion each time to keep life simple, if not fast
406        pars = revert_pars(model_info, pars)
407        for k, v in pars.items():
408            parts = k.split('.')  # polydispersity components
409            if len(parts) == 2:
410                model.dispersion[parts[0]][parts[1]] = v
411            else:
412                model.setParam(k, v)
413        return theory()
414
415    calculator.engine = "sasview"
416    return calculator
417
418DTYPE_MAP = {
419    'half': '16',
420    'fast': 'fast',
421    'single': '32',
422    'double': '64',
423    'quad': '128',
424    'f16': '16',
425    'f32': '32',
426    'f64': '64',
427    'longdouble': '128',
428}
429def eval_opencl(model_info, data, dtype='single', cutoff=0.):
430    """
431    Return a model calculator using the OpenCL calculation engine.
432    """
433    try:
434        model = core.build_model(model_info, dtype=dtype, platform="ocl")
435    except Exception as exc:
436        print(exc)
437        print("... trying again with single precision")
438        model = core.build_model(model_info, dtype='single', platform="ocl")
439    calculator = DirectModel(data, model, cutoff=cutoff)
440    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
441    return calculator
442
443def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
444    """
445    Return a model calculator using the DLL calculation engine.
446    """
447    if dtype == 'quad':
448        dtype = 'longdouble'
449    model = core.build_model(model_info, dtype=dtype, platform="dll")
450    calculator = DirectModel(data, model, cutoff=cutoff)
451    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
452    return calculator
453
454def time_calculation(calculator, pars, Nevals=1):
455    """
456    Compute the average calculation time over N evaluations.
457
458    An additional call is generated without polydispersity in order to
459    initialize the calculation engine, and make the average more stable.
460    """
461    # initialize the code so time is more accurate
462    if Nevals > 1:
463        value = calculator(**suppress_pd(pars))
464    toc = tic()
465    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
466        value = calculator(**pars)
467    average_time = toc()*1000./Nevals
468    return value, average_time
469
470def make_data(opts):
471    """
472    Generate an empty dataset, used with the model to set Q points
473    and resolution.
474
475    *opts* contains the options, with 'qmax', 'nq', 'res',
476    'accuracy', 'is2d' and 'view' parsed from the command line.
477    """
478    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
479    if opts['is2d']:
480        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
481        data.accuracy = opts['accuracy']
482        set_beam_stop(data, 0.0004)
483        index = ~data.mask
484    else:
485        if opts['view'] == 'log' and not opts['zero']:
486            qmax = math.log10(qmax)
487            q = np.logspace(qmax-3, qmax, nq)
488        else:
489            q = np.linspace(0.001*qmax, qmax, nq)
490        if opts['zero']:
491            q = np.hstack((0, q))
492        data = empty_data1D(q, resolution=res)
493        index = slice(None, None)
494    return data, index
495
496def make_engine(model_info, data, dtype, cutoff):
497    """
498    Generate the appropriate calculation engine for the given datatype.
499
500    Datatypes with '!' appended are evaluated using external C DLLs rather
501    than OpenCL.
502    """
503    if dtype == 'sasview':
504        return eval_sasview(model_info, data)
505    elif dtype.endswith('!'):
506        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
507    else:
508        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
509
510def _show_invalid(data, theory):
511    if not theory.mask.any():
512        return
513
514    if hasattr(data, 'x'):
515        bad = zip(data.x[theory.mask], theory[theory.mask])
516        print("   *** ", ", ".join("I(%g)=%g"%(x, y) for x,y in bad))
517
518
519def compare(opts, limits=None):
520    """
521    Preform a comparison using options from the command line.
522
523    *limits* are the limits on the values to use, either to set the y-axis
524    for 1D or to set the colormap scale for 2D.  If None, then they are
525    inferred from the data and returned. When exploring using Bumps,
526    the limits are set when the model is initially called, and maintained
527    as the values are adjusted, making it easier to see the effects of the
528    parameters.
529    """
530    Nbase, Ncomp = opts['n1'], opts['n2']
531    pars = opts['pars']
532    data = opts['data']
533
534    # Base calculation
535    if Nbase > 0:
536        base = opts['engines'][0]
537        try:
538            base_value, base_time = time_calculation(base, pars, Nbase)
539            base_value = np.ma.masked_invalid(base_value)
540            print("%s t=%.2f ms, intensity=%.0f"
541                  % (base.engine, base_time, base_value.sum()))
542            _show_invalid(data, base_value)
543        except ImportError:
544            traceback.print_exc()
545            Nbase = 0
546
547    # Comparison calculation
548    if Ncomp > 0:
549        comp = opts['engines'][1]
550        try:
551            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
552            comp_value = np.ma.masked_invalid(comp_value)
553            print("%s t=%.2f ms, intensity=%.0f"
554                  % (comp.engine, comp_time, comp_value.sum()))
555            _show_invalid(data, comp_value)
556        except ImportError:
557            traceback.print_exc()
558            Ncomp = 0
559
560    # Compare, but only if computing both forms
561    if Nbase > 0 and Ncomp > 0:
562        resid = (base_value - comp_value)
563        relerr = resid/comp_value
564        _print_stats("|%s-%s|"
565                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
566                     resid)
567        _print_stats("|(%s-%s)/%s|"
568                     % (base.engine, comp.engine, comp.engine),
569                     relerr)
570
571    # Plot if requested
572    if not opts['plot'] and not opts['explore']: return
573    view = opts['view']
574    import matplotlib.pyplot as plt
575    if limits is None:
576        vmin, vmax = np.Inf, -np.Inf
577        if Nbase > 0:
578            vmin = min(vmin, base_value.min())
579            vmax = max(vmax, base_value.max())
580        if Ncomp > 0:
581            vmin = min(vmin, comp_value.min())
582            vmax = max(vmax, comp_value.max())
583        limits = vmin, vmax
584
585    if Nbase > 0:
586        if Ncomp > 0: plt.subplot(131)
587        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
588        plt.title("%s t=%.2f ms"%(base.engine, base_time))
589        #cbar_title = "log I"
590    if Ncomp > 0:
591        if Nbase > 0: plt.subplot(132)
592        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
593        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
594        #cbar_title = "log I"
595    if Ncomp > 0 and Nbase > 0:
596        plt.subplot(133)
597        if not opts['rel_err']:
598            err, errstr, errview = resid, "abs err", "linear"
599        else:
600            err, errstr, errview = abs(relerr), "rel err", "log"
601        #err,errstr = base/comp,"ratio"
602        plot_theory(data, None, resid=err, view=errview, use_data=False)
603        if view == 'linear':
604            plt.xscale('linear')
605        plt.title("max %s = %.3g"%(errstr, abs(err).max()))
606        #cbar_title = errstr if errview=="linear" else "log "+errstr
607    #if is2D:
608    #    h = plt.colorbar()
609    #    h.ax.set_title(cbar_title)
610
611    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
612        plt.figure()
613        v = relerr
614        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
615        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
616        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
617                   % (base.engine, comp.engine, comp.engine))
618        plt.ylabel('P(err)')
619        plt.title('Distribution of relative error between calculation engines')
620
621    if not opts['explore']:
622        plt.show()
623
624    return limits
625
626def _print_stats(label, err):
627    sorted_err = np.sort(abs(err.compressed()))
628    p50 = int((len(err)-1)*0.50)
629    p98 = int((len(err)-1)*0.98)
630    data = [
631        "max:%.3e"%sorted_err[-1],
632        "median:%.3e"%sorted_err[p50],
633        "98%%:%.3e"%sorted_err[p98],
634        "rms:%.3e"%np.sqrt(np.mean(err**2)),
635        "zero-offset:%+.3e"%np.mean(err),
636        ]
637    print(label+"  "+"  ".join(data))
638
639
640
641# ===========================================================================
642#
643NAME_OPTIONS = set([
644    'plot', 'noplot',
645    'half', 'fast', 'single', 'double',
646    'single!', 'double!', 'quad!', 'sasview',
647    'lowq', 'midq', 'highq', 'exq', 'zero',
648    '2d', '1d',
649    'preset', 'random',
650    'poly', 'mono',
651    'nopars', 'pars',
652    'rel', 'abs',
653    'linear', 'log', 'q4',
654    'hist', 'nohist',
655    'edit',
656    'demo', 'default',
657    ])
658VALUE_OPTIONS = [
659    # Note: random is both a name option and a value option
660    'cutoff', 'random', 'nq', 'res', 'accuracy',
661    ]
662
663def columnize(L, indent="", width=79):
664    """
665    Format a list of strings into columns.
666
667    Returns a string with carriage returns ready for printing.
668    """
669    column_width = max(len(w) for w in L) + 1
670    num_columns = (width - len(indent)) // column_width
671    num_rows = len(L) // num_columns
672    L = L + [""] * (num_rows*num_columns - len(L))
673    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
674    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
675             for row in zip(*columns)]
676    output = indent + ("\n"+indent).join(lines)
677    return output
678
679
680def get_pars(model_info, use_demo=False):
681    """
682    Extract demo parameters from the model definition.
683    """
684    # Get the default values for the parameters
685    pars = {}
686    for p in model_info.parameters.call_parameters:
687        parts = [('', p.default)]
688        if p.polydisperse:
689            parts.append(('_pd', 0.0))
690            parts.append(('_pd_n', 0))
691            parts.append(('_pd_nsigma', 3.0))
692            parts.append(('_pd_type', "gaussian"))
693        for ext, val in parts:
694            if p.length > 1:
695                dict(("%s%d%s"%(p.id,k,ext), val) for k in range(p.length))
696            else:
697                pars[p.id+ext] = val
698
699    # Plug in values given in demo
700    if use_demo:
701        pars.update(model_info.demo)
702    return pars
703
704
705def parse_opts():
706    """
707    Parse command line options.
708    """
709    MODELS = core.list_models()
710    flags = [arg for arg in sys.argv[1:]
711             if arg.startswith('-')]
712    values = [arg for arg in sys.argv[1:]
713              if not arg.startswith('-') and '=' in arg]
714    args = [arg for arg in sys.argv[1:]
715            if not arg.startswith('-') and '=' not in arg]
716    models = "\n    ".join("%-15s"%v for v in MODELS)
717    if len(args) == 0:
718        print(USAGE)
719        print("\nAvailable models:")
720        print(columnize(MODELS, indent="  "))
721        sys.exit(1)
722    if len(args) > 3:
723        print("expected parameters: model N1 N2")
724
725    name = args[0]
726    try:
727        model_info = core.load_model_info(name)
728    except ImportError as exc:
729        print(str(exc))
730        print("Could not find model; use one of:\n    " + models)
731        sys.exit(1)
732
733    invalid = [o[1:] for o in flags
734               if o[1:] not in NAME_OPTIONS
735               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
736    if invalid:
737        print("Invalid options: %s"%(", ".join(invalid)))
738        sys.exit(1)
739
740
741    # pylint: disable=bad-whitespace
742    # Interpret the flags
743    opts = {
744        'plot'      : True,
745        'view'      : 'log',
746        'is2d'      : False,
747        'qmax'      : 0.05,
748        'nq'        : 128,
749        'res'       : 0.0,
750        'accuracy'  : 'Low',
751        'cutoff'    : 0.0,
752        'seed'      : -1,  # default to preset
753        'mono'      : False,
754        'show_pars' : False,
755        'show_hist' : False,
756        'rel_err'   : True,
757        'explore'   : False,
758        'use_demo'  : True,
759    }
760    engines = []
761    for arg in flags:
762        if arg == '-noplot':    opts['plot'] = False
763        elif arg == '-plot':    opts['plot'] = True
764        elif arg == '-linear':  opts['view'] = 'linear'
765        elif arg == '-log':     opts['view'] = 'log'
766        elif arg == '-q4':      opts['view'] = 'q4'
767        elif arg == '-1d':      opts['is2d'] = False
768        elif arg == '-2d':      opts['is2d'] = True
769        elif arg == '-exq':     opts['qmax'] = 10.0
770        elif arg == '-highq':   opts['qmax'] = 1.0
771        elif arg == '-midq':    opts['qmax'] = 0.2
772        elif arg == '-lowq':    opts['qmax'] = 0.05
773        elif arg == '-zero':    opts['zero'] = True
774        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
775        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
776        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
777        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
778        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
779        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
780        elif arg == '-preset':  opts['seed'] = -1
781        elif arg == '-mono':    opts['mono'] = True
782        elif arg == '-poly':    opts['mono'] = False
783        elif arg == '-pars':    opts['show_pars'] = True
784        elif arg == '-nopars':  opts['show_pars'] = False
785        elif arg == '-hist':    opts['show_hist'] = True
786        elif arg == '-nohist':  opts['show_hist'] = False
787        elif arg == '-rel':     opts['rel_err'] = True
788        elif arg == '-abs':     opts['rel_err'] = False
789        elif arg == '-half':    engines.append(arg[1:])
790        elif arg == '-fast':    engines.append(arg[1:])
791        elif arg == '-single':  engines.append(arg[1:])
792        elif arg == '-double':  engines.append(arg[1:])
793        elif arg == '-single!': engines.append(arg[1:])
794        elif arg == '-double!': engines.append(arg[1:])
795        elif arg == '-quad!':   engines.append(arg[1:])
796        elif arg == '-sasview': engines.append(arg[1:])
797        elif arg == '-edit':    opts['explore'] = True
798        elif arg == '-demo':    opts['use_demo'] = True
799        elif arg == '-default':    opts['use_demo'] = False
800    # pylint: enable=bad-whitespace
801
802    if len(engines) == 0:
803        engines.extend(['single', 'double'])
804    elif len(engines) == 1:
805        if engines[0][0] != 'double':
806            engines.append('double')
807        else:
808            engines.append('single')
809    elif len(engines) > 2:
810        del engines[2:]
811
812    n1 = int(args[1]) if len(args) > 1 else 1
813    n2 = int(args[2]) if len(args) > 2 else 1
814    use_sasview = any(engine=='sasview' and count>0
815                      for engine, count in zip(engines, [n1, n2]))
816
817    # Get demo parameters from model definition, or use default parameters
818    # if model does not define demo parameters
819    pars = get_pars(model_info, opts['use_demo'])
820
821
822    # Fill in parameters given on the command line
823    presets = {}
824    for arg in values:
825        k, v = arg.split('=', 1)
826        if k not in pars:
827            # extract base name without polydispersity info
828            s = set(p.split('_pd')[0] for p in pars)
829            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
830            sys.exit(1)
831        presets[k] = float(v) if not k.endswith('type') else v
832
833    # randomize parameters
834    #pars.update(set_pars)  # set value before random to control range
835    if opts['seed'] > -1:
836        pars = randomize_pars(model_info, pars, seed=opts['seed'])
837        print("Randomize using -random=%i"%opts['seed'])
838    if opts['mono']:
839        pars = suppress_pd(pars)
840    pars.update(presets)  # set value after random to control value
841    #import pprint; pprint.pprint(model_info)
842    constrain_pars(model_info, pars)
843    if use_sasview:
844        constrain_new_to_old(model_info, pars)
845    if opts['show_pars']:
846        print(str(parlist(model_info, pars, opts['is2d'])))
847
848    # Create the computational engines
849    data, _ = make_data(opts)
850    if n1:
851        base = make_engine(model_info, data, engines[0], opts['cutoff'])
852    else:
853        base = None
854    if n2:
855        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
856    else:
857        comp = None
858
859    # pylint: disable=bad-whitespace
860    # Remember it all
861    opts.update({
862        'name'      : name,
863        'def'       : model_info,
864        'n1'        : n1,
865        'n2'        : n2,
866        'presets'   : presets,
867        'pars'      : pars,
868        'data'      : data,
869        'engines'   : [base, comp],
870    })
871    # pylint: enable=bad-whitespace
872
873    return opts
874
875def explore(opts):
876    """
877    Explore the model using the Bumps GUI.
878    """
879    import wx  # type: ignore
880    from bumps.names import FitProblem  # type: ignore
881    from bumps.gui.app_frame import AppFrame  # type: ignore
882
883    problem = FitProblem(Explore(opts))
884    is_mac = "cocoa" in wx.version()
885    app = wx.App()
886    frame = AppFrame(parent=None, title="explore")
887    if not is_mac: frame.Show()
888    frame.panel.set_model(model=problem)
889    frame.panel.Layout()
890    frame.panel.aui.Split(0, wx.TOP)
891    if is_mac: frame.Show()
892    app.MainLoop()
893
894class Explore(object):
895    """
896    Bumps wrapper for a SAS model comparison.
897
898    The resulting object can be used as a Bumps fit problem so that
899    parameters can be adjusted in the GUI, with plots updated on the fly.
900    """
901    def __init__(self, opts):
902        from bumps.cli import config_matplotlib  # type: ignore
903        from . import bumps_model
904        config_matplotlib()
905        self.opts = opts
906        model_info = opts['def']
907        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
908        # Initialize parameter ranges, fixing the 2D parameters for 1D data.
909        if not opts['is2d']:
910            for p in model_info.parameters.user_parameters(is2d=False):
911                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
912                    k = p.name+ext
913                    v = pars.get(k, None)
914                    if v is not None:
915                        v.range(*parameter_range(k, v.value))
916        else:
917            for k, v in pars.items():
918                v.range(*parameter_range(k, v.value))
919
920        self.pars = pars
921        self.pd_types = pd_types
922        self.limits = None
923
924    def numpoints(self):
925        """
926        Return the number of points.
927        """
928        return len(self.pars) + 1  # so dof is 1
929
930    def parameters(self):
931        """
932        Return a dictionary of parameters.
933        """
934        return self.pars
935
936    def nllf(self):
937        """
938        Return cost.
939        """
940        # pylint: disable=no-self-use
941        return 0.  # No nllf
942
943    def plot(self, view='log'):
944        """
945        Plot the data and residuals.
946        """
947        pars = dict((k, v.value) for k, v in self.pars.items())
948        pars.update(self.pd_types)
949        self.opts['pars'] = pars
950        limits = compare(self.opts, limits=self.limits)
951        if self.limits is None:
952            vmin, vmax = limits
953            vmax = 1.3*vmax
954            vmin = vmax*1e-7
955            self.limits = vmin, vmax
956
957
958def main():
959    """
960    Main program.
961    """
962    opts = parse_opts()
963    if opts['explore']:
964        explore(opts)
965    else:
966        compare(opts)
967
968if __name__ == "__main__":
969    main()
Note: See TracBrowser for help on using the repository browser.