Changeset 13d86bc in sasmodels for compare-new.py


Ignore:
Timestamp:
Aug 26, 2014 9:06:58 AM (10 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
a7684e5
Parents:
32c160a
Message:

1D model comparison with sasview

File:
1 edited

Legend:

Unmodified
Added
Removed
  • compare-new.py

    rce27e21 r13d86bc  
    66import numpy as np 
    77 
    8 from sasmodels.core import BumpsModel, fake_data2D, set_beam_stop, plot_data, \ 
    9     tic, opencl_model, dll_model 
     8from sasmodels.core import BumpsModel, plot_data, tic, opencl_model, dll_model 
    109 
    1110def sasview_model(modelname, **pars): 
     
    2726        elif k.endswith("_pd_nsigma"): 
    2827            model.dispersion[k[:-10]]['nsigmas'] = v 
     28        elif k.endswith("_pd_type"): 
     29            model.dispersion[k[:-8]]['type'] = v 
    2930        else: 
    3031            model.setParam(k, v) 
    3132    return model 
    3233 
     34def load_opencl(modelname, dtype='single'): 
     35    sasmodels = __import__('sasmodels.models.'+modelname) 
     36    module = getattr(sasmodels.models, modelname, None) 
     37    kernel = opencl_model(module, dtype=dtype) 
     38    return kernel 
     39 
     40def load_dll(modelname, dtype='single'): 
     41    sasmodels = __import__('sasmodels.models.'+modelname) 
     42    module = getattr(sasmodels.models, modelname, None) 
     43    kernel = dll_model(module, dtype=dtype) 
     44    return kernel 
     45 
    3346 
    3447def compare(Ncpu, cpuname, cpupars, Ngpu, gpuname, gpupars): 
     
    3649    #from sasmodels.core import load_data 
    3750    #data = load_data('December/DEC07098.DAT') 
    38     data = fake_data2D(np.linspace(-0.05, 0.05, 128)) 
    39     set_beam_stop(data, 0.004) 
     51    from sasmodels.core import empty_data1D 
     52    data = empty_data1D(np.logspace(-4, -1, 128)) 
     53    #from sasmodels.core import empty_2D, set_beam_stop 
     54    #data = empty_data2D(np.linspace(-0.05, 0.05, 128)) 
     55    #set_beam_stop(data, 0.004) 
     56    is2D = hasattr(data, 'qx_data') 
    4057 
    4158    if Ngpu > 0: 
    42         gpumodel = opencl_model(gpuname, dtype='single') 
     59        gpumodel = load_opencl(gpuname, dtype='single') 
    4360        model = BumpsModel(data, gpumodel, **gpupars) 
    4461        toc = tic() 
     
    5269 
    5370    if 0 and Ncpu > 0: # Hack to compare ctypes vs. opencl 
    54         dllmodel = dll_model(gpuname) 
     71        dllmodel = load_dll(gpuname, dtype='double') 
    5572        model = BumpsModel(data, dllmodel, **gpupars) 
    5673        toc = tic() 
     
    7592        cpumodel = sasview_model(cpuname, **cpupars) 
    7693        toc = tic() 
    77         for i in range(Ncpu): 
    78             cpu = cpumodel.evalDistribution([data.qx_data, data.qy_data]) 
     94        if is2D: 
     95            for i in range(Ncpu): 
     96                cpu = cpumodel.evalDistribution([data.qx_data, data.qy_data]) 
     97        else: 
     98            for i in range(Ncpu): 
     99                cpu = cpumodel.evalDistribution(data.x) 
    79100        cpu_time = toc()*1000./Ncpu 
    80101        print "sasview t=%.1f ms, intensity=%.0f"%(cpu_time, sum(cpu[model.index])) 
     
    92113    if Ncpu > 0: 
    93114        if Ngpu > 0: plt.subplot(131) 
    94         plot_data(data, cpu) 
     115        plot_data(data, cpu, scale='log') 
    95116        plt.title("omp t=%.1f ms"%cpu_time) 
    96117    if Ngpu > 0: 
    97118        if Ncpu > 0: plt.subplot(132) 
    98         plot_data(data, gpu) 
     119        plot_data(data, gpu, scale='log') 
    99120        plt.title("ocl t=%.1f ms"%gpu_time) 
    100121    if Ncpu > 0 and Ngpu > 0: 
    101122        plt.subplot(133) 
    102         plot_data(data, 1e8*relerr) 
     123        plot_data(data, 1e8*relerr, scale='linear') 
    103124        plt.title("max rel err = %.3g"%max(abs(relerr))) 
    104         plt.colorbar() 
     125        if is2D: plt.colorbar() 
    105126    plt.show() 
    106127 
Note: See TracChangeset for help on using the changeset viewer.