source: sasmodels/sasmodels/jitter.py @ cff2939

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since cff2939 was cff2939, checked in by Paul Kienzle <pkienzle@…>, 6 months ago

move extra made available for use of unicode greek letters into slider width

  • Property mode set to 100755
File size: 50.7 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Jitter Explorer
5===============
6
7Application to explore orientation angle and angular dispersity.
8
9From the command line::
10
11    # Show docs
12    python -m sasmodels.jitter --help
13
14    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi
15    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0
16
17From a jupyter cell::
18
19    import ipyvolume as ipv
20    from sasmodels import jitter
21    import importlib; importlib.reload(jitter)
22    jitter.set_plotter("ipv")
23
24    size = (10, 40, 100)
25    view = (20, 0, 0)
26
27    #size = (15, 15, 100)
28    #view = (60, 60, 0)
29
30    dview = (0, 0, 0)
31    #dview = (5, 5, 0)
32    #dview = (15, 180, 0)
33    #dview = (180, 15, 0)
34
35    projection = 'equirectangular'
36    #projection = 'azimuthal_equidistance'
37    #projection = 'guyou'
38    #projection = 'sinusoidal'
39    #projection = 'azimuthal_equal_area'
40
41    dist = 'uniform'
42    #dist = 'gaussian'
43
44    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection)
45    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '')
46    #ipv.savefig(filename+'.png')
47"""
48from __future__ import division, print_function
49
50import argparse
51
52import numpy as np
53from numpy import pi, cos, sin, sqrt, exp, degrees, radians
54
55def draw_beam(axes, view=(0, 0), alpha=0.5, steps=25):
56    """
57    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
58    """
59    #axes.plot([0,0],[0,0],[1,-1])
60    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha)
61
62    u = np.linspace(0, 2 * np.pi, steps)
63    v = np.linspace(-1, 1, 2)
64
65    r = 0.02
66    x = r*np.outer(np.cos(u), np.ones_like(v))
67    y = r*np.outer(np.sin(u), np.ones_like(v))
68    z = 1.3*np.outer(np.ones_like(u), v)
69
70    theta, phi = view
71    shape = x.shape
72    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
73    points = Rz(phi)*Ry(theta)*points
74    x, y, z = [v.reshape(shape) for v in points]
75    axes.plot_surface(x, y, z, color='yellow', alpha=alpha)
76
77    # TODO: draw endcaps on beam
78    ## Drawing tiny balls on the end will work
79    #draw_sphere(axes, radius=0.02, center=(0, 0, 1.3), color='yellow', alpha=alpha)
80    #draw_sphere(axes, radius=0.02, center=(0, 0, -1.3), color='yellow', alpha=alpha)
81    ## The following does not work
82    #triangles = [(0, i+1, i+2) for i in range(steps-2)]
83    #x_cap, y_cap = x[:, 0], y[:, 0]
84    #for z_cap in z[:, 0], z[:, -1]:
85    #    axes.plot_trisurf(x_cap, y_cap, z_cap, triangles,
86    #                      color='yellow', alpha=alpha)
87
88
89def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1):
90    """Draw an ellipsoid."""
91    a, b, c = size
92    u = np.linspace(0, 2 * np.pi, steps)
93    v = np.linspace(0, np.pi, steps)
94    x = a*np.outer(np.cos(u), np.sin(v))
95    y = b*np.outer(np.sin(u), np.sin(v))
96    z = c*np.outer(np.ones_like(u), np.cos(v))
97    x, y, z = transform_xyz(view, jitter, x, y, z)
98
99    axes.plot_surface(x, y, z, color='w', alpha=alpha)
100
101    draw_labels(axes, view, jitter, [
102        ('c+', [+0, +0, +c], [+1, +0, +0]),
103        ('c-', [+0, +0, -c], [+0, +0, -1]),
104        ('a+', [+a, +0, +0], [+0, +0, +1]),
105        ('a-', [-a, +0, +0], [+0, +0, -1]),
106        ('b+', [+0, +b, +0], [-1, +0, +0]),
107        ('b-', [+0, -b, +0], [-1, +0, +0]),
108    ])
109
110def draw_sc(axes, size, view, jitter, steps=None, alpha=1):
111    """Draw points for simple cubic paracrystal"""
112    atoms = _build_sc()
113    _draw_crystal(axes, size, view, jitter, atoms=atoms)
114
115def draw_fcc(axes, size, view, jitter, steps=None, alpha=1):
116    """Draw points for face-centered cubic paracrystal"""
117    # Build the simple cubic crystal
118    atoms = _build_sc()
119    # Define the centers for each face
120    # x planes at -1, 0, 1 have four centers per plane, at +/- 0.5 in y and z
121    x, y, z = (
122        [-1]*4 + [0]*4 + [1]*4,
123        ([-0.5]*2 + [0.5]*2)*3,
124        [-0.5, 0.5]*12,
125    )
126    # y and z planes can be generated by substituting x for y and z respectively
127    atoms.extend(zip(x+y+z, y+z+x, z+x+y))
128    _draw_crystal(axes, size, view, jitter, atoms=atoms)
129
130def draw_bcc(axes, size, view, jitter, steps=None, alpha=1):
131    """Draw points for body-centered cubic paracrystal"""
132    # Build the simple cubic crystal
133    atoms = _build_sc()
134    # Define the centers for each octant
135    # x plane at +/- 0.5 have four centers per plane at +/- 0.5 in y and z
136    x, y, z = (
137        [-0.5]*4 + [0.5]*4,
138        ([-0.5]*2 + [0.5]*2)*2,
139        [-0.5, 0.5]*8,
140    )
141    atoms.extend(zip(x, y, z))
142    _draw_crystal(axes, size, view, jitter, atoms=atoms)
143
144def _draw_crystal(axes, size, view, jitter, atoms=None):
145    atoms, size = np.asarray(atoms, 'd').T, np.asarray(size, 'd')
146    x, y, z = atoms*size[:, None]
147    x, y, z = transform_xyz(view, jitter, x, y, z)
148    axes.scatter([x[0]], [y[0]], [z[0]], c='yellow', marker='o')
149    axes.scatter(x[1:], y[1:], z[1:], c='r', marker='o')
150
151def _build_sc():
152    # three planes of 9 dots for x at -1, 0 and 1
153    x, y, z = (
154        [-1]*9 + [0]*9 + [1]*9,
155        ([-1]*3 + [0]*3 + [1]*3)*3,
156        [-1, 0, 1]*9,
157    )
158    atoms = list(zip(x, y, z))
159    #print(list(enumerate(atoms)))
160    # Pull the dot at (0, 0, 1) to the front of the list
161    # It will be highlighted in the view
162    index = 14
163    highlight = atoms[index]
164    del atoms[index]
165    atoms.insert(0, highlight)
166    return atoms
167
168def draw_box(axes, size, view):
169    a, b, c = size
170    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
171    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
172    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
173    x, y, z = transform_xyz(view, None, x, y, z)
174    def draw(i, j):
175        axes.plot([x[i],x[j]], [y[i], y[j]], [z[i], z[j]], color='black')
176    draw(0, 1)
177    draw(0, 2)
178    draw(0, 3)
179    draw(7, 4)
180    draw(7, 5)
181    draw(7, 6)
182
183def draw_parallelepiped(axes, size, view, jitter, steps=None,
184                        color=(0.6, 1.0, 0.6), alpha=1):
185    """Draw a parallelepiped."""
186    a, b, c = size
187    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
188    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
189    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
190    tri = np.array([
191        # counter clockwise triangles
192        # z: up/down, x: right/left, y: front/back
193        [0, 1, 2], [3, 2, 1], # top face
194        [6, 5, 4], [5, 6, 7], # bottom face
195        [0, 2, 6], [6, 4, 0], # right face
196        [1, 5, 7], [7, 3, 1], # left face
197        [2, 3, 6], [7, 6, 3], # front face
198        [4, 1, 0], [5, 1, 4], # back face
199    ])
200
201    x, y, z = transform_xyz(view, jitter, x, y, z)
202    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha,
203                      linewidth=0)
204
205    # Colour the c+ face of the box.
206    # Since I can't control face color, instead draw a thin box situated just
207    # in front of the "c+" face.  Use the c face so that rotations about psi
208    # rotate that face.
209    if 0:
210        color = (1, 0.6, 0.6)  # pink
211        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
212        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
213        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
214        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001)
215        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
216
217    draw_labels(axes, view, jitter, [
218        ('c+', [+0, +0, +c], [+1, +0, +0]),
219        ('c-', [+0, +0, -c], [+0, +0, -1]),
220        ('a+', [+a, +0, +0], [+0, +0, +1]),
221        ('a-', [-a, +0, +0], [+0, +0, -1]),
222        ('b+', [+0, +b, +0], [-1, +0, +0]),
223        ('b-', [+0, -b, +0], [-1, +0, +0]),
224    ])
225
226def draw_sphere(axes, radius=0.5, steps=25, center=(0,0,0), color='w', alpha=1.):
227    """Draw a sphere"""
228    u = np.linspace(0, 2 * np.pi, steps)
229    v = np.linspace(0, np.pi, steps)
230
231    x = radius * np.outer(np.cos(u), np.sin(v)) + center[0]
232    y = radius * np.outer(np.sin(u), np.sin(v)) + center[1]
233    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2]
234    axes.plot_surface(x, y, z, color=color, alpha=alpha)
235    #axes.plot_wireframe(x, y, z)
236
237def draw_axes(axes, origin=(-1, -1, -1), length=(2, 2, 2)):
238    x, y, z = origin
239    dx, dy, dz = length
240    axes.plot([x, x+dx], [y, y], [z, z], color='black')
241    axes.plot([x, x], [y, y+dy], [z, z], color='black')
242    axes.plot([x, x], [y, y], [z, z+dz], color='black')
243
244def draw_person_on_sphere(axes, view, height=0.5, radius=0.5):
245    limb_offset = height * 0.05
246    head_radius = height * 0.10
247    head_height = height - head_radius
248    neck_length = head_radius * 0.50
249    shoulder_height = height - 2*head_radius - neck_length
250    torso_length = shoulder_height * 0.55
251    torso_radius = torso_length * 0.30
252    leg_length = shoulder_height - torso_length
253    arm_length = torso_length * 0.90
254
255    def _draw_part(x, z):
256        y = np.zeros_like(x)
257        xp, yp, zp = transform_xyz(view, None, x, y, z + radius)
258        axes.plot(xp, yp, zp, color='k')
259
260    # circle for head
261    u = np.linspace(0, 2 * np.pi, 40)
262    x = head_radius * np.cos(u)
263    z = head_radius * np.sin(u) + head_height
264    _draw_part(x, z)
265
266    # rectangle for body
267    x = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius])
268    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length
269    _draw_part(x, z)
270
271    # arms
272    x = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius])
273    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height])
274    _draw_part(x, z)
275    _draw_part(-x, z)
276
277    # legs
278    x = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset])
279    z = np.array([0, leg_length])
280    _draw_part(x, z)
281    _draw_part(-x, z)
282
283    limits = [-radius-height, radius+height]
284    axes.set_xlim(limits)
285    axes.set_ylim(limits)
286    axes.set_zlim(limits)
287    axes.set_axis_off()
288
289def draw_jitter(axes, view, jitter, dist='gaussian',
290                size=(0.1, 0.4, 1.0),
291                draw_shape=draw_parallelepiped,
292                projection='equirectangular',
293                alpha=0.8,
294                views=None):
295    """
296    Represent jitter as a set of shapes at different orientations.
297    """
298    project, project_weight = get_projection(projection)
299
300    # set max diagonal to 0.95
301    scale = 0.95/sqrt(sum(v**2 for v in size))
302    size = tuple(scale*v for v in size)
303
304    dtheta, dphi, dpsi = jitter
305    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist]
306    def steps(delta):
307        if views is None:
308            n = max(3, min(25, 2*int(base*delta/5)))
309        else:
310            n = views
311        return base*delta*np.linspace(-1, 1, n) if delta > 0 else [0.]
312    for theta in steps(dtheta):
313        for phi in steps(dphi):
314            for psi in steps(dpsi):
315                w = project_weight(theta, phi, 1.0, 1.0)
316                if w > 0:
317                    dview = project(theta, phi, psi)
318                    draw_shape(axes, size, view, dview, alpha=alpha)
319    for v in 'xyz':
320        a, b, c = size
321        lim = np.sqrt(a**2 + b**2 + c**2)
322        getattr(axes, 'set_'+v+'lim')([-lim, lim])
323        #getattr(axes, v+'axis').label.set_text(v)
324
325PROJECTIONS = [
326    # in order of PROJECTION number; do not change without updating the
327    # constants in kernel_iq.c
328    'equirectangular', 'sinusoidal', 'guyou', 'azimuthal_equidistance',
329    'azimuthal_equal_area',
330]
331def get_projection(projection):
332
333    """
334    jitter projections
335    <https://en.wikipedia.org/wiki/List_of_map_projections>
336
337    equirectangular (standard latitude-longitude mesh)
338        <https://en.wikipedia.org/wiki/Equirectangular_projection>
339        Allows free movement in phi (around the equator), but theta is
340        limited to +/- 90, and points are cos-weighted. Jitter in phi is
341        uniform in weight along a line of latitude.  With small theta and
342        phi ranging over +/- 180 this forms a wobbling disk.  With small
343        phi and theta ranging over +/- 90 this forms a wedge like a slice
344        of an orange.
345    azimuthal_equidistance (Postel)
346        <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>
347        Preserves distance from center, and so is an excellent map for
348        representing a bivariate gaussian on the surface.  Theta and phi
349        operate identically, cutting wegdes from the antipode of the viewing
350        angle.  This unfortunately does not allow free movement in either
351        theta or phi since the orthogonal wobble decreases to 0 as the body
352        rotates through 180 degrees.
353    sinusoidal (Sanson-Flamsteed, Mercator equal-area)
354        <https://en.wikipedia.org/wiki/Sinusoidal_projection>
355        Preserves arc length with latitude, giving bad behaviour at
356        theta near +/- 90.  Theta and phi operate somewhat differently,
357        so a system with a-b-c dtheta-dphi-dpsi will not give the same
358        value as one with b-a-c dphi-dtheta-dpsi, as would be the case
359        for azimuthal equidistance.  Free movement using theta or phi
360        uniform over +/- 180 will work, but not as well as equirectangular
361        phi, with theta being slightly worse.  Computationally it is much
362        cheaper for wide theta-phi meshes since it excludes points which
363        lie outside the sinusoid near theta +/- 90 rather than packing
364        them close together as in equirectangle.  Note that the poles
365        will be slightly overweighted for theta > 90 with the circle
366        from theta at 90+dt winding backwards around the pole, overlapping
367        the circle from theta at 90-dt.
368    Guyou (hemisphere-in-a-square) **not weighted**
369        <https://en.wikipedia.org/wiki/Guyou_hemisphere-in-a-square_projection>
370        With tiling, allows rotation in phi or theta through +/- 180, with
371        uniform spacing.  Both theta and phi allow free rotation, with wobble
372        in the orthogonal direction reasonably well behaved (though not as
373        good as equirectangular phi). The forward/reverse transformations
374        relies on elliptic integrals that are somewhat expensive, so the
375        behaviour has to be very good to justify the cost and complexity.
376        The weighting function for each point has not yet been computed.
377        Note: run the module *guyou.py* directly and it will show the forward
378        and reverse mappings.
379    azimuthal_equal_area  **incomplete**
380        <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>
381        Preserves the relative density of the surface patches.  Not that
382        useful and not completely implemented
383    Gauss-Kreuger **not implemented**
384        <https://en.wikipedia.org/wiki/Transverse_Mercator_projection#Ellipsoidal_transverse_Mercator>
385        Should allow free movement in theta, but phi is distorted.
386    """
387    # TODO: try Kent distribution instead of a gaussian warped by projection
388
389    if projection == 'equirectangular':  #define PROJECTION 1
390        def _project(theta_i, phi_j, psi):
391            latitude, longitude = theta_i, phi_j
392            return latitude, longitude, psi
393            #return Rx(phi_j)*Ry(theta_i)
394        def _weight(theta_i, phi_j, w_i, w_j):
395            return w_i*w_j*abs(cos(radians(theta_i)))
396    elif projection == 'sinusoidal':  #define PROJECTION 2
397        def _project(theta_i, phi_j, psi):
398            latitude = theta_i
399            scale = cos(radians(latitude))
400            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0
401            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
402            return latitude, longitude, psi
403            #return Rx(longitude)*Ry(latitude)
404        def _project(theta_i, phi_j, w_i, w_j):
405            latitude = theta_i
406            scale = cos(radians(latitude))
407            active = 1 if abs(phi_j) < abs(scale)*180 else 0
408            return active*w_i*w_j
409    elif projection == 'guyou':  #define PROJECTION 3  (eventually?)
410        def _project(theta_i, phi_j, psi):
411            from .guyou import guyou_invert
412            #latitude, longitude = guyou_invert([theta_i], [phi_j])
413            longitude, latitude = guyou_invert([phi_j], [theta_i])
414            return latitude, longitude, psi
415            #return Rx(longitude[0])*Ry(latitude[0])
416        def _weight(theta_i, phi_j, w_i, w_j):
417            return w_i*w_j
418    elif projection == 'azimuthal_equidistance':
419        # Note that calculates angles for Rz Ry rather than Rx Ry
420        def _project(theta_i, phi_j, psi):
421            latitude = sqrt(theta_i**2 + phi_j**2)
422            longitude = degrees(np.arctan2(phi_j, theta_i))
423            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
424            return latitude, longitude, psi-longitude, 'zyz'
425            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
426            #return R_to_xyz(R)
427            #return Rz(longitude)*Ry(latitude)
428        def _weight(theta_i, phi_j, w_i, w_j):
429            # Weighting for each point comes from the integral:
430            #     \int\int I(q, lat, log) sin(lat) dlat dlog
431            # We are doing a conformal mapping from disk to sphere, so we need
432            # a change of variables g(theta, phi) -> (lat, long):
433            #     lat, long = sqrt(theta^2 + phi^2), arctan(phi/theta)
434            # giving:
435            #     dtheta dphi = det(J) dlat dlong
436            # where J is the jacobian from the partials of g. Using
437            #     R = sqrt(theta^2 + phi^2),
438            # then
439            #     J = [[x/R, Y/R], -y/R^2, x/R^2]]
440            # and
441            #     det(J) = 1/R
442            # with the final integral being:
443            #    \int\int I(q, theta, phi) sin(R)/R dtheta dphi
444            #
445            # This does approximately the right thing, decreasing the weight
446            # of each point as you go farther out on the disk, but it hasn't
447            # yet been checked against the 1D integral results. Prior
448            # to declaring this "good enough" and checking that integrals
449            # work in practice, we will examine alternative mappings.
450            #
451            # The issue is that the mapping does not support the case of free
452            # rotation about a single axis correctly, with a small deviation
453            # in the orthogonal axis independent of the first axis.  Like the
454            # usual polar coordiates integration, the integrated sections
455            # form wedges, though at least in this case the wedge cuts through
456            # the entire sphere, and treats theta and phi identically.
457            latitude = sqrt(theta_i**2 + phi_j**2)
458            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
459            return weight*w_i*w_j if latitude < 180 else 0
460    elif projection == 'azimuthal_equal_area':
461        # Note that calculates angles for Rz Ry rather than Rx Ry
462        def _project(theta_i, phi_j, psi):
463            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180)
464            latitude = 180-degrees(2*np.arccos(radius))
465            longitude = degrees(np.arctan2(phi_j, theta_i))
466            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
467            return latitude, longitude, psi, "zyz"
468            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
469            #return R_to_xyz(R)
470            #return Rz(longitude)*Ry(latitude)
471        def _weight(theta_i, phi_j, w_i, w_j):
472            latitude = sqrt(theta_i**2 + phi_j**2)
473            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
474            return weight*w_i*w_j if latitude < 180 else 0
475    else:
476        raise ValueError("unknown projection %r"%projection)
477
478    return _project, _weight
479
480def R_to_xyz(R):
481    """
482    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix.
483
484    Extracting Euler Angles from a Rotation Matrix
485    Mike Day, Insomniac Games
486    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf
487    Based on: Shoemake’s "Euler Angle Conversion", Graphics Gems IV, pp.  222-229
488    """
489    phi = np.arctan2(R[1, 2], R[2, 2])
490    theta = np.arctan2(-R[0, 2], np.sqrt(R[0, 0]**2 + R[0, 1]**2))
491    psi = np.arctan2(R[0, 1], R[0, 0])
492    return np.degrees(phi), np.degrees(theta), np.degrees(psi)
493
494def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian',
495              projection='equirectangular'):
496    """
497    Draw the dispersion mesh showing the theta-phi orientations at which
498    the model will be evaluated.
499    """
500
501    _project, _weight = get_projection(projection)
502    def _rotate(theta, phi, z):
503        dview = _project(theta, phi, 0.)
504        if len(dview) == 4: # hack for zyz coords
505            return Rz(dview[1])*Ry(dview[0])*z
506        else:
507            return Rx(dview[1])*Ry(dview[0])*z
508
509
510    dist_x = np.linspace(-1, 1, n)
511    weights = np.ones_like(dist_x)
512    if dist == 'gaussian':
513        dist_x *= 3
514        weights = exp(-0.5*dist_x**2)
515    elif dist == 'rectangle':
516        # Note: uses sasmodels ridiculous definition of rectangle width
517        dist_x *= sqrt(3)
518    elif dist == 'uniform':
519        pass
520    else:
521        raise ValueError("expected dist to be gaussian, rectangle or uniform")
522
523    # mesh in theta, phi formed by rotating z
524    dtheta, dphi, dpsi = jitter
525    z = np.matrix([[0], [0], [radius]])
526    points = np.hstack([_rotate(theta_i, phi_j, z)
527                        for theta_i in dtheta*dist_x
528                        for phi_j in dphi*dist_x])
529    dist_w = np.array([_weight(theta_i, phi_j, w_i, w_j)
530                       for w_i, theta_i in zip(weights, dtheta*dist_x)
531                       for w_j, phi_j in zip(weights, dphi*dist_x)])
532    #print(max(dist_w), min(dist_w), min(dist_w[dist_w > 0]))
533    points = points[:, dist_w > 0]
534    dist_w = dist_w[dist_w > 0]
535    dist_w /= max(dist_w)
536
537    # rotate relative to beam
538    points = orient_relative_to_beam(view, points)
539
540    x, y, z = [np.array(v).flatten() for v in points]
541    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
542    axes.scatter(x, y, z, c=dist_w, marker='o', vmin=0., vmax=1.)
543
544def draw_labels(axes, view, jitter, text):
545    """
546    Draw text at a particular location.
547    """
548    labels, locations, orientations = zip(*text)
549    px, py, pz = zip(*locations)
550    dx, dy, dz = zip(*orientations)
551
552    px, py, pz = transform_xyz(view, jitter, px, py, pz)
553    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
554
555    # TODO: zdir for labels is broken, and labels aren't appearing.
556    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
557        zdir = np.asarray(zdir).flatten()
558        axes.text(p[0], p[1], p[2], label, zdir=zdir)
559
560# Definition of rotation matrices comes from wikipedia:
561#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
562def Rx(angle):
563    """Construct a matrix to rotate points about *x* by *angle* degrees."""
564    angle = radians(angle)
565    rot = [[1, 0, 0],
566           [0, +cos(angle), -sin(angle)],
567           [0, +sin(angle), +cos(angle)]]
568    return np.matrix(rot)
569
570def Ry(angle):
571    """Construct a matrix to rotate points about *y* by *angle* degrees."""
572    angle = radians(angle)
573    rot = [[+cos(angle), 0, +sin(angle)],
574           [0, 1, 0],
575           [-sin(angle), 0, +cos(angle)]]
576    return np.matrix(rot)
577
578def Rz(angle):
579    """Construct a matrix to rotate points about *z* by *angle* degrees."""
580    angle = radians(angle)
581    rot = [[+cos(angle), -sin(angle), 0],
582           [+sin(angle), +cos(angle), 0],
583           [0, 0, 1]]
584    return np.matrix(rot)
585
586def transform_xyz(view, jitter, x, y, z):
587    """
588    Send a set of (x,y,z) points through the jitter and view transforms.
589    """
590    x, y, z = [np.asarray(v) for v in (x, y, z)]
591    shape = x.shape
592    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
593    points = apply_jitter(jitter, points)
594    points = orient_relative_to_beam(view, points)
595    x, y, z = [np.array(v).reshape(shape) for v in points]
596    return x, y, z
597
598def apply_jitter(jitter, points):
599    """
600    Apply the jitter transform to a set of points.
601
602    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
603    """
604    if jitter is None:
605        return points
606    # Hack to deal with the fact that azimuthal_equidistance uses euler angles
607    if len(jitter) == 4:
608        dtheta, dphi, dpsi, _ = jitter
609        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points
610    else:
611        dtheta, dphi, dpsi = jitter
612        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
613    return points
614
615def orient_relative_to_beam(view, points):
616    """
617    Apply the view transform to a set of points.
618
619    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
620    """
621    theta, phi, psi = view
622    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle
623    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles
624    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle
625    return points
626
627def orient_relative_to_beam_quaternion(view, points):
628    """
629    Apply the view transform to a set of points.
630
631    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
632
633    This variant uses quaternions rather than rotation matrices for the
634    computation.  It works but it is not used because it doesn't solve
635    any problems.  The challenge of mapping theta/phi/psi to SO(3) does
636    not disappear by calculating the transform differently.
637    """
638    theta, phi, psi = view
639    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1]
640    q = Quaternion(1, [0, 0, 0])
641    ## Compose a rotation about the three axes by rotating
642    ## the unit vectors before applying the rotation.
643    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q
644    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q
645    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q
646    ## The above turns out to be equivalent to reversing
647    ## the order of application, so ignore it and use below.
648    q = q * Quaternion.from_angle_axis(theta, x)
649    q = q * Quaternion.from_angle_axis(phi, y)
650    q = q * Quaternion.from_angle_axis(psi, z)
651    ## Reverse the order by post-multiply rather than pre-multiply
652    #q = Quaternion.from_angle_axis(theta, x) * q
653    #q = Quaternion.from_angle_axis(phi, y) * q
654    #q = Quaternion.from_angle_axis(psi, z) * q
655    #print("axes psi", q.rot(np.matrix([x, y, z]).T))
656    return q.rot(points)
657#orient_relative_to_beam = orient_relative_to_beam_quaternion
658
659# Simple stand-alone quaternion class
660import numpy as np
661from copy import copy
662class Quaternion(object):
663    def __init__(self, w, r):
664         self.w = w
665         self.r = np.asarray(r, dtype='d')
666    @staticmethod
667    def from_angle_axis(theta, r):
668         theta = np.radians(theta)/2
669         r = np.asarray(r)
670         w = np.cos(theta)
671         r = np.sin(theta)*r/np.dot(r,r)
672         return Quaternion(w, r)
673    def __mul__(self, other):
674        if isinstance(other, Quaternion):
675            w = self.w*other.w - np.dot(self.r, other.r)
676            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r)
677            return Quaternion(w, r)
678    def rot(self, v):
679        v = np.asarray(v).T
680        use_transpose = (v.shape[-1] != 3)
681        if use_transpose: v = v.T
682        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v)
683        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v))
684        if use_transpose: v = v.T
685        return v.T
686    def conj(self):
687        return Quaternion(self.w, -self.r)
688    def inv(self):
689        return self.conj()/self.norm()**2
690    def norm(self):
691        return np.sqrt(self.w**2 + np.sum(self.r**2))
692    def __str__(self):
693        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2])
694def test_qrot():
695    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees from y.
696    # The rotation axis is determined by rotating the point [0, 1, 0] about x.
697    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0])
698    q = Quaternion.from_angle_axis(60, ax)
699    # Set the point to be rotated, and its expected rotated position.
700    p = [1, -1, 2]
701    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8]
702    #print(q, q.rot(p) - target)
703    assert max(abs(q.rot(p) - target)) < 1e-14
704#test_qrot()
705#import sys; sys.exit()
706
707# translate between number of dimension of dispersity and the number of
708# points along each dimension.
709PD_N_TABLE = {
710    (0, 0, 0): (0, 0, 0),     # 0
711    (1, 0, 0): (100, 0, 0),   # 100
712    (0, 1, 0): (0, 100, 0),
713    (0, 0, 1): (0, 0, 100),
714    (1, 1, 0): (30, 30, 0),   # 900
715    (1, 0, 1): (30, 0, 30),
716    (0, 1, 1): (0, 30, 30),
717    (1, 1, 1): (15, 15, 15),  # 3375
718}
719
720def clipped_range(data, portion=1.0, mode='central'):
721    """
722    Determine range from data.
723
724    If *portion* is 1, use full range, otherwise use the center of the range
725    or the top of the range, depending on whether *mode* is 'central' or 'top'.
726    """
727    if portion == 1.0:
728        return data.min(), data.max()
729    elif mode == 'central':
730        data = np.sort(data.flatten())
731        offset = int(portion*len(data)/2 + 0.5)
732        return data[offset], data[-offset]
733    elif mode == 'top':
734        data = np.sort(data.flatten())
735        offset = int(portion*len(data) + 0.5)
736        return data[offset], data[-1]
737
738def draw_scattering(calculator, axes, view, jitter, dist='gaussian'):
739    """
740    Plot the scattering for the particular view.
741
742    *calculator* is returned from :func:`build_model`.  *axes* are the 3D axes
743    on which the data will be plotted.  *view* and *jitter* are the current
744    orientation and orientation dispersity.  *dist* is one of the sasmodels
745    weight distributions.
746    """
747    if dist == 'uniform':  # uniform is not yet in this branch
748        dist, scale = 'rectangle', 1/sqrt(3)
749    else:
750        scale = 1
751
752    # add the orientation parameters to the model parameters
753    theta, phi, psi = view
754    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
755    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd > 0, phi_pd > 0, psi_pd > 0)]
756    ## increase pd_n for testing jitter integration rather than simple viz
757    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
758
759    pars = dict(
760        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
761        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
762        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
763    )
764    pars.update(calculator.pars)
765
766    # compute the pattern
767    qx, qy = calculator._data.x_bins, calculator._data.y_bins
768    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
769
770    # scale it and draw it
771    Iqxy = np.log(Iqxy)
772    if calculator.limits:
773        # use limits from orientation (0,0,0)
774        vmin, vmax = calculator.limits
775    else:
776        vmax = Iqxy.max()
777        vmin = vmax*10**-7
778        #vmin, vmax = clipped_range(Iqxy, portion=portion, mode='top')
779    #vmin, vmax = Iqxy.min(), Iqxy.max()
780    #print("range",(vmin,vmax))
781    #qx, qy = np.meshgrid(qx, qy)
782    if 0:
783        from matplotlib import cm
784        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
785        level[level < 0] = 0
786        colors = plt.get_cmap()(level)
787        #colors = cm.coolwarm(level)
788        #colors = cm.gist_yarg(level)
789        #colors = cm.Wistia(level)
790        colors[level<=0, 3] = 0.  # set floor to transparent
791        x, y = np.meshgrid(qx/qx.max(), qy/qy.max())
792        axes.plot_surface(x, y, -1.1*np.ones_like(x), facecolors=colors)
793    elif 1:
794        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
795                      levels=np.linspace(vmin, vmax, 24))
796    else:
797        axes.pcolormesh(qx, qy, Iqxy)
798
799def build_model(model_name, n=150, qmax=0.5, **pars):
800    """
801    Build a calculator for the given shape.
802
803    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
804    on which to evaluate the model.  The remaining parameters are stored in
805    the returned calculator as *calculator.pars*.  They are used by
806    :func:`draw_scattering` to set the non-orientation parameters in the
807    calculation.
808
809    Returns a *calculator* function which takes a dictionary or parameters and
810    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
811    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
812    for details.
813    """
814    from sasmodels.core import load_model_info, build_model as build_sasmodel
815    from sasmodels.data import empty_data2D
816    from sasmodels.direct_model import DirectModel
817
818    model_info = load_model_info(model_name)
819    model = build_sasmodel(model_info) #, dtype='double!')
820    q = np.linspace(-qmax, qmax, n)
821    data = empty_data2D(q, q)
822    calculator = DirectModel(data, model)
823
824    # stuff the values for non-orientation parameters into the calculator
825    calculator.pars = pars.copy()
826    calculator.pars.setdefault('backgound', 1e-3)
827
828    # fix the data limits so that we can see if the pattern fades
829    # under rotation or angular dispersion
830    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
831    Iqxy = np.log(Iqxy)
832    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
833    calculator.limits = vmin, vmax+1
834
835    return calculator
836
837def select_calculator(model_name, n=150, size=(10, 40, 100)):
838    """
839    Create a model calculator for the given shape.
840
841    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
842    parallelepiped or bcc_paracrystal. *n* is the number of points to use
843    in the q range.  *qmax* is chosen based on model parameters for the
844    given model to show something intersting.
845
846    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
847    equitorial axes and polar axis respectively.  See :func:`build_model`
848    for details on the returned calculator.
849    """
850    a, b, c = size
851    d_factor = 0.06  # for paracrystal models
852    if model_name == 'sphere':
853        calculator = build_model('sphere', n=n, radius=c)
854        a = b = c
855    elif model_name == 'sc_paracrystal':
856        a = b = c
857        dnn = c
858        radius = 0.5*c
859        calculator = build_model('sc_paracrystal', n=n, dnn=dnn,
860                                 d_factor=d_factor, radius=(1-d_factor)*radius,
861                                 background=0)
862    elif model_name == 'fcc_paracrystal':
863        a = b = c
864        # nearest neigbour distance dnn should be 2 radius, but I think the
865        # model uses lattice spacing rather than dnn in its calculations
866        dnn = 0.5*c
867        radius = sqrt(2)/4 * c
868        calculator = build_model('fcc_paracrystal', n=n, dnn=dnn,
869                                 d_factor=d_factor, radius=(1-d_factor)*radius,
870                                 background=0)
871    elif model_name == 'bcc_paracrystal':
872        a = b = c
873        # nearest neigbour distance dnn should be 2 radius, but I think the
874        # model uses lattice spacing rather than dnn in its calculations
875        dnn = 0.5*c
876        radius = sqrt(3)/2 * c
877        calculator = build_model('bcc_paracrystal', n=n, dnn=dnn,
878                                 d_factor=d_factor, radius=(1-d_factor)*radius,
879                                 background=0)
880    elif model_name == 'cylinder':
881        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
882        a = b
883    elif model_name == 'ellipsoid':
884        calculator = build_model('ellipsoid', n=n, qmax=1.0,
885                                 radius_polar=c, radius_equatorial=b)
886        a = b
887    elif model_name == 'triaxial_ellipsoid':
888        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
889                                 radius_equat_minor=a,
890                                 radius_equat_major=b,
891                                 radius_polar=c)
892    elif model_name == 'parallelepiped':
893        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
894    else:
895        raise ValueError("unknown model %s"%model_name)
896
897    return calculator, (a, b, c)
898
899SHAPES = [
900    'parallelepiped',
901    'sphere', 'ellipsoid', 'triaxial_ellipsoid',
902    'cylinder',
903    'fcc_paracrystal', 'bcc_paracrystal', 'sc_paracrystal',
904]
905
906DRAW_SHAPES = {
907    'fcc_paracrystal': draw_fcc,
908    'bcc_paracrystal': draw_bcc,
909    'sc_paracrystal': draw_sc,
910    'parallelepiped': draw_parallelepiped,
911}
912
913DISTRIBUTIONS = [
914    'gaussian', 'rectangle', 'uniform',
915]
916DIST_LIMITS = {
917    'gaussian': 30,
918    'rectangle': 90/sqrt(3),
919    'uniform': 90,
920}
921
922
923def run(model_name='parallelepiped', size=(10, 40, 100),
924        view=(0, 0, 0), jitter=(0, 0, 0),
925        dist='gaussian', mesh=30,
926        projection='equirectangular'):
927    """
928    Show an interactive orientation and jitter demo.
929
930    *model_name* is one of: sphere, ellipsoid, triaxial_ellipsoid,
931    parallelepiped, cylinder, or sc/fcc/bcc_paracrystal
932
933    *size* gives the dimensions (a, b, c) of the shape.
934
935    *view* gives the initial view (theta, phi, psi) of the shape.
936
937    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape.
938
939    *dist* is the type of dispersition: gaussian, rectangle, or uniform.
940
941    *mesh* is the number of points in the dispersion mesh.
942
943    *projection* is the map projection to use for the mesh: equirectangular,
944    sinusoidal, guyou, azimuthal_equidistance, or azimuthal_equal_area.
945    """
946    # projection number according to 1-order position in list, but
947    # only 1 and 2 are implemented so far.
948    from sasmodels import generate
949    generate.PROJECTION = PROJECTIONS.index(projection) + 1
950    if generate.PROJECTION > 2:
951        print("*** PROJECTION %s not implemented in scattering function ***"%projection)
952        generate.PROJECTION = 2
953
954    # set up calculator
955    calculator, size = select_calculator(model_name, n=150, size=size)
956    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped)
957    #draw_shape = draw_fcc
958
959    ## uncomment to set an independent the colour range for every view
960    ## If left commented, the colour range is fixed for all views
961    calculator.limits = None
962
963    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection)
964
965def mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
966    # Note: travis-ci does not support mpl_toolkits.mplot3d, but this shouldn't be
967    # an issue since we are lazy-loading the package on a path that isn't tested.
968    import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot
969    import matplotlib as mpl
970    import matplotlib.pyplot as plt
971    from matplotlib.widgets import Slider
972
973    ## create the plot window
974    #plt.hold(True)
975    plt.subplots(num=None, figsize=(5.5, 5.5))
976    plt.set_cmap('gist_earth')
977    plt.clf()
978    plt.gcf().canvas.set_window_title(projection)
979    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
980    #axes = plt.subplot(gs[0], projection='3d')
981    axes = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
982    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
983        axes.axis('square')
984    except Exception:
985        pass
986
987    # CRUFT: use axisbg instead of facecolor for matplotlib<2
988    facecolor_prop = 'facecolor' if mpl.__version__ > '2' else 'axisbg'
989    props = {facecolor_prop: 'lightgoldenrodyellow'}
990
991    ## add control widgets to plot
992    axes_theta = plt.axes([0.05, 0.15, 0.50, 0.04], **props)
993    axes_phi = plt.axes([0.05, 0.10, 0.50, 0.04], **props)
994    axes_psi = plt.axes([0.05, 0.05, 0.50, 0.04], **props)
995    stheta = Slider(axes_theta, u'Ξ', -90, 90, valinit=0)
996    sphi = Slider(axes_phi, u'φ', -180, 180, valinit=0)
997    spsi = Slider(axes_psi, u'ψ', -180, 180, valinit=0)
998
999    axes_dtheta = plt.axes([0.70, 0.15, 0.20, 0.04], **props)
1000    axes_dphi = plt.axes([0.70, 0.1, 0.20, 0.04], **props)
1001    axes_dpsi = plt.axes([0.70, 0.05, 0.20, 0.04], **props)
1002
1003    # Note: using ridiculous definition of rectangle distribution, whose width
1004    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
1005    # the maximum width to 90.
1006    dlimit = DIST_LIMITS[dist]
1007    sdtheta = Slider(axes_dtheta, u'Δξ', 0, 2*dlimit, valinit=0)
1008    sdphi = Slider(axes_dphi, u'Δφ', 0, 2*dlimit, valinit=0)
1009    sdpsi = Slider(axes_dpsi, u'Δψ', 0, 2*dlimit, valinit=0)
1010
1011    ## initial view and jitter
1012    theta, phi, psi = view
1013    stheta.set_val(theta)
1014    sphi.set_val(phi)
1015    spsi.set_val(psi)
1016    dtheta, dphi, dpsi = jitter
1017    sdtheta.set_val(dtheta)
1018    sdphi.set_val(dphi)
1019    sdpsi.set_val(dpsi)
1020
1021    ## callback to draw the new view
1022    def update(val, axis=None):
1023        view = stheta.val, sphi.val, spsi.val
1024        jitter = sdtheta.val, sdphi.val, sdpsi.val
1025        # set small jitter as 0 if multiple pd dims
1026        dims = sum(v > 0 for v in jitter)
1027        limit = [0, 0.5, 5, 5][dims]
1028        jitter = [0 if v < limit else v for v in jitter]
1029        axes.cla()
1030
1031        ## Visualize as person on globe
1032        #draw_sphere(axes)
1033        #draw_person_on_sphere(axes, view)
1034
1035        ## Move beam instead of shape
1036        #draw_beam(axes, -view[:2])
1037        #draw_jitter(axes, (0,0,0), (0,0,0), views=3)
1038
1039        ## Move shape and draw scattering
1040        draw_beam(axes, (0, 0), alpha=1.)
1041        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.,
1042                    draw_shape=draw_shape, projection=projection, views=3)
1043        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection)
1044        draw_scattering(calculator, axes, view, jitter, dist=dist)
1045
1046        plt.gcf().canvas.draw()
1047
1048    ## bind control widgets to view updater
1049    stheta.on_changed(lambda v: update(v, 'theta'))
1050    sphi.on_changed(lambda v: update(v, 'phi'))
1051    spsi.on_changed(lambda v: update(v, 'psi'))
1052    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
1053    sdphi.on_changed(lambda v: update(v, 'dphi'))
1054    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
1055
1056    ## initialize view
1057    update(None, 'phi')
1058
1059    ## go interactive
1060    plt.show()
1061
1062
1063def map_colors(z, kw):
1064    from matplotlib import cm
1065
1066    cmap = kw.pop('cmap', cm.coolwarm)
1067    alpha = kw.pop('alpha', None)
1068    vmin = kw.pop('vmin', z.min())
1069    vmax = kw.pop('vmax', z.max())
1070    c = kw.pop('c', None)
1071    color = kw.pop('color', c)
1072    if color is None:
1073        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1)
1074        color = cmap(znorm)
1075    elif isinstance(color, np.ndarray) and color.shape == z.shape:
1076        color = cmap(color)
1077    if alpha is None:
1078        if isinstance(color, np.ndarray):
1079            color = color[..., :3]
1080    else:
1081        color[..., 3] = alpha
1082    kw['color'] = color
1083
1084def make_vec(*args):
1085    #return [np.asarray(v, 'd').flatten() for v in args]
1086    return [np.asarray(v, 'd') for v in args]
1087
1088def make_image(z, kw):
1089    import PIL.Image
1090    from matplotlib import cm
1091
1092    cmap = kw.pop('cmap', cm.coolwarm)
1093
1094    znorm = (z-z.min())/z.ptp()
1095    c = cmap(znorm)
1096    c = c[..., :3]
1097    rgb = np.asarray(c*255, 'u1')
1098    image = PIL.Image.fromarray(rgb, mode='RGB')
1099    return image
1100
1101
1102_IPV_MARKERS = {
1103    'o': 'sphere',
1104}
1105_IPV_COLORS = {
1106    'w': 'white',
1107    'k': 'black',
1108    'c': 'cyan',
1109    'm': 'magenta',
1110    'y': 'yellow',
1111    'r': 'red',
1112    'g': 'green',
1113    'b': 'blue',
1114}
1115def ipv_fix_color(kw):
1116    alpha = kw.pop('alpha', None)
1117    color = kw.get('color', None)
1118    if isinstance(color, str):
1119        color = _IPV_COLORS.get(color, color)
1120        kw['color'] = color
1121    if alpha is not None:
1122        color = kw['color']
1123        #TODO: convert color to [r, g, b, a] if not already
1124        if isinstance(color, (tuple, list)):
1125            if len(color) == 3:
1126                color = (color[0], color[1], color[2], alpha)
1127            else:
1128                color = (color[0], color[1], color[2], alpha*color[3])
1129            color = np.array(color)
1130        if isinstance(color, np.ndarray) and color.shape[-1] == 4:
1131            color[..., 3] = alpha
1132            kw['color'] = color
1133
1134def ipv_set_transparency(kw, obj):
1135    color = kw.get('color', None)
1136    if (isinstance(color, np.ndarray)
1137            and color.shape[-1] == 4
1138            and (color[..., 3] != 1.0).any()):
1139        obj.material.transparent = True
1140        obj.material.side = "FrontSide"
1141
1142def ipv_axes():
1143    import ipyvolume as ipv
1144
1145    class Plotter:
1146        # transparency can be achieved by setting the following:
1147        #    mesh.color = [r, g, b, alpha]
1148        #    mesh.material.transparent = True
1149        #    mesh.material.side = "FrontSide"
1150        # smooth(ish) rotation can be achieved by setting:
1151        #    slide.continuous_update = True
1152        #    figure.animation = 0.
1153        #    mesh.material.x = x
1154        #    mesh.material.y = y
1155        #    mesh.material.z = z
1156        # maybe need to synchronize update of x/y/z to avoid shimmy when moving
1157        def plot(self, x, y, z, **kw):
1158            ipv_fix_color(kw)
1159            x, y, z = make_vec(x, y, z)
1160            ipv.plot(x, y, z, **kw)
1161        def plot_surface(self, x, y, z, **kw):
1162            facecolors = kw.pop('facecolors', None)
1163            if facecolors is not None:
1164                kw['color'] = facecolors
1165            ipv_fix_color(kw)
1166            x, y, z = make_vec(x, y, z)
1167            h = ipv.plot_surface(x, y, z, **kw)
1168            ipv_set_transparency(kw, h)
1169            #h.material.side = "DoubleSide"
1170            return h
1171        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw):
1172            kw.pop('linewidth', None)
1173            ipv_fix_color(kw)
1174            x, y, z = make_vec(x, y, Z)
1175            if triangles is not None:
1176                triangles = np.asarray(triangles)
1177            h = ipv.plot_trisurf(x, y, z, triangles=triangles, **kw)
1178            ipv_set_transparency(kw, h)
1179            return h
1180        def scatter(self, x, y, z, **kw):
1181            x, y, z = make_vec(x, y, z)
1182            map_colors(z, kw)
1183            marker = kw.get('marker', None)
1184            kw['marker'] = _IPV_MARKERS.get(marker, marker)
1185            h = ipv.scatter(x, y, z, **kw)
1186            ipv_set_transparency(kw, h)
1187            return h
1188        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw):
1189            # Don't use contour for now (although we might want to later)
1190            self.pcolor(x, y, v, zdir='z', offset=offset, **kw)
1191        def pcolor(self, x, y, v, zdir='z', offset=0, **kw):
1192            x, y, v = make_vec(x, y, v)
1193            image = make_image(v, kw)
1194            xmin, xmax = x.min(), x.max()
1195            ymin, ymax = y.min(), y.max()
1196            x = np.array([[xmin, xmax], [xmin, xmax]])
1197            y = np.array([[ymin, ymin], [ymax, ymax]])
1198            z = x*0 + offset
1199            u = np.array([[0., 1], [0, 1]])
1200            v = np.array([[0., 0], [1, 1]])
1201            h = ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False)
1202            ipv_set_transparency(kw, h)
1203            h.material.side = "DoubleSide"
1204            return h
1205        def text(self, *args, **kw):
1206            pass
1207        def set_xlim(self, limits):
1208            ipv.xlim(*limits)
1209        def set_ylim(self, limits):
1210            ipv.ylim(*limits)
1211        def set_zlim(self, limits):
1212            ipv.zlim(*limits)
1213        def set_axes_on(self):
1214            ipv.style.axis_on()
1215        def set_axis_off(self):
1216            ipv.style.axes_off()
1217    return Plotter()
1218
1219def ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
1220    import ipywidgets as widgets
1221    import ipyvolume as ipv
1222
1223    axes = ipv_axes()
1224
1225    def draw(view, jitter):
1226        camera = ipv.gcf().camera
1227        #print(ipv.gcf().__dict__.keys())
1228        #print(dir(ipv.gcf()))
1229        ipv.figure(animation=0.)  # no animation when updating object mesh
1230
1231        # set small jitter as 0 if multiple pd dims
1232        dims = sum(v > 0 for v in jitter)
1233        limit = [0, 0.5, 5, 5][dims]
1234        jitter = [0 if v < limit else v for v in jitter]
1235
1236        ## Visualize as person on globe
1237        #draw_beam(axes, (0, 0))
1238        #draw_sphere(axes)
1239        #draw_person_on_sphere(axes, view)
1240
1241        ## Move beam instead of shape
1242        #draw_beam(axes, view=(-view[0], -view[1]))
1243        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0))
1244
1245        ## Move shape and draw scattering
1246        draw_beam(axes, (0, 0), steps=25)
1247        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.0,
1248                    draw_shape=draw_shape, projection=projection)
1249        draw_mesh(axes, view, jitter, dist=dist, n=mesh, radius=0.95,
1250                  projection=projection)
1251        draw_scattering(calculator, axes, view, jitter, dist=dist)
1252
1253        draw_axes(axes, origin=(-1, -1, -1.1))
1254        ipv.style.box_off()
1255        ipv.style.axes_off()
1256        ipv.xyzlabel(" ", " ", " ")
1257
1258        ipv.gcf().camera = camera
1259        ipv.show()
1260
1261
1262    trange, prange = (-180., 180., 1.), (-180., 180., 1.)
1263    dtrange, dprange = (0., 180., 1.), (0., 180., 1.)
1264
1265    ## Super simple interfaca, but uses non-ascii variable namese
1266    # Ξ φ ψ Δξ Δφ Δψ
1267    #def update(**kw):
1268    #    view = kw['Ξ'], kw['φ'], kw['ψ']
1269    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ']
1270    #    draw(view, jitter)
1271    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange)
1272
1273    def update(theta, phi, psi, dtheta, dphi, dpsi):
1274        draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi))
1275
1276    def slider(name, slice, init=0.):
1277        return widgets.FloatSlider(
1278            value=init,
1279            min=slice[0],
1280            max=slice[1],
1281            step=slice[2],
1282            description=name,
1283            disabled=False,
1284            #continuous_update=True,
1285            continuous_update=False,
1286            orientation='horizontal',
1287            readout=True,
1288            readout_format='.1f',
1289            )
1290    theta = slider(u'Ξ', trange, view[0])
1291    phi = slider(u'φ', prange, view[1])
1292    psi = slider(u'ψ', prange, view[2])
1293    dtheta = slider(u'Δξ', dtrange, jitter[0])
1294    dphi = slider(u'Δφ', dprange, jitter[1])
1295    dpsi = slider(u'Δψ', dprange, jitter[2])
1296    fields = {
1297        'theta': theta, 'phi': phi, 'psi': psi,
1298        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi,
1299    }
1300    ui = widgets.HBox([
1301        widgets.VBox([theta, phi, psi]),
1302        widgets.VBox([dtheta, dphi, dpsi])
1303    ])
1304
1305    out = widgets.interactive_output(update, fields)
1306    display(ui, out)
1307
1308
1309_ENGINES = {
1310    "matplotlib": mpl_plot,
1311    "mpl": mpl_plot,
1312    #"plotly": plotly_plot,
1313    "ipvolume": ipv_plot,
1314    "ipv": ipv_plot,
1315}
1316PLOT_ENGINE = _ENGINES["matplotlib"]
1317def set_plotter(name):
1318    global PLOT_ENGINE
1319    PLOT_ENGINE = _ENGINES[name]
1320
1321def main():
1322    parser = argparse.ArgumentParser(
1323        description="Display jitter",
1324        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1325        )
1326    parser.add_argument('-p', '--projection', choices=PROJECTIONS,
1327                        default=PROJECTIONS[0],
1328                        help='coordinate projection')
1329    parser.add_argument('-s', '--size', type=str, default='10,40,100',
1330                        help='a,b,c lengths')
1331    parser.add_argument('-v', '--view', type=str, default='0,0,0',
1332                        help='initial view angles')
1333    parser.add_argument('-j', '--jitter', type=str, default='0,0,0',
1334                        help='initial angular dispersion')
1335    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS,
1336                        default=DISTRIBUTIONS[0],
1337                        help='jitter distribution')
1338    parser.add_argument('-m', '--mesh', type=int, default=30,
1339                        help='#points in theta-phi mesh')
1340    parser.add_argument('shape', choices=SHAPES, nargs='?', default=SHAPES[0],
1341                        help='oriented shape')
1342    opts = parser.parse_args()
1343    size = tuple(float(v) for v in opts.size.split(','))
1344    view = tuple(float(v) for v in opts.view.split(','))
1345    jitter = tuple(float(v) for v in opts.jitter.split(','))
1346    run(opts.shape, size=size, view=view, jitter=jitter,
1347        mesh=opts.mesh, dist=opts.distribution,
1348        projection=opts.projection)
1349
1350if __name__ == "__main__":
1351    main()
Note: See TracBrowser for help on using the repository browser.