source: sasmodels/sasmodels/compare.py @ 69aa451

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 69aa451 was 69aa451, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor parameter representation

  • Property mode set to 100755
File size: 30.2 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np
37
38from . import core
39from . import kerneldll
40from . import product
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_pars, constrain_new_to_old
44
45USAGE = """
46usage: compare.py model N1 N2 [options...] [key=val]
47
48Compare the speed and value for a model between the SasView original and the
49sasmodels rewrite.
50
51model is the name of the model to compare (see below).
52N1 is the number of times to run sasmodels (default=1).
53N2 is the number times to run sasview (default=1).
54
55Options (* for default):
56
57    -plot*/-noplot plots or suppress the plot of the model
58    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
59    -nq=128 sets the number of Q points in the data set
60    -1d*/-2d computes 1d or 2d data
61    -preset*/-random[=seed] preset or random parameters
62    -mono/-poly* force monodisperse/polydisperse
63    -cutoff=1e-5* cutoff value for including a point in polydispersity
64    -pars/-nopars* prints the parameter set or not
65    -abs/-rel* plot relative or absolute error
66    -linear/-log*/-q4 intensity scaling
67    -hist/-nohist* plot histogram of relative error
68    -res=0 sets the resolution width dQ/Q if calculating with resolution
69    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
70    -edit starts the parameter explorer
71
72Any two calculation engines can be selected for comparison:
73
74    -single/-double/-half/-fast sets an OpenCL calculation engine
75    -single!/-double!/-quad! sets an OpenMP calculation engine
76    -sasview sets the sasview calculation engine
77
78The default is -single -sasview.  Note that the interpretation of quad
79precision depends on architecture, and may vary from 64-bit to 128-bit,
80with 80-bit floats being common (1e-19 precision).
81
82Key=value pairs allow you to set specific values for the model parameters.
83"""
84
85# Update docs with command line usage string.   This is separate from the usual
86# doc string so that we can display it at run time if there is an error.
87# lin
88__doc__ = (__doc__  # pylint: disable=redefined-builtin
89           + """
90Program description
91-------------------
92
93"""
94           + USAGE)
95
96kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
97
98MODELS = core.list_models()
99
100# CRUFT python 2.6
101if not hasattr(datetime.timedelta, 'total_seconds'):
102    def delay(dt):
103        """Return number date-time delta as number seconds"""
104        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
105else:
106    def delay(dt):
107        """Return number date-time delta as number seconds"""
108        return dt.total_seconds()
109
110
111class push_seed(object):
112    """
113    Set the seed value for the random number generator.
114
115    When used in a with statement, the random number generator state is
116    restored after the with statement is complete.
117
118    :Parameters:
119
120    *seed* : int or array_like, optional
121        Seed for RandomState
122
123    :Example:
124
125    Seed can be used directly to set the seed::
126
127        >>> from numpy.random import randint
128        >>> push_seed(24)
129        <...push_seed object at...>
130        >>> print(randint(0,1000000,3))
131        [242082    899 211136]
132
133    Seed can also be used in a with statement, which sets the random
134    number generator state for the enclosed computations and restores
135    it to the previous state on completion::
136
137        >>> with push_seed(24):
138        ...    print(randint(0,1000000,3))
139        [242082    899 211136]
140
141    Using nested contexts, we can demonstrate that state is indeed
142    restored after the block completes::
143
144        >>> with push_seed(24):
145        ...    print(randint(0,1000000))
146        ...    with push_seed(24):
147        ...        print(randint(0,1000000,3))
148        ...    print(randint(0,1000000))
149        242082
150        [242082    899 211136]
151        899
152
153    The restore step is protected against exceptions in the block::
154
155        >>> with push_seed(24):
156        ...    print(randint(0,1000000))
157        ...    try:
158        ...        with push_seed(24):
159        ...            print(randint(0,1000000,3))
160        ...            raise Exception()
161        ...    except:
162        ...        print("Exception raised")
163        ...    print(randint(0,1000000))
164        242082
165        [242082    899 211136]
166        Exception raised
167        899
168    """
169    def __init__(self, seed=None):
170        self._state = np.random.get_state()
171        np.random.seed(seed)
172
173    def __enter__(self):
174        return None
175
176    def __exit__(self, *args):
177        np.random.set_state(self._state)
178
179def tic():
180    """
181    Timer function.
182
183    Use "toc=tic()" to start the clock and "toc()" to measure
184    a time interval.
185    """
186    then = datetime.datetime.now()
187    return lambda: delay(datetime.datetime.now() - then)
188
189
190def set_beam_stop(data, radius, outer=None):
191    """
192    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
193
194    Note: this function does not use the sasview package
195    """
196    if hasattr(data, 'qx_data'):
197        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
198        data.mask = (q < radius)
199        if outer is not None:
200            data.mask |= (q >= outer)
201    else:
202        data.mask = (data.x < radius)
203        if outer is not None:
204            data.mask |= (data.x >= outer)
205
206
207def parameter_range(p, v):
208    """
209    Choose a parameter range based on parameter name and initial value.
210    """
211    if p.endswith('_pd_n'):
212        return [0, 100]
213    elif p.endswith('_pd_nsigma'):
214        return [0, 5]
215    elif p.endswith('_pd_type'):
216        return v
217    elif any(s in p for s in ('theta', 'phi', 'psi')):
218        # orientation in [-180,180], orientation pd in [0,45]
219        if p.endswith('_pd'):
220            return [0, 45]
221        else:
222            return [-180, 180]
223    elif 'sld' in p:
224        return [-0.5, 10]
225    elif p.endswith('_pd'):
226        return [0, 1]
227    elif p == 'background':
228        return [0, 10]
229    elif p == 'scale':
230        return [0, 1e3]
231    elif p == 'case_num':
232        # RPA hack
233        return [0, 10]
234    elif v < 0:
235        # Kxy parameters in rpa model can be negative
236        return [2*v, -2*v]
237    else:
238        return [0, (2*v if v > 0 else 1)]
239
240
241def _randomize_one(p, v):
242    """
243    Randomize a single parameter.
244    """
245    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
246        return v
247    else:
248        return np.random.uniform(*parameter_range(p, v))
249
250
251def randomize_pars(pars, seed=None):
252    """
253    Generate random values for all of the parameters.
254
255    Valid ranges for the random number generator are guessed from the name of
256    the parameter; this will not account for constraints such as cap radius
257    greater than cylinder radius in the capped_cylinder model, so
258    :func:`constrain_pars` needs to be called afterward..
259    """
260    with push_seed(seed):
261        # Note: the sort guarantees order `of calls to random number generator
262        pars = dict((p, _randomize_one(p, v))
263                    for p, v in sorted(pars.items()))
264    return pars
265
266def constrain_pars(model_info, pars):
267    """
268    Restrict parameters to valid values.
269
270    This includes model specific code for models such as capped_cylinder
271    which need to support within model constraints (cap radius more than
272    cylinder radius in this case).
273    """
274    name = model_info['id']
275    # if it is a product model, then just look at the form factor since
276    # none of the structure factors need any constraints.
277    if '*' in name:
278        name = name.split('*')[0]
279
280    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
281        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
282    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
283        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
284
285    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
286    if name == 'guinier':
287        #q_max = 0.2  # mid q maximum
288        q_max = 1.0  # high q maximum
289        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
290        pars['rg'] = min(pars['rg'], rg_max)
291
292    if name == 'rpa':
293        # Make sure phi sums to 1.0
294        if pars['case_num'] < 2:
295            pars['Phia'] = 0.
296            pars['Phib'] = 0.
297        elif pars['case_num'] < 5:
298            pars['Phia'] = 0.
299        total = sum(pars['Phi'+c] for c in 'abcd')
300        for c in 'abcd':
301            pars['Phi'+c] /= total
302
303def parlist(model_info, pars, is2d):
304    """
305    Format the parameter list for printing.
306    """
307    if is2d:
308        exclude = lambda n: False
309    else:
310        par1d = model_info['par_type']['1d']
311        exclude = lambda n: n not in par1d
312    lines = []
313    for p in model_info['parameters']:
314        if exclude(p.name): continue
315        fields = dict(
316            value=pars.get(p.name, p.default),
317            pd=pars.get(p.name+"_pd", 0.),
318            n=int(pars.get(p.name+"_pd_n", 0)),
319            nsigma=pars.get(p.name+"_pd_nsgima", 3.),
320            type=pars.get(p.name+"_pd_type", 'gaussian'))
321        lines.append(_format_par(p.name, **fields))
322    return "\n".join(lines)
323
324    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
325
326def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
327    line = "%s: %g"%(name, value)
328    if pd != 0.  and n != 0:
329        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
330                % (pd, n, nsigma, nsigma, type)
331    return line
332
333def suppress_pd(pars):
334    """
335    Suppress theta_pd for now until the normalization is resolved.
336
337    May also suppress complete polydispersity of the model to test
338    models more quickly.
339    """
340    pars = pars.copy()
341    for p in pars:
342        if p.endswith("_pd_n"): pars[p] = 0
343    return pars
344
345def eval_sasview(model_info, data):
346    """
347    Return a model calculator using the SasView fitting engine.
348    """
349    # importing sas here so that the error message will be that sas failed to
350    # import rather than the more obscure smear_selection not imported error
351    import sas
352    from sas.models.qsmearing import smear_selection
353
354    def get_model(name):
355        #print("new",sorted(_pars.items()))
356        sas = __import__('sas.models.' + name)
357        ModelClass = getattr(getattr(sas.models, name, None), name, None)
358        if ModelClass is None:
359            raise ValueError("could not find model %r in sas.models"%name)
360        return ModelClass()
361
362    # grab the sasview model, or create it if it is a product model
363    if model_info['composition']:
364        composition_type, parts = model_info['composition']
365        if composition_type == 'product':
366            from sas.models.MultiplicationModel import MultiplicationModel
367            P, S = [get_model(p) for p in model_info['oldname']]
368            model = MultiplicationModel(P, S)
369        else:
370            raise ValueError("sasview mixture models not supported by compare")
371    else:
372        model = get_model(model_info['oldname'])
373
374    # build a smearer with which to call the model, if necessary
375    smearer = smear_selection(data, model=model)
376    if hasattr(data, 'qx_data'):
377        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
378        index = ((~data.mask) & (~np.isnan(data.data))
379                 & (q >= data.qmin) & (q <= data.qmax))
380        if smearer is not None:
381            smearer.model = model  # because smear_selection has a bug
382            smearer.accuracy = data.accuracy
383            smearer.set_index(index)
384            theory = lambda: smearer.get_value()
385        else:
386            theory = lambda: model.evalDistribution([data.qx_data[index],
387                                                     data.qy_data[index]])
388    elif smearer is not None:
389        theory = lambda: smearer(model.evalDistribution(data.x))
390    else:
391        theory = lambda: model.evalDistribution(data.x)
392
393    def calculator(**pars):
394        """
395        Sasview calculator for model.
396        """
397        # paying for parameter conversion each time to keep life simple, if not fast
398        pars = revert_pars(model_info, pars)
399        for k, v in pars.items():
400            parts = k.split('.')  # polydispersity components
401            if len(parts) == 2:
402                model.dispersion[parts[0]][parts[1]] = v
403            else:
404                model.setParam(k, v)
405        return theory()
406
407    calculator.engine = "sasview"
408    return calculator
409
410DTYPE_MAP = {
411    'half': '16',
412    'fast': 'fast',
413    'single': '32',
414    'double': '64',
415    'quad': '128',
416    'f16': '16',
417    'f32': '32',
418    'f64': '64',
419    'longdouble': '128',
420}
421def eval_opencl(model_info, data, dtype='single', cutoff=0.):
422    """
423    Return a model calculator using the OpenCL calculation engine.
424    """
425    try:
426        model = core.build_model(model_info, dtype=dtype, platform="ocl")
427    except Exception as exc:
428        print(exc)
429        print("... trying again with single precision")
430        model = core.build_model(model_info, dtype='single', platform="ocl")
431    calculator = DirectModel(data, model, cutoff=cutoff)
432    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
433    return calculator
434
435def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
436    """
437    Return a model calculator using the DLL calculation engine.
438    """
439    if dtype == 'quad':
440        dtype = 'longdouble'
441    model = core.build_model(model_info, dtype=dtype, platform="dll")
442    calculator = DirectModel(data, model, cutoff=cutoff)
443    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
444    return calculator
445
446def time_calculation(calculator, pars, Nevals=1):
447    """
448    Compute the average calculation time over N evaluations.
449
450    An additional call is generated without polydispersity in order to
451    initialize the calculation engine, and make the average more stable.
452    """
453    # initialize the code so time is more accurate
454    if Nevals > 1:
455        value = calculator(**suppress_pd(pars))
456    toc = tic()
457    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
458        value = calculator(**pars)
459    average_time = toc()*1000./Nevals
460    return value, average_time
461
462def make_data(opts):
463    """
464    Generate an empty dataset, used with the model to set Q points
465    and resolution.
466
467    *opts* contains the options, with 'qmax', 'nq', 'res',
468    'accuracy', 'is2d' and 'view' parsed from the command line.
469    """
470    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
471    if opts['is2d']:
472        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
473        data.accuracy = opts['accuracy']
474        set_beam_stop(data, 0.004)
475        index = ~data.mask
476    else:
477        if opts['view'] == 'log':
478            qmax = math.log10(qmax)
479            q = np.logspace(qmax-3, qmax, nq)
480        else:
481            q = np.linspace(0.001*qmax, qmax, nq)
482        data = empty_data1D(q, resolution=res)
483        index = slice(None, None)
484    return data, index
485
486def make_engine(model_info, data, dtype, cutoff):
487    """
488    Generate the appropriate calculation engine for the given datatype.
489
490    Datatypes with '!' appended are evaluated using external C DLLs rather
491    than OpenCL.
492    """
493    if dtype == 'sasview':
494        return eval_sasview(model_info, data)
495    elif dtype.endswith('!'):
496        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
497    else:
498        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
499
500def compare(opts, limits=None):
501    """
502    Preform a comparison using options from the command line.
503
504    *limits* are the limits on the values to use, either to set the y-axis
505    for 1D or to set the colormap scale for 2D.  If None, then they are
506    inferred from the data and returned. When exploring using Bumps,
507    the limits are set when the model is initially called, and maintained
508    as the values are adjusted, making it easier to see the effects of the
509    parameters.
510    """
511    Nbase, Ncomp = opts['n1'], opts['n2']
512    pars = opts['pars']
513    data = opts['data']
514
515    # Base calculation
516    if Nbase > 0:
517        base = opts['engines'][0]
518        try:
519            base_value, base_time = time_calculation(base, pars, Nbase)
520            print("%s t=%.2f ms, intensity=%.0f"
521                  % (base.engine, base_time, sum(base_value)))
522        except ImportError:
523            traceback.print_exc()
524            Nbase = 0
525
526    # Comparison calculation
527    if Ncomp > 0:
528        comp = opts['engines'][1]
529        try:
530            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
531            print("%s t=%.2f ms, intensity=%.0f"
532                  % (comp.engine, comp_time, sum(comp_value)))
533        except ImportError:
534            traceback.print_exc()
535            Ncomp = 0
536
537    # Compare, but only if computing both forms
538    if Nbase > 0 and Ncomp > 0:
539        resid = (base_value - comp_value)
540        relerr = resid/comp_value
541        _print_stats("|%s-%s|"
542                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
543                     resid)
544        _print_stats("|(%s-%s)/%s|"
545                     % (base.engine, comp.engine, comp.engine),
546                     relerr)
547
548    # Plot if requested
549    if not opts['plot'] and not opts['explore']: return
550    view = opts['view']
551    import matplotlib.pyplot as plt
552    if limits is None:
553        vmin, vmax = np.Inf, -np.Inf
554        if Nbase > 0:
555            vmin = min(vmin, min(base_value))
556            vmax = max(vmax, max(base_value))
557        if Ncomp > 0:
558            vmin = min(vmin, min(comp_value))
559            vmax = max(vmax, max(comp_value))
560        limits = vmin, vmax
561
562    if Nbase > 0:
563        if Ncomp > 0: plt.subplot(131)
564        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
565        plt.title("%s t=%.2f ms"%(base.engine, base_time))
566        #cbar_title = "log I"
567    if Ncomp > 0:
568        if Nbase > 0: plt.subplot(132)
569        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
570        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
571        #cbar_title = "log I"
572    if Ncomp > 0 and Nbase > 0:
573        plt.subplot(133)
574        if not opts['rel_err']:
575            err, errstr, errview = resid, "abs err", "linear"
576        else:
577            err, errstr, errview = abs(relerr), "rel err", "log"
578        #err,errstr = base/comp,"ratio"
579        plot_theory(data, None, resid=err, view=errview, use_data=False)
580        if view == 'linear':
581            plt.xscale('linear')
582        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
583        #cbar_title = errstr if errview=="linear" else "log "+errstr
584    #if is2D:
585    #    h = plt.colorbar()
586    #    h.ax.set_title(cbar_title)
587
588    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
589        plt.figure()
590        v = relerr
591        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
592        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
593        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
594                   % (base.engine, comp.engine, comp.engine))
595        plt.ylabel('P(err)')
596        plt.title('Distribution of relative error between calculation engines')
597
598    if not opts['explore']:
599        plt.show()
600
601    return limits
602
603def _print_stats(label, err):
604    sorted_err = np.sort(abs(err))
605    p50 = int((len(err)-1)*0.50)
606    p98 = int((len(err)-1)*0.98)
607    data = [
608        "max:%.3e"%sorted_err[-1],
609        "median:%.3e"%sorted_err[p50],
610        "98%%:%.3e"%sorted_err[p98],
611        "rms:%.3e"%np.sqrt(np.mean(err**2)),
612        "zero-offset:%+.3e"%np.mean(err),
613        ]
614    print(label+"  "+"  ".join(data))
615
616
617
618# ===========================================================================
619#
620NAME_OPTIONS = set([
621    'plot', 'noplot',
622    'half', 'fast', 'single', 'double',
623    'single!', 'double!', 'quad!', 'sasview',
624    'lowq', 'midq', 'highq', 'exq',
625    '2d', '1d',
626    'preset', 'random',
627    'poly', 'mono',
628    'nopars', 'pars',
629    'rel', 'abs',
630    'linear', 'log', 'q4',
631    'hist', 'nohist',
632    'edit',
633    ])
634VALUE_OPTIONS = [
635    # Note: random is both a name option and a value option
636    'cutoff', 'random', 'nq', 'res', 'accuracy',
637    ]
638
639def columnize(L, indent="", width=79):
640    """
641    Format a list of strings into columns.
642
643    Returns a string with carriage returns ready for printing.
644    """
645    column_width = max(len(w) for w in L) + 1
646    num_columns = (width - len(indent)) // column_width
647    num_rows = len(L) // num_columns
648    L = L + [""] * (num_rows*num_columns - len(L))
649    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
650    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
651             for row in zip(*columns)]
652    output = indent + ("\n"+indent).join(lines)
653    return output
654
655
656def get_demo_pars(model_info):
657    """
658    Extract demo parameters from the model definition.
659    """
660    # Get the default values for the parameters
661    pars = dict((p.name, p.default) for p in model_info['parameters'])
662
663    # Fill in default values for the polydispersity parameters
664    for p in model_info['parameters']:
665        if p.type in ('volume', 'orientation'):
666            pars[p.name+'_pd'] = 0.0
667            pars[p.name+'_pd_n'] = 0
668            pars[p.name+'_pd_nsigma'] = 3.0
669            pars[p.name+'_pd_type'] = "gaussian"
670
671    # Plug in values given in demo
672    pars.update(model_info['demo'])
673    return pars
674
675
676def parse_opts():
677    """
678    Parse command line options.
679    """
680    MODELS = core.list_models()
681    flags = [arg for arg in sys.argv[1:]
682             if arg.startswith('-')]
683    values = [arg for arg in sys.argv[1:]
684              if not arg.startswith('-') and '=' in arg]
685    args = [arg for arg in sys.argv[1:]
686            if not arg.startswith('-') and '=' not in arg]
687    models = "\n    ".join("%-15s"%v for v in MODELS)
688    if len(args) == 0:
689        print(USAGE)
690        print("\nAvailable models:")
691        print(columnize(MODELS, indent="  "))
692        sys.exit(1)
693    if len(args) > 3:
694        print("expected parameters: model N1 N2")
695
696    name = args[0]
697    try:
698        model_info = core.load_model_info(name)
699    except ImportError, exc:
700        print(str(exc))
701        print("Could not find model; use one of:\n    " + models)
702        sys.exit(1)
703
704    invalid = [o[1:] for o in flags
705               if o[1:] not in NAME_OPTIONS
706               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
707    if invalid:
708        print("Invalid options: %s"%(", ".join(invalid)))
709        sys.exit(1)
710
711
712    # pylint: disable=bad-whitespace
713    # Interpret the flags
714    opts = {
715        'plot'      : True,
716        'view'      : 'log',
717        'is2d'      : False,
718        'qmax'      : 0.05,
719        'nq'        : 128,
720        'res'       : 0.0,
721        'accuracy'  : 'Low',
722        'cutoff'    : 0.0,
723        'seed'      : -1,  # default to preset
724        'mono'      : False,
725        'show_pars' : False,
726        'show_hist' : False,
727        'rel_err'   : True,
728        'explore'   : False,
729    }
730    engines = []
731    for arg in flags:
732        if arg == '-noplot':    opts['plot'] = False
733        elif arg == '-plot':    opts['plot'] = True
734        elif arg == '-linear':  opts['view'] = 'linear'
735        elif arg == '-log':     opts['view'] = 'log'
736        elif arg == '-q4':      opts['view'] = 'q4'
737        elif arg == '-1d':      opts['is2d'] = False
738        elif arg == '-2d':      opts['is2d'] = True
739        elif arg == '-exq':     opts['qmax'] = 10.0
740        elif arg == '-highq':   opts['qmax'] = 1.0
741        elif arg == '-midq':    opts['qmax'] = 0.2
742        elif arg == '-lowq':    opts['qmax'] = 0.05
743        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
744        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
745        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
746        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
747        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
748        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
749        elif arg == '-preset':  opts['seed'] = -1
750        elif arg == '-mono':    opts['mono'] = True
751        elif arg == '-poly':    opts['mono'] = False
752        elif arg == '-pars':    opts['show_pars'] = True
753        elif arg == '-nopars':  opts['show_pars'] = False
754        elif arg == '-hist':    opts['show_hist'] = True
755        elif arg == '-nohist':  opts['show_hist'] = False
756        elif arg == '-rel':     opts['rel_err'] = True
757        elif arg == '-abs':     opts['rel_err'] = False
758        elif arg == '-half':    engines.append(arg[1:])
759        elif arg == '-fast':    engines.append(arg[1:])
760        elif arg == '-single':  engines.append(arg[1:])
761        elif arg == '-double':  engines.append(arg[1:])
762        elif arg == '-single!': engines.append(arg[1:])
763        elif arg == '-double!': engines.append(arg[1:])
764        elif arg == '-quad!':   engines.append(arg[1:])
765        elif arg == '-sasview': engines.append(arg[1:])
766        elif arg == '-edit':    opts['explore'] = True
767    # pylint: enable=bad-whitespace
768
769    if len(engines) == 0:
770        engines.extend(['single', 'sasview'])
771    elif len(engines) == 1:
772        if engines[0][0] != 'sasview':
773            engines.append('sasview')
774        else:
775            engines.append('single')
776    elif len(engines) > 2:
777        del engines[2:]
778
779    n1 = int(args[1]) if len(args) > 1 else 1
780    n2 = int(args[2]) if len(args) > 2 else 1
781
782    # Get demo parameters from model definition, or use default parameters
783    # if model does not define demo parameters
784    pars = get_demo_pars(model_info)
785
786    # Fill in parameters given on the command line
787    presets = {}
788    for arg in values:
789        k, v = arg.split('=', 1)
790        if k not in pars:
791            # extract base name without polydispersity info
792            s = set(p.split('_pd')[0] for p in pars)
793            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
794            sys.exit(1)
795        presets[k] = float(v) if not k.endswith('type') else v
796
797    # randomize parameters
798    #pars.update(set_pars)  # set value before random to control range
799    if opts['seed'] > -1:
800        pars = randomize_pars(pars, seed=opts['seed'])
801        print("Randomize using -random=%i"%opts['seed'])
802    if opts['mono']:
803        pars = suppress_pd(pars)
804    pars.update(presets)  # set value after random to control value
805    #import pprint; pprint.pprint(model_info)
806    constrain_pars(model_info, pars)
807    constrain_new_to_old(model_info, pars)
808    if opts['show_pars']:
809        print(str(parlist(model_info, pars, opts['is2d'])))
810
811    # Create the computational engines
812    data, _ = make_data(opts)
813    if n1:
814        base = make_engine(model_info, data, engines[0], opts['cutoff'])
815    else:
816        base = None
817    if n2:
818        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
819    else:
820        comp = None
821
822    # pylint: disable=bad-whitespace
823    # Remember it all
824    opts.update({
825        'name'      : name,
826        'def'       : model_info,
827        'n1'        : n1,
828        'n2'        : n2,
829        'presets'   : presets,
830        'pars'      : pars,
831        'data'      : data,
832        'engines'   : [base, comp],
833    })
834    # pylint: enable=bad-whitespace
835
836    return opts
837
838def explore(opts):
839    """
840    Explore the model using the Bumps GUI.
841    """
842    import wx
843    from bumps.names import FitProblem
844    from bumps.gui.app_frame import AppFrame
845
846    problem = FitProblem(Explore(opts))
847    is_mac = "cocoa" in wx.version()
848    app = wx.App()
849    frame = AppFrame(parent=None, title="explore")
850    if not is_mac: frame.Show()
851    frame.panel.set_model(model=problem)
852    frame.panel.Layout()
853    frame.panel.aui.Split(0, wx.TOP)
854    if is_mac: frame.Show()
855    app.MainLoop()
856
857class Explore(object):
858    """
859    Bumps wrapper for a SAS model comparison.
860
861    The resulting object can be used as a Bumps fit problem so that
862    parameters can be adjusted in the GUI, with plots updated on the fly.
863    """
864    def __init__(self, opts):
865        from bumps.cli import config_matplotlib
866        from . import bumps_model
867        config_matplotlib()
868        self.opts = opts
869        model_info = opts['def']
870        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
871        if not opts['is2d']:
872            for p in model_info['parameters'].type['1d']:
873                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
874                    k = p.name+ext
875                    v = pars.get(k, None)
876                    if v is not None:
877                        v.range(*parameter_range(k, v.value))
878        else:
879            for k, v in pars.items():
880                v.range(*parameter_range(k, v.value))
881
882        self.pars = pars
883        self.pd_types = pd_types
884        self.limits = None
885
886    def numpoints(self):
887        """
888        Return the number of points.
889        """
890        return len(self.pars) + 1  # so dof is 1
891
892    def parameters(self):
893        """
894        Return a dictionary of parameters.
895        """
896        return self.pars
897
898    def nllf(self):
899        """
900        Return cost.
901        """
902        # pylint: disable=no-self-use
903        return 0.  # No nllf
904
905    def plot(self, view='log'):
906        """
907        Plot the data and residuals.
908        """
909        pars = dict((k, v.value) for k, v in self.pars.items())
910        pars.update(self.pd_types)
911        self.opts['pars'] = pars
912        limits = compare(self.opts, limits=self.limits)
913        if self.limits is None:
914            vmin, vmax = limits
915            vmax = 1.3*vmax
916            vmin = vmax*1e-7
917            self.limits = vmin, vmax
918
919
920def main():
921    """
922    Main program.
923    """
924    opts = parse_opts()
925    if opts['explore']:
926        explore(opts)
927    else:
928        compare(opts)
929
930if __name__ == "__main__":
931    main()
Note: See TracBrowser for help on using the repository browser.