Changeset 7cf2cfd in sasmodels


Ignore:
Timestamp:
Nov 22, 2015 11:37:15 PM (9 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
3b4243d
Parents:
677ccf1
Message:

refactor compare.py so that bumps/sasview not required for simple tests

Files:
1 added
1 deleted
6 edited

Legend:

Unmodified
Added
Removed
  • compare.py

    r29fc2a3 r7cf2cfd  
    66from os.path import basename, dirname, join as joinpath 
    77import glob 
     8import datetime 
    89 
    910import numpy as np 
     
    1314 
    1415 
    15 from sasmodels.bumps_model import Model, Experiment, plot_theory, tic 
    1616from sasmodels import core 
    1717from sasmodels import kerneldll 
     18from sasmodels.data import plot_theory, empty_data1D, empty_data2D 
     19from sasmodels.direct_model import DirectModel 
    1820from sasmodels.convert import revert_model 
    1921kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True 
     
    2224MODELS = [basename(f)[:-3] 
    2325          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))] 
     26 
     27# CRUFT python 2.6 
     28if not hasattr(datetime.timedelta, 'total_seconds'): 
     29    def delay(dt): 
     30        """Return number date-time delta as number seconds""" 
     31        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds 
     32else: 
     33    def delay(dt): 
     34        """Return number date-time delta as number seconds""" 
     35        return dt.total_seconds() 
     36 
     37 
     38def tic(): 
     39    """ 
     40    Timer function. 
     41 
     42    Use "toc=tic()" to start the clock and "toc()" to measure 
     43    a time interval. 
     44    """ 
     45    then = datetime.datetime.now() 
     46    return lambda: delay(datetime.datetime.now() - then) 
     47 
     48 
     49def set_beam_stop(data, radius, outer=None): 
     50    """ 
     51    Add a beam stop of the given *radius*.  If *outer*, make an annulus. 
     52 
     53    Note: this function does not use the sasview package 
     54    """ 
     55    if hasattr(data, 'qx_data'): 
     56        q = np.sqrt(data.qx_data**2 + data.qy_data**2) 
     57        data.mask = (q < radius) 
     58        if outer is not None: 
     59            data.mask |= (q >= outer) 
     60    else: 
     61        data.mask = (data.x < radius) 
     62        if outer is not None: 
     63            data.mask |= (data.x >= outer) 
    2464 
    2565 
     
    99139        if p.endswith("_pd"): pars[p] = 0 
    100140 
    101 def eval_sasview(name, pars, data, Nevals=1): 
     141def eval_sasview(model_definition, pars, data, Nevals=1): 
    102142    from sas.models.qsmearing import smear_selection 
    103     model = sasview_model(name, **pars) 
     143    model = sasview_model(model_definition, **pars) 
    104144    smearer = smear_selection(data, model=model) 
    105145    value = None  # silence the linter 
     
    131171        print "... trying again with single precision" 
    132172        model = core.load_model(model_definition, dtype='single', platform="ocl") 
    133     problem = Experiment(data, Model(model, **pars), cutoff=cutoff) 
     173    calculator = DirectModel(data, model, cutoff=cutoff) 
    134174    value = None  # silence the linter 
    135175    toc = tic() 
    136176    for _ in range(max(Nevals, 1)):  # force at least one eval 
    137         #pars['scale'] = np.random.rand() 
    138         problem.update() 
    139         value = problem.theory() 
     177        value = calculator(**pars) 
    140178    average_time = toc()*1000./Nevals 
    141179    return value, average_time 
    142180 
     181 
    143182def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.): 
    144183    model = core.load_model(model_definition, dtype=dtype, platform="dll") 
    145     problem = Experiment(data, Model(model, **pars), cutoff=cutoff) 
     184    calculator = DirectModel(data, model, cutoff=cutoff) 
    146185    value = None  # silence the linter 
    147186    toc = tic() 
    148187    for _ in range(max(Nevals, 1)):  # force at least one eval 
    149         problem.update() 
    150         value = problem.theory() 
     188        value = calculator(**pars) 
    151189    average_time = toc()*1000./Nevals 
    152190    return value, average_time 
    153191 
     192 
    154193def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'): 
    155194    if is2D: 
    156         from sasmodels.bumps_model import empty_data2D, set_beam_stop 
    157195        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution) 
    158196        data.accuracy = accuracy 
     
    160198        index = ~data.mask 
    161199    else: 
    162         from sasmodels.bumps_model import empty_data1D 
    163200        if view == 'log': 
    164201            qmax = math.log10(qmax) 
     
    190227 
    191228    # randomize parameters 
    192     pars.update(set_pars) 
     229    #pars.update(set_pars)  # set value before random to control range 
    193230    if '-random' in opts or '-random' in opt_values: 
    194231        seed = int(opt_values['-random']) if '-random' in opt_values else None 
    195232        pars, seed = randomize_model(name, pars, seed=seed) 
    196233        print "Randomize using -random=%i"%seed 
     234    pars.update(set_pars)  # set value after random to control value 
    197235 
    198236    # parameter selection 
     
    217255        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu)) 
    218256    elif Ncpu > 0: 
    219         cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu) 
    220         comp = "sasview" 
    221         print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu)) 
     257        try: 
     258            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu) 
     259            comp = "sasview" 
     260            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background']) 
     261            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu)) 
     262        except ImportError: 
     263            Ncpu = 0 
    222264 
    223265    # Compare, but only if computing both forms 
     
    238280    if Ncpu > 0: 
    239281        if Nocl > 0: plt.subplot(131) 
    240         plot_theory(data, cpu, view=view) 
     282        plot_theory(data, cpu, view=view, plot_data=False) 
    241283        plt.title("%s t=%.1f ms"%(comp,cpu_time)) 
    242         cbar_title = "log I" 
     284        #cbar_title = "log I" 
    243285    if Nocl > 0: 
    244286        if Ncpu > 0: plt.subplot(132) 
    245         plot_theory(data, ocl, view=view) 
     287        plot_theory(data, ocl, view=view, plot_data=False) 
    246288        plt.title("opencl t=%.1f ms"%ocl_time) 
    247         cbar_title = "log I" 
     289        #cbar_title = "log I" 
    248290    if Ncpu > 0 and Nocl > 0: 
    249291        plt.subplot(133) 
     
    253295            err,errstr,errview = abs(relerr), "rel err", "log" 
    254296        #err,errstr = ocl/cpu,"ratio" 
    255         plot_theory(data, err, view=errview) 
     297        plot_theory(data, None, resid=err, view=errview, plot_data=False) 
    256298        plt.title("max %s = %.3g"%(errstr, max(abs(err)))) 
    257         cbar_title = errstr if errview=="linear" else "log "+errstr 
    258     if is2D: 
    259         h = plt.colorbar() 
    260         h.ax.set_title(cbar_title) 
     299        #cbar_title = errstr if errview=="linear" else "log "+errstr 
     300    #if is2D: 
     301    #    h = plt.colorbar() 
     302    #    h.ax.set_title(cbar_title) 
    261303 
    262304    if Ncpu > 0 and Nocl > 0 and '-hist' in opts: 
     
    320362 
    321363Available models: 
    322  
    323     %s 
    324364""" 
     365 
    325366 
    326367NAME_OPTIONS = set([ 
     
    342383    ] 
    343384 
     385def columnize(L, indent="", width=79): 
     386    column_width = max(len(w) for w in L) + 1 
     387    num_columns = (width - len(indent)) // column_width 
     388    num_rows = len(L) // num_columns 
     389    L = L + [""] * (num_rows*num_columns - len(L)) 
     390    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)] 
     391    lines = [" ".join("%-*s"%(column_width, entry) for entry in row) 
     392             for row in zip(*columns)] 
     393    output = indent + ("\n"+indent).join(lines) 
     394    return output 
     395 
     396 
    344397def get_demo_pars(name): 
    345398    import sasmodels.models 
     
    355408    models = "\n    ".join("%-15s"%v for v in MODELS) 
    356409    if len(args) == 0: 
    357         print(USAGE%models) 
     410        print(USAGE) 
     411        print(columnize(MODELS, indent="  ")) 
    358412        sys.exit(1) 
    359413    if args[0] not in MODELS: 
  • compare_many.py

    rab55943 r7cf2cfd  
    11#!/usr/bin/env python 
    2  
    32import sys 
     3import traceback 
    44 
    55import numpy as np 
    66 
     7from sasmodels import core 
    78from sasmodels.kernelcl import environment 
    89from compare import (MODELS, randomize_model, suppress_pd, eval_sasview, 
    9                      eval_opencl, eval_ctypes, make_data, get_demo_pars) 
     10                     eval_opencl, eval_ctypes, make_data, get_demo_pars, 
     11                     columnize) 
    1012 
    1113def get_stats(target, value, index): 
     
    1315    relerr = resid/target[index] 
    1416    srel = np.argsort(relerr) 
    15     p90 = int(len(relerr)*0.90) 
     17    #p90 = int(len(relerr)*0.90) 
    1618    p95 = int(len(relerr)*0.95) 
    1719    maxrel = np.max(relerr) 
     
    2729        groups.append(p) 
    2830        groups.extend(['']*(len(stats)-1)) 
     31    groups.append("Parameters") 
    2932    columns = ['Seed'] + stats*len(parts) +  list(sorted(pars.keys())) 
    3033    print(','.join('"%s"'%c for c in groups)) 
     
    3235 
    3336def compare_instance(name, data, index, N=1, mono=True, cutoff=1e-5): 
     37    model_definition = core.load_model_definition(name) 
    3438    pars = get_demo_pars(name) 
    3539    header = '\n"Model","%s","Count","%d"'%(name, N) 
    3640    if not mono: header += ',"Cutoff",%g'%(cutoff,) 
    3741    print(header) 
     42 
     43    # Stuff the failure flag into a mutable object so we can update it from 
     44    # within the nested function.  Note that the nested function uses "pars" 
     45    # which is dynamically scoped, not lexically scoped in this context.  That 
     46    # is, pars is replaced each time in the loop, so don't assume that it is 
     47    # the default values defined above. 
     48    def trymodel(fn, *args, **kw): 
     49        try: 
     50            result, _ = fn(model_definition, pars, data, *args, **kw) 
     51        except: 
     52            result = np.NaN 
     53            traceback.print_exc() 
     54        return result 
     55 
     56    num_good = 0 
    3857    first = True 
    3958    for _ in range(N): 
     
    4160        if mono: suppress_pd(pars) 
    4261 
    43         target, _ = eval_sasview(name, pars, data) 
     62        # Force parameter constraints on a per-model basis. 
     63        if name in ('teubner_strey','broad_peak'): 
     64            pars['scale'] = 1.0 
     65        #if name == 'parallelepiped': 
     66        #    pars['a_side'],pars['b_side'],pars['c_side'] = \ 
     67        #        sorted([pars['a_side'],pars['b_side'],pars['c_side']]) 
    4468 
    45         env = environment() 
    46         gpu_single_value,_ = eval_opencl(name, pars, data, dtype='single', cutoff=cutoff) 
    47         gpu_single = get_stats(target, gpu_single_value, index) 
    48         if env.has_double: 
    49             gpu_double_value,_ = eval_opencl(name, pars, data, dtype='double', cutoff=cutoff) 
    50             gpu_double = get_stats(target, gpu_double_value, index) 
     69 
     70        good = True 
     71        labels = [] 
     72        columns = [] 
     73        if 1: 
     74            sasview_value = trymodel(eval_sasview) 
     75        if 0: 
     76            gpu_single_value = trymodel(eval_opencl, dtype='single', cutoff=cutoff) 
     77            stats = get_stats(sasview_value, gpu_single_value, index) 
     78            columns.extend(stats) 
     79            labels.append('GPU single') 
     80            good = good and (stats[0] < 1e-14) 
     81        if 0 and environment().has_double: 
     82            gpu_double_value = trymodel(eval_opencl, dtype='double', cutoff=cutoff) 
     83            stats = get_stats(sasview_value, gpu_double_value, index) 
     84            columns.extend(stats) 
     85            labels.append('GPU double') 
     86            good = good and (stats[0] < 1e-14) 
     87        if 1: 
     88            cpu_double_value = trymodel(eval_ctypes, dtype='double', cutoff=cutoff) 
     89            stats = get_stats(sasview_value, cpu_double_value, index) 
     90            columns.extend(stats) 
     91            labels.append('CPU double') 
     92            good = good and (stats[0] < 1e-14) 
     93        if 0: 
     94            stats = get_stats(cpu_double_value, gpu_single_value, index) 
     95            columns.extend(stats) 
     96            labels.append('single/double') 
     97            good = good and (stats[0] < 1e-14) 
     98 
     99        columns += [v for _,v in sorted(pars.items())] 
     100        if first: 
     101            print_column_headers(pars, labels) 
     102            first = False 
     103        if good: 
     104            num_good += 1 
    51105        else: 
    52             gpu_double = [0]*len(gpu_single) 
    53         cpu_double_value,_ =  eval_ctypes(name, pars, data, dtype='double', cutoff=cutoff) 
    54         cpu_double = get_stats(target, cpu_double_value, index) 
    55         single_double = get_stats(cpu_double_value, gpu_single_value, index) 
     106            print(("%d,"%seed)+','.join("%g"%v for v in columns)) 
     107    print '"%d/%d good"'%(num_good, N) 
    56108 
    57         values = (list(gpu_single) + list(gpu_double) + list(cpu_double) 
    58                   + list(single_double) + [v for _,v in sorted(pars.items())]) 
    59         if gpu_single[0] > 5e-5: 
    60             if first: 
    61                 print_column_headers(pars,'GPU single|GPU double|CPU double|single/double'.split('|')) 
    62                 first = False 
    63             print(("%d,"%seed)+','.join("%g"%v for v in values)) 
    64109 
    65 def main(): 
    66     try: 
    67         model = sys.argv[1] 
    68         assert (model in MODELS) or (model == "all") 
    69         count = int(sys.argv[2]) 
    70         is2D = sys.argv[3].startswith('2d') 
    71         assert sys.argv[3][1] == 'd' 
    72         Nq = int(sys.argv[3][2:]) 
    73         mono = sys.argv[4] == 'mono' 
    74         cutoff = float(sys.argv[4]) if not mono else 0 
    75     except: 
    76         import traceback; traceback.print_exc() 
    77         models = "\n    ".join("%-7s: %s"%(k,v.__name__.replace('_',' ')) 
    78                                for k,v in sorted(MODELS.items())) 
    79         print("""\ 
    80 usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono) 
     110def print_usage(): 
     111    print "usage: compare_many.py MODEL COUNT (1dNQ|2dNQ) (CUTOFF|mono)" 
    81112 
    82 MODEL is the model name of the model, which is one of: 
    83     %s 
    84 or "all" for all the models in alphabetical order. 
     113 
     114def print_models(): 
     115    print(columnize(MODELS, indent="  ")) 
     116 
     117 
     118def print_help(): 
     119    print_usage() 
     120    print("""\ 
     121 
     122MODEL is the model name of the model or "all" for all the models 
     123in alphabetical order. 
    85124 
    86125COUNT is the number of randomly generated parameter sets to try. A value 
     
    98137choice of polydisperse parameters, and the number of points in the distribution 
    99138is set in compare.py defaults for each model. 
    100 """%(models,)) 
     139 
     140Available models: 
     141""") 
     142    print_models() 
     143 
     144def main(): 
     145    if len(sys.argv) == 1: 
     146        print_help() 
     147        sys.exit(1) 
     148 
     149    model = sys.argv[1] 
     150    if not (model in MODELS) and (model != "all"): 
     151        print 'Bad model %s.  Use "all" or one of:' 
     152        print_models() 
     153        sys.exit(1) 
     154    try: 
     155        count = int(sys.argv[2]) 
     156        is2D = sys.argv[3].startswith('2d') 
     157        assert sys.argv[3][1] == 'd' 
     158        Nq = int(sys.argv[3][2:]) 
     159        mono = sys.argv[4] == 'mono' 
     160        cutoff = float(sys.argv[4]) if not mono else 0 
     161    except: 
     162        print_usage() 
    101163        sys.exit(1) 
    102164 
  • example/fit.py

    r346bc88 r7cf2cfd  
    55from bumps.names import * 
    66from sasmodels.core import load_model 
    7 from sasmodels import bumps_model as sas 
     7from sasmodels.bumps_model import Model, Experiment 
     8from sasmodels.data import load_data, set_beam_stop, set_top 
    89 
    910""" IMPORT THE DATA USED """ 
    10 radial_data = sas.load_data('DEC07267.DAT') 
    11 sas.set_beam_stop(radial_data, 0.00669, outer=0.025) 
    12 sas.set_top(radial_data, -.0185) 
     11radial_data = load_data('DEC07267.DAT') 
     12set_beam_stop(radial_data, 0.00669, outer=0.025) 
     13set_top(radial_data, -.0185) 
    1314 
    14 tan_data = sas.load_data('DEC07266.DAT') 
    15 sas.set_beam_stop(tan_data, 0.00669, outer=0.025) 
    16 sas.set_top(tan_data, -.0185) 
     15tan_data = load_data('DEC07266.DAT') 
     16set_beam_stop(tan_data, 0.00669, outer=0.025) 
     17set_top(tan_data, -.0185) 
    1718#sas.set_half(tan_data, 'right') 
    1819 
     
    2829 
    2930if name == "ellipsoid": 
    30     model = sas.Model(kernel, 
     31    model = Model(kernel, 
    3132        scale=0.08, 
    3233        rpolar=15, requatorial=800, 
     
    5152 
    5253elif name == "lamellar": 
    53     model = sas.Model(kernel, 
     54    model = Model(kernel, 
    5455        scale=0.08, 
    5556        thickness=19.2946, 
     
    8788        theta_pd=10, theta_pd_n=50, theta_pd_nsigma=3, 
    8889        phi_pd=0, phi_pd_n=10, phi_pd_nsigma=3) 
    89     model = sas.Model(kernel, **pars) 
     90    model = Model(kernel, **pars) 
    9091 
    9192    # SET THE FITTING PARAMETERS 
     
    102103 
    103104elif name == "core_shell_cylinder": 
    104     model = sas.Model(kernel, 
     105    model = Model(kernel, 
    105106        scale= .031, radius=19.5, thickness=30, length=22, 
    106107        core_sld=7.105, shell_sld=.291, solvent_sld=7.105, 
     
    129130 
    130131elif name == "capped_cylinder": 
    131     model = sas.Model(kernel, 
     132    model = Model(kernel, 
    132133        scale=.08, radius=20, cap_radius=40, length=400, 
    133134        sld_capcyl=1, sld_solv=6.3, 
     
    144145 
    145146elif name == "triaxial_ellipsoid": 
    146     model = sas.Model(kernel, 
     147    model = Model(kernel, 
    147148        scale=0.08, req_minor=15, req_major=20, rpolar=500, 
    148149        sldEll=7.105, solvent_sld=.291, 
     
    170171 
    171172model.cutoff = cutoff 
    172 M = sas.Experiment(data=data, model=model) 
     173M = Experiment(data=data, model=model) 
    173174if section == "both": 
    174    tan_model = sas.Model(model.kernel, **model.parameters()) 
     175   tan_model = Model(model.kernel, **model.parameters()) 
    175176   tan_model.phi = model.phi - 90 
    176177   tan_model.cutoff = cutoff 
    177    tan_M = sas.Experiment(data=tan_data, model=tan_model) 
     178   tan_M = Experiment(data=tan_data, model=tan_model) 
    178179   problem = FitProblem([M, tan_M]) 
    179180else: 
  • sasmodels/bumps_model.py

    r5d80bbf r7cf2cfd  
    1010how far the polydispersity integral extends. 
    1111 
    12 A variety of helper functions are provided: 
    13  
    14     :func:`load_data` loads a sasview data file. 
    15  
    16     :func:`empty_data1D` creates an empty dataset, which is useful for plotting 
    17     a theory function before the data is measured. 
    18  
    19     :func:`empty_data2D` creates an empty 2D dataset. 
    20  
    21     :func:`set_beam_stop` masks the beam stop from the data. 
    22  
    23     :func:`set_half` selects the right or left half of the data, which can 
    24     be useful for shear measurements which have not been properly corrected 
    25     for path length and reflections. 
    26  
    27     :func:`set_top` cuts the top part off the data. 
    28  
    29     :func:`plot_data` plots the data file. 
    30  
    31     :func:`plot_theory` plots a calculated result from the model. 
    32  
    3312""" 
    3413 
    3514import datetime 
    3615import warnings 
    37 import traceback 
    3816 
    3917import numpy as np 
    4018 
     19from bumps.names import Parameter 
     20 
    4121from . import sesans 
    42 from .resolution import Perfect1D, Pinhole1D, Slit1D 
    43 from .resolution2d import Pinhole2D 
    44  
    45 # CRUFT python 2.6 
    46 if not hasattr(datetime.timedelta, 'total_seconds'): 
    47     def delay(dt): 
    48         """Return number date-time delta as number seconds""" 
    49         return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds 
    50 else: 
    51     def delay(dt): 
    52         """Return number date-time delta as number seconds""" 
    53         return dt.total_seconds() 
    54  
     22from . import weights 
     23from .data import plot_theory 
     24from .direct_model import DataMixin 
    5525 
    5626# CRUFT: old style bumps wrapper which doesn't separate data and model 
     
    6434 
    6535 
    66  
    67 def tic(): 
    68     """ 
    69     Timer function. 
    70  
    71     Use "toc=tic()" to start the clock and "toc()" to measure 
    72     a time interval. 
    73     """ 
    74     then = datetime.datetime.now() 
    75     return lambda: delay(datetime.datetime.now() - then) 
    76  
    77  
    78 def load_data(filename): 
    79     """ 
    80     Load data using a sasview loader. 
    81     """ 
    82     from sas.dataloader.loader import Loader 
    83     loader = Loader() 
    84     data = loader.load(filename) 
    85     if data is None: 
    86         raise IOError("Data %r could not be loaded" % filename) 
    87     return data 
    88  
    89 def plot_data(data, view='log'): 
    90     """ 
    91     Plot data loaded by the sasview loader. 
    92     """ 
    93     if hasattr(data, 'qx_data'): 
    94         _plot_2d_signal(data, data.data, view=view) 
    95     else: 
    96         # Note: kind of weird using the _plot_result1D to plot just the 
    97         # data, but it handles the masking and graph markup already, so 
    98         # do not repeat. 
    99         _plot_result1D(data, None, None, view) 
    100  
    101 def plot_theory(data, theory, view='log'): 
    102     if hasattr(data, 'qx_data'): 
    103         _plot_2d_signal(data, theory, view=view) 
    104     else: 
    105         _plot_result1D(data, theory, None, view, include_data=False) 
    106  
    107  
    108 def empty_data1D(q, resolution=0.05): 
    109     """ 
    110     Create empty 1D data using the given *q* as the x value. 
    111  
    112     *resolution* dq/q defaults to 5%. 
    113     """ 
    114  
    115     from sas.dataloader.data_info import Data1D 
    116  
    117     Iq = 100 * np.ones_like(q) 
    118     dIq = np.sqrt(Iq) 
    119     data = Data1D(q, Iq, dx=resolution * q, dy=dIq) 
    120     data.filename = "fake data" 
    121     data.qmin, data.qmax = q.min(), q.max() 
    122     data.mask = np.zeros(len(Iq), dtype='bool') 
    123     return data 
    124  
    125  
    126 def empty_data2D(qx, qy=None, resolution=0.05): 
    127     """ 
    128     Create empty 2D data using the given mesh. 
    129  
    130     If *qy* is missing, create a square mesh with *qy=qx*. 
    131  
    132     *resolution* dq/q defaults to 5%. 
    133     """ 
    134     from sas.dataloader.data_info import Data2D, Detector 
    135  
    136     if qy is None: 
    137         qy = qx 
    138     Qx, Qy = np.meshgrid(qx, qy) 
    139     Qx, Qy = Qx.flatten(), Qy.flatten() 
    140     Iq = 100 * np.ones_like(Qx) 
    141     dIq = np.sqrt(Iq) 
    142     mask = np.ones(len(Iq), dtype='bool') 
    143  
    144     data = Data2D() 
    145     data.filename = "fake data" 
    146     data.qx_data = Qx 
    147     data.qy_data = Qy 
    148     data.data = Iq 
    149     data.err_data = dIq 
    150     data.mask = mask 
    151     data.qmin = 1e-16 
    152     data.qmax = np.inf 
    153  
    154     # 5% dQ/Q resolution 
    155     if resolution != 0: 
    156         # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf 
    157         # Should have an additional constant which depends on distances and 
    158         # radii of the aperture, pixel dimensions and wavelength spread 
    159         # Instead, assume radial dQ/Q is constant, and perpendicular matches 
    160         # radial (which instead it should be inverse). 
    161         Q = np.sqrt(Qx**2 + Qy**2) 
    162         data.dqx_data = resolution * Q 
    163         data.dqy_data = resolution * Q 
    164  
    165     detector = Detector() 
    166     detector.pixel_size.x = 5 # mm 
    167     detector.pixel_size.y = 5 # mm 
    168     detector.distance = 4 # m 
    169     data.detector.append(detector) 
    170     data.xbins = qx 
    171     data.ybins = qy 
    172     data.source.wavelength = 5 # angstroms 
    173     data.source.wavelength_unit = "A" 
    174     data.Q_unit = "1/A" 
    175     data.I_unit = "1/cm" 
    176     data.q_data = np.sqrt(Qx ** 2 + Qy ** 2) 
    177     data.xaxis("Q_x", "A^{-1}") 
    178     data.yaxis("Q_y", "A^{-1}") 
    179     data.zaxis("Intensity", r"\text{cm}^{-1}") 
    180     return data 
    181  
    182  
    183 def set_beam_stop(data, radius, outer=None): 
    184     """ 
    185     Add a beam stop of the given *radius*.  If *outer*, make an annulus. 
    186     """ 
    187     from sas.dataloader.manipulations import Ringcut 
    188     if hasattr(data, 'qx_data'): 
    189         data.mask = Ringcut(0, radius)(data) 
    190         if outer is not None: 
    191             data.mask += Ringcut(outer, np.inf)(data) 
    192     else: 
    193         data.mask = (data.x >= radius) 
    194         if outer is not None: 
    195             data.mask &= (data.x < outer) 
    196  
    197  
    198 def set_half(data, half): 
    199     """ 
    200     Select half of the data, either "right" or "left". 
    201     """ 
    202     from sas.dataloader.manipulations import Boxcut 
    203     if half == 'right': 
    204         data.mask += \ 
    205             Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data) 
    206     if half == 'left': 
    207         data.mask += \ 
    208             Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data) 
    209  
    210  
    211 def set_top(data, cutoff): 
    212     """ 
    213     Chop the top off the data, above *cutoff*. 
    214     """ 
    215     from sas.dataloader.manipulations import Boxcut 
    216     data.mask += \ 
    217         Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data) 
    218  
    219 def protect(fn): 
    220     def wrapper(*args, **kw): 
    221         try:  
    222             return fn(*args, **kw) 
    223         except: 
    224             traceback.print_exc() 
    225     return wrapper 
    226  
    227 @protect 
    228 def _plot_result1D(data, theory, resid, view, include_data=True): 
    229     """ 
    230     Plot the data and residuals for 1D data. 
    231     """ 
    232     import matplotlib.pyplot as plt 
    233     from numpy.ma import masked_array, masked 
    234     #print "not a number",sum(np.isnan(data.y)) 
    235     #data.y[data.y<0.05] = 0.5 
    236     mdata = masked_array(data.y, data.mask) 
    237     mdata[~np.isfinite(mdata)] = masked 
    238     if view is 'log': 
    239         mdata[mdata <= 0] = masked 
    240  
    241     scale = data.x**4 if view == 'q4' else 1.0 
    242     if resid is not None: 
    243         plt.subplot(121) 
    244  
    245     positive = False 
    246     if include_data: 
    247         plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.') 
    248         positive = positive or (mdata>0).any() 
    249     if theory is not None: 
    250         mtheory = masked_array(theory, mdata.mask) 
    251         plt.plot(data.x, scale*mtheory, '-', hold=True) 
    252         positive = positive or (mtheory>0).any() 
    253     plt.xscale(view) 
    254     plt.yscale('linear' if view == 'q4' or not positive else view) 
    255     plt.xlabel('Q') 
    256     plt.ylabel('I(Q)') 
    257     if resid is not None: 
    258         mresid = masked_array(resid, mdata.mask) 
    259         plt.subplot(122) 
    260         plt.plot(data.x, mresid, 'x') 
    261         plt.ylabel('residuals') 
    262         plt.xlabel('Q') 
    263         plt.xscale(view) 
    264  
    265 # pylint: disable=unused-argument 
    266 @protect 
    267 def _plot_sesans(data, theory, resid, view): 
    268     import matplotlib.pyplot as plt 
    269     plt.subplot(121) 
    270     plt.errorbar(data.x, data.y, yerr=data.dy) 
    271     plt.plot(data.x, theory, '-', hold=True) 
    272     plt.xlabel('spin echo length (nm)') 
    273     plt.ylabel('polarization (P/P0)') 
    274     plt.subplot(122) 
    275     plt.plot(data.x, resid, 'x') 
    276     plt.xlabel('spin echo length (nm)') 
    277     plt.ylabel('residuals (P/P0)') 
    278  
    279 @protect 
    280 def _plot_result2D(data, theory, resid, view): 
    281     """ 
    282     Plot the data and residuals for 2D data. 
    283     """ 
    284     import matplotlib.pyplot as plt 
    285     target = data.data[~data.mask] 
    286     if view == 'log': 
    287         vmin = min(target[target>0].min(), theory[theory>0].min()) 
    288         vmax = max(target.max(), theory.max()) 
    289     else: 
    290         vmin = min(target.min(), theory.min()) 
    291         vmax = max(target.max(), theory.max()) 
    292     #print vmin, vmax 
    293     plt.subplot(131) 
    294     _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax) 
    295     plt.title('data') 
    296     plt.colorbar() 
    297     plt.subplot(132) 
    298     _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax) 
    299     plt.title('theory') 
    300     plt.colorbar() 
    301     plt.subplot(133) 
    302     _plot_2d_signal(data, resid, view='linear') 
    303     plt.title('residuals') 
    304     plt.colorbar() 
    305  
    306 @protect 
    307 def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'): 
    308     """ 
    309     Plot the target value for the data.  This could be the data itself, 
    310     the theory calculation, or the residuals. 
    311  
    312     *scale* can be 'log' for log scale data, or 'linear'. 
    313     """ 
    314     import matplotlib.pyplot as plt 
    315     from numpy.ma import masked_array 
    316  
    317     image = np.zeros_like(data.qx_data) 
    318     image[~data.mask] = signal 
    319     valid = np.isfinite(image) 
    320     if view == 'log': 
    321         valid[valid] = (image[valid] > 0) 
    322         image[valid] = np.log10(image[valid]) 
    323     elif view == 'q4': 
    324         image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2 
    325     image[~valid | data.mask] = 0 
    326     #plottable = Iq 
    327     plottable = masked_array(image, ~valid | data.mask) 
    328     xmin, xmax = min(data.qx_data), max(data.qx_data) 
    329     ymin, ymax = min(data.qy_data), max(data.qy_data) 
    330     # TODO: fix vmin, vmax so it is shared for theory/resid 
    331     vmin = vmax = None 
    332     try: 
    333         if vmin is None: vmin = image[valid & ~data.mask].min() 
    334         if vmax is None: vmax = image[valid & ~data.mask].max() 
    335     except: 
    336         vmin, vmax = 0, 1 
    337     #print vmin,vmax 
    338     plt.imshow(plottable.reshape(128, 128), 
    339                interpolation='nearest', aspect=1, origin='upper', 
    340                extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax) 
    341  
    342  
    34336class Model(object): 
    344     def __init__(self, kernel, **kw): 
    345         from bumps.names import Parameter 
    346  
    347         self.kernel = kernel 
    348         partype = kernel.info['partype'] 
     37    def __init__(self, model, **kw): 
     38        self._sasmodel = model 
     39        partype = model.info['partype'] 
    34940 
    35041        pars = [] 
    351         for p in kernel.info['parameters']: 
     42        for p in model.info['parameters']: 
    35243            name, default, limits = p[0], p[2], p[3] 
    35344            value = kw.pop(name, default) 
     
    37869        return dict((k, getattr(self, k)) for k in self._parameter_names) 
    37970 
    380 class Experiment(object): 
     71 
     72class Experiment(DataMixin): 
    38173    """ 
    38274    Return a bumps wrapper for a SAS model. 
     
    39789 
    39890        # remember inputs so we can inspect from outside 
    399         self.data = data 
    40091        self.model = model 
    40192        self.cutoff = cutoff 
    402         if hasattr(data, 'lam'): 
    403             self.data_type = 'sesans' 
    404         elif hasattr(data, 'qx_data'): 
    405             self.data_type = 'Iqxy' 
    406         else: 
    407             self.data_type = 'Iq' 
    408  
    409         # interpret data 
    410         partype = model.kernel.info['partype'] 
    411         if self.data_type == 'sesans': 
    412             q = sesans.make_q(data.sample.zacceptance, data.Rmax) 
    413             self.index = slice(None, None) 
    414             self.Iq = data.y 
    415             self.dIq = data.dy 
    416             #self._theory = np.zeros_like(q) 
    417             q_vectors = [q] 
    418         elif self.data_type == 'Iqxy': 
    419             q = np.sqrt(data.qx_data**2 + data.qy_data**2) 
    420             qmin = getattr(data, 'qmin', 1e-16) 
    421             qmax = getattr(data, 'qmax', np.inf) 
    422             accuracy = getattr(data, 'accuracy', 'Low') 
    423             self.index = (~data.mask) & (~np.isnan(data.data)) \ 
    424                          & (q >= qmin) & (q <= qmax) 
    425             self.Iq = data.data[self.index] 
    426             self.dIq = data.err_data[self.index] 
    427             self.resolution = Pinhole2D(data=data, index=self.index, 
    428                                         nsigma=3.0, accuracy=accuracy) 
    429             #self._theory = np.zeros_like(self.Iq) 
    430             if not partype['orientation'] and not partype['magnetic']: 
    431                 raise ValueError("not 2D without orientation or magnetic parameters") 
    432                 #qx,qy = self.resolution.q_calc 
    433                 #q_vectors = [np.sqrt(qx**2 + qy**2)] 
    434             else: 
    435                 q_vectors = self.resolution.q_calc 
    436         elif self.data_type == 'Iq': 
    437             self.index = (data.x >= data.qmin) & (data.x <= data.qmax) & ~np.isnan(data.y) 
    438             self.Iq = data.y[self.index] 
    439             self.dIq = data.dy[self.index] 
    440             if getattr(data, 'dx', None) is not None: 
    441                 q, dq = data.x[self.index], data.dx[self.index] 
    442                 if (dq>0).any(): 
    443                     self.resolution = Pinhole1D(q, dq) 
    444                 else: 
    445                     self.resolution = Perfect1D(q) 
    446             elif (getattr(data, 'dxl', None) is not None and 
    447                   getattr(data, 'dxw', None) is not None): 
    448                 q = data.x[self.index] 
    449                 width = data.dxh[self.index]  # Note: dx 
    450                 self.resolution = Slit1D(data.x[self.index], 
    451                                          width=data.dxh[self.index], 
    452                                          height=data.dxw[self.index]) 
    453             else: 
    454                 self.resolution = Perfect1D(data.x[self.index]) 
    455  
    456             #self._theory = np.zeros_like(self.Iq) 
    457             q_vectors = [self.resolution.q_calc] 
    458         else: 
    459             raise ValueError("Unknown data type") # never gets here 
    460  
    461         # Remember function inputs so we can delay loading the function and 
    462         # so we can save/restore state 
    463         self._fn_inputs = [v for v in q_vectors] 
    464         self._fn = None 
    465  
     93        self._interpret_data(data, model._sasmodel) 
    46694        self.update() 
    46795 
     
    483111    def theory(self): 
    484112        if 'theory' not in self._cache: 
     113            pars = dict((k, v.value) for k,v in self.model.parameters().items()) 
     114            self._cache['theory'] = self._calc_theory(pars, cutoff=self.cutoff) 
     115            """ 
    485116            if self._fn is None: 
    486                 q_input = self.model.kernel.make_input(self._fn_inputs) 
     117                q_input = self.model.kernel.make_input(self._kernel_inputs) 
    487118                self._fn = self.model.kernel(q_input) 
    488119 
     
    495126                result = sesans.hankel(self.data.x, self.data.lam * 1e-9, 
    496127                                       self.data.sample.thickness / 10, 
    497                                        self._fn_inputs[0], Iq_calc) 
     128                                       self._kernel_inputs[0], Iq_calc) 
    498129                self._cache['theory'] = result 
    499130            else: 
    500131                Iq = self.resolution.apply(Iq_calc) 
    501132                self._cache['theory'] = Iq 
     133            """ 
    502134        return self._cache['theory'] 
    503135 
     
    518150        Plot the data and residuals. 
    519151        """ 
    520         data, theory, resid = self.data, self.theory(), self.residuals() 
    521         if self.data_type == 'Iq': 
    522             _plot_result1D(data, theory, resid, view) 
    523         elif self.data_type == 'Iqxy': 
    524             _plot_result2D(data, theory, resid, view) 
    525         elif self.data_type == 'sesans': 
    526             _plot_sesans(data, theory, resid, view) 
    527         else: 
    528             raise ValueError("Unknown data type") 
     152        data, theory, resid = self._data, self.theory(), self.residuals() 
     153        plot_theory(data, theory, resid, view) 
    529154 
    530155    def simulate_data(self, noise=None): 
    531         theory = self.theory() 
    532         if noise is not None: 
    533             self.dIq = theory*noise*0.01 
    534         dy = self.dIq 
    535         y = theory + np.random.randn(*dy.shape) * dy 
    536         self.Iq = y 
    537         if self.data_type == 'Iq': 
    538             self.data.dy[self.index] = dy 
    539             self.data.y[self.index] = y 
    540         elif self.data_type == 'Iqxy': 
    541             self.data.data[self.index] = y 
    542         elif self.data_type == 'sesans': 
    543             self.data.y[self.index] = y 
    544         else: 
    545             raise ValueError("Unknown model") 
     156        Iq = self.theory() 
     157        self._set_data(Iq, noise) 
    546158 
    547159    def save(self, basename): 
    548160        pass 
    549161 
    550     def _get_weights(self, par): 
     162    def remove_get_weights(self, name): 
    551163        """ 
    552164        Get parameter dispersion weights 
    553165        """ 
    554         from . import weights 
    555  
    556         relative = self.model.kernel.info['partype']['pd-rel'] 
    557         limits = self.model.kernel.info['limits'] 
     166        info = self.model.kernel.info 
     167        relative = name in info['partype']['pd-rel'] 
     168        limits = info['limits'][name] 
    558169        disperser, value, npts, width, nsigma = [ 
    559             getattr(self.model, par + ext) 
     170            getattr(self.model, name + ext) 
    560171            for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')] 
    561172        value, weight = weights.get_weights( 
    562173            disperser, int(npts.value), width.value, nsigma.value, 
    563             value.value, limits[par], par in relative) 
     174            value.value, limits, relative) 
    564175        return value, weight / np.sum(weight) 
    565176 
     
    567178        # Can't pickle gpu functions, so instead make them lazy 
    568179        state = self.__dict__.copy() 
    569         state['_fn'] = None 
     180        state['_kernel'] = None 
    570181        return state 
    571182 
     
    573184        # pylint: disable=attribute-defined-outside-init 
    574185        self.__dict__ = state 
    575  
    576  
    577 def demo(): 
    578     data = load_data('DEC07086.DAT') 
    579     set_beam_stop(data, 0.004) 
    580     plot_data(data) 
    581     import matplotlib.pyplot as plt; plt.show() 
    582  
    583  
    584 if __name__ == "__main__": 
    585     demo() 
  • sasmodels/core.py

    rbcd3aa3 r7cf2cfd  
    123123    """ 
    124124    relative = name in info['partype']['pd-rel'] 
    125     limits = info['limits'] 
     125    limits = info['limits'][name] 
    126126    disperser = pars.get(name+'_pd_type', 'gaussian') 
    127127    value = pars.get(name, info['defaults'][name]) 
     
    130130    nsigma = pars.get(name+'_pd_nsigma', 3.0) 
    131131    value,weight = weights.get_weights( 
    132         disperser, npts, width, nsigma, 
    133         value, limits[name], relative) 
    134     return value,weight/np.sum(weight) 
     132        disperser, npts, width, nsigma, value, limits, relative) 
     133    return value, weight / np.sum(weight) 
    135134 
    136135def dispersion_mesh(pars): 
  • sasmodels/direct_model.py

    raa4946b r7cf2cfd  
    55from .core import load_model_definition, load_model, make_kernel 
    66from .core import call_kernel, call_ER, call_VR 
     7from . import sesans 
     8from . import resolution 
     9from . import resolution2d 
    710 
    8 class DirectModel: 
    9     def __init__(self, name, q_vectors=None, dtype='single'): 
    10         self.model_definition = load_model_definition(name) 
    11         self.model = load_model(self.model_definition, dtype=dtype) 
    12         if q_vectors is not None: 
    13             q_vectors = [np.ascontiguousarray(q,dtype=dtype) for q in q_vectors] 
    14             self.kernel = make_kernel(self.model, q_vectors) 
     11class DataMixin(object): 
     12    """ 
     13    DataMixin captures the common aspects of evaluating a SAS model for a 
     14    particular data set, including calculating Iq and evaluating the 
     15    resolution function.  It is used in particular by :class:`DirectModel`, 
     16    which evaluates a SAS model parameters as key word arguments to the 
     17    calculator method, and by :class:`bumps_model.Experiment`, which wraps the 
     18    model and data for use with the Bumps fitting engine.  It is not 
     19    currently used by :class:`sasview_model.SasviewModel` since this will 
     20    require a number of changes to SasView before we can do it. 
     21    """ 
     22    def _interpret_data(self, data, model): 
     23        self._data = data 
     24        self._model = model 
     25 
     26        # interpret data 
     27        if hasattr(data, 'lam'): 
     28            self.data_type = 'sesans' 
     29        elif hasattr(data, 'qx_data'): 
     30            self.data_type = 'Iqxy' 
     31        else: 
     32            self.data_type = 'Iq' 
     33 
     34        partype = model.info['partype'] 
     35 
     36        if self.data_type == 'sesans': 
     37            q = sesans.make_q(data.sample.zacceptance, data.Rmax) 
     38            self.index = slice(None, None) 
     39            if data.y is not None: 
     40                self.Iq = data.y 
     41                self.dIq = data.dy 
     42            #self._theory = np.zeros_like(q) 
     43            q_vectors = [q] 
     44        elif self.data_type == 'Iqxy': 
     45            if not partype['orientation'] and not partype['magnetic']: 
     46                raise ValueError("not 2D without orientation or magnetic parameters") 
     47            q = np.sqrt(data.qx_data**2 + data.qy_data**2) 
     48            qmin = getattr(data, 'qmin', 1e-16) 
     49            qmax = getattr(data, 'qmax', np.inf) 
     50            accuracy = getattr(data, 'accuracy', 'Low') 
     51            self.index = ~data.mask & (q >= qmin) & (q <= qmax) 
     52            if data.data is not None: 
     53                self.index &= ~np.isnan(data.data) 
     54                self.Iq = data.data[self.index] 
     55                self.dIq = data.err_data[self.index] 
     56            self.resolution = resolution2d.Pinhole2D(data=data, index=self.index, 
     57                                                     nsigma=3.0, accuracy=accuracy) 
     58            #self._theory = np.zeros_like(self.Iq) 
     59            q_vectors = self.resolution.q_calc 
     60        elif self.data_type == 'Iq': 
     61            self.index = (data.x >= data.qmin) & (data.x <= data.qmax) 
     62            if data.y is not None: 
     63                self.index &= ~np.isnan(data.y) 
     64                self.Iq = data.y[self.index] 
     65                self.dIq = data.dy[self.index] 
     66            if getattr(data, 'dx', None) is not None: 
     67                q, dq = data.x[self.index], data.dx[self.index] 
     68                if (dq>0).any(): 
     69                    self.resolution = resolution.Pinhole1D(q, dq) 
     70                else: 
     71                    self.resolution = resolution.Perfect1D(q) 
     72            elif (getattr(data, 'dxl', None) is not None and 
     73                          getattr(data, 'dxw', None) is not None): 
     74                self.resolution = resolution.Slit1D(data.x[self.index], 
     75                                                    width=data.dxh[self.index], 
     76                                                    height=data.dxw[self.index]) 
     77            else: 
     78                self.resolution = resolution.Perfect1D(data.x[self.index]) 
     79 
     80            #self._theory = np.zeros_like(self.Iq) 
     81            q_vectors = [self.resolution.q_calc] 
     82        else: 
     83            raise ValueError("Unknown data type") # never gets here 
     84 
     85        # Remember function inputs so we can delay loading the function and 
     86        # so we can save/restore state 
     87        self._kernel_inputs = [v for v in q_vectors] 
     88        self._kernel = None 
     89 
     90    def _set_data(self, Iq, noise=None): 
     91        if noise is not None: 
     92            self.dIq = Iq*noise*0.01 
     93        dy = self.dIq 
     94        y = Iq + np.random.randn(*dy.shape) * dy 
     95        self.Iq = y 
     96        if self.data_type == 'Iq': 
     97            self._data.dy[self.index] = dy 
     98            self._data.y[self.index] = y 
     99        elif self.data_type == 'Iqxy': 
     100            self._data.data[self.index] = y 
     101        elif self.data_type == 'sesans': 
     102            self._data.y[self.index] = y 
     103        else: 
     104            raise ValueError("Unknown model") 
     105 
     106    def _calc_theory(self, pars, cutoff=0.0): 
     107        if self._kernel is None: 
     108            q_input = self._model.make_input(self._kernel_inputs) 
     109            self._kernel = self._model(q_input) 
     110 
     111        Iq_calc = call_kernel(self._kernel, pars, cutoff=cutoff) 
     112        if self.data_type == 'sesans': 
     113            result = sesans.hankel(self._data.x, self._data.lam * 1e-9, 
     114                                   self._data.sample.thickness / 10, 
     115                                   self._kernel_inputs[0], Iq_calc) 
     116        else: 
     117            result = self.resolution.apply(Iq_calc) 
     118        return result 
     119 
     120 
     121class DirectModel(DataMixin): 
     122    def __init__(self, data, model, cutoff=1e-5): 
     123        self.model = model 
     124        self.cutoff = cutoff 
     125        self._interpret_data(data, model) 
     126        self.kernel = make_kernel(self.model, self._kernel_inputs) 
    15127    def __call__(self, **pars): 
    16         return call_kernel(self.kernel, pars) 
     128        return self._calc_theory(pars, cutoff=self.cutoff) 
    17129    def ER(self, **pars): 
    18130        return call_ER(self.model.info, pars) 
    19131    def VR(self, **pars): 
    20132        return call_VR(self.model.info, pars) 
     133    def simulate_data(self, noise=None, **pars): 
     134        Iq = self.__call__(**pars) 
     135        self._set_data(Iq, noise=noise) 
    21136 
    22137def demo(): 
    23138    import sys 
     139    from .data import empty_data1D, empty_data2D 
     140 
    24141    if len(sys.argv) < 3: 
    25142        print "usage: python -m sasmodels.direct_model modelname (q|qx,qy) par=val ..." 
     
    27144    model_name = sys.argv[1] 
    28145    call = sys.argv[2].upper() 
    29     if call in ("ER","VR"): 
    30         q_vectors = None 
    31     else: 
    32         values = [float(v) for v in sys.argv[2].split(',')] 
     146    if call not in ("ER","VR"): 
     147        try: 
     148            values = [float(v) for v in call.split(',')] 
     149        except: 
     150            values = [] 
    33151        if len(values) == 1: 
    34             q = values[0] 
    35             q_vectors = [[q]] 
     152            q, = values 
     153            data = empty_data1D([q]) 
    36154        elif len(values) == 2: 
    37155            qx,qy = values 
    38             q_vectors = [[qx],[qy]] 
     156            data = empty_data2D([qx],[qy]) 
    39157        else: 
    40158            print "use q or qx,qy or ER or VR" 
    41159            sys.exit(1) 
    42     model = DirectModel(model_name, q_vectors) 
     160    else: 
     161        data = empty_data1D([0.001])  # Data not used in ER/VR 
     162 
     163    model_definition = load_model_definition(model_name) 
     164    model = load_model(model_definition, dtype='single') 
     165    calculator = DirectModel(data, model) 
    43166    pars = dict((k,float(v)) 
    44167                for pair in sys.argv[3:] 
    45168                for k,v in [pair.split('=')]) 
    46169    if call == "ER": 
    47         print model.ER(**pars) 
     170        print calculator.ER(**pars) 
    48171    elif call == "VR": 
    49         print model.VR(**pars) 
     172        print calculator.VR(**pars) 
    50173    else: 
    51         Iq = model(**pars) 
     174        Iq = calculator(**pars) 
    52175        print Iq[0] 
    53176 
Note: See TracChangeset for help on using the changeset viewer.