source: sasmodels/sasmodels/compare.py @ d15a908

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since d15a908 was d15a908, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

doc and delint

  • Property mode set to 100755
File size: 27.1 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31USAGE = """
32usage: compare.py model N1 N2 [options...] [key=val]
33
34Compare the speed and value for a model between the SasView original and the
35sasmodels rewrite.
36
37model is the name of the model to compare (see below).
38N1 is the number of times to run sasmodels (default=1).
39N2 is the number times to run sasview (default=1).
40
41Options (* for default):
42
43    -plot*/-noplot plots or suppress the plot of the model
44    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
45    -nq=128 sets the number of Q points in the data set
46    -1d*/-2d computes 1d or 2d data
47    -preset*/-random[=seed] preset or random parameters
48    -mono/-poly* force monodisperse/polydisperse
49    -cutoff=1e-5* cutoff value for including a point in polydispersity
50    -pars/-nopars* prints the parameter set or not
51    -abs/-rel* plot relative or absolute error
52    -linear/-log*/-q4 intensity scaling
53    -hist/-nohist* plot histogram of relative error
54    -res=0 sets the resolution width dQ/Q if calculating with resolution
55    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
56    -edit starts the parameter explorer
57
58Any two calculation engines can be selected for comparison:
59
60    -single/-double/-half/-fast sets an OpenCL calculation engine
61    -single!/-double!/-quad! sets an OpenMP calculation engine
62    -sasview sets the sasview calculation engine
63
64The default is -single -sasview.  Note that the interpretation of quad
65precision depends on architecture, and may vary from 64-bit to 128-bit,
66with 80-bit floats being common (1e-19 precision).
67
68Key=value pairs allow you to set specific values for the model parameters.
69"""
70
71# Update docs with command line usage string.   This is separate from the usual
72# doc string so that we can display it at run time if there is an error.
73# lin
74__doc__ = (__doc__  # pylint: disable=redefined-builtin
75           + """
76Program description
77-------------------
78
79"""
80           + USAGE)
81
82
83
84import sys
85import math
86from os.path import basename, dirname, join as joinpath
87import glob
88import datetime
89import traceback
90
91import numpy as np
92
93ROOT = dirname(__file__)
94sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
95
96
97from . import core
98from . import kerneldll
99from . import generate
100from .data import plot_theory, empty_data1D, empty_data2D
101from .direct_model import DirectModel
102from .convert import revert_model, constrain_new_to_old
103kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
104
105# List of available models
106MODELS = [basename(f)[:-3]
107          for f in sorted(glob.glob(joinpath(ROOT, "models", "[a-zA-Z]*.py")))]
108
109# CRUFT python 2.6
110if not hasattr(datetime.timedelta, 'total_seconds'):
111    def delay(dt):
112        """Return number date-time delta as number seconds"""
113        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
114else:
115    def delay(dt):
116        """Return number date-time delta as number seconds"""
117        return dt.total_seconds()
118
119
120def tic():
121    """
122    Timer function.
123
124    Use "toc=tic()" to start the clock and "toc()" to measure
125    a time interval.
126    """
127    then = datetime.datetime.now()
128    return lambda: delay(datetime.datetime.now() - then)
129
130
131def set_beam_stop(data, radius, outer=None):
132    """
133    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
134
135    Note: this function does not use the sasview package
136    """
137    if hasattr(data, 'qx_data'):
138        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
139        data.mask = (q < radius)
140        if outer is not None:
141            data.mask |= (q >= outer)
142    else:
143        data.mask = (data.x < radius)
144        if outer is not None:
145            data.mask |= (data.x >= outer)
146
147
148def parameter_range(p, v):
149    """
150    Choose a parameter range based on parameter name and initial value.
151    """
152    if p.endswith('_pd_n'):
153        return [0, 100]
154    elif p.endswith('_pd_nsigma'):
155        return [0, 5]
156    elif p.endswith('_pd_type'):
157        return v
158    elif any(s in p for s in ('theta', 'phi', 'psi')):
159        # orientation in [-180,180], orientation pd in [0,45]
160        if p.endswith('_pd'):
161            return [0, 45]
162        else:
163            return [-180, 180]
164    elif 'sld' in p:
165        return [-0.5, 10]
166    elif p.endswith('_pd'):
167        return [0, 1]
168    elif p == 'background':
169        return [0, 10]
170    elif p == 'scale':
171        return [0, 1e3]
172    elif p == 'case_num':
173        # RPA hack
174        return [0, 10]
175    elif v < 0:
176        # Kxy parameters in rpa model can be negative
177        return [2*v, -2*v]
178    else:
179        return [0, (2*v if v > 0 else 1)]
180
181def _randomize_one(p, v):
182    """
183    Randomize a single parameter.
184    """
185    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
186        return v
187    else:
188        return np.random.uniform(*parameter_range(p, v))
189
190def randomize_pars(pars, seed=None):
191    """
192    Generate random values for all of the parameters.
193
194    Valid ranges for the random number generator are guessed from the name of
195    the parameter; this will not account for constraints such as cap radius
196    greater than cylinder radius in the capped_cylinder model, so
197    :func:`constrain_pars` needs to be called afterward..
198    """
199    np.random.seed(seed)
200    # Note: the sort guarantees order `of calls to random number generator
201    pars = dict((p, _randomize_one(p, v))
202                for p, v in sorted(pars.items()))
203    return pars
204
205def constrain_pars(model_definition, pars):
206    """
207    Restrict parameters to valid values.
208
209    This includes model specific code for models such as capped_cylinder
210    which need to support within model constraints (cap radius more than
211    cylinder radius in this case).
212    """
213    name = model_definition.name
214    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
215        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
216    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
217        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
218
219    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
220    if name == 'guinier':
221        #q_max = 0.2  # mid q maximum
222        q_max = 1.0  # high q maximum
223        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
224        pars['rg'] = min(pars['rg'], rg_max)
225
226    if name == 'rpa':
227        # Make sure phi sums to 1.0
228        if pars['case_num'] < 2:
229            pars['Phia'] = 0.
230            pars['Phib'] = 0.
231        elif pars['case_num'] < 5:
232            pars['Phia'] = 0.
233        total = sum(pars['Phi'+c] for c in 'abcd')
234        for c in 'abcd':
235            pars['Phi'+c] /= total
236
237def parlist(pars):
238    """
239    Format the parameter list for printing.
240    """
241    return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
242
243def suppress_pd(pars):
244    """
245    Suppress theta_pd for now until the normalization is resolved.
246
247    May also suppress complete polydispersity of the model to test
248    models more quickly.
249    """
250    pars = pars.copy()
251    for p in pars:
252        if p.endswith("_pd_n"): pars[p] = 0
253    return pars
254
255def eval_sasview(model_definition, data):
256    """
257    Return a model calculator using the SasView fitting engine.
258    """
259    # importing sas here so that the error message will be that sas failed to
260    # import rather than the more obscure smear_selection not imported error
261    import sas
262    from sas.models.qsmearing import smear_selection
263
264    # convert model parameters from sasmodel form to sasview form
265    #print("old",sorted(pars.items()))
266    modelname, _ = revert_model(model_definition, {})
267    #print("new",sorted(_pars.items()))
268    sas = __import__('sas.models.'+modelname)
269    ModelClass = getattr(getattr(sas.models, modelname, None), modelname, None)
270    if ModelClass is None:
271        raise ValueError("could not find model %r in sas.models"%modelname)
272    model = ModelClass()
273    smearer = smear_selection(data, model=model)
274
275    if hasattr(data, 'qx_data'):
276        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
277        index = ((~data.mask) & (~np.isnan(data.data))
278                 & (q >= data.qmin) & (q <= data.qmax))
279        if smearer is not None:
280            smearer.model = model  # because smear_selection has a bug
281            smearer.accuracy = data.accuracy
282            smearer.set_index(index)
283            theory = lambda: smearer.get_value()
284        else:
285            theory = lambda: model.evalDistribution([data.qx_data[index],
286                                                     data.qy_data[index]])
287    elif smearer is not None:
288        theory = lambda: smearer(model.evalDistribution(data.x))
289    else:
290        theory = lambda: model.evalDistribution(data.x)
291
292    def calculator(**pars):
293        """
294        Sasview calculator for model.
295        """
296        # paying for parameter conversion each time to keep life simple, if not fast
297        _, pars = revert_model(model_definition, pars)
298        for k, v in pars.items():
299            parts = k.split('.')  # polydispersity components
300            if len(parts) == 2:
301                model.dispersion[parts[0]][parts[1]] = v
302            else:
303                model.setParam(k, v)
304        return theory()
305
306    calculator.engine = "sasview"
307    return calculator
308
309DTYPE_MAP = {
310    'half': '16',
311    'fast': 'fast',
312    'single': '32',
313    'double': '64',
314    'quad': '128',
315    'f16': '16',
316    'f32': '32',
317    'f64': '64',
318    'longdouble': '128',
319}
320def eval_opencl(model_definition, data, dtype='single', cutoff=0.):
321    """
322    Return a model calculator using the OpenCL calculation engine.
323    """
324    try:
325        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
326    except Exception as exc:
327        print(exc)
328        print("... trying again with single precision")
329        dtype = 'single'
330        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
331    calculator = DirectModel(data, model, cutoff=cutoff)
332    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
333    return calculator
334
335def eval_ctypes(model_definition, data, dtype='double', cutoff=0.):
336    """
337    Return a model calculator using the DLL calculation engine.
338    """
339    if dtype == 'quad':
340        dtype = 'longdouble'
341    model = core.load_model(model_definition, dtype=dtype, platform="dll")
342    calculator = DirectModel(data, model, cutoff=cutoff)
343    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
344    return calculator
345
346def time_calculation(calculator, pars, Nevals=1):
347    """
348    Compute the average calculation time over N evaluations.
349
350    An additional call is generated without polydispersity in order to
351    initialize the calculation engine, and make the average more stable.
352    """
353    # initialize the code so time is more accurate
354    value = calculator(**suppress_pd(pars))
355    toc = tic()
356    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
357        value = calculator(**pars)
358    average_time = toc()*1000./Nevals
359    return value, average_time
360
361def make_data(opts):
362    """
363    Generate an empty dataset, used with the model to set Q points
364    and resolution.
365
366    *opts* contains the options, with 'qmax', 'nq', 'res',
367    'accuracy', 'is2d' and 'view' parsed from the command line.
368    """
369    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
370    if opts['is2d']:
371        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
372        data.accuracy = opts['accuracy']
373        set_beam_stop(data, 0.004)
374        index = ~data.mask
375    else:
376        if opts['view'] == 'log':
377            qmax = math.log10(qmax)
378            q = np.logspace(qmax-3, qmax, nq)
379        else:
380            q = np.linspace(0.001*qmax, qmax, nq)
381        data = empty_data1D(q, resolution=res)
382        index = slice(None, None)
383    return data, index
384
385def make_engine(model_definition, data, dtype, cutoff):
386    """
387    Generate the appropriate calculation engine for the given datatype.
388
389    Datatypes with '!' appended are evaluated using external C DLLs rather
390    than OpenCL.
391    """
392    if dtype == 'sasview':
393        return eval_sasview(model_definition, data)
394    elif dtype.endswith('!'):
395        return eval_ctypes(model_definition, data, dtype=dtype[:-1],
396                           cutoff=cutoff)
397    else:
398        return eval_opencl(model_definition, data, dtype=dtype,
399                           cutoff=cutoff)
400
401def compare(opts, limits=None):
402    """
403    Preform a comparison using options from the command line.
404
405    *limits* are the limits on the values to use, either to set the y-axis
406    for 1D or to set the colormap scale for 2D.  If None, then they are
407    inferred from the data and returned. When exploring using Bumps,
408    the limits are set when the model is initially called, and maintained
409    as the values are adjusted, making it easier to see the effects of the
410    parameters.
411    """
412    Nbase, Ncomp = opts['n1'], opts['n2']
413    pars = opts['pars']
414    data = opts['data']
415
416    # Base calculation
417    if Nbase > 0:
418        base = opts['engines'][0]
419        try:
420            base_value, base_time = time_calculation(base, pars, Nbase)
421            print("%s t=%.1f ms, intensity=%.0f"
422                  % (base.engine, base_time, sum(base_value)))
423        except ImportError:
424            traceback.print_exc()
425            Nbase = 0
426
427    # Comparison calculation
428    if Ncomp > 0:
429        comp = opts['engines'][1]
430        try:
431            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
432            print("%s t=%.1f ms, intensity=%.0f"
433                  % (comp.engine, comp_time, sum(comp_value)))
434        except ImportError:
435            traceback.print_exc()
436            Ncomp = 0
437
438    # Compare, but only if computing both forms
439    if Nbase > 0 and Ncomp > 0:
440        resid = (base_value - comp_value)
441        relerr = resid/comp_value
442        _print_stats("|%s-%s|"
443                     % (base.engine, comp.engine) + (" "*(3+len(comp.engine))),
444                     resid)
445        _print_stats("|(%s-%s)/%s|"
446                     % (base.engine, comp.engine, comp.engine),
447                     relerr)
448
449    # Plot if requested
450    if not opts['plot'] and not opts['explore']: return
451    view = opts['view']
452    import matplotlib.pyplot as plt
453    if limits is None:
454        vmin, vmax = np.Inf, -np.Inf
455        if Nbase > 0:
456            vmin = min(vmin, min(base_value))
457            vmax = max(vmax, max(base_value))
458        if Ncomp > 0:
459            vmin = min(vmin, min(comp_value))
460            vmax = max(vmax, max(comp_value))
461        limits = vmin, vmax
462
463    if Nbase > 0:
464        if Ncomp > 0: plt.subplot(131)
465        plot_theory(data, base_value, view=view, plot_data=False, limits=limits)
466        plt.title("%s t=%.1f ms"%(base.engine, base_time))
467        #cbar_title = "log I"
468    if Ncomp > 0:
469        if Nbase > 0: plt.subplot(132)
470        plot_theory(data, comp_value, view=view, plot_data=False, limits=limits)
471        plt.title("%s t=%.1f ms"%(comp.engine, comp_time))
472        #cbar_title = "log I"
473    if Ncomp > 0 and Nbase > 0:
474        plt.subplot(133)
475        if '-abs' in opts:
476            err, errstr, errview = resid, "abs err", "linear"
477        else:
478            err, errstr, errview = abs(relerr), "rel err", "log"
479        #err,errstr = base/comp,"ratio"
480        plot_theory(data, None, resid=err, view=errview, plot_data=False)
481        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
482        #cbar_title = errstr if errview=="linear" else "log "+errstr
483    #if is2D:
484    #    h = plt.colorbar()
485    #    h.ax.set_title(cbar_title)
486
487    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
488        plt.figure()
489        v = relerr
490        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
491        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
492        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
493                   % (base.engine, comp.engine, comp.engine))
494        plt.ylabel('P(err)')
495        plt.title('Distribution of relative error between calculation engines')
496
497    if not opts['explore']:
498        plt.show()
499
500    return limits
501
502def _print_stats(label, err):
503    sorted_err = np.sort(abs(err))
504    p50 = int((len(err)-1)*0.50)
505    p98 = int((len(err)-1)*0.98)
506    data = [
507        "max:%.3e"%sorted_err[-1],
508        "median:%.3e"%sorted_err[p50],
509        "98%%:%.3e"%sorted_err[p98],
510        "rms:%.3e"%np.sqrt(np.mean(err**2)),
511        "zero-offset:%+.3e"%np.mean(err),
512        ]
513    print(label+"  "+"  ".join(data))
514
515
516
517# ===========================================================================
518#
519NAME_OPTIONS = set([
520    'plot', 'noplot',
521    'half', 'fast', 'single', 'double',
522    'single!', 'double!', 'quad!', 'sasview',
523    'lowq', 'midq', 'highq', 'exq',
524    '2d', '1d',
525    'preset', 'random',
526    'poly', 'mono',
527    'nopars', 'pars',
528    'rel', 'abs',
529    'linear', 'log', 'q4',
530    'hist', 'nohist',
531    'edit',
532    ])
533VALUE_OPTIONS = [
534    # Note: random is both a name option and a value option
535    'cutoff', 'random', 'nq', 'res', 'accuracy',
536    ]
537
538def columnize(L, indent="", width=79):
539    """
540    Format a list of strings into columns for printing.
541    """
542    column_width = max(len(w) for w in L) + 1
543    num_columns = (width - len(indent)) // column_width
544    num_rows = len(L) // num_columns
545    L = L + [""] * (num_rows*num_columns - len(L))
546    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
547    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
548             for row in zip(*columns)]
549    output = indent + ("\n"+indent).join(lines)
550    return output
551
552
553def get_demo_pars(model_definition):
554    """
555    Extract demo parameters from the model definition.
556    """
557    info = generate.make_info(model_definition)
558    # Get the default values for the parameters
559    pars = dict((p[0], p[2]) for p in info['parameters'])
560
561    # Fill in default values for the polydispersity parameters
562    for p in info['parameters']:
563        if p[4] in ('volume', 'orientation'):
564            pars[p[0]+'_pd'] = 0.0
565            pars[p[0]+'_pd_n'] = 0
566            pars[p[0]+'_pd_nsigma'] = 3.0
567            pars[p[0]+'_pd_type'] = "gaussian"
568
569    # Plug in values given in demo
570    pars.update(info['demo'])
571    return pars
572
573def parse_opts():
574    """
575    Parse command line options.
576    """
577    flags = [arg for arg in sys.argv[1:]
578             if arg.startswith('-')]
579    values = [arg for arg in sys.argv[1:]
580              if not arg.startswith('-') and '=' in arg]
581    args = [arg for arg in sys.argv[1:]
582            if not arg.startswith('-') and '=' not in arg]
583    models = "\n    ".join("%-15s"%v for v in MODELS)
584    if len(args) == 0:
585        print(USAGE)
586        print("\nAvailable models:")
587        print(columnize(MODELS, indent="  "))
588        sys.exit(1)
589    if args[0] not in MODELS:
590        print("Model %r not available. Use one of:\n    %s"%(args[0], models))
591        sys.exit(1)
592    if len(args) > 3:
593        print("expected parameters: model N1 N2")
594
595    invalid = [o[1:] for o in flags
596               if o[1:] not in NAME_OPTIONS
597               and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
598    if invalid:
599        print("Invalid options: %s"%(", ".join(invalid)))
600        sys.exit(1)
601
602
603    # pylint: disable=bad-whitespace
604    # Interpret the flags
605    opts = {
606        'plot'      : True,
607        'view'      : 'log',
608        'is2d'      : False,
609        'qmax'      : 0.05,
610        'nq'        : 128,
611        'res'       : 0.0,
612        'accuracy'  : 'Low',
613        'cutoff'    : 1e-5,
614        'seed'      : -1,  # default to preset
615        'mono'      : False,
616        'show_pars' : False,
617        'show_hist' : False,
618        'rel_err'   : True,
619        'explore'   : False,
620    }
621    engines = []
622    for arg in flags:
623        if arg == '-noplot':    opts['plot'] = False
624        elif arg == '-plot':    opts['plot'] = True
625        elif arg == '-linear':  opts['view'] = 'linear'
626        elif arg == '-log':     opts['view'] = 'log'
627        elif arg == '-q4':      opts['view'] = 'q4'
628        elif arg == '-1d':      opts['is2d'] = False
629        elif arg == '-2d':      opts['is2d'] = True
630        elif arg == '-exq':     opts['qmax'] = 10.0
631        elif arg == '-highq':   opts['qmax'] = 1.0
632        elif arg == '-midq':    opts['qmax'] = 0.2
633        elif arg == '-loq':     opts['qmax'] = 0.05
634        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
635        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
636        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
637        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
638        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
639        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
640        elif arg == '-preset':  opts['seed'] = -1
641        elif arg == '-mono':    opts['mono'] = True
642        elif arg == '-poly':    opts['mono'] = False
643        elif arg == '-pars':    opts['show_pars'] = True
644        elif arg == '-nopars':  opts['show_pars'] = False
645        elif arg == '-hist':    opts['show_hist'] = True
646        elif arg == '-nohist':  opts['show_hist'] = False
647        elif arg == '-rel':     opts['rel_err'] = True
648        elif arg == '-abs':     opts['rel_err'] = False
649        elif arg == '-half':    engines.append(arg[1:])
650        elif arg == '-fast':    engines.append(arg[1:])
651        elif arg == '-single':  engines.append(arg[1:])
652        elif arg == '-double':  engines.append(arg[1:])
653        elif arg == '-single!': engines.append(arg[1:])
654        elif arg == '-double!': engines.append(arg[1:])
655        elif arg == '-quad!':   engines.append(arg[1:])
656        elif arg == '-sasview': engines.append(arg[1:])
657        elif arg == '-edit':    opts['explore'] = True
658    # pylint: enable=bad-whitespace
659
660    if len(engines) == 0:
661        engines.extend(['single', 'sasview'])
662    elif len(engines) == 1:
663        if engines[0][0] != 'sasview':
664            engines.append('sasview')
665        else:
666            engines.append('single')
667    elif len(engines) > 2:
668        del engines[2:]
669
670    name = args[0]
671    model_definition = core.load_model_definition(name)
672
673    n1 = int(args[1]) if len(args) > 1 else 1
674    n2 = int(args[2]) if len(args) > 2 else 1
675
676    # Get demo parameters from model definition, or use default parameters
677    # if model does not define demo parameters
678    pars = get_demo_pars(model_definition)
679
680    # Fill in parameters given on the command line
681    presets = {}
682    for arg in values:
683        k, v = arg.split('=', 1)
684        if k not in pars:
685            # extract base name without polydispersity info
686            s = set(p.split('_pd')[0] for p in pars)
687            print("%r invalid; parameters are: %s"%(k, ", ".join(sorted(s))))
688            sys.exit(1)
689        presets[k] = float(v) if not k.endswith('type') else v
690
691    # randomize parameters
692    #pars.update(set_pars)  # set value before random to control range
693    if opts['seed'] > -1:
694        pars = randomize_pars(pars, seed=opts['seed'])
695        print("Randomize using -random=%i"%opts['seed'])
696    if opts['mono']:
697        pars = suppress_pd(pars)
698    pars.update(presets)  # set value after random to control value
699    constrain_pars(model_definition, pars)
700    constrain_new_to_old(model_definition, pars)
701    if opts['show_pars']:
702        print("pars " + str(parlist(pars)))
703
704    # Create the computational engines
705    data, _ = make_data(opts)
706    if n1:
707        base = make_engine(model_definition, data, engines[0], opts['cutoff'])
708    else:
709        base = None
710    if n2:
711        comp = make_engine(model_definition, data, engines[1], opts['cutoff'])
712    else:
713        comp = None
714
715    # pylint: disable=bad-whitespace
716    # Remember it all
717    opts.update({
718        'name'      : name,
719        'def'       : model_definition,
720        'n1'        : n1,
721        'n2'        : n2,
722        'presets'   : presets,
723        'pars'      : pars,
724        'data'      : data,
725        'engines'   : [base, comp],
726    })
727    # pylint: enable=bad-whitespace
728
729    return opts
730
731def explore(opts):
732    """
733    Explore the model using the Bumps GUI.
734    """
735    import wx
736    from bumps.names import FitProblem
737    from bumps.gui.app_frame import AppFrame
738
739    problem = FitProblem(Explore(opts))
740    is_mac = "cocoa" in wx.version()
741    app = wx.App()
742    frame = AppFrame(parent=None, title="explore")
743    if not is_mac: frame.Show()
744    frame.panel.set_model(model=problem)
745    frame.panel.Layout()
746    frame.panel.aui.Split(0, wx.TOP)
747    if is_mac: frame.Show()
748    app.MainLoop()
749
750class Explore(object):
751    """
752    Bumps wrapper for a SAS model comparison.
753
754    The resulting object can be used as a Bumps fit problem so that
755    parameters can be adjusted in the GUI, with plots updated on the fly.
756    """
757    def __init__(self, opts):
758        from bumps.cli import config_matplotlib
759        from . import bumps_model
760        config_matplotlib()
761        self.opts = opts
762        info = generate.make_info(opts['def'])
763        pars, pd_types = bumps_model.create_parameters(info, **opts['pars'])
764        if not opts['is2d']:
765            active = [base + ext
766                      for base in info['partype']['pd-1d']
767                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
768            active.extend(info['partype']['fixed-1d'])
769            for k in active:
770                v = pars[k]
771                v.range(*parameter_range(k, v.value))
772        else:
773            for k, v in pars.items():
774                v.range(*parameter_range(k, v.value))
775
776        self.pars = pars
777        self.pd_types = pd_types
778        self.limits = None
779
780    def numpoints(self):
781        """
782        Return the number of points.
783        """
784        return len(self.pars) + 1  # so dof is 1
785
786    def parameters(self):
787        """
788        Return a dictionary of parameters.
789        """
790        return self.pars
791
792    def nllf(self):
793        """
794        Return cost.
795        """
796        # pylint: disable=no-self-use
797        return 0.  # No nllf
798
799    def plot(self, view='log'):
800        """
801        Plot the data and residuals.
802        """
803        pars = dict((k, v.value) for k, v in self.pars.items())
804        pars.update(self.pd_types)
805        self.opts['pars'] = pars
806        limits = compare(self.opts, limits=self.limits)
807        if self.limits is None:
808            vmin, vmax = limits
809            vmax = 1.3*vmax
810            vmin = vmax*1e-7
811            self.limits = vmin, vmax
812
813
814def main():
815    """
816    Main program.
817    """
818    opts = parse_opts()
819    if opts['explore']:
820        explore(opts)
821    else:
822        compare(opts)
823
824if __name__ == "__main__":
825    main()
Note: See TracBrowser for help on using the repository browser.