source: sasmodels/sasmodels/jitter.py @ 9ec9c67

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9ec9c67 was 9ec9c67, checked in by Paul Kienzle <pkienzle@…>, 3 years ago

note techniques for smooth rotation and transparency in ipyvolume jitter viewer code

  • Property mode set to 100755
File size: 49.9 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Jitter Explorer
5===============
6
7Application to explore orientation angle and angular dispersity.
8
9From the command line::
10
11    # Show docs
12    python -m sasmodels.jitter --help
13
14    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi
15    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0
16
17From a jupyter cell::
18
19    import ipyvolume as ipv
20    from sasmodels import jitter
21    import importlib; importlib.reload(jitter)
22    jitter.set_plotter("ipv")
23
24    size = (10, 40, 100)
25    view = (20, 0, 0)
26
27    #size = (15, 15, 100)
28    #view = (60, 60, 0)
29
30    dview = (0, 0, 0)
31    #dview = (5, 5, 0)
32    #dview = (15, 180, 0)
33    #dview = (180, 15, 0)
34
35    projection = 'equirectangular'
36    #projection = 'azimuthal_equidistance'
37    #projection = 'guyou'
38    #projection = 'sinusoidal'
39    #projection = 'azimuthal_equal_area'
40
41    dist = 'uniform'
42    #dist = 'gaussian'
43
44    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection)
45    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '')
46    #ipv.savefig(filename+'.png')
47"""
48from __future__ import division, print_function
49
50import argparse
51
52import numpy as np
53from numpy import pi, cos, sin, sqrt, exp, degrees, radians
54
55def draw_beam(axes, view=(0, 0), alpha=0.5, steps=6):
56    """
57    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
58    """
59    #axes.plot([0,0],[0,0],[1,-1])
60    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha)
61
62    u = np.linspace(0, 2 * np.pi, steps)
63    v = np.linspace(-1, 1, 2)
64
65    r = 0.02
66    x = r*np.outer(np.cos(u), np.ones_like(v))
67    y = r*np.outer(np.sin(u), np.ones_like(v))
68    z = 1.3*np.outer(np.ones_like(u), v)
69
70    theta, phi = view
71    shape = x.shape
72    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
73    points = Rz(phi)*Ry(theta)*points
74    x, y, z = [v.reshape(shape) for v in points]
75    axes.plot_surface(x, y, z, color='yellow', alpha=alpha)
76
77    # TODO: draw endcaps on beam
78    ## Draw tiny balls on the end will work
79    #draw_sphere(axes, radius=0.02, center=(0, 0, 1.3), color='yellow')
80    #draw_sphere(axes, radius=0.02, center=(0, 0, -1.3), color='yellow')
81    ## The following does not work
82    #triangles = [(0, i+1, i+2) for i in range(steps-2)]
83    #x_cap, y_cap = x[:, 0], y[:, 0]
84    #for z_cap in z[:, 0], z[:, -1]:
85    #    axes.plot_trisurf(x_cap, y_cap, z_cap, triangles,
86    #                      color='yellow', alpha=alpha)
87
88
89def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1):
90    """Draw an ellipsoid."""
91    a, b, c = size
92    u = np.linspace(0, 2 * np.pi, steps)
93    v = np.linspace(0, np.pi, steps)
94    x = a*np.outer(np.cos(u), np.sin(v))
95    y = b*np.outer(np.sin(u), np.sin(v))
96    z = c*np.outer(np.ones_like(u), np.cos(v))
97    x, y, z = transform_xyz(view, jitter, x, y, z)
98
99    axes.plot_surface(x, y, z, color='w', alpha=alpha)
100
101    draw_labels(axes, view, jitter, [
102        ('c+', [+0, +0, +c], [+1, +0, +0]),
103        ('c-', [+0, +0, -c], [+0, +0, -1]),
104        ('a+', [+a, +0, +0], [+0, +0, +1]),
105        ('a-', [-a, +0, +0], [+0, +0, -1]),
106        ('b+', [+0, +b, +0], [-1, +0, +0]),
107        ('b-', [+0, -b, +0], [-1, +0, +0]),
108    ])
109
110def draw_sc(axes, size, view, jitter, steps=None, alpha=1):
111    """Draw points for simple cubic paracrystal"""
112    atoms = _build_sc()
113    _draw_crystal(axes, size, view, jitter, atoms=atoms)
114
115def draw_fcc(axes, size, view, jitter, steps=None, alpha=1):
116    """Draw points for face-centered cubic paracrystal"""
117    # Build the simple cubic crystal
118    atoms = _build_sc()
119    # Define the centers for each face
120    # x planes at -1, 0, 1 have four centers per plane, at +/- 0.5 in y and z
121    x, y, z = (
122        [-1]*4 + [0]*4 + [1]*4,
123        ([-0.5]*2 + [0.5]*2)*3,
124        [-0.5, 0.5]*12,
125    )
126    # y and z planes can be generated by substituting x for y and z respectively
127    atoms.extend(zip(x+y+z, y+z+x, z+x+y))
128    _draw_crystal(axes, size, view, jitter, atoms=atoms)
129
130def draw_bcc(axes, size, view, jitter, steps=None, alpha=1):
131    """Draw points for body-centered cubic paracrystal"""
132    # Build the simple cubic crystal
133    atoms = _build_sc()
134    # Define the centers for each octant
135    # x plane at +/- 0.5 have four centers per plane at +/- 0.5 in y and z
136    x, y, z = (
137        [-0.5]*4 + [0.5]*4,
138        ([-0.5]*2 + [0.5]*2)*2,
139        [-0.5, 0.5]*8,
140    )
141    atoms.extend(zip(x, y, z))
142    _draw_crystal(axes, size, view, jitter, atoms=atoms)
143
144def _draw_crystal(axes, size, view, jitter, atoms=None):
145    atoms, size = np.asarray(atoms, 'd').T, np.asarray(size, 'd')
146    x, y, z = atoms*size[:, None]
147    x, y, z = transform_xyz(view, jitter, x, y, z)
148    axes.scatter([x[0]], [y[0]], [z[0]], c='yellow', marker='o')
149    axes.scatter(x[1:], y[1:], z[1:], c='r', marker='o')
150
151def _build_sc():
152    # three planes of 9 dots for x at -1, 0 and 1
153    x, y, z = (
154        [-1]*9 + [0]*9 + [1]*9,
155        ([-1]*3 + [0]*3 + [1]*3)*3,
156        [-1, 0, 1]*9,
157    )
158    atoms = list(zip(x, y, z))
159    #print(list(enumerate(atoms)))
160    # Pull the dot at (0, 0, 1) to the front of the list
161    # It will be highlighted in the view
162    index = 14
163    highlight = atoms[index]
164    del atoms[index]
165    atoms.insert(0, highlight)
166    return atoms
167
168def draw_box(axes, size, view):
169    a, b, c = size
170    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
171    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
172    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
173    x, y, z = transform_xyz(view, None, x, y, z)
174    def draw(i, j):
175        axes.plot([x[i],x[j]], [y[i], y[j]], [z[i], z[j]], color='black')
176    draw(0, 1)
177    draw(0, 2)
178    draw(0, 3)
179    draw(7, 4)
180    draw(7, 5)
181    draw(7, 6)
182
183def draw_parallelepiped(axes, size, view, jitter, steps=None,
184                        color=(0.6, 1.0, 0.6), alpha=1):
185    """Draw a parallelepiped."""
186    a, b, c = size
187    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
188    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
189    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
190    tri = np.array([
191        # counter clockwise triangles
192        # z: up/down, x: right/left, y: front/back
193        [0, 1, 2], [3, 2, 1], # top face
194        [6, 5, 4], [5, 6, 7], # bottom face
195        [0, 2, 6], [6, 4, 0], # right face
196        [1, 5, 7], [7, 3, 1], # left face
197        [2, 3, 6], [7, 6, 3], # front face
198        [4, 1, 0], [5, 1, 4], # back face
199    ])
200
201    x, y, z = transform_xyz(view, jitter, x, y, z)
202    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
203
204    # Colour the c+ face of the box.
205    # Since I can't control face color, instead draw a thin box situated just
206    # in front of the "c+" face.  Use the c face so that rotations about psi
207    # rotate that face.
208    if 0:
209        color = (1, 0.6, 0.6)  # pink
210        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
211        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
212        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
213        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001)
214        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
215
216    draw_labels(axes, view, jitter, [
217        ('c+', [+0, +0, +c], [+1, +0, +0]),
218        ('c-', [+0, +0, -c], [+0, +0, -1]),
219        ('a+', [+a, +0, +0], [+0, +0, +1]),
220        ('a-', [-a, +0, +0], [+0, +0, -1]),
221        ('b+', [+0, +b, +0], [-1, +0, +0]),
222        ('b-', [+0, -b, +0], [-1, +0, +0]),
223    ])
224
225def draw_sphere(axes, radius=0.5, steps=25, center=(0,0,0), color='w'):
226    """Draw a sphere"""
227    u = np.linspace(0, 2 * np.pi, steps)
228    v = np.linspace(0, np.pi, steps)
229
230    x = radius * np.outer(np.cos(u), np.sin(v)) + center[0]
231    y = radius * np.outer(np.sin(u), np.sin(v)) + center[1]
232    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2]
233    axes.plot_surface(x, y, z, color=color)
234    #axes.plot_wireframe(x, y, z)
235
236def draw_axes(axes, origin=(-1, -1, -1), length=(2, 2, 2)):
237    x, y, z = origin
238    dx, dy, dz = length
239    axes.plot([x, x+dx], [y, y], [z, z], color='black')
240    axes.plot([x, x], [y, y+dy], [z, z], color='black')
241    axes.plot([x, x], [y, y], [z, z+dz], color='black')
242
243def draw_person_on_sphere(axes, view, height=0.5, radius=0.5):
244    limb_offset = height * 0.05
245    head_radius = height * 0.10
246    head_height = height - head_radius
247    neck_length = head_radius * 0.50
248    shoulder_height = height - 2*head_radius - neck_length
249    torso_length = shoulder_height * 0.55
250    torso_radius = torso_length * 0.30
251    leg_length = shoulder_height - torso_length
252    arm_length = torso_length * 0.90
253
254    def _draw_part(x, z):
255        y = np.zeros_like(x)
256        xp, yp, zp = transform_xyz(view, None, x, y, z + radius)
257        axes.plot(xp, yp, zp, color='k')
258
259    # circle for head
260    u = np.linspace(0, 2 * np.pi, 40)
261    x = head_radius * np.cos(u)
262    z = head_radius * np.sin(u) + head_height
263    _draw_part(x, z)
264
265    # rectangle for body
266    x = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius])
267    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length
268    _draw_part(x, z)
269
270    # arms
271    x = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius])
272    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height])
273    _draw_part(x, z)
274    _draw_part(-x, z)
275
276    # legs
277    x = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset])
278    z = np.array([0, leg_length])
279    _draw_part(x, z)
280    _draw_part(-x, z)
281
282    limits = [-radius-height, radius+height]
283    axes.set_xlim(limits)
284    axes.set_ylim(limits)
285    axes.set_zlim(limits)
286    axes.set_axis_off()
287
288def draw_jitter(axes, view, jitter, dist='gaussian',
289                size=(0.1, 0.4, 1.0),
290                draw_shape=draw_parallelepiped,
291                projection='equirectangular',
292                alpha=0.8,
293                views=None):
294    """
295    Represent jitter as a set of shapes at different orientations.
296    """
297    project, weight = get_projection(projection)
298
299    # set max diagonal to 0.95
300    scale = 0.95/sqrt(sum(v**2 for v in size))
301    size = tuple(scale*v for v in size)
302
303    dtheta, dphi, dpsi = jitter
304    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist]
305    def steps(delta):
306        limit = base*delta
307        if views is None:
308            n = max(3, min(25, 2*int(base*delta/5)))
309        else:
310            n = views
311        s = base*delta*np.linspace(-1, 1, n) if delta > 0 else [0]
312        return s
313    stheta = steps(dtheta)
314    sphi = steps(dphi)
315    spsi = steps(dpsi)
316    #print(stheta, sphi, spsi)
317    for theta in stheta:
318        for phi in sphi:
319            for psi in spsi:
320                w = weight(theta, phi, 1.0, 1.0)
321                if w > 0:
322                    dview = project(theta, phi, psi)
323                    draw_shape(axes, size, view, dview, alpha=alpha)
324    for v in 'xyz':
325        a, b, c = size
326        lim = np.sqrt(a**2 + b**2 + c**2)
327        getattr(axes, 'set_'+v+'lim')([-lim, lim])
328        #getattr(axes, v+'axis').label.set_text(v)
329
330PROJECTIONS = [
331    # in order of PROJECTION number; do not change without updating the
332    # constants in kernel_iq.c
333    'equirectangular', 'sinusoidal', 'guyou', 'azimuthal_equidistance',
334    'azimuthal_equal_area',
335]
336def get_projection(projection):
337
338    """
339    jitter projections
340    <https://en.wikipedia.org/wiki/List_of_map_projections>
341
342    equirectangular (standard latitude-longitude mesh)
343        <https://en.wikipedia.org/wiki/Equirectangular_projection>
344        Allows free movement in phi (around the equator), but theta is
345        limited to +/- 90, and points are cos-weighted. Jitter in phi is
346        uniform in weight along a line of latitude.  With small theta and
347        phi ranging over +/- 180 this forms a wobbling disk.  With small
348        phi and theta ranging over +/- 90 this forms a wedge like a slice
349        of an orange.
350    azimuthal_equidistance (Postel)
351        <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>
352        Preserves distance from center, and so is an excellent map for
353        representing a bivariate gaussian on the surface.  Theta and phi
354        operate identically, cutting wegdes from the antipode of the viewing
355        angle.  This unfortunately does not allow free movement in either
356        theta or phi since the orthogonal wobble decreases to 0 as the body
357        rotates through 180 degrees.
358    sinusoidal (Sanson-Flamsteed, Mercator equal-area)
359        <https://en.wikipedia.org/wiki/Sinusoidal_projection>
360        Preserves arc length with latitude, giving bad behaviour at
361        theta near +/- 90.  Theta and phi operate somewhat differently,
362        so a system with a-b-c dtheta-dphi-dpsi will not give the same
363        value as one with b-a-c dphi-dtheta-dpsi, as would be the case
364        for azimuthal equidistance.  Free movement using theta or phi
365        uniform over +/- 180 will work, but not as well as equirectangular
366        phi, with theta being slightly worse.  Computationally it is much
367        cheaper for wide theta-phi meshes since it excludes points which
368        lie outside the sinusoid near theta +/- 90 rather than packing
369        them close together as in equirectangle.  Note that the poles
370        will be slightly overweighted for theta > 90 with the circle
371        from theta at 90+dt winding backwards around the pole, overlapping
372        the circle from theta at 90-dt.
373    Guyou (hemisphere-in-a-square) **not weighted**
374        <https://en.wikipedia.org/wiki/Guyou_hemisphere-in-a-square_projection>
375        With tiling, allows rotation in phi or theta through +/- 180, with
376        uniform spacing.  Both theta and phi allow free rotation, with wobble
377        in the orthogonal direction reasonably well behaved (though not as
378        good as equirectangular phi). The forward/reverse transformations
379        relies on elliptic integrals that are somewhat expensive, so the
380        behaviour has to be very good to justify the cost and complexity.
381        The weighting function for each point has not yet been computed.
382        Note: run the module *guyou.py* directly and it will show the forward
383        and reverse mappings.
384    azimuthal_equal_area  **incomplete**
385        <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>
386        Preserves the relative density of the surface patches.  Not that
387        useful and not completely implemented
388    Gauss-Kreuger **not implemented**
389        <https://en.wikipedia.org/wiki/Transverse_Mercator_projection#Ellipsoidal_transverse_Mercator>
390        Should allow free movement in theta, but phi is distorted.
391    """
392    # TODO: try Kent distribution instead of a gaussian warped by projection
393
394    if projection == 'equirectangular':  #define PROJECTION 1
395        def _project(theta_i, phi_j, psi):
396            latitude, longitude = theta_i, phi_j
397            return latitude, longitude, psi
398            #return Rx(phi_j)*Ry(theta_i)
399        def _weight(theta_i, phi_j, w_i, w_j):
400            return w_i*w_j*abs(cos(radians(theta_i)))
401    elif projection == 'sinusoidal':  #define PROJECTION 2
402        def _project(theta_i, phi_j, psi):
403            latitude = theta_i
404            scale = cos(radians(latitude))
405            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0
406            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
407            return latitude, longitude, psi
408            #return Rx(longitude)*Ry(latitude)
409        def _project(theta_i, phi_j, w_i, w_j):
410            latitude = theta_i
411            scale = cos(radians(latitude))
412            active = 1 if abs(phi_j) < abs(scale)*180 else 0
413            return active*w_i*w_j
414    elif projection == 'guyou':  #define PROJECTION 3  (eventually?)
415        def _project(theta_i, phi_j, psi):
416            from .guyou import guyou_invert
417            #latitude, longitude = guyou_invert([theta_i], [phi_j])
418            longitude, latitude = guyou_invert([phi_j], [theta_i])
419            return latitude, longitude, psi
420            #return Rx(longitude[0])*Ry(latitude[0])
421        def _weight(theta_i, phi_j, w_i, w_j):
422            return w_i*w_j
423    elif projection == 'azimuthal_equidistance':
424        # Note that calculates angles for Rz Ry rather than Rx Ry
425        def _project(theta_i, phi_j, psi):
426            latitude = sqrt(theta_i**2 + phi_j**2)
427            longitude = degrees(np.arctan2(phi_j, theta_i))
428            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
429            return latitude, longitude, psi-longitude, 'zyz'
430            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
431            #return R_to_xyz(R)
432            #return Rz(longitude)*Ry(latitude)
433        def _weight(theta_i, phi_j, w_i, w_j):
434            # Weighting for each point comes from the integral:
435            #     \int\int I(q, lat, log) sin(lat) dlat dlog
436            # We are doing a conformal mapping from disk to sphere, so we need
437            # a change of variables g(theta, phi) -> (lat, long):
438            #     lat, long = sqrt(theta^2 + phi^2), arctan(phi/theta)
439            # giving:
440            #     dtheta dphi = det(J) dlat dlong
441            # where J is the jacobian from the partials of g. Using
442            #     R = sqrt(theta^2 + phi^2),
443            # then
444            #     J = [[x/R, Y/R], -y/R^2, x/R^2]]
445            # and
446            #     det(J) = 1/R
447            # with the final integral being:
448            #    \int\int I(q, theta, phi) sin(R)/R dtheta dphi
449            #
450            # This does approximately the right thing, decreasing the weight
451            # of each point as you go farther out on the disk, but it hasn't
452            # yet been checked against the 1D integral results. Prior
453            # to declaring this "good enough" and checking that integrals
454            # work in practice, we will examine alternative mappings.
455            #
456            # The issue is that the mapping does not support the case of free
457            # rotation about a single axis correctly, with a small deviation
458            # in the orthogonal axis independent of the first axis.  Like the
459            # usual polar coordiates integration, the integrated sections
460            # form wedges, though at least in this case the wedge cuts through
461            # the entire sphere, and treats theta and phi identically.
462            latitude = sqrt(theta_i**2 + phi_j**2)
463            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
464            return weight*w_i*w_j if latitude < 180 else 0
465    elif projection == 'azimuthal_equal_area':
466        # Note that calculates angles for Rz Ry rather than Rx Ry
467        def _project(theta_i, phi_j, psi):
468            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180)
469            latitude = 180-degrees(2*np.arccos(radius))
470            longitude = degrees(np.arctan2(phi_j, theta_i))
471            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
472            return latitude, longitude, psi, "zyz"
473            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
474            #return R_to_xyz(R)
475            #return Rz(longitude)*Ry(latitude)
476        def _weight(theta_i, phi_j, w_i, w_j):
477            latitude = sqrt(theta_i**2 + phi_j**2)
478            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
479            return weight*w_i*w_j if latitude < 180 else 0
480    else:
481        raise ValueError("unknown projection %r"%projection)
482
483    return _project, _weight
484
485def R_to_xyz(R):
486    """
487    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix.
488
489    Extracting Euler Angles from a Rotation Matrix
490    Mike Day, Insomniac Games
491    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf
492    Based on: Shoemake’s "Euler Angle Conversion", Graphics Gems IV, pp.  222-229
493    """
494    phi = np.arctan2(R[1, 2], R[2, 2])
495    theta = np.arctan2(-R[0, 2], np.sqrt(R[0, 0]**2 + R[0, 1]**2))
496    psi = np.arctan2(R[0, 1], R[0, 0])
497    return np.degrees(phi), np.degrees(theta), np.degrees(psi)
498
499def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian',
500              projection='equirectangular'):
501    """
502    Draw the dispersion mesh showing the theta-phi orientations at which
503    the model will be evaluated.
504    """
505
506    _project, _weight = get_projection(projection)
507    def _rotate(theta, phi, z):
508        dview = _project(theta, phi, 0.)
509        if len(dview) == 4: # hack for zyz coords
510            return Rz(dview[1])*Ry(dview[0])*z
511        else:
512            return Rx(dview[1])*Ry(dview[0])*z
513
514
515    dist_x = np.linspace(-1, 1, n)
516    weights = np.ones_like(dist_x)
517    if dist == 'gaussian':
518        dist_x *= 3
519        weights = exp(-0.5*dist_x**2)
520    elif dist == 'rectangle':
521        # Note: uses sasmodels ridiculous definition of rectangle width
522        dist_x *= sqrt(3)
523    elif dist == 'uniform':
524        pass
525    else:
526        raise ValueError("expected dist to be gaussian, rectangle or uniform")
527
528    # mesh in theta, phi formed by rotating z
529    dtheta, dphi, dpsi = jitter
530    z = np.matrix([[0], [0], [radius]])
531    points = np.hstack([_rotate(theta_i, phi_j, z)
532                        for theta_i in dtheta*dist_x
533                        for phi_j in dphi*dist_x])
534    dist_w = np.array([_weight(theta_i, phi_j, w_i, w_j)
535                       for w_i, theta_i in zip(weights, dtheta*dist_x)
536                       for w_j, phi_j in zip(weights, dphi*dist_x)])
537    #print(max(dist_w), min(dist_w), min(dist_w[dist_w > 0]))
538    points = points[:, dist_w > 0]
539    dist_w = dist_w[dist_w > 0]
540    dist_w /= max(dist_w)
541
542    # rotate relative to beam
543    points = orient_relative_to_beam(view, points)
544
545    x, y, z = [np.array(v).flatten() for v in points]
546    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
547    axes.scatter(x, y, z, c=dist_w, marker='o', vmin=0., vmax=1.)
548
549def draw_labels(axes, view, jitter, text):
550    """
551    Draw text at a particular location.
552    """
553    labels, locations, orientations = zip(*text)
554    px, py, pz = zip(*locations)
555    dx, dy, dz = zip(*orientations)
556
557    px, py, pz = transform_xyz(view, jitter, px, py, pz)
558    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
559
560    # TODO: zdir for labels is broken, and labels aren't appearing.
561    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
562        zdir = np.asarray(zdir).flatten()
563        axes.text(p[0], p[1], p[2], label, zdir=zdir)
564
565# Definition of rotation matrices comes from wikipedia:
566#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
567def Rx(angle):
568    """Construct a matrix to rotate points about *x* by *angle* degrees."""
569    angle = radians(angle)
570    rot = [[1, 0, 0],
571           [0, +cos(angle), -sin(angle)],
572           [0, +sin(angle), +cos(angle)]]
573    return np.matrix(rot)
574
575def Ry(angle):
576    """Construct a matrix to rotate points about *y* by *angle* degrees."""
577    angle = radians(angle)
578    rot = [[+cos(angle), 0, +sin(angle)],
579           [0, 1, 0],
580           [-sin(angle), 0, +cos(angle)]]
581    return np.matrix(rot)
582
583def Rz(angle):
584    """Construct a matrix to rotate points about *z* by *angle* degrees."""
585    angle = radians(angle)
586    rot = [[+cos(angle), -sin(angle), 0],
587           [+sin(angle), +cos(angle), 0],
588           [0, 0, 1]]
589    return np.matrix(rot)
590
591def transform_xyz(view, jitter, x, y, z):
592    """
593    Send a set of (x,y,z) points through the jitter and view transforms.
594    """
595    x, y, z = [np.asarray(v) for v in (x, y, z)]
596    shape = x.shape
597    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
598    points = apply_jitter(jitter, points)
599    points = orient_relative_to_beam(view, points)
600    x, y, z = [np.array(v).reshape(shape) for v in points]
601    return x, y, z
602
603def apply_jitter(jitter, points):
604    """
605    Apply the jitter transform to a set of points.
606
607    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
608    """
609    if jitter is None:
610        return points
611    # Hack to deal with the fact that azimuthal_equidistance uses euler angles
612    if len(jitter) == 4:
613        dtheta, dphi, dpsi, _ = jitter
614        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points
615    else:
616        dtheta, dphi, dpsi = jitter
617        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
618    return points
619
620def orient_relative_to_beam(view, points):
621    """
622    Apply the view transform to a set of points.
623
624    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
625    """
626    theta, phi, psi = view
627    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle
628    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles
629    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle
630    return points
631
632def orient_relative_to_beam_quaternion(view, points):
633    """
634    Apply the view transform to a set of points.
635
636    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
637
638    This variant uses quaternions rather than rotation matrices for the
639    computation.  It works but it is not used because it doesn't solve
640    any problems.  The challenge of mapping theta/phi/psi to SO(3) does
641    not disappear by calculating the transform differently.
642    """
643    theta, phi, psi = view
644    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1]
645    q = Quaternion(1, [0, 0, 0])
646    ## Compose a rotation about the three axes by rotating
647    ## the unit vectors before applying the rotation.
648    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q
649    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q
650    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q
651    ## The above turns out to be equivalent to reversing
652    ## the order of application, so ignore it and use below.
653    q = q * Quaternion.from_angle_axis(theta, x)
654    q = q * Quaternion.from_angle_axis(phi, y)
655    q = q * Quaternion.from_angle_axis(psi, z)
656    ## Reverse the order by post-multiply rather than pre-multiply
657    #q = Quaternion.from_angle_axis(theta, x) * q
658    #q = Quaternion.from_angle_axis(phi, y) * q
659    #q = Quaternion.from_angle_axis(psi, z) * q
660    #print("axes psi", q.rot(np.matrix([x, y, z]).T))
661    return q.rot(points)
662#orient_relative_to_beam = orient_relative_to_beam_quaternion
663
664# Simple stand-alone quaternion class
665import numpy as np
666from copy import copy
667class Quaternion(object):
668    def __init__(self, w, r):
669         self.w = w
670         self.r = np.asarray(r, dtype='d')
671    @staticmethod
672    def from_angle_axis(theta, r):
673         theta = np.radians(theta)/2
674         r = np.asarray(r)
675         w = np.cos(theta)
676         r = np.sin(theta)*r/np.dot(r,r)
677         return Quaternion(w, r)
678    def __mul__(self, other):
679        if isinstance(other, Quaternion):
680            w = self.w*other.w - np.dot(self.r, other.r)
681            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r)
682            return Quaternion(w, r)
683    def rot(self, v):
684        v = np.asarray(v).T
685        use_transpose = (v.shape[-1] != 3)
686        if use_transpose: v = v.T
687        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v)
688        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v))
689        if use_transpose: v = v.T
690        return v.T
691    def conj(self):
692        return Quaternion(self.w, -self.r)
693    def inv(self):
694        return self.conj()/self.norm()**2
695    def norm(self):
696        return np.sqrt(self.w**2 + np.sum(self.r**2))
697    def __str__(self):
698        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2])
699def test_qrot():
700    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees from y.
701    # The rotation axis is determined by rotating the point [0, 1, 0] about x.
702    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0])
703    q = Quaternion.from_angle_axis(60, ax)
704    # Set the point to be rotated, and its expected rotated position.
705    p = [1, -1, 2]
706    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8]
707    #print(q, q.rot(p) - target)
708    assert max(abs(q.rot(p) - target)) < 1e-14
709#test_qrot()
710#import sys; sys.exit()
711
712# translate between number of dimension of dispersity and the number of
713# points along each dimension.
714PD_N_TABLE = {
715    (0, 0, 0): (0, 0, 0),     # 0
716    (1, 0, 0): (100, 0, 0),   # 100
717    (0, 1, 0): (0, 100, 0),
718    (0, 0, 1): (0, 0, 100),
719    (1, 1, 0): (30, 30, 0),   # 900
720    (1, 0, 1): (30, 0, 30),
721    (0, 1, 1): (0, 30, 30),
722    (1, 1, 1): (15, 15, 15),  # 3375
723}
724
725def clipped_range(data, portion=1.0, mode='central'):
726    """
727    Determine range from data.
728
729    If *portion* is 1, use full range, otherwise use the center of the range
730    or the top of the range, depending on whether *mode* is 'central' or 'top'.
731    """
732    if portion == 1.0:
733        return data.min(), data.max()
734    elif mode == 'central':
735        data = np.sort(data.flatten())
736        offset = int(portion*len(data)/2 + 0.5)
737        return data[offset], data[-offset]
738    elif mode == 'top':
739        data = np.sort(data.flatten())
740        offset = int(portion*len(data) + 0.5)
741        return data[offset], data[-1]
742
743def draw_scattering(calculator, axes, view, jitter, dist='gaussian'):
744    """
745    Plot the scattering for the particular view.
746
747    *calculator* is returned from :func:`build_model`.  *axes* are the 3D axes
748    on which the data will be plotted.  *view* and *jitter* are the current
749    orientation and orientation dispersity.  *dist* is one of the sasmodels
750    weight distributions.
751    """
752    if dist == 'uniform':  # uniform is not yet in this branch
753        dist, scale = 'rectangle', 1/sqrt(3)
754    else:
755        scale = 1
756
757    # add the orientation parameters to the model parameters
758    theta, phi, psi = view
759    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
760    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd > 0, phi_pd > 0, psi_pd > 0)]
761    ## increase pd_n for testing jitter integration rather than simple viz
762    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
763
764    pars = dict(
765        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
766        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
767        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
768    )
769    pars.update(calculator.pars)
770
771    # compute the pattern
772    qx, qy = calculator._data.x_bins, calculator._data.y_bins
773    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
774
775    # scale it and draw it
776    Iqxy = np.log(Iqxy)
777    if calculator.limits:
778        # use limits from orientation (0,0,0)
779        vmin, vmax = calculator.limits
780    else:
781        vmax = Iqxy.max()
782        vmin = vmax*10**-7
783        #vmin, vmax = clipped_range(Iqxy, portion=portion, mode='top')
784    #print("range",(vmin,vmax))
785    #qx, qy = np.meshgrid(qx, qy)
786    if 0:
787        from matplotlib import cm
788        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
789        level[level < 0] = 0
790        colors = plt.get_cmap()(level)
791        #colors = cm.coolwarm(level)
792        #colors = cm.gist_yarg(level)
793        x, y = np.meshgrid(qx/qx.max(), qy/qy.max())
794        axes.plot_surface(x, y, -1.1*np.ones_like(x), facecolors=colors)
795    elif 1:
796        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
797                      levels=np.linspace(vmin, vmax, 24))
798    else:
799        axes.pcolormesh(qx, qy, Iqxy)
800
801def build_model(model_name, n=150, qmax=0.5, **pars):
802    """
803    Build a calculator for the given shape.
804
805    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
806    on which to evaluate the model.  The remaining parameters are stored in
807    the returned calculator as *calculator.pars*.  They are used by
808    :func:`draw_scattering` to set the non-orientation parameters in the
809    calculation.
810
811    Returns a *calculator* function which takes a dictionary or parameters and
812    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
813    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
814    for details.
815    """
816    from sasmodels.core import load_model_info, build_model as build_sasmodel
817    from sasmodels.data import empty_data2D
818    from sasmodels.direct_model import DirectModel
819
820    model_info = load_model_info(model_name)
821    model = build_sasmodel(model_info) #, dtype='double!')
822    q = np.linspace(-qmax, qmax, n)
823    data = empty_data2D(q, q)
824    calculator = DirectModel(data, model)
825
826    # stuff the values for non-orientation parameters into the calculator
827    calculator.pars = pars.copy()
828    calculator.pars.setdefault('backgound', 1e-3)
829
830    # fix the data limits so that we can see if the pattern fades
831    # under rotation or angular dispersion
832    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
833    Iqxy = np.log(Iqxy)
834    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
835    calculator.limits = vmin, vmax+1
836
837    return calculator
838
839def select_calculator(model_name, n=150, size=(10, 40, 100)):
840    """
841    Create a model calculator for the given shape.
842
843    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
844    parallelepiped or bcc_paracrystal. *n* is the number of points to use
845    in the q range.  *qmax* is chosen based on model parameters for the
846    given model to show something intersting.
847
848    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
849    equitorial axes and polar axis respectively.  See :func:`build_model`
850    for details on the returned calculator.
851    """
852    a, b, c = size
853    d_factor = 0.06  # for paracrystal models
854    if model_name == 'sphere':
855        calculator = build_model('sphere', n=n, radius=c)
856        a = b = c
857    elif model_name == 'sc_paracrystal':
858        a = b = c
859        dnn = c
860        radius = 0.5*c
861        calculator = build_model('sc_paracrystal', n=n, dnn=dnn,
862                                 d_factor=d_factor, radius=(1-d_factor)*radius,
863                                 background=0)
864    elif model_name == 'fcc_paracrystal':
865        a = b = c
866        # nearest neigbour distance dnn should be 2 radius, but I think the
867        # model uses lattice spacing rather than dnn in its calculations
868        dnn = 0.5*c
869        radius = sqrt(2)/4 * c
870        calculator = build_model('fcc_paracrystal', n=n, dnn=dnn,
871                                 d_factor=d_factor, radius=(1-d_factor)*radius,
872                                 background=0)
873    elif model_name == 'bcc_paracrystal':
874        a = b = c
875        # nearest neigbour distance dnn should be 2 radius, but I think the
876        # model uses lattice spacing rather than dnn in its calculations
877        dnn = 0.5*c
878        radius = sqrt(3)/2 * c
879        calculator = build_model('bcc_paracrystal', n=n, dnn=dnn,
880                                 d_factor=d_factor, radius=(1-d_factor)*radius,
881                                 background=0)
882    elif model_name == 'cylinder':
883        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
884        a = b
885    elif model_name == 'ellipsoid':
886        calculator = build_model('ellipsoid', n=n, qmax=1.0,
887                                 radius_polar=c, radius_equatorial=b)
888        a = b
889    elif model_name == 'triaxial_ellipsoid':
890        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
891                                 radius_equat_minor=a,
892                                 radius_equat_major=b,
893                                 radius_polar=c)
894    elif model_name == 'parallelepiped':
895        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
896    else:
897        raise ValueError("unknown model %s"%model_name)
898
899    return calculator, (a, b, c)
900
901SHAPES = [
902    'parallelepiped',
903    'sphere', 'ellipsoid', 'triaxial_ellipsoid',
904    'cylinder',
905    'fcc_paracrystal', 'bcc_paracrystal', 'sc_paracrystal',
906]
907
908DRAW_SHAPES = {
909    'fcc_paracrystal': draw_fcc,
910    'bcc_paracrystal': draw_bcc,
911    'sc_paracrystal': draw_sc,
912    'parallelepiped': draw_parallelepiped,
913}
914
915DISTRIBUTIONS = [
916    'gaussian', 'rectangle', 'uniform',
917]
918DIST_LIMITS = {
919    'gaussian': 30,
920    'rectangle': 90/sqrt(3),
921    'uniform': 90,
922}
923
924
925def run(model_name='parallelepiped', size=(10, 40, 100),
926        view=(0, 0, 0), jitter=(0, 0, 0),
927        dist='gaussian', mesh=30,
928        projection='equirectangular'):
929    """
930    Show an interactive orientation and jitter demo.
931
932    *model_name* is one of: sphere, ellipsoid, triaxial_ellipsoid,
933    parallelepiped, cylinder, or sc/fcc/bcc_paracrystal
934
935    *size* gives the dimensions (a, b, c) of the shape.
936
937    *view* gives the initial view (theta, phi, psi) of the shape.
938
939    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape.
940
941    *dist* is the type of dispersition: gaussian, rectangle, or uniform.
942
943    *mesh* is the number of points in the dispersion mesh.
944
945    *projection* is the map projection to use for the mesh: equirectangular,
946    sinusoidal, guyou, azimuthal_equidistance, or azimuthal_equal_area.
947    """
948    # projection number according to 1-order position in list, but
949    # only 1 and 2 are implemented so far.
950    from sasmodels import generate
951    generate.PROJECTION = PROJECTIONS.index(projection) + 1
952    if generate.PROJECTION > 2:
953        print("*** PROJECTION %s not implemented in scattering function ***"%projection)
954        generate.PROJECTION = 2
955
956    # set up calculator
957    calculator, size = select_calculator(model_name, n=100, size=size)
958    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped)
959    #draw_shape = draw_fcc
960
961    ## uncomment to set an independent the colour range for every view
962    ## If left commented, the colour range is fixed for all views
963    calculator.limits = None
964
965    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection)
966
967def mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
968    import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot
969    import matplotlib.pyplot as plt
970    from matplotlib.widgets import Slider
971
972    ## create the plot window
973    #plt.hold(True)
974    plt.subplots(num=None, figsize=(5.5, 5.5))
975    plt.set_cmap('gist_earth')
976    plt.clf()
977    plt.gcf().canvas.set_window_title(projection)
978    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
979    #axes = plt.subplot(gs[0], projection='3d')
980    axes = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
981    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
982        axes.axis('square')
983    except Exception:
984        pass
985
986    axcolor = 'lightgoldenrodyellow'
987
988    ## add control widgets to plot
989    axes_theta = plt.axes([0.1, 0.15, 0.45, 0.04], facecolor=axcolor)
990    axes_phi = plt.axes([0.1, 0.1, 0.45, 0.04], facecolor=axcolor)
991    axes_psi = plt.axes([0.1, 0.05, 0.45, 0.04], facecolor=axcolor)
992    stheta = Slider(axes_theta, u'Ξ', -90, 90, valinit=0)
993    sphi = Slider(axes_phi, u'φ', -180, 180, valinit=0)
994    spsi = Slider(axes_psi, u'ψ', -180, 180, valinit=0)
995
996    axes_dtheta = plt.axes([0.75, 0.15, 0.15, 0.04], facecolor=axcolor)
997    axes_dphi = plt.axes([0.75, 0.1, 0.15, 0.04], facecolor=axcolor)
998    axes_dpsi = plt.axes([0.75, 0.05, 0.15, 0.04], facecolor=axcolor)
999    # Note: using ridiculous definition of rectangle distribution, whose width
1000    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
1001    # the maximum width to 90.
1002    dlimit = DIST_LIMITS[dist]
1003    sdtheta = Slider(axes_dtheta, u'Δξ', 0, 2*dlimit, valinit=0)
1004    sdphi = Slider(axes_dphi, u'Δφ', 0, 2*dlimit, valinit=0)
1005    sdpsi = Slider(axes_dpsi, u'Δψ', 0, 2*dlimit, valinit=0)
1006
1007    ## initial view and jitter
1008    theta, phi, psi = view
1009    stheta.set_val(theta)
1010    sphi.set_val(phi)
1011    spsi.set_val(psi)
1012    dtheta, dphi, dpsi = jitter
1013    sdtheta.set_val(dtheta)
1014    sdphi.set_val(dphi)
1015    sdpsi.set_val(dpsi)
1016
1017    ## callback to draw the new view
1018    def update(val, axis=None):
1019        view = stheta.val, sphi.val, spsi.val
1020        jitter = sdtheta.val, sdphi.val, sdpsi.val
1021        # set small jitter as 0 if multiple pd dims
1022        dims = sum(v > 0 for v in jitter)
1023        limit = [0, 0.5, 5, 5][dims]
1024        jitter = [0 if v < limit else v for v in jitter]
1025        axes.cla()
1026
1027        ## Visualize as person on globe
1028        #draw_sphere(axes)
1029        #draw_person_on_sphere(axes, view)
1030
1031        ## Move beam instead of shape
1032        #draw_beam(axes, -view[:2])
1033        #draw_jitter(axes, (0,0,0), (0,0,0), views=3)
1034
1035        ## Move shape and draw scattering
1036        draw_beam(axes, (0, 0))
1037        draw_jitter(axes, view, jitter, dist=dist, size=size,
1038                    draw_shape=draw_shape, projection=projection, views=3)
1039        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection)
1040        draw_scattering(calculator, axes, view, jitter, dist=dist)
1041
1042        plt.gcf().canvas.draw()
1043
1044    ## bind control widgets to view updater
1045    stheta.on_changed(lambda v: update(v, 'theta'))
1046    sphi.on_changed(lambda v: update(v, 'phi'))
1047    spsi.on_changed(lambda v: update(v, 'psi'))
1048    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
1049    sdphi.on_changed(lambda v: update(v, 'dphi'))
1050    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
1051
1052    ## initialize view
1053    update(None, 'phi')
1054
1055    ## go interactive
1056    plt.show()
1057
1058
1059def map_colors(z, kw):
1060    from matplotlib import cm
1061
1062    cmap = kw.pop('cmap', cm.coolwarm)
1063    alpha = kw.pop('alpha', None)
1064    vmin = kw.pop('vmin', z.min())
1065    vmax = kw.pop('vmax', z.max())
1066    c = kw.pop('c', None)
1067    color = kw.pop('color', c)
1068    if color is None:
1069        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1)
1070        color = cmap(znorm)
1071    elif isinstance(color, np.ndarray) and color.shape == z.shape:
1072        color = cmap(color)
1073    if alpha is None:
1074        if isinstance(color, np.ndarray):
1075            color = color[..., :3]
1076    else:
1077        color[..., 3] = alpha
1078    kw['color'] = color
1079
1080def make_vec(*args):
1081    #return [np.asarray(v, 'd').flatten() for v in args]
1082    return [np.asarray(v, 'd') for v in args]
1083
1084def make_image(z, kw):
1085    import PIL.Image
1086    from matplotlib import cm
1087
1088    cmap = kw.pop('cmap', cm.coolwarm)
1089
1090    znorm = (z-z.min())/z.ptp()
1091    c = cmap(znorm)
1092    c = c[..., :3]
1093    rgb = np.asarray(c*255, 'u1')
1094    image = PIL.Image.fromarray(rgb, mode='RGB')
1095    return image
1096
1097
1098_IPV_MARKERS = {
1099    'o': 'sphere',
1100}
1101_IPV_COLORS = {
1102    'w': 'white',
1103    'k': 'black',
1104    'c': 'cyan',
1105    'm': 'magenta',
1106    'y': 'yellow',
1107    'r': 'red',
1108    'g': 'green',
1109    'b': 'blue',
1110}
1111def ipv_fix_color(kw):
1112    alpha = kw.pop('alpha', None)
1113    color = kw.get('color', None)
1114    if isinstance(color, str):
1115        color = _IPV_COLORS.get(color, color)
1116        kw['color'] = color
1117    if alpha is not None:
1118        color = kw['color']
1119        #TODO: convert color to [r, g, b, a] if not already
1120        if isinstance(color, np.ndarray) and color.shape[-1] == 4:
1121            color[..., 3] = alpha
1122            kw['color'] = color
1123
1124def ipv_set_transparency(kw, obj):
1125    color = kw.get('color', None)
1126    if (isinstance(color, np.ndarray)
1127            and color.shape[-1] == 4
1128            and (color[..., 3] != 1.0).any()):
1129        obj.material.transparent = True
1130        obj.material.side = "FrontSide"
1131
1132def ipv_axes():
1133    import ipyvolume as ipv
1134
1135    class Plotter:
1136        # transparency can be achieved by setting the following:
1137        #    mesh.color = [r, g, b, alpha]
1138        #    mesh.material.transparent = True
1139        #    mesh.material.side = "FrontSide"
1140        # smooth(ish) rotation can be achieved by setting:
1141        #    slide.continuous_update = True
1142        #    figure.animation = 0.
1143        #    mesh.material.x = x
1144        #    mesh.material.y = y
1145        #    mesh.material.z = z
1146        # maybe need to synchronize update of x/y/z to avoid shimmy when moving
1147        def plot(self, x, y, z, **kw):
1148            ipv_fix_color(kw)
1149            x, y, z = make_vec(x, y, z)
1150            ipv.plot(x, y, z, **kw)
1151        def plot_surface(self, x, y, z, **kw):
1152            facecolors = kw.pop('facecolors', None)
1153            if facecolors is not None:
1154                kw['color'] = facecolors
1155            ipv_fix_color(kw)
1156            x, y, z = make_vec(x, y, z)
1157            h = ipv.plot_surface(x, y, z, **kw)
1158            ipv_set_transparency(kw, h)
1159            return h
1160        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw):
1161            ipv_fix_color(kw)
1162            x, y, z = make_vec(x, y, Z)
1163            if triangles is not None:
1164                triangles = np.asarray(triangles)
1165            h = ipv.plot_trisurf(x, y, z, triangles=triangles, **kw)
1166            ipv_set_transparency(kw, h)
1167            return h
1168        def scatter(self, x, y, z, **kw):
1169            x, y, z = make_vec(x, y, z)
1170            map_colors(z, kw)
1171            marker = kw.get('marker', None)
1172            kw['marker'] = _IPV_MARKERS.get(marker, marker)
1173            h = ipv.scatter(x, y, z, **kw)
1174            ipv_set_transparency(kw, h)
1175            return h
1176        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw):
1177            # Don't use contour for now (although we might want to later)
1178            self.pcolor(x, y, v, zdir='z', offset=offset, **kw)
1179        def pcolor(self, x, y, v, zdir='z', offset=0, **kw):
1180            x, y, v = make_vec(x, y, v)
1181            image = make_image(v, kw)
1182            xmin, xmax = x.min(), x.max()
1183            ymin, ymax = y.min(), y.max()
1184            x = np.array([[xmin, xmax], [xmin, xmax]])
1185            y = np.array([[ymin, ymin], [ymax, ymax]])
1186            z = x*0 + offset
1187            u = np.array([[0., 1], [0, 1]])
1188            v = np.array([[0., 0], [1, 1]])
1189            h = ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False)
1190            ipv_set_transparency(kw, h)
1191            return h
1192        def text(self, *args, **kw):
1193            pass
1194        def set_xlim(self, limits):
1195            ipv.xlim(*limits)
1196        def set_ylim(self, limits):
1197            ipv.ylim(*limits)
1198        def set_zlim(self, limits):
1199            ipv.zlim(*limits)
1200        def set_axes_on(self):
1201            ipv.style.axis_on()
1202        def set_axis_off(self):
1203            ipv.style.axes_off()
1204    return Plotter()
1205
1206def ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
1207    import ipywidgets as widgets
1208    import ipyvolume as ipv
1209
1210    axes = ipv_axes()
1211
1212    def draw(view, jitter):
1213        camera = ipv.gcf().camera
1214        #print(ipv.gcf().__dict__.keys())
1215        #print(dir(ipv.gcf()))
1216        ipv.figure(animation=0.)  # no animation when updating object mesh
1217
1218        # set small jitter as 0 if multiple pd dims
1219        dims = sum(v > 0 for v in jitter)
1220        limit = [0, 0.5, 5, 5][dims]
1221        jitter = [0 if v < limit else v for v in jitter]
1222
1223        ## Visualize as person on globe
1224        #draw_beam(axes, (0, 0))
1225        #draw_sphere(axes)
1226        #draw_person_on_sphere(axes, view)
1227
1228        ## Move beam instead of shape
1229        #draw_beam(axes, view=(-view[0], -view[1]))
1230        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0))
1231
1232        ## Move shape and draw scattering
1233        draw_beam(axes, (0, 0))
1234        draw_jitter(axes, view, jitter, dist=dist, size=size,
1235                    draw_shape=draw_shape, projection=projection)
1236        draw_mesh(axes, view, jitter, dist=dist, n=mesh, radius=0.95, projection=projection)
1237        draw_scattering(calculator, axes, view, jitter, dist=dist)
1238
1239        draw_axes(axes, origin=(-1, -1, -1.1))
1240        ipv.style.box_off()
1241        ipv.style.axes_off()
1242        ipv.xyzlabel(" ", " ", " ")
1243
1244        ipv.gcf().camera = camera
1245        ipv.show()
1246
1247
1248    trange, prange = (-180., 180., 1.), (-180., 180., 1.)
1249    dtrange, dprange = (0., 180., 1.), (0., 180., 1.)
1250
1251    ## Super simple interfaca, but uses non-ascii variable namese
1252    # Ξ φ ψ Δξ Δφ Δψ
1253    #def update(**kw):
1254    #    view = kw['Ξ'], kw['φ'], kw['ψ']
1255    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ']
1256    #    draw(view, jitter)
1257    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange)
1258
1259    def update(theta, phi, psi, dtheta, dphi, dpsi):
1260        draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi))
1261
1262    def slider(name, slice, init=0.):
1263        return widgets.FloatSlider(
1264            value=init,
1265            min=slice[0],
1266            max=slice[1],
1267            step=slice[2],
1268            description=name,
1269            disabled=False,
1270            #continuous_update=True,
1271            continuous_update=False,
1272            orientation='horizontal',
1273            readout=True,
1274            readout_format='.1f',
1275            )
1276    theta = slider(u'Ξ', trange, view[0])
1277    phi = slider(u'φ', prange, view[1])
1278    psi = slider(u'ψ', prange, view[2])
1279    dtheta = slider(u'Δξ', dtrange, jitter[0])
1280    dphi = slider(u'Δφ', dprange, jitter[1])
1281    dpsi = slider(u'Δψ', dprange, jitter[2])
1282    fields = {
1283        'theta': theta, 'phi': phi, 'psi': psi,
1284        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi,
1285    }
1286    ui = widgets.HBox([
1287        widgets.VBox([theta, phi, psi]),
1288        widgets.VBox([dtheta, dphi, dpsi])
1289    ])
1290
1291    out = widgets.interactive_output(update, fields)
1292    display(ui, out)
1293
1294
1295_ENGINES = {
1296    "matplotlib": mpl_plot,
1297    "mpl": mpl_plot,
1298    #"plotly": plotly_plot,
1299    "ipvolume": ipv_plot,
1300    "ipv": ipv_plot,
1301}
1302PLOT_ENGINE = _ENGINES["matplotlib"]
1303def set_plotter(name):
1304    global PLOT_ENGINE
1305    PLOT_ENGINE = _ENGINES[name]
1306
1307def main():
1308    parser = argparse.ArgumentParser(
1309        description="Display jitter",
1310        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1311        )
1312    parser.add_argument('-p', '--projection', choices=PROJECTIONS,
1313                        default=PROJECTIONS[0],
1314                        help='coordinate projection')
1315    parser.add_argument('-s', '--size', type=str, default='10,40,100',
1316                        help='a,b,c lengths')
1317    parser.add_argument('-v', '--view', type=str, default='0,0,0',
1318                        help='initial view angles')
1319    parser.add_argument('-j', '--jitter', type=str, default='0,0,0',
1320                        help='initial angular dispersion')
1321    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS,
1322                        default=DISTRIBUTIONS[0],
1323                        help='jitter distribution')
1324    parser.add_argument('-m', '--mesh', type=int, default=30,
1325                        help='#points in theta-phi mesh')
1326    parser.add_argument('shape', choices=SHAPES, nargs='?', default=SHAPES[0],
1327                        help='oriented shape')
1328    opts = parser.parse_args()
1329    size = tuple(float(v) for v in opts.size.split(','))
1330    view = tuple(float(v) for v in opts.view.split(','))
1331    jitter = tuple(float(v) for v in opts.jitter.split(','))
1332    run(opts.shape, size=size, view=view, jitter=jitter,
1333        mesh=opts.mesh, dist=opts.distribution,
1334        projection=opts.projection)
1335
1336if __name__ == "__main__":
1337    main()
Note: See TracBrowser for help on using the repository browser.