source: sasmodels/sasmodels/jitter.py @ 275511c

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 275511c was 275511c, checked in by Paul Kienzle <pkienzle@…>, 18 months ago

restore matplotlib version of jitter viewer

  • Property mode set to 100755
File size: 50.5 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Jitter Explorer
5===============
6
7Application to explore orientation angle and angular dispersity.
8
9From the command line::
10
11    # Show docs
12    python -m sasmodels.jitter --help
13
14    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi
15    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0
16
17From a jupyter cell::
18
19    import ipyvolume as ipv
20    from sasmodels import jitter
21    import importlib; importlib.reload(jitter)
22    jitter.set_plotter("ipv")
23
24    size = (10, 40, 100)
25    view = (20, 0, 0)
26
27    #size = (15, 15, 100)
28    #view = (60, 60, 0)
29
30    dview = (0, 0, 0)
31    #dview = (5, 5, 0)
32    #dview = (15, 180, 0)
33    #dview = (180, 15, 0)
34
35    projection = 'equirectangular'
36    #projection = 'azimuthal_equidistance'
37    #projection = 'guyou'
38    #projection = 'sinusoidal'
39    #projection = 'azimuthal_equal_area'
40
41    dist = 'uniform'
42    #dist = 'gaussian'
43
44    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection)
45    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '')
46    #ipv.savefig(filename+'.png')
47"""
48from __future__ import division, print_function
49
50import argparse
51
52import numpy as np
53from numpy import pi, cos, sin, sqrt, exp, degrees, radians
54
55def draw_beam(axes, view=(0, 0), alpha=0.5, steps=25):
56    """
57    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
58    """
59    #axes.plot([0,0],[0,0],[1,-1])
60    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha)
61
62    u = np.linspace(0, 2 * np.pi, steps)
63    v = np.linspace(-1, 1, 2)
64
65    r = 0.02
66    x = r*np.outer(np.cos(u), np.ones_like(v))
67    y = r*np.outer(np.sin(u), np.ones_like(v))
68    z = 1.3*np.outer(np.ones_like(u), v)
69
70    theta, phi = view
71    shape = x.shape
72    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
73    points = Rz(phi)*Ry(theta)*points
74    x, y, z = [v.reshape(shape) for v in points]
75    axes.plot_surface(x, y, z, color='yellow', alpha=alpha)
76
77    # TODO: draw endcaps on beam
78    ## Drawing tiny balls on the end will work
79    #draw_sphere(axes, radius=0.02, center=(0, 0, 1.3), color='yellow', alpha=alpha)
80    #draw_sphere(axes, radius=0.02, center=(0, 0, -1.3), color='yellow', alpha=alpha)
81    ## The following does not work
82    #triangles = [(0, i+1, i+2) for i in range(steps-2)]
83    #x_cap, y_cap = x[:, 0], y[:, 0]
84    #for z_cap in z[:, 0], z[:, -1]:
85    #    axes.plot_trisurf(x_cap, y_cap, z_cap, triangles,
86    #                      color='yellow', alpha=alpha)
87
88
89def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1):
90    """Draw an ellipsoid."""
91    a, b, c = size
92    u = np.linspace(0, 2 * np.pi, steps)
93    v = np.linspace(0, np.pi, steps)
94    x = a*np.outer(np.cos(u), np.sin(v))
95    y = b*np.outer(np.sin(u), np.sin(v))
96    z = c*np.outer(np.ones_like(u), np.cos(v))
97    x, y, z = transform_xyz(view, jitter, x, y, z)
98
99    axes.plot_surface(x, y, z, color='w', alpha=alpha)
100
101    draw_labels(axes, view, jitter, [
102        ('c+', [+0, +0, +c], [+1, +0, +0]),
103        ('c-', [+0, +0, -c], [+0, +0, -1]),
104        ('a+', [+a, +0, +0], [+0, +0, +1]),
105        ('a-', [-a, +0, +0], [+0, +0, -1]),
106        ('b+', [+0, +b, +0], [-1, +0, +0]),
107        ('b-', [+0, -b, +0], [-1, +0, +0]),
108    ])
109
110def draw_sc(axes, size, view, jitter, steps=None, alpha=1):
111    """Draw points for simple cubic paracrystal"""
112    atoms = _build_sc()
113    _draw_crystal(axes, size, view, jitter, atoms=atoms)
114
115def draw_fcc(axes, size, view, jitter, steps=None, alpha=1):
116    """Draw points for face-centered cubic paracrystal"""
117    # Build the simple cubic crystal
118    atoms = _build_sc()
119    # Define the centers for each face
120    # x planes at -1, 0, 1 have four centers per plane, at +/- 0.5 in y and z
121    x, y, z = (
122        [-1]*4 + [0]*4 + [1]*4,
123        ([-0.5]*2 + [0.5]*2)*3,
124        [-0.5, 0.5]*12,
125    )
126    # y and z planes can be generated by substituting x for y and z respectively
127    atoms.extend(zip(x+y+z, y+z+x, z+x+y))
128    _draw_crystal(axes, size, view, jitter, atoms=atoms)
129
130def draw_bcc(axes, size, view, jitter, steps=None, alpha=1):
131    """Draw points for body-centered cubic paracrystal"""
132    # Build the simple cubic crystal
133    atoms = _build_sc()
134    # Define the centers for each octant
135    # x plane at +/- 0.5 have four centers per plane at +/- 0.5 in y and z
136    x, y, z = (
137        [-0.5]*4 + [0.5]*4,
138        ([-0.5]*2 + [0.5]*2)*2,
139        [-0.5, 0.5]*8,
140    )
141    atoms.extend(zip(x, y, z))
142    _draw_crystal(axes, size, view, jitter, atoms=atoms)
143
144def _draw_crystal(axes, size, view, jitter, atoms=None):
145    atoms, size = np.asarray(atoms, 'd').T, np.asarray(size, 'd')
146    x, y, z = atoms*size[:, None]
147    x, y, z = transform_xyz(view, jitter, x, y, z)
148    axes.scatter([x[0]], [y[0]], [z[0]], c='yellow', marker='o')
149    axes.scatter(x[1:], y[1:], z[1:], c='r', marker='o')
150
151def _build_sc():
152    # three planes of 9 dots for x at -1, 0 and 1
153    x, y, z = (
154        [-1]*9 + [0]*9 + [1]*9,
155        ([-1]*3 + [0]*3 + [1]*3)*3,
156        [-1, 0, 1]*9,
157    )
158    atoms = list(zip(x, y, z))
159    #print(list(enumerate(atoms)))
160    # Pull the dot at (0, 0, 1) to the front of the list
161    # It will be highlighted in the view
162    index = 14
163    highlight = atoms[index]
164    del atoms[index]
165    atoms.insert(0, highlight)
166    return atoms
167
168def draw_box(axes, size, view):
169    a, b, c = size
170    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
171    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
172    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
173    x, y, z = transform_xyz(view, None, x, y, z)
174    def draw(i, j):
175        axes.plot([x[i],x[j]], [y[i], y[j]], [z[i], z[j]], color='black')
176    draw(0, 1)
177    draw(0, 2)
178    draw(0, 3)
179    draw(7, 4)
180    draw(7, 5)
181    draw(7, 6)
182
183def draw_parallelepiped(axes, size, view, jitter, steps=None,
184                        color=(0.6, 1.0, 0.6), alpha=1):
185    """Draw a parallelepiped."""
186    a, b, c = size
187    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
188    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
189    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
190    tri = np.array([
191        # counter clockwise triangles
192        # z: up/down, x: right/left, y: front/back
193        [0, 1, 2], [3, 2, 1], # top face
194        [6, 5, 4], [5, 6, 7], # bottom face
195        [0, 2, 6], [6, 4, 0], # right face
196        [1, 5, 7], [7, 3, 1], # left face
197        [2, 3, 6], [7, 6, 3], # front face
198        [4, 1, 0], [5, 1, 4], # back face
199    ])
200
201    x, y, z = transform_xyz(view, jitter, x, y, z)
202    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha,
203                      linewidth=0)
204
205    # Colour the c+ face of the box.
206    # Since I can't control face color, instead draw a thin box situated just
207    # in front of the "c+" face.  Use the c face so that rotations about psi
208    # rotate that face.
209    if 0:
210        color = (1, 0.6, 0.6)  # pink
211        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
212        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
213        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
214        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001)
215        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
216
217    draw_labels(axes, view, jitter, [
218        ('c+', [+0, +0, +c], [+1, +0, +0]),
219        ('c-', [+0, +0, -c], [+0, +0, -1]),
220        ('a+', [+a, +0, +0], [+0, +0, +1]),
221        ('a-', [-a, +0, +0], [+0, +0, -1]),
222        ('b+', [+0, +b, +0], [-1, +0, +0]),
223        ('b-', [+0, -b, +0], [-1, +0, +0]),
224    ])
225
226def draw_sphere(axes, radius=0.5, steps=25, center=(0,0,0), color='w', alpha=1.):
227    """Draw a sphere"""
228    u = np.linspace(0, 2 * np.pi, steps)
229    v = np.linspace(0, np.pi, steps)
230
231    x = radius * np.outer(np.cos(u), np.sin(v)) + center[0]
232    y = radius * np.outer(np.sin(u), np.sin(v)) + center[1]
233    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2]
234    axes.plot_surface(x, y, z, color=color, alpha=alpha)
235    #axes.plot_wireframe(x, y, z)
236
237def draw_axes(axes, origin=(-1, -1, -1), length=(2, 2, 2)):
238    x, y, z = origin
239    dx, dy, dz = length
240    axes.plot([x, x+dx], [y, y], [z, z], color='black')
241    axes.plot([x, x], [y, y+dy], [z, z], color='black')
242    axes.plot([x, x], [y, y], [z, z+dz], color='black')
243
244def draw_person_on_sphere(axes, view, height=0.5, radius=0.5):
245    limb_offset = height * 0.05
246    head_radius = height * 0.10
247    head_height = height - head_radius
248    neck_length = head_radius * 0.50
249    shoulder_height = height - 2*head_radius - neck_length
250    torso_length = shoulder_height * 0.55
251    torso_radius = torso_length * 0.30
252    leg_length = shoulder_height - torso_length
253    arm_length = torso_length * 0.90
254
255    def _draw_part(x, z):
256        y = np.zeros_like(x)
257        xp, yp, zp = transform_xyz(view, None, x, y, z + radius)
258        axes.plot(xp, yp, zp, color='k')
259
260    # circle for head
261    u = np.linspace(0, 2 * np.pi, 40)
262    x = head_radius * np.cos(u)
263    z = head_radius * np.sin(u) + head_height
264    _draw_part(x, z)
265
266    # rectangle for body
267    x = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius])
268    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length
269    _draw_part(x, z)
270
271    # arms
272    x = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius])
273    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height])
274    _draw_part(x, z)
275    _draw_part(-x, z)
276
277    # legs
278    x = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset])
279    z = np.array([0, leg_length])
280    _draw_part(x, z)
281    _draw_part(-x, z)
282
283    limits = [-radius-height, radius+height]
284    axes.set_xlim(limits)
285    axes.set_ylim(limits)
286    axes.set_zlim(limits)
287    axes.set_axis_off()
288
289def draw_jitter(axes, view, jitter, dist='gaussian',
290                size=(0.1, 0.4, 1.0),
291                draw_shape=draw_parallelepiped,
292                projection='equirectangular',
293                alpha=0.8,
294                views=None):
295    """
296    Represent jitter as a set of shapes at different orientations.
297    """
298    project, project_weight = get_projection(projection)
299
300    # set max diagonal to 0.95
301    scale = 0.95/sqrt(sum(v**2 for v in size))
302    size = tuple(scale*v for v in size)
303
304    dtheta, dphi, dpsi = jitter
305    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist]
306    def steps(delta):
307        if views is None:
308            n = max(3, min(25, 2*int(base*delta/5)))
309        else:
310            n = views
311        return base*delta*np.linspace(-1, 1, n) if delta > 0 else [0.]
312    for theta in steps(dtheta):
313        for phi in steps(dphi):
314            for psi in steps(dpsi):
315                w = project_weight(theta, phi, 1.0, 1.0)
316                if w > 0:
317                    dview = project(theta, phi, psi)
318                    draw_shape(axes, size, view, dview, alpha=alpha)
319    for v in 'xyz':
320        a, b, c = size
321        lim = np.sqrt(a**2 + b**2 + c**2)
322        getattr(axes, 'set_'+v+'lim')([-lim, lim])
323        #getattr(axes, v+'axis').label.set_text(v)
324
325PROJECTIONS = [
326    # in order of PROJECTION number; do not change without updating the
327    # constants in kernel_iq.c
328    'equirectangular', 'sinusoidal', 'guyou', 'azimuthal_equidistance',
329    'azimuthal_equal_area',
330]
331def get_projection(projection):
332
333    """
334    jitter projections
335    <https://en.wikipedia.org/wiki/List_of_map_projections>
336
337    equirectangular (standard latitude-longitude mesh)
338        <https://en.wikipedia.org/wiki/Equirectangular_projection>
339        Allows free movement in phi (around the equator), but theta is
340        limited to +/- 90, and points are cos-weighted. Jitter in phi is
341        uniform in weight along a line of latitude.  With small theta and
342        phi ranging over +/- 180 this forms a wobbling disk.  With small
343        phi and theta ranging over +/- 90 this forms a wedge like a slice
344        of an orange.
345    azimuthal_equidistance (Postel)
346        <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>
347        Preserves distance from center, and so is an excellent map for
348        representing a bivariate gaussian on the surface.  Theta and phi
349        operate identically, cutting wegdes from the antipode of the viewing
350        angle.  This unfortunately does not allow free movement in either
351        theta or phi since the orthogonal wobble decreases to 0 as the body
352        rotates through 180 degrees.
353    sinusoidal (Sanson-Flamsteed, Mercator equal-area)
354        <https://en.wikipedia.org/wiki/Sinusoidal_projection>
355        Preserves arc length with latitude, giving bad behaviour at
356        theta near +/- 90.  Theta and phi operate somewhat differently,
357        so a system with a-b-c dtheta-dphi-dpsi will not give the same
358        value as one with b-a-c dphi-dtheta-dpsi, as would be the case
359        for azimuthal equidistance.  Free movement using theta or phi
360        uniform over +/- 180 will work, but not as well as equirectangular
361        phi, with theta being slightly worse.  Computationally it is much
362        cheaper for wide theta-phi meshes since it excludes points which
363        lie outside the sinusoid near theta +/- 90 rather than packing
364        them close together as in equirectangle.  Note that the poles
365        will be slightly overweighted for theta > 90 with the circle
366        from theta at 90+dt winding backwards around the pole, overlapping
367        the circle from theta at 90-dt.
368    Guyou (hemisphere-in-a-square) **not weighted**
369        <https://en.wikipedia.org/wiki/Guyou_hemisphere-in-a-square_projection>
370        With tiling, allows rotation in phi or theta through +/- 180, with
371        uniform spacing.  Both theta and phi allow free rotation, with wobble
372        in the orthogonal direction reasonably well behaved (though not as
373        good as equirectangular phi). The forward/reverse transformations
374        relies on elliptic integrals that are somewhat expensive, so the
375        behaviour has to be very good to justify the cost and complexity.
376        The weighting function for each point has not yet been computed.
377        Note: run the module *guyou.py* directly and it will show the forward
378        and reverse mappings.
379    azimuthal_equal_area  **incomplete**
380        <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>
381        Preserves the relative density of the surface patches.  Not that
382        useful and not completely implemented
383    Gauss-Kreuger **not implemented**
384        <https://en.wikipedia.org/wiki/Transverse_Mercator_projection#Ellipsoidal_transverse_Mercator>
385        Should allow free movement in theta, but phi is distorted.
386    """
387    # TODO: try Kent distribution instead of a gaussian warped by projection
388
389    if projection == 'equirectangular':  #define PROJECTION 1
390        def _project(theta_i, phi_j, psi):
391            latitude, longitude = theta_i, phi_j
392            return latitude, longitude, psi
393            #return Rx(phi_j)*Ry(theta_i)
394        def _weight(theta_i, phi_j, w_i, w_j):
395            return w_i*w_j*abs(cos(radians(theta_i)))
396    elif projection == 'sinusoidal':  #define PROJECTION 2
397        def _project(theta_i, phi_j, psi):
398            latitude = theta_i
399            scale = cos(radians(latitude))
400            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0
401            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
402            return latitude, longitude, psi
403            #return Rx(longitude)*Ry(latitude)
404        def _project(theta_i, phi_j, w_i, w_j):
405            latitude = theta_i
406            scale = cos(radians(latitude))
407            active = 1 if abs(phi_j) < abs(scale)*180 else 0
408            return active*w_i*w_j
409    elif projection == 'guyou':  #define PROJECTION 3  (eventually?)
410        def _project(theta_i, phi_j, psi):
411            from .guyou import guyou_invert
412            #latitude, longitude = guyou_invert([theta_i], [phi_j])
413            longitude, latitude = guyou_invert([phi_j], [theta_i])
414            return latitude, longitude, psi
415            #return Rx(longitude[0])*Ry(latitude[0])
416        def _weight(theta_i, phi_j, w_i, w_j):
417            return w_i*w_j
418    elif projection == 'azimuthal_equidistance':
419        # Note that calculates angles for Rz Ry rather than Rx Ry
420        def _project(theta_i, phi_j, psi):
421            latitude = sqrt(theta_i**2 + phi_j**2)
422            longitude = degrees(np.arctan2(phi_j, theta_i))
423            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
424            return latitude, longitude, psi-longitude, 'zyz'
425            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
426            #return R_to_xyz(R)
427            #return Rz(longitude)*Ry(latitude)
428        def _weight(theta_i, phi_j, w_i, w_j):
429            # Weighting for each point comes from the integral:
430            #     \int\int I(q, lat, log) sin(lat) dlat dlog
431            # We are doing a conformal mapping from disk to sphere, so we need
432            # a change of variables g(theta, phi) -> (lat, long):
433            #     lat, long = sqrt(theta^2 + phi^2), arctan(phi/theta)
434            # giving:
435            #     dtheta dphi = det(J) dlat dlong
436            # where J is the jacobian from the partials of g. Using
437            #     R = sqrt(theta^2 + phi^2),
438            # then
439            #     J = [[x/R, Y/R], -y/R^2, x/R^2]]
440            # and
441            #     det(J) = 1/R
442            # with the final integral being:
443            #    \int\int I(q, theta, phi) sin(R)/R dtheta dphi
444            #
445            # This does approximately the right thing, decreasing the weight
446            # of each point as you go farther out on the disk, but it hasn't
447            # yet been checked against the 1D integral results. Prior
448            # to declaring this "good enough" and checking that integrals
449            # work in practice, we will examine alternative mappings.
450            #
451            # The issue is that the mapping does not support the case of free
452            # rotation about a single axis correctly, with a small deviation
453            # in the orthogonal axis independent of the first axis.  Like the
454            # usual polar coordiates integration, the integrated sections
455            # form wedges, though at least in this case the wedge cuts through
456            # the entire sphere, and treats theta and phi identically.
457            latitude = sqrt(theta_i**2 + phi_j**2)
458            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
459            return weight*w_i*w_j if latitude < 180 else 0
460    elif projection == 'azimuthal_equal_area':
461        # Note that calculates angles for Rz Ry rather than Rx Ry
462        def _project(theta_i, phi_j, psi):
463            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180)
464            latitude = 180-degrees(2*np.arccos(radius))
465            longitude = degrees(np.arctan2(phi_j, theta_i))
466            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
467            return latitude, longitude, psi, "zyz"
468            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
469            #return R_to_xyz(R)
470            #return Rz(longitude)*Ry(latitude)
471        def _weight(theta_i, phi_j, w_i, w_j):
472            latitude = sqrt(theta_i**2 + phi_j**2)
473            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
474            return weight*w_i*w_j if latitude < 180 else 0
475    else:
476        raise ValueError("unknown projection %r"%projection)
477
478    return _project, _weight
479
480def R_to_xyz(R):
481    """
482    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix.
483
484    Extracting Euler Angles from a Rotation Matrix
485    Mike Day, Insomniac Games
486    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf
487    Based on: Shoemake’s "Euler Angle Conversion", Graphics Gems IV, pp.  222-229
488    """
489    phi = np.arctan2(R[1, 2], R[2, 2])
490    theta = np.arctan2(-R[0, 2], np.sqrt(R[0, 0]**2 + R[0, 1]**2))
491    psi = np.arctan2(R[0, 1], R[0, 0])
492    return np.degrees(phi), np.degrees(theta), np.degrees(psi)
493
494def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian',
495              projection='equirectangular'):
496    """
497    Draw the dispersion mesh showing the theta-phi orientations at which
498    the model will be evaluated.
499    """
500
501    _project, _weight = get_projection(projection)
502    def _rotate(theta, phi, z):
503        dview = _project(theta, phi, 0.)
504        if len(dview) == 4: # hack for zyz coords
505            return Rz(dview[1])*Ry(dview[0])*z
506        else:
507            return Rx(dview[1])*Ry(dview[0])*z
508
509
510    dist_x = np.linspace(-1, 1, n)
511    weights = np.ones_like(dist_x)
512    if dist == 'gaussian':
513        dist_x *= 3
514        weights = exp(-0.5*dist_x**2)
515    elif dist == 'rectangle':
516        # Note: uses sasmodels ridiculous definition of rectangle width
517        dist_x *= sqrt(3)
518    elif dist == 'uniform':
519        pass
520    else:
521        raise ValueError("expected dist to be gaussian, rectangle or uniform")
522
523    # mesh in theta, phi formed by rotating z
524    dtheta, dphi, dpsi = jitter
525    z = np.matrix([[0], [0], [radius]])
526    points = np.hstack([_rotate(theta_i, phi_j, z)
527                        for theta_i in dtheta*dist_x
528                        for phi_j in dphi*dist_x])
529    dist_w = np.array([_weight(theta_i, phi_j, w_i, w_j)
530                       for w_i, theta_i in zip(weights, dtheta*dist_x)
531                       for w_j, phi_j in zip(weights, dphi*dist_x)])
532    #print(max(dist_w), min(dist_w), min(dist_w[dist_w > 0]))
533    points = points[:, dist_w > 0]
534    dist_w = dist_w[dist_w > 0]
535    dist_w /= max(dist_w)
536
537    # rotate relative to beam
538    points = orient_relative_to_beam(view, points)
539
540    x, y, z = [np.array(v).flatten() for v in points]
541    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
542    axes.scatter(x, y, z, c=dist_w, marker='o', vmin=0., vmax=1.)
543
544def draw_labels(axes, view, jitter, text):
545    """
546    Draw text at a particular location.
547    """
548    labels, locations, orientations = zip(*text)
549    px, py, pz = zip(*locations)
550    dx, dy, dz = zip(*orientations)
551
552    px, py, pz = transform_xyz(view, jitter, px, py, pz)
553    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
554
555    # TODO: zdir for labels is broken, and labels aren't appearing.
556    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
557        zdir = np.asarray(zdir).flatten()
558        axes.text(p[0], p[1], p[2], label, zdir=zdir)
559
560# Definition of rotation matrices comes from wikipedia:
561#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
562def Rx(angle):
563    """Construct a matrix to rotate points about *x* by *angle* degrees."""
564    angle = radians(angle)
565    rot = [[1, 0, 0],
566           [0, +cos(angle), -sin(angle)],
567           [0, +sin(angle), +cos(angle)]]
568    return np.matrix(rot)
569
570def Ry(angle):
571    """Construct a matrix to rotate points about *y* by *angle* degrees."""
572    angle = radians(angle)
573    rot = [[+cos(angle), 0, +sin(angle)],
574           [0, 1, 0],
575           [-sin(angle), 0, +cos(angle)]]
576    return np.matrix(rot)
577
578def Rz(angle):
579    """Construct a matrix to rotate points about *z* by *angle* degrees."""
580    angle = radians(angle)
581    rot = [[+cos(angle), -sin(angle), 0],
582           [+sin(angle), +cos(angle), 0],
583           [0, 0, 1]]
584    return np.matrix(rot)
585
586def transform_xyz(view, jitter, x, y, z):
587    """
588    Send a set of (x,y,z) points through the jitter and view transforms.
589    """
590    x, y, z = [np.asarray(v) for v in (x, y, z)]
591    shape = x.shape
592    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
593    points = apply_jitter(jitter, points)
594    points = orient_relative_to_beam(view, points)
595    x, y, z = [np.array(v).reshape(shape) for v in points]
596    return x, y, z
597
598def apply_jitter(jitter, points):
599    """
600    Apply the jitter transform to a set of points.
601
602    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
603    """
604    if jitter is None:
605        return points
606    # Hack to deal with the fact that azimuthal_equidistance uses euler angles
607    if len(jitter) == 4:
608        dtheta, dphi, dpsi, _ = jitter
609        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points
610    else:
611        dtheta, dphi, dpsi = jitter
612        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
613    return points
614
615def orient_relative_to_beam(view, points):
616    """
617    Apply the view transform to a set of points.
618
619    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
620    """
621    theta, phi, psi = view
622    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle
623    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles
624    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle
625    return points
626
627def orient_relative_to_beam_quaternion(view, points):
628    """
629    Apply the view transform to a set of points.
630
631    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
632
633    This variant uses quaternions rather than rotation matrices for the
634    computation.  It works but it is not used because it doesn't solve
635    any problems.  The challenge of mapping theta/phi/psi to SO(3) does
636    not disappear by calculating the transform differently.
637    """
638    theta, phi, psi = view
639    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1]
640    q = Quaternion(1, [0, 0, 0])
641    ## Compose a rotation about the three axes by rotating
642    ## the unit vectors before applying the rotation.
643    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q
644    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q
645    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q
646    ## The above turns out to be equivalent to reversing
647    ## the order of application, so ignore it and use below.
648    q = q * Quaternion.from_angle_axis(theta, x)
649    q = q * Quaternion.from_angle_axis(phi, y)
650    q = q * Quaternion.from_angle_axis(psi, z)
651    ## Reverse the order by post-multiply rather than pre-multiply
652    #q = Quaternion.from_angle_axis(theta, x) * q
653    #q = Quaternion.from_angle_axis(phi, y) * q
654    #q = Quaternion.from_angle_axis(psi, z) * q
655    #print("axes psi", q.rot(np.matrix([x, y, z]).T))
656    return q.rot(points)
657#orient_relative_to_beam = orient_relative_to_beam_quaternion
658
659# Simple stand-alone quaternion class
660import numpy as np
661from copy import copy
662class Quaternion(object):
663    def __init__(self, w, r):
664         self.w = w
665         self.r = np.asarray(r, dtype='d')
666    @staticmethod
667    def from_angle_axis(theta, r):
668         theta = np.radians(theta)/2
669         r = np.asarray(r)
670         w = np.cos(theta)
671         r = np.sin(theta)*r/np.dot(r,r)
672         return Quaternion(w, r)
673    def __mul__(self, other):
674        if isinstance(other, Quaternion):
675            w = self.w*other.w - np.dot(self.r, other.r)
676            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r)
677            return Quaternion(w, r)
678    def rot(self, v):
679        v = np.asarray(v).T
680        use_transpose = (v.shape[-1] != 3)
681        if use_transpose: v = v.T
682        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v)
683        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v))
684        if use_transpose: v = v.T
685        return v.T
686    def conj(self):
687        return Quaternion(self.w, -self.r)
688    def inv(self):
689        return self.conj()/self.norm()**2
690    def norm(self):
691        return np.sqrt(self.w**2 + np.sum(self.r**2))
692    def __str__(self):
693        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2])
694def test_qrot():
695    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees from y.
696    # The rotation axis is determined by rotating the point [0, 1, 0] about x.
697    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0])
698    q = Quaternion.from_angle_axis(60, ax)
699    # Set the point to be rotated, and its expected rotated position.
700    p = [1, -1, 2]
701    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8]
702    #print(q, q.rot(p) - target)
703    assert max(abs(q.rot(p) - target)) < 1e-14
704#test_qrot()
705#import sys; sys.exit()
706
707# translate between number of dimension of dispersity and the number of
708# points along each dimension.
709PD_N_TABLE = {
710    (0, 0, 0): (0, 0, 0),     # 0
711    (1, 0, 0): (100, 0, 0),   # 100
712    (0, 1, 0): (0, 100, 0),
713    (0, 0, 1): (0, 0, 100),
714    (1, 1, 0): (30, 30, 0),   # 900
715    (1, 0, 1): (30, 0, 30),
716    (0, 1, 1): (0, 30, 30),
717    (1, 1, 1): (15, 15, 15),  # 3375
718}
719
720def clipped_range(data, portion=1.0, mode='central'):
721    """
722    Determine range from data.
723
724    If *portion* is 1, use full range, otherwise use the center of the range
725    or the top of the range, depending on whether *mode* is 'central' or 'top'.
726    """
727    if portion == 1.0:
728        return data.min(), data.max()
729    elif mode == 'central':
730        data = np.sort(data.flatten())
731        offset = int(portion*len(data)/2 + 0.5)
732        return data[offset], data[-offset]
733    elif mode == 'top':
734        data = np.sort(data.flatten())
735        offset = int(portion*len(data) + 0.5)
736        return data[offset], data[-1]
737
738def draw_scattering(calculator, axes, view, jitter, dist='gaussian'):
739    """
740    Plot the scattering for the particular view.
741
742    *calculator* is returned from :func:`build_model`.  *axes* are the 3D axes
743    on which the data will be plotted.  *view* and *jitter* are the current
744    orientation and orientation dispersity.  *dist* is one of the sasmodels
745    weight distributions.
746    """
747    if dist == 'uniform':  # uniform is not yet in this branch
748        dist, scale = 'rectangle', 1/sqrt(3)
749    else:
750        scale = 1
751
752    # add the orientation parameters to the model parameters
753    theta, phi, psi = view
754    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
755    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd > 0, phi_pd > 0, psi_pd > 0)]
756    ## increase pd_n for testing jitter integration rather than simple viz
757    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
758
759    pars = dict(
760        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
761        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
762        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
763    )
764    pars.update(calculator.pars)
765
766    # compute the pattern
767    qx, qy = calculator._data.x_bins, calculator._data.y_bins
768    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
769
770    # scale it and draw it
771    Iqxy = np.log(Iqxy)
772    if calculator.limits:
773        # use limits from orientation (0,0,0)
774        vmin, vmax = calculator.limits
775    else:
776        vmax = Iqxy.max()
777        vmin = vmax*10**-7
778        #vmin, vmax = clipped_range(Iqxy, portion=portion, mode='top')
779    #vmin, vmax = Iqxy.min(), Iqxy.max()
780    #print("range",(vmin,vmax))
781    #qx, qy = np.meshgrid(qx, qy)
782    if 0:
783        from matplotlib import cm
784        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
785        level[level < 0] = 0
786        colors = plt.get_cmap()(level)
787        #colors = cm.coolwarm(level)
788        #colors = cm.gist_yarg(level)
789        #colors = cm.Wistia(level)
790        colors[level<=0, 3] = 0.  # set floor to transparent
791        x, y = np.meshgrid(qx/qx.max(), qy/qy.max())
792        axes.plot_surface(x, y, -1.1*np.ones_like(x), facecolors=colors)
793    elif 1:
794        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
795                      levels=np.linspace(vmin, vmax, 24))
796    else:
797        axes.pcolormesh(qx, qy, Iqxy)
798
799def build_model(model_name, n=150, qmax=0.5, **pars):
800    """
801    Build a calculator for the given shape.
802
803    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
804    on which to evaluate the model.  The remaining parameters are stored in
805    the returned calculator as *calculator.pars*.  They are used by
806    :func:`draw_scattering` to set the non-orientation parameters in the
807    calculation.
808
809    Returns a *calculator* function which takes a dictionary or parameters and
810    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
811    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
812    for details.
813    """
814    from sasmodels.core import load_model_info, build_model as build_sasmodel
815    from sasmodels.data import empty_data2D
816    from sasmodels.direct_model import DirectModel
817
818    model_info = load_model_info(model_name)
819    model = build_sasmodel(model_info) #, dtype='double!')
820    q = np.linspace(-qmax, qmax, n)
821    data = empty_data2D(q, q)
822    calculator = DirectModel(data, model)
823
824    # stuff the values for non-orientation parameters into the calculator
825    calculator.pars = pars.copy()
826    calculator.pars.setdefault('backgound', 1e-3)
827
828    # fix the data limits so that we can see if the pattern fades
829    # under rotation or angular dispersion
830    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
831    Iqxy = np.log(Iqxy)
832    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
833    calculator.limits = vmin, vmax+1
834
835    return calculator
836
837def select_calculator(model_name, n=150, size=(10, 40, 100)):
838    """
839    Create a model calculator for the given shape.
840
841    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
842    parallelepiped or bcc_paracrystal. *n* is the number of points to use
843    in the q range.  *qmax* is chosen based on model parameters for the
844    given model to show something intersting.
845
846    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
847    equitorial axes and polar axis respectively.  See :func:`build_model`
848    for details on the returned calculator.
849    """
850    a, b, c = size
851    d_factor = 0.06  # for paracrystal models
852    if model_name == 'sphere':
853        calculator = build_model('sphere', n=n, radius=c)
854        a = b = c
855    elif model_name == 'sc_paracrystal':
856        a = b = c
857        dnn = c
858        radius = 0.5*c
859        calculator = build_model('sc_paracrystal', n=n, dnn=dnn,
860                                 d_factor=d_factor, radius=(1-d_factor)*radius,
861                                 background=0)
862    elif model_name == 'fcc_paracrystal':
863        a = b = c
864        # nearest neigbour distance dnn should be 2 radius, but I think the
865        # model uses lattice spacing rather than dnn in its calculations
866        dnn = 0.5*c
867        radius = sqrt(2)/4 * c
868        calculator = build_model('fcc_paracrystal', n=n, dnn=dnn,
869                                 d_factor=d_factor, radius=(1-d_factor)*radius,
870                                 background=0)
871    elif model_name == 'bcc_paracrystal':
872        a = b = c
873        # nearest neigbour distance dnn should be 2 radius, but I think the
874        # model uses lattice spacing rather than dnn in its calculations
875        dnn = 0.5*c
876        radius = sqrt(3)/2 * c
877        calculator = build_model('bcc_paracrystal', n=n, dnn=dnn,
878                                 d_factor=d_factor, radius=(1-d_factor)*radius,
879                                 background=0)
880    elif model_name == 'cylinder':
881        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
882        a = b
883    elif model_name == 'ellipsoid':
884        calculator = build_model('ellipsoid', n=n, qmax=1.0,
885                                 radius_polar=c, radius_equatorial=b)
886        a = b
887    elif model_name == 'triaxial_ellipsoid':
888        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
889                                 radius_equat_minor=a,
890                                 radius_equat_major=b,
891                                 radius_polar=c)
892    elif model_name == 'parallelepiped':
893        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
894    else:
895        raise ValueError("unknown model %s"%model_name)
896
897    return calculator, (a, b, c)
898
899SHAPES = [
900    'parallelepiped',
901    'sphere', 'ellipsoid', 'triaxial_ellipsoid',
902    'cylinder',
903    'fcc_paracrystal', 'bcc_paracrystal', 'sc_paracrystal',
904]
905
906DRAW_SHAPES = {
907    'fcc_paracrystal': draw_fcc,
908    'bcc_paracrystal': draw_bcc,
909    'sc_paracrystal': draw_sc,
910    'parallelepiped': draw_parallelepiped,
911}
912
913DISTRIBUTIONS = [
914    'gaussian', 'rectangle', 'uniform',
915]
916DIST_LIMITS = {
917    'gaussian': 30,
918    'rectangle': 90/sqrt(3),
919    'uniform': 90,
920}
921
922
923def run(model_name='parallelepiped', size=(10, 40, 100),
924        view=(0, 0, 0), jitter=(0, 0, 0),
925        dist='gaussian', mesh=30,
926        projection='equirectangular'):
927    """
928    Show an interactive orientation and jitter demo.
929
930    *model_name* is one of: sphere, ellipsoid, triaxial_ellipsoid,
931    parallelepiped, cylinder, or sc/fcc/bcc_paracrystal
932
933    *size* gives the dimensions (a, b, c) of the shape.
934
935    *view* gives the initial view (theta, phi, psi) of the shape.
936
937    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape.
938
939    *dist* is the type of dispersition: gaussian, rectangle, or uniform.
940
941    *mesh* is the number of points in the dispersion mesh.
942
943    *projection* is the map projection to use for the mesh: equirectangular,
944    sinusoidal, guyou, azimuthal_equidistance, or azimuthal_equal_area.
945    """
946    # projection number according to 1-order position in list, but
947    # only 1 and 2 are implemented so far.
948    from sasmodels import generate
949    generate.PROJECTION = PROJECTIONS.index(projection) + 1
950    if generate.PROJECTION > 2:
951        print("*** PROJECTION %s not implemented in scattering function ***"%projection)
952        generate.PROJECTION = 2
953
954    # set up calculator
955    calculator, size = select_calculator(model_name, n=150, size=size)
956    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped)
957    #draw_shape = draw_fcc
958
959    ## uncomment to set an independent the colour range for every view
960    ## If left commented, the colour range is fixed for all views
961    calculator.limits = None
962
963    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection)
964
965def mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
966    import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot
967    import matplotlib.pyplot as plt
968    from matplotlib.widgets import Slider
969
970    ## create the plot window
971    #plt.hold(True)
972    plt.subplots(num=None, figsize=(5.5, 5.5))
973    plt.set_cmap('gist_earth')
974    plt.clf()
975    plt.gcf().canvas.set_window_title(projection)
976    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
977    #axes = plt.subplot(gs[0], projection='3d')
978    axes = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
979    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
980        axes.axis('square')
981    except Exception:
982        pass
983
984    #axcolor = {'facecolor': 'lightgoldenrodyellow'}
985    axcolor = {}
986
987    ## add control widgets to plot
988    axes_theta = plt.axes([0.1, 0.15, 0.45, 0.04], **axcolor)
989    axes_phi = plt.axes([0.1, 0.1, 0.45, 0.04], **axcolor)
990    axes_psi = plt.axes([0.1, 0.05, 0.45, 0.04], **axcolor)
991    stheta = Slider(axes_theta, u'Ξ', -90, 90, valinit=0)
992    sphi = Slider(axes_phi, u'φ', -180, 180, valinit=0)
993    spsi = Slider(axes_psi, u'ψ', -180, 180, valinit=0)
994
995    axes_dtheta = plt.axes([0.75, 0.15, 0.15, 0.04], **axcolor)
996    axes_dphi = plt.axes([0.75, 0.1, 0.15, 0.04], **axcolor)
997    axes_dpsi = plt.axes([0.75, 0.05, 0.15, 0.04], **axcolor)
998    # Note: using ridiculous definition of rectangle distribution, whose width
999    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
1000    # the maximum width to 90.
1001    dlimit = DIST_LIMITS[dist]
1002    sdtheta = Slider(axes_dtheta, u'Δξ', 0, 2*dlimit, valinit=0)
1003    sdphi = Slider(axes_dphi, u'Δφ', 0, 2*dlimit, valinit=0)
1004    sdpsi = Slider(axes_dpsi, u'Δψ', 0, 2*dlimit, valinit=0)
1005
1006    ## initial view and jitter
1007    theta, phi, psi = view
1008    stheta.set_val(theta)
1009    sphi.set_val(phi)
1010    spsi.set_val(psi)
1011    dtheta, dphi, dpsi = jitter
1012    sdtheta.set_val(dtheta)
1013    sdphi.set_val(dphi)
1014    sdpsi.set_val(dpsi)
1015
1016    ## callback to draw the new view
1017    def update(val, axis=None):
1018        view = stheta.val, sphi.val, spsi.val
1019        jitter = sdtheta.val, sdphi.val, sdpsi.val
1020        # set small jitter as 0 if multiple pd dims
1021        dims = sum(v > 0 for v in jitter)
1022        limit = [0, 0.5, 5, 5][dims]
1023        jitter = [0 if v < limit else v for v in jitter]
1024        axes.cla()
1025
1026        ## Visualize as person on globe
1027        #draw_sphere(axes)
1028        #draw_person_on_sphere(axes, view)
1029
1030        ## Move beam instead of shape
1031        #draw_beam(axes, -view[:2])
1032        #draw_jitter(axes, (0,0,0), (0,0,0), views=3)
1033
1034        ## Move shape and draw scattering
1035        draw_beam(axes, (0, 0), alpha=1.)
1036        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.,
1037                    draw_shape=draw_shape, projection=projection, views=3)
1038        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection)
1039        draw_scattering(calculator, axes, view, jitter, dist=dist)
1040
1041        plt.gcf().canvas.draw()
1042
1043    ## bind control widgets to view updater
1044    stheta.on_changed(lambda v: update(v, 'theta'))
1045    sphi.on_changed(lambda v: update(v, 'phi'))
1046    spsi.on_changed(lambda v: update(v, 'psi'))
1047    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
1048    sdphi.on_changed(lambda v: update(v, 'dphi'))
1049    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
1050
1051    ## initialize view
1052    update(None, 'phi')
1053
1054    ## go interactive
1055    plt.show()
1056
1057
1058def map_colors(z, kw):
1059    from matplotlib import cm
1060
1061    cmap = kw.pop('cmap', cm.coolwarm)
1062    alpha = kw.pop('alpha', None)
1063    vmin = kw.pop('vmin', z.min())
1064    vmax = kw.pop('vmax', z.max())
1065    c = kw.pop('c', None)
1066    color = kw.pop('color', c)
1067    if color is None:
1068        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1)
1069        color = cmap(znorm)
1070    elif isinstance(color, np.ndarray) and color.shape == z.shape:
1071        color = cmap(color)
1072    if alpha is None:
1073        if isinstance(color, np.ndarray):
1074            color = color[..., :3]
1075    else:
1076        color[..., 3] = alpha
1077    kw['color'] = color
1078
1079def make_vec(*args):
1080    #return [np.asarray(v, 'd').flatten() for v in args]
1081    return [np.asarray(v, 'd') for v in args]
1082
1083def make_image(z, kw):
1084    import PIL.Image
1085    from matplotlib import cm
1086
1087    cmap = kw.pop('cmap', cm.coolwarm)
1088
1089    znorm = (z-z.min())/z.ptp()
1090    c = cmap(znorm)
1091    c = c[..., :3]
1092    rgb = np.asarray(c*255, 'u1')
1093    image = PIL.Image.fromarray(rgb, mode='RGB')
1094    return image
1095
1096
1097_IPV_MARKERS = {
1098    'o': 'sphere',
1099}
1100_IPV_COLORS = {
1101    'w': 'white',
1102    'k': 'black',
1103    'c': 'cyan',
1104    'm': 'magenta',
1105    'y': 'yellow',
1106    'r': 'red',
1107    'g': 'green',
1108    'b': 'blue',
1109}
1110def ipv_fix_color(kw):
1111    alpha = kw.pop('alpha', None)
1112    color = kw.get('color', None)
1113    if isinstance(color, str):
1114        color = _IPV_COLORS.get(color, color)
1115        kw['color'] = color
1116    if alpha is not None:
1117        color = kw['color']
1118        #TODO: convert color to [r, g, b, a] if not already
1119        if isinstance(color, (tuple, list)):
1120            if len(color) == 3:
1121                color = (color[0], color[1], color[2], alpha)
1122            else:
1123                color = (color[0], color[1], color[2], alpha*color[3])
1124            color = np.array(color)
1125        if isinstance(color, np.ndarray) and color.shape[-1] == 4:
1126            color[..., 3] = alpha
1127            kw['color'] = color
1128
1129def ipv_set_transparency(kw, obj):
1130    color = kw.get('color', None)
1131    if (isinstance(color, np.ndarray)
1132            and color.shape[-1] == 4
1133            and (color[..., 3] != 1.0).any()):
1134        obj.material.transparent = True
1135        obj.material.side = "FrontSide"
1136
1137def ipv_axes():
1138    import ipyvolume as ipv
1139
1140    class Plotter:
1141        # transparency can be achieved by setting the following:
1142        #    mesh.color = [r, g, b, alpha]
1143        #    mesh.material.transparent = True
1144        #    mesh.material.side = "FrontSide"
1145        # smooth(ish) rotation can be achieved by setting:
1146        #    slide.continuous_update = True
1147        #    figure.animation = 0.
1148        #    mesh.material.x = x
1149        #    mesh.material.y = y
1150        #    mesh.material.z = z
1151        # maybe need to synchronize update of x/y/z to avoid shimmy when moving
1152        def plot(self, x, y, z, **kw):
1153            ipv_fix_color(kw)
1154            x, y, z = make_vec(x, y, z)
1155            ipv.plot(x, y, z, **kw)
1156        def plot_surface(self, x, y, z, **kw):
1157            facecolors = kw.pop('facecolors', None)
1158            if facecolors is not None:
1159                kw['color'] = facecolors
1160            ipv_fix_color(kw)
1161            x, y, z = make_vec(x, y, z)
1162            h = ipv.plot_surface(x, y, z, **kw)
1163            ipv_set_transparency(kw, h)
1164            #h.material.side = "DoubleSide"
1165            return h
1166        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw):
1167            kw.pop('linewidth', None)
1168            ipv_fix_color(kw)
1169            x, y, z = make_vec(x, y, Z)
1170            if triangles is not None:
1171                triangles = np.asarray(triangles)
1172            h = ipv.plot_trisurf(x, y, z, triangles=triangles, **kw)
1173            ipv_set_transparency(kw, h)
1174            return h
1175        def scatter(self, x, y, z, **kw):
1176            x, y, z = make_vec(x, y, z)
1177            map_colors(z, kw)
1178            marker = kw.get('marker', None)
1179            kw['marker'] = _IPV_MARKERS.get(marker, marker)
1180            h = ipv.scatter(x, y, z, **kw)
1181            ipv_set_transparency(kw, h)
1182            return h
1183        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw):
1184            # Don't use contour for now (although we might want to later)
1185            self.pcolor(x, y, v, zdir='z', offset=offset, **kw)
1186        def pcolor(self, x, y, v, zdir='z', offset=0, **kw):
1187            x, y, v = make_vec(x, y, v)
1188            image = make_image(v, kw)
1189            xmin, xmax = x.min(), x.max()
1190            ymin, ymax = y.min(), y.max()
1191            x = np.array([[xmin, xmax], [xmin, xmax]])
1192            y = np.array([[ymin, ymin], [ymax, ymax]])
1193            z = x*0 + offset
1194            u = np.array([[0., 1], [0, 1]])
1195            v = np.array([[0., 0], [1, 1]])
1196            h = ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False)
1197            ipv_set_transparency(kw, h)
1198            h.material.side = "DoubleSide"
1199            return h
1200        def text(self, *args, **kw):
1201            pass
1202        def set_xlim(self, limits):
1203            ipv.xlim(*limits)
1204        def set_ylim(self, limits):
1205            ipv.ylim(*limits)
1206        def set_zlim(self, limits):
1207            ipv.zlim(*limits)
1208        def set_axes_on(self):
1209            ipv.style.axis_on()
1210        def set_axis_off(self):
1211            ipv.style.axes_off()
1212    return Plotter()
1213
1214def ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
1215    import ipywidgets as widgets
1216    import ipyvolume as ipv
1217
1218    axes = ipv_axes()
1219
1220    def draw(view, jitter):
1221        camera = ipv.gcf().camera
1222        #print(ipv.gcf().__dict__.keys())
1223        #print(dir(ipv.gcf()))
1224        ipv.figure(animation=0.)  # no animation when updating object mesh
1225
1226        # set small jitter as 0 if multiple pd dims
1227        dims = sum(v > 0 for v in jitter)
1228        limit = [0, 0.5, 5, 5][dims]
1229        jitter = [0 if v < limit else v for v in jitter]
1230
1231        ## Visualize as person on globe
1232        #draw_beam(axes, (0, 0))
1233        #draw_sphere(axes)
1234        #draw_person_on_sphere(axes, view)
1235
1236        ## Move beam instead of shape
1237        #draw_beam(axes, view=(-view[0], -view[1]))
1238        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0))
1239
1240        ## Move shape and draw scattering
1241        draw_beam(axes, (0, 0), steps=25)
1242        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.0,
1243                    draw_shape=draw_shape, projection=projection)
1244        draw_mesh(axes, view, jitter, dist=dist, n=mesh, radius=0.95,
1245                  projection=projection)
1246        draw_scattering(calculator, axes, view, jitter, dist=dist)
1247
1248        draw_axes(axes, origin=(-1, -1, -1.1))
1249        ipv.style.box_off()
1250        ipv.style.axes_off()
1251        ipv.xyzlabel(" ", " ", " ")
1252
1253        ipv.gcf().camera = camera
1254        ipv.show()
1255
1256
1257    trange, prange = (-180., 180., 1.), (-180., 180., 1.)
1258    dtrange, dprange = (0., 180., 1.), (0., 180., 1.)
1259
1260    ## Super simple interfaca, but uses non-ascii variable namese
1261    # Ξ φ ψ Δξ Δφ Δψ
1262    #def update(**kw):
1263    #    view = kw['Ξ'], kw['φ'], kw['ψ']
1264    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ']
1265    #    draw(view, jitter)
1266    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange)
1267
1268    def update(theta, phi, psi, dtheta, dphi, dpsi):
1269        draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi))
1270
1271    def slider(name, slice, init=0.):
1272        return widgets.FloatSlider(
1273            value=init,
1274            min=slice[0],
1275            max=slice[1],
1276            step=slice[2],
1277            description=name,
1278            disabled=False,
1279            #continuous_update=True,
1280            continuous_update=False,
1281            orientation='horizontal',
1282            readout=True,
1283            readout_format='.1f',
1284            )
1285    theta = slider(u'Ξ', trange, view[0])
1286    phi = slider(u'φ', prange, view[1])
1287    psi = slider(u'ψ', prange, view[2])
1288    dtheta = slider(u'Δξ', dtrange, jitter[0])
1289    dphi = slider(u'Δφ', dprange, jitter[1])
1290    dpsi = slider(u'Δψ', dprange, jitter[2])
1291    fields = {
1292        'theta': theta, 'phi': phi, 'psi': psi,
1293        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi,
1294    }
1295    ui = widgets.HBox([
1296        widgets.VBox([theta, phi, psi]),
1297        widgets.VBox([dtheta, dphi, dpsi])
1298    ])
1299
1300    out = widgets.interactive_output(update, fields)
1301    display(ui, out)
1302
1303
1304_ENGINES = {
1305    "matplotlib": mpl_plot,
1306    "mpl": mpl_plot,
1307    #"plotly": plotly_plot,
1308    "ipvolume": ipv_plot,
1309    "ipv": ipv_plot,
1310}
1311PLOT_ENGINE = _ENGINES["matplotlib"]
1312def set_plotter(name):
1313    global PLOT_ENGINE
1314    PLOT_ENGINE = _ENGINES[name]
1315
1316def main():
1317    parser = argparse.ArgumentParser(
1318        description="Display jitter",
1319        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1320        )
1321    parser.add_argument('-p', '--projection', choices=PROJECTIONS,
1322                        default=PROJECTIONS[0],
1323                        help='coordinate projection')
1324    parser.add_argument('-s', '--size', type=str, default='10,40,100',
1325                        help='a,b,c lengths')
1326    parser.add_argument('-v', '--view', type=str, default='0,0,0',
1327                        help='initial view angles')
1328    parser.add_argument('-j', '--jitter', type=str, default='0,0,0',
1329                        help='initial angular dispersion')
1330    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS,
1331                        default=DISTRIBUTIONS[0],
1332                        help='jitter distribution')
1333    parser.add_argument('-m', '--mesh', type=int, default=30,
1334                        help='#points in theta-phi mesh')
1335    parser.add_argument('shape', choices=SHAPES, nargs='?', default=SHAPES[0],
1336                        help='oriented shape')
1337    opts = parser.parse_args()
1338    size = tuple(float(v) for v in opts.size.split(','))
1339    view = tuple(float(v) for v in opts.view.split(','))
1340    jitter = tuple(float(v) for v in opts.jitter.split(','))
1341    run(opts.shape, size=size, view=view, jitter=jitter,
1342        mesh=opts.mesh, dist=opts.distribution,
1343        projection=opts.projection)
1344
1345if __name__ == "__main__":
1346    main()
Note: See TracBrowser for help on using the repository browser.