source: sasmodels/sasmodels/jitter.py @ 4e28511

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 4e28511 was 4e28511, checked in by Paul Kienzle <pkienzle@…>, 7 months ago

more lint

  • Property mode set to 100755
File size: 54.7 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Jitter Explorer
5===============
6
7Application to explore orientation angle and angular dispersity.
8
9From the command line::
10
11    # Show docs
12    python -m sasmodels.jitter --help
13
14    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi
15    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0
16
17From a jupyter cell::
18
19    import ipyvolume as ipv
20    from sasmodels import jitter
21    import importlib; importlib.reload(jitter)
22    jitter.set_plotter("ipv")
23
24    size = (10, 40, 100)
25    view = (20, 0, 0)
26
27    #size = (15, 15, 100)
28    #view = (60, 60, 0)
29
30    dview = (0, 0, 0)
31    #dview = (5, 5, 0)
32    #dview = (15, 180, 0)
33    #dview = (180, 15, 0)
34
35    projection = 'equirectangular'
36    #projection = 'azimuthal_equidistance'
37    #projection = 'guyou'
38    #projection = 'sinusoidal'
39    #projection = 'azimuthal_equal_area'
40
41    dist = 'uniform'
42    #dist = 'gaussian'
43
44    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection)
45    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '')
46    #ipv.savefig(filename+'.png')
47"""
48from __future__ import division, print_function
49
50import argparse
51
52import numpy as np
53from numpy import pi, cos, sin, sqrt, exp, log, degrees, radians, arccos, arctan2
54
55# Too many complaints about variable names from pylint:
56#    a, b, c, u, v, x, y, z, dx, dy, dz, px, py, pz, R, Rx, Ry, Rz, ...
57# pylint: disable=invalid-name
58
59def draw_beam(axes, view=(0, 0), alpha=0.5, steps=25):
60    """
61    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
62    """
63    #axes.plot([0,0],[0,0],[1,-1])
64    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha)
65
66    u = np.linspace(0, 2 * pi, steps)
67    v = np.linspace(-1, 1, 2)
68
69    r = 0.02
70    x = r*np.outer(cos(u), np.ones_like(v))
71    y = r*np.outer(sin(u), np.ones_like(v))
72    z = 1.3*np.outer(np.ones_like(u), v)
73
74    theta, phi = view
75    shape = x.shape
76    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
77    points = Rz(phi)*Ry(theta)*points
78    x, y, z = [v.reshape(shape) for v in points]
79    axes.plot_surface(x, y, z, color='yellow', alpha=alpha)
80
81    # TODO: draw endcaps on beam
82    ## Drawing tiny balls on the end will work
83    #draw_sphere(axes, radius=0.02, center=(0, 0, 1.3), color='yellow', alpha=alpha)
84    #draw_sphere(axes, radius=0.02, center=(0, 0, -1.3), color='yellow', alpha=alpha)
85    ## The following does not work
86    #triangles = [(0, i+1, i+2) for i in range(steps-2)]
87    #x_cap, y_cap = x[:, 0], y[:, 0]
88    #for z_cap in z[:, 0], z[:, -1]:
89    #    axes.plot_trisurf(x_cap, y_cap, z_cap, triangles,
90    #                      color='yellow', alpha=alpha)
91
92
93def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1):
94    """Draw an ellipsoid."""
95    a, b, c = size
96    u = np.linspace(0, 2 * pi, steps)
97    v = np.linspace(0, pi, steps)
98    x = a*np.outer(cos(u), sin(v))
99    y = b*np.outer(sin(u), sin(v))
100    z = c*np.outer(np.ones_like(u), cos(v))
101    x, y, z = transform_xyz(view, jitter, x, y, z)
102
103    axes.plot_surface(x, y, z, color='w', alpha=alpha)
104
105    draw_labels(axes, view, jitter, [
106        ('c+', [+0, +0, +c], [+1, +0, +0]),
107        ('c-', [+0, +0, -c], [+0, +0, -1]),
108        ('a+', [+a, +0, +0], [+0, +0, +1]),
109        ('a-', [-a, +0, +0], [+0, +0, -1]),
110        ('b+', [+0, +b, +0], [-1, +0, +0]),
111        ('b-', [+0, -b, +0], [-1, +0, +0]),
112    ])
113
114def draw_sc(axes, size, view, jitter, steps=None, alpha=1):
115    """Draw points for simple cubic paracrystal"""
116    # pylint: disable=unused-argument
117    atoms = _build_sc()
118    _draw_crystal(axes, size, view, jitter, atoms=atoms)
119
120def draw_fcc(axes, size, view, jitter, steps=None, alpha=1):
121    """Draw points for face-centered cubic paracrystal"""
122    # pylint: disable=unused-argument
123    # Build the simple cubic crystal
124    atoms = _build_sc()
125    # Define the centers for each face
126    # x planes at -1, 0, 1 have four centers per plane, at +/- 0.5 in y and z
127    x, y, z = (
128        [-1]*4 + [0]*4 + [1]*4,
129        ([-0.5]*2 + [0.5]*2)*3,
130        [-0.5, 0.5]*12,
131    )
132    # y and z planes can be generated by substituting x for y and z respectively
133    atoms.extend(zip(x+y+z, y+z+x, z+x+y))
134    _draw_crystal(axes, size, view, jitter, atoms=atoms)
135
136def draw_bcc(axes, size, view, jitter, steps=None, alpha=1):
137    """Draw points for body-centered cubic paracrystal"""
138    # pylint: disable=unused-argument
139    # Build the simple cubic crystal
140    atoms = _build_sc()
141    # Define the centers for each octant
142    # x plane at +/- 0.5 have four centers per plane at +/- 0.5 in y and z
143    x, y, z = (
144        [-0.5]*4 + [0.5]*4,
145        ([-0.5]*2 + [0.5]*2)*2,
146        [-0.5, 0.5]*8,
147    )
148    atoms.extend(zip(x, y, z))
149    _draw_crystal(axes, size, view, jitter, atoms=atoms)
150
151def _draw_crystal(axes, size, view, jitter, atoms=None):
152    atoms, size = np.asarray(atoms, 'd').T, np.asarray(size, 'd')
153    x, y, z = atoms*size[:, None]
154    x, y, z = transform_xyz(view, jitter, x, y, z)
155    axes.scatter([x[0]], [y[0]], [z[0]], c='yellow', marker='o')
156    axes.scatter(x[1:], y[1:], z[1:], c='r', marker='o')
157
158def _build_sc():
159    # three planes of 9 dots for x at -1, 0 and 1
160    x, y, z = (
161        [-1]*9 + [0]*9 + [1]*9,
162        ([-1]*3 + [0]*3 + [1]*3)*3,
163        [-1, 0, 1]*9,
164    )
165    atoms = list(zip(x, y, z))
166    #print(list(enumerate(atoms)))
167    # Pull the dot at (0, 0, 1) to the front of the list
168    # It will be highlighted in the view
169    index = 14
170    highlight = atoms[index]
171    del atoms[index]
172    atoms.insert(0, highlight)
173    return atoms
174
175def draw_box(axes, size, view):
176    """Draw a wireframe box at a particular view."""
177    a, b, c = size
178    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
179    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
180    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
181    x, y, z = transform_xyz(view, None, x, y, z)
182    def _draw(i, j):
183        axes.plot([x[i], x[j]], [y[i], y[j]], [z[i], z[j]], color='black')
184    _draw(0, 1)
185    _draw(0, 2)
186    _draw(0, 3)
187    _draw(7, 4)
188    _draw(7, 5)
189    _draw(7, 6)
190
191def draw_parallelepiped(axes, size, view, jitter, steps=None,
192                        color=(0.6, 1.0, 0.6), alpha=1):
193    """Draw a parallelepiped surface, with view and jitter."""
194    # pylint: disable=unused-argument
195    a, b, c = size
196    x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
197    y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
198    z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
199    tri = np.array([
200        # counter clockwise triangles
201        # z: up/down, x: right/left, y: front/back
202        [0, 1, 2], [3, 2, 1], # top face
203        [6, 5, 4], [5, 6, 7], # bottom face
204        [0, 2, 6], [6, 4, 0], # right face
205        [1, 5, 7], [7, 3, 1], # left face
206        [2, 3, 6], [7, 6, 3], # front face
207        [4, 1, 0], [5, 1, 4], # back face
208    ])
209
210    x, y, z = transform_xyz(view, jitter, x, y, z)
211    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha,
212                      linewidth=0)
213
214    # Colour the c+ face of the box.
215    # Since I can't control face color, instead draw a thin box situated just
216    # in front of the "c+" face.  Use the c face so that rotations about psi
217    # rotate that face.
218    if 0: # pylint: disable=using-constant-test
219        color = (1, 0.6, 0.6)  # pink
220        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1])
221        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1])
222        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1])
223        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001)
224        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha)
225
226    draw_labels(axes, view, jitter, [
227        ('c+', [+0, +0, +c], [+1, +0, +0]),
228        ('c-', [+0, +0, -c], [+0, +0, -1]),
229        ('a+', [+a, +0, +0], [+0, +0, +1]),
230        ('a-', [-a, +0, +0], [+0, +0, -1]),
231        ('b+', [+0, +b, +0], [-1, +0, +0]),
232        ('b-', [+0, -b, +0], [-1, +0, +0]),
233    ])
234
235def draw_sphere(axes, radius=1.0, steps=25,
236                center=(0, 0, 0), color='w', alpha=1.):
237    """Draw a sphere"""
238    u = np.linspace(0, 2 * pi, steps)
239    v = np.linspace(0, pi, steps)
240
241    x = radius * np.outer(cos(u), sin(v)) + center[0]
242    y = radius * np.outer(sin(u), sin(v)) + center[1]
243    z = radius * np.outer(np.ones(np.size(u)), cos(v)) + center[2]
244    axes.plot_surface(x, y, z, color=color, alpha=alpha)
245    #axes.plot_wireframe(x, y, z)
246
247def draw_axes(axes, origin=(-1, -1, -1), length=(2, 2, 2)):
248    """Draw wireframe axes lines, with given origin and length"""
249    x, y, z = origin
250    dx, dy, dz = length
251    axes.plot([x, x+dx], [y, y], [z, z], color='black')
252    axes.plot([x, x], [y, y+dy], [z, z], color='black')
253    axes.plot([x, x], [y, y], [z, z+dz], color='black')
254
255def draw_person_on_sphere(axes, view, height=0.5, radius=1.0):
256    """
257    Draw a person on the surface of a sphere.
258
259    *view* indicates (latitude, longitude, orientation)
260    """
261    limb_offset = height * 0.05
262    head_radius = height * 0.10
263    head_height = height - head_radius
264    neck_length = head_radius * 0.50
265    shoulder_height = height - 2*head_radius - neck_length
266    torso_length = shoulder_height * 0.55
267    torso_radius = torso_length * 0.30
268    leg_length = shoulder_height - torso_length
269    arm_length = torso_length * 0.90
270
271    def _draw_part(y, z):
272        x = np.zeros_like(y)
273        xp, yp, zp = transform_xyz(view, None, x, y, z + radius)
274        axes.plot(xp, yp, zp, color='k')
275
276    # circle for head
277    u = np.linspace(0, 2 * pi, 40)
278    y = head_radius * cos(u)
279    z = head_radius * sin(u) + head_height
280    _draw_part(y, z)
281
282    # rectangle for body
283    y = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius])
284    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length
285    _draw_part(y, z)
286
287    # arms
288    y = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius])
289    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height])
290    _draw_part(y, z)
291    _draw_part(-y, z)  # pylint: disable=invalid-unary-operand-type
292
293    # legs
294    y = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset])
295    z = np.array([0, leg_length])
296    _draw_part(y, z)
297    _draw_part(-y, z)  # pylint: disable=invalid-unary-operand-type
298
299    limits = [-radius-height, radius+height]
300    axes.set_xlim(limits)
301    axes.set_ylim(limits)
302    axes.set_zlim(limits)
303    axes.set_axis_off()
304
305def draw_jitter(axes, view, jitter, dist='gaussian',
306                size=(0.1, 0.4, 1.0),
307                draw_shape=draw_parallelepiped,
308                projection='equirectangular',
309                alpha=0.8,
310                views=None):
311    """
312    Represent jitter as a set of shapes at different orientations.
313    """
314    project, project_weight = get_projection(projection)
315
316    # set max diagonal to 0.95
317    scale = 0.95/sqrt(sum(v**2 for v in size))
318    size = tuple(scale*v for v in size)
319
320    dtheta, dphi, dpsi = jitter
321    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist]
322    def _steps(delta):
323        if views is None:
324            n = max(3, min(25, 2*int(base*delta/5)))
325        else:
326            n = views
327        return base*delta*np.linspace(-1, 1, n) if delta > 0 else [0.]
328    for theta in _steps(dtheta):
329        for phi in _steps(dphi):
330            for psi in _steps(dpsi):
331                w = project_weight(theta, phi, 1.0, 1.0)
332                if w > 0:
333                    dview = project(theta, phi, psi)
334                    draw_shape(axes, size, view, dview, alpha=alpha)
335    for v in 'xyz':
336        a, b, c = size
337        lim = sqrt(a**2 + b**2 + c**2)
338        getattr(axes, 'set_'+v+'lim')([-lim, lim])
339        #getattr(axes, v+'axis').label.set_text(v)
340
341PROJECTIONS = [
342    # in order of PROJECTION number; do not change without updating the
343    # constants in kernel_iq.c
344    'equirectangular', 'sinusoidal', 'guyou', 'azimuthal_equidistance',
345    'azimuthal_equal_area',
346]
347def get_projection(projection):
348
349    """
350    jitter projections
351    <https://en.wikipedia.org/wiki/List_of_map_projections>
352
353    equirectangular (standard latitude-longitude mesh)
354        <https://en.wikipedia.org/wiki/Equirectangular_projection>
355        Allows free movement in phi (around the equator), but theta is
356        limited to +/- 90, and points are cos-weighted. Jitter in phi is
357        uniform in weight along a line of latitude.  With small theta and
358        phi ranging over +/- 180 this forms a wobbling disk.  With small
359        phi and theta ranging over +/- 90 this forms a wedge like a slice
360        of an orange.
361    azimuthal_equidistance (Postel)
362        <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>
363        Preserves distance from center, and so is an excellent map for
364        representing a bivariate gaussian on the surface.  Theta and phi
365        operate identically, cutting wegdes from the antipode of the viewing
366        angle.  This unfortunately does not allow free movement in either
367        theta or phi since the orthogonal wobble decreases to 0 as the body
368        rotates through 180 degrees.
369    sinusoidal (Sanson-Flamsteed, Mercator equal-area)
370        <https://en.wikipedia.org/wiki/Sinusoidal_projection>
371        Preserves arc length with latitude, giving bad behaviour at
372        theta near +/- 90.  Theta and phi operate somewhat differently,
373        so a system with a-b-c dtheta-dphi-dpsi will not give the same
374        value as one with b-a-c dphi-dtheta-dpsi, as would be the case
375        for azimuthal equidistance.  Free movement using theta or phi
376        uniform over +/- 180 will work, but not as well as equirectangular
377        phi, with theta being slightly worse.  Computationally it is much
378        cheaper for wide theta-phi meshes since it excludes points which
379        lie outside the sinusoid near theta +/- 90 rather than packing
380        them close together as in equirectangle.  Note that the poles
381        will be slightly overweighted for theta > 90 with the circle
382        from theta at 90+dt winding backwards around the pole, overlapping
383        the circle from theta at 90-dt.
384    Guyou (hemisphere-in-a-square) **not weighted**
385        <https://en.wikipedia.org/wiki/Guyou_hemisphere-in-a-square_projection>
386        With tiling, allows rotation in phi or theta through +/- 180, with
387        uniform spacing.  Both theta and phi allow free rotation, with wobble
388        in the orthogonal direction reasonably well behaved (though not as
389        good as equirectangular phi). The forward/reverse transformations
390        relies on elliptic integrals that are somewhat expensive, so the
391        behaviour has to be very good to justify the cost and complexity.
392        The weighting function for each point has not yet been computed.
393        Note: run the module *guyou.py* directly and it will show the forward
394        and reverse mappings.
395    azimuthal_equal_area  **incomplete**
396        <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>
397        Preserves the relative density of the surface patches.  Not that
398        useful and not completely implemented
399    Gauss-Kreuger **not implemented**
400        <https://en.wikipedia.org/wiki/Transverse_Mercator_projection#Ellipsoidal_transverse_Mercator>
401        Should allow free movement in theta, but phi is distorted.
402    """
403    # pylint: disable=unused-argument
404    # TODO: try Kent distribution instead of a gaussian warped by projection
405
406    if projection == 'equirectangular':  #define PROJECTION 1
407        def _project(theta_i, phi_j, psi):
408            latitude, longitude = theta_i, phi_j
409            return latitude, longitude, psi, 'xyz'
410            #return Rx(phi_j)*Ry(theta_i)
411        def _weight(theta_i, phi_j, w_i, w_j):
412            return w_i*w_j*abs(cos(radians(theta_i)))
413    elif projection == 'sinusoidal':  #define PROJECTION 2
414        def _project(theta_i, phi_j, psi):
415            latitude = theta_i
416            scale = cos(radians(latitude))
417            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0
418            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
419            return latitude, longitude, psi, 'xyz'
420            #return Rx(longitude)*Ry(latitude)
421        def _project(theta_i, phi_j, w_i, w_j):
422            latitude = theta_i
423            scale = cos(radians(latitude))
424            active = 1 if abs(phi_j) < abs(scale)*180 else 0
425            return active*w_i*w_j
426    elif projection == 'guyou':  #define PROJECTION 3  (eventually?)
427        def _project(theta_i, phi_j, psi):
428            from .guyou import guyou_invert
429            #latitude, longitude = guyou_invert([theta_i], [phi_j])
430            longitude, latitude = guyou_invert([phi_j], [theta_i])
431            return latitude, longitude, psi, 'xyz'
432            #return Rx(longitude[0])*Ry(latitude[0])
433        def _weight(theta_i, phi_j, w_i, w_j):
434            return w_i*w_j
435    elif projection == 'azimuthal_equidistance':
436        # Note that calculates angles for Rz Ry rather than Rx Ry
437        def _project(theta_i, phi_j, psi):
438            latitude = sqrt(theta_i**2 + phi_j**2)
439            longitude = degrees(arctan2(phi_j, theta_i))
440            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
441            return latitude, longitude, psi-longitude, 'zyz'
442            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
443            #return R_to_xyz(R)
444            #return Rz(longitude)*Ry(latitude)
445        def _weight(theta_i, phi_j, w_i, w_j):
446            # Weighting for each point comes from the integral:
447            #     \int\int I(q, lat, log) sin(lat) dlat dlog
448            # We are doing a conformal mapping from disk to sphere, so we need
449            # a change of variables g(theta, phi) -> (lat, long):
450            #     lat, long = sqrt(theta^2 + phi^2), arctan(phi/theta)
451            # giving:
452            #     dtheta dphi = det(J) dlat dlong
453            # where J is the jacobian from the partials of g. Using
454            #     R = sqrt(theta^2 + phi^2),
455            # then
456            #     J = [[x/R, Y/R], -y/R^2, x/R^2]]
457            # and
458            #     det(J) = 1/R
459            # with the final integral being:
460            #    \int\int I(q, theta, phi) sin(R)/R dtheta dphi
461            #
462            # This does approximately the right thing, decreasing the weight
463            # of each point as you go farther out on the disk, but it hasn't
464            # yet been checked against the 1D integral results. Prior
465            # to declaring this "good enough" and checking that integrals
466            # work in practice, we will examine alternative mappings.
467            #
468            # The issue is that the mapping does not support the case of free
469            # rotation about a single axis correctly, with a small deviation
470            # in the orthogonal axis independent of the first axis.  Like the
471            # usual polar coordiates integration, the integrated sections
472            # form wedges, though at least in this case the wedge cuts through
473            # the entire sphere, and treats theta and phi identically.
474            latitude = sqrt(theta_i**2 + phi_j**2)
475            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
476            return weight*w_i*w_j if latitude < 180 else 0
477    elif projection == 'azimuthal_equal_area':
478        # Note that calculates angles for Rz Ry rather than Rx Ry
479        def _project(theta_i, phi_j, psi):
480            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180)
481            latitude = 180-degrees(2*arccos(radius))
482            longitude = degrees(arctan2(phi_j, theta_i))
483            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
484            return latitude, longitude, psi, 'zyz'
485            #R = Rz(longitude)*Ry(latitude)*Rz(psi)
486            #return R_to_xyz(R)
487            #return Rz(longitude)*Ry(latitude)
488        def _weight(theta_i, phi_j, w_i, w_j):
489            latitude = sqrt(theta_i**2 + phi_j**2)
490            weight = sin(radians(latitude))/latitude if latitude != 0 else 1
491            return weight*w_i*w_j if latitude < 180 else 0
492    else:
493        raise ValueError("unknown projection %r"%projection)
494
495    return _project, _weight
496
497def R_to_xyz(R):
498    """
499    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix.
500
501    Extracting Euler Angles from a Rotation Matrix
502    Mike Day, Insomniac Games
503    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf
504    Based on: Shoemake’s "Euler Angle Conversion", Graphics Gems IV, pp.  222-229
505    """
506    phi = arctan2(R[1, 2], R[2, 2])
507    theta = arctan2(-R[0, 2], sqrt(R[0, 0]**2 + R[0, 1]**2))
508    psi = arctan2(R[0, 1], R[0, 0])
509    return degrees(phi), degrees(theta), degrees(psi)
510
511def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian',
512              projection='equirectangular'):
513    """
514    Draw the dispersion mesh showing the theta-phi orientations at which
515    the model will be evaluated.
516    """
517
518    _project, _weight = get_projection(projection)
519    def _rotate(theta, phi, z):
520        dview = _project(theta, phi, 0.)
521        if dview[3] == 'zyz':
522            return Rz(dview[1])*Ry(dview[0])*z
523        else:  # dview[3] == 'xyz':
524            return Rx(dview[1])*Ry(dview[0])*z
525
526
527    dist_x = np.linspace(-1, 1, n)
528    weights = np.ones_like(dist_x)
529    if dist == 'gaussian':
530        dist_x *= 3
531        weights = exp(-0.5*dist_x**2)
532    elif dist == 'rectangle':
533        # Note: uses sasmodels ridiculous definition of rectangle width
534        dist_x *= sqrt(3)
535    elif dist == 'uniform':
536        pass
537    else:
538        raise ValueError("expected dist to be gaussian, rectangle or uniform")
539
540    # mesh in theta, phi formed by rotating z
541    dtheta, dphi, dpsi = jitter  # pylint: disable=unused-variable
542    z = np.matrix([[0], [0], [radius]])
543    points = np.hstack([_rotate(theta_i, phi_j, z)
544                        for theta_i in dtheta*dist_x
545                        for phi_j in dphi*dist_x])
546    dist_w = np.array([_weight(theta_i, phi_j, w_i, w_j)
547                       for w_i, theta_i in zip(weights, dtheta*dist_x)
548                       for w_j, phi_j in zip(weights, dphi*dist_x)])
549    #print(max(dist_w), min(dist_w), min(dist_w[dist_w > 0]))
550    points = points[:, dist_w > 0]
551    dist_w = dist_w[dist_w > 0]
552    dist_w /= max(dist_w)
553
554    # rotate relative to beam
555    points = orient_relative_to_beam(view, points)
556
557    x, y, z = [np.array(v).flatten() for v in points]
558    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
559    axes.scatter(x, y, z, c=dist_w, marker='o', vmin=0., vmax=1.)
560
561def draw_labels(axes, view, jitter, text):
562    """
563    Draw text at a particular location.
564    """
565    labels, locations, orientations = zip(*text)
566    px, py, pz = zip(*locations)
567    dx, dy, dz = zip(*orientations)
568
569    px, py, pz = transform_xyz(view, jitter, px, py, pz)
570    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
571
572    # TODO: zdir for labels is broken, and labels aren't appearing.
573    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
574        zdir = np.asarray(zdir).flatten()
575        axes.text(p[0], p[1], p[2], label, zdir=zdir)
576
577# Definition of rotation matrices comes from wikipedia:
578#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
579def Rx(angle):
580    """Construct a matrix to rotate points about *x* by *angle* degrees."""
581    angle = radians(angle)
582    rot = [[1, 0, 0],
583           [0, +cos(angle), -sin(angle)],
584           [0, +sin(angle), +cos(angle)]]
585    return np.matrix(rot)
586
587def Ry(angle):
588    """Construct a matrix to rotate points about *y* by *angle* degrees."""
589    angle = radians(angle)
590    rot = [[+cos(angle), 0, +sin(angle)],
591           [0, 1, 0],
592           [-sin(angle), 0, +cos(angle)]]
593    return np.matrix(rot)
594
595def Rz(angle):
596    """Construct a matrix to rotate points about *z* by *angle* degrees."""
597    angle = radians(angle)
598    rot = [[+cos(angle), -sin(angle), 0],
599           [+sin(angle), +cos(angle), 0],
600           [0, 0, 1]]
601    return np.matrix(rot)
602
603def transform_xyz(view, jitter, x, y, z):
604    """
605    Send a set of (x,y,z) points through the jitter and view transforms.
606    """
607    x, y, z = [np.asarray(v) for v in (x, y, z)]
608    shape = x.shape
609    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
610    points = apply_jitter(jitter, points)
611    points = orient_relative_to_beam(view, points)
612    x, y, z = [np.array(v).reshape(shape) for v in points]
613    return x, y, z
614
615def apply_jitter(jitter, points):
616    """
617    Apply the jitter transform to a set of points.
618
619    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
620    """
621    if jitter is None:
622        return points
623    # Hack to deal with the fact that azimuthal_equidistance uses euler angles
624    if len(jitter) == 4:
625        dtheta, dphi, dpsi, _ = jitter
626        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points
627    else:
628        dtheta, dphi, dpsi = jitter
629        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
630    return points
631
632def orient_relative_to_beam(view, points):
633    """
634    Apply the view transform to a set of points.
635
636    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
637    """
638    theta, phi, psi = view
639    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle
640    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles
641    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle
642    return points
643
644def orient_relative_to_beam_quaternion(view, points):
645    """
646    Apply the view transform to a set of points.
647
648    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
649
650    This variant uses quaternions rather than rotation matrices for the
651    computation.  It works but it is not used because it doesn't solve
652    any problems.  The challenge of mapping theta/phi/psi to SO(3) does
653    not disappear by calculating the transform differently.
654    """
655    theta, phi, psi = view
656    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1]
657    q = Quaternion(1, [0, 0, 0])
658    ## Compose a rotation about the three axes by rotating
659    ## the unit vectors before applying the rotation.
660    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q
661    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q
662    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q
663    ## The above turns out to be equivalent to reversing
664    ## the order of application, so ignore it and use below.
665    q = q * Quaternion.from_angle_axis(theta, x)
666    q = q * Quaternion.from_angle_axis(phi, y)
667    q = q * Quaternion.from_angle_axis(psi, z)
668    ## Reverse the order by post-multiply rather than pre-multiply
669    #q = Quaternion.from_angle_axis(theta, x) * q
670    #q = Quaternion.from_angle_axis(phi, y) * q
671    #q = Quaternion.from_angle_axis(psi, z) * q
672    #print("axes psi", q.rot(np.matrix([x, y, z]).T))
673    return q.rot(points)
674#orient_relative_to_beam = orient_relative_to_beam_quaternion
675
676# === Quaterion class definition === BEGIN
677# Simple stand-alone quaternion class
678
679# Note: this code works but isn't unused since quaternions didn't solve the
680# representation problem.  Leave it here in case we want to revisit this later.
681
682#import numpy as np
683class Quaternion(object):
684    r"""
685    Quaternion(w, r) = w + ir[0] + jr[1] + kr[2]
686
687    Quaternion.from_angle_axis(theta, r) for a rotation of angle theta about
688    an axis oriented toward the direction r.  This defines a unit quaternion,
689    normalizing $r$ to the unit vector $\hat r$, and setting quaternion
690    $Q = \cos \theta + \sin \theta \hat r$
691
692    Quaternion objects can be multiplied, which applies a rotation about the
693    given axis, allowing composition of rotations without risk of gimbal lock.
694    The resulting quaternion is applied to a set of points using *Q.rot(v)*.
695    """
696    def __init__(self, w, r):
697        self.w = w
698        self.r = np.asarray(r, dtype='d')
699
700    @staticmethod
701    def from_angle_axis(theta, r):
702        """Build quaternion as rotation theta about axis r"""
703        theta = np.radians(theta)/2
704        r = np.asarray(r)
705        w = np.cos(theta)
706        r = np.sin(theta)*r/np.dot(r, r)
707        return Quaternion(w, r)
708
709    def __mul__(self, other):
710        """Multiply quaterions"""
711        if isinstance(other, Quaternion):
712            w = self.w*other.w - np.dot(self.r, other.r)
713            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r)
714            return Quaternion(w, r)
715        raise NotImplementedError("Quaternion * non-quaternion not implemented")
716
717    def rot(self, v):
718        """Transform point *v* by quaternion"""
719        v = np.asarray(v).T
720        use_transpose = (v.shape[-1] != 3)
721        if use_transpose:
722            v = v.T
723        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v)
724        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v))
725        if use_transpose:
726            v = v.T
727        return v.T
728
729    def conj(self):
730        """Conjugate quaternion"""
731        return Quaternion(self.w, -self.r)
732
733    def inv(self):
734        """Inverse quaternion"""
735        return self.conj()/self.norm()**2
736
737    def norm(self):
738        """Quaternion length"""
739        return np.sqrt(self.w**2 + np.sum(self.r**2))
740
741    def __str__(self):
742        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2])
743
744def test_qrot():
745    """Quaternion checks"""
746    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees
747    # from y.  The rotation axis is determined by rotating the point [0, 1, 0]
748    # about x.
749    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0])
750    q = Quaternion.from_angle_axis(60, ax)
751    # Set the point to be rotated, and its expected rotated position.
752    p = [1, -1, 2]
753    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8]
754    #print(q, q.rot(p) - target)
755    assert max(abs(q.rot(p) - target)) < 1e-14
756#test_qrot()
757#import sys; sys.exit()
758# === Quaterion class definition === END
759
760# translate between number of dimension of dispersity and the number of
761# points along each dimension.
762PD_N_TABLE = {
763    (0, 0, 0): (0, 0, 0),     # 0
764    (1, 0, 0): (100, 0, 0),   # 100
765    (0, 1, 0): (0, 100, 0),
766    (0, 0, 1): (0, 0, 100),
767    (1, 1, 0): (30, 30, 0),   # 900
768    (1, 0, 1): (30, 0, 30),
769    (0, 1, 1): (0, 30, 30),
770    (1, 1, 1): (15, 15, 15),  # 3375
771}
772
773def clipped_range(data, portion=1.0, mode='central'):
774    """
775    Determine range from data.
776
777    If *portion* is 1, use full range, otherwise use the center of the range
778    or the top of the range, depending on whether *mode* is 'central' or 'top'.
779    """
780    if portion < 1.0:
781        if mode == 'central':
782            data = np.sort(data.flatten())
783            offset = int(portion*len(data)/2 + 0.5)
784            return data[offset], data[-offset]
785        if mode == 'top':
786            data = np.sort(data.flatten())
787            offset = int(portion*len(data) + 0.5)
788            return data[offset], data[-1]
789    # Default: full range
790    return data.min(), data.max()
791
792def draw_scattering(calculator, axes, view, jitter, dist='gaussian'):
793    """
794    Plot the scattering for the particular view.
795
796    *calculator* is returned from :func:`build_model`.  *axes* are the 3D axes
797    on which the data will be plotted.  *view* and *jitter* are the current
798    orientation and orientation dispersity.  *dist* is one of the sasmodels
799    weight distributions.
800    """
801    if dist == 'uniform':  # uniform is not yet in this branch
802        dist, scale = 'rectangle', 1/sqrt(3)
803    else:
804        scale = 1
805
806    # add the orientation parameters to the model parameters
807    theta, phi, psi = view
808    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
809    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd > 0, phi_pd > 0, psi_pd > 0)]
810    ## increase pd_n for testing jitter integration rather than simple viz
811    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
812
813    pars = dict(
814        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
815        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
816        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
817    )
818    pars.update(calculator.pars)
819
820    # compute the pattern
821    qx, qy = calculator.qxqy
822    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
823
824    # scale it and draw it
825    Iqxy = log(Iqxy)
826    if calculator.limits:
827        # use limits from orientation (0,0,0)
828        vmin, vmax = calculator.limits
829    else:
830        vmax = Iqxy.max()
831        vmin = vmax*10**-7
832        #vmin, vmax = clipped_range(Iqxy, portion=portion, mode='top')
833    #vmin, vmax = Iqxy.min(), Iqxy.max()
834    #print("range",(vmin,vmax))
835    #qx, qy = np.meshgrid(qx, qy)
836    if 0:  # pylint: disable=using-constant-test
837        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
838        level[level < 0] = 0
839        from matplotlib import pylab as plt
840        colors = plt.get_cmap()(level)
841        #from matplotlib import cm
842        #colors = cm.coolwarm(level)
843        #colors = cm.gist_yarg(level)
844        #colors = cm.Wistia(level)
845        colors[level <= 0, 3] = 0.  # set floor to transparent
846        x, y = np.meshgrid(qx/qx.max(), qy/qy.max())
847        axes.plot_surface(x, y, -1.1*np.ones_like(x), facecolors=colors)
848    elif 1:  # pylint: disable=using-constant-test
849        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
850                      levels=np.linspace(vmin, vmax, 24))
851    else:
852        axes.pcolormesh(qx, qy, Iqxy)
853
854def build_model(model_name, n=150, qmax=0.5, **pars):
855    """
856    Build a calculator for the given shape.
857
858    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
859    on which to evaluate the model.  The remaining parameters are stored in
860    the returned calculator as *calculator.pars*.  They are used by
861    :func:`draw_scattering` to set the non-orientation parameters in the
862    calculation.
863
864    Returns a *calculator* function which takes a dictionary or parameters and
865    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
866    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
867    for details.
868    """
869    from sasmodels.core import load_model_info, build_model as build_sasmodel
870    from sasmodels.data import empty_data2D
871    from sasmodels.direct_model import DirectModel
872
873    model_info = load_model_info(model_name)
874    model = build_sasmodel(model_info) #, dtype='double!')
875    q = np.linspace(-qmax, qmax, n)
876    data = empty_data2D(q, q)
877    calculator = DirectModel(data, model)
878
879    # Remember the data axes so we can plot the results
880    calculator.qxqy = (q, q)
881
882    # stuff the values for non-orientation parameters into the calculator
883    calculator.pars = pars.copy()
884    calculator.pars.setdefault('backgound', 1e-3)
885
886    # fix the data limits so that we can see if the pattern fades
887    # under rotation or angular dispersion
888    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
889    Iqxy = log(Iqxy)
890    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
891    calculator.limits = vmin, vmax+1
892
893    return calculator
894
895def select_calculator(model_name, n=150, size=(10, 40, 100)):
896    """
897    Create a model calculator for the given shape.
898
899    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
900    parallelepiped or bcc_paracrystal. *n* is the number of points to use
901    in the q range.  *qmax* is chosen based on model parameters for the
902    given model to show something intersting.
903
904    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
905    equitorial axes and polar axis respectively.  See :func:`build_model`
906    for details on the returned calculator.
907    """
908    a, b, c = size
909    d_factor = 0.06  # for paracrystal models
910    if model_name == 'sphere':
911        calculator = build_model('sphere', n=n, radius=c)
912        a = b = c
913    elif model_name == 'sc_paracrystal':
914        a = b = c
915        dnn = c
916        radius = 0.5*c
917        calculator = build_model('sc_paracrystal', n=n, dnn=dnn,
918                                 d_factor=d_factor, radius=(1-d_factor)*radius,
919                                 background=0)
920    elif model_name == 'fcc_paracrystal':
921        a = b = c
922        # nearest neigbour distance dnn should be 2 radius, but I think the
923        # model uses lattice spacing rather than dnn in its calculations
924        dnn = 0.5*c
925        radius = sqrt(2)/4 * c
926        calculator = build_model('fcc_paracrystal', n=n, dnn=dnn,
927                                 d_factor=d_factor, radius=(1-d_factor)*radius,
928                                 background=0)
929    elif model_name == 'bcc_paracrystal':
930        a = b = c
931        # nearest neigbour distance dnn should be 2 radius, but I think the
932        # model uses lattice spacing rather than dnn in its calculations
933        dnn = 0.5*c
934        radius = sqrt(3)/2 * c
935        calculator = build_model('bcc_paracrystal', n=n, dnn=dnn,
936                                 d_factor=d_factor, radius=(1-d_factor)*radius,
937                                 background=0)
938    elif model_name == 'cylinder':
939        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
940        a = b
941    elif model_name == 'ellipsoid':
942        calculator = build_model('ellipsoid', n=n, qmax=1.0,
943                                 radius_polar=c, radius_equatorial=b)
944        a = b
945    elif model_name == 'triaxial_ellipsoid':
946        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
947                                 radius_equat_minor=a,
948                                 radius_equat_major=b,
949                                 radius_polar=c)
950    elif model_name == 'parallelepiped':
951        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
952    else:
953        raise ValueError("unknown model %s"%model_name)
954
955    return calculator, (a, b, c)
956
957SHAPES = [
958    'parallelepiped',
959    'sphere', 'ellipsoid', 'triaxial_ellipsoid',
960    'cylinder',
961    'fcc_paracrystal', 'bcc_paracrystal', 'sc_paracrystal',
962]
963
964DRAW_SHAPES = {
965    'fcc_paracrystal': draw_fcc,
966    'bcc_paracrystal': draw_bcc,
967    'sc_paracrystal': draw_sc,
968    'parallelepiped': draw_parallelepiped,
969}
970
971DISTRIBUTIONS = [
972    'gaussian', 'rectangle', 'uniform',
973]
974DIST_LIMITS = {
975    'gaussian': 30,
976    'rectangle': 90/sqrt(3),
977    'uniform': 90,
978}
979
980
981def run(model_name='parallelepiped', size=(10, 40, 100),
982        view=(0, 0, 0), jitter=(0, 0, 0),
983        dist='gaussian', mesh=30,
984        projection='equirectangular'):
985    """
986    Show an interactive orientation and jitter demo.
987
988    *model_name* is one of: sphere, ellipsoid, triaxial_ellipsoid,
989    parallelepiped, cylinder, or sc/fcc/bcc_paracrystal
990
991    *size* gives the dimensions (a, b, c) of the shape.
992
993    *view* gives the initial view (theta, phi, psi) of the shape.
994
995    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape.
996
997    *dist* is the type of dispersition: gaussian, rectangle, or uniform.
998
999    *mesh* is the number of points in the dispersion mesh.
1000
1001    *projection* is the map projection to use for the mesh: equirectangular,
1002    sinusoidal, guyou, azimuthal_equidistance, or azimuthal_equal_area.
1003    """
1004    # projection number according to 1-order position in list, but
1005    # only 1 and 2 are implemented so far.
1006    from sasmodels import generate
1007    generate.PROJECTION = PROJECTIONS.index(projection) + 1
1008    if generate.PROJECTION > 2:
1009        print("*** PROJECTION %s not implemented in scattering function ***"%projection)
1010        generate.PROJECTION = 2
1011
1012    # set up calculator
1013    calculator, size = select_calculator(model_name, n=150, size=size)
1014    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped)
1015    #draw_shape = draw_fcc
1016
1017    ## uncomment to set an independent the colour range for every view
1018    ## If left commented, the colour range is fixed for all views
1019    calculator.limits = None
1020
1021    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection)
1022
1023def _mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
1024    # Note: travis-ci does not support mpl_toolkits.mplot3d, but this shouldn't be
1025    # an issue since we are lazy-loading the package on a path that isn't tested.
1026    # Importing mplot3d adds projection='3d' option to subplot
1027    import mpl_toolkits.mplot3d  # pylint: disable=unused-variable
1028    import matplotlib as mpl
1029    import matplotlib.pyplot as plt
1030    from matplotlib.widgets import Slider
1031
1032    ## create the plot window
1033    #plt.hold(True)
1034    plt.subplots(num=None, figsize=(5.5, 5.5))
1035    plt.set_cmap('gist_earth')
1036    plt.clf()
1037    plt.gcf().canvas.set_window_title(projection)
1038    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
1039    #axes = plt.subplot(gs[0], projection='3d')
1040    axes = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
1041    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
1042        axes.axis('square')
1043    except Exception:
1044        pass
1045
1046    # CRUFT: use axisbg instead of facecolor for matplotlib<2
1047    facecolor_prop = 'facecolor' if mpl.__version__ > '2' else 'axisbg'
1048    props = {facecolor_prop: 'lightgoldenrodyellow'}
1049
1050    ## add control widgets to plot
1051    axes_theta = plt.axes([0.05, 0.15, 0.50, 0.04], **props)
1052    axes_phi = plt.axes([0.05, 0.10, 0.50, 0.04], **props)
1053    axes_psi = plt.axes([0.05, 0.05, 0.50, 0.04], **props)
1054    stheta = Slider(axes_theta, u'Ξ', -90, 90, valinit=0)
1055    sphi = Slider(axes_phi, u'φ', -180, 180, valinit=0)
1056    spsi = Slider(axes_psi, u'ψ', -180, 180, valinit=0)
1057
1058    axes_dtheta = plt.axes([0.70, 0.15, 0.20, 0.04], **props)
1059    axes_dphi = plt.axes([0.70, 0.1, 0.20, 0.04], **props)
1060    axes_dpsi = plt.axes([0.70, 0.05, 0.20, 0.04], **props)
1061
1062    # Note: using ridiculous definition of rectangle distribution, whose width
1063    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
1064    # the maximum width to 90.
1065    dlimit = DIST_LIMITS[dist]
1066    sdtheta = Slider(axes_dtheta, u'Δξ', 0, 2*dlimit, valinit=0)
1067    sdphi = Slider(axes_dphi, u'Δφ', 0, 2*dlimit, valinit=0)
1068    sdpsi = Slider(axes_dpsi, u'Δψ', 0, 2*dlimit, valinit=0)
1069
1070    ## initial view and jitter
1071    theta, phi, psi = view
1072    stheta.set_val(theta)
1073    sphi.set_val(phi)
1074    spsi.set_val(psi)
1075    dtheta, dphi, dpsi = jitter
1076    sdtheta.set_val(dtheta)
1077    sdphi.set_val(dphi)
1078    sdpsi.set_val(dpsi)
1079
1080    ## callback to draw the new view
1081    def _update(val, axis=None):
1082        # pylint: disable=unused-argument
1083        view = stheta.val, sphi.val, spsi.val
1084        jitter = sdtheta.val, sdphi.val, sdpsi.val
1085        # set small jitter as 0 if multiple pd dims
1086        dims = sum(v > 0 for v in jitter)
1087        limit = [0, 0.5, 5, 5][dims]
1088        jitter = [0 if v < limit else v for v in jitter]
1089        axes.cla()
1090
1091        ## Visualize as person on globe
1092        #draw_sphere(axes, radius=0.5)
1093        #draw_person_on_sphere(axes, view, radius=0.5)
1094
1095        ## Move beam instead of shape
1096        #draw_beam(axes, -view[:2])
1097        #draw_jitter(axes, (0,0,0), (0,0,0), views=3)
1098
1099        ## Move shape and draw scattering
1100        draw_beam(axes, (0, 0), alpha=1.)
1101        #draw_person_on_sphere(axes, view, radius=1.2, height=0.5)
1102        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.,
1103                    draw_shape=draw_shape, projection=projection, views=3)
1104        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection)
1105        draw_scattering(calculator, axes, view, jitter, dist=dist)
1106
1107        plt.gcf().canvas.draw()
1108
1109    ## bind control widgets to view updater
1110    stheta.on_changed(lambda v: _update(v, 'theta'))
1111    sphi.on_changed(lambda v: _update(v, 'phi'))
1112    spsi.on_changed(lambda v: _update(v, 'psi'))
1113    sdtheta.on_changed(lambda v: _update(v, 'dtheta'))
1114    sdphi.on_changed(lambda v: _update(v, 'dphi'))
1115    sdpsi.on_changed(lambda v: _update(v, 'dpsi'))
1116
1117    ## initialize view
1118    _update(None, 'phi')
1119
1120    ## go interactive
1121    plt.show()
1122
1123
1124def map_colors(z, kw):
1125    """
1126    Process matplotlib-style colour arguments.
1127
1128    Pulls 'cmap', 'alpha', 'vmin', and 'vmax' from th *kw* dictionary, setting
1129    the *kw['color']* to an RGB array.  These are ignored if 'c' or 'color' are
1130    set inside *kw*.
1131    """
1132    from matplotlib import cm
1133
1134    cmap = kw.pop('cmap', cm.coolwarm)
1135    alpha = kw.pop('alpha', None)
1136    vmin = kw.pop('vmin', z.min())
1137    vmax = kw.pop('vmax', z.max())
1138    c = kw.pop('c', None)
1139    color = kw.pop('color', c)
1140    if color is None:
1141        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1)
1142        color = cmap(znorm)
1143    elif isinstance(color, np.ndarray) and color.shape == z.shape:
1144        color = cmap(color)
1145    if alpha is None:
1146        if isinstance(color, np.ndarray):
1147            color = color[..., :3]
1148    else:
1149        color[..., 3] = alpha
1150    kw['color'] = color
1151
1152def make_vec(*args):
1153    """Turn all elements of *args* into numpy arrays"""
1154    #return [np.asarray(v, 'd').flatten() for v in args]
1155    return [np.asarray(v, 'd') for v in args]
1156
1157def make_image(z, kw):
1158    """Convert numpy array *z* into a *PIL* RGB image."""
1159    import PIL.Image
1160    from matplotlib import cm
1161
1162    cmap = kw.pop('cmap', cm.coolwarm)
1163
1164    znorm = (z-z.min())/z.ptp()
1165    c = cmap(znorm)
1166    c = c[..., :3]
1167    rgb = np.asarray(c*255, 'u1')
1168    image = PIL.Image.fromarray(rgb, mode='RGB')
1169    return image
1170
1171
1172_IPV_MARKERS = {
1173    'o': 'sphere',
1174}
1175_IPV_COLORS = {
1176    'w': 'white',
1177    'k': 'black',
1178    'c': 'cyan',
1179    'm': 'magenta',
1180    'y': 'yellow',
1181    'r': 'red',
1182    'g': 'green',
1183    'b': 'blue',
1184}
1185def _ipv_fix_color(kw):
1186    alpha = kw.pop('alpha', None)
1187    color = kw.get('color', None)
1188    if isinstance(color, str):
1189        color = _IPV_COLORS.get(color, color)
1190        kw['color'] = color
1191    if alpha is not None:
1192        color = kw['color']
1193        #TODO: convert color to [r, g, b, a] if not already
1194        if isinstance(color, (tuple, list)):
1195            if len(color) == 3:
1196                color = (color[0], color[1], color[2], alpha)
1197            else:
1198                color = (color[0], color[1], color[2], alpha*color[3])
1199            color = np.array(color)
1200        if isinstance(color, np.ndarray) and color.shape[-1] == 4:
1201            color[..., 3] = alpha
1202            kw['color'] = color
1203
1204def _ipv_set_transparency(kw, obj):
1205    color = kw.get('color', None)
1206    if (isinstance(color, np.ndarray)
1207            and color.shape[-1] == 4
1208            and (color[..., 3] != 1.0).any()):
1209        obj.material.transparent = True
1210        obj.material.side = "FrontSide"
1211
1212def ipv_axes():
1213    """
1214    Build a matplotlib style Axes interface for ipyvolume
1215    """
1216    import ipyvolume as ipv
1217
1218    class Axes(object):
1219        """
1220        Matplotlib Axes3D style interface to ipyvolume renderer.
1221        """
1222        # pylint: disable=no-self-use,no-init
1223        # transparency can be achieved by setting the following:
1224        #    mesh.color = [r, g, b, alpha]
1225        #    mesh.material.transparent = True
1226        #    mesh.material.side = "FrontSide"
1227        # smooth(ish) rotation can be achieved by setting:
1228        #    slide.continuous_update = True
1229        #    figure.animation = 0.
1230        #    mesh.material.x = x
1231        #    mesh.material.y = y
1232        #    mesh.material.z = z
1233        # maybe need to synchronize update of x/y/z to avoid shimmy when moving
1234        def plot(self, x, y, z, **kw):
1235            """mpl style plot interface for ipyvolume"""
1236            _ipv_fix_color(kw)
1237            x, y, z = make_vec(x, y, z)
1238            ipv.plot(x, y, z, **kw)
1239        def plot_surface(self, x, y, z, **kw):
1240            """mpl style plot_surface interface for ipyvolume"""
1241            facecolors = kw.pop('facecolors', None)
1242            if facecolors is not None:
1243                kw['color'] = facecolors
1244            _ipv_fix_color(kw)
1245            x, y, z = make_vec(x, y, z)
1246            h = ipv.plot_surface(x, y, z, **kw)
1247            _ipv_set_transparency(kw, h)
1248            #h.material.side = "DoubleSide"
1249            return h
1250        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw):
1251            """mpl style plot_trisurf interface for ipyvolume"""
1252            kw.pop('linewidth', None)
1253            _ipv_fix_color(kw)
1254            x, y, z = make_vec(x, y, Z)
1255            if triangles is not None:
1256                triangles = np.asarray(triangles)
1257            h = ipv.plot_trisurf(x, y, z, triangles=triangles, **kw)
1258            _ipv_set_transparency(kw, h)
1259            return h
1260        def scatter(self, x, y, z, **kw):
1261            """mpl style scatter interface for ipyvolume"""
1262            x, y, z = make_vec(x, y, z)
1263            map_colors(z, kw)
1264            marker = kw.get('marker', None)
1265            kw['marker'] = _IPV_MARKERS.get(marker, marker)
1266            h = ipv.scatter(x, y, z, **kw)
1267            _ipv_set_transparency(kw, h)
1268            return h
1269        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw):
1270            """mpl style contourf interface for ipyvolume"""
1271            # pylint: disable=unused-argument
1272            # Don't use contour for now (although we might want to later)
1273            self.pcolor(x, y, v, zdir='z', offset=offset, **kw)
1274        def pcolor(self, x, y, v, zdir='z', offset=0, **kw):
1275            """mpl style pcolor interface for ipyvolume"""
1276            # pylint: disable=unused-argument
1277            x, y, v = make_vec(x, y, v)
1278            image = make_image(v, kw)
1279            xmin, xmax = x.min(), x.max()
1280            ymin, ymax = y.min(), y.max()
1281            x = np.array([[xmin, xmax], [xmin, xmax]])
1282            y = np.array([[ymin, ymin], [ymax, ymax]])
1283            z = x*0 + offset
1284            u = np.array([[0., 1], [0, 1]])
1285            v = np.array([[0., 0], [1, 1]])
1286            h = ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False)
1287            _ipv_set_transparency(kw, h)
1288            h.material.side = "DoubleSide"
1289            return h
1290        def text(self, *args, **kw):
1291            """mpl style text interface for ipyvolume"""
1292            pass
1293        def set_xlim(self, limits):
1294            """mpl style set_xlim interface for ipyvolume"""
1295            ipv.xlim(*limits)
1296        def set_ylim(self, limits):
1297            """mpl style set_ylim interface for ipyvolume"""
1298            ipv.ylim(*limits)
1299        def set_zlim(self, limits):
1300            """mpl style set_zlim interface for ipyvolume"""
1301            ipv.zlim(*limits)
1302        def set_axes_on(self):
1303            """mpl style set_axes_on interface for ipyvolume"""
1304            ipv.style.axis_on()
1305        def set_axis_off(self):
1306            """mpl style set_axes_off interface for ipyvolume"""
1307            ipv.style.axes_off()
1308    return Axes()
1309
1310def _ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection):
1311    from IPython.display import display
1312    import ipywidgets as widgets
1313    import ipyvolume as ipv
1314
1315    axes = ipv_axes()
1316
1317    def _draw(view, jitter):
1318        camera = ipv.gcf().camera
1319        #print(ipv.gcf().__dict__.keys())
1320        #print(dir(ipv.gcf()))
1321        ipv.figure(animation=0.)  # no animation when updating object mesh
1322
1323        # set small jitter as 0 if multiple pd dims
1324        dims = sum(v > 0 for v in jitter)
1325        limit = [0, 0.5, 5, 5][dims]
1326        jitter = [0 if v < limit else v for v in jitter]
1327
1328        ## Visualize as person on globe
1329        #draw_beam(axes, (0, 0))
1330        #draw_sphere(axes, radius=0.5)
1331        #draw_person_on_sphere(axes, view, radius=0.5)
1332
1333        ## Move beam instead of shape
1334        #draw_beam(axes, view=(-view[0], -view[1]))
1335        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0))
1336
1337        ## Move shape and draw scattering
1338        draw_beam(axes, (0, 0), steps=25)
1339        draw_jitter(axes, view, jitter, dist=dist, size=size, alpha=1.0,
1340                    draw_shape=draw_shape, projection=projection)
1341        draw_mesh(axes, view, jitter, dist=dist, n=mesh, radius=0.95,
1342                  projection=projection)
1343        draw_scattering(calculator, axes, view, jitter, dist=dist)
1344
1345        draw_axes(axes, origin=(-1, -1, -1.1))
1346        ipv.style.box_off()
1347        ipv.style.axes_off()
1348        ipv.xyzlabel(" ", " ", " ")
1349
1350        ipv.gcf().camera = camera
1351        ipv.show()
1352
1353
1354    trange, prange = (-180., 180., 1.), (-180., 180., 1.)
1355    dtrange, dprange = (0., 180., 1.), (0., 180., 1.)
1356
1357    ## Super simple interfaca, but uses non-ascii variable namese
1358    # Ξ φ ψ Δξ Δφ Δψ
1359    #def update(**kw):
1360    #    view = kw['Ξ'], kw['φ'], kw['ψ']
1361    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ']
1362    #    draw(view, jitter)
1363    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange)
1364
1365    def _update(theta, phi, psi, dtheta, dphi, dpsi):
1366        _draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi))
1367
1368    def _slider(name, slice, init=0.):
1369        return widgets.FloatSlider(
1370            value=init,
1371            min=slice[0],
1372            max=slice[1],
1373            step=slice[2],
1374            description=name,
1375            disabled=False,
1376            #continuous_update=True,
1377            continuous_update=False,
1378            orientation='horizontal',
1379            readout=True,
1380            readout_format='.1f',
1381            )
1382    theta = _slider(u'Ξ', trange, view[0])
1383    phi = _slider(u'φ', prange, view[1])
1384    psi = _slider(u'ψ', prange, view[2])
1385    dtheta = _slider(u'Δξ', dtrange, jitter[0])
1386    dphi = _slider(u'Δφ', dprange, jitter[1])
1387    dpsi = _slider(u'Δψ', dprange, jitter[2])
1388    fields = {
1389        'theta': theta, 'phi': phi, 'psi': psi,
1390        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi,
1391    }
1392    ui = widgets.HBox([
1393        widgets.VBox([theta, phi, psi]),
1394        widgets.VBox([dtheta, dphi, dpsi])
1395    ])
1396
1397    out = widgets.interactive_output(_update, fields)
1398    display(ui, out)
1399
1400
1401_ENGINES = {
1402    "matplotlib": _mpl_plot,
1403    "mpl": _mpl_plot,
1404    #"plotly": _plotly_plot,
1405    "ipvolume": _ipv_plot,
1406    "ipv": _ipv_plot,
1407}
1408PLOT_ENGINE = _ENGINES["matplotlib"]
1409def set_plotter(name):
1410    """
1411    Setting the plotting engine to matplotlib/ipyvolume or equivalently mpl/ipv.
1412    """
1413    global PLOT_ENGINE
1414    PLOT_ENGINE = _ENGINES[name]
1415
1416def main():
1417    """
1418    Command line interface to the jitter viewer.
1419    """
1420    parser = argparse.ArgumentParser(
1421        description="Display jitter",
1422        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1423        )
1424    parser.add_argument('-p', '--projection', choices=PROJECTIONS,
1425                        default=PROJECTIONS[0],
1426                        help='coordinate projection')
1427    parser.add_argument('-s', '--size', type=str, default='10,40,100',
1428                        help='a,b,c lengths')
1429    parser.add_argument('-v', '--view', type=str, default='0,0,0',
1430                        help='initial view angles')
1431    parser.add_argument('-j', '--jitter', type=str, default='0,0,0',
1432                        help='initial angular dispersion')
1433    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS,
1434                        default=DISTRIBUTIONS[0],
1435                        help='jitter distribution')
1436    parser.add_argument('-m', '--mesh', type=int, default=30,
1437                        help='#points in theta-phi mesh')
1438    parser.add_argument('shape', choices=SHAPES, nargs='?', default=SHAPES[0],
1439                        help='oriented shape')
1440    opts = parser.parse_args()
1441    size = tuple(float(v) for v in opts.size.split(','))
1442    view = tuple(float(v) for v in opts.view.split(','))
1443    jitter = tuple(float(v) for v in opts.jitter.split(','))
1444    run(opts.shape, size=size, view=view, jitter=jitter,
1445        mesh=opts.mesh, dist=opts.distribution,
1446        projection=opts.projection)
1447
1448if __name__ == "__main__":
1449    main()
Note: See TracBrowser for help on using the repository browser.