source: sasmodels/sasmodels/compare.py @ d19962c

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since d19962c was d19962c, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

working vector parameter example using dll engine

  • Property mode set to 100755
File size: 30.4 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np
37
38from . import core
39from . import kerneldll
40from . import product
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_pars, constrain_new_to_old
44
45USAGE = """
46usage: compare.py model N1 N2 [options...] [key=val]
47
48Compare the speed and value for a model between the SasView original and the
49sasmodels rewrite.
50
51model is the name of the model to compare (see below).
52N1 is the number of times to run sasmodels (default=1).
53N2 is the number times to run sasview (default=1).
54
55Options (* for default):
56
57    -plot*/-noplot plots or suppress the plot of the model
58    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
59    -nq=128 sets the number of Q points in the data set
60    -1d*/-2d computes 1d or 2d data
61    -preset*/-random[=seed] preset or random parameters
62    -mono/-poly* force monodisperse/polydisperse
63    -cutoff=1e-5* cutoff value for including a point in polydispersity
64    -pars/-nopars* prints the parameter set or not
65    -abs/-rel* plot relative or absolute error
66    -linear/-log*/-q4 intensity scaling
67    -hist/-nohist* plot histogram of relative error
68    -res=0 sets the resolution width dQ/Q if calculating with resolution
69    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
70    -edit starts the parameter explorer
71    -default/-demo* use demo vs default parameters
72
73Any two calculation engines can be selected for comparison:
74
75    -single/-double/-half/-fast sets an OpenCL calculation engine
76    -single!/-double!/-quad! sets an OpenMP calculation engine
77    -sasview sets the sasview calculation engine
78
79The default is -single -sasview.  Note that the interpretation of quad
80precision depends on architecture, and may vary from 64-bit to 128-bit,
81with 80-bit floats being common (1e-19 precision).
82
83Key=value pairs allow you to set specific values for the model parameters.
84"""
85
86# Update docs with command line usage string.   This is separate from the usual
87# doc string so that we can display it at run time if there is an error.
88# lin
89__doc__ = (__doc__  # pylint: disable=redefined-builtin
90           + """
91Program description
92-------------------
93
94"""
95           + USAGE)
96
97kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
98
99MODELS = core.list_models()
100
101# CRUFT python 2.6
102if not hasattr(datetime.timedelta, 'total_seconds'):
103    def delay(dt):
104        """Return number date-time delta as number seconds"""
105        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
106else:
107    def delay(dt):
108        """Return number date-time delta as number seconds"""
109        return dt.total_seconds()
110
111
112class push_seed(object):
113    """
114    Set the seed value for the random number generator.
115
116    When used in a with statement, the random number generator state is
117    restored after the with statement is complete.
118
119    :Parameters:
120
121    *seed* : int or array_like, optional
122        Seed for RandomState
123
124    :Example:
125
126    Seed can be used directly to set the seed::
127
128        >>> from numpy.random import randint
129        >>> push_seed(24)
130        <...push_seed object at...>
131        >>> print(randint(0,1000000,3))
132        [242082    899 211136]
133
134    Seed can also be used in a with statement, which sets the random
135    number generator state for the enclosed computations and restores
136    it to the previous state on completion::
137
138        >>> with push_seed(24):
139        ...    print(randint(0,1000000,3))
140        [242082    899 211136]
141
142    Using nested contexts, we can demonstrate that state is indeed
143    restored after the block completes::
144
145        >>> with push_seed(24):
146        ...    print(randint(0,1000000))
147        ...    with push_seed(24):
148        ...        print(randint(0,1000000,3))
149        ...    print(randint(0,1000000))
150        242082
151        [242082    899 211136]
152        899
153
154    The restore step is protected against exceptions in the block::
155
156        >>> with push_seed(24):
157        ...    print(randint(0,1000000))
158        ...    try:
159        ...        with push_seed(24):
160        ...            print(randint(0,1000000,3))
161        ...            raise Exception()
162        ...    except:
163        ...        print("Exception raised")
164        ...    print(randint(0,1000000))
165        242082
166        [242082    899 211136]
167        Exception raised
168        899
169    """
170    def __init__(self, seed=None):
171        self._state = np.random.get_state()
172        np.random.seed(seed)
173
174    def __enter__(self):
175        return None
176
177    def __exit__(self, *args):
178        np.random.set_state(self._state)
179
180def tic():
181    """
182    Timer function.
183
184    Use "toc=tic()" to start the clock and "toc()" to measure
185    a time interval.
186    """
187    then = datetime.datetime.now()
188    return lambda: delay(datetime.datetime.now() - then)
189
190
191def set_beam_stop(data, radius, outer=None):
192    """
193    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
194
195    Note: this function does not use the sasview package
196    """
197    if hasattr(data, 'qx_data'):
198        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
199        data.mask = (q < radius)
200        if outer is not None:
201            data.mask |= (q >= outer)
202    else:
203        data.mask = (data.x < radius)
204        if outer is not None:
205            data.mask |= (data.x >= outer)
206
207
208def parameter_range(p, v):
209    """
210    Choose a parameter range based on parameter name and initial value.
211    """
212    if p.endswith('_pd_n'):
213        return [0, 100]
214    elif p.endswith('_pd_nsigma'):
215        return [0, 5]
216    elif p.endswith('_pd_type'):
217        return v
218    elif any(s in p for s in ('theta', 'phi', 'psi')):
219        # orientation in [-180,180], orientation pd in [0,45]
220        if p.endswith('_pd'):
221            return [0, 45]
222        else:
223            return [-180, 180]
224    elif 'sld' in p:
225        return [-0.5, 10]
226    elif p.endswith('_pd'):
227        return [0, 1]
228    elif p == 'background':
229        return [0, 10]
230    elif p == 'scale':
231        return [0, 1e3]
232    elif p == 'case_num':
233        # RPA hack
234        return [0, 10]
235    elif v < 0:
236        # Kxy parameters in rpa model can be negative
237        return [2*v, -2*v]
238    else:
239        return [0, (2*v if v > 0 else 1)]
240
241
242def _randomize_one(p, v):
243    """
244    Randomize a single parameter.
245    """
246    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
247        return v
248    else:
249        return np.random.uniform(*parameter_range(p, v))
250
251
252def randomize_pars(pars, seed=None):
253    """
254    Generate random values for all of the parameters.
255
256    Valid ranges for the random number generator are guessed from the name of
257    the parameter; this will not account for constraints such as cap radius
258    greater than cylinder radius in the capped_cylinder model, so
259    :func:`constrain_pars` needs to be called afterward..
260    """
261    with push_seed(seed):
262        # Note: the sort guarantees order `of calls to random number generator
263        pars = dict((p, _randomize_one(p, v))
264                    for p, v in sorted(pars.items()))
265    return pars
266
267def constrain_pars(model_info, pars):
268    """
269    Restrict parameters to valid values.
270
271    This includes model specific code for models such as capped_cylinder
272    which need to support within model constraints (cap radius more than
273    cylinder radius in this case).
274    """
275    name = model_info['id']
276    # if it is a product model, then just look at the form factor since
277    # none of the structure factors need any constraints.
278    if '*' in name:
279        name = name.split('*')[0]
280
281    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
282        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
283    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
284        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
285
286    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
287    if name == 'guinier':
288        #q_max = 0.2  # mid q maximum
289        q_max = 1.0  # high q maximum
290        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
291        pars['rg'] = min(pars['rg'], rg_max)
292
293    if name == 'rpa':
294        # Make sure phi sums to 1.0
295        if pars['case_num'] < 2:
296            pars['Phia'] = 0.
297            pars['Phib'] = 0.
298        elif pars['case_num'] < 5:
299            pars['Phia'] = 0.
300        total = sum(pars['Phi'+c] for c in 'abcd')
301        for c in 'abcd':
302            pars['Phi'+c] /= total
303
304def parlist(model_info, pars, is2d):
305    """
306    Format the parameter list for printing.
307    """
308    lines = []
309    parameters = model_info['parameters']
310    for p in parameters.user_parameters(pars, is2d):
311        fields = dict(
312            value=pars.get(p.id, p.default),
313            pd=pars.get(p.id+"_pd", 0.),
314            n=int(pars.get(p.id+"_pd_n", 0)),
315            nsigma=pars.get(p.id+"_pd_nsgima", 3.),
316            type=pars.get(p.id+"_pd_type", 'gaussian'))
317        lines.append(_format_par(p.name, **fields))
318    return "\n".join(lines)
319
320    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
321
322def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
323    line = "%s: %g"%(name, value)
324    if pd != 0.  and n != 0:
325        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
326                % (pd, n, nsigma, nsigma, type)
327    return line
328
329def suppress_pd(pars):
330    """
331    Suppress theta_pd for now until the normalization is resolved.
332
333    May also suppress complete polydispersity of the model to test
334    models more quickly.
335    """
336    pars = pars.copy()
337    for p in pars:
338        if p.endswith("_pd_n"): pars[p] = 0
339    return pars
340
341def eval_sasview(model_info, data):
342    """
343    Return a model calculator using the SasView fitting engine.
344    """
345    # importing sas here so that the error message will be that sas failed to
346    # import rather than the more obscure smear_selection not imported error
347    import sas
348    from sas.models.qsmearing import smear_selection
349
350    def get_model(name):
351        #print("new",sorted(_pars.items()))
352        sas = __import__('sas.models.' + name)
353        ModelClass = getattr(getattr(sas.models, name, None), name, None)
354        if ModelClass is None:
355            raise ValueError("could not find model %r in sas.models"%name)
356        return ModelClass()
357
358    # grab the sasview model, or create it if it is a product model
359    if model_info['composition']:
360        composition_type, parts = model_info['composition']
361        if composition_type == 'product':
362            from sas.sascalc.fit.MultiplicationModel import MultiplicationModel
363            P, S = [get_model(p) for p in model_info['oldname']]
364            model = MultiplicationModel(P, S)
365        else:
366            raise ValueError("sasview mixture models not supported by compare")
367    else:
368        model = get_model(model_info['oldname'])
369
370    # build a smearer with which to call the model, if necessary
371    smearer = smear_selection(data, model=model)
372    if hasattr(data, 'qx_data'):
373        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
374        index = ((~data.mask) & (~np.isnan(data.data))
375                 & (q >= data.qmin) & (q <= data.qmax))
376        if smearer is not None:
377            smearer.model = model  # because smear_selection has a bug
378            smearer.accuracy = data.accuracy
379            smearer.set_index(index)
380            theory = lambda: smearer.get_value()
381        else:
382            theory = lambda: model.evalDistribution([data.qx_data[index],
383                                                     data.qy_data[index]])
384    elif smearer is not None:
385        theory = lambda: smearer(model.evalDistribution(data.x))
386    else:
387        theory = lambda: model.evalDistribution(data.x)
388
389    def calculator(**pars):
390        """
391        Sasview calculator for model.
392        """
393        # paying for parameter conversion each time to keep life simple, if not fast
394        pars = revert_pars(model_info, pars)
395        for k, v in pars.items():
396            parts = k.split('.')  # polydispersity components
397            if len(parts) == 2:
398                model.dispersion[parts[0]][parts[1]] = v
399            else:
400                model.setParam(k, v)
401        return theory()
402
403    calculator.engine = "sasview"
404    return calculator
405
406DTYPE_MAP = {
407    'half': '16',
408    'fast': 'fast',
409    'single': '32',
410    'double': '64',
411    'quad': '128',
412    'f16': '16',
413    'f32': '32',
414    'f64': '64',
415    'longdouble': '128',
416}
417def eval_opencl(model_info, data, dtype='single', cutoff=0.):
418    """
419    Return a model calculator using the OpenCL calculation engine.
420    """
421    try:
422        model = core.build_model(model_info, dtype=dtype, platform="ocl")
423    except Exception as exc:
424        print(exc)
425        print("... trying again with single precision")
426        model = core.build_model(model_info, dtype='single', platform="ocl")
427    calculator = DirectModel(data, model, cutoff=cutoff)
428    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
429    return calculator
430
431def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
432    """
433    Return a model calculator using the DLL calculation engine.
434    """
435    if dtype == 'quad':
436        dtype = 'longdouble'
437    model = core.build_model(model_info, dtype=dtype, platform="dll")
438    calculator = DirectModel(data, model, cutoff=cutoff)
439    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
440    return calculator
441
442def time_calculation(calculator, pars, Nevals=1):
443    """
444    Compute the average calculation time over N evaluations.
445
446    An additional call is generated without polydispersity in order to
447    initialize the calculation engine, and make the average more stable.
448    """
449    # initialize the code so time is more accurate
450    if Nevals > 1:
451        value = calculator(**suppress_pd(pars))
452    toc = tic()
453    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
454        value = calculator(**pars)
455    average_time = toc()*1000./Nevals
456    return value, average_time
457
458def make_data(opts):
459    """
460    Generate an empty dataset, used with the model to set Q points
461    and resolution.
462
463    *opts* contains the options, with 'qmax', 'nq', 'res',
464    'accuracy', 'is2d' and 'view' parsed from the command line.
465    """
466    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
467    if opts['is2d']:
468        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
469        data.accuracy = opts['accuracy']
470        set_beam_stop(data, 0.004)
471        index = ~data.mask
472    else:
473        if opts['view'] == 'log':
474            qmax = math.log10(qmax)
475            q = np.logspace(qmax-3, qmax, nq)
476        else:
477            q = np.linspace(0.001*qmax, qmax, nq)
478        data = empty_data1D(q, resolution=res)
479        index = slice(None, None)
480    return data, index
481
482def make_engine(model_info, data, dtype, cutoff):
483    """
484    Generate the appropriate calculation engine for the given datatype.
485
486    Datatypes with '!' appended are evaluated using external C DLLs rather
487    than OpenCL.
488    """
489    if dtype == 'sasview':
490        return eval_sasview(model_info, data)
491    elif dtype.endswith('!'):
492        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
493    else:
494        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
495
496def compare(opts, limits=None):
497    """
498    Preform a comparison using options from the command line.
499
500    *limits* are the limits on the values to use, either to set the y-axis
501    for 1D or to set the colormap scale for 2D.  If None, then they are
502    inferred from the data and returned. When exploring using Bumps,
503    the limits are set when the model is initially called, and maintained
504    as the values are adjusted, making it easier to see the effects of the
505    parameters.
506    """
507    Nbase, Ncomp = opts['n1'], opts['n2']
508    pars = opts['pars']
509    data = opts['data']
510
511    # Base calculation
512    if Nbase > 0:
513        base = opts['engines'][0]
514        try:
515            base_value, base_time = time_calculation(base, pars, Nbase)
516            print("%s t=%.2f ms, intensity=%.0f"
517                  % (base.engine, base_time, sum(base_value)))
518        except ImportError:
519            traceback.print_exc()
520            Nbase = 0
521
522    # Comparison calculation
523    if Ncomp > 0:
524        comp = opts['engines'][1]
525        try:
526            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
527            print("%s t=%.2f ms, intensity=%.0f"
528                  % (comp.engine, comp_time, sum(comp_value)))
529        except ImportError:
530            traceback.print_exc()
531            Ncomp = 0
532
533    # Compare, but only if computing both forms
534    if Nbase > 0 and Ncomp > 0:
535        resid = (base_value - comp_value)
536        relerr = resid/comp_value
537        _print_stats("|%s-%s|"
538                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
539                     resid)
540        _print_stats("|(%s-%s)/%s|"
541                     % (base.engine, comp.engine, comp.engine),
542                     relerr)
543
544    # Plot if requested
545    if not opts['plot'] and not opts['explore']: return
546    view = opts['view']
547    import matplotlib.pyplot as plt
548    if limits is None:
549        vmin, vmax = np.Inf, -np.Inf
550        if Nbase > 0:
551            vmin = min(vmin, min(base_value))
552            vmax = max(vmax, max(base_value))
553        if Ncomp > 0:
554            vmin = min(vmin, min(comp_value))
555            vmax = max(vmax, max(comp_value))
556        limits = vmin, vmax
557
558    if Nbase > 0:
559        if Ncomp > 0: plt.subplot(131)
560        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
561        plt.title("%s t=%.2f ms"%(base.engine, base_time))
562        #cbar_title = "log I"
563    if Ncomp > 0:
564        if Nbase > 0: plt.subplot(132)
565        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
566        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
567        #cbar_title = "log I"
568    if Ncomp > 0 and Nbase > 0:
569        plt.subplot(133)
570        if not opts['rel_err']:
571            err, errstr, errview = resid, "abs err", "linear"
572        else:
573            err, errstr, errview = abs(relerr), "rel err", "log"
574        #err,errstr = base/comp,"ratio"
575        plot_theory(data, None, resid=err, view=errview, use_data=False)
576        if view == 'linear':
577            plt.xscale('linear')
578        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
579        #cbar_title = errstr if errview=="linear" else "log "+errstr
580    #if is2D:
581    #    h = plt.colorbar()
582    #    h.ax.set_title(cbar_title)
583
584    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
585        plt.figure()
586        v = relerr
587        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
588        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
589        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
590                   % (base.engine, comp.engine, comp.engine))
591        plt.ylabel('P(err)')
592        plt.title('Distribution of relative error between calculation engines')
593
594    if not opts['explore']:
595        plt.show()
596
597    return limits
598
599def _print_stats(label, err):
600    sorted_err = np.sort(abs(err))
601    p50 = int((len(err)-1)*0.50)
602    p98 = int((len(err)-1)*0.98)
603    data = [
604        "max:%.3e"%sorted_err[-1],
605        "median:%.3e"%sorted_err[p50],
606        "98%%:%.3e"%sorted_err[p98],
607        "rms:%.3e"%np.sqrt(np.mean(err**2)),
608        "zero-offset:%+.3e"%np.mean(err),
609        ]
610    print(label+"  "+"  ".join(data))
611
612
613
614# ===========================================================================
615#
616NAME_OPTIONS = set([
617    'plot', 'noplot',
618    'half', 'fast', 'single', 'double',
619    'single!', 'double!', 'quad!', 'sasview',
620    'lowq', 'midq', 'highq', 'exq',
621    '2d', '1d',
622    'preset', 'random',
623    'poly', 'mono',
624    'nopars', 'pars',
625    'rel', 'abs',
626    'linear', 'log', 'q4',
627    'hist', 'nohist',
628    'edit',
629    'demo', 'default',
630    ])
631VALUE_OPTIONS = [
632    # Note: random is both a name option and a value option
633    'cutoff', 'random', 'nq', 'res', 'accuracy',
634    ]
635
636def columnize(L, indent="", width=79):
637    """
638    Format a list of strings into columns.
639
640    Returns a string with carriage returns ready for printing.
641    """
642    column_width = max(len(w) for w in L) + 1
643    num_columns = (width - len(indent)) // column_width
644    num_rows = len(L) // num_columns
645    L = L + [""] * (num_rows*num_columns - len(L))
646    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
647    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
648             for row in zip(*columns)]
649    output = indent + ("\n"+indent).join(lines)
650    return output
651
652
653def get_pars(model_info, use_demo=False):
654    """
655    Extract demo parameters from the model definition.
656    """
657    # Get the default values for the parameters
658    pars = {}
659    for p in model_info['parameters'].call_parameters:
660        parts = [('', p.default)]
661        if p.polydisperse:
662            parts.append(('_pd', 0.0))
663            parts.append(('_pd_n', 0))
664            parts.append(('_pd_nsigma', 3.0))
665            parts.append(('_pd_type', "gaussian"))
666        for ext, val in parts:
667            if p.length > 1:
668                dict(("%s%d%s"%(p.id,k,ext), val) for k in range(p.length))
669            else:
670                pars[p.id+ext] = val
671
672    # Plug in values given in demo
673    if use_demo:
674        pars.update(model_info['demo'])
675    return pars
676
677
678def parse_opts():
679    """
680    Parse command line options.
681    """
682    MODELS = core.list_models()
683    flags = [arg for arg in sys.argv[1:]
684             if arg.startswith('-')]
685    values = [arg for arg in sys.argv[1:]
686              if not arg.startswith('-') and '=' in arg]
687    args = [arg for arg in sys.argv[1:]
688            if not arg.startswith('-') and '=' not in arg]
689    models = "\n    ".join("%-15s"%v for v in MODELS)
690    if len(args) == 0:
691        print(USAGE)
692        print("\nAvailable models:")
693        print(columnize(MODELS, indent="  "))
694        sys.exit(1)
695    if len(args) > 3:
696        print("expected parameters: model N1 N2")
697
698    name = args[0]
699    try:
700        model_info = core.load_model_info(name)
701    except ImportError, exc:
702        print(str(exc))
703        print("Could not find model; use one of:\n    " + models)
704        sys.exit(1)
705
706    invalid = [o[1:] for o in flags
707               if o[1:] not in NAME_OPTIONS
708               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
709    if invalid:
710        print("Invalid options: %s"%(", ".join(invalid)))
711        sys.exit(1)
712
713
714    # pylint: disable=bad-whitespace
715    # Interpret the flags
716    opts = {
717        'plot'      : True,
718        'view'      : 'log',
719        'is2d'      : False,
720        'qmax'      : 0.05,
721        'nq'        : 128,
722        'res'       : 0.0,
723        'accuracy'  : 'Low',
724        'cutoff'    : 0.0,
725        'seed'      : -1,  # default to preset
726        'mono'      : False,
727        'show_pars' : False,
728        'show_hist' : False,
729        'rel_err'   : True,
730        'explore'   : False,
731        'use_demo'  : True,
732    }
733    engines = []
734    for arg in flags:
735        if arg == '-noplot':    opts['plot'] = False
736        elif arg == '-plot':    opts['plot'] = True
737        elif arg == '-linear':  opts['view'] = 'linear'
738        elif arg == '-log':     opts['view'] = 'log'
739        elif arg == '-q4':      opts['view'] = 'q4'
740        elif arg == '-1d':      opts['is2d'] = False
741        elif arg == '-2d':      opts['is2d'] = True
742        elif arg == '-exq':     opts['qmax'] = 10.0
743        elif arg == '-highq':   opts['qmax'] = 1.0
744        elif arg == '-midq':    opts['qmax'] = 0.2
745        elif arg == '-lowq':    opts['qmax'] = 0.05
746        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
747        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
748        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
749        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
750        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
751        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
752        elif arg == '-preset':  opts['seed'] = -1
753        elif arg == '-mono':    opts['mono'] = True
754        elif arg == '-poly':    opts['mono'] = False
755        elif arg == '-pars':    opts['show_pars'] = True
756        elif arg == '-nopars':  opts['show_pars'] = False
757        elif arg == '-hist':    opts['show_hist'] = True
758        elif arg == '-nohist':  opts['show_hist'] = False
759        elif arg == '-rel':     opts['rel_err'] = True
760        elif arg == '-abs':     opts['rel_err'] = False
761        elif arg == '-half':    engines.append(arg[1:])
762        elif arg == '-fast':    engines.append(arg[1:])
763        elif arg == '-single':  engines.append(arg[1:])
764        elif arg == '-double':  engines.append(arg[1:])
765        elif arg == '-single!': engines.append(arg[1:])
766        elif arg == '-double!': engines.append(arg[1:])
767        elif arg == '-quad!':   engines.append(arg[1:])
768        elif arg == '-sasview': engines.append(arg[1:])
769        elif arg == '-edit':    opts['explore'] = True
770        elif arg == '-demo':    opts['use_demo'] = True
771        elif arg == '-default':    opts['use_demo'] = False
772    # pylint: enable=bad-whitespace
773
774    if len(engines) == 0:
775        engines.extend(['single', 'double'])
776    elif len(engines) == 1:
777        if engines[0][0] != 'double':
778            engines.append('double')
779        else:
780            engines.append('single')
781    elif len(engines) > 2:
782        del engines[2:]
783
784    n1 = int(args[1]) if len(args) > 1 else 1
785    n2 = int(args[2]) if len(args) > 2 else 1
786
787    # Get demo parameters from model definition, or use default parameters
788    # if model does not define demo parameters
789    pars = get_pars(model_info, opts['use_demo'])
790
791
792    # Fill in parameters given on the command line
793    presets = {}
794    for arg in values:
795        k, v = arg.split('=', 1)
796        if k not in pars:
797            # extract base name without polydispersity info
798            s = set(p.split('_pd')[0] for p in pars)
799            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
800            sys.exit(1)
801        presets[k] = float(v) if not k.endswith('type') else v
802
803    # randomize parameters
804    #pars.update(set_pars)  # set value before random to control range
805    if opts['seed'] > -1:
806        pars = randomize_pars(pars, seed=opts['seed'])
807        print("Randomize using -random=%i"%opts['seed'])
808    if opts['mono']:
809        pars = suppress_pd(pars)
810    pars.update(presets)  # set value after random to control value
811    #import pprint; pprint.pprint(model_info)
812    constrain_pars(model_info, pars)
813    constrain_new_to_old(model_info, pars)
814    if opts['show_pars']:
815        print(str(parlist(model_info, pars, opts['is2d'])))
816
817    # Create the computational engines
818    data, _ = make_data(opts)
819    if n1:
820        base = make_engine(model_info, data, engines[0], opts['cutoff'])
821    else:
822        base = None
823    if n2:
824        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
825    else:
826        comp = None
827
828    # pylint: disable=bad-whitespace
829    # Remember it all
830    opts.update({
831        'name'      : name,
832        'def'       : model_info,
833        'n1'        : n1,
834        'n2'        : n2,
835        'presets'   : presets,
836        'pars'      : pars,
837        'data'      : data,
838        'engines'   : [base, comp],
839    })
840    # pylint: enable=bad-whitespace
841
842    return opts
843
844def explore(opts):
845    """
846    Explore the model using the Bumps GUI.
847    """
848    import wx
849    from bumps.names import FitProblem
850    from bumps.gui.app_frame import AppFrame
851
852    problem = FitProblem(Explore(opts))
853    is_mac = "cocoa" in wx.version()
854    app = wx.App()
855    frame = AppFrame(parent=None, title="explore")
856    if not is_mac: frame.Show()
857    frame.panel.set_model(model=problem)
858    frame.panel.Layout()
859    frame.panel.aui.Split(0, wx.TOP)
860    if is_mac: frame.Show()
861    app.MainLoop()
862
863class Explore(object):
864    """
865    Bumps wrapper for a SAS model comparison.
866
867    The resulting object can be used as a Bumps fit problem so that
868    parameters can be adjusted in the GUI, with plots updated on the fly.
869    """
870    def __init__(self, opts):
871        from bumps.cli import config_matplotlib
872        from . import bumps_model
873        config_matplotlib()
874        self.opts = opts
875        model_info = opts['def']
876        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
877        if not opts['is2d']:
878            for p in model_info['parameters'].type['1d']:
879                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
880                    k = p.name+ext
881                    v = pars.get(k, None)
882                    if v is not None:
883                        v.range(*parameter_range(k, v.value))
884        else:
885            for k, v in pars.items():
886                v.range(*parameter_range(k, v.value))
887
888        self.pars = pars
889        self.pd_types = pd_types
890        self.limits = None
891
892    def numpoints(self):
893        """
894        Return the number of points.
895        """
896        return len(self.pars) + 1  # so dof is 1
897
898    def parameters(self):
899        """
900        Return a dictionary of parameters.
901        """
902        return self.pars
903
904    def nllf(self):
905        """
906        Return cost.
907        """
908        # pylint: disable=no-self-use
909        return 0.  # No nllf
910
911    def plot(self, view='log'):
912        """
913        Plot the data and residuals.
914        """
915        pars = dict((k, v.value) for k, v in self.pars.items())
916        pars.update(self.pd_types)
917        self.opts['pars'] = pars
918        limits = compare(self.opts, limits=self.limits)
919        if self.limits is None:
920            vmin, vmax = limits
921            vmax = 1.3*vmax
922            vmin = vmax*1e-7
923            self.limits = vmin, vmax
924
925
926def main():
927    """
928    Main program.
929    """
930    opts = parse_opts()
931    if opts['explore']:
932        explore(opts)
933    else:
934        compare(opts)
935
936if __name__ == "__main__":
937    main()
Note: See TracBrowser for help on using the repository browser.