source: sasmodels/sasmodels/jitter.py @ 1511a60c

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 1511a60c was 1511a60c, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

more jitter plot tweaks to help with euler angle figures

  • Property mode set to 100755
File size: 48.6 KB
Line 
1#!/usr/bin/env python
2"""
3Jitter Explorer
4===============
5
6Application to explore orientation angle and angular dispersity.
7
8From the command line::
9
10    # Show docs
11    python -m sasmodels.jitter --help
12
13    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi
14    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0
15
16From a jupyter cell::
17
18    import ipyvolume as ipv
19    from sasmodels import jitter
20    import importlib; importlib.reload(jitter)
21    jitter.set_plotter("ipv")
22
23    size = (10, 40, 100)
24    view = (20, 0, 0)
25
26    #size = (15, 15, 100)
27    #view = (60, 60, 0)
28
29    dview = (0, 0, 0)
30    #dview = (5, 5, 0)
31    #dview = (15, 180, 0)
32    #dview = (180, 15, 0)
33
34    projection = 'equirectangular'
35    #projection = 'azimuthal_equidistance'
36    #projection = 'guyou'
37    #projection = 'sinusoidal'
38    #projection = 'azimuthal_equal_area'
39
40    dist = 'uniform'
41    #dist = 'gaussian'
42
43    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection)
44    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '')
45    #ipv.savefig(filename+'.png')
46"""
47from __future__ import division, print_function
48
49import argparse
50
51import numpy as np
52from numpy import pi, cos, sin, sqrt, exp, degrees, radians
53
54def draw_beam(axes, view=(0, 0), alpha=0.5, steps=6):
55    """
56    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
57    """
58    #axes.plot([0,0],[0,0],[1,-1])
59    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha)
60
61    u = np.linspace(0, 2 * np.pi, steps)
62    v = np.linspace(-1, 1, 2)
63
64    r = 0.02
65    x = r*np.outer(np.cos(u), np.ones_like(v))
66    y = r*np.outer(np.sin(u), np.ones_like(v))
67    z = 1.3*np.outer(np.ones_like(u), v)
68
69    theta, phi = view
70    shape = x.shape
71    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
72    points = Rz(phi)*Ry(theta)*points
73    x, y, z = [v.reshape(shape) for v in points]
74    axes.plot_surface(x, y, z, color='yellow', alpha=alpha)
75
76    # TODO: draw endcaps on beam
77    ## Draw tiny balls on the end will work
78    #draw_sphere(axes, radius=0.02, center=(0, 0, 1.3), color='yellow')
79    #draw_sphere(axes, radius=0.02, center=(0, 0, -1.3), color='yellow')
80    ## The following does not work
81    #triangles = [(0, i+1, i+2) for i in range(steps-2)]
82    #x_cap, y_cap = x[:, 0], y[:, 0]
83    #for z_cap in z[:, 0], z[:, -1]:
84    #    axes.plot_trisurf(x_cap, y_cap, z_cap, triangles,
85    #                      color='yellow', alpha=alpha)
86
87
88def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1):
89    """Draw an ellipsoid."""
90    a, b, c = size
91    u = np.linspace(0, 2 * np.pi, steps)
92    v = np.linspace(0, np.pi, steps)
93    x = a*np.outer(np.cos(u), np.sin(v))
94    y = b*np.outer(np.sin(u), np.sin(v))
95    z = c*np.outer(np.ones_like(u), np.cos(v))
96    x, y, z = transform_xyz(view, jitter, x, y, z)
97
98    axes.plot_surface(x, y, z, color='w', alpha=alpha)
99
100    draw_labels(axes, view, jitter, [
101        ('c+', [+0, +0, +c], [+1, +0, +0]),
102        ('c-', [+0, +0, -c], [+0, +0, -1]),
103        ('a+', [+a, +0, +0], [+0, +0, +1]),
104        ('a-', [-a, +0, +0], [+0, +0, -1]),
105        ('b+', [+0, +b, +0], [-1, +0, +0]),
106        ('b-', [+0, -b, +0], [-1, +0, +0]),
107    ])
108
109def draw_sc(axes, size, view, jitter, steps=None, alpha=1):
110    """Draw points for simple cubic paracrystal"""
111    atoms = _build_sc()
112    _draw_crystal(axes, size, view, jitter, atoms=atoms)
113
114def draw_fcc(axes, size, view, jitter, steps=None, alpha=1):
115    """Draw points for face-centered cubic paracrystal"""
116    # Build the simple cubic crystal
117    atoms = _build_sc()
118    # Define the centers for each face
119    # x planes at -1, 0, 1 have four centers per plane, at +/- 0.5 in y and z
120    x, y, z = (
121        [-1]*4 + [0]*4 + [1]*4,
122        ([-0.5]*2 + [0.5]*2)*3,
123        [-0.5, 0.5]*12,
124    )
125    # y and z planes can be generated by substituting x for y and z respectively
126    atoms.extend(zip(x+y+z, y+z+x, z+x+y))
127    _draw_crystal(axes, size, view, jitter, atoms=atoms)
128
129def draw_bcc(axes, size, view, jitter, steps=None, alpha=1):
130    """Draw points for body-centered cubic paracrystal"""
131    # Build the simple cubic crystal
132    atoms = _build_sc()
133    # Define the centers for each octant
134    # x plane at +/- 0.5 have four centers per plane at +/- 0.5 in y and z
135    x, y, z = (
136        [-0.5]*4 + [0.5]*4,
137        ([-0.5]*2 + [0.5]*2)*2,
138        [-0.5, 0.5]*8,
139    )
140    atoms.extend(zip(x, y, z))
141    _draw_crystal(axes, size, view, jitter, atoms=atoms)
142
143def _draw_crystal(axes, size, view, jitter, atoms=None):
144    atoms, size = np.asarray(atoms, 'd').T, np.asarray(size, 'd')
145    x, y, z = atoms*size[:, None]
146    x, y, z = transform_xyz(view, jitter, x, y, z)
147    axes.scatter([x[0]], [y[0]], [z[0]], c='yellow', marker='o')
148    axes.scatter(x[1:], y[1:], z[1:], c='r', marker='o')
149
150def _build_sc():
151    # three planes of 9 dots for x at -1, 0 and 1
152    x, y, z = (
153        [-1]*9 + [0]*9 + [1]*9,
154        ([-1]*3 + [0]*3 + [1]*3)*3,
155        [-1, 0, 1]*9,
156    )
157    atoms = list(zip(x, y, z))
158    #print(list(enumerate(atoms)))
159    # Pull the dot at (0, 0, 1) to the front of the list
160    # It will be highlighted in the view
161    index = 14
162    highlight = atoms[index]
163    del atoms[index]
164    atoms.insert(0, highlight)
165    return atoms
166
167def draw_box(axes, size, view):
168    a, b, c = size
169    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
170    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
171    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
172    x, y, z = transform_xyz(view, None, x, y, z)
173    def draw(i, j):
174        axes.plot([x[i],x[j]], [y[i], y[j]], [z[i], z[j]], color='black')
175    draw(0, 1)
176    draw(0, 2)
177    draw(0, 3)
178    draw(7, 4)
179    draw(7, 5)
180    draw(7, 6)
181
182def draw_parallelepiped(axes, size, view, jitter, steps=None, 
183                        color=(0.6, 1.0, 0.6), alpha=1):
184    """Draw a parallelepiped."""
185    a, b, c = size
186    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
187    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
188    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
189    tri = np.array([
190        # counter clockwise triangles
191        # z: up/down, x: right/left, y: front/back
192        [0, 1, 2], [3, 2, 1], # top face
193        [6, 5, 4], [5, 6, 7], # bottom face
194        [0, 2, 6], [6, 4, 0], # right face
195        [1, 5, 7], [7, 3, 1], # left face
196        [2, 3, 6], [7, 6, 3], # front face
197        [4, 1, 0], [5, 1, 4], # back face
198    ])
199
200    x, y, z = transform_xyz(view, jitter, x, y, z)
201    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
202
203    # Colour the c+ face of the box.
204    # Since I can't control face color, instead draw a thin box situated just
205    # in front of the "c+" face.  Use the c face so that rotations about psi
206    # rotate that face.
207    if 0:
208        color = (1, 0.6, 0.6)  # pink
209        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
210        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
211        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
212        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001)
213        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
214
215    draw_labels(axes, view, jitter, [
216        ('c+', [+0, +0, +c], [+1, +0, +0]),
217        ('c-', [+0, +0, -c], [+0, +0, -1]),
218        ('a+', [+a, +0, +0], [+0, +0, +1]),
219        ('a-', [-a, +0, +0], [+0, +0, -1]),
220        ('b+', [+0, +b, +0], [-1, +0, +0]),
221        ('b-', [+0, -b, +0], [-1, +0, +0]),
222    ])
223
224def draw_sphere(axes, radius=0.5, steps=25, center=(0,0,0), color='w'):
225    """Draw a sphere"""
226    u = np.linspace(0, 2 * np.pi, steps)
227    v = np.linspace(0, np.pi, steps)
228
229    x = radius * np.outer(np.cos(u), np.sin(v)) + center[0]
230    y = radius * np.outer(np.sin(u), np.sin(v)) + center[1]
231    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2]
232    axes.plot_surface(x, y, z, color=color)
233    #axes.plot_wireframe(x, y, z)
234
235def draw_axes(axes, origin=(-1, -1, -1), length=(2, 2, 2)):
236    x, y, z = origin
237    dx, dy, dz = length
238    axes.plot([x, x+dx], [y, y], [z, z], color='black')
239    axes.plot([x, x], [y, y+dy], [z, z], color='black')
240    axes.plot([x, x], [y, y], [z, z+dz], color='black')
241
242def draw_person_on_sphere(axes, view, height=0.5, radius=0.5):
243    limb_offset = height * 0.05
244    head_radius = height * 0.10
245    head_height = height - head_radius
246    neck_length = head_radius * 0.50
247    shoulder_height = height - 2*head_radius - neck_length
248    torso_length = shoulder_height * 0.55
249    torso_radius = torso_length * 0.30
250    leg_length = shoulder_height - torso_length
251    arm_length = torso_length * 0.90
252
253    def _draw_part(x, z):
254        y = np.zeros_like(x)
255        xp, yp, zp = transform_xyz(view, None, x, y, z + radius)
256        axes.plot(xp, yp, zp, color='k')
257
258    # circle for head
259    u = np.linspace(0, 2 * np.pi, 40)
260    x = head_radius * np.cos(u)
261    z = head_radius * np.sin(u) + head_height
262    _draw_part(x, z)
263
264    # rectangle for body
265    x = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius])
266    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length
267    _draw_part(x, z)
268
269    # arms
270    x = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius])
271    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height])
272    _draw_part(x, z)
273    _draw_part(-x, z)
274
275    # legs
276    x = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset])
277    z = np.array([0, leg_length])
278    _draw_part(x, z)
279    _draw_part(-x, z)
280
281    limits = [-radius-height, radius+height]
282    axes.set_xlim(limits)
283    axes.set_ylim(limits)
284    axes.set_zlim(limits)
285    axes.set_axis_off()
286
287def draw_jitter(axes, view, jitter, dist='gaussian', 
288                size=(0.1, 0.4, 1.0),
289                draw_shape=draw_parallelepiped, 
290                projection='equirectangular',
291                alpha=0.8,
292                views=None):
293    """
294    Represent jitter as a set of shapes at different orientations.
295    """
296    project, weight = get_projection(projection)
297
298    # set max diagonal to 0.95
299    scale = 0.95/sqrt(sum(v**2 for v in size))
300    size = tuple(scale*v for v in size)
301
302    dtheta, dphi, dpsi = jitter
303    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist]
304    def steps(delta):
305        limit = base*delta
306        if views is None:
307            n = max(3, min(25, 2*int(base*delta/5)))
308        else:
309            n = views
310        s = base*delta*np.linspace(-1, 1, n) if delta > 0 else [0]
311        return s
312    stheta = steps(dtheta)
313    sphi = steps(dphi)
314    spsi = steps(dpsi)
315    #print(stheta, sphi, spsi)
316    for theta in stheta:
317        for phi in sphi:
318            for psi in spsi:
319                w = weight(theta, phi, 1.0, 1.0)
320                if w > 0:
321                    dview = project(theta, phi, psi)
322                    draw_shape(axes, size, view, dview, alpha=alpha)
323    for v in 'xyz':
324        a, b, c = size
325        lim = np.sqrt(a**2 + b**2 + c**2)
326        getattr(axes, 'set_'+v+'lim')([-lim, lim])
327        #getattr(axes, v+'axis').label.set_text(v)
328
329PROJECTIONS = [
330    # in order of PROJECTION number; do not change without updating the
331    # constants in kernel_iq.c
332    'equirectangular', 'sinusoidal', 'guyou', 'azimuthal_equidistance',
333    'azimuthal_equal_area',
334]
335def get_projection(projection):
336
337    """
338    jitter projections
339    <https://en.wikipedia.org/wiki/List_of_map_projections>
340
341    equirectangular (standard latitude-longitude mesh)
342        <https://en.wikipedia.org/wiki/Equirectangular_projection>
343        Allows free movement in phi (around the equator), but theta is
344        limited to +/- 90, and points are cos-weighted. Jitter in phi is
345        uniform in weight along a line of latitude.  With small theta and
346        phi ranging over +/- 180 this forms a wobbling disk.  With small
347        phi and theta ranging over +/- 90 this forms a wedge like a slice
348        of an orange.
349    azimuthal_equidistance (Postel)
350        <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>
351        Preserves distance from center, and so is an excellent map for
352        representing a bivariate gaussian on the surface.  Theta and phi
353        operate identically, cutting wegdes from the antipode of the viewing
354        angle.  This unfortunately does not allow free movement in either
355        theta or phi since the orthogonal wobble decreases to 0 as the body
356        rotates through 180 degrees.
357    sinusoidal (Sanson-Flamsteed, Mercator equal-area)
358        <https://en.wikipedia.org/wiki/Sinusoidal_projection>
359        Preserves arc length with latitude, giving bad behaviour at
360        theta near +/- 90.  Theta and phi operate somewhat differently,
361        so a system with a-b-c dtheta-dphi-dpsi will not give the same
362        value as one with b-a-c dphi-dtheta-dpsi, as would be the case
363        for azimuthal equidistance.  Free movement using theta or phi
364        uniform over +/- 180 will work, but not as well as equirectangular
365        phi, with theta being slightly worse.  Computationally it is much
366        cheaper for wide theta-phi meshes since it excludes points which
367        lie outside the sinusoid near theta +/- 90 rather than packing
368        them close together as in equirectangle.  Note that the poles
369        will be slightly overweighted for theta > 90 with the circle
370        from theta at 90+dt winding backwards around the pole, overlapping
371        the circle from theta at 90-dt.
372    Guyou (hemisphere-in-a-square) **not weighted**
373        <https://en.wikipedia.org/wiki/Guyou_hemisphere-in-a-square_projection>
374        With tiling, allows rotation in phi or theta through +/- 180, with
375        uniform spacing.  Both theta and phi allow free rotation, with wobble
376        in the orthogonal direction reasonably well behaved (though not as
377        good as equirectangular phi). The forward/reverse transformations
378        relies on elliptic integrals that are somewhat expensive, so the
379        behaviour has to be very good to justify the cost and complexity.
380        The weighting function for each point has not yet been computed.
381        Note: run the module *guyou.py* directly and it will show the forward
382        and reverse mappings.
383    azimuthal_equal_area  **incomplete**
384        <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>
385        Preserves the relative density of the surface patches.  Not that
386        useful and not completely implemented
387    Gauss-Kreuger **not implemented**
388        <https://en.wikipedia.org/wiki/Transverse_Mercator_projection#Ellipsoidal_transverse_Mercator>
389        Should allow free movement in theta, but phi is distorted.
390    """
391    # TODO: try Kent distribution instead of a gaussian warped by projection
392
393    if projection == 'equirectangular':  #define PROJECTION 1
394        def _project(theta_i, phi_j, psi):
395            latitude, longitude = theta_i, phi_j
396            return latitude, longitude, psi
397            #return Rx(phi_j)*Ry(theta_i)
398        def _weight(theta_i, phi_j, w_i, w_j):
399            return w_i*w_j*abs(cos(radians(theta_i)))
400    elif projection == 'sinusoidal':  #define PROJECTION 2
401        def _project(theta_i, phi_j, psi):
402            latitude = theta_i
403            scale = cos(radians(latitude))
404            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0
405            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
406            return latitude, longitude, psi
407            #return Rx(longitude)*Ry(latitude)
408        def _project(theta_i, phi_j, w_i, w_j):
409            latitude = theta_i
410            scale = cos(radians(latitude))
411            active = 1 if abs(phi_j) < abs(scale)*180 else 0
412            return active*w_i*w_j
413    elif projection == 'guyou':  #define PROJECTION 3  (eventually?)
414        def _project(theta_i, phi_j, psi):
415            from .guyou import guyou_invert
416            #latitude, longitude = guyou_invert([theta_i], [phi_j])
417            longitude, latitude = guyou_invert([phi_j], [theta_i])
418            return latitude, longitude, psi
419            #return Rx(longitude[0])*Ry(latitude[0])
420        def _weight(theta_i, phi_j, w_i, w_j):
421            return w_i*w_j
422    elif projection == 'azimuthal_equidistance': 
423        # Note that calculates angles for Rz Ry rather than Rx Ry
424        def _project(theta_i, phi_j, psi):
425            latitude = sqrt(theta_i**2 + phi_j**2)
426            longitude = degrees(np.arctan2(phi_j, theta_i))
427            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
428            return latitude, longitude, psi-longitude, 'zyz'
429            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
430            #return R_to_xyz(R)
431            #return Rz(longitude)*Ry(latitude)
432        def _weight(theta_i, phi_j, w_i, w_j):
433            # Weighting for each point comes from the integral:
434            #     \int\int I(q, lat, log) sin(lat) dlat dlog
435            # We are doing a conformal mapping from disk to sphere, so we need
436            # a change of variables g(theta, phi) -> (lat, long):
437            #     lat, long = sqrt(theta^2 + phi^2), arctan(phi/theta)
438            # giving:
439            #     dtheta dphi = det(J) dlat dlong
440            # where J is the jacobian from the partials of g. Using
441            #     R = sqrt(theta^2 + phi^2),
442            # then
443            #     J = [[x/R, Y/R], -y/R^2, x/R^2]]
444            # and
445            #     det(J) = 1/R
446            # with the final integral being:
447            #    \int\int I(q, theta, phi) sin(R)/R dtheta dphi
448            #
449            # This does approximately the right thing, decreasing the weight
450            # of each point as you go farther out on the disk, but it hasn't
451            # yet been checked against the 1D integral results. Prior
452            # to declaring this "good enough" and checking that integrals
453            # work in practice, we will examine alternative mappings.
454            #
455            # The issue is that the mapping does not support the case of free
456            # rotation about a single axis correctly, with a small deviation
457            # in the orthogonal axis independent of the first axis.  Like the
458            # usual polar coordiates integration, the integrated sections
459            # form wedges, though at least in this case the wedge cuts through
460            # the entire sphere, and treats theta and phi identically.
461            latitude = sqrt(theta_i**2 + phi_j**2)
462            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
463            return weight*w_i*w_j if latitude < 180 else 0
464    elif projection == 'azimuthal_equal_area':
465        # Note that calculates angles for Rz Ry rather than Rx Ry
466        def _project(theta_i, phi_j, psi):
467            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180)
468            latitude = 180-degrees(2*np.arccos(radius))
469            longitude = degrees(np.arctan2(phi_j, theta_i))
470            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
471            return latitude, longitude, psi, "zyz"
472            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
473            #return R_to_xyz(R)
474            #return Rz(longitude)*Ry(latitude)
475        def _weight(theta_i, phi_j, w_i, w_j):
476            latitude = sqrt(theta_i**2 + phi_j**2)
477            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
478            return weight*w_i*w_j if latitude < 180 else 0
479    else:
480        raise ValueError("unknown projection %r"%projection)
481
482    return _project, _weight
483
484def R_to_xyz(R):
485    """
486    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix.
487
488    Extracting Euler Angles from a Rotation Matrix
489    Mike Day, Insomniac Games
490    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf
491    Based on: Shoemake’s “Euler Angle Conversion”, Graphics Gems IV, pp.  222-229
492    """
493    phi = np.arctan2(R[1, 2], R[2, 2])
494    theta = np.arctan2(-R[0, 2], np.sqrt(R[0, 0]**2 + R[0, 1]**2))
495    psi = np.arctan2(R[0, 1], R[0, 0])
496    return np.degrees(phi), np.degrees(theta), np.degrees(psi)
497
498def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian',
499              projection='equirectangular'):
500    """
501    Draw the dispersion mesh showing the theta-phi orientations at which
502    the model will be evaluated.
503    """
504
505    _project, _weight = get_projection(projection)
506    def _rotate(theta, phi, z):
507        dview = _project(theta, phi, 0.)
508        if len(dview) == 4: # hack for zyz coords
509            return Rz(dview[1])*Ry(dview[0])*z
510        else:
511            return Rx(dview[1])*Ry(dview[0])*z
512
513
514    dist_x = np.linspace(-1, 1, n)
515    weights = np.ones_like(dist_x)
516    if dist == 'gaussian':
517        dist_x *= 3
518        weights = exp(-0.5*dist_x**2)
519    elif dist == 'rectangle':
520        # Note: uses sasmodels ridiculous definition of rectangle width
521        dist_x *= sqrt(3)
522    elif dist == 'uniform':
523        pass
524    else:
525        raise ValueError("expected dist to be gaussian, rectangle or uniform")
526
527    # mesh in theta, phi formed by rotating z
528    dtheta, dphi, dpsi = jitter
529    z = np.matrix([[0], [0], [radius]])
530    points = np.hstack([_rotate(theta_i, phi_j, z)
531                        for theta_i in dtheta*dist_x
532                        for phi_j in dphi*dist_x])
533    dist_w = np.array([_weight(theta_i, phi_j, w_i, w_j)
534                       for w_i, theta_i in zip(weights, dtheta*dist_x)
535                       for w_j, phi_j in zip(weights, dphi*dist_x)])
536    #print(max(dist_w), min(dist_w), min(dist_w[dist_w > 0]))
537    points = points[:, dist_w > 0]
538    dist_w = dist_w[dist_w > 0]
539    dist_w /= max(dist_w)
540
541    # rotate relative to beam
542    points = orient_relative_to_beam(view, points)
543
544    x, y, z = [np.array(v).flatten() for v in points]
545    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
546    axes.scatter(x, y, z, c=dist_w, marker='o', vmin=0., vmax=1.)
547
548def draw_labels(axes, view, jitter, text):
549    """
550    Draw text at a particular location.
551    """
552    labels, locations, orientations = zip(*text)
553    px, py, pz = zip(*locations)
554    dx, dy, dz = zip(*orientations)
555
556    px, py, pz = transform_xyz(view, jitter, px, py, pz)
557    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
558
559    # TODO: zdir for labels is broken, and labels aren't appearing.
560    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
561        zdir = np.asarray(zdir).flatten()
562        axes.text(p[0], p[1], p[2], label, zdir=zdir)
563
564# Definition of rotation matrices comes from wikipedia:
565#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
566def Rx(angle):
567    """Construct a matrix to rotate points about *x* by *angle* degrees."""
568    angle = radians(angle)
569    rot = [[1, 0, 0],
570           [0, +cos(angle), -sin(angle)],
571           [0, +sin(angle), +cos(angle)]]
572    return np.matrix(rot)
573
574def Ry(angle):
575    """Construct a matrix to rotate points about *y* by *angle* degrees."""
576    angle = radians(angle)
577    rot = [[+cos(angle), 0, +sin(angle)],
578           [0, 1, 0],
579           [-sin(angle), 0, +cos(angle)]]
580    return np.matrix(rot)
581
582def Rz(angle):
583    """Construct a matrix to rotate points about *z* by *angle* degrees."""
584    angle = radians(angle)
585    rot = [[+cos(angle), -sin(angle), 0],
586           [+sin(angle), +cos(angle), 0],
587           [0, 0, 1]]
588    return np.matrix(rot)
589
590def transform_xyz(view, jitter, x, y, z):
591    """
592    Send a set of (x,y,z) points through the jitter and view transforms.
593    """
594    x, y, z = [np.asarray(v) for v in (x, y, z)]
595    shape = x.shape
596    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
597    points = apply_jitter(jitter, points)
598    points = orient_relative_to_beam(view, points)
599    x, y, z = [np.array(v).reshape(shape) for v in points]
600    return x, y, z
601
602def apply_jitter(jitter, points):
603    """
604    Apply the jitter transform to a set of points.
605
606    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
607    """
608    if jitter is None:
609        return points
610    # Hack to deal with the fact that azimuthal_equidistance uses euler angles
611    if len(jitter) == 4:
612        dtheta, dphi, dpsi, _ = jitter
613        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points
614    else:
615        dtheta, dphi, dpsi = jitter
616        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
617    return points
618
619def orient_relative_to_beam(view, points):
620    """
621    Apply the view transform to a set of points.
622
623    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
624    """
625    theta, phi, psi = view
626    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle
627    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles
628    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle
629    return points
630
631def orient_relative_to_beam_quaternion(view, points):
632    """
633    Apply the view transform to a set of points.
634
635    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
636
637    This variant uses quaternions rather than rotation matrices for the
638    computation.  It works but it is not used because it doesn't solve
639    any problems.  The challenge of mapping theta/phi/psi to SO(3) does
640    not disappear by calculating the transform differently.
641    """
642    theta, phi, psi = view
643    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1]
644    q = Quaternion(1, [0, 0, 0])
645    ## Compose a rotation about the three axes by rotating
646    ## the unit vectors before applying the rotation.
647    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q
648    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q
649    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q
650    ## The above turns out to be equivalent to reversing
651    ## the order of application, so ignore it and use below.
652    q = q * Quaternion.from_angle_axis(theta, x)
653    q = q * Quaternion.from_angle_axis(phi, y)
654    q = q * Quaternion.from_angle_axis(psi, z)
655    ## Reverse the order by post-multiply rather than pre-multiply
656    #q = Quaternion.from_angle_axis(theta, x) * q
657    #q = Quaternion.from_angle_axis(phi, y) * q
658    #q = Quaternion.from_angle_axis(psi, z) * q
659    #print("axes psi", q.rot(np.matrix([x, y, z]).T))
660    return q.rot(points)
661#orient_relative_to_beam = orient_relative_to_beam_quaternion
662 
663# Simple stand-alone quaternion class
664import numpy as np
665from copy import copy
666class Quaternion(object):
667    def __init__(self, w, r):
668         self.w = w
669         self.r = np.asarray(r, dtype='d')
670    @staticmethod
671    def from_angle_axis(theta, r):
672         theta = np.radians(theta)/2
673         r = np.asarray(r)
674         w = np.cos(theta)
675         r = np.sin(theta)*r/np.dot(r,r)
676         return Quaternion(w, r)
677    def __mul__(self, other):
678        if isinstance(other, Quaternion):
679            w = self.w*other.w - np.dot(self.r, other.r)
680            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r)
681            return Quaternion(w, r)
682    def rot(self, v):
683        v = np.asarray(v).T
684        use_transpose = (v.shape[-1] != 3)
685        if use_transpose: v = v.T
686        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v)
687        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v))
688        if use_transpose: v = v.T
689        return v.T
690    def conj(self):
691        return Quaternion(self.w, -self.r)
692    def inv(self):
693        return self.conj()/self.norm()**2
694    def norm(self):
695        return np.sqrt(self.w**2 + np.sum(self.r**2))
696    def __str__(self):
697        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2])
698def test_qrot():
699    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees from y.
700    # The rotation axis is determined by rotating the point [0, 1, 0] about x.
701    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0])
702    q = Quaternion.from_angle_axis(60, ax)
703    # Set the point to be rotated, and its expected rotated position.
704    p = [1, -1, 2]
705    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8]
706    #print(q, q.rot(p) - target)
707    assert max(abs(q.rot(p) - target)) < 1e-14 
708#test_qrot()
709#import sys; sys.exit()
710
711# translate between number of dimension of dispersity and the number of
712# points along each dimension.
713PD_N_TABLE = {
714    (0, 0, 0): (0, 0, 0),     # 0
715    (1, 0, 0): (100, 0, 0),   # 100
716    (0, 1, 0): (0, 100, 0),
717    (0, 0, 1): (0, 0, 100),
718    (1, 1, 0): (30, 30, 0),   # 900
719    (1, 0, 1): (30, 0, 30),
720    (0, 1, 1): (0, 30, 30),
721    (1, 1, 1): (15, 15, 15),  # 3375
722}
723
724def clipped_range(data, portion=1.0, mode='central'):
725    """
726    Determine range from data.
727
728    If *portion* is 1, use full range, otherwise use the center of the range
729    or the top of the range, depending on whether *mode* is 'central' or 'top'.
730    """
731    if portion == 1.0:
732        return data.min(), data.max()
733    elif mode == 'central':
734        data = np.sort(data.flatten())
735        offset = int(portion*len(data)/2 + 0.5)
736        return data[offset], data[-offset]
737    elif mode == 'top':
738        data = np.sort(data.flatten())
739        offset = int(portion*len(data) + 0.5)
740        return data[offset], data[-1]
741
742def draw_scattering(calculator, axes, view, jitter, dist='gaussian'):
743    """
744    Plot the scattering for the particular view.
745
746    *calculator* is returned from :func:`build_model`.  *axes* are the 3D axes
747    on which the data will be plotted.  *view* and *jitter* are the current
748    orientation and orientation dispersity.  *dist* is one of the sasmodels
749    weight distributions.
750    """
751    if dist == 'uniform':  # uniform is not yet in this branch
752        dist, scale = 'rectangle', 1/sqrt(3)
753    else:
754        scale = 1
755
756    # add the orientation parameters to the model parameters
757    theta, phi, psi = view
758    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
759    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd > 0, phi_pd > 0, psi_pd > 0)]
760    ## increase pd_n for testing jitter integration rather than simple viz
761    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
762
763    pars = dict(
764        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
765        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
766        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
767    )
768    pars.update(calculator.pars)
769
770    # compute the pattern
771    qx, qy = calculator._data.x_bins, calculator._data.y_bins
772    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
773
774    # scale it and draw it
775    Iqxy = np.log(Iqxy)
776    if calculator.limits:
777        # use limits from orientation (0,0,0)
778        vmin, vmax = calculator.limits
779    else:
780        vmax = Iqxy.max()
781        vmin = vmax*10**-7
782        #vmin, vmax = clipped_range(Iqxy, portion=portion, mode='top')
783    #print("range",(vmin,vmax))
784    #qx, qy = np.meshgrid(qx, qy)
785    if 0:
786        from matplotlib import cm
787        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
788        level[level < 0] = 0
789        colors = plt.get_cmap()(level)
790        #colors = cm.coolwarm(level)
791        #colors = cm.gist_yarg(level)
792        x, y = np.meshgrid(qx/qx.max(), qy/qy.max())
793        axes.plot_surface(x, y, -1.1*np.ones_like(x), facecolors=colors)
794    elif 1:
795        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
796                      levels=np.linspace(vmin, vmax, 24))
797    else:
798        axes.pcolormesh(qx, qy, Iqxy)
799
800def build_model(model_name, n=150, qmax=0.5, **pars):
801    """
802    Build a calculator for the given shape.
803
804    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
805    on which to evaluate the model.  The remaining parameters are stored in
806    the returned calculator as *calculator.pars*.  They are used by
807    :func:`draw_scattering` to set the non-orientation parameters in the
808    calculation.
809
810    Returns a *calculator* function which takes a dictionary or parameters and
811    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
812    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
813    for details.
814    """
815    from sasmodels.core import load_model_info, build_model as build_sasmodel
816    from sasmodels.data import empty_data2D
817    from sasmodels.direct_model import DirectModel
818
819    model_info = load_model_info(model_name)
820    model = build_sasmodel(model_info) #, dtype='double!')
821    q = np.linspace(-qmax, qmax, n)
822    data = empty_data2D(q, q)
823    calculator = DirectModel(data, model)
824
825    # stuff the values for non-orientation parameters into the calculator
826    calculator.pars = pars.copy()
827    calculator.pars.setdefault('backgound', 1e-3)
828
829    # fix the data limits so that we can see if the pattern fades
830    # under rotation or angular dispersion
831    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
832    Iqxy = np.log(Iqxy)
833    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
834    calculator.limits = vmin, vmax+1
835
836    return calculator
837
838def select_calculator(model_name, n=150, size=(10, 40, 100)):
839    """
840    Create a model calculator for the given shape.
841
842    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
843    parallelepiped or bcc_paracrystal. *n* is the number of points to use
844    in the q range.  *qmax* is chosen based on model parameters for the
845    given model to show something intersting.
846
847    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
848    equitorial axes and polar axis respectively.  See :func:`build_model`
849    for details on the returned calculator.
850    """
851    a, b, c = size
852    d_factor = 0.06  # for paracrystal models
853    if model_name == 'sphere':
854        calculator = build_model('sphere', n=n, radius=c)
855        a = b = c
856    elif model_name == 'sc_paracrystal':
857        a = b = c
858        dnn = c
859        radius = 0.5*c
860        calculator = build_model('sc_paracrystal', n=n, dnn=dnn,
861                                 d_factor=d_factor, radius=(1-d_factor)*radius,
862                                 background=0)
863    elif model_name == 'fcc_paracrystal':
864        a = b = c
865        # nearest neigbour distance dnn should be 2 radius, but I think the
866        # model uses lattice spacing rather than dnn in its calculations
867        dnn = 0.5*c
868        radius = sqrt(2)/4 * c
869        calculator = build_model('fcc_paracrystal', n=n, dnn=dnn,
870                                 d_factor=d_factor, radius=(1-d_factor)*radius,
871                                 background=0)
872    elif model_name == 'bcc_paracrystal':
873        a = b = c
874        # nearest neigbour distance dnn should be 2 radius, but I think the
875        # model uses lattice spacing rather than dnn in its calculations
876        dnn = 0.5*c
877        radius = sqrt(3)/2 * c
878        calculator = build_model('bcc_paracrystal', n=n, dnn=dnn,
879                                 d_factor=d_factor, radius=(1-d_factor)*radius,
880                                 background=0)
881    elif model_name == 'cylinder':
882        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
883        a = b
884    elif model_name == 'ellipsoid':
885        calculator = build_model('ellipsoid', n=n, qmax=1.0,
886                                 radius_polar=c, radius_equatorial=b)
887        a = b
888    elif model_name == 'triaxial_ellipsoid':
889        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
890                                 radius_equat_minor=a,
891                                 radius_equat_major=b,
892                                 radius_polar=c)
893    elif model_name == 'parallelepiped':
894        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
895    else:
896        raise ValueError("unknown model %s"%model_name)
897
898    return calculator, (a, b, c)
899
900SHAPES = [
901    'parallelepiped',
902    'sphere', 'ellipsoid', 'triaxial_ellipsoid',
903    'cylinder',
904    'fcc_paracrystal', 'bcc_paracrystal', 'sc_paracrystal',
905]
906
907DRAW_SHAPES = {
908    'fcc_paracrystal': draw_fcc,
909    'bcc_paracrystal': draw_bcc,
910    'sc_paracrystal': draw_sc,
911    'parallelepiped': draw_parallelepiped,
912}
913
914DISTRIBUTIONS = [
915    'gaussian', 'rectangle', 'uniform',
916]
917DIST_LIMITS = {
918    'gaussian': 30,
919    'rectangle': 90/sqrt(3),
920    'uniform': 90,
921}
922
923
924def run(model_name='parallelepiped', size=(10, 40, 100), 
925        view=(0, 0, 0), jitter=(0, 0, 0),
926        dist='gaussian', mesh=30,
927        projection='equirectangular'):
928    """
929    Show an interactive orientation and jitter demo.
930
931    *model_name* is one of: sphere, ellipsoid, triaxial_ellipsoid,
932    parallelepiped, cylinder, or sc/fcc/bcc_paracrystal
933
934    *size* gives the dimensions (a, b, c) of the shape.
935
936    *view* gives the initial view (theta, phi, psi) of the shape.
937
938    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape.
939
940    *dist* is the type of dispersition: gaussian, rectangle, or uniform.
941
942    *mesh* is the number of points in the dispersion mesh.
943
944    *projection* is the map projection to use for the mesh: equirectangular,
945    sinusoidal, guyou, azimuthal_equidistance, or azimuthal_equal_area.
946    """
947    # projection number according to 1-order position in list, but
948    # only 1 and 2 are implemented so far.
949    from sasmodels import generate
950    generate.PROJECTION = PROJECTIONS.index(projection) + 1
951    if generate.PROJECTION > 2:
952        print("*** PROJECTION %s not implemented in scattering function ***"%projection)
953        generate.PROJECTION = 2
954
955    # set up calculator
956    calculator, size = select_calculator(model_name, n=150, size=size)
957    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped)
958    #draw_shape = draw_fcc
959
960    ## uncomment to set an independent the colour range for every view
961    ## If left commented, the colour range is fixed for all views
962    calculator.limits = None
963
964    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection)
965
966def mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
967    import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot
968    import matplotlib.pyplot as plt
969    from matplotlib.widgets import Slider
970
971    ## create the plot window
972    #plt.hold(True)
973    plt.subplots(num=None, figsize=(5.5, 5.5))
974    plt.set_cmap('gist_earth')
975    plt.clf()
976    plt.gcf().canvas.set_window_title(projection)
977    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
978    #axes = plt.subplot(gs[0], projection='3d')
979    axes = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
980    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
981        axes.axis('square')
982    except Exception:
983        pass
984
985    axcolor = 'lightgoldenrodyellow'
986
987    ## add control widgets to plot
988    axes_theta = plt.axes([0.1, 0.15, 0.45, 0.04], facecolor=axcolor)
989    axes_phi = plt.axes([0.1, 0.1, 0.45, 0.04], facecolor=axcolor)
990    axes_psi = plt.axes([0.1, 0.05, 0.45, 0.04], facecolor=axcolor)
991    stheta = Slider(axes_theta, 'Theta', -90, 90, valinit=0)
992    sphi = Slider(axes_phi, 'Phi', -180, 180, valinit=0)
993    spsi = Slider(axes_psi, 'Psi', -180, 180, valinit=0)
994
995    axes_dtheta = plt.axes([0.75, 0.15, 0.15, 0.04], facecolor=axcolor)
996    axes_dphi = plt.axes([0.75, 0.1, 0.15, 0.04], facecolor=axcolor)
997    axes_dpsi = plt.axes([0.75, 0.05, 0.15, 0.04], facecolor=axcolor)
998    # Note: using ridiculous definition of rectangle distribution, whose width
999    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
1000    # the maximum width to 90.
1001    dlimit = DIST_LIMITS[dist]
1002    sdtheta = Slider(axes_dtheta, 'dTheta', 0, 2*dlimit, valinit=0)
1003    sdphi = Slider(axes_dphi, 'dPhi', 0, 2*dlimit, valinit=0)
1004    sdpsi = Slider(axes_dpsi, 'dPsi', 0, 2*dlimit, valinit=0)
1005
1006    ## initial view and jitter
1007    theta, phi, psi = view
1008    stheta.set_val(theta)
1009    sphi.set_val(phi)
1010    spsi.set_val(psi)
1011    dtheta, dphi, dpsi = jitter
1012    sdtheta.set_val(dtheta)
1013    sdphi.set_val(dphi)
1014    sdpsi.set_val(dpsi)
1015
1016    ## callback to draw the new view
1017    def update(val, axis=None):
1018        view = stheta.val, sphi.val, spsi.val
1019        jitter = sdtheta.val, sdphi.val, sdpsi.val
1020        # set small jitter as 0 if multiple pd dims
1021        dims = sum(v > 0 for v in jitter)
1022        limit = [0, 0.5, 5, 5][dims]
1023        jitter = [0 if v < limit else v for v in jitter]
1024        axes.cla()
1025
1026        ## Visualize as person on globe
1027        #draw_sphere(axes)
1028        #draw_person_on_sphere(axes, view)
1029
1030        ## Move beam instead of shape
1031        #draw_beam(axes, -view[:2])
1032        #draw_jitter(axes, (0,0,0), (0,0,0), views=3)
1033
1034        ## Move shape and draw scattering
1035        draw_beam(axes, (0, 0))
1036        draw_jitter(axes, view, jitter, dist=dist, size=size, 
1037                    draw_shape=draw_shape, projection=projection, views=3)
1038        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection)
1039        draw_scattering(calculator, axes, view, jitter, dist=dist)
1040
1041        plt.gcf().canvas.draw()
1042
1043    ## bind control widgets to view updater
1044    stheta.on_changed(lambda v: update(v, 'theta'))
1045    sphi.on_changed(lambda v: update(v, 'phi'))
1046    spsi.on_changed(lambda v: update(v, 'psi'))
1047    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
1048    sdphi.on_changed(lambda v: update(v, 'dphi'))
1049    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
1050
1051    ## initialize view
1052    update(None, 'phi')
1053
1054    ## go interactive
1055    plt.show()
1056
1057
1058def map_colors(z, kw):
1059    from matplotlib import cm
1060
1061    cmap = kw.pop('cmap', cm.coolwarm)
1062    alpha = kw.pop('alpha', None)
1063    vmin = kw.pop('vmin', z.min())
1064    vmax = kw.pop('vmax', z.max())
1065    c = kw.pop('c', None)
1066    color = kw.pop('color', c)
1067    if color is None:
1068        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1)
1069        color = cmap(znorm)
1070    elif isinstance(color, np.ndarray) and color.shape == z.shape:
1071        color = cmap(color)
1072    if alpha is None:
1073        if isinstance(color, np.ndarray):
1074            color = color[..., :3]
1075    else:
1076        color[..., 3] = alpha
1077    kw['color'] = color
1078
1079def make_vec(*args, flat=False):
1080    if flat:
1081        return [np.asarray(v, 'd').flatten() for v in args]
1082    else:
1083        return [np.asarray(v, 'd') for v in args]
1084
1085def make_image(z, kw):
1086    import PIL.Image
1087    from matplotlib import cm
1088
1089    cmap = kw.pop('cmap', cm.coolwarm)
1090
1091    znorm = (z-z.min())/z.ptp()
1092    c = cmap(znorm)
1093    c = c[..., :3]
1094    rgb = np.asarray(c*255, 'u1')
1095    image = PIL.Image.fromarray(rgb, mode='RGB')
1096    return image
1097
1098
1099_IPV_MARKERS = {
1100    'o': 'sphere',
1101}
1102_IPV_COLORS = {
1103    'w': 'white',
1104    'k': 'black',
1105    'c': 'cyan',
1106    'm': 'magenta',
1107    'y': 'yellow',
1108    'r': 'red',
1109    'g': 'green',
1110    'b': 'blue',
1111}
1112def ipv_fix_color(kw):
1113    kw.pop('alpha', None)
1114    color = kw.get('color', None)
1115    if isinstance(color, str):
1116        color = _IPV_COLORS.get(color, color)
1117        kw['color'] = color
1118
1119def ipv_axes():
1120    import ipyvolume as ipv
1121
1122    class Plotter:
1123        def plot(self, x, y, z, **kw):
1124            ipv_fix_color(kw)
1125            x, y, z = make_vec(x, y, z)
1126            ipv.plot(x, y, z, **kw)
1127        def plot_surface(self, x, y, z, **kw):
1128            facecolors = kw.pop('facecolors', None)
1129            if facecolors is not None:
1130                kw['color'] = facecolors
1131            ipv_fix_color(kw)
1132            x, y, z = make_vec(x, y, z)
1133            ipv.plot_surface(x, y, z, **kw)
1134        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw):
1135            ipv_fix_color(kw)
1136            x, y, z = make_vec(x, y, Z)
1137            if triangles is not None:
1138                triangles = np.asarray(triangles)
1139            ipv.plot_trisurf(x, y, z, triangles=triangles, **kw)
1140        def scatter(self, x, y, z, **kw):
1141            x, y, z = make_vec(x, y, z)
1142            map_colors(z, kw)
1143            marker = kw.get('marker', None)
1144            kw['marker'] = _IPV_MARKERS.get(marker, marker)
1145            ipv.scatter(x, y, z, **kw)
1146        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw):
1147            # Don't use contour for now (although we might want to later)
1148            self.pcolor(x, y, v, zdir='z', offset=offset, **kw)
1149        def pcolor(self, x, y, v, zdir='z', offset=0, **kw):
1150            x, y, v = make_vec(x, y, v)
1151            image = make_image(v, kw)
1152            xmin, xmax = x.min(), x.max()
1153            ymin, ymax = y.min(), y.max()
1154            x = np.array([[xmin, xmax], [xmin, xmax]])
1155            y = np.array([[ymin, ymin], [ymax, ymax]])
1156            z = x*0 + offset
1157            u = np.array([[0., 1], [0, 1]])
1158            v = np.array([[0., 0], [1, 1]])
1159            ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False)   
1160        def text(self, *args, **kw):
1161            pass
1162        def set_xlim(self, limits):
1163            ipv.xlim(*limits)
1164        def set_ylim(self, limits):
1165            ipv.ylim(*limits)
1166        def set_zlim(self, limits):
1167            ipv.zlim(*limits)
1168        def set_axes_on(self):
1169            ipv.style.axis_on()
1170        def set_axis_off(self):
1171            ipv.style.axes_off()
1172    return Plotter()
1173
1174def ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
1175    import ipywidgets as widgets
1176    import ipyvolume as ipv
1177
1178    axes = ipv_axes()
1179
1180    def draw(view, jitter):
1181        camera = ipv.gcf().camera
1182        #print(ipv.gcf().__dict__.keys())
1183        #print(dir(ipv.gcf()))
1184        ipv.figure()
1185
1186        # set small jitter as 0 if multiple pd dims
1187        dims = sum(v > 0 for v in jitter)
1188        limit = [0, 0.5, 5, 5][dims]
1189        jitter = [0 if v < limit else v for v in jitter]
1190
1191        ## Visualize as person on globe
1192        #draw_beam(axes, (0, 0))
1193        #draw_sphere(axes)
1194        #draw_person_on_sphere(axes, view)
1195
1196        ## Move beam instead of shape
1197        #draw_beam(axes, view=(-view[0], -view[1]))
1198        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0))
1199
1200        ## Move shape and draw scattering
1201        draw_beam(axes, (0, 0))
1202        draw_jitter(axes, view, jitter, dist=dist, size=size, 
1203                    draw_shape=draw_shape, projection=projection)
1204        draw_mesh(axes, view, jitter, dist=dist, n=mesh, radius=0.95, projection=projection)
1205        draw_scattering(calculator, axes, view, jitter, dist=dist)
1206   
1207        draw_axes(axes, origin=(-1, -1, -1.1))
1208        ipv.style.box_off()
1209        ipv.style.axes_off()
1210        ipv.xyzlabel(" ", " ", " ")
1211
1212        ipv.gcf().camera = camera
1213        ipv.show()
1214
1215
1216    trange, prange = (-180., 180., 1.), (-180., 180., 1.)
1217    dtrange, dprange = (0., 180., 1.), (0., 180., 1.)
1218
1219    ## Super simple interfaca, but uses non-ascii variable namese
1220    # Ξ φ ψ Δξ Δφ Δψ
1221    #def update(**kw):
1222    #    view = kw['Ξ'], kw['φ'], kw['ψ']
1223    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ']
1224    #    draw(view, jitter)
1225    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange)
1226
1227    def update(theta, phi, psi, dtheta, dphi, dpsi):
1228        draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi))
1229
1230    def slider(name, slice, init=0.):
1231        return widgets.FloatSlider(
1232            value=init,
1233            min=slice[0],
1234            max=slice[1],
1235            step=slice[2],
1236            description=name,
1237            disabled=False,
1238            continuous_update=False,
1239            orientation='horizontal',
1240            readout=True,
1241            readout_format='.1f',
1242            )
1243    theta = slider('Ξ', trange, view[0])
1244    phi = slider('φ', prange, view[1])
1245    psi = slider('ψ', prange, view[2])
1246    dtheta = slider('Δξ', dtrange, jitter[0])
1247    dphi = slider('Δφ', dprange, jitter[1])
1248    dpsi = slider('Δψ', dprange, jitter[2])
1249    fields = {
1250        'theta': theta, 'phi': phi, 'psi': psi,
1251        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi,
1252    }
1253    ui = widgets.HBox([
1254        widgets.VBox([theta, phi, psi]),
1255        widgets.VBox([dtheta, dphi, dpsi])
1256    ])
1257
1258    out = widgets.interactive_output(update, fields)
1259    display(ui, out)
1260
1261
1262_ENGINES = {
1263    "matplotlib": mpl_plot,
1264    "mpl": mpl_plot,
1265    #"plotly": plotly_plot,
1266    "ipvolume": ipv_plot,
1267    "ipv": ipv_plot,
1268}
1269PLOT_ENGINE = _ENGINES["matplotlib"]
1270def set_plotter(name):
1271    global PLOT_ENGINE
1272    PLOT_ENGINE = _ENGINES[name]
1273
1274def main():
1275    parser = argparse.ArgumentParser(
1276        description="Display jitter",
1277        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1278        )
1279    parser.add_argument('-p', '--projection', choices=PROJECTIONS,
1280                        default=PROJECTIONS[0],
1281                        help='coordinate projection')
1282    parser.add_argument('-s', '--size', type=str, default='10,40,100',
1283                        help='a,b,c lengths')
1284    parser.add_argument('-v', '--view', type=str, default='0,0,0',
1285                        help='initial view angles')
1286    parser.add_argument('-j', '--jitter', type=str, default='0,0,0',
1287                        help='initial angular dispersion')
1288    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS,
1289                        default=DISTRIBUTIONS[0],
1290                        help='jitter distribution')
1291    parser.add_argument('-m', '--mesh', type=int, default=30,
1292                        help='#points in theta-phi mesh')
1293    parser.add_argument('shape', choices=SHAPES, nargs='?', default=SHAPES[0],
1294                        help='oriented shape')
1295    opts = parser.parse_args()
1296    size = tuple(float(v) for v in opts.size.split(','))
1297    view = tuple(float(v) for v in opts.view.split(','))
1298    jitter = tuple(float(v) for v in opts.jitter.split(','))
1299    run(opts.shape, size=size, view=view, jitter=jitter,
1300        mesh=opts.mesh, dist=opts.distribution,
1301        projection=opts.projection)
1302
1303if __name__ == "__main__":
1304    main()
Note: See TracBrowser for help on using the repository browser.