source: sasmodels/explore/jitter.py @ a0ebc96

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a0ebc96 was 8678a34, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

code cleanup on jitter visualization

  • Property mode set to 100644
File size: 8.6 KB
Line 
1"""
2Application to explore the difference between sasview 3.x orientation
3dispersity and possible replacement algorithms.
4"""
5import sys
6
7import mpl_toolkits.mplot3d   # Adds projection='3d' option to subplot
8import matplotlib.pyplot as plt
9from matplotlib.widgets import Slider, CheckButtons
10from matplotlib import cm
11import numpy as np
12from numpy import pi, cos, sin, sqrt, exp, degrees, radians
13
14def draw_beam(ax, view=(0, 0)):
15    #ax.plot([0,0],[0,0],[1,-1])
16    #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8)
17
18    steps = 25
19    u = np.linspace(0, 2 * np.pi, steps)
20    v = np.linspace(-1, 1, steps)
21
22    r = 0.02
23    x = r*np.outer(np.cos(u), np.ones_like(v))
24    y = r*np.outer(np.sin(u), np.ones_like(v))
25    z = 1.3*np.outer(np.ones_like(u), v)
26
27    theta, phi = view
28    shape = x.shape
29    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
30    points = Rz(phi)*Ry(theta)*points
31    x, y, z = [v.reshape(shape) for v in points]
32
33    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5)
34
35def draw_jitter(ax, view, jitter):
36    size = [0.1, 0.4, 1.0]
37    draw_shape = draw_parallelepiped
38    #draw_shape = draw_ellipsoid
39
40    #np.random.seed(10)
41    #cloud = np.random.randn(10,3)
42    cloud = [
43        [-1, -1, -1],
44        [-1, -1,  0],
45        [-1, -1,  1],
46        [-1,  0, -1],
47        [-1,  0,  0],
48        [-1,  0,  1],
49        [-1,  1, -1],
50        [-1,  1,  0],
51        [-1,  1,  1],
52        [ 0, -1, -1],
53        [ 0, -1,  0],
54        [ 0, -1,  1],
55        [ 0,  0, -1],
56        [ 0,  0,  0],
57        [ 0,  0,  1],
58        [ 0,  1, -1],
59        [ 0,  1,  0],
60        [ 0,  1,  1],
61        [ 1, -1, -1],
62        [ 1, -1,  0],
63        [ 1, -1,  1],
64        [ 1,  0, -1],
65        [ 1,  0,  0],
66        [ 1,  0,  1],
67        [ 1,  1, -1],
68        [ 1,  1,  0],
69        [ 1,  1,  1],
70    ]
71    dtheta, dphi, dpsi = jitter
72    if dtheta == 0:
73        cloud = [v for v in cloud if v[0] == 0]
74    if dphi == 0:
75        cloud = [v for v in cloud if v[1] == 0]
76    if dpsi == 0:
77        cloud = [v for v in cloud if v[2] == 0]
78    draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8)
79    for point in cloud:
80        delta = [dtheta*point[0], dphi*point[1], dpsi*point[2]]
81        draw_shape(ax, size, view, delta, alpha=0.8)
82    for v in 'xyz':
83        a, b, c = size
84        lim = np.sqrt(a**2+b**2+c**2)
85        getattr(ax, 'set_'+v+'lim')([-lim, lim])
86        getattr(ax, v+'axis').label.set_text(v)
87
88def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1):
89    a,b,c = size
90    u = np.linspace(0, 2 * np.pi, steps)
91    v = np.linspace(0, np.pi, steps)
92    x = a*np.outer(np.cos(u), np.sin(v))
93    y = b*np.outer(np.sin(u), np.sin(v))
94    z = c*np.outer(np.ones_like(u), np.cos(v))
95    x, y, z = transform_xyz(view, jitter, x, y, z)
96
97    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha)
98
99    draw_labels(ax, view, jitter, [
100         ('c+', [ 0, 0, c], [ 1, 0, 0]),
101         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
102         ('a+', [ a, 0, 0], [ 0, 0, 1]),
103         ('a-', [-a, 0, 0], [ 0, 0,-1]),
104         ('b+', [ 0, b, 0], [-1, 0, 0]),
105         ('b-', [ 0,-b, 0], [-1, 0, 0]),
106    ])
107
108def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1):
109    a,b,c = size
110    x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1])
111    y = b*np.array([ 1, 1,-1,-1, 1, 1,-1,-1])
112    z = c*np.array([ 1, 1, 1, 1,-1,-1,-1,-1])
113    tri = np.array([
114        # counter clockwise triangles
115        # z: up/down, x: right/left, y: front/back
116        [0,1,2], [3,2,1], # top face
117        [6,5,4], [5,6,7], # bottom face
118        [0,2,6], [6,4,0], # right face
119        [1,5,7], [7,3,1], # left face
120        [2,3,6], [7,6,3], # front face
121        [4,1,0], [5,1,4], # back face
122    ])
123
124    x, y, z = transform_xyz(view, jitter, x, y, z)
125    ax.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha)
126
127    draw_labels(ax, view, jitter, [
128         ('c+', [ 0, 0, c], [ 1, 0, 0]),
129         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
130         ('a+', [ a, 0, 0], [ 0, 0, 1]),
131         ('a-', [-a, 0, 0], [ 0, 0,-1]),
132         ('b+', [ 0, b, 0], [-1, 0, 0]),
133         ('b-', [ 0,-b, 0], [-1, 0, 0]),
134    ])
135
136def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gauss'):
137    theta, phi, psi = view
138    dtheta, dphi, dpsi = jitter
139    if dist == 'gauss':
140        t = np.linspace(-3, 3, n)
141        weights = exp(-0.5*t**2)
142    elif dist == 'rect':
143        t = np.linspace(0, 1, n)
144        weights = np.ones_like(t)
145    else:
146        raise ValueError("expected dist to be 'gauss' or 'rect'")
147
148    # mesh in theta, phi formed by rotating z
149    z = np.matrix([[0], [0], [radius]])
150    points = np.hstack([Rx(phi_i)*Ry(theta_i)*z
151                        for theta_i in dtheta*t
152                        for phi_i in dphi*t])
153    # rotate relative to beam
154    points = orient_relative_to_beam(view, points)
155
156    w = np.outer(weights, weights)
157
158    x, y, z = [np.array(v).flatten() for v in points]
159    ax.scatter(x, y, z, c=w.flatten(), marker='o', vmin=0., vmax=1.)
160
161def Rx(angle):
162    a = radians(angle)
163    R = [[1., 0., 0.],
164         [0.,  cos(a), sin(a)],
165         [0., -sin(a), cos(a)]]
166    return np.matrix(R)
167
168def Ry(angle):
169    a = radians(angle)
170    R = [[cos(a), 0., -sin(a)],
171         [0., 1., 0.],
172         [sin(a), 0.,  cos(a)]]
173    return np.matrix(R)
174
175def Rz(angle):
176    a = radians(angle)
177    R = [[cos(a), -sin(a), 0.],
178         [sin(a),  cos(a), 0.],
179         [0., 0., 1.]]
180    return np.matrix(R)
181
182def transform_xyz(view, jitter, x, y, z):
183    x, y, z = [np.asarray(v) for v in (x, y, z)]
184    shape = x.shape
185    points = np.matrix([x.flatten(),y.flatten(),z.flatten()])
186    points = apply_jitter(jitter, points)
187    points = orient_relative_to_beam(view, points)
188    x, y, z = [np.array(v).reshape(shape) for v in points]
189    return x, y, z
190
191def apply_jitter(jitter, points):
192    dtheta, dphi, dpsi = jitter
193    points = Rz(dpsi)*Ry(dtheta)*Rx(dphi)*points
194    return points
195
196def orient_relative_to_beam(view, points):
197    theta, phi, psi = view
198    points = Rz(phi)*Ry(theta)*Rz(psi)*points
199    return points
200
201def draw_labels(ax, view, jitter, text):
202    labels, locations, orientations = zip(*text)
203    px, py, pz = zip(*locations)
204    dx, dy, dz = zip(*orientations)
205
206    px, py, pz = transform_xyz(view, jitter, px, py, pz)
207    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
208
209    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
210        zdir = np.asarray(zdir).flatten()
211        ax.text(p[0], p[1], p[2], label, zdir=zdir)
212
213def draw_sphere(ax, radius=10., steps=100):
214    u = np.linspace(0, 2 * np.pi, steps)
215    v = np.linspace(0, np.pi, steps)
216
217    x = radius * np.outer(np.cos(u), np.sin(v))
218    y = radius * np.outer(np.sin(u), np.sin(v))
219    z = radius * np.outer(np.ones(np.size(u)), np.cos(v))
220    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w')
221
222def main():
223    #plt.hold(True)
224    plt.set_cmap('gist_earth')
225    plt.clf()
226    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
227    #ax = plt.subplot(gs[0], projection='3d')
228    ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
229
230    theta, dtheta = 70., 10.
231    phi, dphi = -45., 3.
232    psi, dpsi = -45., 3.
233    theta, phi, psi = 0, 0, 0
234    dtheta, dphi, dpsi = 0, 0, 0
235    #dist = 'rect'
236    dist = 'gauss'
237
238    axcolor = 'lightgoldenrodyellow'
239
240    axtheta  = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor)
241    axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor)
242    axpsi = plt.axes([0.1, 0.05, 0.45, 0.04], axisbg=axcolor)
243    stheta = Slider(axtheta, 'Theta', -90, 90, valinit=theta)
244    sphi = Slider(axphi, 'Phi', -180, 180, valinit=phi)
245    spsi = Slider(axpsi, 'Psi', -180, 180, valinit=psi)
246
247    axdtheta  = plt.axes([0.75, 0.15, 0.15, 0.04], axisbg=axcolor)
248    axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor)
249    axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor)
250    sdtheta = Slider(axdtheta, 'dTheta', 0, 30, valinit=dtheta)
251    sdphi = Slider(axdphi, 'dPhi', 0, 30, valinit=dphi)
252    sdpsi = Slider(axdpsi, 'dPsi', 0, 30, valinit=dpsi)
253
254    def update(val, axis=None):
255        view = stheta.val, sphi.val, spsi.val
256        jitter = sdtheta.val, sdphi.val, sdpsi.val
257        ax.cla()
258        draw_beam(ax, (0, 0))
259        if 0:
260            draw_jitter(ax, view, jitter)
261        else:
262            draw_jitter(ax, view, (0,0,0))
263            draw_mesh(ax, view, jitter)
264        plt.gcf().canvas.draw()
265
266    stheta.on_changed(lambda v: update(v,'theta'))
267    sphi.on_changed(lambda v: update(v, 'phi'))
268    spsi.on_changed(lambda v: update(v, 'psi'))
269    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
270    sdphi.on_changed(lambda v: update(v, 'dphi'))
271    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
272
273    update(None, 'phi')
274
275    plt.show()
276
277if __name__ == "__main__":
278    main()
Note: See TracBrowser for help on using the repository browser.