source: sasmodels/sasmodels/compare.py @ c499331

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since c499331 was c499331, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

progress on having compare.py recognize vector parameters

  • Property mode set to 100755
File size: 31.0 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np
37
38from . import core
39from . import kerneldll
40from . import product
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_pars, constrain_new_to_old
44
45USAGE = """
46usage: compare.py model N1 N2 [options...] [key=val]
47
48Compare the speed and value for a model between the SasView original and the
49sasmodels rewrite.
50
51model is the name of the model to compare (see below).
52N1 is the number of times to run sasmodels (default=1).
53N2 is the number times to run sasview (default=1).
54
55Options (* for default):
56
57    -plot*/-noplot plots or suppress the plot of the model
58    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
59    -nq=128 sets the number of Q points in the data set
60    -1d*/-2d computes 1d or 2d data
61    -preset*/-random[=seed] preset or random parameters
62    -mono/-poly* force monodisperse/polydisperse
63    -cutoff=1e-5* cutoff value for including a point in polydispersity
64    -pars/-nopars* prints the parameter set or not
65    -abs/-rel* plot relative or absolute error
66    -linear/-log*/-q4 intensity scaling
67    -hist/-nohist* plot histogram of relative error
68    -res=0 sets the resolution width dQ/Q if calculating with resolution
69    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
70    -edit starts the parameter explorer
71    -default/-demo* use demo vs default parameters
72
73Any two calculation engines can be selected for comparison:
74
75    -single/-double/-half/-fast sets an OpenCL calculation engine
76    -single!/-double!/-quad! sets an OpenMP calculation engine
77    -sasview sets the sasview calculation engine
78
79The default is -single -sasview.  Note that the interpretation of quad
80precision depends on architecture, and may vary from 64-bit to 128-bit,
81with 80-bit floats being common (1e-19 precision).
82
83Key=value pairs allow you to set specific values for the model parameters.
84"""
85
86# Update docs with command line usage string.   This is separate from the usual
87# doc string so that we can display it at run time if there is an error.
88# lin
89__doc__ = (__doc__  # pylint: disable=redefined-builtin
90           + """
91Program description
92-------------------
93
94"""
95           + USAGE)
96
97kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
98
99MODELS = core.list_models()
100
101# CRUFT python 2.6
102if not hasattr(datetime.timedelta, 'total_seconds'):
103    def delay(dt):
104        """Return number date-time delta as number seconds"""
105        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
106else:
107    def delay(dt):
108        """Return number date-time delta as number seconds"""
109        return dt.total_seconds()
110
111
112class push_seed(object):
113    """
114    Set the seed value for the random number generator.
115
116    When used in a with statement, the random number generator state is
117    restored after the with statement is complete.
118
119    :Parameters:
120
121    *seed* : int or array_like, optional
122        Seed for RandomState
123
124    :Example:
125
126    Seed can be used directly to set the seed::
127
128        >>> from numpy.random import randint
129        >>> push_seed(24)
130        <...push_seed object at...>
131        >>> print(randint(0,1000000,3))
132        [242082    899 211136]
133
134    Seed can also be used in a with statement, which sets the random
135    number generator state for the enclosed computations and restores
136    it to the previous state on completion::
137
138        >>> with push_seed(24):
139        ...    print(randint(0,1000000,3))
140        [242082    899 211136]
141
142    Using nested contexts, we can demonstrate that state is indeed
143    restored after the block completes::
144
145        >>> with push_seed(24):
146        ...    print(randint(0,1000000))
147        ...    with push_seed(24):
148        ...        print(randint(0,1000000,3))
149        ...    print(randint(0,1000000))
150        242082
151        [242082    899 211136]
152        899
153
154    The restore step is protected against exceptions in the block::
155
156        >>> with push_seed(24):
157        ...    print(randint(0,1000000))
158        ...    try:
159        ...        with push_seed(24):
160        ...            print(randint(0,1000000,3))
161        ...            raise Exception()
162        ...    except:
163        ...        print("Exception raised")
164        ...    print(randint(0,1000000))
165        242082
166        [242082    899 211136]
167        Exception raised
168        899
169    """
170    def __init__(self, seed=None):
171        self._state = np.random.get_state()
172        np.random.seed(seed)
173
174    def __enter__(self):
175        return None
176
177    def __exit__(self, *args):
178        np.random.set_state(self._state)
179
180def tic():
181    """
182    Timer function.
183
184    Use "toc=tic()" to start the clock and "toc()" to measure
185    a time interval.
186    """
187    then = datetime.datetime.now()
188    return lambda: delay(datetime.datetime.now() - then)
189
190
191def set_beam_stop(data, radius, outer=None):
192    """
193    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
194
195    Note: this function does not use the sasview package
196    """
197    if hasattr(data, 'qx_data'):
198        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
199        data.mask = (q < radius)
200        if outer is not None:
201            data.mask |= (q >= outer)
202    else:
203        data.mask = (data.x < radius)
204        if outer is not None:
205            data.mask |= (data.x >= outer)
206
207
208def parameter_range(p, v):
209    """
210    Choose a parameter range based on parameter name and initial value.
211    """
212    if p.endswith('_pd_n'):
213        return [0, 100]
214    elif p.endswith('_pd_nsigma'):
215        return [0, 5]
216    elif p.endswith('_pd_type'):
217        return v
218    elif any(s in p for s in ('theta', 'phi', 'psi')):
219        # orientation in [-180,180], orientation pd in [0,45]
220        if p.endswith('_pd'):
221            return [0, 45]
222        else:
223            return [-180, 180]
224    elif 'sld' in p:
225        return [-0.5, 10]
226    elif p.endswith('_pd'):
227        return [0, 1]
228    elif p == 'background':
229        return [0, 10]
230    elif p == 'scale':
231        return [0, 1e3]
232    elif p == 'case_num':
233        # RPA hack
234        return [0, 10]
235    elif v < 0:
236        # Kxy parameters in rpa model can be negative
237        return [2*v, -2*v]
238    else:
239        return [0, (2*v if v > 0 else 1)]
240
241
242def _randomize_one(p, v):
243    """
244    Randomize a single parameter.
245    """
246    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
247        return v
248    else:
249        return np.random.uniform(*parameter_range(p, v))
250
251
252def randomize_pars(pars, seed=None):
253    """
254    Generate random values for all of the parameters.
255
256    Valid ranges for the random number generator are guessed from the name of
257    the parameter; this will not account for constraints such as cap radius
258    greater than cylinder radius in the capped_cylinder model, so
259    :func:`constrain_pars` needs to be called afterward..
260    """
261    with push_seed(seed):
262        # Note: the sort guarantees order `of calls to random number generator
263        pars = dict((p, _randomize_one(p, v))
264                    for p, v in sorted(pars.items()))
265    return pars
266
267def constrain_pars(model_info, pars):
268    """
269    Restrict parameters to valid values.
270
271    This includes model specific code for models such as capped_cylinder
272    which need to support within model constraints (cap radius more than
273    cylinder radius in this case).
274    """
275    name = model_info['id']
276    # if it is a product model, then just look at the form factor since
277    # none of the structure factors need any constraints.
278    if '*' in name:
279        name = name.split('*')[0]
280
281    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
282        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
283    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
284        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
285
286    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
287    if name == 'guinier':
288        #q_max = 0.2  # mid q maximum
289        q_max = 1.0  # high q maximum
290        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
291        pars['rg'] = min(pars['rg'], rg_max)
292
293    if name == 'rpa':
294        # Make sure phi sums to 1.0
295        if pars['case_num'] < 2:
296            pars['Phia'] = 0.
297            pars['Phib'] = 0.
298        elif pars['case_num'] < 5:
299            pars['Phia'] = 0.
300        total = sum(pars['Phi'+c] for c in 'abcd')
301        for c in 'abcd':
302            pars['Phi'+c] /= total
303
304def parlist(model_info, pars, is2d):
305    """
306    Format the parameter list for printing.
307    """
308    lines = []
309    parameters = model_info['parameters']
310    for p in parameters.type['2d' if is2d else '1d']:
311        if p.length > 1:
312            for k in range(p.length):
313                ext = "[%d]"%k
314                fields = dict(
315                    value=pars.get(p.id+ext, p.default),
316                    pd=pars.get(p.id+"_pd"+ext, 0.),
317                    n=int(pars.get(p.id+"_pd_n"+ext, 0)),
318                    nsigma=pars.get(p.id+"_pd_nsgima"+ext, 3.),
319                    type=pars.get(p.id+"_pd_type"+ext, 'gaussian'))
320                lines.append(_format_par(p.id+ext, **fields))
321        else:
322            fields = dict(
323                value=pars.get(p.id, p.default),
324                pd=pars.get(p.id+"_pd", 0.),
325                n=int(pars.get(p.id+"_pd_n", 0)),
326                nsigma=pars.get(p.id+"_pd_nsgima", 3.),
327                type=pars.get(p.id+"_pd_type", 'gaussian'))
328            lines.append(_format_par(p.name, **fields))
329    return "\n".join(lines)
330
331    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
332
333def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
334    line = "%s: %g"%(name, value)
335    if pd != 0.  and n != 0:
336        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
337                % (pd, n, nsigma, nsigma, type)
338    return line
339
340def suppress_pd(pars):
341    """
342    Suppress theta_pd for now until the normalization is resolved.
343
344    May also suppress complete polydispersity of the model to test
345    models more quickly.
346    """
347    pars = pars.copy()
348    for p in pars:
349        if p.endswith("_pd_n"): pars[p] = 0
350    return pars
351
352def eval_sasview(model_info, data):
353    """
354    Return a model calculator using the SasView fitting engine.
355    """
356    # importing sas here so that the error message will be that sas failed to
357    # import rather than the more obscure smear_selection not imported error
358    import sas
359    from sas.models.qsmearing import smear_selection
360
361    def get_model(name):
362        #print("new",sorted(_pars.items()))
363        sas = __import__('sas.models.' + name)
364        ModelClass = getattr(getattr(sas.models, name, None), name, None)
365        if ModelClass is None:
366            raise ValueError("could not find model %r in sas.models"%name)
367        return ModelClass()
368
369    # grab the sasview model, or create it if it is a product model
370    if model_info['composition']:
371        composition_type, parts = model_info['composition']
372        if composition_type == 'product':
373            from sas.sascalc.fit.MultiplicationModel import MultiplicationModel
374            P, S = [get_model(p) for p in model_info['oldname']]
375            model = MultiplicationModel(P, S)
376        else:
377            raise ValueError("sasview mixture models not supported by compare")
378    else:
379        model = get_model(model_info['oldname'])
380
381    # build a smearer with which to call the model, if necessary
382    smearer = smear_selection(data, model=model)
383    if hasattr(data, 'qx_data'):
384        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
385        index = ((~data.mask) & (~np.isnan(data.data))
386                 & (q >= data.qmin) & (q <= data.qmax))
387        if smearer is not None:
388            smearer.model = model  # because smear_selection has a bug
389            smearer.accuracy = data.accuracy
390            smearer.set_index(index)
391            theory = lambda: smearer.get_value()
392        else:
393            theory = lambda: model.evalDistribution([data.qx_data[index],
394                                                     data.qy_data[index]])
395    elif smearer is not None:
396        theory = lambda: smearer(model.evalDistribution(data.x))
397    else:
398        theory = lambda: model.evalDistribution(data.x)
399
400    def calculator(**pars):
401        """
402        Sasview calculator for model.
403        """
404        # paying for parameter conversion each time to keep life simple, if not fast
405        pars = revert_pars(model_info, pars)
406        for k, v in pars.items():
407            parts = k.split('.')  # polydispersity components
408            if len(parts) == 2:
409                model.dispersion[parts[0]][parts[1]] = v
410            else:
411                model.setParam(k, v)
412        return theory()
413
414    calculator.engine = "sasview"
415    return calculator
416
417DTYPE_MAP = {
418    'half': '16',
419    'fast': 'fast',
420    'single': '32',
421    'double': '64',
422    'quad': '128',
423    'f16': '16',
424    'f32': '32',
425    'f64': '64',
426    'longdouble': '128',
427}
428def eval_opencl(model_info, data, dtype='single', cutoff=0.):
429    """
430    Return a model calculator using the OpenCL calculation engine.
431    """
432    try:
433        model = core.build_model(model_info, dtype=dtype, platform="ocl")
434    except Exception as exc:
435        print(exc)
436        print("... trying again with single precision")
437        model = core.build_model(model_info, dtype='single', platform="ocl")
438    calculator = DirectModel(data, model, cutoff=cutoff)
439    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
440    return calculator
441
442def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
443    """
444    Return a model calculator using the DLL calculation engine.
445    """
446    if dtype == 'quad':
447        dtype = 'longdouble'
448    model = core.build_model(model_info, dtype=dtype, platform="dll")
449    calculator = DirectModel(data, model, cutoff=cutoff)
450    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
451    return calculator
452
453def time_calculation(calculator, pars, Nevals=1):
454    """
455    Compute the average calculation time over N evaluations.
456
457    An additional call is generated without polydispersity in order to
458    initialize the calculation engine, and make the average more stable.
459    """
460    # initialize the code so time is more accurate
461    if Nevals > 1:
462        value = calculator(**suppress_pd(pars))
463    toc = tic()
464    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
465        value = calculator(**pars)
466    average_time = toc()*1000./Nevals
467    return value, average_time
468
469def make_data(opts):
470    """
471    Generate an empty dataset, used with the model to set Q points
472    and resolution.
473
474    *opts* contains the options, with 'qmax', 'nq', 'res',
475    'accuracy', 'is2d' and 'view' parsed from the command line.
476    """
477    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
478    if opts['is2d']:
479        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
480        data.accuracy = opts['accuracy']
481        set_beam_stop(data, 0.004)
482        index = ~data.mask
483    else:
484        if opts['view'] == 'log':
485            qmax = math.log10(qmax)
486            q = np.logspace(qmax-3, qmax, nq)
487        else:
488            q = np.linspace(0.001*qmax, qmax, nq)
489        data = empty_data1D(q, resolution=res)
490        index = slice(None, None)
491    return data, index
492
493def make_engine(model_info, data, dtype, cutoff):
494    """
495    Generate the appropriate calculation engine for the given datatype.
496
497    Datatypes with '!' appended are evaluated using external C DLLs rather
498    than OpenCL.
499    """
500    if dtype == 'sasview':
501        return eval_sasview(model_info, data)
502    elif dtype.endswith('!'):
503        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
504    else:
505        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
506
507def compare(opts, limits=None):
508    """
509    Preform a comparison using options from the command line.
510
511    *limits* are the limits on the values to use, either to set the y-axis
512    for 1D or to set the colormap scale for 2D.  If None, then they are
513    inferred from the data and returned. When exploring using Bumps,
514    the limits are set when the model is initially called, and maintained
515    as the values are adjusted, making it easier to see the effects of the
516    parameters.
517    """
518    Nbase, Ncomp = opts['n1'], opts['n2']
519    pars = opts['pars']
520    data = opts['data']
521
522    # Base calculation
523    if Nbase > 0:
524        base = opts['engines'][0]
525        try:
526            base_value, base_time = time_calculation(base, pars, Nbase)
527            print("%s t=%.2f ms, intensity=%.0f"
528                  % (base.engine, base_time, sum(base_value)))
529        except ImportError:
530            traceback.print_exc()
531            Nbase = 0
532
533    # Comparison calculation
534    if Ncomp > 0:
535        comp = opts['engines'][1]
536        try:
537            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
538            print("%s t=%.2f ms, intensity=%.0f"
539                  % (comp.engine, comp_time, sum(comp_value)))
540        except ImportError:
541            traceback.print_exc()
542            Ncomp = 0
543
544    # Compare, but only if computing both forms
545    if Nbase > 0 and Ncomp > 0:
546        resid = (base_value - comp_value)
547        relerr = resid/comp_value
548        _print_stats("|%s-%s|"
549                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
550                     resid)
551        _print_stats("|(%s-%s)/%s|"
552                     % (base.engine, comp.engine, comp.engine),
553                     relerr)
554
555    # Plot if requested
556    if not opts['plot'] and not opts['explore']: return
557    view = opts['view']
558    import matplotlib.pyplot as plt
559    if limits is None:
560        vmin, vmax = np.Inf, -np.Inf
561        if Nbase > 0:
562            vmin = min(vmin, min(base_value))
563            vmax = max(vmax, max(base_value))
564        if Ncomp > 0:
565            vmin = min(vmin, min(comp_value))
566            vmax = max(vmax, max(comp_value))
567        limits = vmin, vmax
568
569    if Nbase > 0:
570        if Ncomp > 0: plt.subplot(131)
571        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
572        plt.title("%s t=%.2f ms"%(base.engine, base_time))
573        #cbar_title = "log I"
574    if Ncomp > 0:
575        if Nbase > 0: plt.subplot(132)
576        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
577        plt.title("%s t=%.2f ms"%(comp.engine, comp_time))
578        #cbar_title = "log I"
579    if Ncomp > 0 and Nbase > 0:
580        plt.subplot(133)
581        if not opts['rel_err']:
582            err, errstr, errview = resid, "abs err", "linear"
583        else:
584            err, errstr, errview = abs(relerr), "rel err", "log"
585        #err,errstr = base/comp,"ratio"
586        plot_theory(data, None, resid=err, view=errview, use_data=False)
587        if view == 'linear':
588            plt.xscale('linear')
589        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
590        #cbar_title = errstr if errview=="linear" else "log "+errstr
591    #if is2D:
592    #    h = plt.colorbar()
593    #    h.ax.set_title(cbar_title)
594
595    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
596        plt.figure()
597        v = relerr
598        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
599        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
600        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
601                   % (base.engine, comp.engine, comp.engine))
602        plt.ylabel('P(err)')
603        plt.title('Distribution of relative error between calculation engines')
604
605    if not opts['explore']:
606        plt.show()
607
608    return limits
609
610def _print_stats(label, err):
611    sorted_err = np.sort(abs(err))
612    p50 = int((len(err)-1)*0.50)
613    p98 = int((len(err)-1)*0.98)
614    data = [
615        "max:%.3e"%sorted_err[-1],
616        "median:%.3e"%sorted_err[p50],
617        "98%%:%.3e"%sorted_err[p98],
618        "rms:%.3e"%np.sqrt(np.mean(err**2)),
619        "zero-offset:%+.3e"%np.mean(err),
620        ]
621    print(label+"  "+"  ".join(data))
622
623
624
625# ===========================================================================
626#
627NAME_OPTIONS = set([
628    'plot', 'noplot',
629    'half', 'fast', 'single', 'double',
630    'single!', 'double!', 'quad!', 'sasview',
631    'lowq', 'midq', 'highq', 'exq',
632    '2d', '1d',
633    'preset', 'random',
634    'poly', 'mono',
635    'nopars', 'pars',
636    'rel', 'abs',
637    'linear', 'log', 'q4',
638    'hist', 'nohist',
639    'edit',
640    'demo', 'default',
641    ])
642VALUE_OPTIONS = [
643    # Note: random is both a name option and a value option
644    'cutoff', 'random', 'nq', 'res', 'accuracy',
645    ]
646
647def columnize(L, indent="", width=79):
648    """
649    Format a list of strings into columns.
650
651    Returns a string with carriage returns ready for printing.
652    """
653    column_width = max(len(w) for w in L) + 1
654    num_columns = (width - len(indent)) // column_width
655    num_rows = len(L) // num_columns
656    L = L + [""] * (num_rows*num_columns - len(L))
657    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
658    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
659             for row in zip(*columns)]
660    output = indent + ("\n"+indent).join(lines)
661    return output
662
663
664def get_pars(model_info, use_demo=False):
665    """
666    Extract demo parameters from the model definition.
667    """
668    # Get the default values for the parameters
669    pars = {}
670    for p in model_info['parameters']:
671        parts = [('', p.default)]
672        if p.polydisperse:
673            parts.append(('_pd', 0.0))
674            parts.append(('_pd_n', 0))
675            parts.append(('_pd_nsigma', 3.0))
676            parts.append(('_pd_type', "gaussian"))
677        for ext, val in parts:
678            if p.length > 1:
679                dict(("%s%s[%d]"%(p.id,ext,k), val) for k in range(p.length))
680            else:
681                pars[p.id+ext] = val
682
683    # Plug in values given in demo
684    if use_demo:
685        pars.update(model_info['demo'])
686    return pars
687
688
689def parse_opts():
690    """
691    Parse command line options.
692    """
693    MODELS = core.list_models()
694    flags = [arg for arg in sys.argv[1:]
695             if arg.startswith('-')]
696    values = [arg for arg in sys.argv[1:]
697              if not arg.startswith('-') and '=' in arg]
698    args = [arg for arg in sys.argv[1:]
699            if not arg.startswith('-') and '=' not in arg]
700    models = "\n    ".join("%-15s"%v for v in MODELS)
701    if len(args) == 0:
702        print(USAGE)
703        print("\nAvailable models:")
704        print(columnize(MODELS, indent="  "))
705        sys.exit(1)
706    if len(args) > 3:
707        print("expected parameters: model N1 N2")
708
709    name = args[0]
710    try:
711        if name.endswith('.py'):
712            model_info = core.load_model_info_from_path(name)
713        else:
714            model_info = core.load_model_info(name)
715    except ImportError, exc:
716        print(str(exc))
717        print("Could not find model; use one of:\n    " + models)
718        sys.exit(1)
719
720    invalid = [o[1:] for o in flags
721               if o[1:] not in NAME_OPTIONS
722               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
723    if invalid:
724        print("Invalid options: %s"%(", ".join(invalid)))
725        sys.exit(1)
726
727
728    # pylint: disable=bad-whitespace
729    # Interpret the flags
730    opts = {
731        'plot'      : True,
732        'view'      : 'log',
733        'is2d'      : False,
734        'qmax'      : 0.05,
735        'nq'        : 128,
736        'res'       : 0.0,
737        'accuracy'  : 'Low',
738        'cutoff'    : 0.0,
739        'seed'      : -1,  # default to preset
740        'mono'      : False,
741        'show_pars' : False,
742        'show_hist' : False,
743        'rel_err'   : True,
744        'explore'   : False,
745        'use_demo'  : True,
746    }
747    engines = []
748    for arg in flags:
749        if arg == '-noplot':    opts['plot'] = False
750        elif arg == '-plot':    opts['plot'] = True
751        elif arg == '-linear':  opts['view'] = 'linear'
752        elif arg == '-log':     opts['view'] = 'log'
753        elif arg == '-q4':      opts['view'] = 'q4'
754        elif arg == '-1d':      opts['is2d'] = False
755        elif arg == '-2d':      opts['is2d'] = True
756        elif arg == '-exq':     opts['qmax'] = 10.0
757        elif arg == '-highq':   opts['qmax'] = 1.0
758        elif arg == '-midq':    opts['qmax'] = 0.2
759        elif arg == '-lowq':    opts['qmax'] = 0.05
760        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
761        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
762        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
763        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
764        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
765        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
766        elif arg == '-preset':  opts['seed'] = -1
767        elif arg == '-mono':    opts['mono'] = True
768        elif arg == '-poly':    opts['mono'] = False
769        elif arg == '-pars':    opts['show_pars'] = True
770        elif arg == '-nopars':  opts['show_pars'] = False
771        elif arg == '-hist':    opts['show_hist'] = True
772        elif arg == '-nohist':  opts['show_hist'] = False
773        elif arg == '-rel':     opts['rel_err'] = True
774        elif arg == '-abs':     opts['rel_err'] = False
775        elif arg == '-half':    engines.append(arg[1:])
776        elif arg == '-fast':    engines.append(arg[1:])
777        elif arg == '-single':  engines.append(arg[1:])
778        elif arg == '-double':  engines.append(arg[1:])
779        elif arg == '-single!': engines.append(arg[1:])
780        elif arg == '-double!': engines.append(arg[1:])
781        elif arg == '-quad!':   engines.append(arg[1:])
782        elif arg == '-sasview': engines.append(arg[1:])
783        elif arg == '-edit':    opts['explore'] = True
784        elif arg == '-demo':    opts['use_demo'] = True
785        elif arg == '-default':    opts['use_demo'] = False
786    # pylint: enable=bad-whitespace
787
788    if len(engines) == 0:
789        engines.extend(['single', 'double'])
790    elif len(engines) == 1:
791        if engines[0][0] != 'double':
792            engines.append('double')
793        else:
794            engines.append('single')
795    elif len(engines) > 2:
796        del engines[2:]
797
798    n1 = int(args[1]) if len(args) > 1 else 1
799    n2 = int(args[2]) if len(args) > 2 else 1
800
801    # Get demo parameters from model definition, or use default parameters
802    # if model does not define demo parameters
803    pars = get_pars(model_info, opts['use_demo'])
804
805
806    # Fill in parameters given on the command line
807    presets = {}
808    for arg in values:
809        k, v = arg.split('=', 1)
810        if k not in pars:
811            # extract base name without polydispersity info
812            s = set(p.split('_pd')[0] for p in pars)
813            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
814            sys.exit(1)
815        presets[k] = float(v) if not k.endswith('type') else v
816
817    # randomize parameters
818    #pars.update(set_pars)  # set value before random to control range
819    if opts['seed'] > -1:
820        pars = randomize_pars(pars, seed=opts['seed'])
821        print("Randomize using -random=%i"%opts['seed'])
822    if opts['mono']:
823        pars = suppress_pd(pars)
824    pars.update(presets)  # set value after random to control value
825    #import pprint; pprint.pprint(model_info)
826    constrain_pars(model_info, pars)
827    constrain_new_to_old(model_info, pars)
828    if opts['show_pars']:
829        print(str(parlist(model_info, pars, opts['is2d'])))
830
831    # Create the computational engines
832    data, _ = make_data(opts)
833    if n1:
834        base = make_engine(model_info, data, engines[0], opts['cutoff'])
835    else:
836        base = None
837    if n2:
838        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
839    else:
840        comp = None
841
842    # pylint: disable=bad-whitespace
843    # Remember it all
844    opts.update({
845        'name'      : name,
846        'def'       : model_info,
847        'n1'        : n1,
848        'n2'        : n2,
849        'presets'   : presets,
850        'pars'      : pars,
851        'data'      : data,
852        'engines'   : [base, comp],
853    })
854    # pylint: enable=bad-whitespace
855
856    return opts
857
858def explore(opts):
859    """
860    Explore the model using the Bumps GUI.
861    """
862    import wx
863    from bumps.names import FitProblem
864    from bumps.gui.app_frame import AppFrame
865
866    problem = FitProblem(Explore(opts))
867    is_mac = "cocoa" in wx.version()
868    app = wx.App()
869    frame = AppFrame(parent=None, title="explore")
870    if not is_mac: frame.Show()
871    frame.panel.set_model(model=problem)
872    frame.panel.Layout()
873    frame.panel.aui.Split(0, wx.TOP)
874    if is_mac: frame.Show()
875    app.MainLoop()
876
877class Explore(object):
878    """
879    Bumps wrapper for a SAS model comparison.
880
881    The resulting object can be used as a Bumps fit problem so that
882    parameters can be adjusted in the GUI, with plots updated on the fly.
883    """
884    def __init__(self, opts):
885        from bumps.cli import config_matplotlib
886        from . import bumps_model
887        config_matplotlib()
888        self.opts = opts
889        model_info = opts['def']
890        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
891        if not opts['is2d']:
892            for p in model_info['parameters'].type['1d']:
893                for ext in ['', '_pd', '_pd_n', '_pd_nsigma']:
894                    k = p.name+ext
895                    v = pars.get(k, None)
896                    if v is not None:
897                        v.range(*parameter_range(k, v.value))
898        else:
899            for k, v in pars.items():
900                v.range(*parameter_range(k, v.value))
901
902        self.pars = pars
903        self.pd_types = pd_types
904        self.limits = None
905
906    def numpoints(self):
907        """
908        Return the number of points.
909        """
910        return len(self.pars) + 1  # so dof is 1
911
912    def parameters(self):
913        """
914        Return a dictionary of parameters.
915        """
916        return self.pars
917
918    def nllf(self):
919        """
920        Return cost.
921        """
922        # pylint: disable=no-self-use
923        return 0.  # No nllf
924
925    def plot(self, view='log'):
926        """
927        Plot the data and residuals.
928        """
929        pars = dict((k, v.value) for k, v in self.pars.items())
930        pars.update(self.pd_types)
931        self.opts['pars'] = pars
932        limits = compare(self.opts, limits=self.limits)
933        if self.limits is None:
934            vmin, vmax = limits
935            vmax = 1.3*vmax
936            vmin = vmax*1e-7
937            self.limits = vmin, vmax
938
939
940def main():
941    """
942    Main program.
943    """
944    opts = parse_opts()
945    if opts['explore']:
946        explore(opts)
947    else:
948        compare(opts)
949
950if __name__ == "__main__":
951    main()
Note: See TracBrowser for help on using the repository browser.