Changeset 7cf2cfd in sasmodels for compare.py


Ignore:
Timestamp:
Nov 22, 2015 9:37:15 PM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
3b4243d
Parents:
677ccf1
Message:

refactor compare.py so that bumps/sasview not required for simple tests

File:
1 edited

Legend:

Unmodified
Added
Removed
  • compare.py

    r29fc2a3 r7cf2cfd  
    66from os.path import basename, dirname, join as joinpath 
    77import glob 
     8import datetime 
    89 
    910import numpy as np 
     
    1314 
    1415 
    15 from sasmodels.bumps_model import Model, Experiment, plot_theory, tic 
    1616from sasmodels import core 
    1717from sasmodels import kerneldll 
     18from sasmodels.data import plot_theory, empty_data1D, empty_data2D 
     19from sasmodels.direct_model import DirectModel 
    1820from sasmodels.convert import revert_model 
    1921kerneldll.ALLOW_SINGLE_PRECISION_DLLS = True 
     
    2224MODELS = [basename(f)[:-3] 
    2325          for f in sorted(glob.glob(joinpath(ROOT,"sasmodels","models","[a-zA-Z]*.py")))] 
     26 
     27# CRUFT python 2.6 
     28if not hasattr(datetime.timedelta, 'total_seconds'): 
     29    def delay(dt): 
     30        """Return number date-time delta as number seconds""" 
     31        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds 
     32else: 
     33    def delay(dt): 
     34        """Return number date-time delta as number seconds""" 
     35        return dt.total_seconds() 
     36 
     37 
     38def tic(): 
     39    """ 
     40    Timer function. 
     41 
     42    Use "toc=tic()" to start the clock and "toc()" to measure 
     43    a time interval. 
     44    """ 
     45    then = datetime.datetime.now() 
     46    return lambda: delay(datetime.datetime.now() - then) 
     47 
     48 
     49def set_beam_stop(data, radius, outer=None): 
     50    """ 
     51    Add a beam stop of the given *radius*.  If *outer*, make an annulus. 
     52 
     53    Note: this function does not use the sasview package 
     54    """ 
     55    if hasattr(data, 'qx_data'): 
     56        q = np.sqrt(data.qx_data**2 + data.qy_data**2) 
     57        data.mask = (q < radius) 
     58        if outer is not None: 
     59            data.mask |= (q >= outer) 
     60    else: 
     61        data.mask = (data.x < radius) 
     62        if outer is not None: 
     63            data.mask |= (data.x >= outer) 
    2464 
    2565 
     
    99139        if p.endswith("_pd"): pars[p] = 0 
    100140 
    101 def eval_sasview(name, pars, data, Nevals=1): 
     141def eval_sasview(model_definition, pars, data, Nevals=1): 
    102142    from sas.models.qsmearing import smear_selection 
    103     model = sasview_model(name, **pars) 
     143    model = sasview_model(model_definition, **pars) 
    104144    smearer = smear_selection(data, model=model) 
    105145    value = None  # silence the linter 
     
    131171        print "... trying again with single precision" 
    132172        model = core.load_model(model_definition, dtype='single', platform="ocl") 
    133     problem = Experiment(data, Model(model, **pars), cutoff=cutoff) 
     173    calculator = DirectModel(data, model, cutoff=cutoff) 
    134174    value = None  # silence the linter 
    135175    toc = tic() 
    136176    for _ in range(max(Nevals, 1)):  # force at least one eval 
    137         #pars['scale'] = np.random.rand() 
    138         problem.update() 
    139         value = problem.theory() 
     177        value = calculator(**pars) 
    140178    average_time = toc()*1000./Nevals 
    141179    return value, average_time 
    142180 
     181 
    143182def eval_ctypes(model_definition, pars, data, dtype='double', Nevals=1, cutoff=0.): 
    144183    model = core.load_model(model_definition, dtype=dtype, platform="dll") 
    145     problem = Experiment(data, Model(model, **pars), cutoff=cutoff) 
     184    calculator = DirectModel(data, model, cutoff=cutoff) 
    146185    value = None  # silence the linter 
    147186    toc = tic() 
    148187    for _ in range(max(Nevals, 1)):  # force at least one eval 
    149         problem.update() 
    150         value = problem.theory() 
     188        value = calculator(**pars) 
    151189    average_time = toc()*1000./Nevals 
    152190    return value, average_time 
    153191 
     192 
    154193def make_data(qmax, is2D, Nq=128, resolution=0.0, accuracy='Low', view='log'): 
    155194    if is2D: 
    156         from sasmodels.bumps_model import empty_data2D, set_beam_stop 
    157195        data = empty_data2D(np.linspace(-qmax, qmax, Nq), resolution=resolution) 
    158196        data.accuracy = accuracy 
     
    160198        index = ~data.mask 
    161199    else: 
    162         from sasmodels.bumps_model import empty_data1D 
    163200        if view == 'log': 
    164201            qmax = math.log10(qmax) 
     
    190227 
    191228    # randomize parameters 
    192     pars.update(set_pars) 
     229    #pars.update(set_pars)  # set value before random to control range 
    193230    if '-random' in opts or '-random' in opt_values: 
    194231        seed = int(opt_values['-random']) if '-random' in opt_values else None 
    195232        pars, seed = randomize_model(name, pars, seed=seed) 
    196233        print "Randomize using -random=%i"%seed 
     234    pars.update(set_pars)  # set value after random to control value 
    197235 
    198236    # parameter selection 
     
    217255        print "ctypes t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu)) 
    218256    elif Ncpu > 0: 
    219         cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu) 
    220         comp = "sasview" 
    221         print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu)) 
     257        try: 
     258            cpu, cpu_time = eval_sasview(model_definition, pars, data, Ncpu) 
     259            comp = "sasview" 
     260            #print "ocl/sasview", (ocl-pars['background'])/(cpu-pars['background']) 
     261            print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu)) 
     262        except ImportError: 
     263            Ncpu = 0 
    222264 
    223265    # Compare, but only if computing both forms 
     
    238280    if Ncpu > 0: 
    239281        if Nocl > 0: plt.subplot(131) 
    240         plot_theory(data, cpu, view=view) 
     282        plot_theory(data, cpu, view=view, plot_data=False) 
    241283        plt.title("%s t=%.1f ms"%(comp,cpu_time)) 
    242         cbar_title = "log I" 
     284        #cbar_title = "log I" 
    243285    if Nocl > 0: 
    244286        if Ncpu > 0: plt.subplot(132) 
    245         plot_theory(data, ocl, view=view) 
     287        plot_theory(data, ocl, view=view, plot_data=False) 
    246288        plt.title("opencl t=%.1f ms"%ocl_time) 
    247         cbar_title = "log I" 
     289        #cbar_title = "log I" 
    248290    if Ncpu > 0 and Nocl > 0: 
    249291        plt.subplot(133) 
     
    253295            err,errstr,errview = abs(relerr), "rel err", "log" 
    254296        #err,errstr = ocl/cpu,"ratio" 
    255         plot_theory(data, err, view=errview) 
     297        plot_theory(data, None, resid=err, view=errview, plot_data=False) 
    256298        plt.title("max %s = %.3g"%(errstr, max(abs(err)))) 
    257         cbar_title = errstr if errview=="linear" else "log "+errstr 
    258     if is2D: 
    259         h = plt.colorbar() 
    260         h.ax.set_title(cbar_title) 
     299        #cbar_title = errstr if errview=="linear" else "log "+errstr 
     300    #if is2D: 
     301    #    h = plt.colorbar() 
     302    #    h.ax.set_title(cbar_title) 
    261303 
    262304    if Ncpu > 0 and Nocl > 0 and '-hist' in opts: 
     
    320362 
    321363Available models: 
    322  
    323     %s 
    324364""" 
     365 
    325366 
    326367NAME_OPTIONS = set([ 
     
    342383    ] 
    343384 
     385def columnize(L, indent="", width=79): 
     386    column_width = max(len(w) for w in L) + 1 
     387    num_columns = (width - len(indent)) // column_width 
     388    num_rows = len(L) // num_columns 
     389    L = L + [""] * (num_rows*num_columns - len(L)) 
     390    columns = [L[k*num_rows:(k+1)*num_rows] for k in range(num_columns)] 
     391    lines = [" ".join("%-*s"%(column_width, entry) for entry in row) 
     392             for row in zip(*columns)] 
     393    output = indent + ("\n"+indent).join(lines) 
     394    return output 
     395 
     396 
    344397def get_demo_pars(name): 
    345398    import sasmodels.models 
     
    355408    models = "\n    ".join("%-15s"%v for v in MODELS) 
    356409    if len(args) == 0: 
    357         print(USAGE%models) 
     410        print(USAGE) 
     411        print(columnize(MODELS, indent="  ")) 
    358412        sys.exit(1) 
    359413    if args[0] not in MODELS: 
Note: See TracChangeset for help on using the changeset viewer.