source: sasmodels/explore/jitter.py @ d4c33d6

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since d4c33d6 was d4c33d6, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

send cos-weighted theta values to kernel rather than computing them inside

  • Property mode set to 100644
File size: 8.5 KB
Line 
1"""
2Application to explore the difference between sasview 3.x orientation
3dispersity and possible replacement algorithms.
4"""
5import sys
6
7import mpl_toolkits.mplot3d   # Adds projection='3d' option to subplot
8import matplotlib.pyplot as plt
9from matplotlib.widgets import Slider, CheckButtons
10from matplotlib import cm
11import numpy as np
12from numpy import pi, cos, sin, sqrt, exp, degrees, radians
13
14def draw_beam(ax, view=(0, 0)):
15    #ax.plot([0,0],[0,0],[1,-1])
16    #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8)
17
18    steps = 25
19    u = np.linspace(0, 2 * np.pi, steps)
20    v = np.linspace(-1, 1, steps)
21
22    r = 0.02
23    x = r*np.outer(np.cos(u), np.ones_like(v))
24    y = r*np.outer(np.sin(u), np.ones_like(v))
25    z = 1.3*np.outer(np.ones_like(u), v)
26
27    theta, phi = view
28    shape = x.shape
29    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
30    points = Rz(phi)*Ry(theta)*points
31    x, y, z = [v.reshape(shape) for v in points]
32
33    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5)
34
35def draw_jitter(ax, view, jitter):
36    size = [0.1, 0.4, 1.0]
37    draw_shape = draw_parallelepiped
38    #draw_shape = draw_ellipsoid
39
40    #np.random.seed(10)
41    #cloud = np.random.randn(10,3)
42    cloud = [
43        [-1, -1, -1],
44        [-1, -1,  0],
45        [-1, -1,  1],
46        [-1,  0, -1],
47        [-1,  0,  0],
48        [-1,  0,  1],
49        [-1,  1, -1],
50        [-1,  1,  0],
51        [-1,  1,  1],
52        [ 0, -1, -1],
53        [ 0, -1,  0],
54        [ 0, -1,  1],
55        [ 0,  0, -1],
56        [ 0,  0,  0],
57        [ 0,  0,  1],
58        [ 0,  1, -1],
59        [ 0,  1,  0],
60        [ 0,  1,  1],
61        [ 1, -1, -1],
62        [ 1, -1,  0],
63        [ 1, -1,  1],
64        [ 1,  0, -1],
65        [ 1,  0,  0],
66        [ 1,  0,  1],
67        [ 1,  1, -1],
68        [ 1,  1,  0],
69        [ 1,  1,  1],
70    ]
71    dtheta, dphi, dpsi = jitter
72    if dtheta == 0:
73        cloud = [v for v in cloud if v[0] == 0]
74    if dphi == 0:
75        cloud = [v for v in cloud if v[1] == 0]
76    if dpsi == 0:
77        cloud = [v for v in cloud if v[2] == 0]
78    draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8)
79    for point in cloud:
80        delta = [dtheta*point[0], dphi*point[1], dpsi*point[2]]
81        draw_shape(ax, size, view, delta, alpha=0.8)
82    for v in 'xyz':
83        a, b, c = size
84        lim = np.sqrt(a**2+b**2+c**2)
85        getattr(ax, 'set_'+v+'lim')([-lim, lim])
86        getattr(ax, v+'axis').label.set_text(v)
87
88def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1):
89    a,b,c = size
90    u = np.linspace(0, 2 * np.pi, steps)
91    v = np.linspace(0, np.pi, steps)
92    x = a*np.outer(np.cos(u), np.sin(v))
93    y = b*np.outer(np.sin(u), np.sin(v))
94    z = c*np.outer(np.ones_like(u), np.cos(v))
95    x, y, z = transform_xyz(view, jitter, x, y, z)
96
97    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha)
98
99    draw_labels(ax, view, jitter, [
100         ('c+', [ 0, 0, c], [ 1, 0, 0]),
101         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
102         ('a+', [ a, 0, 0], [ 0, 0, 1]),
103         ('a-', [-a, 0, 0], [ 0, 0,-1]),
104         ('b+', [ 0, b, 0], [-1, 0, 0]),
105         ('b-', [ 0,-b, 0], [-1, 0, 0]),
106    ])
107
108def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1):
109    a,b,c = size
110    x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1])
111    y = b*np.array([ 1, 1,-1,-1, 1, 1,-1,-1])
112    z = c*np.array([ 1, 1, 1, 1,-1,-1,-1,-1])
113    tri = np.array([
114        # counter clockwise triangles
115        # z: up/down, x: right/left, y: front/back
116        [0,1,2], [3,2,1], # top face
117        [6,5,4], [5,6,7], # bottom face
118        [0,2,6], [6,4,0], # right face
119        [1,5,7], [7,3,1], # left face
120        [2,3,6], [7,6,3], # front face
121        [4,1,0], [5,1,4], # back face
122    ])
123
124    x, y, z = transform_xyz(view, jitter, x, y, z)
125    ax.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha)
126
127    draw_labels(ax, view, jitter, [
128         ('c+', [ 0, 0, c], [ 1, 0, 0]),
129         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
130         ('a+', [ a, 0, 0], [ 0, 0, 1]),
131         ('a-', [-a, 0, 0], [ 0, 0,-1]),
132         ('b+', [ 0, b, 0], [-1, 0, 0]),
133         ('b-', [ 0,-b, 0], [-1, 0, 0]),
134    ])
135
136def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gauss'):
137    theta, phi, psi = view
138    dtheta, dphi, dpsi = jitter
139    if dist == 'gauss':
140        t = np.linspace(-3, 3, n)
141        weights = exp(-0.5*t**2)
142    elif dist == 'rect':
143        t = np.linspace(0, 1, n)
144        weights = np.ones_like(t)
145    else:
146        raise ValueError("expected dist to be 'gauss' or 'rect'")
147
148    # mesh in theta, phi formed by rotating z
149    z = np.matrix([[0], [0], [radius]])
150    points = np.hstack([Rx(phi_i)*Ry(theta_i)*z
151                        for theta_i in dtheta*t
152                        for phi_i in dphi*t])
153    # rotate relative to beam
154    points = orient_relative_to_beam(view, points)
155
156    w = np.outer(weights, weights)
157
158    x, y, z = [np.array(v).flatten() for v in points]
159    ax.scatter(x, y, z, c=w.flatten(), marker='o', vmin=0., vmax=1.)
160
161def Rx(angle):
162    a = radians(angle)
163    R = [[1., 0., 0.],
164         [0.,  cos(a), sin(a)],
165         [0., -sin(a), cos(a)]]
166    return np.matrix(R)
167
168def Ry(angle):
169    a = radians(angle)
170    R = [[cos(a), 0., -sin(a)],
171         [0., 1., 0.],
172         [sin(a), 0.,  cos(a)]]
173    return np.matrix(R)
174
175def Rz(angle):
176    a = radians(angle)
177    R = [[cos(a), -sin(a), 0.],
178         [sin(a),  cos(a), 0.],
179         [0., 0., 1.]]
180    return np.matrix(R)
181
182def transform_xyz(view, jitter, x, y, z):
183    x, y, z = [np.asarray(v) for v in (x, y, z)]
184    shape = x.shape
185    points = np.matrix([x.flatten(),y.flatten(),z.flatten()])
186    points = apply_jitter(jitter, points)
187    points = orient_relative_to_beam(view, points)
188    x, y, z = [np.array(v).reshape(shape) for v in points]
189    return x, y, z
190
191def apply_jitter(jitter, points):
192    dtheta, dphi, dpsi = jitter
193    points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
194    return points
195
196def orient_relative_to_beam(view, points):
197    theta, phi, psi = view
198    points = Rz(phi)*Ry(theta)*Rz(psi)*points
199    return points
200
201def draw_labels(ax, view, jitter, text):
202    labels, locations, orientations = zip(*text)
203    px, py, pz = zip(*locations)
204    dx, dy, dz = zip(*orientations)
205
206    px, py, pz = transform_xyz(view, jitter, px, py, pz)
207    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
208
209    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
210        zdir = np.asarray(zdir).flatten()
211        ax.text(p[0], p[1], p[2], label, zdir=zdir)
212
213def draw_sphere(ax, radius=10., steps=100):
214    u = np.linspace(0, 2 * np.pi, steps)
215    v = np.linspace(0, np.pi, steps)
216
217    x = radius * np.outer(np.cos(u), np.sin(v))
218    y = radius * np.outer(np.sin(u), np.sin(v))
219    z = radius * np.outer(np.ones(np.size(u)), np.cos(v))
220    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w')
221
222def main():
223    #plt.hold(True)
224    plt.set_cmap('gist_earth')
225    plt.clf()
226    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
227    #ax = plt.subplot(gs[0], projection='3d')
228    ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
229
230    theta, dtheta = 70., 10.
231    phi, dphi = -45., 3.
232    psi, dpsi = -45., 3.
233    theta, phi, psi = 0, 0, 0
234    dtheta, dphi, dpsi = 0, 0, 0
235    #dist = 'rect'
236    dist = 'gauss'
237
238    axcolor = 'lightgoldenrodyellow'
239
240    axtheta  = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor)
241    axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor)
242    axpsi = plt.axes([0.1, 0.05, 0.45, 0.04], axisbg=axcolor)
243    stheta = Slider(axtheta, 'Theta', -90, 90, valinit=theta)
244    sphi = Slider(axphi, 'Phi', -180, 180, valinit=phi)
245    spsi = Slider(axpsi, 'Psi', -180, 180, valinit=psi)
246
247    axdtheta  = plt.axes([0.75, 0.15, 0.15, 0.04], axisbg=axcolor)
248    axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor)
249    axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor)
250    sdtheta = Slider(axdtheta, 'dTheta', 0, 30, valinit=dtheta)
251    sdphi = Slider(axdphi, 'dPhi', 0, 30, valinit=dphi)
252    sdpsi = Slider(axdpsi, 'dPsi', 0, 30, valinit=dpsi)
253
254    def update(val, axis=None):
255        view = stheta.val, sphi.val, spsi.val
256        jitter = sdtheta.val, sdphi.val, sdpsi.val
257        ax.cla()
258        draw_beam(ax, (0, 0))
259        draw_jitter(ax, view, jitter)
260        #draw_jitter(ax, view, (0,0,0))
261        draw_mesh(ax, view, jitter)
262        plt.gcf().canvas.draw()
263
264    stheta.on_changed(lambda v: update(v,'theta'))
265    sphi.on_changed(lambda v: update(v, 'phi'))
266    spsi.on_changed(lambda v: update(v, 'psi'))
267    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
268    sdphi.on_changed(lambda v: update(v, 'dphi'))
269    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
270
271    update(None, 'phi')
272
273    plt.show()
274
275if __name__ == "__main__":
276    main()
Note: See TracBrowser for help on using the repository browser.