source: sasmodels/compare.py @ 09e15be

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 09e15be was 09e15be, checked in by HMP1 <helen.park@…>, 10 years ago

Attempt at faster kernel for TEST,
updated fit.py,
errors in the kernels fixed

  • Property mode set to 100644
File size: 5.3 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import datetime
5from sasmodel import SasModel, load_data, set_beam_stop, plot_data
6
7TIC = None
8def tic():
9    global TIC
10    TIC = datetime.datetime.now()
11
12def toc():
13    now = datetime.datetime.now()
14    return (now-TIC).total_seconds()
15
16def sasview_model(modelname, **pars):
17    modelname = modelname.capitalize()+"Model"
18    sans = __import__('sans.models.'+modelname)
19    ModelClass = getattr(getattr(sans.models,modelname,None),modelname,None)
20    if ModelClass is None:
21        raise ValueError("could not find model %r in sans.models"%modelname)
22    model = ModelClass()
23
24    for k,v in pars.items():
25        if k.endswith("_pd"):
26            model.dispersion[k[:-3]]['width'] = v
27        elif k.endswith("_pd_n"):
28            model.dispersion[k[:-5]]['npts'] = v
29        elif k.endswith("_pd_nsigma"):
30            model.dispersion[k[:-10]]['nsigmas'] = v
31        else:
32            model.setParam(k, v)
33    return model
34
35
36def sasview_eval(model, data):
37    theory = model.evalDistribution([data.qx_data, data.qy_data])
38    return theory
39
40def cyl(N=1):
41    import sys
42    import matplotlib.pyplot as plt
43
44    if len(sys.argv) > 1:
45        N = int(sys.argv[1])
46    data = load_data('JUN03289.DAT')
47    set_beam_stop(data, 0.004)
48
49    pars = dict(
50        scale=1, radius=64.1, length=266.96, sldCyl=.291e-6, sldSolv=5.77e-6, background=0,
51        cyl_theta=0, cyl_phi=0, radius_pd=0.1, radius_pd_n=10, radius_pd_nsigma=3,length_pd=0.1,
52        length_pd_n=5, length_pd_nsigma=3, cyl_theta_pd=0.1, cyl_theta_pd_n=5, cyl_theta_pd_nsigma=3,
53        cyl_phi_pd=0.1, cyl_phi_pd_n=10, cyl_phi_pd_nsigma=3,
54        )
55
56    model = sasview_model('cylinder', **pars)
57    tic()
58    for i in range(N):
59        cpu = sasview_eval(model, data)
60    cpu_time = toc()*1000./N
61
62    from code_cylinder import GpuCylinder
63    model = SasModel(data, GpuCylinder, dtype='f', **pars)
64    tic()
65    for i in range(N):
66        gpu = model.theory()
67    gpu_time = toc()*1000./N
68
69    relerr = (gpu - cpu)/cpu
70    print "max(|(ocl-omp)/ocl|)", max(abs(relerr))
71    print "omp t=%.1f ms"%cpu_time
72    print "ocl t=%.1f ms"%gpu_time
73
74    plt.subplot(131); plot_data(data, cpu); plt.title("omp t=%.1f ms"%cpu_time)
75    plt.subplot(132); plot_data(data, gpu); plt.title("ocl t=%.1f ms"%gpu_time)
76    plt.subplot(133); plot_data(data, 1e8*relerr); plt.title("relerr x 10^8"); plt.colorbar()
77    plt.show()
78
79def ellipse(N=4):
80    import sys
81    import matplotlib.pyplot as plt
82
83    if len(sys.argv) > 1:
84        N = int(sys.argv[1])
85    data = load_data('DEC07133.DAT')
86    set_beam_stop(data, 0.004)
87
88    pars = dict(scale=.027, radius_a=60, radius_b=180, sldEll=.297e-6, sldSolv=5.773e-6, background=4.9,
89                axis_theta=0, axis_phi=90, radius_a_pd=0.1, radius_a_pd_n=10, radius_a_pd_nsigma=3, radius_b_pd=0.1, radius_b_pd_n=10,
90                radius_b_pd_nsigma=3, axis_theta_pd=0.1, axis_theta_pd_n=6, axis_theta_pd_nsigma=3, axis_phi_pd=0.1,
91                axis_phi_pd_n=6, axis_phi_pd_nsigma=3,)
92
93    model = sasview_model('ellipsoid', **pars)
94    tic()
95    for i in range(N):
96        cpu = sasview_eval(model, data)
97    cpu_time = toc()*1000./N
98
99    from code_ellipse import GpuEllipse
100    model = SasModel(data, GpuEllipse, dtype='f', **pars)
101    tic()
102    for i in range(N):
103        gpu = model.theory()
104    gpu_time = toc()*1000./N
105
106    relerr = (gpu - cpu)/cpu
107    print "max(|(ocl-omp)/ocl|)", max(abs(relerr))
108    print "omp t=%.1f ms"%cpu_time
109    print "ocl t=%.1f ms"%gpu_time
110
111    plt.subplot(131); plot_data(data, cpu); plt.title("omp t=%.1f ms"%cpu_time)
112    plt.subplot(132); plot_data(data, gpu); plt.title("ocl t=%.1f ms"%gpu_time)
113    plt.subplot(133); plot_data(data, 1e8*relerr); plt.title("relerr x 10^8"); plt.colorbar()
114    plt.show()
115
116def coreshell(N=4):
117    import sys
118    import matplotlib.pyplot as plt
119
120    if len(sys.argv) > 1:
121        N = int(sys.argv[1])
122    data = load_data('DEC07133.DAT')
123    set_beam_stop(data, 0.004)
124
125    pars = dict(scale= 1.77881e-06, radius=325, thickness=25, length=34.2709,
126                 core_sld=1e-6, shell_sld=.291e-6, solvent_sld=7.105e-6,
127                 background=223.827, axis_theta=90, axis_phi=0,
128                 axis_theta_pd=15.8,
129                 radius_pd=0.1, radius_pd_n=1, radius_pd_nsigma=0,
130                 length_pd=0.1, length_pd_n=1, length_pd_nsigma=0,
131                 thickness_pd=0.1, thickness_pd_n=1, thickness_pd_nsigma=0,
132                 axis_theta_pd_n=10, axis_theta_pd_nsigma=3,
133                 axis_phi_pd=0.0008748, axis_phi_pd_n=10, axis_phi_pd_nsigma=3,)
134
135    model = sasview_model('CoreShellCylinder', **pars)
136    tic()
137    for i in range(N):
138        cpu = sasview_eval(model, data)
139    cpu_time = toc()*1000./N
140
141    from code_coreshellcyl import GpuCoreShellCylinder
142    model = SasModel(data, GpuCoreShellCylinder, dtype='f', **pars)
143    tic()
144    for i in range(N):
145        gpu = model.theory()
146    gpu_time = toc()*1000./N
147
148    relerr = (gpu - cpu)/cpu
149    print "max(|(ocl-omp)/ocl|)", max(abs(relerr))
150    print "omp t=%.1f ms"%cpu_time
151    print "ocl t=%.1f ms"%gpu_time
152
153    plt.subplot(131); plot_data(data, cpu); plt.title("omp t=%.1f ms"%cpu_time)
154    plt.subplot(132); plot_data(data, gpu); plt.title("ocl t=%.1f ms"%gpu_time)
155    plt.subplot(133); plot_data(data, 1e8*relerr); plt.title("relerr x 10^8"); plt.colorbar()
156    plt.show()
157
158if __name__ == "__main__":
159    coreshell()
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
Note: See TracBrowser for help on using the repository browser.