source: sasmodels/sasmodels/compare.py @ af92b73

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since af92b73 was af92b73, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

report an extra digit on timing comparison

  • Property mode set to 100755
File size: 31.2 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np
37
38from . import core
39from . import kerneldll
40from . import product
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_pars, constrain_new_to_old
44
45USAGE = """
46usage: compare.py model N1 N2 [options...] [key=val]
47
48Compare the speed and value for a model between the SasView original and the
49sasmodels rewrite.
50
51model is the name of the model to compare (see below).
52N1 is the number of times to run sasmodels (default=1).
53N2 is the number times to run sasview (default=1).
54
55Options (* for default):
56
57    -plot*/-noplot plots or suppress the plot of the model
58    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
59    -nq=128 sets the number of Q points in the data set
60    -1d*/-2d computes 1d or 2d data
61    -preset*/-random[=seed] preset or random parameters
62    -mono/-poly* force monodisperse/polydisperse
63    -cutoff=1e-5* cutoff value for including a point in polydispersity
64    -pars/-nopars* prints the parameter set or not
65    -abs/-rel* plot relative or absolute error
66    -linear/-log*/-q4 intensity scaling
67    -hist/-nohist* plot histogram of relative error
68    -res=0 sets the resolution width dQ/Q if calculating with resolution
69    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
70    -edit starts the parameter explorer
71
72Any two calculation engines can be selected for comparison:
73
74    -single/-double/-half/-fast sets an OpenCL calculation engine
75    -single!/-double!/-quad! sets an OpenMP calculation engine
76    -sasview sets the sasview calculation engine
77
78The default is -single -sasview.  Note that the interpretation of quad
79precision depends on architecture, and may vary from 64-bit to 128-bit,
80with 80-bit floats being common (1e-19 precision).
81
82Key=value pairs allow you to set specific values for the model parameters.
83"""
84
85# Update docs with command line usage string.   This is separate from the usual
86# doc string so that we can display it at run time if there is an error.
87# lin
88__doc__ = (__doc__  # pylint: disable=redefined-builtin
89           + """
90Program description
91-------------------
92
93"""
94           + USAGE)
95
96kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
97
98MODELS = core.list_models()
99
100# CRUFT python 2.6
101if not hasattr(datetime.timedelta, 'total_seconds'):
102    def delay(dt):
103        """Return number date-time delta as number seconds"""
104        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
105else:
106    def delay(dt):
107        """Return number date-time delta as number seconds"""
108        return dt.total_seconds()
109
110
111class push_seed(object):
112    """
113    Set the seed value for the random number generator.
114
115    When used in a with statement, the random number generator state is
116    restored after the with statement is complete.
117
118    :Parameters:
119
120    *seed* : int or array_like, optional
121        Seed for RandomState
122
123    :Example:
124
125    Seed can be used directly to set the seed::
126
127        >>> from numpy.random import randint
128        >>> push_seed(24)
129        <...push_seed object at...>
130        >>> print(randint(0,1000000,3))
131        [242082    899 211136]
132
133    Seed can also be used in a with statement, which sets the random
134    number generator state for the enclosed computations and restores
135    it to the previous state on completion::
136
137        >>> with push_seed(24):
138        ...    print(randint(0,1000000,3))
139        [242082    899 211136]
140
141    Using nested contexts, we can demonstrate that state is indeed
142    restored after the block completes::
143
144        >>> with push_seed(24):
145        ...    print(randint(0,1000000))
146        ...    with push_seed(24):
147        ...        print(randint(0,1000000,3))
148        ...    print(randint(0,1000000))
149        242082
150        [242082    899 211136]
151        899
152
153    The restore step is protected against exceptions in the block::
154
155        >>> with push_seed(24):
156        ...    print(randint(0,1000000))
157        ...    try:
158        ...        with push_seed(24):
159        ...            print(randint(0,1000000,3))
160        ...            raise Exception()
161        ...    except:
162        ...        print("Exception raised")
163        ...    print(randint(0,1000000))
164        242082
165        [242082    899 211136]
166        Exception raised
167        899
168    """
169    def __init__(self, seed=None):
170        self._state = np.random.get_state()
171        np.random.seed(seed)
172
173    def __enter__(self):
174        return None
175
176    def __exit__(self, *args):
177        np.random.set_state(self._state)
178
179def tic():
180    """
181    Timer function.
182
183    Use "toc=tic()" to start the clock and "toc()" to measure
184    a time interval.
185    """
186    then = datetime.datetime.now()
187    return lambda: delay(datetime.datetime.now() - then)
188
189
190def set_beam_stop(data, radius, outer=None):
191    """
192    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
193
194    Note: this function does not use the sasview package
195    """
196    if hasattr(data, 'qx_data'):
197        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
198        data.mask = (q < radius)
199        if outer is not None:
200            data.mask |= (q >= outer)
201    else:
202        data.mask = (data.x < radius)
203        if outer is not None:
204            data.mask |= (data.x >= outer)
205
206
207def parameter_range(p, v):
208    """
209    Choose a parameter range based on parameter name and initial value.
210    """
211    if p.endswith('_pd_n'):
212        return [0, 100]
213    elif p.endswith('_pd_nsigma'):
214        return [0, 5]
215    elif p.endswith('_pd_type'):
216        return v
217    elif any(s in p for s in ('theta', 'phi', 'psi')):
218        # orientation in [-180,180], orientation pd in [0,45]
219        if p.endswith('_pd'):
220            return [0, 45]
221        else:
222            return [-180, 180]
223    elif 'sld' in p:
224        return [-0.5, 10]
225    elif p.endswith('_pd'):
226        return [0, 1]
227    elif p == 'background':
228        return [0, 10]
229    elif p == 'scale':
230        return [0, 1e3]
231    elif p == 'case_num':
232        # RPA hack
233        return [0, 10]
234    elif v < 0:
235        # Kxy parameters in rpa model can be negative
236        return [2*v, -2*v]
237    else:
238        return [0, (2*v if v > 0 else 1)]
239
240
241def _randomize_one(p, v):
242    """
243    Randomize a single parameter.
244    """
245    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
246        return v
247    else:
248        return np.random.uniform(*parameter_range(p, v))
249
250
251def randomize_pars(pars, seed=None):
252    """
253    Generate random values for all of the parameters.
254
255    Valid ranges for the random number generator are guessed from the name of
256    the parameter; this will not account for constraints such as cap radius
257    greater than cylinder radius in the capped_cylinder model, so
258    :func:`constrain_pars` needs to be called afterward..
259    """
260    with push_seed(seed):
261        # Note: the sort guarantees order `of calls to random number generator
262        pars = dict((p, _randomize_one(p, v))
263                    for p, v in sorted(pars.items()))
264    return pars
265
266def constrain_pars(model_info, pars):
267    """
268    Restrict parameters to valid values.
269
270    This includes model specific code for models such as capped_cylinder
271    which need to support within model constraints (cap radius more than
272    cylinder radius in this case).
273    """
274    name = model_info['id']
275    # if it is a product model, then just look at the form factor since
276    # none of the structure factors need any constraints.
277    if '*' in name:
278        name = name.split('*')[0]
279
280    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
281        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
282    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
283        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
284
285    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
286    if name == 'guinier':
287        #q_max = 0.2  # mid q maximum
288        q_max = 1.0  # high q maximum
289        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
290        pars['rg'] = min(pars['rg'], rg_max)
291
292    if name == 'rpa':
293        # Make sure phi sums to 1.0
294        if pars['case_num'] < 2:
295            pars['Phia'] = 0.
296            pars['Phib'] = 0.
297        elif pars['case_num'] < 5:
298            pars['Phia'] = 0.
299        total = sum(pars['Phi'+c] for c in 'abcd')
300        for c in 'abcd':
301            pars['Phi'+c] /= total
302
303def parlist(model_info, pars, is2d):
304    """
305    Format the parameter list for printing.
306    """
307    if is2d:
308        exclude = lambda n: False
309    else:
310        partype = model_info['partype']
311        par1d = set(partype['fixed-1d']+partype['pd-1d'])
312        exclude = lambda n: n not in par1d
313    lines = []
314    for p in model_info['parameters']:
315        if exclude(p.name): continue
316        fields = dict(
317            value=pars.get(p.name, p.default),
318            pd=pars.get(p.name+"_pd", 0.),
319            n=int(pars.get(p.name+"_pd_n", 0)),
320            nsigma=pars.get(p.name+"_pd_nsgima", 3.),
321            type=pars.get(p.name+"_pd_type", 'gaussian'))
322        lines.append(_format_par(p.name, **fields))
323    return "\n".join(lines)
324
325    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
326
327def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
328    line = "%s: %g"%(name, value)
329    if pd != 0.  and n != 0:
330        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
331                % (pd, n, nsigma, nsigma, type)
332    return line
333
334def suppress_pd(pars):
335    """
336    Suppress theta_pd for now until the normalization is resolved.
337
338    May also suppress complete polydispersity of the model to test
339    models more quickly.
340    """
341    pars = pars.copy()
342    for p in pars:
343        if p.endswith("_pd_n"): pars[p] = 0
344    return pars
345
346def eval_sasview(model_info, data):
347    """
348    Return a model calculator using the SasView fitting engine.
349    """
350    # importing sas here so that the error message will be that sas failed to
351    # import rather than the more obscure smear_selection not imported error
352    import sas
353    from sas.models.qsmearing import smear_selection
354
355    def get_model(name):
356        #print("new",sorted(_pars.items()))
357        sas = __import__('sas.models.' + name)
358        ModelClass = getattr(getattr(sas.models, name, None), name, None)
359        if ModelClass is None:
360            raise ValueError("could not find model %r in sas.models"%name)
361        return ModelClass()
362
363    # grab the sasview model, or create it if it is a product model
364    if model_info['composition']:
365        composition_type, parts = model_info['composition']
366        if composition_type == 'product':
367            from sas.models.MultiplicationModel import MultiplicationModel
368            P, S = [get_model(p) for p in model_info['oldname']]
369            model = MultiplicationModel(P, S)
370        else:
371            raise ValueError("mixture models not handled yet")
372    else:
373        model = get_model(model_info['oldname'])
374
375    # build a smearer with which to call the model, if necessary
376    smearer = smear_selection(data, model=model)
377    if hasattr(data, 'qx_data'):
378        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
379        index = ((~data.mask) & (~np.isnan(data.data))
380                 & (q >= data.qmin) & (q <= data.qmax))
381        if smearer is not None:
382            smearer.model = model  # because smear_selection has a bug
383            smearer.accuracy = data.accuracy
384            smearer.set_index(index)
385            theory = lambda: smearer.get_value()
386        else:
387            theory = lambda: model.evalDistribution([data.qx_data[index],
388                                                     data.qy_data[index]])
389    elif smearer is not None:
390        theory = lambda: smearer(model.evalDistribution(data.x))
391    else:
392        theory = lambda: model.evalDistribution(data.x)
393
394    def calculator(**pars):
395        """
396        Sasview calculator for model.
397        """
398        # paying for parameter conversion each time to keep life simple, if not fast
399        pars = revert_pars(model_info, pars)
400        for k, v in pars.items():
401            parts = k.split('.')  # polydispersity components
402            if len(parts) == 2:
403                model.dispersion[parts[0]][parts[1]] = v
404            else:
405                model.setParam(k, v)
406        return theory()
407
408    calculator.engine = "sasview"
409    return calculator
410
411DTYPE_MAP = {
412    'half': '16',
413    'fast': 'fast',
414    'single': '32',
415    'double': '64',
416    'quad': '128',
417    'f16': '16',
418    'f32': '32',
419    'f64': '64',
420    'longdouble': '128',
421}
422def eval_opencl(model_info, data, dtype='single', cutoff=0.):
423    """
424    Return a model calculator using the OpenCL calculation engine.
425    """
426    def builder(model_info):
427        try:
428            return core.build_model(model_info, dtype=dtype, platform="ocl")
429        except Exception as exc:
430            print(exc)
431            print("... trying again with single precision")
432            return core.build_model(model_info, dtype='single', platform="ocl")
433    if model_info['composition']:
434        composition_type, parts = model_info['composition']
435        if composition_type == 'product':
436            P, S = [builder(p) for p in parts]
437            model = product.ProductModel(P, S)
438        else:
439            raise ValueError("mixture models not handled yet")
440    else:
441        model = builder(model_info)
442    calculator = DirectModel(data, model, cutoff=cutoff)
443    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
444    return calculator
445
446def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
447    """
448    Return a model calculator using the DLL calculation engine.
449    """
450    if dtype == 'quad':
451        dtype = 'longdouble'
452    def builder(model_info):
453        return core.build_model(model_info, dtype=dtype, platform="dll")
454
455    if model_info['composition']:
456        composition_type, parts = model_info['composition']
457        if composition_type == 'product':
458            P, S = [builder(p) for p in parts]
459            model = product.ProductModel(P, S)
460        else:
461            raise ValueError("mixture models not handled yet")
462    else:
463        model = builder(model_info)
464    calculator = DirectModel(data, model, cutoff=cutoff)
465    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
466    return calculator
467
468def time_calculation(calculator, pars, Nevals=1):
469    """
470    Compute the average calculation time over N evaluations.
471
472    An additional call is generated without polydispersity in order to
473    initialize the calculation engine, and make the average more stable.
474    """
475    # initialize the code so time is more accurate
476    value = calculator(**suppress_pd(pars))
477    toc = tic()
478    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
479        value = calculator(**pars)
480    average_time = toc()*1000./Nevals
481    return value, average_time
482
483def make_data(opts):
484    """
485    Generate an empty dataset, used with the model to set Q points
486    and resolution.
487
488    *opts* contains the options, with 'qmax', 'nq', 'res',
489    'accuracy', 'is2d' and 'view' parsed from the command line.
490    """
491    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
492    if opts['is2d']:
493        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
494        data.accuracy = opts['accuracy']
495        set_beam_stop(data, 0.004)
496        index = ~data.mask
497    else:
498        if opts['view'] == 'log':
499            qmax = math.log10(qmax)
500            q = np.logspace(qmax-3, qmax, nq)
501        else:
502            q = np.linspace(0.001*qmax, qmax, nq)
503        data = empty_data1D(q, resolution=res)
504        index = slice(None, None)
505    return data, index
506
507def make_engine(model_info, data, dtype, cutoff):
508    """
509    Generate the appropriate calculation engine for the given datatype.
510
511    Datatypes with '!' appended are evaluated using external C DLLs rather
512    than OpenCL.
513    """
514    if dtype == 'sasview':
515        return eval_sasview(model_info, data)
516    elif dtype.endswith('!'):
517        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
518    else:
519        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
520
521def compare(opts, limits=None):
522    """
523    Preform a comparison using options from the command line.
524
525    *limits* are the limits on the values to use, either to set the y-axis
526    for 1D or to set the colormap scale for 2D.  If None, then they are
527    inferred from the data and returned. When exploring using Bumps,
528    the limits are set when the model is initially called, and maintained
529    as the values are adjusted, making it easier to see the effects of the
530    parameters.
531    """
532    Nbase, Ncomp = opts['n1'], opts['n2']
533    pars = opts['pars']
534    data = opts['data']
535
536    # Base calculation
537    if Nbase > 0:
538        base = opts['engines'][0]
539        try:
540            base_value, base_time = time_calculation(base, pars, Nbase)
541            print("%s t=%.2f ms, intensity=%.0f"
542                  % (base.engine, base_time, sum(base_value)))
543        except ImportError:
544            traceback.print_exc()
545            Nbase = 0
546
547    # Comparison calculation
548    if Ncomp > 0:
549        comp = opts['engines'][1]
550        try:
551            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
552            print("%s t=%.2f ms, intensity=%.0f"
553                  % (comp.engine, comp_time, sum(comp_value)))
554        except ImportError:
555            traceback.print_exc()
556            Ncomp = 0
557
558    # Compare, but only if computing both forms
559    if Nbase > 0 and Ncomp > 0:
560        resid = (base_value - comp_value)
561        relerr = resid/comp_value
562        _print_stats("|%s-%s|"
563                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
564                     resid)
565        _print_stats("|(%s-%s)/%s|"
566                     % (base.engine, comp.engine, comp.engine),
567                     relerr)
568
569    # Plot if requested
570    if not opts['plot'] and not opts['explore']: return
571    view = opts['view']
572    import matplotlib.pyplot as plt
573    if limits is None:
574        vmin, vmax = np.Inf, -np.Inf
575        if Nbase > 0:
576            vmin = min(vmin, min(base_value))
577            vmax = max(vmax, max(base_value))
578        if Ncomp > 0:
579            vmin = min(vmin, min(comp_value))
580            vmax = max(vmax, max(comp_value))
581        limits = vmin, vmax
582
583    if Nbase > 0:
584        if Ncomp > 0: plt.subplot(131)
585        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
586        plt.title("%s t=%.2f ms"%(base.engine, base_time))
587        #cbar_title = "log I"
588    if Ncomp > 0:
589        if Nbase > 0: plt.subplot(132)
590        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
591        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
592        #cbar_title = "log I"
593    if Ncomp > 0 and Nbase > 0:
594        plt.subplot(133)
595        if not opts['rel_err']:
596            err, errstr, errview = resid, "abs err", "linear"
597        else:
598            err, errstr, errview = abs(relerr), "rel err", "log"
599        #err,errstr = base/comp,"ratio"
600        plot_theory(data, None, resid=err, view=errview, use_data=False)
601        if view == 'linear':
602            plt.xscale('linear')
603        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
604        #cbar_title = errstr if errview=="linear" else "log "+errstr
605    #if is2D:
606    #    h = plt.colorbar()
607    #    h.ax.set_title(cbar_title)
608
609    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
610        plt.figure()
611        v = relerr
612        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
613        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
614        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
615                   % (base.engine, comp.engine, comp.engine))
616        plt.ylabel('P(err)')
617        plt.title('Distribution of relative error between calculation engines')
618
619    if not opts['explore']:
620        plt.show()
621
622    return limits
623
624def _print_stats(label, err):
625    sorted_err = np.sort(abs(err))
626    p50 = int((len(err)-1)*0.50)
627    p98 = int((len(err)-1)*0.98)
628    data = [
629        "max:%.3e"%sorted_err[-1],
630        "median:%.3e"%sorted_err[p50],
631        "98%%:%.3e"%sorted_err[p98],
632        "rms:%.3e"%np.sqrt(np.mean(err**2)),
633        "zero-offset:%+.3e"%np.mean(err),
634        ]
635    print(label+"  "+"  ".join(data))
636
637
638
639# ===========================================================================
640#
641NAME_OPTIONS = set([
642    'plot', 'noplot',
643    'half', 'fast', 'single', 'double',
644    'single!', 'double!', 'quad!', 'sasview',
645    'lowq', 'midq', 'highq', 'exq',
646    '2d', '1d',
647    'preset', 'random',
648    'poly', 'mono',
649    'nopars', 'pars',
650    'rel', 'abs',
651    'linear', 'log', 'q4',
652    'hist', 'nohist',
653    'edit',
654    ])
655VALUE_OPTIONS = [
656    # Note: random is both a name option and a value option
657    'cutoff', 'random', 'nq', 'res', 'accuracy',
658    ]
659
660def columnize(L, indent="", width=79):
661    """
662    Format a list of strings into columns.
663
664    Returns a string with carriage returns ready for printing.
665    """
666    column_width = max(len(w) for w in L) + 1
667    num_columns = (width - len(indent)) // column_width
668    num_rows = len(L) // num_columns
669    L = L + [""] * (num_rows*num_columns - len(L))
670    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
671    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
672             for row in zip(*columns)]
673    output = indent + ("\n"+indent).join(lines)
674    return output
675
676
677def get_demo_pars(model_info):
678    """
679    Extract demo parameters from the model definition.
680    """
681    # Get the default values for the parameters
682    pars = dict((p.name, p.default) for p in model_info['parameters'])
683
684    # Fill in default values for the polydispersity parameters
685    for p in model_info['parameters']:
686        if p.type in ('volume', 'orientation'):
687            pars[p.name+'_pd'] = 0.0
688            pars[p.name+'_pd_n'] = 0
689            pars[p.name+'_pd_nsigma'] = 3.0
690            pars[p.name+'_pd_type'] = "gaussian"
691
692    # Plug in values given in demo
693    pars.update(model_info['demo'])
694    return pars
695
696
697def parse_opts():
698    """
699    Parse command line options.
700    """
701    MODELS = core.list_models()
702    flags = [arg for arg in sys.argv[1:]
703             if arg.startswith('-')]
704    values = [arg for arg in sys.argv[1:]
705              if not arg.startswith('-') and '=' in arg]
706    args = [arg for arg in sys.argv[1:]
707            if not arg.startswith('-') and '=' not in arg]
708    models = "\n    ".join("%-15s"%v for v in MODELS)
709    if len(args) == 0:
710        print(USAGE)
711        print("\nAvailable models:")
712        print(columnize(MODELS, indent="  "))
713        sys.exit(1)
714    if len(args) > 3:
715        print("expected parameters: model N1 N2")
716
717    def _get_info(name):
718        try:
719            model_info = core.load_model_info(name)
720        except ImportError, exc:
721            print(str(exc))
722            print("Use one of:\n    " + models)
723            sys.exit(1)
724        return model_info
725
726    name = args[0]
727    if '*' in name:
728        parts = [_get_info(k) for k in name.split('*')]
729        model_info = product.make_product_info(*parts)
730    else:
731        model_info = _get_info(name)
732
733    invalid = [o[1:] for o in flags
734               if o[1:] not in NAME_OPTIONS
735               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
736    if invalid:
737        print("Invalid options: %s"%(", ".join(invalid)))
738        sys.exit(1)
739
740
741    # pylint: disable=bad-whitespace
742    # Interpret the flags
743    opts = {
744        'plot'      : True,
745        'view'      : 'log',
746        'is2d'      : False,
747        'qmax'      : 0.05,
748        'nq'        : 128,
749        'res'       : 0.0,
750        'accuracy'  : 'Low',
751        'cutoff'    : 1e-5,
752        'seed'      : -1,  # default to preset
753        'mono'      : False,
754        'show_pars' : False,
755        'show_hist' : False,
756        'rel_err'   : True,
757        'explore'   : False,
758    }
759    engines = []
760    for arg in flags:
761        if arg == '-noplot':    opts['plot'] = False
762        elif arg == '-plot':    opts['plot'] = True
763        elif arg == '-linear':  opts['view'] = 'linear'
764        elif arg == '-log':     opts['view'] = 'log'
765        elif arg == '-q4':      opts['view'] = 'q4'
766        elif arg == '-1d':      opts['is2d'] = False
767        elif arg == '-2d':      opts['is2d'] = True
768        elif arg == '-exq':     opts['qmax'] = 10.0
769        elif arg == '-highq':   opts['qmax'] = 1.0
770        elif arg == '-midq':    opts['qmax'] = 0.2
771        elif arg == '-lowq':    opts['qmax'] = 0.05
772        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
773        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
774        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
775        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
776        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
777        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
778        elif arg == '-preset':  opts['seed'] = -1
779        elif arg == '-mono':    opts['mono'] = True
780        elif arg == '-poly':    opts['mono'] = False
781        elif arg == '-pars':    opts['show_pars'] = True
782        elif arg == '-nopars':  opts['show_pars'] = False
783        elif arg == '-hist':    opts['show_hist'] = True
784        elif arg == '-nohist':  opts['show_hist'] = False
785        elif arg == '-rel':     opts['rel_err'] = True
786        elif arg == '-abs':     opts['rel_err'] = False
787        elif arg == '-half':    engines.append(arg[1:])
788        elif arg == '-fast':    engines.append(arg[1:])
789        elif arg == '-single':  engines.append(arg[1:])
790        elif arg == '-double':  engines.append(arg[1:])
791        elif arg == '-single!': engines.append(arg[1:])
792        elif arg == '-double!': engines.append(arg[1:])
793        elif arg == '-quad!':   engines.append(arg[1:])
794        elif arg == '-sasview': engines.append(arg[1:])
795        elif arg == '-edit':    opts['explore'] = True
796    # pylint: enable=bad-whitespace
797
798    if len(engines) == 0:
799        engines.extend(['single', 'sasview'])
800    elif len(engines) == 1:
801        if engines[0][0] != 'sasview':
802            engines.append('sasview')
803        else:
804            engines.append('single')
805    elif len(engines) > 2:
806        del engines[2:]
807
808    n1 = int(args[1]) if len(args) > 1 else 1
809    n2 = int(args[2]) if len(args) > 2 else 1
810
811    # Get demo parameters from model definition, or use default parameters
812    # if model does not define demo parameters
813    pars = get_demo_pars(model_info)
814
815    # Fill in parameters given on the command line
816    presets = {}
817    for arg in values:
818        k, v = arg.split('=', 1)
819        if k not in pars:
820            # extract base name without polydispersity info
821            s = set(p.split('_pd')[0] for p in pars)
822            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
823            sys.exit(1)
824        presets[k] = float(v) if not k.endswith('type') else v
825
826    # randomize parameters
827    #pars.update(set_pars)  # set value before random to control range
828    if opts['seed'] > -1:
829        pars = randomize_pars(pars, seed=opts['seed'])
830        print("Randomize using -random=%i"%opts['seed'])
831    if opts['mono']:
832        pars = suppress_pd(pars)
833    pars.update(presets)  # set value after random to control value
834    #import pprint; pprint.pprint(model_info)
835    constrain_pars(model_info, pars)
836    constrain_new_to_old(model_info, pars)
837    if opts['show_pars']:
838        print(str(parlist(model_info, pars, opts['is2d'])))
839
840    # Create the computational engines
841    data, _ = make_data(opts)
842    if n1:
843        base = make_engine(model_info, data, engines[0], opts['cutoff'])
844    else:
845        base = None
846    if n2:
847        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
848    else:
849        comp = None
850
851    # pylint: disable=bad-whitespace
852    # Remember it all
853    opts.update({
854        'name'      : name,
855        'def'       : model_info,
856        'n1'        : n1,
857        'n2'        : n2,
858        'presets'   : presets,
859        'pars'      : pars,
860        'data'      : data,
861        'engines'   : [base, comp],
862    })
863    # pylint: enable=bad-whitespace
864
865    return opts
866
867def explore(opts):
868    """
869    Explore the model using the Bumps GUI.
870    """
871    import wx
872    from bumps.names import FitProblem
873    from bumps.gui.app_frame import AppFrame
874
875    problem = FitProblem(Explore(opts))
876    is_mac = "cocoa" in wx.version()
877    app = wx.App()
878    frame = AppFrame(parent=None, title="explore")
879    if not is_mac: frame.Show()
880    frame.panel.set_model(model=problem)
881    frame.panel.Layout()
882    frame.panel.aui.Split(0, wx.TOP)
883    if is_mac: frame.Show()
884    app.MainLoop()
885
886class Explore(object):
887    """
888    Bumps wrapper for a SAS model comparison.
889
890    The resulting object can be used as a Bumps fit problem so that
891    parameters can be adjusted in the GUI, with plots updated on the fly.
892    """
893    def __init__(self, opts):
894        from bumps.cli import config_matplotlib
895        from . import bumps_model
896        config_matplotlib()
897        self.opts = opts
898        model_info = opts['def']
899        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
900        if not opts['is2d']:
901            active = [base + ext
902                      for base in model_info['partype']['pd-1d']
903                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
904            active.extend(model_info['partype']['fixed-1d'])
905            for k in active:
906                v = pars[k]
907                v.range(*parameter_range(k, v.value))
908        else:
909            for k, v in pars.items():
910                v.range(*parameter_range(k, v.value))
911
912        self.pars = pars
913        self.pd_types = pd_types
914        self.limits = None
915
916    def numpoints(self):
917        """
918        Return the number of points.
919        """
920        return len(self.pars) + 1  # so dof is 1
921
922    def parameters(self):
923        """
924        Return a dictionary of parameters.
925        """
926        return self.pars
927
928    def nllf(self):
929        """
930        Return cost.
931        """
932        # pylint: disable=no-self-use
933        return 0.  # No nllf
934
935    def plot(self, view='log'):
936        """
937        Plot the data and residuals.
938        """
939        pars = dict((k, v.value) for k, v in self.pars.items())
940        pars.update(self.pd_types)
941        self.opts['pars'] = pars
942        limits = compare(self.opts, limits=self.limits)
943        if self.limits is None:
944            vmin, vmax = limits
945            vmax = 1.3*vmax
946            vmin = vmax*1e-7
947            self.limits = vmin, vmax
948
949
950def main():
951    """
952    Main program.
953    """
954    opts = parse_opts()
955    if opts['explore']:
956        explore(opts)
957    else:
958        compare(opts)
959
960if __name__ == "__main__":
961    main()
Note: See TracBrowser for help on using the repository browser.