Changeset 5f12750 in sasmodels


Ignore:
Timestamp:
Jan 30, 2019 11:15:24 AM (3 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
1511a60c
Parents:
119073a
Message:

support ipyvolume for jitter in jupyter notebooks

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/jitter.py

    r1198f90 r5f12750  
    55 
    66Application to explore orientation angle and angular dispersity. 
     7 
     8From the command line:: 
     9 
     10    # Show docs 
     11    python -m sasmodels.jitter --help 
     12 
     13    # Guyou projection jitter, uniform over 20 degree theta and 10 in phi 
     14    python -m sasmodels.jitter --projection=guyou --dist=uniform --jitter=20,10,0 
     15 
     16From a jupyter cell:: 
     17 
     18    import ipyvolume as ipv 
     19    from sasmodels import jitter 
     20    import importlib; importlib.reload(jitter) 
     21    jitter.set_plotter("ipv") 
     22 
     23    size = (10, 40, 100) 
     24    view = (20, 0, 0) 
     25 
     26    #size = (15, 15, 100) 
     27    #view = (60, 60, 0) 
     28 
     29    dview = (0, 0, 0) 
     30    #dview = (5, 5, 0) 
     31    #dview = (15, 180, 0) 
     32    #dview = (180, 15, 0) 
     33 
     34    projection = 'equirectangular' 
     35    #projection = 'azimuthal_equidistance' 
     36    #projection = 'guyou' 
     37    #projection = 'sinusoidal' 
     38    #projection = 'azimuthal_equal_area' 
     39 
     40    dist = 'uniform' 
     41    #dist = 'gaussian' 
     42 
     43    jitter.run(size=size, view=view, jitter=dview, dist=dist, projection=projection) 
     44    #filename = projection+('_theta' if dview[0] == 180 else '_phi' if dview[1] == 180 else '') 
     45    #ipv.savefig(filename+'.png') 
    746""" 
    847from __future__ import division, print_function 
     
    1049import argparse 
    1150 
    12 try: # CRUFT: travis-ci does not support mpl_toolkits.mplot3d 
    13     import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot 
    14 except ImportError: 
    15     pass 
    16  
    17 import matplotlib.pyplot as plt 
    18 from matplotlib.widgets import Slider 
    1951import numpy as np 
    2052from numpy import pi, cos, sin, sqrt, exp, degrees, radians 
    2153 
    22 def draw_beam(axes, view=(0, 0)): 
     54def draw_beam(axes, view=(0, 0), alpha=0.5): 
    2355    """ 
    2456    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1) 
    2557    """ 
    2658    #axes.plot([0,0],[0,0],[1,-1]) 
    27     #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8) 
    28  
    29     steps = 25 
    30     u = np.linspace(0, 2 * np.pi, steps) 
    31     v = np.linspace(-1, 1, steps) 
     59    #axes.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=alpha) 
     60 
     61    steps = [6, 6] 
     62    u = np.linspace(0, 2 * np.pi, steps[0]) 
     63    v = np.linspace(-1, 1, steps[1]) 
    3264 
    3365    r = 0.02 
     
    4274    x, y, z = [v.reshape(shape) for v in points] 
    4375 
    44     axes.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5) 
     76    axes.plot_surface(x, y, z, color='yellow', alpha=alpha) 
    4577 
    4678def draw_ellipsoid(axes, size, view, jitter, steps=25, alpha=1): 
     
    5486    x, y, z = transform_xyz(view, jitter, x, y, z) 
    5587 
    56     axes.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha) 
     88    axes.plot_surface(x, y, z, color='w', alpha=alpha) 
    5789 
    5890    draw_labels(axes, view, jitter, [ 
     
    141173 
    142174    x, y, z = transform_xyz(view, jitter, x, y, z) 
    143     axes.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha) 
    144  
    145     # Draw pink face on box. 
     175    color = [0.6, 1, 0.6]  # pale green 
     176    axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha) 
     177 
     178    # Colour the c+ face of the box. 
    146179    # Since I can't control face color, instead draw a thin box situated just 
    147180    # in front of the "c+" face.  Use the c face so that rotations about psi 
    148181    # rotate that face. 
    149     if 1: 
     182    if 0: 
     183        #color = [1, 0.6, 0.6]  # pink 
    150184        x = a*np.array([+1, -1, +1, -1, +1, -1, +1, -1]) 
    151185        y = b*np.array([+1, +1, -1, -1, +1, +1, -1, -1]) 
    152186        z = c*np.array([+1, +1, +1, +1, -1, -1, -1, -1]) 
    153187        x, y, z = transform_xyz(view, jitter, x, y, abs(z)+0.001) 
    154         axes.plot_trisurf(x, y, triangles=tri, Z=z, color=[1, 0.6, 0.6], alpha=alpha) 
     188        axes.plot_trisurf(x, y, triangles=tri, Z=z, color=color, alpha=alpha) 
    155189 
    156190    draw_labels(axes, view, jitter, [ 
     
    163197    ]) 
    164198 
    165 def draw_sphere(axes, radius=10., steps=100): 
     199def draw_sphere(axes, radius=0.5, steps=25): 
    166200    """Draw a sphere""" 
    167201    u = np.linspace(0, 2 * np.pi, steps) 
     
    171205    y = radius * np.outer(np.sin(u), np.sin(v)) 
    172206    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) 
    173     axes.plot_surface(x, y, z, rstride=4, cstride=4, color='w') 
    174  
    175 def draw_jitter(axes, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0), 
    176                 draw_shape=draw_parallelepiped): 
     207    axes.plot_surface(x, y, z, color='w') 
     208    #axes.plot_wireframe(x, y, z) 
     209 
     210def draw_person_on_sphere(axes, view, height=0.5, radius=0.5): 
     211    limb_offset = height * 0.05 
     212    head_radius = height * 0.10 
     213    head_height = height - head_radius 
     214    neck_length = head_radius * 0.50 
     215    shoulder_height = height - 2*head_radius - neck_length 
     216    torso_length = shoulder_height * 0.55 
     217    torso_radius = torso_length * 0.30 
     218    leg_length = shoulder_height - torso_length 
     219    arm_length = torso_length * 0.90 
     220 
     221    def _draw_part(x, z): 
     222        y = np.zeros_like(x) 
     223        xp, yp, zp = transform_xyz(view, None, x, y, z + radius) 
     224        axes.plot(xp, yp, zp, color='k') 
     225 
     226    # circle for head 
     227    u = np.linspace(0, 2 * np.pi, 40) 
     228    x = head_radius * np.cos(u) 
     229    z = head_radius * np.sin(u) + head_height 
     230    _draw_part(x, z) 
     231 
     232    # rectangle for body 
     233    x = np.array([-torso_radius, torso_radius, torso_radius, -torso_radius, -torso_radius]) 
     234    z = np.array([0., 0, torso_length, torso_length, 0]) + leg_length 
     235    _draw_part(x, z) 
     236 
     237    # arms 
     238    x = np.array([-torso_radius - limb_offset, -torso_radius - limb_offset, -torso_radius]) 
     239    z = np.array([shoulder_height - arm_length, shoulder_height, shoulder_height]) 
     240    _draw_part(x, z) 
     241    _draw_part(-x, z) 
     242 
     243    # legs 
     244    x = np.array([-torso_radius + limb_offset, -torso_radius + limb_offset]) 
     245    z = np.array([0, leg_length]) 
     246    _draw_part(x, z) 
     247    _draw_part(-x, z) 
     248 
     249    limits = [-radius-height, radius+height] 
     250    axes.set_xlim(limits) 
     251    axes.set_ylim(limits) 
     252    axes.set_zlim(limits) 
     253    axes.set_axis_off() 
     254 
     255def draw_jitter(axes, view, jitter, dist='gaussian',  
     256                size=(0.1, 0.4, 1.0), 
     257                draw_shape=draw_parallelepiped,  
     258                projection='equirectangular', 
     259                alpha=0.8, 
     260                views=None): 
    177261    """ 
    178262    Represent jitter as a set of shapes at different orientations. 
    179263    """ 
     264    project, weight = get_projection(projection) 
     265 
    180266    # set max diagonal to 0.95 
    181267    scale = 0.95/sqrt(sum(v**2 for v in size)) 
    182268    size = tuple(scale*v for v in size) 
    183269 
    184     #np.random.seed(10) 
    185     #cloud = np.random.randn(10,3) 
    186     cloud = [ 
    187         [-1, -1, -1], 
    188         [-1, -1, +0], 
    189         [-1, -1, +1], 
    190         [-1, +0, -1], 
    191         [-1, +0, +0], 
    192         [-1, +0, +1], 
    193         [-1, +1, -1], 
    194         [-1, +1, +0], 
    195         [-1, +1, +1], 
    196         [+0, -1, -1], 
    197         [+0, -1, +0], 
    198         [+0, -1, +1], 
    199         [+0, +0, -1], 
    200         [+0, +0, +0], 
    201         [+0, +0, +1], 
    202         [+0, +1, -1], 
    203         [+0, +1, +0], 
    204         [+0, +1, +1], 
    205         [+1, -1, -1], 
    206         [+1, -1, +0], 
    207         [+1, -1, +1], 
    208         [+1, +0, -1], 
    209         [+1, +0, +0], 
    210         [+1, +0, +1], 
    211         [+1, +1, -1], 
    212         [+1, +1, +0], 
    213         [+1, +1, +1], 
    214     ] 
    215270    dtheta, dphi, dpsi = jitter 
    216     if dtheta == 0: 
    217         cloud = [v for v in cloud if v[0] == 0] 
    218     if dphi == 0: 
    219         cloud = [v for v in cloud if v[1] == 0] 
    220     if dpsi == 0: 
    221         cloud = [v for v in cloud if v[2] == 0] 
    222     draw_shape(axes, size, view, [0, 0, 0], steps=100, alpha=0.8) 
    223     scale = {'gaussian':1, 'rectangle':1/sqrt(3), 'uniform':1/3}[dist] 
    224     for point in cloud: 
    225         delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]] 
    226         draw_shape(axes, size, view, delta, alpha=0.8) 
     271    base = {'gaussian':3, 'rectangle':sqrt(3), 'uniform':1}[dist] 
     272    def steps(delta): 
     273        limit = base*delta 
     274        if views is None: 
     275            n = max(3, min(25, 2*int(base*delta/15))) 
     276        else: 
     277            n = views 
     278        s = base*delta*np.linspace(-1, 1, n) if delta > 0 else [0] 
     279        return s 
     280    stheta = steps(dtheta) 
     281    sphi = steps(dphi) 
     282    spsi = steps(dpsi) 
     283    #print(stheta, sphi, spsi) 
     284    for theta in stheta: 
     285        for phi in sphi: 
     286            for psi in spsi: 
     287                w = weight(theta, phi, 1.0, 1.0) 
     288                if w > 0: 
     289                    dview = project(theta, phi, psi) 
     290                    draw_shape(axes, size, view, dview, alpha=alpha) 
    227291    for v in 'xyz': 
    228292        a, b, c = size 
    229293        lim = np.sqrt(a**2 + b**2 + c**2) 
    230294        getattr(axes, 'set_'+v+'lim')([-lim, lim]) 
    231         getattr(axes, v+'axis').label.set_text(v) 
     295        #getattr(axes, v+'axis').label.set_text(v) 
    232296 
    233297PROJECTIONS = [ 
     
    237301    'azimuthal_equal_area', 
    238302] 
    239 def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian', 
    240               projection='equirectangular'): 
    241     """ 
    242     Draw the dispersion mesh showing the theta-phi orientations at which 
    243     the model will be evaluated. 
    244  
     303def get_projection(projection): 
     304 
     305    """ 
    245306    jitter projections 
    246307    <https://en.wikipedia.org/wiki/List_of_map_projections> 
     
    298359    # TODO: try Kent distribution instead of a gaussian warped by projection 
    299360 
    300     dist_x = np.linspace(-1, 1, n) 
    301     weights = np.ones_like(dist_x) 
    302     if dist == 'gaussian': 
    303         dist_x *= 3 
    304         weights = exp(-0.5*dist_x**2) 
    305     elif dist == 'rectangle': 
    306         # Note: uses sasmodels ridiculous definition of rectangle width 
    307         dist_x *= sqrt(3) 
    308     elif dist == 'uniform': 
    309         pass 
    310     else: 
    311         raise ValueError("expected dist to be gaussian, rectangle or uniform") 
    312  
    313361    if projection == 'equirectangular':  #define PROJECTION 1 
    314         def _rotate(theta_i, phi_j): 
    315             return Rx(phi_j)*Ry(theta_i) 
     362        def _project(theta_i, phi_j, psi): 
     363            latitude, longitude = theta_i, phi_j 
     364            return latitude, longitude, psi 
     365            #return Rx(phi_j)*Ry(theta_i) 
    316366        def _weight(theta_i, phi_j, w_i, w_j): 
    317367            return w_i*w_j*abs(cos(radians(theta_i))) 
    318368    elif projection == 'sinusoidal':  #define PROJECTION 2 
    319         def _rotate(theta_i, phi_j): 
     369        def _project(theta_i, phi_j, psi): 
    320370            latitude = theta_i 
    321371            scale = cos(radians(latitude)) 
    322372            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0 
    323373            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude)) 
    324             return Rx(longitude)*Ry(latitude) 
    325         def _weight(theta_i, phi_j, w_i, w_j): 
     374            return latitude, longitude, psi 
     375            #return Rx(longitude)*Ry(latitude) 
     376        def _project(theta_i, phi_j, w_i, w_j): 
    326377            latitude = theta_i 
    327378            scale = cos(radians(latitude)) 
     
    329380            return active*w_i*w_j 
    330381    elif projection == 'guyou':  #define PROJECTION 3  (eventually?) 
    331         def _rotate(theta_i, phi_j): 
     382        def _project(theta_i, phi_j, psi): 
    332383            from .guyou import guyou_invert 
    333384            #latitude, longitude = guyou_invert([theta_i], [phi_j]) 
    334385            longitude, latitude = guyou_invert([phi_j], [theta_i]) 
    335             return Rx(longitude[0])*Ry(latitude[0]) 
     386            return latitude, longitude, psi 
     387            #return Rx(longitude[0])*Ry(latitude[0]) 
    336388        def _weight(theta_i, phi_j, w_i, w_j): 
    337389            return w_i*w_j 
    338     elif projection == 'azimuthal_equidistance':  # Note: Rz Ry, not Rx Ry 
    339         def _rotate(theta_i, phi_j): 
     390    elif projection == 'azimuthal_equidistance':  
     391        # Note that calculates angles for Rz Ry rather than Rx Ry 
     392        def _project(theta_i, phi_j, psi): 
    340393            latitude = sqrt(theta_i**2 + phi_j**2) 
    341394            longitude = degrees(np.arctan2(phi_j, theta_i)) 
    342395            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude)) 
    343             return Rz(longitude)*Ry(latitude) 
     396            return latitude, longitude, psi-longitude, 'zyz' 
     397            #R = Rz(longitude)*Ry(latitude)*Rz(psi) 
     398            #return R_to_xyz(R) 
     399            #return Rz(longitude)*Ry(latitude) 
    344400        def _weight(theta_i, phi_j, w_i, w_j): 
    345401            # Weighting for each point comes from the integral: 
     
    375431            return weight*w_i*w_j if latitude < 180 else 0 
    376432    elif projection == 'azimuthal_equal_area': 
    377         def _rotate(theta_i, phi_j): 
     433        # Note that calculates angles for Rz Ry rather than Rx Ry 
     434        def _project(theta_i, phi_j, psi): 
    378435            radius = min(1, sqrt(theta_i**2 + phi_j**2)/180) 
    379436            latitude = 180-degrees(2*np.arccos(radius)) 
    380437            longitude = degrees(np.arctan2(phi_j, theta_i)) 
    381438            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude)) 
    382             return Rz(longitude)*Ry(latitude) 
     439            return latitude, longitude, psi, "zyz" 
     440            #R = Rz(longitude)*Ry(latitude)*Rz(psi) 
     441            #return R_to_xyz(R) 
     442            #return Rz(longitude)*Ry(latitude) 
    383443        def _weight(theta_i, phi_j, w_i, w_j): 
    384444            latitude = sqrt(theta_i**2 + phi_j**2) 
     
    388448        raise ValueError("unknown projection %r"%projection) 
    389449 
     450    return _project, _weight 
     451 
     452def R_to_xyz(R): 
     453    """ 
     454    Return phi, theta, psi Tait-Bryan angles corresponding to the given rotation matrix. 
     455 
     456    Extracting Euler Angles from a Rotation Matrix 
     457    Mike Day, Insomniac Games 
     458    https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2012/07/euler-angles1.pdf 
     459    Based on: Shoemake’s “Euler Angle Conversion”, Graphics Gems IV, pp.  222-229 
     460    """ 
     461    phi = np.arctan2(R[1, 2], R[2, 2]) 
     462    theta = np.arctan2(-R[0, 2], np.sqrt(R[0, 0]**2 + R[0, 1]**2)) 
     463    psi = np.arctan2(R[0, 1], R[0, 0]) 
     464    return np.degrees(phi), np.degrees(theta), np.degrees(psi) 
     465 
     466def draw_mesh(axes, view, jitter, radius=1.2, n=11, dist='gaussian', 
     467              projection='equirectangular'): 
     468    """ 
     469    Draw the dispersion mesh showing the theta-phi orientations at which 
     470    the model will be evaluated. 
     471    """ 
     472 
     473    _project, _weight = get_projection(projection) 
     474    def _rotate(theta, phi, z): 
     475        dview = _project(theta, phi, 0.) 
     476        if len(dview) == 4: # hack for zyz coords 
     477            return Rz(dview[1])*Ry(dview[0])*z 
     478        else: 
     479            return Rx(dview[1])*Ry(dview[0])*z 
     480 
     481 
     482    dist_x = np.linspace(-1, 1, n) 
     483    weights = np.ones_like(dist_x) 
     484    if dist == 'gaussian': 
     485        dist_x *= 3 
     486        weights = exp(-0.5*dist_x**2) 
     487    elif dist == 'rectangle': 
     488        # Note: uses sasmodels ridiculous definition of rectangle width 
     489        dist_x *= sqrt(3) 
     490    elif dist == 'uniform': 
     491        pass 
     492    else: 
     493        raise ValueError("expected dist to be gaussian, rectangle or uniform") 
     494 
    390495    # mesh in theta, phi formed by rotating z 
    391496    dtheta, dphi, dpsi = jitter 
    392497    z = np.matrix([[0], [0], [radius]]) 
    393     points = np.hstack([_rotate(theta_i, phi_j)*z 
     498    points = np.hstack([_rotate(theta_i, phi_j, z) 
    394499                        for theta_i in dtheta*dist_x 
    395500                        for phi_j in dphi*dist_x]) 
     
    469574    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 
    470575    """ 
    471     dtheta, dphi, dpsi = jitter 
    472     points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points 
     576    if jitter is None: 
     577        return points 
     578    # Hack to deal with the fact that azimuthal_equidistance uses euler angles 
     579    if len(jitter) == 4: 
     580        dtheta, dphi, dpsi, _ = jitter 
     581        points = Rz(dphi)*Ry(dtheta)*Rz(dpsi)*points 
     582    else: 
     583        dtheta, dphi, dpsi = jitter 
     584        points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points 
    473585    return points 
    474586 
     
    480592    """ 
    481593    theta, phi, psi = view 
    482     points = Rz(phi)*Ry(theta)*Rz(psi)*points 
     594    points = Rz(phi)*Ry(theta)*Rz(psi)*points # viewing angle 
     595    #points = Rz(psi)*Ry(pi/2-theta)*Rz(phi)*points # 1-D integration angles 
     596    #points = Rx(phi)*Ry(theta)*Rz(psi)*points  # angular dispersion angle 
    483597    return points 
     598 
     599def orient_relative_to_beam_quaternion(view, points): 
     600    """ 
     601    Apply the view transform to a set of points. 
     602 
     603    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 
     604 
     605    This variant uses quaternions rather than rotation matrices for the 
     606    computation.  It works but it is not used because it doesn't solve  
     607    any problems.  The challenge of mapping theta/phi/psi to SO(3) does  
     608    not disappear by calculating the transform differently. 
     609    """ 
     610    theta, phi, psi = view 
     611    x, y, z = [1, 0, 0], [0, 1, 0], [0, 0, 1] 
     612    q = Quaternion(1, [0, 0, 0]) 
     613    ## Compose a rotation about the three axes by rotating 
     614    ## the unit vectors before applying the rotation. 
     615    #q = Quaternion.from_angle_axis(theta, q.rot(x)) * q 
     616    #q = Quaternion.from_angle_axis(phi, q.rot(y)) * q 
     617    #q = Quaternion.from_angle_axis(psi, q.rot(z)) * q 
     618    ## The above turns out to be equivalent to reversing 
     619    ## the order of application, so ignore it and use below. 
     620    q = q * Quaternion.from_angle_axis(theta, x) 
     621    q = q * Quaternion.from_angle_axis(phi, y) 
     622    q = q * Quaternion.from_angle_axis(psi, z) 
     623    ## Reverse the order by post-multiply rather than pre-multiply 
     624    #q = Quaternion.from_angle_axis(theta, x) * q 
     625    #q = Quaternion.from_angle_axis(phi, y) * q 
     626    #q = Quaternion.from_angle_axis(psi, z) * q 
     627    #print("axes psi", q.rot(np.matrix([x, y, z]).T)) 
     628    return q.rot(points) 
     629#orient_relative_to_beam = orient_relative_to_beam_quaternion 
     630  
     631# Simple stand-alone quaternion class 
     632import numpy as np 
     633from copy import copy 
     634class Quaternion(object): 
     635    def __init__(self, w, r): 
     636         self.w = w 
     637         self.r = np.asarray(r, dtype='d') 
     638    @staticmethod 
     639    def from_angle_axis(theta, r): 
     640         theta = np.radians(theta)/2 
     641         r = np.asarray(r) 
     642         w = np.cos(theta) 
     643         r = np.sin(theta)*r/np.dot(r,r) 
     644         return Quaternion(w, r) 
     645    def __mul__(self, other): 
     646        if isinstance(other, Quaternion): 
     647            w = self.w*other.w - np.dot(self.r, other.r) 
     648            r = self.w*other.r + other.w*self.r + np.cross(self.r, other.r) 
     649            return Quaternion(w, r) 
     650    def rot(self, v): 
     651        v = np.asarray(v).T 
     652        use_transpose = (v.shape[-1] != 3) 
     653        if use_transpose: v = v.T 
     654        v = v + np.cross(2*self.r, np.cross(self.r, v) + self.w*v) 
     655        #v = v + 2*self.w*np.cross(self.r, v) + np.cross(2*self.r, np.cross(self.r, v)) 
     656        if use_transpose: v = v.T 
     657        return v.T 
     658    def conj(self): 
     659        return Quaternion(self.w, -self.r) 
     660    def inv(self): 
     661        return self.conj()/self.norm()**2 
     662    def norm(self): 
     663        return np.sqrt(self.w**2 + np.sum(self.r**2)) 
     664    def __str__(self): 
     665        return "%g%+gi%+gj%+gk"%(self.w, self.r[0], self.r[1], self.r[2]) 
     666def test_qrot(): 
     667    # Define rotation of 60 degrees around an axis in y-z that is 60 degrees from y. 
     668    # The rotation axis is determined by rotating the point [0, 1, 0] about x. 
     669    ax = Quaternion.from_angle_axis(60, [1, 0, 0]).rot([0, 1, 0]) 
     670    q = Quaternion.from_angle_axis(60, ax) 
     671    # Set the point to be rotated, and its expected rotated position. 
     672    p = [1, -1, 2] 
     673    target = [(10+4*np.sqrt(3))/8, (1+2*np.sqrt(3))/8, (14-3*np.sqrt(3))/8] 
     674    #print(q, q.rot(p) - target) 
     675    assert max(abs(q.rot(p) - target)) < 1e-14  
     676#test_qrot() 
     677#import sys; sys.exit() 
    484678 
    485679# translate between number of dimension of dispersity and the number of 
     
    561755        level[level < 0] = 0 
    562756        colors = plt.get_cmap()(level) 
    563         axes.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors) 
     757        axes.plot_surface(qx, qy, -1.1, facecolors=colors) 
    564758    elif 1: 
    565759        axes.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1, 
     
    691885} 
    692886 
    693 def run(model_name='parallelepiped', size=(10, 40, 100), 
     887 
     888def run(model_name='parallelepiped', size=(10, 40, 100),  
     889        view=(0, 0, 0), jitter=(0, 0, 0), 
    694890        dist='gaussian', mesh=30, 
    695891        projection='equirectangular'): 
     
    701897 
    702898    *size* gives the dimensions (a, b, c) of the shape. 
     899 
     900    *view* gives the initial view (theta, phi, psi) of the shape. 
     901 
     902    *view* gives the initial jitter (dtheta, dphi, dpsi) of the shape. 
    703903 
    704904    *dist* is the type of dispersition: gaussian, rectangle, or uniform. 
     
    720920    calculator, size = select_calculator(model_name, n=150, size=size) 
    721921    draw_shape = DRAW_SHAPES.get(model_name, draw_parallelepiped) 
     922    #draw_shape = draw_fcc 
    722923 
    723924    ## uncomment to set an independent the colour range for every view 
     
    725926    calculator.limits = None 
    726927 
    727     ## initial view 
    728     #theta, dtheta = 70., 10. 
    729     #phi, dphi = -45., 3. 
    730     #psi, dpsi = -45., 3. 
    731     theta, phi, psi = 0, 0, 0 
    732     dtheta, dphi, dpsi = 0, 0, 0 
     928    PLOT_ENGINE(calculator, draw_shape, size, view, jitter, dist, mesh, projection) 
     929 
     930def mpl_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection): 
     931    import mpl_toolkits.mplot3d  # Adds projection='3d' option to subplot 
     932    import matplotlib.pyplot as plt 
     933    from matplotlib.widgets import Slider 
    733934 
    734935    ## create the plot window 
     
    749950 
    750951    ## add control widgets to plot 
    751     axes_theta = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor) 
    752     axes_phi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor) 
    753     axes_psi = plt.axes([0.1, 0.05, 0.45, 0.04], axisbg=axcolor) 
    754     stheta = Slider(axes_theta, 'Theta', -90, 90, valinit=theta) 
    755     sphi = Slider(axes_phi, 'Phi', -180, 180, valinit=phi) 
    756     spsi = Slider(axes_psi, 'Psi', -180, 180, valinit=psi) 
    757  
    758     axes_dtheta = plt.axes([0.75, 0.15, 0.15, 0.04], axisbg=axcolor) 
    759     axes_dphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor) 
    760     axes_dpsi = plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor) 
     952    axes_theta = plt.axes([0.1, 0.15, 0.45, 0.04], facecolor=axcolor) 
     953    axes_phi = plt.axes([0.1, 0.1, 0.45, 0.04], facecolor=axcolor) 
     954    axes_psi = plt.axes([0.1, 0.05, 0.45, 0.04], facecolor=axcolor) 
     955    stheta = Slider(axes_theta, 'Theta', -90, 90, valinit=0) 
     956    sphi = Slider(axes_phi, 'Phi', -180, 180, valinit=0) 
     957    spsi = Slider(axes_psi, 'Psi', -180, 180, valinit=0) 
     958 
     959    axes_dtheta = plt.axes([0.75, 0.15, 0.15, 0.04], facecolor=axcolor) 
     960    axes_dphi = plt.axes([0.75, 0.1, 0.15, 0.04], facecolor=axcolor) 
     961    axes_dpsi = plt.axes([0.75, 0.05, 0.15, 0.04], facecolor=axcolor) 
    761962    # Note: using ridiculous definition of rectangle distribution, whose width 
    762963    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep 
    763964    # the maximum width to 90. 
    764965    dlimit = DIST_LIMITS[dist] 
    765     sdtheta = Slider(axes_dtheta, 'dTheta', 0, 2*dlimit, valinit=dtheta) 
    766     sdphi = Slider(axes_dphi, 'dPhi', 0, 2*dlimit, valinit=dphi) 
    767     sdpsi = Slider(axes_dpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi) 
    768  
     966    sdtheta = Slider(axes_dtheta, 'dTheta', 0, 2*dlimit, valinit=0) 
     967    sdphi = Slider(axes_dphi, 'dPhi', 0, 2*dlimit, valinit=0) 
     968    sdpsi = Slider(axes_dpsi, 'dPsi', 0, 2*dlimit, valinit=0) 
     969 
     970    ## initial view and jitter 
     971    theta, phi, psi = view 
     972    stheta.set_val(theta) 
     973    sphi.set_val(phi) 
     974    spsi.set_val(psi) 
     975    dtheta, dphi, dpsi = jitter 
     976    sdtheta.set_val(dtheta) 
     977    sdphi.set_val(dphi) 
     978    sdpsi.set_val(dpsi) 
    769979 
    770980    ## callback to draw the new view 
     
    774984        # set small jitter as 0 if multiple pd dims 
    775985        dims = sum(v > 0 for v in jitter) 
    776         limit = [0, 0.5, 5][dims] 
     986        limit = [0, 0.5, 5, 5][dims] 
    777987        jitter = [0 if v < limit else v for v in jitter] 
    778988        axes.cla() 
     989 
     990        ## Visualize as person on globe 
     991        #draw_sphere(axes) 
     992        #draw_person_on_sphere(axes, view) 
     993 
     994        ## Move beam instead of shape 
     995        #draw_beam(axes, -view[:2]) 
     996        #draw_jitter(axes, (0,0,0), (0,0,0), views=3) 
     997 
     998        ## Move shape and draw scattering 
    779999        draw_beam(axes, (0, 0)) 
    780         draw_jitter(axes, view, jitter, dist=dist, size=size, draw_shape=draw_shape) 
    781         #draw_jitter(axes, view, (0,0,0)) 
     1000        draw_jitter(axes, view, jitter, dist=dist, size=size,  
     1001                    draw_shape=draw_shape, projection=projection, views=3) 
    7821002        draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection) 
    7831003        draw_scattering(calculator, axes, view, jitter, dist=dist) 
     1004 
    7841005        plt.gcf().canvas.draw() 
    7851006 
     
    7971018    ## go interactive 
    7981019    plt.show() 
     1020 
     1021 
     1022def map_colors(z, kw): 
     1023    from matplotlib import cm 
     1024 
     1025    cmap = kw.pop('cmap', cm.coolwarm) 
     1026    alpha = kw.pop('alpha', None) 
     1027    vmin = kw.pop('vmin', z.min()) 
     1028    vmax = kw.pop('vmax', z.max()) 
     1029    c = kw.pop('c', None) 
     1030    color = kw.pop('color', c) 
     1031    if color is None: 
     1032        znorm = ((z - vmin) / (vmax - vmin)).clip(0, 1) 
     1033        color = cmap(znorm) 
     1034    elif isinstance(color, np.ndarray) and color.shape == z.shape: 
     1035        color = cmap(color) 
     1036    if alpha is None: 
     1037        if isinstance(color, np.ndarray): 
     1038            color = color[..., :3] 
     1039    else: 
     1040        color[..., 3] = alpha 
     1041    kw['color'] = color 
     1042 
     1043def make_vec(*args, flat=False): 
     1044    if flat: 
     1045        return [np.asarray(v, 'd').flatten() for v in args] 
     1046    else: 
     1047        return [np.asarray(v, 'd') for v in args] 
     1048 
     1049def make_image(z, kw): 
     1050    import PIL.Image 
     1051    from matplotlib import cm 
     1052 
     1053    cmap = kw.pop('cmap', cm.coolwarm) 
     1054 
     1055    znorm = (z-z.min())/z.ptp() 
     1056    c = cmap(znorm) 
     1057    c = c[..., :3] 
     1058    rgb = np.asarray(c*255, 'u1') 
     1059    image = PIL.Image.fromarray(rgb, mode='RGB') 
     1060    return image 
     1061 
     1062 
     1063_IPV_MARKERS = { 
     1064    'o': 'sphere', 
     1065} 
     1066_IPV_COLORS = { 
     1067    'w': 'white', 
     1068    'k': 'black', 
     1069    'c': 'cyan', 
     1070    'm': 'magenta', 
     1071    'y': 'yellow', 
     1072    'r': 'red', 
     1073    'g': 'green', 
     1074    'b': 'blue', 
     1075} 
     1076def ipv_fix_color(kw): 
     1077    kw.pop('alpha', None) 
     1078    color = kw.get('color', None) 
     1079    if isinstance(color, str): 
     1080        color = _IPV_COLORS.get(color, color) 
     1081        kw['color'] = color 
     1082 
     1083 
     1084def ipv_plot(calculator, draw_shape, size, view, jitter, dist, mesh, projection): 
     1085    import ipywidgets as widgets 
     1086    import ipyvolume as ipv 
     1087 
     1088    class Plotter: 
     1089        def plot(self, x, y, z, **kw): 
     1090            ipv_fix_color(kw) 
     1091            x, y, z = make_vec(x, y, z) 
     1092            ipv.plot(x, y, z, **kw) 
     1093        def plot_surface(self, x, y, z, **kw): 
     1094            ipv_fix_color(kw) 
     1095            x, y, z = make_vec(x, y, z) 
     1096            ipv.plot_surface(x, y, z, **kw) 
     1097        def plot_trisurf(self, x, y, triangles=None, Z=None, **kw): 
     1098            ipv_fix_color(kw) 
     1099            x, y, z = make_vec(x, y, Z) 
     1100            if triangles is not None: 
     1101                triangles = np.asarray(triangles) 
     1102            ipv.plot_trisurf(x, y, z, triangles=triangles, **kw) 
     1103        def scatter(self, x, y, z, **kw): 
     1104            x, y, z = make_vec(x, y, z) 
     1105            map_colors(z, kw) 
     1106            marker = kw.get('marker', None) 
     1107            kw['marker'] = _IPV_MARKERS.get(marker, marker) 
     1108            ipv.scatter(x, y, z, **kw) 
     1109        def contourf(self, x, y, v, zdir='z', offset=0, levels=None, **kw): 
     1110            # Don't use contour for now (although we might want to later) 
     1111            self.pcolor(x, y, v, zdir='z', offset=offset, **kw) 
     1112        def pcolor(self, x, y, v, zdir='z', offset=0, **kw): 
     1113            x, y, v = make_vec(x, y, v) 
     1114            image = make_image(v, kw) 
     1115            xmin, xmax = x.min(), x.max() 
     1116            ymin, ymax = y.min(), y.max() 
     1117            x = np.array([[xmin, xmax], [xmin, xmax]]) 
     1118            y = np.array([[ymin, ymin], [ymax, ymax]]) 
     1119            z = x*0 + offset 
     1120            u = np.array([[0., 1], [0, 1]]) 
     1121            v = np.array([[0., 0], [1, 1]]) 
     1122            ipv.plot_mesh(x, y, z, u=u, v=v, texture=image, wireframe=False)     
     1123        def text(self, *args, **kw): 
     1124            pass 
     1125        def set_xlim(self, limits): 
     1126            ipv.xlim(*limits) 
     1127        def set_ylim(self, limits): 
     1128            ipv.ylim(*limits) 
     1129        def set_zlim(self, limits): 
     1130            ipv.zlim(*limits) 
     1131        def set_axes_on(self): 
     1132            ipv.style.axis_on() 
     1133        def set_axis_off(self): 
     1134            ipv.style.axes_off() 
     1135    axes = Plotter() 
     1136 
     1137 
     1138    def draw(view, jitter): 
     1139        camera = ipv.gcf().camera 
     1140        #print(ipv.gcf().__dict__.keys()) 
     1141        #print(dir(ipv.gcf())) 
     1142        ipv.figure() 
     1143 
     1144        # set small jitter as 0 if multiple pd dims 
     1145        dims = sum(v > 0 for v in jitter) 
     1146        limit = [0, 0.5, 5, 5][dims] 
     1147        jitter = [0 if v < limit else v for v in jitter] 
     1148 
     1149        ## Visualize as person on globe 
     1150        #draw_beam(axes, (0, 0)) 
     1151        #draw_sphere(axes) 
     1152        #draw_person_on_sphere(axes, view) 
     1153 
     1154        ## Move beam instead of shape 
     1155        #draw_beam(axes, view=(-view[0], -view[1])) 
     1156        #draw_jitter(axes, view=(0,0,0), jitter=(0,0,0)) 
     1157 
     1158        ## Move shape and draw scattering 
     1159        draw_beam(axes, (0, 0)) 
     1160        draw_jitter(axes, view, jitter, dist=dist, size=size,  
     1161                    draw_shape=draw_shape, projection=projection) 
     1162        #draw_mesh(axes, view, jitter, dist=dist, n=mesh, projection=projection) 
     1163        #draw_scattering(calculator, axes, view, jitter, dist=dist) 
     1164     
     1165        ipv.style.box_off() 
     1166        #ipv.style.axes_off() 
     1167        ipv.xyzlabel(" ", " ", " ") 
     1168 
     1169        ipv.gcf().camera = camera 
     1170        ipv.show() 
     1171 
     1172 
     1173    trange, prange = (-180., 180., 1.), (-180., 180., 1.) 
     1174    dtrange, dprange = (0., 180., 1.), (0., 180., 1.) 
     1175 
     1176    ## Super simple interfaca, but uses non-ascii variable namese 
     1177    # Ξ φ ψ Δξ Δφ Δψ 
     1178    #def update(**kw): 
     1179    #    view = kw['Ξ'], kw['φ'], kw['ψ'] 
     1180    #    jitter = kw['Δξ'], kw['Δφ'], kw['Δψ'] 
     1181    #    draw(view, jitter) 
     1182    #widgets.interact(update, Ξ=trange, φ=prange, ψ=prange, Δξ=dtrange, Δφ=dprange, Δψ=dprange) 
     1183 
     1184    def update(theta, phi, psi, dtheta, dphi, dpsi): 
     1185        draw(view=(theta, phi, psi), jitter=(dtheta, dphi, dpsi)) 
     1186 
     1187    def slider(name, slice, init=0.): 
     1188        return widgets.FloatSlider( 
     1189            value=init, 
     1190            min=slice[0], 
     1191            max=slice[1], 
     1192            step=slice[2], 
     1193            description=name, 
     1194            disabled=False, 
     1195            continuous_update=False, 
     1196            orientation='horizontal', 
     1197            readout=True, 
     1198            readout_format='.1f', 
     1199            ) 
     1200    theta = slider('Ξ', trange, view[0]) 
     1201    phi = slider('φ', prange, view[1]) 
     1202    psi = slider('ψ', prange, view[2]) 
     1203    dtheta = slider('Δξ', dtrange, jitter[0]) 
     1204    dphi = slider('Δφ', dprange, jitter[1]) 
     1205    dpsi = slider('Δψ', dprange, jitter[2]) 
     1206    fields = { 
     1207        'theta': theta, 'phi': phi, 'psi': psi, 
     1208        'dtheta': dtheta, 'dphi': dphi, 'dpsi': dpsi, 
     1209    } 
     1210    ui = widgets.HBox([ 
     1211        widgets.VBox([theta, phi, psi]), 
     1212        widgets.VBox([dtheta, dphi, dpsi]) 
     1213    ]) 
     1214 
     1215    out = widgets.interactive_output(update, fields) 
     1216    display(ui, out) 
     1217 
     1218 
     1219_ENGINES = { 
     1220    "matplotlib": mpl_plot, 
     1221    "mpl": mpl_plot, 
     1222    #"plotly": plotly_plot, 
     1223    "ipvolume": ipv_plot, 
     1224    "ipv": ipv_plot, 
     1225} 
     1226PLOT_ENGINE = _ENGINES["matplotlib"] 
     1227def set_plotter(name): 
     1228    global PLOT_ENGINE 
     1229    PLOT_ENGINE = _ENGINES[name] 
    7991230 
    8001231def main(): 
     
    8081239    parser.add_argument('-s', '--size', type=str, default='10,40,100', 
    8091240                        help='a,b,c lengths') 
     1241    parser.add_argument('-v', '--view', type=str, default='0,0,0', 
     1242                        help='initial view angles') 
     1243    parser.add_argument('-j', '--jitter', type=str, default='0,0,0', 
     1244                        help='initial angular dispersion') 
    8101245    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS, 
    8111246                        default=DISTRIBUTIONS[0], 
     
    8161251                        help='oriented shape') 
    8171252    opts = parser.parse_args() 
    818     size = tuple(int(v) for v in opts.size.split(',')) 
    819     run(opts.shape, size=size, 
     1253    size = tuple(float(v) for v in opts.size.split(',')) 
     1254    view = tuple(float(v) for v in opts.view.split(',')) 
     1255    jitter = tuple(float(v) for v in opts.jitter.split(',')) 
     1256    run(opts.shape, size=size, view=view, jitter=jitter, 
    8201257        mesh=opts.mesh, dist=opts.distribution, 
    8211258        projection=opts.projection) 
Note: See TracChangeset for help on using the changeset viewer.