source: sasmodels/sasmodels/compare.py @ 608e31e

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 608e31e was 608e31e, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

reduce lint

  • Property mode set to 100755
File size: 26.7 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Program to compare models using different compute engines.
5
6This program lets you compare results between OpenCL and DLL versions
7of the code and between precision (half, fast, single, double, quad),
8where fast precision is single precision using native functions for
9trig, etc., and may not be completely IEEE 754 compliant.  This lets
10make sure that the model calculations are stable, or if you need to
11tag the model as double precision only.
12
13Run using ./compare.sh (Linux, Mac) or compare.bat (Windows) in the
14sasmodels root to see the command line options.
15
16Note that there is no way within sasmodels to select between an
17OpenCL CPU device and a GPU device, but you can do so by setting the
18PYOPENCL_CTX environment variable ahead of time.  Start a python
19interpreter and enter::
20
21    import pyopencl as cl
22    cl.create_some_context()
23
24This will prompt you to select from the available OpenCL devices
25and tell you which string to use for the PYOPENCL_CTX variable.
26On Windows you will need to remove the quotes.
27"""
28
29from __future__ import print_function
30
31USAGE = """
32usage: compare.py model N1 N2 [options...] [key=val]
33
34Compare the speed and value for a model between the SasView original and the
35sasmodels rewrite.
36
37model is the name of the model to compare (see below).
38N1 is the number of times to run sasmodels (default=1).
39N2 is the number times to run sasview (default=1).
40
41Options (* for default):
42
43    -plot*/-noplot plots or suppress the plot of the model
44    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
45    -nq=128 sets the number of Q points in the data set
46    -1d*/-2d computes 1d or 2d data
47    -preset*/-random[=seed] preset or random parameters
48    -mono/-poly* force monodisperse/polydisperse
49    -cutoff=1e-5* cutoff value for including a point in polydispersity
50    -pars/-nopars* prints the parameter set or not
51    -abs/-rel* plot relative or absolute error
52    -linear/-log*/-q4 intensity scaling
53    -hist/-nohist* plot histogram of relative error
54    -res=0 sets the resolution width dQ/Q if calculating with resolution
55    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
56    -edit starts the parameter explorer
57
58Any two calculation engines can be selected for comparison:
59
60    -single/-double/-half/-fast sets an OpenCL calculation engine
61    -single!/-double!/-quad! sets an OpenMP calculation engine
62    -sasview sets the sasview calculation engine
63
64The default is -single -sasview.  Note that the interpretation of quad
65precision depends on architecture, and may vary from 64-bit to 128-bit,
66with 80-bit floats being common (1e-19 precision).
67
68Key=value pairs allow you to set specific values for the model parameters.
69"""
70
71# Update docs with command line usage string.   This is separate from the usual
72# doc string so that we can display it at run time if there is an error.
73# lin
74__doc__ = __doc__ + """
75Program description
76-------------------
77
78""" + USAGE
79
80
81
82import sys
83import math
84from os.path import basename, dirname, join as joinpath
85import glob
86import datetime
87import traceback
88
89import numpy as np
90
91ROOT = dirname(__file__)
92sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
93
94
95from . import core
96from . import kerneldll
97from . import generate
98from .data import plot_theory, empty_data1D, empty_data2D
99from .direct_model import DirectModel
100from .convert import revert_model, constrain_new_to_old
101kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
102
103# List of available models
104MODELS = [basename(f)[:-3]
105          for f in sorted(glob.glob(joinpath(ROOT, "models", "[a-zA-Z]*.py")))]
106
107# CRUFT python 2.6
108if not hasattr(datetime.timedelta, 'total_seconds'):
109    def delay(dt):
110        """Return number date-time delta as number seconds"""
111        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
112else:
113    def delay(dt):
114        """Return number date-time delta as number seconds"""
115        return dt.total_seconds()
116
117
118def tic():
119    """
120    Timer function.
121
122    Use "toc=tic()" to start the clock and "toc()" to measure
123    a time interval.
124    """
125    then = datetime.datetime.now()
126    return lambda: delay(datetime.datetime.now() - then)
127
128
129def set_beam_stop(data, radius, outer=None):
130    """
131    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
132
133    Note: this function does not use the sasview package
134    """
135    if hasattr(data, 'qx_data'):
136        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
137        data.mask = (q < radius)
138        if outer is not None:
139            data.mask |= (q >= outer)
140    else:
141        data.mask = (data.x < radius)
142        if outer is not None:
143            data.mask |= (data.x >= outer)
144
145
146def parameter_range(p, v):
147    """
148    Choose a parameter range based on parameter name and initial value.
149    """
150    if p.endswith('_pd_n'):
151        return [0, 100]
152    elif p.endswith('_pd_nsigma'):
153        return [0, 5]
154    elif p.endswith('_pd_type'):
155        return v
156    elif any(s in p for s in ('theta', 'phi', 'psi')):
157        # orientation in [-180,180], orientation pd in [0,45]
158        if p.endswith('_pd'):
159            return [0, 45]
160        else:
161            return [-180, 180]
162    elif 'sld' in p:
163        return [-0.5, 10]
164    elif p.endswith('_pd'):
165        return [0, 1]
166    elif p == 'background':
167        return [0, 10]
168    elif p == 'scale':
169        return [0, 1e3]
170    elif p == 'case_num':
171        # RPA hack
172        return [0, 10]
173    elif v < 0:
174        # Kxy parameters in rpa model can be negative
175        return [2*v, -2*v]
176    else:
177        return [0, (2*v if v > 0 else 1)]
178
179def _randomize_one(p, v):
180    """
181    Randomize a single parameter.
182    """
183    if any(p.endswith(s) for s in ('_pd_n', '_pd_nsigma', '_pd_type')):
184        return v
185    else:
186        return np.random.uniform(*parameter_range(p, v))
187
188def randomize_pars(pars, seed=None):
189    """
190    Generate random values for all of the parameters.
191
192    Valid ranges for the random number generator are guessed from the name of
193    the parameter; this will not account for constraints such as cap radius
194    greater than cylinder radius in the capped_cylinder model, so
195    :func:`constrain_pars` needs to be called afterward..
196    """
197    np.random.seed(seed)
198    # Note: the sort guarantees order `of calls to random number generator
199    pars = dict((p, _randomize_one(p, v))
200                for p, v in sorted(pars.items()))
201    return pars
202
203def constrain_pars(model_definition, pars):
204    """
205    Restrict parameters to valid values.
206
207    This includes model specific code for models such as capped_cylinder
208    which need to support within model constraints (cap radius more than
209    cylinder radius in this case).
210    """
211    name = model_definition.name
212    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
213        pars['radius'], pars['cap_radius'] = pars['cap_radius'], pars['radius']
214    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
215        pars['radius'], pars['bell_radius'] = pars['bell_radius'], pars['radius']
216
217    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
218    if name == 'guinier':
219        #q_max = 0.2  # mid q maximum
220        q_max = 1.0  # high q maximum
221        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
222        pars['rg'] = min(pars['rg'], rg_max)
223
224    if name == 'rpa':
225        # Make sure phi sums to 1.0
226        if pars['case_num'] < 2:
227            pars['Phia'] = 0.
228            pars['Phib'] = 0.
229        elif pars['case_num'] < 5:
230            pars['Phia'] = 0.
231        total = sum(pars['Phi'+c] for c in 'abcd')
232        for c in 'abcd':
233            pars['Phi'+c] /= total
234
235def parlist(pars):
236    """
237    Format the parameter list for printing.
238    """
239    return "\n".join("%s: %s"%(p, v) for p, v in sorted(pars.items()))
240
241def suppress_pd(pars):
242    """
243    Suppress theta_pd for now until the normalization is resolved.
244
245    May also suppress complete polydispersity of the model to test
246    models more quickly.
247    """
248    pars = pars.copy()
249    for p in pars:
250        if p.endswith("_pd_n"): pars[p] = 0
251    return pars
252
253def eval_sasview(model_definition, data):
254    """
255    Return a model calculator using the SasView fitting engine.
256    """
257    # importing sas here so that the error message will be that sas failed to
258    # import rather than the more obscure smear_selection not imported error
259    import sas
260    from sas.models.qsmearing import smear_selection
261
262    # convert model parameters from sasmodel form to sasview form
263    #print("old",sorted(pars.items()))
264    modelname, _ = revert_model(model_definition, {})
265    #print("new",sorted(_pars.items()))
266    sas = __import__('sas.models.'+modelname)
267    ModelClass = getattr(getattr(sas.models, modelname, None), modelname, None)
268    if ModelClass is None:
269        raise ValueError("could not find model %r in sas.models"%modelname)
270    model = ModelClass()
271    smearer = smear_selection(data, model=model)
272
273    if hasattr(data, 'qx_data'):
274        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
275        index = ((~data.mask) & (~np.isnan(data.data))
276                 & (q >= data.qmin) & (q <= data.qmax))
277        if smearer is not None:
278            smearer.model = model  # because smear_selection has a bug
279            smearer.accuracy = data.accuracy
280            smearer.set_index(index)
281            theory = lambda: smearer.get_value()
282        else:
283            theory = lambda: model.evalDistribution([data.qx_data[index], data.qy_data[index]])
284    elif smearer is not None:
285        theory = lambda: smearer(model.evalDistribution(data.x))
286    else:
287        theory = lambda: model.evalDistribution(data.x)
288
289    def calculator(**pars):
290        """
291        Sasview calculator for model.
292        """
293        # paying for parameter conversion each time to keep life simple, if not fast
294        _, pars = revert_model(model_definition, pars)
295        for k, v in pars.items():
296            parts = k.split('.')  # polydispersity components
297            if len(parts) == 2:
298                model.dispersion[parts[0]][parts[1]] = v
299            else:
300                model.setParam(k, v)
301        return theory()
302
303    calculator.engine = "sasview"
304    return calculator
305
306DTYPE_MAP = {
307    'half': '16',
308    'fast': 'fast',
309    'single': '32',
310    'double': '64',
311    'quad': '128',
312    'f16': '16',
313    'f32': '32',
314    'f64': '64',
315    'longdouble': '128',
316}
317def eval_opencl(model_definition, data, dtype='single', cutoff=0.):
318    """
319    Return a model calculator using the OpenCL calculation engine.
320    """
321    try:
322        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
323    except Exception as exc:
324        print(exc)
325        print("... trying again with single precision")
326        dtype = 'single'
327        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
328    calculator = DirectModel(data, model, cutoff=cutoff)
329    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
330    return calculator
331
332def eval_ctypes(model_definition, data, dtype='double', cutoff=0.):
333    """
334    Return a model calculator using the DLL calculation engine.
335    """
336    if dtype == 'quad':
337        dtype = 'longdouble'
338    model = core.load_model(model_definition, dtype=dtype, platform="dll")
339    calculator = DirectModel(data, model, cutoff=cutoff)
340    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
341    return calculator
342
343def time_calculation(calculator, pars, Nevals=1):
344    """
345    Compute the average calculation time over N evaluations.
346
347    An additional call is generated without polydispersity in order to
348    initialize the calculation engine, and make the average more stable.
349    """
350    # initialize the code so time is more accurate
351    value = calculator(**suppress_pd(pars))
352    toc = tic()
353    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
354        value = calculator(**pars)
355    average_time = toc()*1000./Nevals
356    return value, average_time
357
358def make_data(opts):
359    """
360    Generate an empty dataset, used with the model to set Q points
361    and resolution.
362
363    *opts* contains the options, with 'qmax', 'nq', 'res',
364    'accuracy', 'is2d' and 'view' parsed from the command line.
365    """
366    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
367    if opts['is2d']:
368        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
369        data.accuracy = opts['accuracy']
370        set_beam_stop(data, 0.004)
371        index = ~data.mask
372    else:
373        if opts['view'] == 'log':
374            qmax = math.log10(qmax)
375            q = np.logspace(qmax-3, qmax, nq)
376        else:
377            q = np.linspace(0.001*qmax, qmax, nq)
378        data = empty_data1D(q, resolution=res)
379        index = slice(None, None)
380    return data, index
381
382def make_engine(model_definition, data, dtype, cutoff):
383    """
384    Generate the appropriate calculation engine for the given datatype.
385
386    Datatypes with '!' appended are evaluated using external C DLLs rather
387    than OpenCL.
388    """
389    if dtype == 'sasview':
390        return eval_sasview(model_definition, data)
391    elif dtype.endswith('!'):
392        return eval_ctypes(model_definition, data, dtype=dtype[:-1],
393                           cutoff=cutoff)
394    else:
395        return eval_opencl(model_definition, data, dtype=dtype,
396                           cutoff=cutoff)
397
398def compare(opts, limits=None):
399    """
400    Preform a comparison using options from the command line.
401
402    *limits* are the limits on the values to use, either to set the y-axis
403    for 1D or to set the colormap scale for 2D.  If None, then they are
404    inferred from the data and returned. When exploring using Bumps,
405    the limits are set when the model is initially called, and maintained
406    as the values are adjusted, making it easier to see the effects of the
407    parameters.
408    """
409    Nbase, Ncomp = opts['n1'], opts['n2']
410    pars = opts['pars']
411    data = opts['data']
412
413    # Base calculation
414    if Nbase > 0:
415        base = opts['engines'][0]
416        try:
417            base_value, base_time = time_calculation(base, pars, Nbase)
418            print("%s t=%.1f ms, intensity=%.0f"%(base.engine, base_time, sum(base_value)))
419        except ImportError:
420            traceback.print_exc()
421            Nbase = 0
422
423    # Comparison calculation
424    if Ncomp > 0:
425        comp = opts['engines'][1]
426        try:
427            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
428            print("%s t=%.1f ms, intensity=%.0f"%(comp.engine, comp_time, sum(comp_value)))
429        except ImportError:
430            traceback.print_exc()
431            Ncomp = 0
432
433    # Compare, but only if computing both forms
434    if Nbase > 0 and Ncomp > 0:
435        #print("speedup %.2g"%(comp_time/base_time))
436        #print("max |base/comp|", max(abs(base_value/comp_value)), "%.15g"%max(abs(base_value)), "%.15g"%max(abs(comp_value)))
437        #comp *= max(base_value/comp_value)
438        resid = (base_value - comp_value)
439        relerr = resid/comp_value
440        _print_stats("|%s-%s|"%(base.engine, comp.engine) + (" "*(3+len(comp.engine))),
441                     resid)
442        _print_stats("|(%s-%s)/%s|"%(base.engine, comp.engine, comp.engine),
443                     relerr)
444
445    # Plot if requested
446    if not opts['plot'] and not opts['explore']: return
447    view = opts['view']
448    import matplotlib.pyplot as plt
449    if limits is None:
450        vmin, vmax = np.Inf, -np.Inf
451        if Nbase > 0:
452            vmin = min(vmin, min(base_value))
453            vmax = max(vmax, max(base_value))
454        if Ncomp > 0:
455            vmin = min(vmin, min(comp_value))
456            vmax = max(vmax, max(comp_value))
457        limits = vmin, vmax
458
459    if Nbase > 0:
460        if Ncomp > 0: plt.subplot(131)
461        plot_theory(data, base_value, view=view, plot_data=False, limits=limits)
462        plt.title("%s t=%.1f ms"%(base.engine, base_time))
463        #cbar_title = "log I"
464    if Ncomp > 0:
465        if Nbase > 0: plt.subplot(132)
466        plot_theory(data, comp_value, view=view, plot_data=False, limits=limits)
467        plt.title("%s t=%.1f ms"%(comp.engine, comp_time))
468        #cbar_title = "log I"
469    if Ncomp > 0 and Nbase > 0:
470        plt.subplot(133)
471        if '-abs' in opts:
472            err, errstr, errview = resid, "abs err", "linear"
473        else:
474            err, errstr, errview = abs(relerr), "rel err", "log"
475        #err,errstr = base/comp,"ratio"
476        plot_theory(data, None, resid=err, view=errview, plot_data=False)
477        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
478        #cbar_title = errstr if errview=="linear" else "log "+errstr
479    #if is2D:
480    #    h = plt.colorbar()
481    #    h.ax.set_title(cbar_title)
482
483    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
484        plt.figure()
485        v = relerr
486        v[v == 0] = 0.5*np.min(np.abs(v[v != 0]))
487        plt.hist(np.log10(np.abs(v)), normed=1, bins=50)
488        plt.xlabel('log10(err), err = |(%s - %s) / %s|'
489                   % (base.engine, comp.engine, comp.engine))
490        plt.ylabel('P(err)')
491        plt.title('Distribution of relative error between calculation engines')
492
493    if not opts['explore']:
494        plt.show()
495
496    return limits
497
498def _print_stats(label, err):
499    sorted_err = np.sort(abs(err))
500    p50 = int((len(err)-1)*0.50)
501    p98 = int((len(err)-1)*0.98)
502    data = [
503        "max:%.3e"%sorted_err[-1],
504        "median:%.3e"%sorted_err[p50],
505        "98%%:%.3e"%sorted_err[p98],
506        "rms:%.3e"%np.sqrt(np.mean(err**2)),
507        "zero-offset:%+.3e"%np.mean(err),
508        ]
509    print(label+"  "+"  ".join(data))
510
511
512
513# ===========================================================================
514#
515NAME_OPTIONS = set([
516    'plot', 'noplot',
517    'half', 'fast', 'single', 'double',
518    'single!', 'double!', 'quad!', 'sasview',
519    'lowq', 'midq', 'highq', 'exq',
520    '2d', '1d',
521    'preset', 'random',
522    'poly', 'mono',
523    'nopars', 'pars',
524    'rel', 'abs',
525    'linear', 'log', 'q4',
526    'hist', 'nohist',
527    'edit',
528    ])
529VALUE_OPTIONS = [
530    # Note: random is both a name option and a value option
531    'cutoff', 'random', 'nq', 'res', 'accuracy',
532    ]
533
534def columnize(L, indent="", width=79):
535    """
536    Format a list of strings into columns for printing.
537    """
538    column_width = max(len(w) for w in L) + 1
539    num_columns = (width - len(indent)) // column_width
540    num_rows = len(L) // num_columns
541    L = L + [""] * (num_rows*num_columns - len(L))
542    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
543    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
544             for row in zip(*columns)]
545    output = indent + ("\n"+indent).join(lines)
546    return output
547
548
549def get_demo_pars(model_definition):
550    """
551    Extract demo parameters from the model definition.
552    """
553    info = generate.make_info(model_definition)
554    # Get the default values for the parameters
555    pars = dict((p[0], p[2]) for p in info['parameters'])
556
557    # Fill in default values for the polydispersity parameters
558    for p in info['parameters']:
559        if p[4] in ('volume', 'orientation'):
560            pars[p[0]+'_pd'] = 0.0
561            pars[p[0]+'_pd_n'] = 0
562            pars[p[0]+'_pd_nsigma'] = 3.0
563            pars[p[0]+'_pd_type'] = "gaussian"
564
565    # Plug in values given in demo
566    pars.update(info['demo'])
567    return pars
568
569def parse_opts():
570    """
571    Parse command line options.
572    """
573    flags = [arg for arg in sys.argv[1:]
574             if arg.startswith('-')]
575    values = [arg for arg in sys.argv[1:]
576              if not arg.startswith('-') and '=' in arg]
577    args = [arg for arg in sys.argv[1:]
578            if not arg.startswith('-') and '=' not in arg]
579    models = "\n    ".join("%-15s"%v for v in MODELS)
580    if len(args) == 0:
581        print(USAGE)
582        print("\nAvailable models:")
583        print(columnize(MODELS, indent="  "))
584        sys.exit(1)
585    if args[0] not in MODELS:
586        print("Model %r not available. Use one of:\n    %s"%(args[0], models))
587        sys.exit(1)
588    if len(args) > 3:
589        print("expected parameters: model N1 N2")
590
591    invalid = [o[1:] for o in flags
592               if o[1:] not in NAME_OPTIONS
593                   and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
594    if invalid:
595        print("Invalid options: %s"%(", ".join(invalid)))
596        sys.exit(1)
597
598
599    # Interpret the flags
600    opts = {
601        'plot'      : True,
602        'view'      : 'log',
603        'is2d'      : False,
604        'qmax'      : 0.05,
605        'nq'        : 128,
606        'res'       : 0.0,
607        'accuracy'  : 'Low',
608        'cutoff'    : 1e-5,
609        'seed'      : -1,  # default to preset
610        'mono'      : False,
611        'show_pars' : False,
612        'show_hist' : False,
613        'rel_err'   : True,
614        'explore'   : False,
615    }
616    engines = []
617    for arg in flags:
618        if arg == '-noplot':    opts['plot'] = False
619        elif arg == '-plot':    opts['plot'] = True
620        elif arg == '-linear':  opts['view'] = 'linear'
621        elif arg == '-log':     opts['view'] = 'log'
622        elif arg == '-q4':      opts['view'] = 'q4'
623        elif arg == '-1d':      opts['is2d'] = False
624        elif arg == '-2d':      opts['is2d'] = True
625        elif arg == '-exq':     opts['qmax'] = 10.0
626        elif arg == '-highq':   opts['qmax'] = 1.0
627        elif arg == '-midq':    opts['qmax'] = 0.2
628        elif arg == '-loq':     opts['qmax'] = 0.05
629        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
630        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
631        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
632        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
633        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
634        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
635        elif arg == '-preset':  opts['seed'] = -1
636        elif arg == '-mono':    opts['mono'] = True
637        elif arg == '-poly':    opts['mono'] = False
638        elif arg == '-pars':    opts['show_pars'] = True
639        elif arg == '-nopars':  opts['show_pars'] = False
640        elif arg == '-hist':    opts['show_hist'] = True
641        elif arg == '-nohist':  opts['show_hist'] = False
642        elif arg == '-rel':     opts['rel_err'] = True
643        elif arg == '-abs':     opts['rel_err'] = False
644        elif arg == '-half':    engines.append(arg[1:])
645        elif arg == '-fast':    engines.append(arg[1:])
646        elif arg == '-single':  engines.append(arg[1:])
647        elif arg == '-double':  engines.append(arg[1:])
648        elif arg == '-single!': engines.append(arg[1:])
649        elif arg == '-double!': engines.append(arg[1:])
650        elif arg == '-quad!':   engines.append(arg[1:])
651        elif arg == '-sasview': engines.append(arg[1:])
652        elif arg == '-edit':    opts['explore'] = True
653
654    if len(engines) == 0:
655        engines.extend(['single', 'sasview'])
656    elif len(engines) == 1:
657        if engines[0][0] != 'sasview':
658            engines.append('sasview')
659        else:
660            engines.append('single')
661    elif len(engines) > 2:
662        del engines[2:]
663
664    name = args[0]
665    model_definition = core.load_model_definition(name)
666
667    n1 = int(args[1]) if len(args) > 1 else 1
668    n2 = int(args[2]) if len(args) > 2 else 1
669
670    # Get demo parameters from model definition, or use default parameters
671    # if model does not define demo parameters
672    pars = get_demo_pars(model_definition)
673
674    # Fill in parameters given on the command line
675    presets = {}
676    for arg in values:
677        k,v = arg.split('=',1)
678        if k not in pars:
679            # extract base name without polydispersity info
680            s = set(p.split('_pd')[0] for p in pars)
681            print("%r invalid; parameters are: %s"%(k,", ".join(sorted(s))))
682            sys.exit(1)
683        presets[k] = float(v) if not k.endswith('type') else v
684
685    # randomize parameters
686    #pars.update(set_pars)  # set value before random to control range
687    if opts['seed'] > -1:
688        pars = randomize_pars(pars, seed=opts['seed'])
689        print("Randomize using -random=%i"%opts['seed'])
690    if opts['mono']:
691        pars = suppress_pd(pars)
692    pars.update(presets)  # set value after random to control value
693    constrain_pars(model_definition, pars)
694    constrain_new_to_old(model_definition, pars)
695    if opts['show_pars']:
696        print("pars " + str(parlist(pars)))
697
698    # Create the computational engines
699    data, _index = make_data(opts)
700    if n1:
701        base = make_engine(model_definition, data, engines[0], opts['cutoff'])
702    else:
703        base = None
704    if n2:
705        comp = make_engine(model_definition, data, engines[1], opts['cutoff'])
706    else:
707        comp = None
708
709    # Remember it all
710    opts.update({
711        'name'      : name,
712        'def'       : model_definition,
713        'n1'        : n1,
714        'n2'        : n2,
715        'presets'   : presets,
716        'pars'      : pars,
717        'data'      : data,
718        'engines'   : [base, comp],
719    })
720
721    return opts
722
723def main():
724    opts = parse_opts()
725    if opts['explore']:
726        explore(opts)
727    else:
728        compare(opts)
729
730def explore(opts):
731    import wx
732    from bumps.names import FitProblem
733    from bumps.gui.app_frame import AppFrame
734
735    problem = FitProblem(Explore(opts))
736    isMac = "cocoa" in wx.version()
737    app = wx.App()
738    frame = AppFrame(parent=None, title="explore")
739    if not isMac: frame.Show()
740    frame.panel.set_model(model=problem)
741    frame.panel.Layout()
742    frame.panel.aui.Split(0, wx.TOP)
743    if isMac: frame.Show()
744    app.MainLoop()
745
746class Explore(object):
747    """
748    Return a bumps wrapper for a SAS model comparison.
749    """
750    def __init__(self, opts):
751        from bumps.cli import config_matplotlib
752        from . import bumps_model
753        config_matplotlib()
754        self.opts = opts
755        info = generate.make_info(opts['def'])
756        pars, pd_types = bumps_model.create_parameters(info, **opts['pars'])
757        if not opts['is2d']:
758            active = [base + ext
759                      for base in info['partype']['pd-1d']
760                      for ext in ['', '_pd', '_pd_n', '_pd_nsigma']]
761            active.extend(info['partype']['fixed-1d'])
762            for k in active:
763                v = pars[k]
764                v.range(*parameter_range(k, v.value))
765        else:
766            for k, v in pars.items():
767                v.range(*parameter_range(k, v.value))
768
769        self.pars = pars
770        self.pd_types = pd_types
771        self.limits = None
772
773    def numpoints(self):
774        """
775        Return the number of points.
776        """
777        return len(self.pars) + 1  # so dof is 1
778
779    def parameters(self):
780        """
781        Return a dictionary of parameters.
782        """
783        return self.pars
784
785    def nllf(self):
786        """
787        Return cost.
788        """
789        return 0.  # No nllf
790
791    def plot(self, view='log'):
792        """
793        Plot the data and residuals.
794        """
795        pars = dict((k, v.value) for k, v in self.pars.items())
796        pars.update(self.pd_types)
797        self.opts['pars'] = pars
798        limits = compare(self.opts, limits=self.limits)
799        if self.limits is None:
800            vmin, vmax = limits
801            vmax = 1.3*vmax
802            vmin = vmax*1e-7
803            self.limits = vmin, vmax
804
805
806if __name__ == "__main__":
807    main()
Note: See TracBrowser for help on using the repository browser.