source: sasmodels/sasmodels/compare.py @ 17bbadd

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 17bbadd was 17bbadd, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor so all model defintion queries use model_info; better documentation of model_info structure; initial implementation of product model (broken)

  • Property mode set to 100755
File size: 31.1 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31import sys
32import math
33import datetime
34import traceback
35
36import numpy as np
37
38from . import core
39from . import kerneldll
40from . import product
41from .data import plot_theory, empty_data1D, empty_data2D
42from .direct_model import DirectModel
43from .convert import revert_pars, constrain_new_to_old
44
45USAGE = """
46usage: compare.py model N1 N2 [options...] [key=val]
47
48Compare the speed and value for a model between the SasView original and the
49sasmodels rewrite.
50
51model is the name of the model to compare (see below).
52N1 is the number of times to run sasmodels (default=1).
53N2 is the number times to run sasview (default=1).
54
55Options (* for default):
56
57    -plot*/-noplot plots or suppress the plot of the model
58    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
59    -nq=128 sets the number of Q points in the data set
60    -1d*/-2d computes 1d or 2d data
61    -preset*/-random[=seed] preset or random parameters
62    -mono/-poly* force monodisperse/polydisperse
63    -cutoff=1e-5* cutoff value for including a point in polydispersity
64    -pars/-nopars* prints the parameter set or not
65    -abs/-rel* plot relative or absolute error
66    -linear/-log*/-q4 intensity scaling
67    -hist/-nohist* plot histogram of relative error
68    -res=0 sets the resolution width dQ/Q if calculating with resolution
69    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
70    -edit starts the parameter explorer
71
72Any two calculation engines can be selected for comparison:
73
74    -single/-double/-half/-fast sets an OpenCL calculation engine
75    -single!/-double!/-quad! sets an OpenMP calculation engine
76    -sasview sets the sasview calculation engine
77
78The default is -single -sasview.  Note that the interpretation of quad
79precision depends on architecture, and may vary from 64-bit to 128-bit,
80with 80-bit floats being common (1e-19 precision).
81
82Key=value pairs allow you to set specific values for the model parameters.
83"""
84
85# Update docs with command line usage string.   This is separate from the usual
86# doc string so that we can display it at run time if there is an error.
87# lin
88__doc__ = (__doc__  # pylint: disable=redefined-builtin
89           + """
90Program description
91-------------------
92
93"""
94           + USAGE)
95
96kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
97
98MODELS = core.list_models()
99
100# CRUFT python 2.6
101if not hasattr(datetime.timedelta, 'total_seconds'):
102    def delay(dt):
103        """Return number date-time delta as number seconds"""
104        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
105else:
106    def delay(dt):
107        """Return number date-time delta as number seconds"""
108        return dt.total_seconds()
109
110
111class push_seed(object):
112    """
113    Set the seed value for the random number generator.
114
115    When used in a with statement, the random number generator state is
116    restored after the with statement is complete.
117
118    :Parameters:
119
120    *seed* : int or array_like, optional
121        Seed for RandomState
122
123    :Example:
124
125    Seed can be used directly to set the seed::
126
127        >>> from numpy.random import randint
128        >>> push_seed(24)
129        <...push_seed object at...>
130        >>> print(randint(0,1000000,3))
131        [242082    899 211136]
132
133    Seed can also be used in a with statement, which sets the random
134    number generator state for the enclosed computations and restores
135    it to the previous state on completion::
136
137        >>> with push_seed(24):
138        ...    print(randint(0,1000000,3))
139        [242082    899 211136]
140
141    Using nested contexts, we can demonstrate that state is indeed
142    restored after the block completes::
143
144        >>> with push_seed(24):
145        ...    print(randint(0,1000000))
146        ...    with push_seed(24):
147        ...        print(randint(0,1000000,3))
148        ...    print(randint(0,1000000))
149        242082
150        [242082    899 211136]
151        899
152
153    The restore step is protected against exceptions in the block::
154
155        >>> with push_seed(24):
156        ...    print(randint(0,1000000))
157        ...    try:
158        ...        with push_seed(24):
159        ...            print(randint(0,1000000,3))
160        ...            raise Exception()
161        ...    except:
162        ...        print("Exception raised")
163        ...    print(randint(0,1000000))
164        242082
165        [242082    899 211136]
166        Exception raised
167        899
168    """
169    def __init__(self, seed=None):
170        self._state = np.random.get_state()
171        np.random.seed(seed)
172
173    def __enter__(self):
174        return None
175
176    def __exit__(self, *args):
177        np.random.set_state(self._state)
178
179def tic():
180    """
181    Timer function.
182
183    Use "toc=tic()" to start the clock and "toc()" to measure
184    a time interval.
185    """
186    then = datetime.datetime.now()
187    return lambda: delay(datetime.datetime.now() - then)
188
189
190def set_beam_stop(data, radius, outer=None):
191    """
192    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
193
194    Note: this function does not use the sasview package
195    """
196    if hasattr(data, 'qx_data'):
197        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
198        data.mask = (q < radius)
199        if outer is not None:
200            data.mask |= (q >= outer)
201    else:
202        data.mask = (data.x < radius)
203        if outer is not None:
204            data.mask |= (data.x >= outer)
205
206
207def parameter_range(p, v):
208    """
209    Choose a parameter range based on parameter name and initial value.
210    """
211    if p.endswith('_pd_n'):
212        return [0, 100]
213    elif p.endswith('_pd_nsigma'):
214        return [0, 5]
215    elif p.endswith('_pd_type'):
216        return v
217    elif any(s in p for s in ('theta', 'phi', 'psi')):
218        # orientation in [-180,180], orientation pd in [0,45]
219        if p.endswith('_pd'):
220            return [0, 45]
221        else:
222            return [-180, 180]
223    elif 'sld' in p:
224        return [-0.5, 10]
225    elif p.endswith('_pd'):
226        return [0, 1]
227    elif p == 'background':
228        return [0, 10]
229    elif p == 'scale':
230        return [0, 1e3]
231    elif p == 'case_num':
232        # RPA hack
233        return [0, 10]
234    elif v < 0:
235        # Kxy parameters in rpa model can be negative
236        return [2*v, -2*v]
237    else:
238        return [0, (2*v if v > 0 else 1)]
239
240
241def _randomize_one(p, v):
242    """
243    Randomize a single parameter.
244    """
245    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
246        return v
247    else:
248        return np.random.uniform(*parameter_range(p, v))
249
250
251def randomize_pars(pars, seed=None):
252    """
253    Generate random values for all of the parameters.
254
255    Valid ranges for the random number generator are guessed from the name of
256    the parameter; this will not account for constraints such as cap radius
257    greater than cylinder radius in the capped_cylinder model, so
258    :func:`constrain_pars` needs to be called afterward..
259    """
260    with push_seed(seed):
261        # Note: the sort guarantees order `of calls to random number generator
262        pars = dict((p, _randomize_one(p, v))
263                    for p, v in sorted(pars.items()))
264    return pars
265
266def constrain_pars(model_info, pars):
267    """
268    Restrict parameters to valid values.
269
270    This includes model specific code for models such as capped_cylinder
271    which need to support within model constraints (cap radius more than
272    cylinder radius in this case).
273    """
274    name = model_info['id']
275    # if it is a product model, then just look at the form factor since
276    # none of the structure factors need any constraints.
277    if '*' in name:
278        name = name.split('*')[0]
279
280    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
281        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
282    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
283        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
284
285    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
286    if name == 'guinier':
287        #q_max = 0.2  # mid q maximum
288        q_max = 1.0  # high q maximum
289        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
290        pars['rg'] = min(pars['rg'], rg_max)
291
292    if name == 'rpa':
293        # Make sure phi sums to 1.0
294        if pars['case_num'] < 2:
295            pars['Phia'] = 0.
296            pars['Phib'] = 0.
297        elif pars['case_num'] < 5:
298            pars['Phia'] = 0.
299        total = sum(pars['Phi'+c] for c in 'abcd')
300        for c in 'abcd':
301            pars['Phi'+c] /= total
302
303def parlist(pars):
304    """
305    Format the parameter list for printing.
306    """
307    active = None
308    fields = {}
309    lines = []
310    for k, v in sorted(pars.items()):
311        parts = k.split('_pd')
312        #print(k, active, parts)
313        if len(parts) == 1:
314            if active: lines.append(_format_par(active, **fields))
315            active = k
316            fields = {'value': v}
317        else:
318            assert parts[0] == active
319            if parts[1]:
320                fields[parts[1][1:]] = v
321            else:
322                fields['pd'] = v
323    if active: lines.append(_format_par(active, **fields))
324    return "\n".join(lines)
325
326    #return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
327
328def _format_par(name, value=0., pd=0., n=0, nsigma=3., type='gaussian'):
329    line = "%s: %g"%(name, value)
330    if pd != 0.  and n != 0:
331        line += " +/- %g  (%d points in [-%g,%g] sigma %s)"\
332                % (pd, n, nsigma, nsigma, type)
333    return line
334
335def suppress_pd(pars):
336    """
337    Suppress theta_pd for now until the normalization is resolved.
338
339    May also suppress complete polydispersity of the model to test
340    models more quickly.
341    """
342    pars = pars.copy()
343    for p in pars:
344        if p.endswith("_pd_n"): pars[p] = 0
345    return pars
346
347def eval_sasview(model_info, data):
348    """
349    Return a model calculator using the SasView fitting engine.
350    """
351    # importing sas here so that the error message will be that sas failed to
352    # import rather than the more obscure smear_selection not imported error
353    import sas
354    from sas.models.qsmearing import smear_selection
355
356    def get_model(name):
357        #print("new",sorted(_pars.items()))
358        sas = __import__('sas.models.' + name)
359        ModelClass = getattr(getattr(sas.models, name, None), name, None)
360        if ModelClass is None:
361            raise ValueError("could not find model %r in sas.models"%name)
362        return ModelClass()
363
364    # grab the sasview model, or create it if it is a product model
365    if model_info['composition']:
366        composition_type, parts = model_info['composition']
367        if composition_type == 'product':
368            from sas.models import MultiplicationModel
369            P, S = [get_model(p) for p in model_info['oldname']]
370            model = MultiplicationModel(P, S)
371        else:
372            raise ValueError("mixture models not handled yet")
373    else:
374        model = get_model(model_info['oldname'])
375
376    # build a smearer with which to call the model, if necessary
377    smearer = smear_selection(data, model=model)
378    if hasattr(data, 'qx_data'):
379        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
380        index = ((~data.mask) & (~np.isnan(data.data))
381                 & (q >= data.qmin) & (q <= data.qmax))
382        if smearer is not None:
383            smearer.model = model  # because smear_selection has a bug
384            smearer.accuracy = data.accuracy
385            smearer.set_index(index)
386            theory = lambda: smearer.get_value()
387        else:
388            theory = lambda: model.evalDistribution([data.qx_data[index],
389                                                     data.qy_data[index]])
390    elif smearer is not None:
391        theory = lambda: smearer(model.evalDistribution(data.x))
392    else:
393        theory = lambda: model.evalDistribution(data.x)
394
395    def calculator(**pars):
396        """
397        Sasview calculator for model.
398        """
399        # paying for parameter conversion each time to keep life simple, if not fast
400        pars = revert_pars(model_info, pars)
401        for k, v in pars.items():
402            parts = k.split('.')  # polydispersity components
403            if len(parts) == 2:
404                model.dispersion[parts[0]][parts[1]] = v
405            else:
406                model.setParam(k, v)
407        return theory()
408
409    calculator.engine = "sasview"
410    return calculator
411
412DTYPE_MAP = {
413    'half': '16',
414    'fast': 'fast',
415    'single': '32',
416    'double': '64',
417    'quad': '128',
418    'f16': '16',
419    'f32': '32',
420    'f64': '64',
421    'longdouble': '128',
422}
423def eval_opencl(model_info, data, dtype='single', cutoff=0.):
424    """
425    Return a model calculator using the OpenCL calculation engine.
426    """
427    def builder(model_info):
428        try:
429            return core.build_model(model_info, dtype=dtype, platform="ocl")
430        except Exception as exc:
431            print(exc)
432            print("... trying again with single precision")
433            dtype = 'single'
434            return core.build_model(model_info, dtype=dtype, platform="ocl")
435    if model_info['composition']:
436        composition_type, parts = model_info['composition']
437        if composition_type == 'product':
438            P, S = [builder(p) for p in parts]
439            model = product.ProductModel(P, S)
440        else:
441            raise ValueError("mixture models not handled yet")
442    else:
443        model = builder(model_info)
444    calculator = DirectModel(data, model, cutoff=cutoff)
445    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
446    return calculator
447
448def eval_ctypes(model_info, data, dtype='double', cutoff=0.):
449    """
450    Return a model calculator using the DLL calculation engine.
451    """
452    if dtype == 'quad':
453        dtype = 'longdouble'
454    def builder(model_info):
455        return core.build_model(model_info, dtype=dtype, platform="dll")
456
457    if model_info['composition']:
458        composition_type, parts = model_info['composition']
459        if composition_type == 'product':
460            P, S = [builder(p) for p in parts]
461            model = product.ProductModel(P, S)
462        else:
463            raise ValueError("mixture models not handled yet")
464    else:
465        model = builder(model_info)
466    calculator = DirectModel(data, model, cutoff=cutoff)
467    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
468    return calculator
469
470def time_calculation(calculator, pars, Nevals=1):
471    """
472    Compute the average calculation time over N evaluations.
473
474    An additional call is generated without polydispersity in order to
475    initialize the calculation engine, and make the average more stable.
476    """
477    # initialize the code so time is more accurate
478    value = calculator(**suppress_pd(pars))
479    toc = tic()
480    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
481        value = calculator(**pars)
482    average_time = toc()*1000./Nevals
483    return value, average_time
484
485def make_data(opts):
486    """
487    Generate an empty dataset, used with the model to set Q points
488    and resolution.
489
490    *opts* contains the options, with 'qmax', 'nq', 'res',
491    'accuracy', 'is2d' and 'view' parsed from the command line.
492    """
493    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
494    if opts['is2d']:
495        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
496        data.accuracy = opts['accuracy']
497        set_beam_stop(data, 0.004)
498        index = ~data.mask
499    else:
500        if opts['view'] == 'log':
501            qmax = math.log10(qmax)
502            q = np.logspace(qmax-3, qmax, nq)
503        else:
504            q = np.linspace(0.001*qmax, qmax, nq)
505        data = empty_data1D(q, resolution=res)
506        index = slice(None, None)
507    return data, index
508
509def make_engine(model_info, data, dtype, cutoff):
510    """
511    Generate the appropriate calculation engine for the given datatype.
512
513    Datatypes with '!' appended are evaluated using external C DLLs rather
514    than OpenCL.
515    """
516    if dtype == 'sasview':
517        return eval_sasview(model_info, data)
518    elif dtype.endswith('!'):
519        return eval_ctypes(model_info, data, dtype=dtype[:-1], cutoff=cutoff)
520    else:
521        return eval_opencl(model_info, data, dtype=dtype, cutoff=cutoff)
522
523def compare(opts, limits=None):
524    """
525    Preform a comparison using options from the command line.
526
527    *limits* are the limits on the values to use, either to set the y-axis
528    for 1D or to set the colormap scale for 2D.  If None, then they are
529    inferred from the data and returned. When exploring using Bumps,
530    the limits are set when the model is initially called, and maintained
531    as the values are adjusted, making it easier to see the effects of the
532    parameters.
533    """
534    Nbase, Ncomp = opts['n1'], opts['n2']
535    pars = opts['pars']
536    data = opts['data']
537
538    # Base calculation
539    if Nbase > 0:
540        base = opts['engines'][0]
541        try:
542            base_value, base_time = time_calculation(base, pars, Nbase)
543            print("%s t=%.1f ms, intensity=%.0f"
544                  % (base.engine, base_time, sum(base_value)))
545        except ImportError:
546            traceback.print_exc()
547            Nbase = 0
548
549    # Comparison calculation
550    if Ncomp > 0:
551        comp = opts['engines'][1]
552        try:
553            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
554            print("%s t=%.1f ms, intensity=%.0f"
555                  % (comp.engine, comp_time, sum(comp_value)))
556        except ImportError:
557            traceback.print_exc()
558            Ncomp = 0
559
560    # Compare, but only if computing both forms
561    if Nbase > 0 and Ncomp > 0:
562        resid = (base_value - comp_value)
563        relerr = resid/comp_value
564        _print_stats("|%s-%s|"
565                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
566                     resid)
567        _print_stats("|(%s-%s)/%s|"
568                     % (base.engine, comp.engine, comp.engine),
569                     relerr)
570
571    # Plot if requested
572    if not opts['plot'] and not opts['explore']: return
573    view = opts['view']
574    import matplotlib.pyplot as plt
575    if limits is None:
576        vmin, vmax = np.Inf, -np.Inf
577        if Nbase > 0:
578            vmin = min(vmin, min(base_value))
579            vmax = max(vmax, max(base_value))
580        if Ncomp > 0:
581            vmin = min(vmin, min(comp_value))
582            vmax = max(vmax, max(comp_value))
583        limits = vmin, vmax
584
585    if Nbase > 0:
586        if Ncomp > 0: plt.subplot(131)
587        plot_theory(data, base_value, view=view, use_data=False, limits=limits)
588        plt.title("%s t=%.1f ms"%(base.engine, base_time))
589        #cbar_title = "log I"
590    if Ncomp > 0:
591        if Nbase > 0: plt.subplot(132)
592        plot_theory(data, comp_value, view=view, use_data=False, limits=limits)
593        plt.title("%s t=%.1f ms"%(comp.engine, comp_time))
594        #cbar_title = "log I"
595    if Ncomp > 0 and Nbase > 0:
596        plt.subplot(133)
597        if not opts['rel_err']:
598            err, errstr, errview = resid, "abs err", "linear"
599        else:
600            err, errstr, errview = abs(relerr), "rel err", "log"
601        #err,errstr = base/comp,"ratio"
602        plot_theory(data, None, resid=err, view=errview, use_data=False)
603        if view == 'linear':
604            plt.xscale('linear')
605        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
606        #cbar_title = errstr if errview=="linear" else "log "+errstr
607    #if is2D:
608    #    h = plt.colorbar()
609    #    h.ax.set_title(cbar_title)
610
611    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
612        plt.figure()
613        v = relerr
614        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
615        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
616        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
617                   % (base.engine, comp.engine, comp.engine))
618        plt.ylabel('P(err)')
619        plt.title('Distribution of relative error between calculation engines')
620
621    if not opts['explore']:
622        plt.show()
623
624    return limits
625
626def _print_stats(label, err):
627    sorted_err = np.sort(abs(err))
628    p50 = int((len(err)-1)*0.50)
629    p98 = int((len(err)-1)*0.98)
630    data = [
631        "max:%.3e"%sorted_err[-1],
632        "median:%.3e"%sorted_err[p50],
633        "98%%:%.3e"%sorted_err[p98],
634        "rms:%.3e"%np.sqrt(np.mean(err**2)),
635        "zero-offset:%+.3e"%np.mean(err),
636        ]
637    print(label+"  "+"  ".join(data))
638
639
640
641# ===========================================================================
642#
643NAME_OPTIONS = set([
644    'plot', 'noplot',
645    'half', 'fast', 'single', 'double',
646    'single!', 'double!', 'quad!', 'sasview',
647    'lowq', 'midq', 'highq', 'exq',
648    '2d', '1d',
649    'preset', 'random',
650    'poly', 'mono',
651    'nopars', 'pars',
652    'rel', 'abs',
653    'linear', 'log', 'q4',
654    'hist', 'nohist',
655    'edit',
656    ])
657VALUE_OPTIONS = [
658    # Note: random is both a name option and a value option
659    'cutoff', 'random', 'nq', 'res', 'accuracy',
660    ]
661
662def columnize(L, indent="", width=79):
663    """
664    Format a list of strings into columns.
665
666    Returns a string with carriage returns ready for printing.
667    """
668    column_width = max(len(w) for w in L) + 1
669    num_columns = (width - len(indent)) // column_width
670    num_rows = len(L) // num_columns
671    L = L + [""] * (num_rows*num_columns - len(L))
672    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
673    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
674             for row in zip(*columns)]
675    output = indent + ("\n"+indent).join(lines)
676    return output
677
678
679def get_demo_pars(model_info):
680    """
681    Extract demo parameters from the model definition.
682    """
683    # Get the default values for the parameters
684    pars = dict((p[0], p[2]) for p in model_info['parameters'])
685
686    # Fill in default values for the polydispersity parameters
687    for p in model_info['parameters']:
688        if p[4] in ('volume', 'orientation'):
689            pars[p[0]+'_pd'] = 0.0
690            pars[p[0]+'_pd_n'] = 0
691            pars[p[0]+'_pd_nsigma'] = 3.0
692            pars[p[0]+'_pd_type'] = "gaussian"
693
694    # Plug in values given in demo
695    pars.update(model_info['demo'])
696    return pars
697
698
699def parse_opts():
700    """
701    Parse command line options.
702    """
703    MODELS = core.list_models()
704    flags = [arg for arg in sys.argv[1:]
705             if arg.startswith('-')]
706    values = [arg for arg in sys.argv[1:]
707              if not arg.startswith('-') and '=' in arg]
708    args = [arg for arg in sys.argv[1:]
709            if not arg.startswith('-') and '=' not in arg]
710    models = "\n    ".join("%-15s"%v for v in MODELS)
711    if len(args) == 0:
712        print(USAGE)
713        print("\nAvailable models:")
714        print(columnize(MODELS, indent="  "))
715        sys.exit(1)
716    if len(args) > 3:
717        print("expected parameters: model N1 N2")
718
719    def load_model(name):
720        try:
721            model_info = core.load_model_info(name)
722        except ImportError, exc:
723            print(str(exc))
724            print("Use one of:\n    " + models)
725            sys.exit(1)
726        return model_info
727
728    name = args[0]
729    if '*' in name:
730        parts = [load_model(k) for k in name.split('*')]
731        model_info = product.make_product_info(*parts)
732    else:
733        model_info = load_model(name)
734
735    invalid = [o[1:] for o in flags
736               if o[1:] not in NAME_OPTIONS
737               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
738    if invalid:
739        print("Invalid options: %s"%(", ".join(invalid)))
740        sys.exit(1)
741
742
743    # pylint: disable=bad-whitespace
744    # Interpret the flags
745    opts = {
746        'plot'      : True,
747        'view'      : 'log',
748        'is2d'      : False,
749        'qmax'      : 0.05,
750        'nq'        : 128,
751        'res'       : 0.0,
752        'accuracy'  : 'Low',
753        'cutoff'    : 1e-5,
754        'seed'      : -1,  # default to preset
755        'mono'      : False,
756        'show_pars' : False,
757        'show_hist' : False,
758        'rel_err'   : True,
759        'explore'   : False,
760    }
761    engines = []
762    for arg in flags:
763        if arg == '-noplot':    opts['plot'] = False
764        elif arg == '-plot':    opts['plot'] = True
765        elif arg == '-linear':  opts['view'] = 'linear'
766        elif arg == '-log':     opts['view'] = 'log'
767        elif arg == '-q4':      opts['view'] = 'q4'
768        elif arg == '-1d':      opts['is2d'] = False
769        elif arg == '-2d':      opts['is2d'] = True
770        elif arg == '-exq':     opts['qmax'] = 10.0
771        elif arg == '-highq':   opts['qmax'] = 1.0
772        elif arg == '-midq':    opts['qmax'] = 0.2
773        elif arg == '-lowq':    opts['qmax'] = 0.05
774        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
775        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
776        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
777        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
778        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
779        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
780        elif arg == '-preset':  opts['seed'] = -1
781        elif arg == '-mono':    opts['mono'] = True
782        elif arg == '-poly':    opts['mono'] = False
783        elif arg == '-pars':    opts['show_pars'] = True
784        elif arg == '-nopars':  opts['show_pars'] = False
785        elif arg == '-hist':    opts['show_hist'] = True
786        elif arg == '-nohist':  opts['show_hist'] = False
787        elif arg == '-rel':     opts['rel_err'] = True
788        elif arg == '-abs':     opts['rel_err'] = False
789        elif arg == '-half':    engines.append(arg[1:])
790        elif arg == '-fast':    engines.append(arg[1:])
791        elif arg == '-single':  engines.append(arg[1:])
792        elif arg == '-double':  engines.append(arg[1:])
793        elif arg == '-single!': engines.append(arg[1:])
794        elif arg == '-double!': engines.append(arg[1:])
795        elif arg == '-quad!':   engines.append(arg[1:])
796        elif arg == '-sasview': engines.append(arg[1:])
797        elif arg == '-edit':    opts['explore'] = True
798    # pylint: enable=bad-whitespace
799
800    if len(engines) == 0:
801        engines.extend(['single', 'sasview'])
802    elif len(engines) == 1:
803        if engines[0][0] != 'sasview':
804            engines.append('sasview')
805        else:
806            engines.append('single')
807    elif len(engines) > 2:
808        del engines[2:]
809
810    n1 = int(args[1]) if len(args) > 1 else 1
811    n2 = int(args[2]) if len(args) > 2 else 1
812
813    # Get demo parameters from model definition, or use default parameters
814    # if model does not define demo parameters
815    pars = get_demo_pars(model_info)
816
817    # Fill in parameters given on the command line
818    presets = {}
819    for arg in values:
820        k, v = arg.split('=', 1)
821        if k not in pars:
822            # extract base name without polydispersity info
823            s = set(p.split('_pd')[0] for p in pars)
824            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
825            sys.exit(1)
826        presets[k] = float(v) if not k.endswith('type') else v
827
828    # randomize parameters
829    #pars.update(set_pars)  # set value before random to control range
830    if opts['seed'] > -1:
831        pars = randomize_pars(pars, seed=opts['seed'])
832        print("Randomize using -random=%i"%opts['seed'])
833    if opts['mono']:
834        pars = suppress_pd(pars)
835    pars.update(presets)  # set value after random to control value
836    constrain_pars(model_info, pars)
837    constrain_new_to_old(model_info, pars)
838    if opts['show_pars']:
839        print(str(parlist(pars)))
840
841    # Create the computational engines
842    data, _ = make_data(opts)
843    if n1:
844        base = make_engine(model_info, data, engines[0], opts['cutoff'])
845    else:
846        base = None
847    if n2:
848        comp = make_engine(model_info, data, engines[1], opts['cutoff'])
849    else:
850        comp = None
851
852    # pylint: disable=bad-whitespace
853    # Remember it all
854    opts.update({
855        'name'      : name,
856        'def'       : model_info,
857        'n1'        : n1,
858        'n2'        : n2,
859        'presets'   : presets,
860        'pars'      : pars,
861        'data'      : data,
862        'engines'   : [base, comp],
863    })
864    # pylint: enable=bad-whitespace
865
866    return opts
867
868def explore(opts):
869    """
870    Explore the model using the Bumps GUI.
871    """
872    import wx
873    from bumps.names import FitProblem
874    from bumps.gui.app_frame import AppFrame
875
876    problem = FitProblem(Explore(opts))
877    is_mac = "cocoa" in wx.version()
878    app = wx.App()
879    frame = AppFrame(parent=None, title="explore")
880    if not is_mac: frame.Show()
881    frame.panel.set_model(model=problem)
882    frame.panel.Layout()
883    frame.panel.aui.Split(0, wx.TOP)
884    if is_mac: frame.Show()
885    app.MainLoop()
886
887class Explore(object):
888    """
889    Bumps wrapper for a SAS model comparison.
890
891    The resulting object can be used as a Bumps fit problem so that
892    parameters can be adjusted in the GUI, with plots updated on the fly.
893    """
894    def __init__(self, opts):
895        from bumps.cli import config_matplotlib
896        from . import bumps_model
897        config_matplotlib()
898        self.opts = opts
899        model_info = opts['def']
900        pars, pd_types = bumps_model.create_parameters(model_info, **opts['pars'])
901        if not opts['is2d']:
902            active = [base + ext
903                      for base in model_info['partype']['pd-1d']
904                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
905            active.extend(model_info['partype']['fixed-1d'])
906            for k in active:
907                v = pars[k]
908                v.range(*parameter_range(k, v.value))
909        else:
910            for k, v in pars.items():
911                v.range(*parameter_range(k, v.value))
912
913        self.pars = pars
914        self.pd_types = pd_types
915        self.limits = None
916
917    def numpoints(self):
918        """
919        Return the number of points.
920        """
921        return len(self.pars) + 1  # so dof is 1
922
923    def parameters(self):
924        """
925        Return a dictionary of parameters.
926        """
927        return self.pars
928
929    def nllf(self):
930        """
931        Return cost.
932        """
933        # pylint: disable=no-self-use
934        return 0.  # No nllf
935
936    def plot(self, view='log'):
937        """
938        Plot the data and residuals.
939        """
940        pars = dict((k, v.value) for k, v in self.pars.items())
941        pars.update(self.pd_types)
942        self.opts['pars'] = pars
943        limits = compare(self.opts, limits=self.limits)
944        if self.limits is None:
945            vmin, vmax = limits
946            vmax = 1.3*vmax
947            vmin = vmax*1e-7
948            self.limits = vmin, vmax
949
950
951def main():
952    """
953    Main program.
954    """
955    opts = parse_opts()
956    if opts['explore']:
957        explore(opts)
958    else:
959        compare(opts)
960
961if __name__ == "__main__":
962    main()
Note: See TracBrowser for help on using the repository browser.