source: sasmodels/sasmodels/jitter.py @ 3d7f364

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 3d7f364 was 3d7f364, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

tweak web jitter viewer details

  • Property mode set to 100755
File size: 50.3 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Jitter Explorer
5===============
6
7Application to explore orientation angle and angular dispersity.
8
9From the command line::
10
11    # Show docs
12    python -m sasmodels.jitter --help
13
14    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi
15    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0
16
17From a jupyter cell::
18
19    import ipyvolume as ipv
20    from sasmodels import jitter
21    import importlib; importlib.reload(jitter)
22    jitter.set_plotter("ipv")
23
24    size = (10, 40, 100)
25    view = (20, 0, 0)
26
27    #size = (15, 15, 100)
28    #view = (60, 60, 0)
29
30    dview = (0, 0, 0)
31    #dview = (5, 5, 0)
32    #dview = (15, 180, 0)
33    #dview = (180, 15, 0)
34
35    projection = 'equirectangular'
36    #projection = 'azimuthal_equidistance'
37    #projection = 'guyou'
38    #projection = 'sinusoidal'
39    #projection = 'azimuthal_equal_area'
40
41    dist = 'uniform'
42    #dist = 'gaussian'
43
44    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection)
45    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '')
46    #ipv.savefig(filename+'.png')
47"""
48from __future__ import division, print_function
49
50import argparse
51
52import numpy as np
53from numpy import pi, cos, sin, sqrt, exp, degrees, radians
54
55def draw_beam(axes, view=(0, 0), alpha=0.5, steps=6):
56    """
57    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
58    """
59    #axes.plot([0,0],[0,0],[1,-1])
60    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha)
61
62    u = np.linspace(0, 2 * np.pi, steps)
63    v = np.linspace(-1, 1, 2)
64
65    r = 0.02
66    x = r*np.outer(np.cos(u), np.ones_like(v))
67    y = r*np.outer(np.sin(u), np.ones_like(v))
68    z = 1.3*np.outer(np.ones_like(u), v)
69
70    theta, phi = view
71    shape = x.shape
72    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
73    points = Rz(phi)*Ry(theta)*points
74    x, y, z = [v.reshape(shape) for v in points]
75    axes.plot_surface(x, y, z, color='yellow', alpha=alpha)
76
77    # TODO: draw endcaps on beam
78    ## Draw tiny balls on the end will work
79    draw_sphere(axes, radius=0.02, center=(0, 0, 1.3), color='yellow')
80    draw_sphere(axes, radius=0.02, center=(0, 0, -1.3), color='yellow')
81    ## The following does not work
82    #triangles = [(0, i+1, i+2) for i in range(steps-2)]
83    #x_cap, y_cap = x[:, 0], y[:, 0]
84    #for z_cap in z[:, 0], z[:, -1]:
85    #    axes.plot_trisurf(x_cap, y_cap, z_cap, triangles,
86    #                      color='yellow', alpha=alpha)
87
88
89def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1):
90    """Draw an ellipsoid."""
91    a, b, c = size
92    u = np.linspace(0, 2 * np.pi, steps)
93    v = np.linspace(0, np.pi, steps)
94    x = a*np.outer(np.cos(u), np.sin(v))
95    y = b*np.outer(np.sin(u), np.sin(v))
96    z = c*np.outer(np.ones_like(u), np.cos(v))
97    x, y, z = transform_xyz(view, jitter, x, y, z)
98
99    axes.plot_surface(x, y, z, color='w', alpha=alpha)
100
101    draw_labels(axes, view, jitter, [
102        ('c+', [+0, +0, +c], [+1, +0, +0]),
103        ('c-', [+0, +0, -c], [+0, +0, -1]),
104        ('a+', [+a, +0, +0], [+0, +0, +1]),
105        ('a-', [-a, +0, +0], [+0, +0, -1]),
106        ('b+', [+0, +b, +0], [-1, +0, +0]),
107        ('b-', [+0, -b, +0], [-1, +0, +0]),
108    ])
109
110def draw_sc(axes, size, view, jitter, steps=None, alpha=1):
111    """Draw points for simple cubic paracrystal"""
112    atoms = _build_sc()
113    _draw_crystal(axes, size, view, jitter, atoms=atoms)
114
115def draw_fcc(axes, size, view, jitter, steps=None, alpha=1):
116    """Draw points for face-centered cubic paracrystal"""
117    # Build the simple cubic crystal
118    atoms = _build_sc()
119    # Define the centers for each face
120    # x planes at -1, 0, 1 have four centers per plane, at +/- 0.5 in y and z
121    x, y, z = (
122        [-1]*4 + [0]*4 + [1]*4,
123        ([-0.5]*2 + [0.5]*2)*3,
124        [-0.5, 0.5]*12,
125    )
126    # y and z planes can be generated by substituting x for y and z respectively
127    atoms.extend(zip(x+y+z, y+z+x, z+x+y))
128    _draw_crystal(axes, size, view, jitter, atoms=atoms)
129
130def draw_bcc(axes, size, view, jitter, steps=None, alpha=1):
131    """Draw points for body-centered cubic paracrystal"""
132    # Build the simple cubic crystal
133    atoms = _build_sc()
134    # Define the centers for each octant
135    # x plane at +/- 0.5 have four centers per plane at +/- 0.5 in y and z
136    x, y, z = (
137        [-0.5]*4 + [0.5]*4,
138        ([-0.5]*2 + [0.5]*2)*2,
139        [-0.5, 0.5]*8,
140    )
141    atoms.extend(zip(x, y, z))
142    _draw_crystal(axes, size, view, jitter, atoms=atoms)
143
144def _draw_crystal(axes, size, view, jitter, atoms=None):
145    atoms, size = np.asarray(atoms, 'd').T, np.asarray(size, 'd')
146    x, y, z = atoms*size[:, None]
147    x, y, z = transform_xyz(view, jitter, x, y, z)
148    axes.scatter([x[0]], [y[0]], [z[0]], c='yellow', marker='o')
149    axes.scatter(x[1:], y[1:], z[1:], c='r', marker='o')
150
151def _build_sc():
152    # three planes of 9 dots for x at -1, 0 and 1
153    x, y, z = (
154        [-1]*9 + [0]*9 + [1]*9,
155        ([-1]*3 + [0]*3 + [1]*3)*3,
156        [-1, 0, 1]*9,
157    )
158    atoms = list(zip(x, y, z))
159    #print(list(enumerate(atoms)))
160    # Pull the dot at (0, 0, 1) to the front of the list
161    # It will be highlighted in the view
162    index = 14
163    highlight = atoms[index]
164    del atoms[index]
165    atoms.insert(0, highlight)
166    return atoms
167
168def draw_box(axes, size, view):
169    a, b, c = size
170    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
171    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
172    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
173    x, y, z = transform_xyz(view, None, x, y, z)
174    def draw(i, j):
175        axes.plot([x[i],x[j]], [y[i], y[j]], [z[i], z[j]], color='black')
176    draw(0, 1)
177    draw(0, 2)
178    draw(0, 3)
179    draw(7, 4)
180    draw(7, 5)
181    draw(7, 6)
182
183def draw_parallelepiped(axes, size, view, jitter, steps=None,
184                        color=(0.6, 1.0, 0.6), alpha=1):
185    """Draw a parallelepiped."""
186    a, b, c = size
187    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
188    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
189    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
190    tri = np.array([
191        # counter clockwise triangles
192        # z: up/down, x: right/left, y: front/back
193        [0, 1, 2], [3, 2, 1], # top face
194        [6, 5, 4], [5, 6, 7], # bottom face
195        [0, 2, 6], [6, 4, 0], # right face
196        [1, 5, 7], [7, 3, 1], # left face
197        [2, 3, 6], [7, 6, 3], # front face
198        [4, 1, 0], [5, 1, 4], # back face
199    ])
200
201    x, y, z = transform_xyz(view, jitter, x, y, z)
202    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
203
204    # Colour the c+ face of the box.
205    # Since I can't control face color, instead draw a thin box situated just
206    # in front of the "c+" face.  Use the c face so that rotations about psi
207    # rotate that face.
208    if 0:
209        color = (1, 0.6, 0.6)  # pink
210        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
211        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
212        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
213        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001)
214        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
215
216    draw_labels(axes, view, jitter, [
217        ('c+', [+0, +0, +c], [+1, +0, +0]),
218        ('c-', [+0, +0, -c], [+0, +0, -1]),
219        ('a+', [+a, +0, +0], [+0, +0, +1]),
220        ('a-', [-a, +0, +0], [+0, +0, -1]),
221        ('b+', [+0, +b, +0], [-1, +0, +0]),
222        ('b-', [+0, -b, +0], [-1, +0, +0]),
223    ])
224
225def draw_sphere(axes, radius=0.5, steps=25, center=(0,0,0), color='w'):
226    """Draw a sphere"""
227    u = np.linspace(0, 2 * np.pi, steps)
228    v = np.linspace(0, np.pi, steps)
229
230    x = radius * np.outer(np.cos(u), np.sin(v)) + center[0]
231    y = radius * np.outer(np.sin(u), np.sin(v)) + center[1]
232    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2]
233    axes.plot_surface(x, y, z, color=color)
234    #axes.plot_wireframe(x, y, z)
235
236def draw_axes(axes, origin=(-1, -1, -1), length=(2, 2, 2)):
237    x, y, z = origin
238    dx, dy, dz = length
239    axes.plot([x, x+dx], [y, y], [z, z], color='black')
240    axes.plot([x, x], [y, y+dy], [z, z], color='black')
241    axes.plot([x, x], [y, y], [z, z+dz], color='black')
242
243def draw_person_on_sphere(axes, view, height=0.5, radius=0.5):
244    limb_offset = height * 0.05
245    head_radius = height * 0.10
246    head_height = height - head_radius
247    neck_length = head_radius * 0.50
248    shoulder_height = height - 2*head_radius - neck_length
249    torso_length = shoulder_height * 0.55
250    torso_radius = torso_length * 0.30
251    leg_length = shoulder_height - torso_length
252    arm_length = torso_length * 0.90
253
254    def _draw_part(x, z):
255        y = np.zeros_like(x)
256        xp, yp, zp = transform_xyz(view, None, x, y, z + radius)
257        axes.plot(xp, yp, zp, color='k')
258
259    # circle for head
260    u = np.linspace(0, 2 * np.pi, 40)
261    x = head_radius * np.cos(u)
262    z = head_radius * np.sin(u) + head_height
263    _draw_part(x, z)
264
265    # rectangle for body
266    x = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius])
267    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length
268    _draw_part(x, z)
269
270    # arms
271    x = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius])
272    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height])
273    _draw_part(x, z)
274    _draw_part(-x, z)
275
276    # legs
277    x = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset])
278    z = np.array([0, leg_length])
279    _draw_part(x, z)
280    _draw_part(-x, z)
281
282    limits = [-radius-height, radius+height]
283    axes.set_xlim(limits)
284    axes.set_ylim(limits)
285    axes.set_zlim(limits)
286    axes.set_axis_off()
287
288def draw_jitter(axes, view, jitter, dist='gaussian',
289                size=(0.1, 0.4, 1.0),
290                draw_shape=draw_parallelepiped,
291                projection='equirectangular',
292                alpha=0.8,
293                views=None):
294    """
295    Represent jitter as a set of shapes at different orientations.
296    """
297    project, project_weight = get_projection(projection)
298
299    # set max diagonal to 0.95
300    scale = 0.95/sqrt(sum(v**2 for v in size))
301    size = tuple(scale*v for v in size)
302
303    dtheta, dphi, dpsi = jitter
304    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist]
305    def steps(delta):
306        if views is None:
307            n = max(3, min(25, 2*int(base*delta/5)))
308        else:
309            n = views
310        return base*delta*np.linspace(-1, 1, n) if delta > 0 else [0.]
311    for theta in steps(dtheta):
312        for phi in steps(dphi):
313            for psi in steps(dpsi):
314                w = project_weight(theta, phi, 1.0, 1.0)
315                if w > 0:
316                    dview = project(theta, phi, psi)
317                    draw_shape(axes, size, view, dview, alpha=alpha)
318    for v in 'xyz':
319        a, b, c = size
320        lim = np.sqrt(a**2 + b**2 + c**2)
321        getattr(axes, 'set_'+v+'lim')([-lim, lim])
322        #getattr(axes, v+'axis').label.set_text(v)
323
324PROJECTIONS = [
325    # in order of PROJECTION number; do not change without updating the
326    # constants in kernel_iq.c
327    'equirectangular', 'sinusoidal', 'guyou', 'azimuthal_equidistance',
328    'azimuthal_equal_area',
329]
330def get_projection(projection):
331
332    """
333    jitter projections
334    <https://en.wikipedia.org/wiki/List_of_map_projections>
335
336    equirectangular (standard latitude-longitude mesh)
337        <https://en.wikipedia.org/wiki/Equirectangular_projection>
338        Allows free movement in phi (around the equator), but theta is
339        limited to +/- 90, and points are cos-weighted. Jitter in phi is
340        uniform in weight along a line of latitude.  With small theta and
341        phi ranging over +/- 180 this forms a wobbling disk.  With small
342        phi and theta ranging over +/- 90 this forms a wedge like a slice
343        of an orange.
344    azimuthal_equidistance (Postel)
345        <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>
346        Preserves distance from center, and so is an excellent map for
347        representing a bivariate gaussian on the surface.  Theta and phi
348        operate identically, cutting wegdes from the antipode of the viewing
349        angle.  This unfortunately does not allow free movement in either
350        theta or phi since the orthogonal wobble decreases to 0 as the body
351        rotates through 180 degrees.
352    sinusoidal (Sanson-Flamsteed, Mercator equal-area)
353        <https://en.wikipedia.org/wiki/Sinusoidal_projection>
354        Preserves arc length with latitude, giving bad behaviour at
355        theta near +/- 90.  Theta and phi operate somewhat differently,
356        so a system with a-b-c dtheta-dphi-dpsi will not give the same
357        value as one with b-a-c dphi-dtheta-dpsi, as would be the case
358        for azimuthal equidistance.  Free movement using theta or phi
359        uniform over +/- 180 will work, but not as well as equirectangular
360        phi, with theta being slightly worse.  Computationally it is much
361        cheaper for wide theta-phi meshes since it excludes points which
362        lie outside the sinusoid near theta +/- 90 rather than packing
363        them close together as in equirectangle.  Note that the poles
364        will be slightly overweighted for theta > 90 with the circle
365        from theta at 90+dt winding backwards around the pole, overlapping
366        the circle from theta at 90-dt.
367    Guyou (hemisphere-in-a-square) **not weighted**
368        <https://en.wikipedia.org/wiki/Guyou_hemisphere-in-a-square_projection>
369        With tiling, allows rotation in phi or theta through +/- 180, with
370        uniform spacing.  Both theta and phi allow free rotation, with wobble
371        in the orthogonal direction reasonably well behaved (though not as
372        good as equirectangular phi). The forward/reverse transformations
373        relies on elliptic integrals that are somewhat expensive, so the
374        behaviour has to be very good to justify the cost and complexity.
375        The weighting function for each point has not yet been computed.
376        Note: run the module *guyou.py* directly and it will show the forward
377        and reverse mappings.
378    azimuthal_equal_area  **incomplete**
379        <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>
380        Preserves the relative density of the surface patches.  Not that
381        useful and not completely implemented
382    Gauss-Kreuger **not implemented**
383        <https://en.wikipedia.org/wiki/Transverse_Mercator_projection#Ellipsoidal_transverse_Mercator>
384        Should allow free movement in theta, but phi is distorted.
385    """
386    # TODO: try Kent distribution instead of a gaussian warped by projection
387
388    if projection == 'equirectangular':  #define PROJECTION 1
389        def _project(theta_i, phi_j, psi):
390            latitude, longitude = theta_i, phi_j
391            return latitude, longitude, psi
392            #return Rx(phi_j)*Ry(theta_i)
393        def _weight(theta_i, phi_j, w_i, w_j):
394            return w_i*w_j*abs(cos(radians(theta_i)))
395    elif projection == 'sinusoidal':  #define PROJECTION 2
396        def _project(theta_i, phi_j, psi):
397            latitude = theta_i
398            scale = cos(radians(latitude))
399            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0
400            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
401            return latitude, longitude, psi
402            #return Rx(longitude)*Ry(latitude)
403        def _project(theta_i, phi_j, w_i, w_j):
404            latitude = theta_i
405            scale = cos(radians(latitude))
406            active = 1 if abs(phi_j) < abs(scale)*180 else 0
407            return active*w_i*w_j
408    elif projection == 'guyou':  #define PROJECTION 3  (eventually?)
409        def _project(theta_i, phi_j, psi):
410            from .guyou import guyou_invert
411            #latitude, longitude = guyou_invert([theta_i], [phi_j])
412            longitude, latitude = guyou_invert([phi_j], [theta_i])
413            return latitude, longitude, psi
414            #return Rx(longitude[0])*Ry(latitude[0])
415        def _weight(theta_i, phi_j, w_i, w_j):
416            return w_i*w_j
417    elif projection == 'azimuthal_equidistance':
418        # Note that calculates angles for Rz Ry rather than Rx Ry
419        def _project(theta_i, phi_j, psi):
420            latitude = sqrt(theta_i**2 + phi_j**2)
421            longitude = degrees(np.arctan2(phi_j, theta_i))
422            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
423            return latitude, longitude, psi-longitude, 'zyz'
424            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
425            #return R_to_xyz(R)
426            #return Rz(longitude)*Ry(latitude)
427        def _weight(theta_i, phi_j, w_i, w_j):
428            # Weighting for each point comes from the integral:
429            #     \int\int I(q, lat, log) sin(lat) dlat dlog
430            # We are doing a conformal mapping from disk to sphere, so we need
431            # a change of variables g(theta, phi) -> (lat, long):
432            #     lat, long = sqrt(theta^2 + phi^2), arctan(phi/theta)
433            # giving:
434            #     dtheta dphi = det(J) dlat dlong
435            # where J is the jacobian from the partials of g. Using
436            #     R = sqrt(theta^2 + phi^2),
437            # then
438            #     J = [[x/R, Y/R], -y/R^2, x/R^2]]
439            # and
440            #     det(J) = 1/R
441            # with the final integral being:
442            #    \int\int I(q, theta, phi) sin(R)/R dtheta dphi
443            #
444            # This does approximately the right thing, decreasing the weight
445            # of each point as you go farther out on the disk, but it hasn't
446            # yet been checked against the 1D integral results. Prior
447            # to declaring this "good enough" and checking that integrals
448            # work in practice, we will examine alternative mappings.
449            #
450            # The issue is that the mapping does not support the case of free
451            # rotation about a single axis correctly, with a small deviation
452            # in the orthogonal axis independent of the first axis.  Like the
453            # usual polar coordiates integration, the integrated sections
454            # form wedges, though at least in this case the wedge cuts through
455            # the entire sphere, and treats theta and phi identically.
456            latitude = sqrt(theta_i**2 + phi_j**2)
457            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
458            return weight*w_i*w_j if latitude < 180 else 0
459    elif projection == 'azimuthal_equal_area':
460        # Note that calculates angles for Rz Ry rather than Rx Ry
461        def _project(theta_i, phi_j, psi):
462            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180)
463            latitude = 180-degrees(2*np.arccos(radius))
464            longitude = degrees(np.arctan2(phi_j, theta_i))
465            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
466            return latitude, longitude, psi, "zyz"
467            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
468            #return R_to_xyz(R)
469            #return Rz(longitude)*Ry(latitude)
470        def _weight(theta_i, phi_j, w_i, w_j):
471            latitude = sqrt(theta_i**2 + phi_j**2)
472            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
473            return weight*w_i*w_j if latitude < 180 else 0
474    else:
475        raise ValueError("unknown projection %r"%projection)
476
477    return _project, _weight
478
479def R_to_xyz(R):
480    """
481    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix.
482
483    Extracting Euler Angles from a Rotation Matrix
484    Mike Day, Insomniac Games
485    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf
486    Based on: Shoemake’s "Euler Angle Conversion", Graphics Gems IV, pp.  222-229
487    """
488    phi = np.arctan2(R[1, 2], R[2, 2])
489    theta = np.arctan2(-R[0, 2], np.sqrt(R[0, 0]**2 + R[0, 1]**2))
490    psi = np.arctan2(R[0, 1], R[0, 0])
491    return np.degrees(phi), np.degrees(theta), np.degrees(psi)
492
493def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian',
494              projection='equirectangular'):
495    """
496    Draw the dispersion mesh showing the theta-phi orientations at which
497    the model will be evaluated.
498    """
499
500    _project, _weight = get_projection(projection)
501    def _rotate(theta, phi, z):
502        dview = _project(theta, phi, 0.)
503        if len(dview) == 4: # hack for zyz coords
504            return Rz(dview[1])*Ry(dview[0])*z
505        else:
506            return Rx(dview[1])*Ry(dview[0])*z
507
508
509    dist_x = np.linspace(-1, 1, n)
510    weights = np.ones_like(dist_x)
511    if dist == 'gaussian':
512        dist_x *= 3
513        weights = exp(-0.5*dist_x**2)
514    elif dist == 'rectangle':
515        # Note: uses sasmodels ridiculous definition of rectangle width
516        dist_x *= sqrt(3)
517    elif dist == 'uniform':
518        pass
519    else:
520        raise ValueError("expected dist to be gaussian, rectangle or uniform")
521
522    # mesh in theta, phi formed by rotating z
523    dtheta, dphi, dpsi = jitter
524    z = np.matrix([[0], [0], [radius]])
525    points = np.hstack([_rotate(theta_i, phi_j, z)
526                        for theta_i in dtheta*dist_x
527                        for phi_j in dphi*dist_x])
528    dist_w = np.array([_weight(theta_i, phi_j, w_i, w_j)
529                       for w_i, theta_i in zip(weights, dtheta*dist_x)
530                       for w_j, phi_j in zip(weights, dphi*dist_x)])
531    #print(max(dist_w), min(dist_w), min(dist_w[dist_w > 0]))
532    points = points[:, dist_w > 0]
533    dist_w = dist_w[dist_w > 0]
534    dist_w /= max(dist_w)
535
536    # rotate relative to beam
537    points = orient_relative_to_beam(view, points)
538
539    x, y, z = [np.array(v).flatten() for v in points]
540    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
541    axes.scatter(x, y, z, c=dist_w, marker='o', vmin=0., vmax=1.)
542
543def draw_labels(axes, view, jitter, text):
544    """
545    Draw text at a particular location.
546    """
547    labels, locations, orientations = zip(*text)
548    px, py, pz = zip(*locations)
549    dx, dy, dz = zip(*orientations)
550
551    px, py, pz = transform_xyz(view, jitter, px, py, pz)
552    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
553
554    # TODO: zdir for labels is broken, and labels aren't appearing.
555    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
556        zdir = np.asarray(zdir).flatten()
557        axes.text(p[0], p[1], p[2], label, zdir=zdir)
558
559# Definition of rotation matrices comes from wikipedia:
560#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
561def Rx(angle):
562    """Construct a matrix to rotate points about *x* by *angle* degrees."""
563    angle = radians(angle)
564    rot = [[1, 0, 0],
565           [0, +cos(angle), -sin(angle)],
566           [0, +sin(angle), +cos(angle)]]
567    return np.matrix(rot)
568
569def Ry(angle):
570    """Construct a matrix to rotate points about *y* by *angle* degrees."""
571    angle = radians(angle)
572    rot = [[+cos(angle), 0, +sin(angle)],
573           [0, 1, 0],
574           [-sin(angle), 0, +cos(angle)]]
575    return np.matrix(rot)
576
577def Rz(angle):
578    """Construct a matrix to rotate points about *z* by *angle* degrees."""
579    angle = radians(angle)
580    rot = [[+cos(angle), -sin(angle), 0],
581           [+sin(angle), +cos(angle), 0],
582           [0, 0, 1]]
583    return np.matrix(rot)
584
585def transform_xyz(view, jitter, x, y, z):
586    """
587    Send a set of (x,y,z) points through the jitter and view transforms.
588    """
589    x, y, z = [np.asarray(v) for v in (x, y, z)]
590    shape = x.shape
591    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
592    points = apply_jitter(jitter, points)
593    points = orient_relative_to_beam(view, points)
594    x, y, z = [np.array(v).reshape(shape) for v in points]
595    return x, y, z
596
597def apply_jitter(jitter, points):
598    """
599    Apply the jitter transform to a set of points.
600
601    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
602    """
603    if jitter is None:
604        return points
605    # Hack to deal with the fact that azimuthal_equidistance uses euler angles
606    if len(jitter) == 4:
607        dtheta, dphi, dpsi, _ = jitter
608        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points
609    else:
610        dtheta, dphi, dpsi = jitter
611        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
612    return points
613
614def orient_relative_to_beam(view, points):
615    """
616    Apply the view transform to a set of points.
617
618    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
619    """
620    theta, phi, psi = view
621    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle
622    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles
623    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle
624    return points
625
626def orient_relative_to_beam_quaternion(view, points):
627    """
628    Apply the view transform to a set of points.
629
630    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
631
632    This variant uses quaternions rather than rotation matrices for the
633    computation.  It works but it is not used because it doesn't solve
634    any problems.  The challenge of mapping theta/phi/psi to SO(3) does
635    not disappear by calculating the transform differently.
636    """
637    theta, phi, psi = view
638    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1]
639    q = Quaternion(1, [0, 0, 0])
640    ## Compose a rotation about the three axes by rotating
641    ## the unit vectors before applying the rotation.
642    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q
643    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q
644    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q
645    ## The above turns out to be equivalent to reversing
646    ## the order of application, so ignore it and use below.
647    q = q * Quaternion.from_angle_axis(theta, x)
648    q = q * Quaternion.from_angle_axis(phi, y)
649    q = q * Quaternion.from_angle_axis(psi, z)
650    ## Reverse the order by post-multiply rather than pre-multiply
651    #q = Quaternion.from_angle_axis(theta, x) * q
652    #q = Quaternion.from_angle_axis(phi, y) * q
653    #q = Quaternion.from_angle_axis(psi, z) * q
654    #print("axes psi", q.rot(np.matrix([x, y, z]).T))
655    return q.rot(points)
656#orient_relative_to_beam = orient_relative_to_beam_quaternion
657
658# Simple stand-alone quaternion class
659import numpy as np
660from copy import copy
661class Quaternion(object):
662    def __init__(self, w, r):
663         self.w = w
664         self.r = np.asarray(r, dtype='d')
665    @staticmethod
666    def from_angle_axis(theta, r):
667         theta = np.radians(theta)/2
668         r = np.asarray(r)
669         w = np.cos(theta)
670         r = np.sin(theta)*r/np.dot(r,r)
671         return Quaternion(w, r)
672    def __mul__(self, other):
673        if isinstance(other, Quaternion):
674            w = self.w*other.w - np.dot(self.r, other.r)
675            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r)
676            return Quaternion(w, r)
677    def rot(self, v):
678        v = np.asarray(v).T
679        use_transpose = (v.shape[-1] != 3)
680        if use_transpose: v = v.T
681        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v)
682        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v))
683        if use_transpose: v = v.T
684        return v.T
685    def conj(self):
686        return Quaternion(self.w, -self.r)
687    def inv(self):
688        return self.conj()/self.norm()**2
689    def norm(self):
690        return np.sqrt(self.w**2 + np.sum(self.r**2))
691    def __str__(self):
692        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2])
693def test_qrot():
694    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees from y.
695    # The rotation axis is determined by rotating the point [0, 1, 0] about x.
696    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0])
697    q = Quaternion.from_angle_axis(60, ax)
698    # Set the point to be rotated, and its expected rotated position.
699    p = [1, -1, 2]
700    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8]
701    #print(q, q.rot(p) - target)
702    assert max(abs(q.rot(p) - target)) < 1e-14
703#test_qrot()
704#import sys; sys.exit()
705
706# translate between number of dimension of dispersity and the number of
707# points along each dimension.
708PD_N_TABLE = {
709    (0, 0, 0): (0, 0, 0),     # 0
710    (1, 0, 0): (100, 0, 0),   # 100
711    (0, 1, 0): (0, 100, 0),
712    (0, 0, 1): (0, 0, 100),
713    (1, 1, 0): (30, 30, 0),   # 900
714    (1, 0, 1): (30, 0, 30),
715    (0, 1, 1): (0, 30, 30),
716    (1, 1, 1): (15, 15, 15),  # 3375
717}
718
719def clipped_range(data, portion=1.0, mode='central'):
720    """
721    Determine range from data.
722
723    If *portion* is 1, use full range, otherwise use the center of the range
724    or the top of the range, depending on whether *mode* is 'central' or 'top'.
725    """
726    if portion == 1.0:
727        return data.min(), data.max()
728    elif mode == 'central':
729        data = np.sort(data.flatten())
730        offset = int(portion*len(data)/2 + 0.5)
731        return data[offset], data[-offset]
732    elif mode == 'top':
733        data = np.sort(data.flatten())
734        offset = int(portion*len(data) + 0.5)
735        return data[offset], data[-1]
736
737def draw_scattering(calculator, axes, view, jitter, dist='gaussian'):
738    """
739    Plot the scattering for the particular view.
740
741    *calculator* is returned from :func:`build_model`.  *axes* are the 3D axes
742    on which the data will be plotted.  *view* and *jitter* are the current
743    orientation and orientation dispersity.  *dist* is one of the sasmodels
744    weight distributions.
745    """
746    if dist == 'uniform':  # uniform is not yet in this branch
747        dist, scale = 'rectangle', 1/sqrt(3)
748    else:
749        scale = 1
750
751    # add the orientation parameters to the model parameters
752    theta, phi, psi = view
753    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
754    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd > 0, phi_pd > 0, psi_pd > 0)]
755    ## increase pd_n for testing jitter integration rather than simple viz
756    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
757
758    pars = dict(
759        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
760        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
761        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
762    )
763    pars.update(calculator.pars)
764
765    # compute the pattern
766    qx, qy = calculator._data.x_bins, calculator._data.y_bins
767    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
768
769    # scale it and draw it
770    Iqxy = np.log(Iqxy)
771    if calculator.limits:
772        # use limits from orientation (0,0,0)
773        vmin, vmax = calculator.limits
774    else:
775        vmax = Iqxy.max()
776        vmin = vmax*10**-7
777        #vmin, vmax = clipped_range(Iqxy, portion=portion, mode='top')
778    #vmin, vmax = Iqxy.min(), Iqxy.max()
779    #print("range",(vmin,vmax))
780    #qx, qy = np.meshgrid(qx, qy)
781    if 0:
782        from matplotlib import cm
783        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
784        level[level < 0] = 0
785        colors = plt.get_cmap()(level)
786        #colors = cm.coolwarm(level)
787        #colors = cm.gist_yarg(level)
788        #colors = cm.Wistia(level)
789        colors[level<=0, 3] = 0.  # set floor to transparent
790        x, y = np.meshgrid(qx/qx.max(), qy/qy.max())
791        axes.plot_surface(x, y, -1.1*np.ones_like(x), facecolors=colors)
792    elif 1:
793        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
794                      levels=np.linspace(vmin, vmax, 24))
795    else:
796        axes.pcolormesh(qx, qy, Iqxy)
797
798def build_model(model_name, n=150, qmax=0.5, **pars):
799    """
800    Build a calculator for the given shape.
801
802    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
803    on which to evaluate the model.  The remaining parameters are stored in
804    the returned calculator as *calculator.pars*.  They are used by
805    :func:`draw_scattering` to set the non-orientation parameters in the
806    calculation.
807
808    Returns a *calculator* function which takes a dictionary or parameters and
809    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
810    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
811    for details.
812    """
813    from sasmodels.core import load_model_info, build_model as build_sasmodel
814    from sasmodels.data import empty_data2D
815    from sasmodels.direct_model import DirectModel
816
817    model_info = load_model_info(model_name)
818    model = build_sasmodel(model_info) #, dtype='double!')
819    q = np.linspace(-qmax, qmax, n)
820    data = empty_data2D(q, q)
821    calculator = DirectModel(data, model)
822
823    # stuff the values for non-orientation parameters into the calculator
824    calculator.pars = pars.copy()
825    calculator.pars.setdefault('backgound', 1e-3)
826
827    # fix the data limits so that we can see if the pattern fades
828    # under rotation or angular dispersion
829    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
830    Iqxy = np.log(Iqxy)
831    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
832    calculator.limits = vmin, vmax+1
833
834    return calculator
835
836def select_calculator(model_name, n=150, size=(10, 40, 100)):
837    """
838    Create a model calculator for the given shape.
839
840    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
841    parallelepiped or bcc_paracrystal. *n* is the number of points to use
842    in the q range.  *qmax* is chosen based on model parameters for the
843    given model to show something intersting.
844
845    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
846    equitorial axes and polar axis respectively.  See :func:`build_model`
847    for details on the returned calculator.
848    """
849    a, b, c = size
850    d_factor = 0.06  # for paracrystal models
851    if model_name == 'sphere':
852        calculator = build_model('sphere', n=n, radius=c)
853        a = b = c
854    elif model_name == 'sc_paracrystal':
855        a = b = c
856        dnn = c
857        radius = 0.5*c
858        calculator = build_model('sc_paracrystal', n=n, dnn=dnn,
859                                 d_factor=d_factor, radius=(1-d_factor)*radius,
860                                 background=0)
861    elif model_name == 'fcc_paracrystal':
862        a = b = c
863        # nearest neigbour distance dnn should be 2 radius, but I think the
864        # model uses lattice spacing rather than dnn in its calculations
865        dnn = 0.5*c
866        radius = sqrt(2)/4 * c
867        calculator = build_model('fcc_paracrystal', n=n, dnn=dnn,
868                                 d_factor=d_factor, radius=(1-d_factor)*radius,
869                                 background=0)
870    elif model_name == 'bcc_paracrystal':
871        a = b = c
872        # nearest neigbour distance dnn should be 2 radius, but I think the
873        # model uses lattice spacing rather than dnn in its calculations
874        dnn = 0.5*c
875        radius = sqrt(3)/2 * c
876        calculator = build_model('bcc_paracrystal', n=n, dnn=dnn,
877                                 d_factor=d_factor, radius=(1-d_factor)*radius,
878                                 background=0)
879    elif model_name == 'cylinder':
880        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
881        a = b
882    elif model_name == 'ellipsoid':
883        calculator = build_model('ellipsoid', n=n, qmax=1.0,
884                                 radius_polar=c, radius_equatorial=b)
885        a = b
886    elif model_name == 'triaxial_ellipsoid':
887        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
888                                 radius_equat_minor=a,
889                                 radius_equat_major=b,
890                                 radius_polar=c)
891    elif model_name == 'parallelepiped':
892        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
893    else:
894        raise ValueError("unknown model %s"%model_name)
895
896    return calculator, (a, b, c)
897
898SHAPES = [
899    'parallelepiped',
900    'sphere', 'ellipsoid', 'triaxial_ellipsoid',
901    'cylinder',
902    'fcc_paracrystal', 'bcc_paracrystal', 'sc_paracrystal',
903]
904
905DRAW_SHAPES = {
906    'fcc_paracrystal': draw_fcc,
907    'bcc_paracrystal': draw_bcc,
908    'sc_paracrystal': draw_sc,
909    'parallelepiped': draw_parallelepiped,
910}
911
912DISTRIBUTIONS = [
913    'gaussian', 'rectangle', 'uniform',
914]
915DIST_LIMITS = {
916    'gaussian': 30,
917    'rectangle': 90/sqrt(3),
918    'uniform': 90,
919}
920
921
922def run(model_name='parallelepiped', size=(10, 40, 100),
923        view=(0, 0, 0), jitter=(0, 0, 0),
924        dist='gaussian', mesh=30,
925        projection='equirectangular'):
926    """
927    Show an interactive orientation and jitter demo.
928
929    *model_name* is one of: sphere, ellipsoid, triaxial_ellipsoid,
930    parallelepiped, cylinder, or sc/fcc/bcc_paracrystal
931
932    *size* gives the dimensions (a, b, c) of the shape.
933
934    *view* gives the initial view (theta, phi, psi) of the shape.
935
936    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape.
937
938    *dist* is the type of dispersition: gaussian, rectangle, or uniform.
939
940    *mesh* is the number of points in the dispersion mesh.
941
942    *projection* is the map projection to use for the mesh: equirectangular,
943    sinusoidal, guyou, azimuthal_equidistance, or azimuthal_equal_area.
944    """
945    # projection number according to 1-order position in list, but
946    # only 1 and 2 are implemented so far.
947    from sasmodels import generate
948    generate.PROJECTION = PROJECTIONS.index(projection) + 1
949    if generate.PROJECTION > 2:
950        print("*** PROJECTION %s not implemented in scattering function ***"%projection)
951        generate.PROJECTION = 2
952
953    # set up calculator
954    calculator, size = select_calculator(model_name, n=150, size=size)
955    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped)
956    #draw_shape = draw_fcc
957
958    ## uncomment to set an independent the colour range for every view
959    ## If left commented, the colour range is fixed for all views
960    calculator.limits = None
961
962    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection)
963
964def mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
965    import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot
966    import matplotlib.pyplot as plt
967    from matplotlib.widgets import Slider
968
969    ## create the plot window
970    #plt.hold(True)
971    plt.subplots(num=None, figsize=(5.5, 5.5))
972    plt.set_cmap('gist_earth')
973    plt.clf()
974    plt.gcf().canvas.set_window_title(projection)
975    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
976    #axes = plt.subplot(gs[0], projection='3d')
977    axes = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
978    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
979        axes.axis('square')
980    except Exception:
981        pass
982
983    axcolor = 'lightgoldenrodyellow'
984
985    ## add control widgets to plot
986    axes_theta = plt.axes([0.1, 0.15, 0.45, 0.04], facecolor=axcolor)
987    axes_phi = plt.axes([0.1, 0.1, 0.45, 0.04], facecolor=axcolor)
988    axes_psi = plt.axes([0.1, 0.05, 0.45, 0.04], facecolor=axcolor)
989    stheta = Slider(axes_theta, u'Ξ', -90, 90, valinit=0)
990    sphi = Slider(axes_phi, u'φ', -180, 180, valinit=0)
991    spsi = Slider(axes_psi, u'ψ', -180, 180, valinit=0)
992
993    axes_dtheta = plt.axes([0.75, 0.15, 0.15, 0.04], facecolor=axcolor)
994    axes_dphi = plt.axes([0.75, 0.1, 0.15, 0.04], facecolor=axcolor)
995    axes_dpsi = plt.axes([0.75, 0.05, 0.15, 0.04], facecolor=axcolor)
996    # Note: using ridiculous definition of rectangle distribution, whose width
997    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
998    # the maximum width to 90.
999    dlimit = DIST_LIMITS[dist]
1000    sdtheta = Slider(axes_dtheta, u'Δξ', 0, 2*dlimit, valinit=0)
1001    sdphi = Slider(axes_dphi, u'Δφ', 0, 2*dlimit, valinit=0)
1002    sdpsi = Slider(axes_dpsi, u'Δψ', 0, 2*dlimit, valinit=0)
1003
1004    ## initial view and jitter
1005    theta, phi, psi = view
1006    stheta.set_val(theta)
1007    sphi.set_val(phi)
1008    spsi.set_val(psi)
1009    dtheta, dphi, dpsi = jitter
1010    sdtheta.set_val(dtheta)
1011    sdphi.set_val(dphi)
1012    sdpsi.set_val(dpsi)
1013
1014    ## callback to draw the new view
1015    def update(val, axis=None):
1016        view = stheta.val, sphi.val, spsi.val
1017        jitter = sdtheta.val, sdphi.val, sdpsi.val
1018        # set small jitter as 0 if multiple pd dims
1019        dims = sum(v > 0 for v in jitter)
1020        limit = [0, 0.5, 5, 5][dims]
1021        jitter = [0 if v < limit else v for v in jitter]
1022        axes.cla()
1023
1024        ## Visualize as person on globe
1025        #draw_sphere(axes)
1026        #draw_person_on_sphere(axes, view)
1027
1028        ## Move beam instead of shape
1029        #draw_beam(axes, -view[:2])
1030        #draw_jitter(axes, (0,0,0), (0,0,0), views=3)
1031
1032        ## Move shape and draw scattering
1033        draw_beam(axes, (0, 0))
1034        draw_jitter(axes, view, jitter, dist=dist, size=size,
1035                    draw_shape=draw_shape, projection=projection, views=3)
1036        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection)
1037        draw_scattering(calculator, axes, view, jitter, dist=dist)
1038
1039        plt.gcf().canvas.draw()
1040
1041    ## bind control widgets to view updater
1042    stheta.on_changed(lambda v: update(v, 'theta'))
1043    sphi.on_changed(lambda v: update(v, 'phi'))
1044    spsi.on_changed(lambda v: update(v, 'psi'))
1045    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
1046    sdphi.on_changed(lambda v: update(v, 'dphi'))
1047    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
1048
1049    ## initialize view
1050    update(None, 'phi')
1051
1052    ## go interactive
1053    plt.show()
1054
1055
1056def map_colors(z, kw):
1057    from matplotlib import cm
1058
1059    cmap = kw.pop('cmap', cm.coolwarm)
1060    alpha = kw.pop('alpha', None)
1061    vmin = kw.pop('vmin', z.min())
1062    vmax = kw.pop('vmax', z.max())
1063    c = kw.pop('c', None)
1064    color = kw.pop('color', c)
1065    if color is None:
1066        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1)
1067        color = cmap(znorm)
1068    elif isinstance(color, np.ndarray) and color.shape == z.shape:
1069        color = cmap(color)
1070    if alpha is None:
1071        if isinstance(color, np.ndarray):
1072            color = color[..., :3]
1073    else:
1074        color[..., 3] = alpha
1075    kw['color'] = color
1076
1077def make_vec(*args):
1078    #return [np.asarray(v, 'd').flatten() for v in args]
1079    return [np.asarray(v, 'd') for v in args]
1080
1081def make_image(z, kw):
1082    import PIL.Image
1083    from matplotlib import cm
1084
1085    cmap = kw.pop('cmap', cm.coolwarm)
1086
1087    znorm = (z-z.min())/z.ptp()
1088    c = cmap(znorm)
1089    c = c[..., :3]
1090    rgb = np.asarray(c*255, 'u1')
1091    image = PIL.Image.fromarray(rgb, mode='RGB')
1092    return image
1093
1094
1095_IPV_MARKERS = {
1096    'o': 'sphere',
1097}
1098_IPV_COLORS = {
1099    'w': 'white',
1100    'k': 'black',
1101    'c': 'cyan',
1102    'm': 'magenta',
1103    'y': 'yellow',
1104    'r': 'red',
1105    'g': 'green',
1106    'b': 'blue',
1107}
1108def ipv_fix_color(kw):
1109    alpha = kw.pop('alpha', None)
1110    color = kw.get('color', None)
1111    if isinstance(color, str):
1112        color = _IPV_COLORS.get(color, color)
1113        kw['color'] = color
1114    if alpha is not None:
1115        color = kw['color']
1116        #TODO: convert color to [r, g, b, a] if not already
1117        if isinstance(color, (tuple, list)):
1118            if len(color) == 3:
1119                color = (color[0], color[1], color[2], alpha)
1120            else:
1121                color = (color[0], color[1], color[2], alpha*color[3])
1122            color = np.array(color)
1123        if isinstance(color, np.ndarray) and color.shape[-1] == 4:
1124            color[..., 3] = alpha
1125            kw['color'] = color
1126
1127def ipv_set_transparency(kw, obj):
1128    color = kw.get('color', None)
1129    if (isinstance(color, np.ndarray)
1130            and color.shape[-1] == 4
1131            and (color[..., 3] != 1.0).any()):
1132        obj.material.transparent = True
1133        obj.material.side = "FrontSide"
1134
1135def ipv_axes():
1136    import ipyvolume as ipv
1137
1138    class Plotter:
1139        # transparency can be achieved by setting the following:
1140        #    mesh.color = [r, g, b, alpha]
1141        #    mesh.material.transparent = True
1142        #    mesh.material.side = "FrontSide"
1143        # smooth(ish) rotation can be achieved by setting:
1144        #    slide.continuous_update = True
1145        #    figure.animation = 0.
1146        #    mesh.material.x = x
1147        #    mesh.material.y = y
1148        #    mesh.material.z = z
1149        # maybe need to synchronize update of x/y/z to avoid shimmy when moving
1150        def plot(self, x, y, z, **kw):
1151            ipv_fix_color(kw)
1152            x, y, z = make_vec(x, y, z)
1153            ipv.plot(x, y, z, **kw)
1154        def plot_surface(self, x, y, z, **kw):
1155            facecolors = kw.pop('facecolors', None)
1156            if facecolors is not None:
1157                kw['color'] = facecolors
1158            ipv_fix_color(kw)
1159            x, y, z = make_vec(x, y, z)
1160            h = ipv.plot_surface(x, y, z, **kw)
1161            ipv_set_transparency(kw, h)
1162            #h.material.side = "DoubleSide"
1163            return h
1164        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw):
1165            ipv_fix_color(kw)
1166            x, y, z = make_vec(x, y, Z)
1167            if triangles is not None:
1168                triangles = np.asarray(triangles)
1169            h = ipv.plot_trisurf(x, y, z, triangles=triangles, **kw)
1170            ipv_set_transparency(kw, h)
1171            return h
1172        def scatter(self, x, y, z, **kw):
1173            x, y, z = make_vec(x, y, z)
1174            map_colors(z, kw)
1175            marker = kw.get('marker', None)
1176            kw['marker'] = _IPV_MARKERS.get(marker, marker)
1177            h = ipv.scatter(x, y, z, **kw)
1178            ipv_set_transparency(kw, h)
1179            return h
1180        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw):
1181            # Don't use contour for now (although we might want to later)
1182            self.pcolor(x, y, v, zdir='z', offset=offset, **kw)
1183        def pcolor(self, x, y, v, zdir='z', offset=0, **kw):
1184            x, y, v = make_vec(x, y, v)
1185            image = make_image(v, kw)
1186            xmin, xmax = x.min(), x.max()
1187            ymin, ymax = y.min(), y.max()
1188            x = np.array([[xmin, xmax], [xmin, xmax]])
1189            y = np.array([[ymin, ymin], [ymax, ymax]])
1190            z = x*0 + offset
1191            u = np.array([[0., 1], [0, 1]])
1192            v = np.array([[0., 0], [1, 1]])
1193            h = ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False)
1194            ipv_set_transparency(kw, h)
1195            h.material.side = "DoubleSide"
1196            return h
1197        def text(self, *args, **kw):
1198            pass
1199        def set_xlim(self, limits):
1200            ipv.xlim(*limits)
1201        def set_ylim(self, limits):
1202            ipv.ylim(*limits)
1203        def set_zlim(self, limits):
1204            ipv.zlim(*limits)
1205        def set_axes_on(self):
1206            ipv.style.axis_on()
1207        def set_axis_off(self):
1208            ipv.style.axes_off()
1209    return Plotter()
1210
1211def ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
1212    import ipywidgets as widgets
1213    import ipyvolume as ipv
1214
1215    axes = ipv_axes()
1216
1217    def draw(view, jitter):
1218        camera = ipv.gcf().camera
1219        #print(ipv.gcf().__dict__.keys())
1220        #print(dir(ipv.gcf()))
1221        ipv.figure(animation=0.)  # no animation when updating object mesh
1222
1223        # set small jitter as 0 if multiple pd dims
1224        dims = sum(v > 0 for v in jitter)
1225        limit = [0, 0.5, 5, 5][dims]
1226        jitter = [0 if v < limit else v for v in jitter]
1227
1228        ## Visualize as person on globe
1229        #draw_beam(axes, (0, 0))
1230        #draw_sphere(axes)
1231        #draw_person_on_sphere(axes, view)
1232
1233        ## Move beam instead of shape
1234        #draw_beam(axes, view=(-view[0], -view[1]))
1235        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0))
1236
1237        ## Move shape and draw scattering
1238        draw_beam(axes, (0, 0), steps=25)
1239        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.0,
1240                    draw_shape=draw_shape, projection=projection)
1241        draw_mesh(axes, view, jitter, dist=dist, n=mesh, radius=0.95,
1242                  projection=projection)
1243        draw_scattering(calculator, axes, view, jitter, dist=dist)
1244
1245        draw_axes(axes, origin=(-1, -1, -1.1))
1246        ipv.style.box_off()
1247        ipv.style.axes_off()
1248        ipv.xyzlabel(" ", " ", " ")
1249
1250        ipv.gcf().camera = camera
1251        ipv.show()
1252
1253
1254    trange, prange = (-180., 180., 1.), (-180., 180., 1.)
1255    dtrange, dprange = (0., 180., 1.), (0., 180., 1.)
1256
1257    ## Super simple interfaca, but uses non-ascii variable namese
1258    # Ξ φ ψ Δξ Δφ Δψ
1259    #def update(**kw):
1260    #    view = kw['Ξ'], kw['φ'], kw['ψ']
1261    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ']
1262    #    draw(view, jitter)
1263    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange)
1264
1265    def update(theta, phi, psi, dtheta, dphi, dpsi):
1266        draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi))
1267
1268    def slider(name, slice, init=0.):
1269        return widgets.FloatSlider(
1270            value=init,
1271            min=slice[0],
1272            max=slice[1],
1273            step=slice[2],
1274            description=name,
1275            disabled=False,
1276            #continuous_update=True,
1277            continuous_update=False,
1278            orientation='horizontal',
1279            readout=True,
1280            readout_format='.1f',
1281            )
1282    theta = slider(u'Ξ', trange, view[0])
1283    phi = slider(u'φ', prange, view[1])
1284    psi = slider(u'ψ', prange, view[2])
1285    dtheta = slider(u'Δξ', dtrange, jitter[0])
1286    dphi = slider(u'Δφ', dprange, jitter[1])
1287    dpsi = slider(u'Δψ', dprange, jitter[2])
1288    fields = {
1289        'theta': theta, 'phi': phi, 'psi': psi,
1290        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi,
1291    }
1292    ui = widgets.HBox([
1293        widgets.VBox([theta, phi, psi]),
1294        widgets.VBox([dtheta, dphi, dpsi])
1295    ])
1296
1297    out = widgets.interactive_output(update, fields)
1298    display(ui, out)
1299
1300
1301_ENGINES = {
1302    "matplotlib": mpl_plot,
1303    "mpl": mpl_plot,
1304    #"plotly": plotly_plot,
1305    "ipvolume": ipv_plot,
1306    "ipv": ipv_plot,
1307}
1308PLOT_ENGINE = _ENGINES["matplotlib"]
1309def set_plotter(name):
1310    global PLOT_ENGINE
1311    PLOT_ENGINE = _ENGINES[name]
1312
1313def main():
1314    parser = argparse.ArgumentParser(
1315        description="Display jitter",
1316        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1317        )
1318    parser.add_argument('-p', '--projection', choices=PROJECTIONS,
1319                        default=PROJECTIONS[0],
1320                        help='coordinate projection')
1321    parser.add_argument('-s', '--size', type=str, default='10,40,100',
1322                        help='a,b,c lengths')
1323    parser.add_argument('-v', '--view', type=str, default='0,0,0',
1324                        help='initial view angles')
1325    parser.add_argument('-j', '--jitter', type=str, default='0,0,0',
1326                        help='initial angular dispersion')
1327    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS,
1328                        default=DISTRIBUTIONS[0],
1329                        help='jitter distribution')
1330    parser.add_argument('-m', '--mesh', type=int, default=30,
1331                        help='#points in theta-phi mesh')
1332    parser.add_argument('shape', choices=SHAPES, nargs='?', default=SHAPES[0],
1333                        help='oriented shape')
1334    opts = parser.parse_args()
1335    size = tuple(float(v) for v in opts.size.split(','))
1336    view = tuple(float(v) for v in opts.view.split(','))
1337    jitter = tuple(float(v) for v in opts.jitter.split(','))
1338    run(opts.shape, size=size, view=view, jitter=jitter,
1339        mesh=opts.mesh, dist=opts.distribution,
1340        projection=opts.projection)
1341
1342if __name__ == "__main__":
1343    main()
Note: See TracBrowser for help on using the repository browser.