source: sasmodels/sasmodels/jitter.py @ b297ba9

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since b297ba9 was b297ba9, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

lint

  • Property mode set to 100755
File size: 53.9 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Jitter Explorer
5===============
6
7Application to explore orientation angle and angular dispersity.
8
9From the command line::
10
11    # Show docs
12    python -m sasmodels.jitter --help
13
14    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi
15    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0
16
17From a jupyter cell::
18
19    import ipyvolume as ipv
20    from sasmodels import jitter
21    import importlib; importlib.reload(jitter)
22    jitter.set_plotter("ipv")
23
24    size = (10, 40, 100)
25    view = (20, 0, 0)
26
27    #size = (15, 15, 100)
28    #view = (60, 60, 0)
29
30    dview = (0, 0, 0)
31    #dview = (5, 5, 0)
32    #dview = (15, 180, 0)
33    #dview = (180, 15, 0)
34
35    projection = 'equirectangular'
36    #projection = 'azimuthal_equidistance'
37    #projection = 'guyou'
38    #projection = 'sinusoidal'
39    #projection = 'azimuthal_equal_area'
40
41    dist = 'uniform'
42    #dist = 'gaussian'
43
44    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection)
45    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '')
46    #ipv.savefig(filename+'.png')
47"""
48from __future__ import division, print_function
49
50import argparse
51
52import numpy as np
53from numpy import pi, cos, sin, sqrt, exp, log, degrees, radians, arccos, arctan2
54
55# Too many complaints about variable names from pylint:
56#    a, b, c, u, v, x, y, z, dx, dy, dz, px, py, pz, R, Rx, Ry, Rz, ...
57# pylint: disable=invalid-name
58
59def draw_beam(axes, view=(0, 0), alpha=0.5, steps=25):
60    """
61    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
62    """
63    #axes.plot([0,0],[0,0],[1,-1])
64    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha)
65
66    u = np.linspace(0, 2 * pi, steps)
67    v = np.linspace(-1, 1, 2)
68
69    r = 0.02
70    x = r*np.outer(cos(u), np.ones_like(v))
71    y = r*np.outer(sin(u), np.ones_like(v))
72    z = 1.3*np.outer(np.ones_like(u), v)
73
74    theta, phi = view
75    shape = x.shape
76    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
77    points = Rz(phi)*Ry(theta)*points
78    x, y, z = [v.reshape(shape) for v in points]
79    axes.plot_surface(x, y, z, color='yellow', alpha=alpha)
80
81    # TODO: draw endcaps on beam
82    ## Drawing tiny balls on the end will work
83    #draw_sphere(axes, radius=0.02, center=(0, 0, 1.3), color='yellow', alpha=alpha)
84    #draw_sphere(axes, radius=0.02, center=(0, 0, -1.3), color='yellow', alpha=alpha)
85    ## The following does not work
86    #triangles = [(0, i+1, i+2) for i in range(steps-2)]
87    #x_cap, y_cap = x[:, 0], y[:, 0]
88    #for z_cap in z[:, 0], z[:, -1]:
89    #    axes.plot_trisurf(x_cap, y_cap, z_cap, triangles,
90    #                      color='yellow', alpha=alpha)
91
92
93def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1):
94    """Draw an ellipsoid."""
95    a, b, c = size
96    u = np.linspace(0, 2 * pi, steps)
97    v = np.linspace(0, pi, steps)
98    x = a*np.outer(cos(u), sin(v))
99    y = b*np.outer(sin(u), sin(v))
100    z = c*np.outer(np.ones_like(u), cos(v))
101    x, y, z = transform_xyz(view, jitter, x, y, z)
102
103    axes.plot_surface(x, y, z, color='w', alpha=alpha)
104
105    draw_labels(axes, view, jitter, [
106        ('c+', [+0, +0, +c], [+1, +0, +0]),
107        ('c-', [+0, +0, -c], [+0, +0, -1]),
108        ('a+', [+a, +0, +0], [+0, +0, +1]),
109        ('a-', [-a, +0, +0], [+0, +0, -1]),
110        ('b+', [+0, +b, +0], [-1, +0, +0]),
111        ('b-', [+0, -b, +0], [-1, +0, +0]),
112    ])
113
114def draw_sc(axes, size, view, jitter, steps=None, alpha=1):
115    """Draw points for simple cubic paracrystal"""
116    atoms = _build_sc()
117    _draw_crystal(axes, size, view, jitter, atoms=atoms)
118
119def draw_fcc(axes, size, view, jitter, steps=None, alpha=1):
120    """Draw points for face-centered cubic paracrystal"""
121    # Build the simple cubic crystal
122    atoms = _build_sc()
123    # Define the centers for each face
124    # x planes at -1, 0, 1 have four centers per plane, at +/- 0.5 in y and z
125    x, y, z = (
126        [-1]*4 + [0]*4 + [1]*4,
127        ([-0.5]*2 + [0.5]*2)*3,
128        [-0.5, 0.5]*12,
129    )
130    # y and z planes can be generated by substituting x for y and z respectively
131    atoms.extend(zip(x+y+z, y+z+x, z+x+y))
132    _draw_crystal(axes, size, view, jitter, atoms=atoms)
133
134def draw_bcc(axes, size, view, jitter, steps=None, alpha=1):
135    """Draw points for body-centered cubic paracrystal"""
136    # Build the simple cubic crystal
137    atoms = _build_sc()
138    # Define the centers for each octant
139    # x plane at +/- 0.5 have four centers per plane at +/- 0.5 in y and z
140    x, y, z = (
141        [-0.5]*4 + [0.5]*4,
142        ([-0.5]*2 + [0.5]*2)*2,
143        [-0.5, 0.5]*8,
144    )
145    atoms.extend(zip(x, y, z))
146    _draw_crystal(axes, size, view, jitter, atoms=atoms)
147
148def _draw_crystal(axes, size, view, jitter, atoms=None):
149    atoms, size = np.asarray(atoms, 'd').T, np.asarray(size, 'd')
150    x, y, z = atoms*size[:, None]
151    x, y, z = transform_xyz(view, jitter, x, y, z)
152    axes.scatter([x[0]], [y[0]], [z[0]], c='yellow', marker='o')
153    axes.scatter(x[1:], y[1:], z[1:], c='r', marker='o')
154
155def _build_sc():
156    # three planes of 9 dots for x at -1, 0 and 1
157    x, y, z = (
158        [-1]*9 + [0]*9 + [1]*9,
159        ([-1]*3 + [0]*3 + [1]*3)*3,
160        [-1, 0, 1]*9,
161    )
162    atoms = list(zip(x, y, z))
163    #print(list(enumerate(atoms)))
164    # Pull the dot at (0, 0, 1) to the front of the list
165    # It will be highlighted in the view
166    index = 14
167    highlight = atoms[index]
168    del atoms[index]
169    atoms.insert(0, highlight)
170    return atoms
171
172def draw_box(axes, size, view):
173    """Draw a wireframe box at a particular view."""
174    a, b, c = size
175    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
176    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
177    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
178    x, y, z = transform_xyz(view, None, x, y, z)
179    def _draw(i, j):
180        axes.plot([x[i], x[j]], [y[i], y[j]], [z[i], z[j]], color='black')
181    _draw(0, 1)
182    _draw(0, 2)
183    _draw(0, 3)
184    _draw(7, 4)
185    _draw(7, 5)
186    _draw(7, 6)
187
188def draw_parallelepiped(axes, size, view, jitter, steps=None,
189                        color=(0.6, 1.0, 0.6), alpha=1):
190    """Draw a parallelepiped surface, with view and jitter."""
191    a, b, c = size
192    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
193    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
194    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
195    tri = np.array([
196        # counter clockwise triangles
197        # z: up/down, x: right/left, y: front/back
198        [0, 1, 2], [3, 2, 1], # top face
199        [6, 5, 4], [5, 6, 7], # bottom face
200        [0, 2, 6], [6, 4, 0], # right face
201        [1, 5, 7], [7, 3, 1], # left face
202        [2, 3, 6], [7, 6, 3], # front face
203        [4, 1, 0], [5, 1, 4], # back face
204    ])
205
206    x, y, z = transform_xyz(view, jitter, x, y, z)
207    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha,
208                      linewidth=0)
209
210    # Colour the c+ face of the box.
211    # Since I can't control face color, instead draw a thin box situated just
212    # in front of the "c+" face.  Use the c face so that rotations about psi
213    # rotate that face.
214    if 0:
215        color = (1, 0.6, 0.6)  # pink
216        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
217        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
218        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
219        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001)
220        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
221
222    draw_labels(axes, view, jitter, [
223        ('c+', [+0, +0, +c], [+1, +0, +0]),
224        ('c-', [+0, +0, -c], [+0, +0, -1]),
225        ('a+', [+a, +0, +0], [+0, +0, +1]),
226        ('a-', [-a, +0, +0], [+0, +0, -1]),
227        ('b+', [+0, +b, +0], [-1, +0, +0]),
228        ('b-', [+0, -b, +0], [-1, +0, +0]),
229    ])
230
231def draw_sphere(axes, radius=0.5, steps=25,
232                center=(0, 0, 0), color='w', alpha=1.):
233    """Draw a sphere"""
234    u = np.linspace(0, 2 * pi, steps)
235    v = np.linspace(0, pi, steps)
236
237    x = radius * np.outer(cos(u), sin(v)) + center[0]
238    y = radius * np.outer(sin(u), sin(v)) + center[1]
239    z = radius * np.outer(np.ones(np.size(u)), cos(v)) + center[2]
240    axes.plot_surface(x, y, z, color=color, alpha=alpha)
241    #axes.plot_wireframe(x, y, z)
242
243def draw_axes(axes, origin=(-1, -1, -1), length=(2, 2, 2)):
244    """Draw wireframe axes lines, with given origin and length"""
245    x, y, z = origin
246    dx, dy, dz = length
247    axes.plot([x, x+dx], [y, y], [z, z], color='black')
248    axes.plot([x, x], [y, y+dy], [z, z], color='black')
249    axes.plot([x, x], [y, y], [z, z+dz], color='black')
250
251def draw_person_on_sphere(axes, view, height=0.5, radius=0.5):
252    """
253    Draw a person on the surface of a sphere.
254
255    *view* indicates (latitude, longitude, orientation)
256    """
257    limb_offset = height * 0.05
258    head_radius = height * 0.10
259    head_height = height - head_radius
260    neck_length = head_radius * 0.50
261    shoulder_height = height - 2*head_radius - neck_length
262    torso_length = shoulder_height * 0.55
263    torso_radius = torso_length * 0.30
264    leg_length = shoulder_height - torso_length
265    arm_length = torso_length * 0.90
266
267    def _draw_part(x, z):
268        y = np.zeros_like(x)
269        xp, yp, zp = transform_xyz(view, None, x, y, z + radius)
270        axes.plot(xp, yp, zp, color='k')
271
272    # circle for head
273    u = np.linspace(0, 2 * pi, 40)
274    x = head_radius * cos(u)
275    z = head_radius * sin(u) + head_height
276    _draw_part(x, z)
277
278    # rectangle for body
279    x = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius])
280    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length
281    _draw_part(x, z)
282
283    # arms
284    x = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius])
285    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height])
286    _draw_part(x, z)
287    _draw_part(-x, z)
288
289    # legs
290    x = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset])
291    z = np.array([0, leg_length])
292    _draw_part(x, z)
293    _draw_part(-x, z)
294
295    limits = [-radius-height, radius+height]
296    axes.set_xlim(limits)
297    axes.set_ylim(limits)
298    axes.set_zlim(limits)
299    axes.set_axis_off()
300
301def draw_jitter(axes, view, jitter, dist='gaussian',
302                size=(0.1, 0.4, 1.0),
303                draw_shape=draw_parallelepiped,
304                projection='equirectangular',
305                alpha=0.8,
306                views=None):
307    """
308    Represent jitter as a set of shapes at different orientations.
309    """
310    project, project_weight = get_projection(projection)
311
312    # set max diagonal to 0.95
313    scale = 0.95/sqrt(sum(v**2 for v in size))
314    size = tuple(scale*v for v in size)
315
316    dtheta, dphi, dpsi = jitter
317    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist]
318    def _steps(delta):
319        if views is None:
320            n = max(3, min(25, 2*int(base*delta/5)))
321        else:
322            n = views
323        return base*delta*np.linspace(-1, 1, n) if delta > 0 else [0.]
324    for theta in _steps(dtheta):
325        for phi in _steps(dphi):
326            for psi in _steps(dpsi):
327                w = project_weight(theta, phi, 1.0, 1.0)
328                if w > 0:
329                    dview = project(theta, phi, psi)
330                    draw_shape(axes, size, view, dview, alpha=alpha)
331    for v in 'xyz':
332        a, b, c = size
333        lim = sqrt(a**2 + b**2 + c**2)
334        getattr(axes, 'set_'+v+'lim')([-lim, lim])
335        #getattr(axes, v+'axis').label.set_text(v)
336
337PROJECTIONS = [
338    # in order of PROJECTION number; do not change without updating the
339    # constants in kernel_iq.c
340    'equirectangular', 'sinusoidal', 'guyou', 'azimuthal_equidistance',
341    'azimuthal_equal_area',
342]
343def get_projection(projection):
344
345    """
346    jitter projections
347    <https://en.wikipedia.org/wiki/List_of_map_projections>
348
349    equirectangular (standard latitude-longitude mesh)
350        <https://en.wikipedia.org/wiki/Equirectangular_projection>
351        Allows free movement in phi (around the equator), but theta is
352        limited to +/- 90, and points are cos-weighted. Jitter in phi is
353        uniform in weight along a line of latitude.  With small theta and
354        phi ranging over +/- 180 this forms a wobbling disk.  With small
355        phi and theta ranging over +/- 90 this forms a wedge like a slice
356        of an orange.
357    azimuthal_equidistance (Postel)
358        <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>
359        Preserves distance from center, and so is an excellent map for
360        representing a bivariate gaussian on the surface.  Theta and phi
361        operate identically, cutting wegdes from the antipode of the viewing
362        angle.  This unfortunately does not allow free movement in either
363        theta or phi since the orthogonal wobble decreases to 0 as the body
364        rotates through 180 degrees.
365    sinusoidal (Sanson-Flamsteed, Mercator equal-area)
366        <https://en.wikipedia.org/wiki/Sinusoidal_projection>
367        Preserves arc length with latitude, giving bad behaviour at
368        theta near +/- 90.  Theta and phi operate somewhat differently,
369        so a system with a-b-c dtheta-dphi-dpsi will not give the same
370        value as one with b-a-c dphi-dtheta-dpsi, as would be the case
371        for azimuthal equidistance.  Free movement using theta or phi
372        uniform over +/- 180 will work, but not as well as equirectangular
373        phi, with theta being slightly worse.  Computationally it is much
374        cheaper for wide theta-phi meshes since it excludes points which
375        lie outside the sinusoid near theta +/- 90 rather than packing
376        them close together as in equirectangle.  Note that the poles
377        will be slightly overweighted for theta > 90 with the circle
378        from theta at 90+dt winding backwards around the pole, overlapping
379        the circle from theta at 90-dt.
380    Guyou (hemisphere-in-a-square) **not weighted**
381        <https://en.wikipedia.org/wiki/Guyou_hemisphere-in-a-square_projection>
382        With tiling, allows rotation in phi or theta through +/- 180, with
383        uniform spacing.  Both theta and phi allow free rotation, with wobble
384        in the orthogonal direction reasonably well behaved (though not as
385        good as equirectangular phi). The forward/reverse transformations
386        relies on elliptic integrals that are somewhat expensive, so the
387        behaviour has to be very good to justify the cost and complexity.
388        The weighting function for each point has not yet been computed.
389        Note: run the module *guyou.py* directly and it will show the forward
390        and reverse mappings.
391    azimuthal_equal_area  **incomplete**
392        <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>
393        Preserves the relative density of the surface patches.  Not that
394        useful and not completely implemented
395    Gauss-Kreuger **not implemented**
396        <https://en.wikipedia.org/wiki/Transverse_Mercator_projection#Ellipsoidal_transverse_Mercator>
397        Should allow free movement in theta, but phi is distorted.
398    """
399    # TODO: try Kent distribution instead of a gaussian warped by projection
400
401    if projection == 'equirectangular':  #define PROJECTION 1
402        def _project(theta_i, phi_j, psi):
403            latitude, longitude = theta_i, phi_j
404            return latitude, longitude, psi, 'xyz'
405            #return Rx(phi_j)*Ry(theta_i)
406        def _weight(theta_i, phi_j, w_i, w_j):
407            return w_i*w_j*abs(cos(radians(theta_i)))
408    elif projection == 'sinusoidal':  #define PROJECTION 2
409        def _project(theta_i, phi_j, psi):
410            latitude = theta_i
411            scale = cos(radians(latitude))
412            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0
413            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
414            return latitude, longitude, psi, 'xyz'
415            #return Rx(longitude)*Ry(latitude)
416        def _project(theta_i, phi_j, w_i, w_j):
417            latitude = theta_i
418            scale = cos(radians(latitude))
419            active = 1 if abs(phi_j) < abs(scale)*180 else 0
420            return active*w_i*w_j
421    elif projection == 'guyou':  #define PROJECTION 3  (eventually?)
422        def _project(theta_i, phi_j, psi):
423            from .guyou import guyou_invert
424            #latitude, longitude = guyou_invert([theta_i], [phi_j])
425            longitude, latitude = guyou_invert([phi_j], [theta_i])
426            return latitude, longitude, psi, 'xyz'
427            #return Rx(longitude[0])*Ry(latitude[0])
428        def _weight(theta_i, phi_j, w_i, w_j):
429            return w_i*w_j
430    elif projection == 'azimuthal_equidistance':
431        # Note that calculates angles for Rz Ry rather than Rx Ry
432        def _project(theta_i, phi_j, psi):
433            latitude = sqrt(theta_i**2 + phi_j**2)
434            longitude = degrees(arctan2(phi_j, theta_i))
435            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
436            return latitude, longitude, psi-longitude, 'zyz'
437            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
438            #return R_to_xyz(R)
439            #return Rz(longitude)*Ry(latitude)
440        def _weight(theta_i, phi_j, w_i, w_j):
441            # Weighting for each point comes from the integral:
442            #     \int\int I(q, lat, log) sin(lat) dlat dlog
443            # We are doing a conformal mapping from disk to sphere, so we need
444            # a change of variables g(theta, phi) -> (lat, long):
445            #     lat, long = sqrt(theta^2 + phi^2), arctan(phi/theta)
446            # giving:
447            #     dtheta dphi = det(J) dlat dlong
448            # where J is the jacobian from the partials of g. Using
449            #     R = sqrt(theta^2 + phi^2),
450            # then
451            #     J = [[x/R, Y/R], -y/R^2, x/R^2]]
452            # and
453            #     det(J) = 1/R
454            # with the final integral being:
455            #    \int\int I(q, theta, phi) sin(R)/R dtheta dphi
456            #
457            # This does approximately the right thing, decreasing the weight
458            # of each point as you go farther out on the disk, but it hasn't
459            # yet been checked against the 1D integral results. Prior
460            # to declaring this "good enough" and checking that integrals
461            # work in practice, we will examine alternative mappings.
462            #
463            # The issue is that the mapping does not support the case of free
464            # rotation about a single axis correctly, with a small deviation
465            # in the orthogonal axis independent of the first axis.  Like the
466            # usual polar coordiates integration, the integrated sections
467            # form wedges, though at least in this case the wedge cuts through
468            # the entire sphere, and treats theta and phi identically.
469            latitude = sqrt(theta_i**2 + phi_j**2)
470            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
471            return weight*w_i*w_j if latitude < 180 else 0
472    elif projection == 'azimuthal_equal_area':
473        # Note that calculates angles for Rz Ry rather than Rx Ry
474        def _project(theta_i, phi_j, psi):
475            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180)
476            latitude = 180-degrees(2*arccos(radius))
477            longitude = degrees(arctan2(phi_j, theta_i))
478            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
479            return latitude, longitude, psi, 'zyz'
480            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
481            #return R_to_xyz(R)
482            #return Rz(longitude)*Ry(latitude)
483        def _weight(theta_i, phi_j, w_i, w_j):
484            latitude = sqrt(theta_i**2 + phi_j**2)
485            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
486            return weight*w_i*w_j if latitude < 180 else 0
487    else:
488        raise ValueError("unknown projection %r"%projection)
489
490    return _project, _weight
491
492def R_to_xyz(R):
493    """
494    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix.
495
496    Extracting Euler Angles from a Rotation Matrix
497    Mike Day, Insomniac Games
498    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf
499    Based on: Shoemake’s "Euler Angle Conversion", Graphics Gems IV, pp.  222-229
500    """
501    phi = arctan2(R[1, 2], R[2, 2])
502    theta = arctan2(-R[0, 2], sqrt(R[0, 0]**2 + R[0, 1]**2))
503    psi = arctan2(R[0, 1], R[0, 0])
504    return degrees(phi), degrees(theta), degrees(psi)
505
506def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian',
507              projection='equirectangular'):
508    """
509    Draw the dispersion mesh showing the theta-phi orientations at which
510    the model will be evaluated.
511    """
512
513    _project, _weight = get_projection(projection)
514    def _rotate(theta, phi, z):
515        dview = _project(theta, phi, 0.)
516        if dview[3] == 'zyz':
517            return Rz(dview[1])*Ry(dview[0])*z
518        else:  # dview[3] == 'xyz':
519            return Rx(dview[1])*Ry(dview[0])*z
520
521
522    dist_x = np.linspace(-1, 1, n)
523    weights = np.ones_like(dist_x)
524    if dist == 'gaussian':
525        dist_x *= 3
526        weights = exp(-0.5*dist_x**2)
527    elif dist == 'rectangle':
528        # Note: uses sasmodels ridiculous definition of rectangle width
529        dist_x *= sqrt(3)
530    elif dist == 'uniform':
531        pass
532    else:
533        raise ValueError("expected dist to be gaussian, rectangle or uniform")
534
535    # mesh in theta, phi formed by rotating z
536    dtheta, dphi, dpsi = jitter
537    z = np.matrix([[0], [0], [radius]])
538    points = np.hstack([_rotate(theta_i, phi_j, z)
539                        for theta_i in dtheta*dist_x
540                        for phi_j in dphi*dist_x])
541    dist_w = np.array([_weight(theta_i, phi_j, w_i, w_j)
542                       for w_i, theta_i in zip(weights, dtheta*dist_x)
543                       for w_j, phi_j in zip(weights, dphi*dist_x)])
544    #print(max(dist_w), min(dist_w), min(dist_w[dist_w > 0]))
545    points = points[:, dist_w > 0]
546    dist_w = dist_w[dist_w > 0]
547    dist_w /= max(dist_w)
548
549    # rotate relative to beam
550    points = orient_relative_to_beam(view, points)
551
552    x, y, z = [np.array(v).flatten() for v in points]
553    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
554    axes.scatter(x, y, z, c=dist_w, marker='o', vmin=0., vmax=1.)
555
556def draw_labels(axes, view, jitter, text):
557    """
558    Draw text at a particular location.
559    """
560    labels, locations, orientations = zip(*text)
561    px, py, pz = zip(*locations)
562    dx, dy, dz = zip(*orientations)
563
564    px, py, pz = transform_xyz(view, jitter, px, py, pz)
565    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
566
567    # TODO: zdir for labels is broken, and labels aren't appearing.
568    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
569        zdir = np.asarray(zdir).flatten()
570        axes.text(p[0], p[1], p[2], label, zdir=zdir)
571
572# Definition of rotation matrices comes from wikipedia:
573#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
574def Rx(angle):
575    """Construct a matrix to rotate points about *x* by *angle* degrees."""
576    angle = radians(angle)
577    rot = [[1, 0, 0],
578           [0, +cos(angle), -sin(angle)],
579           [0, +sin(angle), +cos(angle)]]
580    return np.matrix(rot)
581
582def Ry(angle):
583    """Construct a matrix to rotate points about *y* by *angle* degrees."""
584    angle = radians(angle)
585    rot = [[+cos(angle), 0, +sin(angle)],
586           [0, 1, 0],
587           [-sin(angle), 0, +cos(angle)]]
588    return np.matrix(rot)
589
590def Rz(angle):
591    """Construct a matrix to rotate points about *z* by *angle* degrees."""
592    angle = radians(angle)
593    rot = [[+cos(angle), -sin(angle), 0],
594           [+sin(angle), +cos(angle), 0],
595           [0, 0, 1]]
596    return np.matrix(rot)
597
598def transform_xyz(view, jitter, x, y, z):
599    """
600    Send a set of (x,y,z) points through the jitter and view transforms.
601    """
602    x, y, z = [np.asarray(v) for v in (x, y, z)]
603    shape = x.shape
604    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
605    points = apply_jitter(jitter, points)
606    points = orient_relative_to_beam(view, points)
607    x, y, z = [np.array(v).reshape(shape) for v in points]
608    return x, y, z
609
610def apply_jitter(jitter, points):
611    """
612    Apply the jitter transform to a set of points.
613
614    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
615    """
616    if jitter is None:
617        return points
618    # Hack to deal with the fact that azimuthal_equidistance uses euler angles
619    if len(jitter) == 4:
620        dtheta, dphi, dpsi, _ = jitter
621        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points
622    else:
623        dtheta, dphi, dpsi = jitter
624        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
625    return points
626
627def orient_relative_to_beam(view, points):
628    """
629    Apply the view transform to a set of points.
630
631    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
632    """
633    theta, phi, psi = view
634    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle
635    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles
636    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle
637    return points
638
639def orient_relative_to_beam_quaternion(view, points):
640    """
641    Apply the view transform to a set of points.
642
643    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
644
645    This variant uses quaternions rather than rotation matrices for the
646    computation.  It works but it is not used because it doesn't solve
647    any problems.  The challenge of mapping theta/phi/psi to SO(3) does
648    not disappear by calculating the transform differently.
649    """
650    theta, phi, psi = view
651    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1]
652    q = Quaternion(1, [0, 0, 0])
653    ## Compose a rotation about the three axes by rotating
654    ## the unit vectors before applying the rotation.
655    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q
656    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q
657    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q
658    ## The above turns out to be equivalent to reversing
659    ## the order of application, so ignore it and use below.
660    q = q * Quaternion.from_angle_axis(theta, x)
661    q = q * Quaternion.from_angle_axis(phi, y)
662    q = q * Quaternion.from_angle_axis(psi, z)
663    ## Reverse the order by post-multiply rather than pre-multiply
664    #q = Quaternion.from_angle_axis(theta, x) * q
665    #q = Quaternion.from_angle_axis(phi, y) * q
666    #q = Quaternion.from_angle_axis(psi, z) * q
667    #print("axes psi", q.rot(np.matrix([x, y, z]).T))
668    return q.rot(points)
669#orient_relative_to_beam = orient_relative_to_beam_quaternion
670
671# === Quaterion class definition === BEGIN
672# Simple stand-alone quaternion class
673
674# Note: this code works but isn't unused since quaternions didn't solve the
675# representation problem.  Leave it here in case we want to revisit this later.
676
677#import numpy as np
678class Quaternion(object):
679    r"""
680    Quaternion(w, r) = w + ir[0] + jr[1] + kr[2]
681
682    Quaternion.from_angle_axis(theta, r) for a rotation of angle theta about
683    an axis oriented toward the direction r.  This defines a unit quaternion,
684    normalizing $r$ to the unit vector $\hat r$, and setting quaternion
685    $Q = \cos \theta + \sin \theta \hat r$
686
687    Quaternion objects can be multiplied, which applies a rotation about the
688    given axis, allowing composition of rotations without risk of gimbal lock.
689    The resulting quaternion is applied to a set of points using *Q.rot(v)*.
690    """
691    def __init__(self, w, r):
692        self.w = w
693        self.r = np.asarray(r, dtype='d')
694
695    @staticmethod
696    def from_angle_axis(theta, r):
697        """Build quaternion as rotation theta about axis r"""
698        theta = np.radians(theta)/2
699        r = np.asarray(r)
700        w = np.cos(theta)
701        r = np.sin(theta)*r/np.dot(r, r)
702        return Quaternion(w, r)
703
704    def __mul__(self, other):
705        """Multiply quaterions"""
706        if isinstance(other, Quaternion):
707            w = self.w*other.w - np.dot(self.r, other.r)
708            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r)
709            return Quaternion(w, r)
710        raise NotImplementedError("Quaternion * non-quaternion not implemented")
711
712    def rot(self, v):
713        """Transform point *v* by quaternion"""
714        v = np.asarray(v).T
715        use_transpose = (v.shape[-1] != 3)
716        if use_transpose:
717            v = v.T
718        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v)
719        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v))
720        if use_transpose:
721            v = v.T
722        return v.T
723
724    def conj(self):
725        """Conjugate quaternion"""
726        return Quaternion(self.w, -self.r)
727
728    def inv(self):
729        """Inverse quaternion"""
730        return self.conj()/self.norm()**2
731
732    def norm(self):
733        """Quaternion length"""
734        return np.sqrt(self.w**2 + np.sum(self.r**2))
735
736    def __str__(self):
737        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2])
738
739def test_qrot():
740    """Quaternion checks"""
741    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees
742    # from y.  The rotation axis is determined by rotating the point [0, 1, 0]
743    # about x.
744    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0])
745    q = Quaternion.from_angle_axis(60, ax)
746    # Set the point to be rotated, and its expected rotated position.
747    p = [1, -1, 2]
748    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8]
749    #print(q, q.rot(p) - target)
750    assert max(abs(q.rot(p) - target)) < 1e-14
751#test_qrot()
752#import sys; sys.exit()
753# === Quaterion class definition === END
754
755# translate between number of dimension of dispersity and the number of
756# points along each dimension.
757PD_N_TABLE = {
758    (0, 0, 0): (0, 0, 0),     # 0
759    (1, 0, 0): (100, 0, 0),   # 100
760    (0, 1, 0): (0, 100, 0),
761    (0, 0, 1): (0, 0, 100),
762    (1, 1, 0): (30, 30, 0),   # 900
763    (1, 0, 1): (30, 0, 30),
764    (0, 1, 1): (0, 30, 30),
765    (1, 1, 1): (15, 15, 15),  # 3375
766}
767
768def clipped_range(data, portion=1.0, mode='central'):
769    """
770    Determine range from data.
771
772    If *portion* is 1, use full range, otherwise use the center of the range
773    or the top of the range, depending on whether *mode* is 'central' or 'top'.
774    """
775    if portion < 1.0:
776        if mode == 'central':
777            data = np.sort(data.flatten())
778            offset = int(portion*len(data)/2 + 0.5)
779            return data[offset], data[-offset]
780        if mode == 'top':
781            data = np.sort(data.flatten())
782            offset = int(portion*len(data) + 0.5)
783            return data[offset], data[-1]
784    # Default: full range
785    return data.min(), data.max()
786
787def draw_scattering(calculator, axes, view, jitter, dist='gaussian'):
788    """
789    Plot the scattering for the particular view.
790
791    *calculator* is returned from :func:`build_model`.  *axes* are the 3D axes
792    on which the data will be plotted.  *view* and *jitter* are the current
793    orientation and orientation dispersity.  *dist* is one of the sasmodels
794    weight distributions.
795    """
796    if dist == 'uniform':  # uniform is not yet in this branch
797        dist, scale = 'rectangle', 1/sqrt(3)
798    else:
799        scale = 1
800
801    # add the orientation parameters to the model parameters
802    theta, phi, psi = view
803    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
804    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd > 0, phi_pd > 0, psi_pd > 0)]
805    ## increase pd_n for testing jitter integration rather than simple viz
806    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
807
808    pars = dict(
809        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
810        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
811        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
812    )
813    pars.update(calculator.pars)
814
815    # compute the pattern
816    qx, qy = calculator.qxqy
817    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
818
819    # scale it and draw it
820    Iqxy = log(Iqxy)
821    if calculator.limits:
822        # use limits from orientation (0,0,0)
823        vmin, vmax = calculator.limits
824    else:
825        vmax = Iqxy.max()
826        vmin = vmax*10**-7
827        #vmin, vmax = clipped_range(Iqxy, portion=portion, mode='top')
828    #vmin, vmax = Iqxy.min(), Iqxy.max()
829    #print("range",(vmin,vmax))
830    #qx, qy = np.meshgrid(qx, qy)
831    if 0:
832        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
833        level[level < 0] = 0
834        from matplotlib import pylab as plt
835        colors = plt.get_cmap()(level)
836        #from matplotlib import cm
837        #colors = cm.coolwarm(level)
838        #colors = cm.gist_yarg(level)
839        #colors = cm.Wistia(level)
840        colors[level <= 0, 3] = 0.  # set floor to transparent
841        x, y = np.meshgrid(qx/qx.max(), qy/qy.max())
842        axes.plot_surface(x, y, -1.1*np.ones_like(x), facecolors=colors)
843    elif 1:
844        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
845                      levels=np.linspace(vmin, vmax, 24))
846    else:
847        axes.pcolormesh(qx, qy, Iqxy)
848
849def build_model(model_name, n=150, qmax=0.5, **pars):
850    """
851    Build a calculator for the given shape.
852
853    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
854    on which to evaluate the model.  The remaining parameters are stored in
855    the returned calculator as *calculator.pars*.  They are used by
856    :func:`draw_scattering` to set the non-orientation parameters in the
857    calculation.
858
859    Returns a *calculator* function which takes a dictionary or parameters and
860    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
861    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
862    for details.
863    """
864    from sasmodels.core import load_model_info, build_model as build_sasmodel
865    from sasmodels.data import empty_data2D
866    from sasmodels.direct_model import DirectModel
867
868    model_info = load_model_info(model_name)
869    model = build_sasmodel(model_info) #, dtype='double!')
870    q = np.linspace(-qmax, qmax, n)
871    data = empty_data2D(q, q)
872    calculator = DirectModel(data, model)
873
874    # Remember the data axes so we can plot the results
875    calculator.qxqy = (q, q)
876
877    # stuff the values for non-orientation parameters into the calculator
878    calculator.pars = pars.copy()
879    calculator.pars.setdefault('backgound', 1e-3)
880
881    # fix the data limits so that we can see if the pattern fades
882    # under rotation or angular dispersion
883    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
884    Iqxy = log(Iqxy)
885    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
886    calculator.limits = vmin, vmax+1
887
888    return calculator
889
890def select_calculator(model_name, n=150, size=(10, 40, 100)):
891    """
892    Create a model calculator for the given shape.
893
894    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
895    parallelepiped or bcc_paracrystal. *n* is the number of points to use
896    in the q range.  *qmax* is chosen based on model parameters for the
897    given model to show something intersting.
898
899    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
900    equitorial axes and polar axis respectively.  See :func:`build_model`
901    for details on the returned calculator.
902    """
903    a, b, c = size
904    d_factor = 0.06  # for paracrystal models
905    if model_name == 'sphere':
906        calculator = build_model('sphere', n=n, radius=c)
907        a = b = c
908    elif model_name == 'sc_paracrystal':
909        a = b = c
910        dnn = c
911        radius = 0.5*c
912        calculator = build_model('sc_paracrystal', n=n, dnn=dnn,
913                                 d_factor=d_factor, radius=(1-d_factor)*radius,
914                                 background=0)
915    elif model_name == 'fcc_paracrystal':
916        a = b = c
917        # nearest neigbour distance dnn should be 2 radius, but I think the
918        # model uses lattice spacing rather than dnn in its calculations
919        dnn = 0.5*c
920        radius = sqrt(2)/4 * c
921        calculator = build_model('fcc_paracrystal', n=n, dnn=dnn,
922                                 d_factor=d_factor, radius=(1-d_factor)*radius,
923                                 background=0)
924    elif model_name == 'bcc_paracrystal':
925        a = b = c
926        # nearest neigbour distance dnn should be 2 radius, but I think the
927        # model uses lattice spacing rather than dnn in its calculations
928        dnn = 0.5*c
929        radius = sqrt(3)/2 * c
930        calculator = build_model('bcc_paracrystal', n=n, dnn=dnn,
931                                 d_factor=d_factor, radius=(1-d_factor)*radius,
932                                 background=0)
933    elif model_name == 'cylinder':
934        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
935        a = b
936    elif model_name == 'ellipsoid':
937        calculator = build_model('ellipsoid', n=n, qmax=1.0,
938                                 radius_polar=c, radius_equatorial=b)
939        a = b
940    elif model_name == 'triaxial_ellipsoid':
941        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
942                                 radius_equat_minor=a,
943                                 radius_equat_major=b,
944                                 radius_polar=c)
945    elif model_name == 'parallelepiped':
946        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
947    else:
948        raise ValueError("unknown model %s"%model_name)
949
950    return calculator, (a, b, c)
951
952SHAPES = [
953    'parallelepiped',
954    'sphere', 'ellipsoid', 'triaxial_ellipsoid',
955    'cylinder',
956    'fcc_paracrystal', 'bcc_paracrystal', 'sc_paracrystal',
957]
958
959DRAW_SHAPES = {
960    'fcc_paracrystal': draw_fcc,
961    'bcc_paracrystal': draw_bcc,
962    'sc_paracrystal': draw_sc,
963    'parallelepiped': draw_parallelepiped,
964}
965
966DISTRIBUTIONS = [
967    'gaussian', 'rectangle', 'uniform',
968]
969DIST_LIMITS = {
970    'gaussian': 30,
971    'rectangle': 90/sqrt(3),
972    'uniform': 90,
973}
974
975
976def run(model_name='parallelepiped', size=(10, 40, 100),
977        view=(0, 0, 0), jitter=(0, 0, 0),
978        dist='gaussian', mesh=30,
979        projection='equirectangular'):
980    """
981    Show an interactive orientation and jitter demo.
982
983    *model_name* is one of: sphere, ellipsoid, triaxial_ellipsoid,
984    parallelepiped, cylinder, or sc/fcc/bcc_paracrystal
985
986    *size* gives the dimensions (a, b, c) of the shape.
987
988    *view* gives the initial view (theta, phi, psi) of the shape.
989
990    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape.
991
992    *dist* is the type of dispersition: gaussian, rectangle, or uniform.
993
994    *mesh* is the number of points in the dispersion mesh.
995
996    *projection* is the map projection to use for the mesh: equirectangular,
997    sinusoidal, guyou, azimuthal_equidistance, or azimuthal_equal_area.
998    """
999    # projection number according to 1-order position in list, but
1000    # only 1 and 2 are implemented so far.
1001    from sasmodels import generate
1002    generate.PROJECTION = PROJECTIONS.index(projection) + 1
1003    if generate.PROJECTION > 2:
1004        print("*** PROJECTION %s not implemented in scattering function ***"%projection)
1005        generate.PROJECTION = 2
1006
1007    # set up calculator
1008    calculator, size = select_calculator(model_name, n=150, size=size)
1009    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped)
1010    #draw_shape = draw_fcc
1011
1012    ## uncomment to set an independent the colour range for every view
1013    ## If left commented, the colour range is fixed for all views
1014    calculator.limits = None
1015
1016    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection)
1017
1018def _mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
1019    # Note: travis-ci does not support mpl_toolkits.mplot3d, but this shouldn't be
1020    # an issue since we are lazy-loading the package on a path that isn't tested.
1021    import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot
1022    import matplotlib as mpl
1023    import matplotlib.pyplot as plt
1024    from matplotlib.widgets import Slider
1025
1026    ## create the plot window
1027    #plt.hold(True)
1028    plt.subplots(num=None, figsize=(5.5, 5.5))
1029    plt.set_cmap('gist_earth')
1030    plt.clf()
1031    plt.gcf().canvas.set_window_title(projection)
1032    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
1033    #axes = plt.subplot(gs[0], projection='3d')
1034    axes = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
1035    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
1036        axes.axis('square')
1037    except Exception:
1038        pass
1039
1040    # CRUFT: use axisbg instead of facecolor for matplotlib<2
1041    facecolor_prop = 'facecolor' if mpl.__version__ > '2' else 'axisbg'
1042    props = {facecolor_prop: 'lightgoldenrodyellow'}
1043
1044    ## add control widgets to plot
1045    axes_theta = plt.axes([0.05, 0.15, 0.50, 0.04], **props)
1046    axes_phi = plt.axes([0.05, 0.10, 0.50, 0.04], **props)
1047    axes_psi = plt.axes([0.05, 0.05, 0.50, 0.04], **props)
1048    stheta = Slider(axes_theta, u'Ξ', -90, 90, valinit=0)
1049    sphi = Slider(axes_phi, u'φ', -180, 180, valinit=0)
1050    spsi = Slider(axes_psi, u'ψ', -180, 180, valinit=0)
1051
1052    axes_dtheta = plt.axes([0.70, 0.15, 0.20, 0.04], **props)
1053    axes_dphi = plt.axes([0.70, 0.1, 0.20, 0.04], **props)
1054    axes_dpsi = plt.axes([0.70, 0.05, 0.20, 0.04], **props)
1055
1056    # Note: using ridiculous definition of rectangle distribution, whose width
1057    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
1058    # the maximum width to 90.
1059    dlimit = DIST_LIMITS[dist]
1060    sdtheta = Slider(axes_dtheta, u'Δξ', 0, 2*dlimit, valinit=0)
1061    sdphi = Slider(axes_dphi, u'Δφ', 0, 2*dlimit, valinit=0)
1062    sdpsi = Slider(axes_dpsi, u'Δψ', 0, 2*dlimit, valinit=0)
1063
1064    ## initial view and jitter
1065    theta, phi, psi = view
1066    stheta.set_val(theta)
1067    sphi.set_val(phi)
1068    spsi.set_val(psi)
1069    dtheta, dphi, dpsi = jitter
1070    sdtheta.set_val(dtheta)
1071    sdphi.set_val(dphi)
1072    sdpsi.set_val(dpsi)
1073
1074    ## callback to draw the new view
1075    def _update(val, axis=None):
1076        view = stheta.val, sphi.val, spsi.val
1077        jitter = sdtheta.val, sdphi.val, sdpsi.val
1078        # set small jitter as 0 if multiple pd dims
1079        dims = sum(v > 0 for v in jitter)
1080        limit = [0, 0.5, 5, 5][dims]
1081        jitter = [0 if v < limit else v for v in jitter]
1082        axes.cla()
1083
1084        ## Visualize as person on globe
1085        #draw_sphere(axes)
1086        #draw_person_on_sphere(axes, view)
1087
1088        ## Move beam instead of shape
1089        #draw_beam(axes, -view[:2])
1090        #draw_jitter(axes, (0,0,0), (0,0,0), views=3)
1091
1092        ## Move shape and draw scattering
1093        draw_beam(axes, (0, 0), alpha=1.)
1094        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.,
1095                    draw_shape=draw_shape, projection=projection, views=3)
1096        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection)
1097        draw_scattering(calculator, axes, view, jitter, dist=dist)
1098
1099        plt.gcf().canvas.draw()
1100
1101    ## bind control widgets to view updater
1102    stheta.on_changed(lambda v: _update(v, 'theta'))
1103    sphi.on_changed(lambda v: _update(v, 'phi'))
1104    spsi.on_changed(lambda v: _update(v, 'psi'))
1105    sdtheta.on_changed(lambda v: _update(v, 'dtheta'))
1106    sdphi.on_changed(lambda v: _update(v, 'dphi'))
1107    sdpsi.on_changed(lambda v: _update(v, 'dpsi'))
1108
1109    ## initialize view
1110    _update(None, 'phi')
1111
1112    ## go interactive
1113    plt.show()
1114
1115
1116def map_colors(z, kw):
1117    """
1118    Process matplotlib-style colour arguments.
1119
1120    Pulls 'cmap', 'alpha', 'vmin', and 'vmax' from th *kw* dictionary, setting
1121    the *kw['color']* to an RGB array.  These are ignored if 'c' or 'color' are
1122    set inside *kw*.
1123    """
1124    from matplotlib import cm
1125
1126    cmap = kw.pop('cmap', cm.coolwarm)
1127    alpha = kw.pop('alpha', None)
1128    vmin = kw.pop('vmin', z.min())
1129    vmax = kw.pop('vmax', z.max())
1130    c = kw.pop('c', None)
1131    color = kw.pop('color', c)
1132    if color is None:
1133        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1)
1134        color = cmap(znorm)
1135    elif isinstance(color, np.ndarray) and color.shape == z.shape:
1136        color = cmap(color)
1137    if alpha is None:
1138        if isinstance(color, np.ndarray):
1139            color = color[..., :3]
1140    else:
1141        color[..., 3] = alpha
1142    kw['color'] = color
1143
1144def make_vec(*args):
1145    """Turn all elements of *args* into numpy arrays"""
1146    #return [np.asarray(v, 'd').flatten() for v in args]
1147    return [np.asarray(v, 'd') for v in args]
1148
1149def make_image(z, kw):
1150    """Convert numpy array *z* into a *PIL* RGB image."""
1151    import PIL.Image
1152    from matplotlib import cm
1153
1154    cmap = kw.pop('cmap', cm.coolwarm)
1155
1156    znorm = (z-z.min())/z.ptp()
1157    c = cmap(znorm)
1158    c = c[..., :3]
1159    rgb = np.asarray(c*255, 'u1')
1160    image = PIL.Image.fromarray(rgb, mode='RGB')
1161    return image
1162
1163
1164_IPV_MARKERS = {
1165    'o': 'sphere',
1166}
1167_IPV_COLORS = {
1168    'w': 'white',
1169    'k': 'black',
1170    'c': 'cyan',
1171    'm': 'magenta',
1172    'y': 'yellow',
1173    'r': 'red',
1174    'g': 'green',
1175    'b': 'blue',
1176}
1177def _ipv_fix_color(kw):
1178    alpha = kw.pop('alpha', None)
1179    color = kw.get('color', None)
1180    if isinstance(color, str):
1181        color = _IPV_COLORS.get(color, color)
1182        kw['color'] = color
1183    if alpha is not None:
1184        color = kw['color']
1185        #TODO: convert color to [r, g, b, a] if not already
1186        if isinstance(color, (tuple, list)):
1187            if len(color) == 3:
1188                color = (color[0], color[1], color[2], alpha)
1189            else:
1190                color = (color[0], color[1], color[2], alpha*color[3])
1191            color = np.array(color)
1192        if isinstance(color, np.ndarray) and color.shape[-1] == 4:
1193            color[..., 3] = alpha
1194            kw['color'] = color
1195
1196def _ipv_set_transparency(kw, obj):
1197    color = kw.get('color', None)
1198    if (isinstance(color, np.ndarray)
1199            and color.shape[-1] == 4
1200            and (color[..., 3] != 1.0).any()):
1201        obj.material.transparent = True
1202        obj.material.side = "FrontSide"
1203
1204def ipv_axes():
1205    """
1206    Build a matplotlib style Axes interface for ipyvolume
1207    """
1208    import ipyvolume as ipv
1209
1210    class Axes:
1211        """
1212        Matplotlib Axes3D style interface to ipyvolume renderer.
1213        """
1214        # transparency can be achieved by setting the following:
1215        #    mesh.color = [r, g, b, alpha]
1216        #    mesh.material.transparent = True
1217        #    mesh.material.side = "FrontSide"
1218        # smooth(ish) rotation can be achieved by setting:
1219        #    slide.continuous_update = True
1220        #    figure.animation = 0.
1221        #    mesh.material.x = x
1222        #    mesh.material.y = y
1223        #    mesh.material.z = z
1224        # maybe need to synchronize update of x/y/z to avoid shimmy when moving
1225        def plot(self, x, y, z, **kw):
1226            """mpl style plot interface for ipyvolume"""
1227            _ipv_fix_color(kw)
1228            x, y, z = make_vec(x, y, z)
1229            ipv.plot(x, y, z, **kw)
1230        def plot_surface(self, x, y, z, **kw):
1231            """mpl style plot_surface interface for ipyvolume"""
1232            facecolors = kw.pop('facecolors', None)
1233            if facecolors is not None:
1234                kw['color'] = facecolors
1235            _ipv_fix_color(kw)
1236            x, y, z = make_vec(x, y, z)
1237            h = ipv.plot_surface(x, y, z, **kw)
1238            _ipv_set_transparency(kw, h)
1239            #h.material.side = "DoubleSide"
1240            return h
1241        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw):
1242            """mpl style plot_trisurf interface for ipyvolume"""
1243            kw.pop('linewidth', None)
1244            _ipv_fix_color(kw)
1245            x, y, z = make_vec(x, y, Z)
1246            if triangles is not None:
1247                triangles = np.asarray(triangles)
1248            h = ipv.plot_trisurf(x, y, z, triangles=triangles, **kw)
1249            _ipv_set_transparency(kw, h)
1250            return h
1251        def scatter(self, x, y, z, **kw):
1252            """mpl style scatter interface for ipyvolume"""
1253            x, y, z = make_vec(x, y, z)
1254            map_colors(z, kw)
1255            marker = kw.get('marker', None)
1256            kw['marker'] = _IPV_MARKERS.get(marker, marker)
1257            h = ipv.scatter(x, y, z, **kw)
1258            _ipv_set_transparency(kw, h)
1259            return h
1260        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw):
1261            """mpl style contourf interface for ipyvolume"""
1262            # Don't use contour for now (although we might want to later)
1263            self.pcolor(x, y, v, zdir='z', offset=offset, **kw)
1264        def pcolor(self, x, y, v, zdir='z', offset=0, **kw):
1265            """mpl style pcolor interface for ipyvolume"""
1266            x, y, v = make_vec(x, y, v)
1267            image = make_image(v, kw)
1268            xmin, xmax = x.min(), x.max()
1269            ymin, ymax = y.min(), y.max()
1270            x = np.array([[xmin, xmax], [xmin, xmax]])
1271            y = np.array([[ymin, ymin], [ymax, ymax]])
1272            z = x*0 + offset
1273            u = np.array([[0., 1], [0, 1]])
1274            v = np.array([[0., 0], [1, 1]])
1275            h = ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False)
1276            _ipv_set_transparency(kw, h)
1277            h.material.side = "DoubleSide"
1278            return h
1279        def text(self, *args, **kw):
1280            """mpl style text interface for ipyvolume"""
1281            pass
1282        def set_xlim(self, limits):
1283            """mpl style set_xlim interface for ipyvolume"""
1284            ipv.xlim(*limits)
1285        def set_ylim(self, limits):
1286            """mpl style set_ylim interface for ipyvolume"""
1287            ipv.ylim(*limits)
1288        def set_zlim(self, limits):
1289            """mpl style set_zlim interface for ipyvolume"""
1290            ipv.zlim(*limits)
1291        def set_axes_on(self):
1292            """mpl style set_axes_on interface for ipyvolume"""
1293            ipv.style.axis_on()
1294        def set_axis_off(self):
1295            """mpl style set_axes_off interface for ipyvolume"""
1296            ipv.style.axes_off()
1297    return Axes()
1298
1299def _ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
1300    from IPython.display import display
1301    import ipywidgets as widgets
1302    import ipyvolume as ipv
1303
1304    axes = ipv_axes()
1305
1306    def _draw(view, jitter):
1307        camera = ipv.gcf().camera
1308        #print(ipv.gcf().__dict__.keys())
1309        #print(dir(ipv.gcf()))
1310        ipv.figure(animation=0.)  # no animation when updating object mesh
1311
1312        # set small jitter as 0 if multiple pd dims
1313        dims = sum(v > 0 for v in jitter)
1314        limit = [0, 0.5, 5, 5][dims]
1315        jitter = [0 if v < limit else v for v in jitter]
1316
1317        ## Visualize as person on globe
1318        #draw_beam(axes, (0, 0))
1319        #draw_sphere(axes)
1320        #draw_person_on_sphere(axes, view)
1321
1322        ## Move beam instead of shape
1323        #draw_beam(axes, view=(-view[0], -view[1]))
1324        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0))
1325
1326        ## Move shape and draw scattering
1327        draw_beam(axes, (0, 0), steps=25)
1328        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.0,
1329                    draw_shape=draw_shape, projection=projection)
1330        draw_mesh(axes, view, jitter, dist=dist, n=mesh, radius=0.95,
1331                  projection=projection)
1332        draw_scattering(calculator, axes, view, jitter, dist=dist)
1333
1334        draw_axes(axes, origin=(-1, -1, -1.1))
1335        ipv.style.box_off()
1336        ipv.style.axes_off()
1337        ipv.xyzlabel(" ", " ", " ")
1338
1339        ipv.gcf().camera = camera
1340        ipv.show()
1341
1342
1343    trange, prange = (-180., 180., 1.), (-180., 180., 1.)
1344    dtrange, dprange = (0., 180., 1.), (0., 180., 1.)
1345
1346    ## Super simple interfaca, but uses non-ascii variable namese
1347    # Ξ φ ψ Δξ Δφ Δψ
1348    #def update(**kw):
1349    #    view = kw['Ξ'], kw['φ'], kw['ψ']
1350    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ']
1351    #    draw(view, jitter)
1352    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange)
1353
1354    def _update(theta, phi, psi, dtheta, dphi, dpsi):
1355        _draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi))
1356
1357    def _slider(name, slice, init=0.):
1358        return widgets.FloatSlider(
1359            value=init,
1360            min=slice[0],
1361            max=slice[1],
1362            step=slice[2],
1363            description=name,
1364            disabled=False,
1365            #continuous_update=True,
1366            continuous_update=False,
1367            orientation='horizontal',
1368            readout=True,
1369            readout_format='.1f',
1370            )
1371    theta = _slider(u'Ξ', trange, view[0])
1372    phi = _slider(u'φ', prange, view[1])
1373    psi = _slider(u'ψ', prange, view[2])
1374    dtheta = _slider(u'Δξ', dtrange, jitter[0])
1375    dphi = _slider(u'Δφ', dprange, jitter[1])
1376    dpsi = _slider(u'Δψ', dprange, jitter[2])
1377    fields = {
1378        'theta': theta, 'phi': phi, 'psi': psi,
1379        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi,
1380    }
1381    ui = widgets.HBox([
1382        widgets.VBox([theta, phi, psi]),
1383        widgets.VBox([dtheta, dphi, dpsi])
1384    ])
1385
1386    out = widgets.interactive_output(_update, fields)
1387    display(ui, out)
1388
1389
1390_ENGINES = {
1391    "matplotlib": _mpl_plot,
1392    "mpl": _mpl_plot,
1393    #"plotly": _plotly_plot,
1394    "ipvolume": _ipv_plot,
1395    "ipv": _ipv_plot,
1396}
1397PLOT_ENGINE = _ENGINES["matplotlib"]
1398def set_plotter(name):
1399    """
1400    Setting the plotting engine to matplotlib/ipyvolume or equivalently mpl/ipv.
1401    """
1402    global PLOT_ENGINE
1403    PLOT_ENGINE = _ENGINES[name]
1404
1405def main():
1406    """
1407    Command line interface to the jitter viewer.
1408    """
1409    parser = argparse.ArgumentParser(
1410        description="Display jitter",
1411        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1412        )
1413    parser.add_argument('-p', '--projection', choices=PROJECTIONS,
1414                        default=PROJECTIONS[0],
1415                        help='coordinate projection')
1416    parser.add_argument('-s', '--size', type=str, default='10,40,100',
1417                        help='a,b,c lengths')
1418    parser.add_argument('-v', '--view', type=str, default='0,0,0',
1419                        help='initial view angles')
1420    parser.add_argument('-j', '--jitter', type=str, default='0,0,0',
1421                        help='initial angular dispersion')
1422    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS,
1423                        default=DISTRIBUTIONS[0],
1424                        help='jitter distribution')
1425    parser.add_argument('-m', '--mesh', type=int, default=30,
1426                        help='#points in theta-phi mesh')
1427    parser.add_argument('shape', choices=SHAPES, nargs='?', default=SHAPES[0],
1428                        help='oriented shape')
1429    opts = parser.parse_args()
1430    size = tuple(float(v) for v in opts.size.split(','))
1431    view = tuple(float(v) for v in opts.view.split(','))
1432    jitter = tuple(float(v) for v in opts.jitter.split(','))
1433    run(opts.shape, size=size, view=view, jitter=jitter,
1434        mesh=opts.mesh, dist=opts.distribution,
1435        projection=opts.projection)
1436
1437if __name__ == "__main__":
1438    main()
Note: See TracBrowser for help on using the repository browser.