source: sasmodels/sasmodels/compare.py @ 9c79c32

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9c79c32 was 8b25ee1, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

in compare, -mono applies before specific parameters so that '-mono parname_pd_n=10' is a quick way to examine polydispersity on a single parameter

  • Property mode set to 100755
File size: 22.6 KB
RevLine 
[8a20be5]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
[87985ca]4import sys
5import math
[d547f16]6from os.path import basename, dirname, join as joinpath
7import glob
[7cf2cfd]8import datetime
[5753e4e]9import traceback
[87985ca]10
[1726b21]11import numpy as np
[473183c]12
[29fc2a3]13ROOT = dirname(__file__)
14sys.path.insert(0, ROOT)  # Make sure sasmodels is first on the path
15
16
[e922c5d]17from . import core
18from . import kerneldll
[cd3dba0]19from . import generate
[e922c5d]20from .data import plot_theory, empty_data1D, empty_data2D
21from .direct_model import DirectModel
[9a66e65]22from .convert import revert_model, constrain_new_to_old
[750ffa5]23kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True
[87985ca]24
[d547f16]25# List of available models
26MODELS = [basename(f)[:-3]
[e922c5d]27          for f in sorted(glob.glob(joinpath(ROOT,"models","[a-zA-Z]*.py")))]
[d547f16]28
[7cf2cfd]29# CRUFT python 2.6
30if not hasattr(datetime.timedelta, 'total_seconds'):
31    def delay(dt):
32        """Return number date-time delta as number seconds"""
33        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
34else:
35    def delay(dt):
36        """Return number date-time delta as number seconds"""
37        return dt.total_seconds()
38
39
40def tic():
41    """
42    Timer function.
43
44    Use "toc=tic()" to start the clock and "toc()" to measure
45    a time interval.
46    """
47    then = datetime.datetime.now()
48    return lambda: delay(datetime.datetime.now() - then)
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54
55    Note: this function does not use the sasview package
56    """
57    if hasattr(data, 'qx_data'):
58        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
59        data.mask = (q < radius)
60        if outer is not None:
61            data.mask |= (q >= outer)
62    else:
63        data.mask = (data.x < radius)
64        if outer is not None:
65            data.mask |= (data.x >= outer)
66
[8a20be5]67
[ec7e360]68def parameter_range(p, v):
[87985ca]69    """
[ec7e360]70    Choose a parameter range based on parameter name and initial value.
[87985ca]71    """
[ec7e360]72    if p.endswith('_pd_n'):
73        return [0, 100]
74    elif p.endswith('_pd_nsigma'):
75        return [0, 5]
76    elif p.endswith('_pd_type'):
[87985ca]77        return v
78    elif any(s in p for s in ('theta','phi','psi')):
79        # orientation in [-180,180], orientation pd in [0,45]
80        if p.endswith('_pd'):
[ec7e360]81            return [0,45]
[87985ca]82        else:
[ec7e360]83            return [-180, 180]
[87985ca]84    elif 'sld' in p:
[ec7e360]85        return [-0.5, 10]
[87985ca]86    elif p.endswith('_pd'):
[ec7e360]87        return [0, 1]
88    elif p in ['background', 'scale']:
89        return [0, 1e3]
[87985ca]90    else:
[ec7e360]91        return [0, (2*v if v>0 else 1)]
[87985ca]92
[ec7e360]93def _randomize_one(p, v):
94    """
95    Randomizing parameter.
96    """
97    if any(p.endswith(s) for s in ('_pd_n','_pd_nsigma','_pd_type')):
98        return v
99    else:
100        return np.random.uniform(*parameter_range(p, v))
[cd3dba0]101
[ec7e360]102def randomize_pars(pars, seed=None):
103    np.random.seed(seed)
104    # Note: the sort guarantees order `of calls to random number generator
105    pars = dict((p,_randomize_one(p,v))
106                for p,v in sorted(pars.items()))
107    return pars
[cd3dba0]108
109def constrain_pars(model_definition, pars):
[9a66e65]110    """
111    Restrict parameters to valid values.
112    """
[cd3dba0]113    name = model_definition.name
[216a9e1]114    if name == 'capped_cylinder' and pars['cap_radius'] < pars['radius']:
115        pars['radius'],pars['cap_radius'] = pars['cap_radius'],pars['radius']
[b514adf]116    if name == 'barbell' and pars['bell_radius'] < pars['radius']:
117        pars['radius'],pars['bell_radius'] = pars['bell_radius'],pars['radius']
118
119    # Limit guinier to an Rg such that Iq > 1e-30 (single precision cutoff)
120    if name == 'guinier':
121        #q_max = 0.2  # mid q maximum
122        q_max = 1.0  # high q maximum
123        rg_max = np.sqrt(90*np.log(10) + 3*np.log(pars['scale']))/q_max
124        pars['rg'] = min(pars['rg'],rg_max)
[cd3dba0]125
[87985ca]126def parlist(pars):
127    return "\n".join("%s: %s"%(p,v) for p,v in sorted(pars.items()))
128
129def suppress_pd(pars):
130    """
131    Suppress theta_pd for now until the normalization is resolved.
132
133    May also suppress complete polydispersity of the model to test
134    models more quickly.
135    """
[f4f3919]136    pars = pars.copy()
[87985ca]137    for p in pars:
[8b25ee1]138        if p.endswith("_pd_n"): pars[p] = 0
[f4f3919]139    return pars
[87985ca]140
[ec7e360]141def eval_sasview(model_definition, data):
[dc056b9]142    # importing sas here so that the error message will be that sas failed to
143    # import rather than the more obscure smear_selection not imported error
[2bebe2b]144    import sas
[346bc88]145    from sas.models.qsmearing import smear_selection
[ec7e360]146
147    # convert model parameters from sasmodel form to sasview form
148    #print("old",sorted(pars.items()))
149    modelname, pars = revert_model(model_definition, {})
150    #print("new",sorted(pars.items()))
151    sas = __import__('sas.models.'+modelname)
152    ModelClass = getattr(getattr(sas.models,modelname,None),modelname,None)
153    if ModelClass is None:
154        raise ValueError("could not find model %r in sas.models"%modelname)
155    model = ModelClass()
[346bc88]156    smearer = smear_selection(data, model=model)
[216a9e1]157
[ec7e360]158    if hasattr(data, 'qx_data'):
159        q = np.sqrt(data.qx_data**2 + data.qy_data**2)
160        index = ((~data.mask) & (~np.isnan(data.data))
161                 & (q >= data.qmin) & (q <= data.qmax))
162        if smearer is not None:
163            smearer.model = model  # because smear_selection has a bug
164            smearer.accuracy = data.accuracy
165            smearer.set_index(index)
166            theory = lambda: smearer.get_value()
167        else:
168            theory = lambda: model.evalDistribution([data.qx_data[index], data.qy_data[index]])
169    elif smearer is not None:
170        theory = lambda: smearer(model.evalDistribution(data.x))
171    else:
172        theory = lambda: model.evalDistribution(data.x)
173
174    def calculator(**pars):
175        # paying for parameter conversion each time to keep life simple, if not fast
176        _, pars = revert_model(model_definition, pars)
177        for k,v in pars.items():
178            parts = k.split('.')  # polydispersity components
179            if len(parts) == 2:
180                model.dispersion[parts[0]][parts[1]] = v
181            else:
182                model.setParam(k, v)
183        return theory()
184
185    calculator.engine = "sasview"
186    return calculator
187
188DTYPE_MAP = {
189    'half': '16',
190    'fast': 'fast',
191    'single': '32',
192    'double': '64',
193    'quad': '128',
194    'f16': '16',
195    'f32': '32',
196    'f64': '64',
197    'longdouble': '128',
198}
199def eval_opencl(model_definition, data, dtype='single', cutoff=0.):
[216a9e1]200    try:
[ec7e360]201        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[9404dd3]202    except Exception as exc:
203        print(exc)
204        print("... trying again with single precision")
[ec7e360]205        dtype = 'single'
206        model = core.load_model(model_definition, dtype=dtype, platform="ocl")
[7cf2cfd]207    calculator = DirectModel(data, model, cutoff=cutoff)
[ec7e360]208    calculator.engine = "OCL%s"%DTYPE_MAP[dtype]
209    return calculator
[216a9e1]210
[ec7e360]211def eval_ctypes(model_definition, data, dtype='double', cutoff=0.):
212    if dtype=='quad':
213        dtype = 'longdouble'
[aa4946b]214    model = core.load_model(model_definition, dtype=dtype, platform="dll")
[7cf2cfd]215    calculator = DirectModel(data, model, cutoff=cutoff)
[ec7e360]216    calculator.engine = "OMP%s"%DTYPE_MAP[dtype]
217    return calculator
218
219def time_calculation(calculator, pars, Nevals=1):
220    # initialize the code so time is more accurate
[f4f3919]221    value = calculator(**suppress_pd(pars))
[216a9e1]222    toc = tic()
[ec7e360]223    for _ in range(max(Nevals, 1)):  # make sure there is at least one eval
[7cf2cfd]224        value = calculator(**pars)
[216a9e1]225    average_time = toc()*1000./Nevals
226    return value, average_time
227
[ec7e360]228def make_data(opts):
229    qmax, nq, res = opts['qmax'], opts['nq'], opts['res']
230    if opts['is2d']:
231        data = empty_data2D(np.linspace(-qmax, qmax, nq), resolution=res)
232        data.accuracy = opts['accuracy']
[87985ca]233        set_beam_stop(data, 0.004)
234        index = ~data.mask
[216a9e1]235    else:
[ec7e360]236        if opts['view'] == 'log':
[b89f519]237            qmax = math.log10(qmax)
[ec7e360]238            q = np.logspace(qmax-3, qmax, nq)
[b89f519]239        else:
[ec7e360]240            q = np.linspace(0.001*qmax, qmax, nq)
241        data = empty_data1D(q, resolution=res)
[216a9e1]242        index = slice(None, None)
243    return data, index
244
[ec7e360]245def make_engine(model_definition, data, dtype, cutoff):
246    if dtype == 'sasview':
247        return eval_sasview(model_definition, data)
248    elif dtype.endswith('!'):
249        return eval_ctypes(model_definition, data, dtype=dtype[:-1],
250                           cutoff=cutoff)
251    else:
252        return eval_opencl(model_definition, data, dtype=dtype,
253                           cutoff=cutoff)
[87985ca]254
[013adb7]255def compare(opts, limits=None):
[ec7e360]256    Nbase, Ncomp = opts['N1'], opts['N2']
257    pars = opts['pars']
258    data = opts['data']
[87985ca]259
[4b41184]260    # Base calculation
[ec7e360]261    if Nbase > 0:
262        base = opts['engines'][0]
[319ab14]263        try:
[ec7e360]264            base_value, base_time = time_calculation(base, pars, Nbase)
265            print("%s t=%.1f ms, intensity=%.0f"%(base.engine, base_time, sum(base_value)))
[319ab14]266        except ImportError:
267            traceback.print_exc()
[1ec7efa]268            Nbase = 0
[4b41184]269
270    # Comparison calculation
[ec7e360]271    if Ncomp > 0:
272        comp = opts['engines'][1]
[7cf2cfd]273        try:
[ec7e360]274            comp_value, comp_time = time_calculation(comp, pars, Ncomp)
275            print("%s t=%.1f ms, intensity=%.0f"%(comp.engine, comp_time, sum(comp_value)))
[7cf2cfd]276        except ImportError:
[5753e4e]277            traceback.print_exc()
[4b41184]278            Ncomp = 0
[87985ca]279
280    # Compare, but only if computing both forms
[4b41184]281    if Nbase > 0 and Ncomp > 0:
[9404dd3]282        #print("speedup %.2g"%(comp_time/base_time))
[ec7e360]283        #print("max |base/comp|", max(abs(base_value/comp_value)), "%.15g"%max(abs(base_value)), "%.15g"%max(abs(comp_value)))
284        #comp *= max(base_value/comp_value)
285        resid = (base_value - comp_value)
286        relerr = resid/comp_value
287        _print_stats("|%s - %s|"%(base.engine,comp.engine)+(" "*(3+len(comp.engine))), resid)
288        _print_stats("|(%s - %s) / %s|"%(base.engine,comp.engine,comp.engine), relerr)
[87985ca]289
290    # Plot if requested
[ec7e360]291    if not opts['plot'] and not opts['explore']: return
292    view = opts['view']
[1726b21]293    import matplotlib.pyplot as plt
[013adb7]294    if limits is None:
295        vmin, vmax = np.Inf, -np.Inf
296        if Nbase > 0:
297            vmin = min(vmin, min(base_value))
298            vmax = max(vmax, max(base_value))
299        if Ncomp > 0:
300            vmin = min(vmin, min(comp_value))
301            vmax = max(vmax, max(comp_value))
302        limits = vmin, vmax
303
[4b41184]304    if Nbase > 0:
[ec7e360]305        if Ncomp > 0: plt.subplot(131)
[013adb7]306        plot_theory(data, base_value, view=view, plot_data=False, limits=limits)
[ec7e360]307        plt.title("%s t=%.1f ms"%(base.engine, base_time))
308        #cbar_title = "log I"
309    if Ncomp > 0:
310        if Nbase > 0: plt.subplot(132)
[013adb7]311        plot_theory(data, comp_value, view=view, plot_data=False, limits=limits)
[ec7e360]312        plt.title("%s t=%.1f ms"%(comp.engine,comp_time))
[7cf2cfd]313        #cbar_title = "log I"
[4b41184]314    if Ncomp > 0 and Nbase > 0:
[87985ca]315        plt.subplot(133)
[29f5536]316        if '-abs' in opts:
[b89f519]317            err,errstr,errview = resid, "abs err", "linear"
[29f5536]318        else:
[b89f519]319            err,errstr,errview = abs(relerr), "rel err", "log"
[4b41184]320        #err,errstr = base/comp,"ratio"
[7cf2cfd]321        plot_theory(data, None, resid=err, view=errview, plot_data=False)
[346bc88]322        plt.title("max %s = %.3g"%(errstr, max(abs(err))))
[7cf2cfd]323        #cbar_title = errstr if errview=="linear" else "log "+errstr
324    #if is2D:
325    #    h = plt.colorbar()
326    #    h.ax.set_title(cbar_title)
[ba69383]327
[4b41184]328    if Ncomp > 0 and Nbase > 0 and '-hist' in opts:
[ba69383]329        plt.figure()
[346bc88]330        v = relerr
[ba69383]331        v[v==0] = 0.5*np.min(np.abs(v[v!=0]))
332        plt.hist(np.log10(np.abs(v)), normed=1, bins=50);
[ec7e360]333        plt.xlabel('log10(err), err = |(%s - %s) / %s|'%(base.engine, comp.engine, comp.engine));
[ba69383]334        plt.ylabel('P(err)')
[ec7e360]335        plt.title('Distribution of relative error between calculation engines')
[ba69383]336
[ec7e360]337    if not opts['explore']:
338        plt.show()
[8a20be5]339
[013adb7]340    return limits
341
[0763009]342def _print_stats(label, err):
343    sorted_err = np.sort(abs(err))
344    p50 = int((len(err)-1)*0.50)
345    p98 = int((len(err)-1)*0.98)
346    data = [
347        "max:%.3e"%sorted_err[-1],
348        "median:%.3e"%sorted_err[p50],
349        "98%%:%.3e"%sorted_err[p98],
350        "rms:%.3e"%np.sqrt(np.mean(err**2)),
351        "zero-offset:%+.3e"%np.mean(err),
352        ]
[9404dd3]353    print(label+"  ".join(data))
[0763009]354
355
356
[87985ca]357# ===========================================================================
358#
359USAGE="""
[ec7e360]360usage: compare.py model N1 N2 [options...] [key=val]
[87985ca]361
362Compare the speed and value for a model between the SasView original and the
[ec7e360]363sasmodels rewrite.
[87985ca]364
365model is the name of the model to compare (see below).
[ec7e360]366N1 is the number of times to run sasmodels (default=1).
367N2 is the number times to run sasview (default=1).
[87985ca]368
369Options (* for default):
370
371    -plot*/-noplot plots or suppress the plot of the model
[29f5536]372    -lowq*/-midq/-highq/-exq use q values up to 0.05, 0.2, 1.0, 10.0
[ec7e360]373    -nq=128 sets the number of Q points in the data set
[73a3e22]374    -1d*/-2d computes 1d or 2d data
[2d0aced]375    -preset*/-random[=seed] preset or random parameters
376    -mono/-poly* force monodisperse/polydisperse
[3e6aaad]377    -cutoff=1e-5* cutoff value for including a point in polydispersity
[2d0aced]378    -pars/-nopars* prints the parameter set or not
379    -abs/-rel* plot relative or absolute error
[ec7e360]380    -linear/-log*/-q4 intensity scaling
[ba69383]381    -hist/-nohist* plot histogram of relative error
[346bc88]382    -res=0 sets the resolution width dQ/Q if calculating with resolution
[5d316e9]383    -accuracy=Low accuracy of the resolution calculation Low, Mid, High, Xhigh
[ec7e360]384    -edit starts the parameter explorer
[87985ca]385
[ec7e360]386Any two calculation engines can be selected for comparison:
387
388    -single/-double/-half/-fast sets an OpenCL calculation engine
389    -single!/-double!/-quad! sets an OpenMP calculation engine
390    -sasview sets the sasview calculation engine
391
[e21cc31]392The default is -single -sasview.  Note that the interpretation of quad
393precision depends on architecture, and may vary from 64-bit to 128-bit,
394with 80-bit floats being common (1e-19 precision).
[ec7e360]395
396Key=value pairs allow you to set specific values for the model parameters.
[87985ca]397
398Available models:
399"""
400
[7cf2cfd]401
[216a9e1]402NAME_OPTIONS = set([
[5d316e9]403    'plot', 'noplot',
[ec7e360]404    'half', 'fast', 'single', 'double',
405    'single!', 'double!', 'quad!', 'sasview',
[5d316e9]406    'lowq', 'midq', 'highq', 'exq',
407    '2d', '1d',
408    'preset', 'random',
409    'poly', 'mono',
410    'nopars', 'pars',
411    'rel', 'abs',
[b89f519]412    'linear', 'log', 'q4',
[5d316e9]413    'hist', 'nohist',
[ec7e360]414    'edit',
[216a9e1]415    ])
416VALUE_OPTIONS = [
417    # Note: random is both a name option and a value option
[ec7e360]418    'cutoff', 'random', 'nq', 'res', 'accuracy',
[87985ca]419    ]
420
[7cf2cfd]421def columnize(L, indent="", width=79):
422    column_width = max(len(w) for w in L) + 1
423    num_columns = (width - len(indent)) // column_width
424    num_rows = len(L) // num_columns
425    L = L + [""] * (num_rows*num_columns - len(L))
426    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)]
427    lines = [" ".join("%-*s"%(column_width, entry) for entry in row)
428             for row in zip(*columns)]
429    output = indent + ("\n"+indent).join(lines)
430    return output
431
432
[cd3dba0]433def get_demo_pars(model_definition):
434    info = generate.make_info(model_definition)
[ec7e360]435    # Get the default values for the parameters
[cd3dba0]436    pars = dict((p[0],p[2]) for p in info['parameters'])
[ec7e360]437
438    # Fill in default values for the polydispersity parameters
439    for p in info['parameters']:
440        if p[4] in ('volume', 'orientation'):
441            pars[p[0]+'_pd'] = 0.0
442            pars[p[0]+'_pd_n'] = 0
443            pars[p[0]+'_pd_nsigma'] = 3.0
444            pars[p[0]+'_pd_type'] = "gaussian"
445
446    # Plug in values given in demo
[cd3dba0]447    pars.update(info['demo'])
[373d1b6]448    return pars
449
[ec7e360]450def parse_opts():
451    flags = [arg for arg in sys.argv[1:] if arg.startswith('-')]
452    values = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' in arg]
[319ab14]453    args = [arg for arg in sys.argv[1:] if not arg.startswith('-') and '=' not in arg]
[d547f16]454    models = "\n    ".join("%-15s"%v for v in MODELS)
[87985ca]455    if len(args) == 0:
[7cf2cfd]456        print(USAGE)
457        print(columnize(MODELS, indent="  "))
[87985ca]458        sys.exit(1)
459    if args[0] not in MODELS:
[9404dd3]460        print("Model %r not available. Use one of:\n    %s"%(args[0],models))
[87985ca]461        sys.exit(1)
[319ab14]462    if len(args) > 3:
463        print("expected parameters: model Nopencl Nsasview")
[87985ca]464
[ec7e360]465    invalid = [o[1:] for o in flags
[216a9e1]466               if o[1:] not in NAME_OPTIONS
467                  and not any(o.startswith('-%s='%t) for t in VALUE_OPTIONS)]
[87985ca]468    if invalid:
[9404dd3]469        print("Invalid options: %s"%(", ".join(invalid)))
[87985ca]470        sys.exit(1)
471
[ec7e360]472
473    # Interpret the flags
474    opts = {
475        'plot'      : True,
476        'view'      : 'log',
477        'is2d'      : False,
478        'qmax'      : 0.05,
479        'nq'        : 128,
480        'res'       : 0.0,
481        'accuracy'  : 'Low',
482        'cutoff'    : 1e-5,
483        'seed'      : -1,  # default to preset
484        'mono'      : False,
485        'show_pars' : False,
486        'show_hist' : False,
487        'rel_err'   : True,
488        'explore'   : False,
489    }
490    engines = []
491    for arg in flags:
492        if arg == '-noplot':    opts['plot'] = False
493        elif arg == '-plot':    opts['plot'] = True
494        elif arg == '-linear':  opts['view'] = 'linear'
495        elif arg == '-log':     opts['view'] = 'log'
496        elif arg == '-q4':      opts['view'] = 'q4'
497        elif arg == '-1d':      opts['is2d'] = False
498        elif arg == '-2d':      opts['is2d'] = True
499        elif arg == '-exq':     opts['qmax'] = 10.0
500        elif arg == '-highq':   opts['qmax'] = 1.0
501        elif arg == '-midq':    opts['qmax'] = 0.2
502        elif arg == '-loq':     opts['qmax'] = 0.05
503        elif arg.startswith('-nq='):       opts['nq'] = int(arg[4:])
504        elif arg.startswith('-res='):      opts['res'] = float(arg[5:])
505        elif arg.startswith('-accuracy='): opts['accuracy'] = arg[10:]
506        elif arg.startswith('-cutoff='):   opts['cutoff'] = float(arg[8:])
507        elif arg.startswith('-random='):   opts['seed'] = int(arg[8:])
508        elif arg == '-random':  opts['seed'] = np.random.randint(1e6)
509        elif arg == '-preset':  opts['seed'] = -1
510        elif arg == '-mono':    opts['mono'] = True
511        elif arg == '-poly':    opts['mono'] = False
512        elif arg == '-pars':    opts['show_pars'] = True
513        elif arg == '-nopars':  opts['show_pars'] = False
514        elif arg == '-hist':    opts['show_hist'] = True
515        elif arg == '-nohist':  opts['show_hist'] = False
516        elif arg == '-rel':     opts['rel_err'] = True
517        elif arg == '-abs':     opts['rel_err'] = False
518        elif arg == '-half':    engines.append(arg[1:])
519        elif arg == '-fast':    engines.append(arg[1:])
520        elif arg == '-single':  engines.append(arg[1:])
521        elif arg == '-double':  engines.append(arg[1:])
522        elif arg == '-single!': engines.append(arg[1:])
523        elif arg == '-double!': engines.append(arg[1:])
524        elif arg == '-quad!':   engines.append(arg[1:])
525        elif arg == '-sasview': engines.append(arg[1:])
526        elif arg == '-edit':    opts['explore'] = True
527
528    if len(engines) == 0:
529        engines.extend(['single','sasview'])
530    elif len(engines) == 1:
531        if engines[0][0] != 'sasview':
532            engines.append('sasview')
533        else:
534            engines.append('single')
535    elif len(engines) > 2:
536        del engines[2:]
537
[d547f16]538    name = args[0]
[cd3dba0]539    model_definition = core.load_model_definition(name)
[d547f16]540
[ec7e360]541    N1 = int(args[1]) if len(args) > 1 else 1
542    N2 = int(args[2]) if len(args) > 2 else 1
[87985ca]543
[ec7e360]544    # Get demo parameters from model definition, or use default parameters
545    # if model does not define demo parameters
546    pars = get_demo_pars(model_definition)
[87985ca]547
548    # Fill in parameters given on the command line
[ec7e360]549    presets = {}
550    for arg in values:
[319ab14]551        k,v = arg.split('=',1)
[87985ca]552        if k not in pars:
[ec7e360]553            # extract base name without polydispersity info
[87985ca]554            s = set(p.split('_pd')[0] for p in pars)
[9404dd3]555            print("%r invalid; parameters are: %s"%(k,", ".join(sorted(s))))
[87985ca]556            sys.exit(1)
[ec7e360]557        presets[k] = float(v) if not k.endswith('type') else v
558
559    # randomize parameters
560    #pars.update(set_pars)  # set value before random to control range
561    if opts['seed'] > -1:
562        pars = randomize_pars(pars, seed=opts['seed'])
563        print("Randomize using -random=%i"%opts['seed'])
[8b25ee1]564    if opts['mono']:
565        pars = suppress_pd(pars)
[ec7e360]566    pars.update(presets)  # set value after random to control value
567    constrain_pars(model_definition, pars)
568    constrain_new_to_old(model_definition, pars)
569    if opts['show_pars']:
570        print("pars " + str(parlist(pars)))
571
572    # Create the computational engines
573    data, _index = make_data(opts)
574    if N1:
575        base = make_engine(model_definition, data, engines[0], opts['cutoff'])
576    else:
577        base = None
578    if N2:
579        comp = make_engine(model_definition, data, engines[1], opts['cutoff'])
580    else:
581        comp = None
582
583    # Remember it all
584    opts.update({
585        'name'      : name,
586        'def'       : model_definition,
587        'N1'        : N1,
588        'N2'        : N2,
589        'presets'   : presets,
590        'pars'      : pars,
591        'data'      : data,
592        'engines'   : [base, comp],
593    })
594
595    return opts
596
597def main():
598    opts = parse_opts()
599    if opts['explore']:
600        explore(opts)
601    else:
602        compare(opts)
603
604def explore(opts):
605    import wx
606    from bumps.names import FitProblem
607    from bumps.gui.app_frame import AppFrame
608
609    problem = FitProblem(Explore(opts))
610    isMac = "cocoa" in wx.version()
611    app = wx.App()
612    frame = AppFrame(parent=None, title="explore")
613    if not isMac: frame.Show()
614    frame.panel.set_model(model=problem)
615    frame.panel.Layout()
616    frame.panel.aui.Split(0, wx.TOP)
617    if isMac: frame.Show()
618    app.MainLoop()
619
620class Explore(object):
621    """
622    Return a bumps wrapper for a SAS model comparison.
623    """
624    def __init__(self, opts):
625        from bumps.cli import config_matplotlib
626        import bumps_model
627        config_matplotlib()
628        self.opts = opts
629        info = generate.make_info(opts['def'])
630        pars, pd_types = bumps_model.create_parameters(info, **opts['pars'])
631        if not opts['is2d']:
632            active = [base + ext
633                      for base in info['partype']['pd-1d']
634                      for ext in ['','_pd','_pd_n','_pd_nsigma']]
635            active.extend(info['partype']['fixed-1d'])
636            for k in active:
637                v = pars[k]
638                v.range(*parameter_range(k, v.value))
639        else:
[013adb7]640            for k, v in pars.items():
[ec7e360]641                v.range(*parameter_range(k, v.value))
642
643        self.pars = pars
644        self.pd_types = pd_types
[013adb7]645        self.limits = None
[ec7e360]646
647    def numpoints(self):
648        """
649        Return the number of points
650        """
651        return len(self.pars) + 1  # so dof is 1
652
653    def parameters(self):
654        """
655        Return a dictionary of parameters
656        """
657        return self.pars
658
659    def nllf(self):
660        return 0.  # No nllf
661
662    def plot(self, view='log'):
663        """
664        Plot the data and residuals.
665        """
666        pars = dict((k, v.value) for k,v in self.pars.items())
667        pars.update(self.pd_types)
668        self.opts['pars'] = pars
[013adb7]669        limits = compare(self.opts, limits=self.limits)
670        if self.limits is None:
671            vmin, vmax = limits
672            vmax = 1.3*vmax
673            vmin = vmax*1e-7
674            self.limits = vmin, vmax
[87985ca]675
676
[8a20be5]677if __name__ == "__main__":
[87985ca]678    main()
Note: See TracBrowser for help on using the repository browser.