source: sasmodels/sasmodels/jitter.py @ 4057e06

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 4057e06 was 4057e06, checked in by GitHub <noreply@…>, 9 months ago

Merge branch 'py3' into webgl_jitter_viewer

  • Property mode set to 100755
File size: 50.8 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Jitter Explorer
5===============
6
7Application to explore orientation angle and angular dispersity.
8
9From the command line::
10
11    # Show docs
12    python -m sasmodels.jitter --help
13
14    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi
15    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0
16
17From a jupyter cell::
18
19    import ipyvolume as ipv
20    from sasmodels import jitter
21    import importlib; importlib.reload(jitter)
22    jitter.set_plotter("ipv")
23
24    size = (10, 40, 100)
25    view = (20, 0, 0)
26
27    #size = (15, 15, 100)
28    #view = (60, 60, 0)
29
30    dview = (0, 0, 0)
31    #dview = (5, 5, 0)
32    #dview = (15, 180, 0)
33    #dview = (180, 15, 0)
34
35    projection = 'equirectangular'
36    #projection = 'azimuthal_equidistance'
37    #projection = 'guyou'
38    #projection = 'sinusoidal'
39    #projection = 'azimuthal_equal_area'
40
41    dist = 'uniform'
42    #dist = 'gaussian'
43
44    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection)
45    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '')
46    #ipv.savefig(filename+'.png')
47"""
48from __future__ import division, print_function
49
50import argparse
51
52import numpy as np
53from numpy import pi, cos, sin, sqrt, exp, degrees, radians
54
55def draw_beam(axes, view=(0, 0), alpha=0.5, steps=25):
56    """
57    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
58    """
59    #axes.plot([0,0],[0,0],[1,-1])
60    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha)
61
62    u = np.linspace(0, 2 * np.pi, steps)
63    v = np.linspace(-1, 1, 2)
64
65    r = 0.02
66    x = r*np.outer(np.cos(u), np.ones_like(v))
67    y = r*np.outer(np.sin(u), np.ones_like(v))
68    z = 1.3*np.outer(np.ones_like(u), v)
69
70    theta, phi = view
71    shape = x.shape
72    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
73    points = Rz(phi)*Ry(theta)*points
74    x, y, z = [v.reshape(shape) for v in points]
75    axes.plot_surface(x, y, z, color='yellow', alpha=alpha)
76
77    # TODO: draw endcaps on beam
78    ## Drawing tiny balls on the end will work
79    #draw_sphere(axes, radius=0.02, center=(0, 0, 1.3), color='yellow', alpha=alpha)
80    #draw_sphere(axes, radius=0.02, center=(0, 0, -1.3), color='yellow', alpha=alpha)
81    ## The following does not work
82    #triangles = [(0, i+1, i+2) for i in range(steps-2)]
83    #x_cap, y_cap = x[:, 0], y[:, 0]
84    #for z_cap in z[:, 0], z[:, -1]:
85    #    axes.plot_trisurf(x_cap, y_cap, z_cap, triangles,
86    #                      color='yellow', alpha=alpha)
87
88
89def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1):
90    """Draw an ellipsoid."""
91    a, b, c = size
92    u = np.linspace(0, 2 * np.pi, steps)
93    v = np.linspace(0, np.pi, steps)
94    x = a*np.outer(np.cos(u), np.sin(v))
95    y = b*np.outer(np.sin(u), np.sin(v))
96    z = c*np.outer(np.ones_like(u), np.cos(v))
97    x, y, z = transform_xyz(view, jitter, x, y, z)
98
99    axes.plot_surface(x, y, z, color='w', alpha=alpha)
100
101    draw_labels(axes, view, jitter, [
102        ('c+', [+0, +0, +c], [+1, +0, +0]),
103        ('c-', [+0, +0, -c], [+0, +0, -1]),
104        ('a+', [+a, +0, +0], [+0, +0, +1]),
105        ('a-', [-a, +0, +0], [+0, +0, -1]),
106        ('b+', [+0, +b, +0], [-1, +0, +0]),
107        ('b-', [+0, -b, +0], [-1, +0, +0]),
108    ])
109
110def draw_sc(axes, size, view, jitter, steps=None, alpha=1):
111    """Draw points for simple cubic paracrystal"""
112    atoms = _build_sc()
113    _draw_crystal(axes, size, view, jitter, atoms=atoms)
114
115def draw_fcc(axes, size, view, jitter, steps=None, alpha=1):
116    """Draw points for face-centered cubic paracrystal"""
117    # Build the simple cubic crystal
118    atoms = _build_sc()
119    # Define the centers for each face
120    # x planes at -1, 0, 1 have four centers per plane, at +/- 0.5 in y and z
121    x, y, z = (
122        [-1]*4 + [0]*4 + [1]*4,
123        ([-0.5]*2 + [0.5]*2)*3,
124        [-0.5, 0.5]*12,
125    )
126    # y and z planes can be generated by substituting x for y and z respectively
127    atoms.extend(zip(x+y+z, y+z+x, z+x+y))
128    _draw_crystal(axes, size, view, jitter, atoms=atoms)
129
130def draw_bcc(axes, size, view, jitter, steps=None, alpha=1):
131    """Draw points for body-centered cubic paracrystal"""
132    # Build the simple cubic crystal
133    atoms = _build_sc()
134    # Define the centers for each octant
135    # x plane at +/- 0.5 have four centers per plane at +/- 0.5 in y and z
136    x, y, z = (
137        [-0.5]*4 + [0.5]*4,
138        ([-0.5]*2 + [0.5]*2)*2,
139        [-0.5, 0.5]*8,
140    )
141    atoms.extend(zip(x, y, z))
142    _draw_crystal(axes, size, view, jitter, atoms=atoms)
143
144def _draw_crystal(axes, size, view, jitter, atoms=None):
145    atoms, size = np.asarray(atoms, 'd').T, np.asarray(size, 'd')
146    x, y, z = atoms*size[:, None]
147    x, y, z = transform_xyz(view, jitter, x, y, z)
148    axes.scatter([x[0]], [y[0]], [z[0]], c='yellow', marker='o')
149    axes.scatter(x[1:], y[1:], z[1:], c='r', marker='o')
150
151def _build_sc():
152    # three planes of 9 dots for x at -1, 0 and 1
153    x, y, z = (
154        [-1]*9 + [0]*9 + [1]*9,
155        ([-1]*3 + [0]*3 + [1]*3)*3,
156        [-1, 0, 1]*9,
157    )
158    atoms = list(zip(x, y, z))
159    #print(list(enumerate(atoms)))
160    # Pull the dot at (0, 0, 1) to the front of the list
161    # It will be highlighted in the view
162    index = 14
163    highlight = atoms[index]
164    del atoms[index]
165    atoms.insert(0, highlight)
166    return atoms
167
168def draw_box(axes, size, view):
169    a, b, c = size
170    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
171    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
172    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
173    x, y, z = transform_xyz(view, None, x, y, z)
174    def draw(i, j):
175        axes.plot([x[i],x[j]], [y[i], y[j]], [z[i], z[j]], color='black')
176    draw(0, 1)
177    draw(0, 2)
178    draw(0, 3)
179    draw(7, 4)
180    draw(7, 5)
181    draw(7, 6)
182
183def draw_parallelepiped(axes, size, view, jitter, steps=None,
184                        color=(0.6, 1.0, 0.6), alpha=1):
185    """Draw a parallelepiped."""
186    a, b, c = size
187    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
188    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
189    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
190    tri = np.array([
191        # counter clockwise triangles
192        # z: up/down, x: right/left, y: front/back
193        [0, 1, 2], [3, 2, 1], # top face
194        [6, 5, 4], [5, 6, 7], # bottom face
195        [0, 2, 6], [6, 4, 0], # right face
196        [1, 5, 7], [7, 3, 1], # left face
197        [2, 3, 6], [7, 6, 3], # front face
198        [4, 1, 0], [5, 1, 4], # back face
199    ])
200
201    x, y, z = transform_xyz(view, jitter, x, y, z)
202    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha,
203                      linewidth=0)
204
205    # Colour the c+ face of the box.
206    # Since I can't control face color, instead draw a thin box situated just
207    # in front of the "c+" face.  Use the c face so that rotations about psi
208    # rotate that face.
209    if 0:
210        color = (1, 0.6, 0.6)  # pink
211        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
212        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
213        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
214        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001)
215        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
216
217    draw_labels(axes, view, jitter, [
218        ('c+', [+0, +0, +c], [+1, +0, +0]),
219        ('c-', [+0, +0, -c], [+0, +0, -1]),
220        ('a+', [+a, +0, +0], [+0, +0, +1]),
221        ('a-', [-a, +0, +0], [+0, +0, -1]),
222        ('b+', [+0, +b, +0], [-1, +0, +0]),
223        ('b-', [+0, -b, +0], [-1, +0, +0]),
224    ])
225
226def draw_sphere(axes, radius=0.5, steps=25, center=(0,0,0), color='w', alpha=1.):
227    """Draw a sphere"""
228    u = np.linspace(0, 2 * np.pi, steps)
229    v = np.linspace(0, np.pi, steps)
230
231    x = radius * np.outer(np.cos(u), np.sin(v)) + center[0]
232    y = radius * np.outer(np.sin(u), np.sin(v)) + center[1]
233    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2]
234    axes.plot_surface(x, y, z, color=color, alpha=alpha)
235    #axes.plot_wireframe(x, y, z)
236
237def draw_axes(axes, origin=(-1, -1, -1), length=(2, 2, 2)):
238    x, y, z = origin
239    dx, dy, dz = length
240    axes.plot([x, x+dx], [y, y], [z, z], color='black')
241    axes.plot([x, x], [y, y+dy], [z, z], color='black')
242    axes.plot([x, x], [y, y], [z, z+dz], color='black')
243
244def draw_person_on_sphere(axes, view, height=0.5, radius=0.5):
245    limb_offset = height * 0.05
246    head_radius = height * 0.10
247    head_height = height - head_radius
248    neck_length = head_radius * 0.50
249    shoulder_height = height - 2*head_radius - neck_length
250    torso_length = shoulder_height * 0.55
251    torso_radius = torso_length * 0.30
252    leg_length = shoulder_height - torso_length
253    arm_length = torso_length * 0.90
254
255    def _draw_part(x, z):
256        y = np.zeros_like(x)
257        xp, yp, zp = transform_xyz(view, None, x, y, z + radius)
258        axes.plot(xp, yp, zp, color='k')
259
260    # circle for head
261    u = np.linspace(0, 2 * np.pi, 40)
262    x = head_radius * np.cos(u)
263    z = head_radius * np.sin(u) + head_height
264    _draw_part(x, z)
265
266    # rectangle for body
267    x = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius])
268    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length
269    _draw_part(x, z)
270
271    # arms
272    x = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius])
273    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height])
274    _draw_part(x, z)
275    _draw_part(-x, z)
276
277    # legs
278    x = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset])
279    z = np.array([0, leg_length])
280    _draw_part(x, z)
281    _draw_part(-x, z)
282
283    limits = [-radius-height, radius+height]
284    axes.set_xlim(limits)
285    axes.set_ylim(limits)
286    axes.set_zlim(limits)
287    axes.set_axis_off()
288
289def draw_jitter(axes, view, jitter, dist='gaussian',
290                size=(0.1, 0.4, 1.0),
291                draw_shape=draw_parallelepiped,
292                projection='equirectangular',
293                alpha=0.8,
294                views=None):
295    """
296    Represent jitter as a set of shapes at different orientations.
297    """
298    project, project_weight = get_projection(projection)
299
300    # set max diagonal to 0.95
301    scale = 0.95/sqrt(sum(v**2 for v in size))
302    size = tuple(scale*v for v in size)
303
304    dtheta, dphi, dpsi = jitter
305    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist]
306    def steps(delta):
307        if views is None:
308            n = max(3, min(25, 2*int(base*delta/5)))
309        else:
310            n = views
311        return base*delta*np.linspace(-1, 1, n) if delta > 0 else [0.]
312    for theta in steps(dtheta):
313        for phi in steps(dphi):
314            for psi in steps(dpsi):
315                w = project_weight(theta, phi, 1.0, 1.0)
316                if w > 0:
317                    dview = project(theta, phi, psi)
318                    draw_shape(axes, size, view, dview, alpha=alpha)
319    for v in 'xyz':
320        a, b, c = size
321        lim = np.sqrt(a**2 + b**2 + c**2)
322        getattr(axes, 'set_'+v+'lim')([-lim, lim])
323        #getattr(axes, v+'axis').label.set_text(v)
324
325PROJECTIONS = [
326    # in order of PROJECTION number; do not change without updating the
327    # constants in kernel_iq.c
328    'equirectangular', 'sinusoidal', 'guyou', 'azimuthal_equidistance',
329    'azimuthal_equal_area',
330]
331def get_projection(projection):
332
333    """
334    jitter projections
335    <https://en.wikipedia.org/wiki/List_of_map_projections>
336
337    equirectangular (standard latitude-longitude mesh)
338        <https://en.wikipedia.org/wiki/Equirectangular_projection>
339        Allows free movement in phi (around the equator), but theta is
340        limited to +/- 90, and points are cos-weighted. Jitter in phi is
341        uniform in weight along a line of latitude.  With small theta and
342        phi ranging over +/- 180 this forms a wobbling disk.  With small
343        phi and theta ranging over +/- 90 this forms a wedge like a slice
344        of an orange.
345    azimuthal_equidistance (Postel)
346        <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>
347        Preserves distance from center, and so is an excellent map for
348        representing a bivariate gaussian on the surface.  Theta and phi
349        operate identically, cutting wegdes from the antipode of the viewing
350        angle.  This unfortunately does not allow free movement in either
351        theta or phi since the orthogonal wobble decreases to 0 as the body
352        rotates through 180 degrees.
353    sinusoidal (Sanson-Flamsteed, Mercator equal-area)
354        <https://en.wikipedia.org/wiki/Sinusoidal_projection>
355        Preserves arc length with latitude, giving bad behaviour at
356        theta near +/- 90.  Theta and phi operate somewhat differently,
357        so a system with a-b-c dtheta-dphi-dpsi will not give the same
358        value as one with b-a-c dphi-dtheta-dpsi, as would be the case
359        for azimuthal equidistance.  Free movement using theta or phi
360        uniform over +/- 180 will work, but not as well as equirectangular
361        phi, with theta being slightly worse.  Computationally it is much
362        cheaper for wide theta-phi meshes since it excludes points which
363        lie outside the sinusoid near theta +/- 90 rather than packing
364        them close together as in equirectangle.  Note that the poles
365        will be slightly overweighted for theta > 90 with the circle
366        from theta at 90+dt winding backwards around the pole, overlapping
367        the circle from theta at 90-dt.
368    Guyou (hemisphere-in-a-square) **not weighted**
369        <https://en.wikipedia.org/wiki/Guyou_hemisphere-in-a-square_projection>
370        With tiling, allows rotation in phi or theta through +/- 180, with
371        uniform spacing.  Both theta and phi allow free rotation, with wobble
372        in the orthogonal direction reasonably well behaved (though not as
373        good as equirectangular phi). The forward/reverse transformations
374        relies on elliptic integrals that are somewhat expensive, so the
375        behaviour has to be very good to justify the cost and complexity.
376        The weighting function for each point has not yet been computed.
377        Note: run the module *guyou.py* directly and it will show the forward
378        and reverse mappings.
379    azimuthal_equal_area  **incomplete**
380        <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>
381        Preserves the relative density of the surface patches.  Not that
382        useful and not completely implemented
383    Gauss-Kreuger **not implemented**
384        <https://en.wikipedia.org/wiki/Transverse_Mercator_projection#Ellipsoidal_transverse_Mercator>
385        Should allow free movement in theta, but phi is distorted.
386    """
387    # TODO: try Kent distribution instead of a gaussian warped by projection
388
389    if projection == 'equirectangular':  #define PROJECTION 1
390        def _project(theta_i, phi_j, psi):
391            latitude, longitude = theta_i, phi_j
392            return latitude, longitude, psi
393            #return Rx(phi_j)*Ry(theta_i)
394        def _weight(theta_i, phi_j, w_i, w_j):
395            return w_i*w_j*abs(cos(radians(theta_i)))
396    elif projection == 'sinusoidal':  #define PROJECTION 2
397        def _project(theta_i, phi_j, psi):
398            latitude = theta_i
399            scale = cos(radians(latitude))
400            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0
401            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
402            return latitude, longitude, psi
403            #return Rx(longitude)*Ry(latitude)
404        def _project(theta_i, phi_j, w_i, w_j):
405            latitude = theta_i
406            scale = cos(radians(latitude))
407            active = 1 if abs(phi_j) < abs(scale)*180 else 0
408            return active*w_i*w_j
409    elif projection == 'guyou':  #define PROJECTION 3  (eventually?)
410        def _project(theta_i, phi_j, psi):
411            from .guyou import guyou_invert
412            #latitude, longitude = guyou_invert([theta_i], [phi_j])
413            longitude, latitude = guyou_invert([phi_j], [theta_i])
414            return latitude, longitude, psi
415            #return Rx(longitude[0])*Ry(latitude[0])
416        def _weight(theta_i, phi_j, w_i, w_j):
417            return w_i*w_j
418    elif projection == 'azimuthal_equidistance':
419        # Note that calculates angles for Rz Ry rather than Rx Ry
420        def _project(theta_i, phi_j, psi):
421            latitude = sqrt(theta_i**2 + phi_j**2)
422            longitude = degrees(np.arctan2(phi_j, theta_i))
423            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
424            return latitude, longitude, psi-longitude, 'zyz'
425            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
426            #return R_to_xyz(R)
427            #return Rz(longitude)*Ry(latitude)
428        def _weight(theta_i, phi_j, w_i, w_j):
429            # Weighting for each point comes from the integral:
430            #     \int\int I(q, lat, log) sin(lat) dlat dlog
431            # We are doing a conformal mapping from disk to sphere, so we need
432            # a change of variables g(theta, phi) -> (lat, long):
433            #     lat, long = sqrt(theta^2 + phi^2), arctan(phi/theta)
434            # giving:
435            #     dtheta dphi = det(J) dlat dlong
436            # where J is the jacobian from the partials of g. Using
437            #     R = sqrt(theta^2 + phi^2),
438            # then
439            #     J = [[x/R, Y/R], -y/R^2, x/R^2]]
440            # and
441            #     det(J) = 1/R
442            # with the final integral being:
443            #    \int\int I(q, theta, phi) sin(R)/R dtheta dphi
444            #
445            # This does approximately the right thing, decreasing the weight
446            # of each point as you go farther out on the disk, but it hasn't
447            # yet been checked against the 1D integral results. Prior
448            # to declaring this "good enough" and checking that integrals
449            # work in practice, we will examine alternative mappings.
450            #
451            # The issue is that the mapping does not support the case of free
452            # rotation about a single axis correctly, with a small deviation
453            # in the orthogonal axis independent of the first axis.  Like the
454            # usual polar coordiates integration, the integrated sections
455            # form wedges, though at least in this case the wedge cuts through
456            # the entire sphere, and treats theta and phi identically.
457            latitude = sqrt(theta_i**2 + phi_j**2)
458            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
459            return weight*w_i*w_j if latitude < 180 else 0
460    elif projection == 'azimuthal_equal_area':
461        # Note that calculates angles for Rz Ry rather than Rx Ry
462        def _project(theta_i, phi_j, psi):
463            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180)
464            latitude = 180-degrees(2*np.arccos(radius))
465            longitude = degrees(np.arctan2(phi_j, theta_i))
466            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
467            return latitude, longitude, psi, "zyz"
468            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
469            #return R_to_xyz(R)
470            #return Rz(longitude)*Ry(latitude)
471        def _weight(theta_i, phi_j, w_i, w_j):
472            latitude = sqrt(theta_i**2 + phi_j**2)
473            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
474            return weight*w_i*w_j if latitude < 180 else 0
475    else:
476        raise ValueError("unknown projection %r"%projection)
477
478    return _project, _weight
479
480def R_to_xyz(R):
481    """
482    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix.
483
484    Extracting Euler Angles from a Rotation Matrix
485    Mike Day, Insomniac Games
486    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf
487    Based on: Shoemake’s "Euler Angle Conversion", Graphics Gems IV, pp.  222-229
488    """
489    phi = np.arctan2(R[1, 2], R[2, 2])
490    theta = np.arctan2(-R[0, 2], np.sqrt(R[0, 0]**2 + R[0, 1]**2))
491    psi = np.arctan2(R[0, 1], R[0, 0])
492    return np.degrees(phi), np.degrees(theta), np.degrees(psi)
493
494def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian',
495              projection='equirectangular'):
496    """
497    Draw the dispersion mesh showing the theta-phi orientations at which
498    the model will be evaluated.
499    """
500
501    _project, _weight = get_projection(projection)
502    def _rotate(theta, phi, z):
503        dview = _project(theta, phi, 0.)
504        if len(dview) == 4: # hack for zyz coords
505            return Rz(dview[1])*Ry(dview[0])*z
506        else:
507            return Rx(dview[1])*Ry(dview[0])*z
508
509
510    dist_x = np.linspace(-1, 1, n)
511    weights = np.ones_like(dist_x)
512    if dist == 'gaussian':
513        dist_x *= 3
514        weights = exp(-0.5*dist_x**2)
515    elif dist == 'rectangle':
516        # Note: uses sasmodels ridiculous definition of rectangle width
517        dist_x *= sqrt(3)
518    elif dist == 'uniform':
519        pass
520    else:
521        raise ValueError("expected dist to be gaussian, rectangle or uniform")
522
523    # mesh in theta, phi formed by rotating z
524    dtheta, dphi, dpsi = jitter
525    z = np.matrix([[0], [0], [radius]])
526    points = np.hstack([_rotate(theta_i, phi_j, z)
527                        for theta_i in dtheta*dist_x
528                        for phi_j in dphi*dist_x])
529    dist_w = np.array([_weight(theta_i, phi_j, w_i, w_j)
530                       for w_i, theta_i in zip(weights, dtheta*dist_x)
531                       for w_j, phi_j in zip(weights, dphi*dist_x)])
532    #print(max(dist_w), min(dist_w), min(dist_w[dist_w > 0]))
533    points = points[:, dist_w > 0]
534    dist_w = dist_w[dist_w > 0]
535    dist_w /= max(dist_w)
536
537    # rotate relative to beam
538    points = orient_relative_to_beam(view, points)
539
540    x, y, z = [np.array(v).flatten() for v in points]
541    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
542    axes.scatter(x, y, z, c=dist_w, marker='o', vmin=0., vmax=1.)
543
544def draw_labels(axes, view, jitter, text):
545    """
546    Draw text at a particular location.
547    """
548    labels, locations, orientations = zip(*text)
549    px, py, pz = zip(*locations)
550    dx, dy, dz = zip(*orientations)
551
552    px, py, pz = transform_xyz(view, jitter, px, py, pz)
553    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
554
555    # TODO: zdir for labels is broken, and labels aren't appearing.
556    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
557        zdir = np.asarray(zdir).flatten()
558        axes.text(p[0], p[1], p[2], label, zdir=zdir)
559
560# Definition of rotation matrices comes from wikipedia:
561#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
562def Rx(angle):
563    """Construct a matrix to rotate points about *x* by *angle* degrees."""
564    angle = radians(angle)
565    rot = [[1, 0, 0],
566           [0, +cos(angle), -sin(angle)],
567           [0, +sin(angle), +cos(angle)]]
568    return np.matrix(rot)
569
570def Ry(angle):
571    """Construct a matrix to rotate points about *y* by *angle* degrees."""
572    angle = radians(angle)
573    rot = [[+cos(angle), 0, +sin(angle)],
574           [0, 1, 0],
575           [-sin(angle), 0, +cos(angle)]]
576    return np.matrix(rot)
577
578def Rz(angle):
579    """Construct a matrix to rotate points about *z* by *angle* degrees."""
580    angle = radians(angle)
581    rot = [[+cos(angle), -sin(angle), 0],
582           [+sin(angle), +cos(angle), 0],
583           [0, 0, 1]]
584    return np.matrix(rot)
585
586def transform_xyz(view, jitter, x, y, z):
587    """
588    Send a set of (x,y,z) points through the jitter and view transforms.
589    """
590    x, y, z = [np.asarray(v) for v in (x, y, z)]
591    shape = x.shape
592    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
593    points = apply_jitter(jitter, points)
594    points = orient_relative_to_beam(view, points)
595    x, y, z = [np.array(v).reshape(shape) for v in points]
596    return x, y, z
597
598def apply_jitter(jitter, points):
599    """
600    Apply the jitter transform to a set of points.
601
602    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
603    """
604    if jitter is None:
605        return points
606    # Hack to deal with the fact that azimuthal_equidistance uses euler angles
607    if len(jitter) == 4:
608        dtheta, dphi, dpsi, _ = jitter
609        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points
610    else:
611        dtheta, dphi, dpsi = jitter
612        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
613    return points
614
615def orient_relative_to_beam(view, points):
616    """
617    Apply the view transform to a set of points.
618
619    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
620    """
621    theta, phi, psi = view
622    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle
623    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles
624    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle
625    return points
626
627def orient_relative_to_beam_quaternion(view, points):
628    """
629    Apply the view transform to a set of points.
630
631    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
632
633    This variant uses quaternions rather than rotation matrices for the
634    computation.  It works but it is not used because it doesn't solve
635    any problems.  The challenge of mapping theta/phi/psi to SO(3) does
636    not disappear by calculating the transform differently.
637    """
638    theta, phi, psi = view
639    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1]
640    q = Quaternion(1, [0, 0, 0])
641    ## Compose a rotation about the three axes by rotating
642    ## the unit vectors before applying the rotation.
643    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q
644    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q
645    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q
646    ## The above turns out to be equivalent to reversing
647    ## the order of application, so ignore it and use below.
648    q = q * Quaternion.from_angle_axis(theta, x)
649    q = q * Quaternion.from_angle_axis(phi, y)
650    q = q * Quaternion.from_angle_axis(psi, z)
651    ## Reverse the order by post-multiply rather than pre-multiply
652    #q = Quaternion.from_angle_axis(theta, x) * q
653    #q = Quaternion.from_angle_axis(phi, y) * q
654    #q = Quaternion.from_angle_axis(psi, z) * q
655    #print("axes psi", q.rot(np.matrix([x, y, z]).T))
656    return q.rot(points)
657#orient_relative_to_beam = orient_relative_to_beam_quaternion
658
659# Simple stand-alone quaternion class
660import numpy as np
661from copy import copy
662class Quaternion(object):
663    def __init__(self, w, r):
664         self.w = w
665         self.r = np.asarray(r, dtype='d')
666    @staticmethod
667    def from_angle_axis(theta, r):
668         theta = np.radians(theta)/2
669         r = np.asarray(r)
670         w = np.cos(theta)
671         r = np.sin(theta)*r/np.dot(r,r)
672         return Quaternion(w, r)
673    def __mul__(self, other):
674        if isinstance(other, Quaternion):
675            w = self.w*other.w - np.dot(self.r, other.r)
676            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r)
677            return Quaternion(w, r)
678    def rot(self, v):
679        v = np.asarray(v).T
680        use_transpose = (v.shape[-1] != 3)
681        if use_transpose: v = v.T
682        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v)
683        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v))
684        if use_transpose: v = v.T
685        return v.T
686    def conj(self):
687        return Quaternion(self.w, -self.r)
688    def inv(self):
689        return self.conj()/self.norm()**2
690    def norm(self):
691        return np.sqrt(self.w**2 + np.sum(self.r**2))
692    def __str__(self):
693        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2])
694def test_qrot():
695    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees from y.
696    # The rotation axis is determined by rotating the point [0, 1, 0] about x.
697    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0])
698    q = Quaternion.from_angle_axis(60, ax)
699    # Set the point to be rotated, and its expected rotated position.
700    p = [1, -1, 2]
701    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8]
702    #print(q, q.rot(p) - target)
703    assert max(abs(q.rot(p) - target)) < 1e-14
704#test_qrot()
705#import sys; sys.exit()
706
707# translate between number of dimension of dispersity and the number of
708# points along each dimension.
709PD_N_TABLE = {
710    (0, 0, 0): (0, 0, 0),     # 0
711    (1, 0, 0): (100, 0, 0),   # 100
712    (0, 1, 0): (0, 100, 0),
713    (0, 0, 1): (0, 0, 100),
714    (1, 1, 0): (30, 30, 0),   # 900
715    (1, 0, 1): (30, 0, 30),
716    (0, 1, 1): (0, 30, 30),
717    (1, 1, 1): (15, 15, 15),  # 3375
718}
719
720def clipped_range(data, portion=1.0, mode='central'):
721    """
722    Determine range from data.
723
724    If *portion* is 1, use full range, otherwise use the center of the range
725    or the top of the range, depending on whether *mode* is 'central' or 'top'.
726    """
727    if portion == 1.0:
728        return data.min(), data.max()
729    elif mode == 'central':
730        data = np.sort(data.flatten())
731        offset = int(portion*len(data)/2 + 0.5)
732        return data[offset], data[-offset]
733    elif mode == 'top':
734        data = np.sort(data.flatten())
735        offset = int(portion*len(data) + 0.5)
736        return data[offset], data[-1]
737
738def draw_scattering(calculator, axes, view, jitter, dist='gaussian'):
739    """
740    Plot the scattering for the particular view.
741
742    *calculator* is returned from :func:`build_model`.  *axes* are the 3D axes
743    on which the data will be plotted.  *view* and *jitter* are the current
744    orientation and orientation dispersity.  *dist* is one of the sasmodels
745    weight distributions.
746    """
747    if dist == 'uniform':  # uniform is not yet in this branch
748        dist, scale = 'rectangle', 1/sqrt(3)
749    else:
750        scale = 1
751
752    # add the orientation parameters to the model parameters
753    theta, phi, psi = view
754    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
755    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd > 0, phi_pd > 0, psi_pd > 0)]
756    ## increase pd_n for testing jitter integration rather than simple viz
757    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
758
759    pars = dict(
760        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
761        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
762        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
763    )
764    pars.update(calculator.pars)
765
766    # compute the pattern
767    qx, qy = calculator._data.x_bins, calculator._data.y_bins
768    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
769
770    # scale it and draw it
771    Iqxy = np.log(Iqxy)
772    if calculator.limits:
773        # use limits from orientation (0,0,0)
774        vmin, vmax = calculator.limits
775    else:
776        vmax = Iqxy.max()
777        vmin = vmax*10**-7
778        #vmin, vmax = clipped_range(Iqxy, portion=portion, mode='top')
779    #vmin, vmax = Iqxy.min(), Iqxy.max()
780    #print("range",(vmin,vmax))
781    #qx, qy = np.meshgrid(qx, qy)
782    if 0:
783        from matplotlib import cm
784        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
785        level[level < 0] = 0
786        colors = plt.get_cmap()(level)
787        #colors = cm.coolwarm(level)
788        #colors = cm.gist_yarg(level)
789        #colors = cm.Wistia(level)
790        colors[level<=0, 3] = 0.  # set floor to transparent
791        x, y = np.meshgrid(qx/qx.max(), qy/qy.max())
792        axes.plot_surface(x, y, -1.1*np.ones_like(x), facecolors=colors)
793    elif 1:
794        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
795                      levels=np.linspace(vmin, vmax, 24))
796    else:
797        axes.pcolormesh(qx, qy, Iqxy)
798
799def build_model(model_name, n=150, qmax=0.5, **pars):
800    """
801    Build a calculator for the given shape.
802
803    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
804    on which to evaluate the model.  The remaining parameters are stored in
805    the returned calculator as *calculator.pars*.  They are used by
806    :func:`draw_scattering` to set the non-orientation parameters in the
807    calculation.
808
809    Returns a *calculator* function which takes a dictionary or parameters and
810    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
811    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
812    for details.
813    """
814    from sasmodels.core import load_model_info, build_model as build_sasmodel
815    from sasmodels.data import empty_data2D
816    from sasmodels.direct_model import DirectModel
817
818    model_info = load_model_info(model_name)
819    model = build_sasmodel(model_info) #, dtype='double!')
820    q = np.linspace(-qmax, qmax, n)
821    data = empty_data2D(q, q)
822    calculator = DirectModel(data, model)
823
824    # stuff the values for non-orientation parameters into the calculator
825    calculator.pars = pars.copy()
826    calculator.pars.setdefault('backgound', 1e-3)
827
828    # fix the data limits so that we can see if the pattern fades
829    # under rotation or angular dispersion
830    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
831    Iqxy = np.log(Iqxy)
832    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
833    calculator.limits = vmin, vmax+1
834
835    return calculator
836
837def select_calculator(model_name, n=150, size=(10, 40, 100)):
838    """
839    Create a model calculator for the given shape.
840
841    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
842    parallelepiped or bcc_paracrystal. *n* is the number of points to use
843    in the q range.  *qmax* is chosen based on model parameters for the
844    given model to show something intersting.
845
846    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
847    equitorial axes and polar axis respectively.  See :func:`build_model`
848    for details on the returned calculator.
849    """
850    a, b, c = size
851    d_factor = 0.06  # for paracrystal models
852    if model_name == 'sphere':
853        calculator = build_model('sphere', n=n, radius=c)
854        a = b = c
855    elif model_name == 'sc_paracrystal':
856        a = b = c
857        dnn = c
858        radius = 0.5*c
859        calculator = build_model('sc_paracrystal', n=n, dnn=dnn,
860                                 d_factor=d_factor, radius=(1-d_factor)*radius,
861                                 background=0)
862    elif model_name == 'fcc_paracrystal':
863        a = b = c
864        # nearest neigbour distance dnn should be 2 radius, but I think the
865        # model uses lattice spacing rather than dnn in its calculations
866        dnn = 0.5*c
867        radius = sqrt(2)/4 * c
868        calculator = build_model('fcc_paracrystal', n=n, dnn=dnn,
869                                 d_factor=d_factor, radius=(1-d_factor)*radius,
870                                 background=0)
871    elif model_name == 'bcc_paracrystal':
872        a = b = c
873        # nearest neigbour distance dnn should be 2 radius, but I think the
874        # model uses lattice spacing rather than dnn in its calculations
875        dnn = 0.5*c
876        radius = sqrt(3)/2 * c
877        calculator = build_model('bcc_paracrystal', n=n, dnn=dnn,
878                                 d_factor=d_factor, radius=(1-d_factor)*radius,
879                                 background=0)
880    elif model_name == 'cylinder':
881        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
882        a = b
883    elif model_name == 'ellipsoid':
884        calculator = build_model('ellipsoid', n=n, qmax=1.0,
885                                 radius_polar=c, radius_equatorial=b)
886        a = b
887    elif model_name == 'triaxial_ellipsoid':
888        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
889                                 radius_equat_minor=a,
890                                 radius_equat_major=b,
891                                 radius_polar=c)
892    elif model_name == 'parallelepiped':
893        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
894    else:
895        raise ValueError("unknown model %s"%model_name)
896
897    return calculator, (a, b, c)
898
899SHAPES = [
900    'parallelepiped',
901    'sphere', 'ellipsoid', 'triaxial_ellipsoid',
902    'cylinder',
903    'fcc_paracrystal', 'bcc_paracrystal', 'sc_paracrystal',
904]
905
906DRAW_SHAPES = {
907    'fcc_paracrystal': draw_fcc,
908    'bcc_paracrystal': draw_bcc,
909    'sc_paracrystal': draw_sc,
910    'parallelepiped': draw_parallelepiped,
911}
912
913DISTRIBUTIONS = [
914    'gaussian', 'rectangle', 'uniform',
915]
916DIST_LIMITS = {
917    'gaussian': 30,
918    'rectangle': 90/sqrt(3),
919    'uniform': 90,
920}
921
922
923def run(model_name='parallelepiped', size=(10, 40, 100),
924        view=(0, 0, 0), jitter=(0, 0, 0),
925        dist='gaussian', mesh=30,
926        projection='equirectangular'):
927    """
928    Show an interactive orientation and jitter demo.
929
930    *model_name* is one of: sphere, ellipsoid, triaxial_ellipsoid,
931    parallelepiped, cylinder, or sc/fcc/bcc_paracrystal
932
933    *size* gives the dimensions (a, b, c) of the shape.
934
935    *view* gives the initial view (theta, phi, psi) of the shape.
936
937    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape.
938
939    *dist* is the type of dispersition: gaussian, rectangle, or uniform.
940
941    *mesh* is the number of points in the dispersion mesh.
942
943    *projection* is the map projection to use for the mesh: equirectangular,
944    sinusoidal, guyou, azimuthal_equidistance, or azimuthal_equal_area.
945    """
946    # projection number according to 1-order position in list, but
947    # only 1 and 2 are implemented so far.
948    from sasmodels import generate
949    generate.PROJECTION = PROJECTIONS.index(projection) + 1
950    if generate.PROJECTION > 2:
951        print("*** PROJECTION %s not implemented in scattering function ***"%projection)
952        generate.PROJECTION = 2
953
954    # set up calculator
955    calculator, size = select_calculator(model_name, n=150, size=size)
956    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped)
957    #draw_shape = draw_fcc
958
959    ## uncomment to set an independent the colour range for every view
960    ## If left commented, the colour range is fixed for all views
961    calculator.limits = None
962
963    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection)
964
965def mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
966    # Note: travis-ci does not support mpl_toolkits.mplot3d, but this shouldn't be
967    # an issue since we are lazy-loading the package on a path that isn't tested.
968    import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot
969    import matplotlib as mpl
970    import matplotlib.pyplot as plt
971    from matplotlib.widgets import Slider
972   
973    ## create the plot window
974    #plt.hold(True)
975    plt.subplots(num=None, figsize=(5.5, 5.5))
976    plt.set_cmap('gist_earth')
977    plt.clf()
978    plt.gcf().canvas.set_window_title(projection)
979    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
980    #axes = plt.subplot(gs[0], projection='3d')
981    axes = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
982    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
983        axes.axis('square')
984    except Exception:
985        pass
986
987
988    # CRUFT: use axisbg instead of facecolor for matplotlib<2
989    facecolor_prop = 'facecolor' if mpl.__version__ > '2' else 'axisbg'
990    props = {facecolor_prop: 'lightgoldenrodyellow'}
991
992    ## add control widgets to plot
993    axes_theta = plt.axes([0.1, 0.15, 0.45, 0.04], **props)
994    axes_phi = plt.axes([0.1, 0.1, 0.45, 0.04], **props)
995    axes_psi = plt.axes([0.1, 0.05, 0.45, 0.04], **props)
996    stheta = Slider(axes_theta, 'Theta', -90, 90, valinit=theta)
997    sphi = Slider(axes_phi, 'Phi', -180, 180, valinit=phi)
998    spsi = Slider(axes_psi, 'Psi', -180, 180, valinit=psi)
999
1000    axes_dtheta = plt.axes([0.75, 0.15, 0.15, 0.04], **props)
1001    axes_dphi = plt.axes([0.75, 0.1, 0.15, 0.04], **props)
1002    axes_dpsi = plt.axes([0.75, 0.05, 0.15, 0.04], **props)
1003
1004    # Note: using ridiculous definition of rectangle distribution, whose width
1005    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
1006    # the maximum width to 90.
1007    dlimit = DIST_LIMITS[dist]
1008    sdtheta = Slider(axes_dtheta, u'Δξ', 0, 2*dlimit, valinit=0)
1009    sdphi = Slider(axes_dphi, u'Δφ', 0, 2*dlimit, valinit=0)
1010    sdpsi = Slider(axes_dpsi, u'Δψ', 0, 2*dlimit, valinit=0)
1011
1012    ## initial view and jitter
1013    theta, phi, psi = view
1014    stheta.set_val(theta)
1015    sphi.set_val(phi)
1016    spsi.set_val(psi)
1017    dtheta, dphi, dpsi = jitter
1018    sdtheta.set_val(dtheta)
1019    sdphi.set_val(dphi)
1020    sdpsi.set_val(dpsi)
1021
1022    ## callback to draw the new view
1023    def update(val, axis=None):
1024        view = stheta.val, sphi.val, spsi.val
1025        jitter = sdtheta.val, sdphi.val, sdpsi.val
1026        # set small jitter as 0 if multiple pd dims
1027        dims = sum(v > 0 for v in jitter)
1028        limit = [0, 0.5, 5, 5][dims]
1029        jitter = [0 if v < limit else v for v in jitter]
1030        axes.cla()
1031
1032        ## Visualize as person on globe
1033        #draw_sphere(axes)
1034        #draw_person_on_sphere(axes, view)
1035
1036        ## Move beam instead of shape
1037        #draw_beam(axes, -view[:2])
1038        #draw_jitter(axes, (0,0,0), (0,0,0), views=3)
1039
1040        ## Move shape and draw scattering
1041        draw_beam(axes, (0, 0), alpha=1.)
1042        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.,
1043                    draw_shape=draw_shape, projection=projection, views=3)
1044        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection)
1045        draw_scattering(calculator, axes, view, jitter, dist=dist)
1046
1047        plt.gcf().canvas.draw()
1048
1049    ## bind control widgets to view updater
1050    stheta.on_changed(lambda v: update(v, 'theta'))
1051    sphi.on_changed(lambda v: update(v, 'phi'))
1052    spsi.on_changed(lambda v: update(v, 'psi'))
1053    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
1054    sdphi.on_changed(lambda v: update(v, 'dphi'))
1055    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
1056
1057    ## initialize view
1058    update(None, 'phi')
1059
1060    ## go interactive
1061    plt.show()
1062
1063
1064def map_colors(z, kw):
1065    from matplotlib import cm
1066
1067    cmap = kw.pop('cmap', cm.coolwarm)
1068    alpha = kw.pop('alpha', None)
1069    vmin = kw.pop('vmin', z.min())
1070    vmax = kw.pop('vmax', z.max())
1071    c = kw.pop('c', None)
1072    color = kw.pop('color', c)
1073    if color is None:
1074        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1)
1075        color = cmap(znorm)
1076    elif isinstance(color, np.ndarray) and color.shape == z.shape:
1077        color = cmap(color)
1078    if alpha is None:
1079        if isinstance(color, np.ndarray):
1080            color = color[..., :3]
1081    else:
1082        color[..., 3] = alpha
1083    kw['color'] = color
1084
1085def make_vec(*args):
1086    #return [np.asarray(v, 'd').flatten() for v in args]
1087    return [np.asarray(v, 'd') for v in args]
1088
1089def make_image(z, kw):
1090    import PIL.Image
1091    from matplotlib import cm
1092
1093    cmap = kw.pop('cmap', cm.coolwarm)
1094
1095    znorm = (z-z.min())/z.ptp()
1096    c = cmap(znorm)
1097    c = c[..., :3]
1098    rgb = np.asarray(c*255, 'u1')
1099    image = PIL.Image.fromarray(rgb, mode='RGB')
1100    return image
1101
1102
1103_IPV_MARKERS = {
1104    'o': 'sphere',
1105}
1106_IPV_COLORS = {
1107    'w': 'white',
1108    'k': 'black',
1109    'c': 'cyan',
1110    'm': 'magenta',
1111    'y': 'yellow',
1112    'r': 'red',
1113    'g': 'green',
1114    'b': 'blue',
1115}
1116def ipv_fix_color(kw):
1117    alpha = kw.pop('alpha', None)
1118    color = kw.get('color', None)
1119    if isinstance(color, str):
1120        color = _IPV_COLORS.get(color, color)
1121        kw['color'] = color
1122    if alpha is not None:
1123        color = kw['color']
1124        #TODO: convert color to [r, g, b, a] if not already
1125        if isinstance(color, (tuple, list)):
1126            if len(color) == 3:
1127                color = (color[0], color[1], color[2], alpha)
1128            else:
1129                color = (color[0], color[1], color[2], alpha*color[3])
1130            color = np.array(color)
1131        if isinstance(color, np.ndarray) and color.shape[-1] == 4:
1132            color[..., 3] = alpha
1133            kw['color'] = color
1134
1135def ipv_set_transparency(kw, obj):
1136    color = kw.get('color', None)
1137    if (isinstance(color, np.ndarray)
1138            and color.shape[-1] == 4
1139            and (color[..., 3] != 1.0).any()):
1140        obj.material.transparent = True
1141        obj.material.side = "FrontSide"
1142
1143def ipv_axes():
1144    import ipyvolume as ipv
1145
1146    class Plotter:
1147        # transparency can be achieved by setting the following:
1148        #    mesh.color = [r, g, b, alpha]
1149        #    mesh.material.transparent = True
1150        #    mesh.material.side = "FrontSide"
1151        # smooth(ish) rotation can be achieved by setting:
1152        #    slide.continuous_update = True
1153        #    figure.animation = 0.
1154        #    mesh.material.x = x
1155        #    mesh.material.y = y
1156        #    mesh.material.z = z
1157        # maybe need to synchronize update of x/y/z to avoid shimmy when moving
1158        def plot(self, x, y, z, **kw):
1159            ipv_fix_color(kw)
1160            x, y, z = make_vec(x, y, z)
1161            ipv.plot(x, y, z, **kw)
1162        def plot_surface(self, x, y, z, **kw):
1163            facecolors = kw.pop('facecolors', None)
1164            if facecolors is not None:
1165                kw['color'] = facecolors
1166            ipv_fix_color(kw)
1167            x, y, z = make_vec(x, y, z)
1168            h = ipv.plot_surface(x, y, z, **kw)
1169            ipv_set_transparency(kw, h)
1170            #h.material.side = "DoubleSide"
1171            return h
1172        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw):
1173            kw.pop('linewidth', None)
1174            ipv_fix_color(kw)
1175            x, y, z = make_vec(x, y, Z)
1176            if triangles is not None:
1177                triangles = np.asarray(triangles)
1178            h = ipv.plot_trisurf(x, y, z, triangles=triangles, **kw)
1179            ipv_set_transparency(kw, h)
1180            return h
1181        def scatter(self, x, y, z, **kw):
1182            x, y, z = make_vec(x, y, z)
1183            map_colors(z, kw)
1184            marker = kw.get('marker', None)
1185            kw['marker'] = _IPV_MARKERS.get(marker, marker)
1186            h = ipv.scatter(x, y, z, **kw)
1187            ipv_set_transparency(kw, h)
1188            return h
1189        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw):
1190            # Don't use contour for now (although we might want to later)
1191            self.pcolor(x, y, v, zdir='z', offset=offset, **kw)
1192        def pcolor(self, x, y, v, zdir='z', offset=0, **kw):
1193            x, y, v = make_vec(x, y, v)
1194            image = make_image(v, kw)
1195            xmin, xmax = x.min(), x.max()
1196            ymin, ymax = y.min(), y.max()
1197            x = np.array([[xmin, xmax], [xmin, xmax]])
1198            y = np.array([[ymin, ymin], [ymax, ymax]])
1199            z = x*0 + offset
1200            u = np.array([[0., 1], [0, 1]])
1201            v = np.array([[0., 0], [1, 1]])
1202            h = ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False)
1203            ipv_set_transparency(kw, h)
1204            h.material.side = "DoubleSide"
1205            return h
1206        def text(self, *args, **kw):
1207            pass
1208        def set_xlim(self, limits):
1209            ipv.xlim(*limits)
1210        def set_ylim(self, limits):
1211            ipv.ylim(*limits)
1212        def set_zlim(self, limits):
1213            ipv.zlim(*limits)
1214        def set_axes_on(self):
1215            ipv.style.axis_on()
1216        def set_axis_off(self):
1217            ipv.style.axes_off()
1218    return Plotter()
1219
1220def ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
1221    import ipywidgets as widgets
1222    import ipyvolume as ipv
1223
1224    axes = ipv_axes()
1225
1226    def draw(view, jitter):
1227        camera = ipv.gcf().camera
1228        #print(ipv.gcf().__dict__.keys())
1229        #print(dir(ipv.gcf()))
1230        ipv.figure(animation=0.)  # no animation when updating object mesh
1231
1232        # set small jitter as 0 if multiple pd dims
1233        dims = sum(v > 0 for v in jitter)
1234        limit = [0, 0.5, 5, 5][dims]
1235        jitter = [0 if v < limit else v for v in jitter]
1236
1237        ## Visualize as person on globe
1238        #draw_beam(axes, (0, 0))
1239        #draw_sphere(axes)
1240        #draw_person_on_sphere(axes, view)
1241
1242        ## Move beam instead of shape
1243        #draw_beam(axes, view=(-view[0], -view[1]))
1244        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0))
1245
1246        ## Move shape and draw scattering
1247        draw_beam(axes, (0, 0), steps=25)
1248        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.0,
1249                    draw_shape=draw_shape, projection=projection)
1250        draw_mesh(axes, view, jitter, dist=dist, n=mesh, radius=0.95,
1251                  projection=projection)
1252        draw_scattering(calculator, axes, view, jitter, dist=dist)
1253
1254        draw_axes(axes, origin=(-1, -1, -1.1))
1255        ipv.style.box_off()
1256        ipv.style.axes_off()
1257        ipv.xyzlabel(" ", " ", " ")
1258
1259        ipv.gcf().camera = camera
1260        ipv.show()
1261
1262
1263    trange, prange = (-180., 180., 1.), (-180., 180., 1.)
1264    dtrange, dprange = (0., 180., 1.), (0., 180., 1.)
1265
1266    ## Super simple interfaca, but uses non-ascii variable namese
1267    # Ξ φ ψ Δξ Δφ Δψ
1268    #def update(**kw):
1269    #    view = kw['Ξ'], kw['φ'], kw['ψ']
1270    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ']
1271    #    draw(view, jitter)
1272    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange)
1273
1274    def update(theta, phi, psi, dtheta, dphi, dpsi):
1275        draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi))
1276
1277    def slider(name, slice, init=0.):
1278        return widgets.FloatSlider(
1279            value=init,
1280            min=slice[0],
1281            max=slice[1],
1282            step=slice[2],
1283            description=name,
1284            disabled=False,
1285            #continuous_update=True,
1286            continuous_update=False,
1287            orientation='horizontal',
1288            readout=True,
1289            readout_format='.1f',
1290            )
1291    theta = slider(u'Ξ', trange, view[0])
1292    phi = slider(u'φ', prange, view[1])
1293    psi = slider(u'ψ', prange, view[2])
1294    dtheta = slider(u'Δξ', dtrange, jitter[0])
1295    dphi = slider(u'Δφ', dprange, jitter[1])
1296    dpsi = slider(u'Δψ', dprange, jitter[2])
1297    fields = {
1298        'theta': theta, 'phi': phi, 'psi': psi,
1299        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi,
1300    }
1301    ui = widgets.HBox([
1302        widgets.VBox([theta, phi, psi]),
1303        widgets.VBox([dtheta, dphi, dpsi])
1304    ])
1305
1306    out = widgets.interactive_output(update, fields)
1307    display(ui, out)
1308
1309
1310_ENGINES = {
1311    "matplotlib": mpl_plot,
1312    "mpl": mpl_plot,
1313    #"plotly": plotly_plot,
1314    "ipvolume": ipv_plot,
1315    "ipv": ipv_plot,
1316}
1317PLOT_ENGINE = _ENGINES["matplotlib"]
1318def set_plotter(name):
1319    global PLOT_ENGINE
1320    PLOT_ENGINE = _ENGINES[name]
1321
1322def main():
1323    parser = argparse.ArgumentParser(
1324        description="Display jitter",
1325        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1326        )
1327    parser.add_argument('-p', '--projection', choices=PROJECTIONS,
1328                        default=PROJECTIONS[0],
1329                        help='coordinate projection')
1330    parser.add_argument('-s', '--size', type=str, default='10,40,100',
1331                        help='a,b,c lengths')
1332    parser.add_argument('-v', '--view', type=str, default='0,0,0',
1333                        help='initial view angles')
1334    parser.add_argument('-j', '--jitter', type=str, default='0,0,0',
1335                        help='initial angular dispersion')
1336    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS,
1337                        default=DISTRIBUTIONS[0],
1338                        help='jitter distribution')
1339    parser.add_argument('-m', '--mesh', type=int, default=30,
1340                        help='#points in theta-phi mesh')
1341    parser.add_argument('shape', choices=SHAPES, nargs='?', default=SHAPES[0],
1342                        help='oriented shape')
1343    opts = parser.parse_args()
1344    size = tuple(float(v) for v in opts.size.split(','))
1345    view = tuple(float(v) for v in opts.view.split(','))
1346    jitter = tuple(float(v) for v in opts.jitter.split(','))
1347    run(opts.shape, size=size, view=view, jitter=jitter,
1348        mesh=opts.mesh, dist=opts.distribution,
1349        projection=opts.projection)
1350
1351if __name__ == "__main__":
1352    main()
Note: See TracBrowser for help on using the repository browser.