Changeset f18ddc8 in sasmodels for explore


Ignore:
Timestamp:
Oct 26, 2017 9:41:54 AM (7 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
3bfd924
Parents:
a7db3c05 (diff), 6db17bd (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
Message:

Merge branch 'master' into ticket-776-orientation

Location:
explore
Files:
7 added
2 edited

Legend:

Unmodified
Added
Removed
  • explore/jitter.py

    • Property mode changed from 100644 to 100755
    r85190c2 r36b3154  
     1#!/usr/bin/env python 
    12""" 
    23Application to explore the difference between sasview 3.x orientation 
    34dispersity and possible replacement algorithms. 
    45""" 
     6from __future__ import division, print_function 
     7 
     8import sys, os 
     9sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 
     10 
    511import mpl_toolkits.mplot3d   # Adds projection='3d' option to subplot 
    612import matplotlib.pyplot as plt 
    713from matplotlib.widgets import Slider, CheckButtons 
    814from matplotlib import cm 
    9  
    1015import numpy as np 
    1116from numpy import pi, cos, sin, sqrt, exp, degrees, radians 
    1217 
    13 def draw_beam(ax): 
     18def draw_beam(ax, view=(0, 0)): 
     19    """ 
     20    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1) 
     21    """ 
    1422    #ax.plot([0,0],[0,0],[1,-1]) 
    1523    #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8) 
     
    2230    x = r*np.outer(np.cos(u), np.ones_like(v)) 
    2331    y = r*np.outer(np.sin(u), np.ones_like(v)) 
    24     z = np.outer(np.ones_like(u), v) 
     32    z = 1.3*np.outer(np.ones_like(u), v) 
     33 
     34    theta, phi = view 
     35    shape = x.shape 
     36    points = np.matrix([x.flatten(), y.flatten(), z.flatten()]) 
     37    points = Rz(phi)*Ry(theta)*points 
     38    x, y, z = [v.reshape(shape) for v in points] 
    2539 
    2640    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5) 
    2741 
    28 def draw_shimmy(ax, theta, phi, psi, dtheta, dphi, dpsi): 
    29     size=[0.1, 0.4, 1.0] 
    30     view=[theta, phi, psi] 
    31     shimmy=[0,0,0] 
    32     #draw_shape = draw_parallelepiped 
    33     draw_shape = draw_ellipsoid 
     42def draw_jitter(ax, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0)): 
     43    """ 
     44    Represent jitter as a set of shapes at different orientations. 
     45    """ 
     46    # set max diagonal to 0.95 
     47    scale = 0.95/sqrt(sum(v**2 for v in size)) 
     48    size = tuple(scale*v for v in size) 
     49    draw_shape = draw_parallelepiped 
     50    #draw_shape = draw_ellipsoid 
    3451 
    3552    #np.random.seed(10) 
     
    6481        [ 1,  1,  1], 
    6582    ] 
     83    dtheta, dphi, dpsi = jitter 
    6684    if dtheta == 0: 
    6785        cloud = [v for v in cloud if v[0] == 0] 
     
    7088    if dpsi == 0: 
    7189        cloud = [v for v in cloud if v[2] == 0] 
    72     draw_shape(ax, size, view, shimmy, steps=100, alpha=0.8) 
     90    draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8) 
     91    scale = 1/sqrt(3) if dist == 'rectangle' else 1 
    7392    for point in cloud: 
    74         shimmy=[dtheta*point[0], dphi*point[1], dpsi*point[2]] 
    75         draw_shape(ax, size, view, shimmy, alpha=0.8) 
     93        delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]] 
     94        draw_shape(ax, size, view, delta, alpha=0.8) 
    7695    for v in 'xyz': 
    7796        a, b, c = size 
     
    8099        getattr(ax, v+'axis').label.set_text(v) 
    81100 
    82 def draw_ellipsoid(ax, size, view, shimmy, steps=25, alpha=1): 
     101def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1): 
     102    """Draw an ellipsoid.""" 
    83103    a,b,c = size 
    84     theta, phi, psi = view 
    85     dtheta, dphi, dpsi = shimmy 
    86  
    87104    u = np.linspace(0, 2 * np.pi, steps) 
    88105    v = np.linspace(0, np.pi, steps) 
     
    90107    y = b*np.outer(np.sin(u), np.sin(v)) 
    91108    z = c*np.outer(np.ones_like(u), np.cos(v)) 
    92  
    93     shape = x.shape 
    94     points = np.matrix([x.flatten(),y.flatten(),z.flatten()]) 
    95     points = Rz(dpsi)*Ry(dtheta)*Rx(dphi)*points 
    96     points = Rz(phi)*Ry(theta)*Rz(psi)*points 
    97     x,y,z = [v.reshape(shape) for v in points] 
     109    x, y, z = transform_xyz(view, jitter, x, y, z) 
    98110 
    99111    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha) 
    100112 
    101 def draw_parallelepiped(ax, size, view, shimmy, alpha=1): 
     113    draw_labels(ax, view, jitter, [ 
     114         ('c+', [ 0, 0, c], [ 1, 0, 0]), 
     115         ('c-', [ 0, 0,-c], [ 0, 0,-1]), 
     116         ('a+', [ a, 0, 0], [ 0, 0, 1]), 
     117         ('a-', [-a, 0, 0], [ 0, 0,-1]), 
     118         ('b+', [ 0, b, 0], [-1, 0, 0]), 
     119         ('b-', [ 0,-b, 0], [-1, 0, 0]), 
     120    ]) 
     121 
     122def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1): 
     123    """Draw a parallelepiped.""" 
    102124    a,b,c = size 
    103     theta, phi, psi = view 
    104     dtheta, dphi, dpsi = shimmy 
    105  
    106125    x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1]) 
    107126    y = b*np.array([ 1, 1,-1,-1, 1, 1,-1,-1]) 
     
    118137    ]) 
    119138 
    120     points = np.matrix([x,y,z]) 
    121     points = Rz(dpsi)*Ry(dtheta)*Rx(dphi)*points 
    122     points = Rz(phi)*Ry(theta)*Rz(psi)*points 
    123  
    124     x,y,z = [np.array(v).flatten() for v in points] 
     139    x, y, z = transform_xyz(view, jitter, x, y, z) 
    125140    ax.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha) 
    126141 
     142    draw_labels(ax, view, jitter, [ 
     143         ('c+', [ 0, 0, c], [ 1, 0, 0]), 
     144         ('c-', [ 0, 0,-c], [ 0, 0,-1]), 
     145         ('a+', [ a, 0, 0], [ 0, 0, 1]), 
     146         ('a-', [-a, 0, 0], [ 0, 0,-1]), 
     147         ('b+', [ 0, b, 0], [-1, 0, 0]), 
     148         ('b-', [ 0,-b, 0], [-1, 0, 0]), 
     149    ]) 
     150 
    127151def draw_sphere(ax, radius=10., steps=100): 
     152    """Draw a sphere""" 
    128153    u = np.linspace(0, 2 * np.pi, steps) 
    129154    v = np.linspace(0, np.pi, steps) 
     
    134159    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w') 
    135160 
    136 def draw_mesh_new(ax, theta, dtheta, phi, dphi, flow, radius=10., dist='gauss'): 
    137     theta_center = radians(theta) 
    138     phi_center = radians(phi) 
    139     flow_center = radians(flow) 
    140     dtheta = radians(dtheta) 
    141     dphi = radians(dphi) 
    142  
    143     # 10 point 3-sigma gaussian weights 
    144     t = np.linspace(-3., 3., 11) 
    145     if dist == 'gauss': 
     161def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gaussian'): 
     162    """ 
     163    Draw the dispersion mesh showing the theta-phi orientations at which 
     164    the model will be evaluated. 
     165    """ 
     166    theta, phi, psi = view 
     167    dtheta, dphi, dpsi = jitter 
     168 
     169    if dist == 'gaussian': 
     170        t = np.linspace(-3, 3, n) 
    146171        weights = exp(-0.5*t**2) 
    147     elif dist == 'rect': 
     172    elif dist == 'rectangle': 
     173        # Note: uses sasmodels ridiculous definition of rectangle width 
     174        t = np.linspace(-1, 1, n)*sqrt(3) 
    148175        weights = np.ones_like(t) 
    149176    else: 
    150         raise ValueError("expected dist to be 'gauss' or 'rect'") 
    151     theta = dtheta*t 
    152     phi = dphi*t 
    153  
    154     x = radius * np.outer(cos(phi), cos(theta)) 
    155     y = radius * np.outer(sin(phi), cos(theta)) 
    156     z = radius * np.outer(np.ones_like(phi), sin(theta)) 
    157     #w = np.outer(weights, weights*abs(cos(dtheta*t))) 
    158     w = np.outer(weights, weights*abs(cos(theta))) 
    159  
    160     x, y, z, w = [v.flatten() for v in (x,y,z,w)] 
    161     x, y, z = rotate(x, y, z, phi_center, theta_center, flow_center) 
    162  
    163     ax.scatter(x, y, z, c=w, marker='o', vmin=0., vmax=1.) 
    164  
    165 def rotate(x, y, z, phi, theta, psi): 
    166     R = Rz(phi)*Ry(theta)*Rz(psi) 
    167     p = np.vstack([x,y,z]) 
    168     return R*p 
    169  
     177        raise ValueError("expected dist to be 'gaussian' or 'rectangle'") 
     178 
     179    # mesh in theta, phi formed by rotating z 
     180    z = np.matrix([[0], [0], [radius]]) 
     181    points = np.hstack([Rx(phi_i)*Ry(theta_i)*z 
     182                        for theta_i in dtheta*t 
     183                        for phi_i in dphi*t]) 
     184    # rotate relative to beam 
     185    points = orient_relative_to_beam(view, points) 
     186 
     187    w = np.outer(weights*cos(radians(dtheta*t)), weights) 
     188 
     189    x, y, z = [np.array(v).flatten() for v in points] 
     190    ax.scatter(x, y, z, c=w.flatten(), marker='o', vmin=0., vmax=1.) 
     191 
     192def draw_labels(ax, view, jitter, text): 
     193    """ 
     194    Draw text at a particular location. 
     195    """ 
     196    labels, locations, orientations = zip(*text) 
     197    px, py, pz = zip(*locations) 
     198    dx, dy, dz = zip(*orientations) 
     199 
     200    px, py, pz = transform_xyz(view, jitter, px, py, pz) 
     201    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz) 
     202 
     203    # TODO: zdir for labels is broken, and labels aren't appearing. 
     204    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)): 
     205        zdir = np.asarray(zdir).flatten() 
     206        ax.text(p[0], p[1], p[2], label, zdir=zdir) 
     207 
     208# Definition of rotation matrices comes from wikipedia: 
     209#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations 
    170210def Rx(angle): 
     211    """Construct a matrix to rotate points about *x* by *angle* degrees.""" 
    171212    a = radians(angle) 
    172     R = [[1., 0., 0.], 
    173          [0.,  cos(a), sin(a)], 
    174          [0., -sin(a), cos(a)]] 
     213    R = [[1, 0, 0], 
     214         [0, +cos(a), -sin(a)], 
     215         [0, +sin(a), +cos(a)]] 
    175216    return np.matrix(R) 
    176217 
    177218def Ry(angle): 
     219    """Construct a matrix to rotate points about *y* by *angle* degrees.""" 
    178220    a = radians(angle) 
    179     R = [[cos(a), 0., -sin(a)], 
    180          [0., 1., 0.], 
    181          [sin(a), 0.,  cos(a)]] 
     221    R = [[+cos(a), 0, +sin(a)], 
     222         [0, 1, 0], 
     223         [-sin(a), 0, +cos(a)]] 
    182224    return np.matrix(R) 
    183225 
    184226def Rz(angle): 
     227    """Construct a matrix to rotate points about *z* by *angle* degrees.""" 
    185228    a = radians(angle) 
    186     R = [[cos(a), -sin(a), 0.], 
    187          [sin(a),  cos(a), 0.], 
    188          [0., 0., 1.]] 
     229    R = [[+cos(a), -sin(a), 0], 
     230         [+sin(a), +cos(a), 0], 
     231         [0, 0, 1]] 
    189232    return np.matrix(R) 
    190233 
    191 def main(): 
     234def transform_xyz(view, jitter, x, y, z): 
     235    """ 
     236    Send a set of (x,y,z) points through the jitter and view transforms. 
     237    """ 
     238    x, y, z = [np.asarray(v) for v in (x, y, z)] 
     239    shape = x.shape 
     240    points = np.matrix([x.flatten(),y.flatten(),z.flatten()]) 
     241    points = apply_jitter(jitter, points) 
     242    points = orient_relative_to_beam(view, points) 
     243    x, y, z = [np.array(v).reshape(shape) for v in points] 
     244    return x, y, z 
     245 
     246def apply_jitter(jitter, points): 
     247    """ 
     248    Apply the jitter transform to a set of points. 
     249 
     250    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 
     251    """ 
     252    dtheta, dphi, dpsi = jitter 
     253    points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points 
     254    return points 
     255 
     256def orient_relative_to_beam(view, points): 
     257    """ 
     258    Apply the view transform to a set of points. 
     259 
     260    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 
     261    """ 
     262    theta, phi, psi = view 
     263    points = Rz(phi)*Ry(theta)*Rz(psi)*points 
     264    return points 
     265 
     266# translate between number of dimension of dispersity and the number of 
     267# points along each dimension. 
     268PD_N_TABLE = { 
     269    (0, 0, 0): (0, 0, 0),     # 0 
     270    (1, 0, 0): (100, 0, 0),   # 100 
     271    (0, 1, 0): (0, 100, 0), 
     272    (0, 0, 1): (0, 0, 100), 
     273    (1, 1, 0): (30, 30, 0),   # 900 
     274    (1, 0, 1): (30, 0, 30), 
     275    (0, 1, 1): (0, 30, 30), 
     276    (1, 1, 1): (15, 15, 15),  # 3375 
     277} 
     278 
     279def clipped_range(data, portion=1.0, mode='central'): 
     280    """ 
     281    Determine range from data. 
     282 
     283    If *portion* is 1, use full range, otherwise use the center of the range 
     284    or the top of the range, depending on whether *mode* is 'central' or 'top'. 
     285    """ 
     286    if portion == 1.0: 
     287        return data.min(), data.max() 
     288    elif mode == 'central': 
     289        data = np.sort(data.flatten()) 
     290        offset = int(portion*len(data)/2 + 0.5) 
     291        return data[offset], data[-offset] 
     292    elif mode == 'top': 
     293        data = np.sort(data.flatten()) 
     294        offset = int(portion*len(data) + 0.5) 
     295        return data[offset], data[-1] 
     296 
     297def draw_scattering(calculator, ax, view, jitter, dist='gaussian'): 
     298    """ 
     299    Plot the scattering for the particular view. 
     300 
     301    *calculator* is returned from :func:`build_model`.  *ax* are the 3D axes 
     302    on which the data will be plotted.  *view* and *jitter* are the current 
     303    orientation and orientation dispersity.  *dist* is one of the sasmodels 
     304    weight distributions. 
     305    """ 
     306    ## Sasmodels use sqrt(3)*width for the rectangle range; scale to the 
     307    ## proper width for comparison. Commented out since now using the 
     308    ## sasmodels definition of width for rectangle. 
     309    #scale = 1/sqrt(3) if dist == 'rectangle' else 1 
     310    scale = 1 
     311 
     312    # add the orientation parameters to the model parameters 
     313    theta, phi, psi = view 
     314    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter] 
     315    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd>0, phi_pd>0, psi_pd>0)] 
     316    ## increase pd_n for testing jitter integration rather than simple viz 
     317    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)] 
     318 
     319    pars = dict( 
     320        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n, 
     321        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n, 
     322        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n, 
     323    ) 
     324    pars.update(calculator.pars) 
     325 
     326    # compute the pattern 
     327    qx, qy = calculator._data.x_bins, calculator._data.y_bins 
     328    Iqxy = calculator(**pars).reshape(len(qx), len(qy)) 
     329 
     330    # scale it and draw it 
     331    Iqxy = np.log(Iqxy) 
     332    if calculator.limits: 
     333        # use limits from orientation (0,0,0) 
     334        vmin, vmax = calculator.limits 
     335    else: 
     336        vmin, vmax = clipped_range(Iqxy, portion=0.95, mode='top') 
     337    #print("range",(vmin,vmax)) 
     338    #qx, qy = np.meshgrid(qx, qy) 
     339    if 0: 
     340        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i') 
     341        level[level<0] = 0 
     342        colors = plt.get_cmap()(level) 
     343        ax.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors) 
     344    elif 1: 
     345        ax.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1, 
     346                    levels=np.linspace(vmin, vmax, 24)) 
     347    else: 
     348        ax.pcolormesh(qx, qy, Iqxy) 
     349 
     350def build_model(model_name, n=150, qmax=0.5, **pars): 
     351    """ 
     352    Build a calculator for the given shape. 
     353 
     354    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh 
     355    on which to evaluate the model.  The remaining parameters are stored in 
     356    the returned calculator as *calculator.pars*.  They are used by 
     357    :func:`draw_scattering` to set the non-orientation parameters in the 
     358    calculation. 
     359 
     360    Returns a *calculator* function which takes a dictionary or parameters and 
     361    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix 
     362    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class 
     363    for details. 
     364    """ 
     365    from sasmodels.core import load_model_info, build_model 
     366    from sasmodels.data import empty_data2D 
     367    from sasmodels.direct_model import DirectModel 
     368 
     369    model_info = load_model_info(model_name) 
     370    model = build_model(model_info) #, dtype='double!') 
     371    q = np.linspace(-qmax, qmax, n) 
     372    data = empty_data2D(q, q) 
     373    calculator = DirectModel(data, model) 
     374 
     375    # stuff the values for non-orientation parameters into the calculator 
     376    calculator.pars = pars.copy() 
     377    calculator.pars.setdefault('backgound', 1e-3) 
     378 
     379    # fix the data limits so that we can see if the pattern fades 
     380    # under rotation or angular dispersion 
     381    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars) 
     382    Iqxy = np.log(Iqxy) 
     383    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top') 
     384    calculator.limits = vmin, vmax+1 
     385 
     386    return calculator 
     387 
     388def select_calculator(model_name, n=150, size=(10,40,100)): 
     389    """ 
     390    Create a model calculator for the given shape. 
     391 
     392    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid, 
     393    parallelepiped or bcc_paracrystal. *n* is the number of points to use 
     394    in the q range.  *qmax* is chosen based on model parameters for the 
     395    given model to show something intersting. 
     396 
     397    Returns *calculator* and tuple *size* (a,b,c) giving minor and major 
     398    equitorial axes and polar axis respectively.  See :func:`build_model` 
     399    for details on the returned calculator. 
     400    """ 
     401    a, b, c = size 
     402    if model_name == 'sphere': 
     403        calculator = build_model('sphere', n=n, radius=c) 
     404        a = b = c 
     405    elif model_name == 'bcc_paracrystal': 
     406        calculator = build_model('bcc_paracrystal', n=n, dnn=c, 
     407                                  d_factor=0.06, radius=40) 
     408        a = b = c 
     409    elif model_name == 'cylinder': 
     410        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c) 
     411        a = b 
     412    elif model_name == 'ellipsoid': 
     413        calculator = build_model('ellipsoid', n=n, qmax=1.0, 
     414                                 radius_polar=c, radius_equatorial=b) 
     415        a = b 
     416    elif model_name == 'triaxial_ellipsoid': 
     417        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5, 
     418                                 radius_equat_minor=a, 
     419                                 radius_equat_major=b, 
     420                                 radius_polar=c) 
     421    elif model_name == 'parallelepiped': 
     422        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c) 
     423    else: 
     424        raise ValueError("unknown model %s"%model_name) 
     425 
     426    return calculator, (a, b, c) 
     427 
     428def main(model_name='parallelepiped', size=(10, 40, 100)): 
     429    """ 
     430    Show an interactive orientation and jitter demo. 
     431 
     432    *model_name* is one of the models available in :func:`select_model`. 
     433    """ 
     434    # set up calculator 
     435    calculator, size = select_calculator(model_name, n=150, size=size) 
     436 
     437    ## uncomment to set an independent the colour range for every view 
     438    ## If left commented, the colour range is fixed for all views 
     439    calculator.limits = None 
     440 
     441    ## use gaussian distribution unless testing integration 
     442    #dist = 'rectangle' 
     443    dist = 'gaussian' 
     444 
     445    ## initial view 
     446    #theta, dtheta = 70., 10. 
     447    #phi, dphi = -45., 3. 
     448    #psi, dpsi = -45., 3. 
     449    theta, phi, psi = 0, 0, 0 
     450    dtheta, dphi, dpsi = 0, 0, 0 
     451 
     452    ## create the plot window 
    192453    #plt.hold(True) 
    193454    plt.set_cmap('gist_earth') 
     
    196457    #ax = plt.subplot(gs[0], projection='3d') 
    197458    ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d') 
    198  
    199     theta, dtheta = 70., 10. 
    200     phi, dphi = -45., 3. 
    201     psi, dpsi = -45., 3. 
    202     theta, phi, psi = 0, 0, 0 
    203     dtheta, dphi, dpsi = 0, 0, 0 
    204     #dist = 'rect' 
    205     dist = 'gauss' 
     459    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection 
     460        ax.axis('square') 
     461    except Exception: 
     462        pass 
    206463 
    207464    axcolor = 'lightgoldenrodyellow' 
     465 
     466    ## add control widgets to plot 
    208467    axtheta  = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor) 
    209468    axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor) 
     
    212471    sphi = Slider(axphi, 'Phi', -180, 180, valinit=phi) 
    213472    spsi = Slider(axpsi, 'Psi', -180, 180, valinit=psi) 
     473 
    214474    axdtheta  = plt.axes([0.75, 0.15, 0.15, 0.04], axisbg=axcolor) 
    215475    axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor) 
    216476    axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor) 
    217     sdtheta = Slider(axdtheta, 'dTheta', 0, 30, valinit=dtheta) 
    218     sdphi = Slider(axdphi, 'dPhi', 0, 30, valinit=dphi) 
    219     sdpsi = Slider(axdpsi, 'dPsi', 0, 30, valinit=dphi) 
    220  
     477    # Note: using ridiculous definition of rectangle distribution, whose width 
     478    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep 
     479    # the maximum width to 90. 
     480    dlimit = 30 if dist == 'gaussian' else 90/sqrt(3) 
     481    sdtheta = Slider(axdtheta, 'dTheta', 0, dlimit, valinit=dtheta) 
     482    sdphi = Slider(axdphi, 'dPhi', 0, 2*dlimit, valinit=dphi) 
     483    sdpsi = Slider(axdpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi) 
     484 
     485    ## callback to draw the new view 
    221486    def update(val, axis=None): 
    222         theta, phi, psi = stheta.val, sphi.val, spsi.val 
    223         dtheta, dphi, dpsi = sdtheta.val, sdphi.val, sdpsi.val 
     487        view = stheta.val, sphi.val, spsi.val 
     488        jitter = sdtheta.val, sdphi.val, sdpsi.val 
     489        # set small jitter as 0 if multiple pd dims 
     490        dims = sum(v > 0 for v in jitter) 
     491        limit = [0, 0, 2, 5][dims] 
     492        jitter = [0 if v < limit else v for v in jitter] 
    224493        ax.cla() 
    225         draw_beam(ax) 
    226         draw_shimmy(ax, theta, phi, psi, dtheta, dphi, dpsi) 
    227         #if not axis.startswith('d'): 
    228         #    ax.view_init(elev=theta, azim=phi) 
     494        draw_beam(ax, (0, 0)) 
     495        draw_jitter(ax, view, jitter, dist=dist, size=size) 
     496        #draw_jitter(ax, view, (0,0,0)) 
     497        draw_mesh(ax, view, jitter, dist=dist) 
     498        draw_scattering(calculator, ax, view, jitter, dist=dist) 
    229499        plt.gcf().canvas.draw() 
    230500 
     501    ## bind control widgets to view updater 
    231502    stheta.on_changed(lambda v: update(v,'theta')) 
    232503    sphi.on_changed(lambda v: update(v, 'phi')) 
     
    236507    sdpsi.on_changed(lambda v: update(v, 'dpsi')) 
    237508 
     509    ## initialize view 
    238510    update(None, 'phi') 
    239511 
     512    ## go interactive 
    240513    plt.show() 
    241514 
    242515if __name__ == "__main__": 
    243     main() 
     516    model_name = sys.argv[1] if len(sys.argv) > 1 else 'parallelepiped' 
     517    size = tuple(int(v) for v in sys.argv[2].split(',')) if len(sys.argv) > 2 else (10, 40, 100) 
     518    main(model_name, size) 
  • explore/precision.py

    r237c9cf ra1c5758  
    11#!/usr/bin/env python 
    22r""" 
    3 Show numerical precision of $2 J_1(x)/x$. 
     3Show numerical precision of various expressions. 
     4 
     5Evaluates the same function(s) in single and double precision and compares 
     6the results to 500 digit mpmath evaluation of the same function. 
     7 
     8Note: a quick way to generation C and python code for taylor series 
     9expansions from sympy: 
     10 
     11    import sympy as sp 
     12    x = sp.var("x") 
     13    f = sp.sin(x)/x 
     14    t = sp.series(f, n=12).removeO()  # taylor series with no O(x^n) term 
     15    p = sp.horner(t)   # Horner representation 
     16    p = p.replace(x**2, sp.var("xsq")  # simplify if alternate terms are zero 
     17    p = p.n(15)  # evaluate coefficients to 15 digits (optional) 
     18    c_code = sp.ccode(p, assign_to=sp.var("p"))  # convert to c code 
     19    py_code = c[:-1]  # strip semicolon to convert c to python 
     20 
     21    # mpmath has pade() rational function approximation, which might work 
     22    # better than the taylor series for some functions: 
     23    P, Q = mp.pade(sp.Poly(t.n(15),x).coeffs(), L, M) 
     24    P = sum(a*x**n for n,a in enumerate(reversed(P))) 
     25    Q = sum(a*x**n for n,a in enumerate(reversed(Q))) 
     26    c_code = sp.ccode(sp.horner(P)/sp.horner(Q), assign_to=sp.var("p")) 
     27 
     28    # There are richardson and shanks series accelerators in both sympy 
     29    # and mpmath that may be helpful. 
    430""" 
    531from __future__ import division, print_function 
     
    284310    np_function=scipy.special.erfc, 
    285311    ocl_function=make_ocl("return sas_erfc(q);", "sas_erfc", ["lib/polevl.c", "lib/sas_erf.c"]), 
     312    limits=(-5., 5.), 
     313) 
     314add_function( 
     315    name="expm1(x)", 
     316    mp_function=mp.expm1, 
     317    np_function=np.expm1, 
     318    ocl_function=make_ocl("return expm1(q);", "sas_expm1"), 
    286319    limits=(-5., 5.), 
    287320) 
     
    448481) 
    449482 
     483replacement_expm1 = """\ 
     484      double x = (double)q;  // go back to float for single precision kernels 
     485      // Adapted from the cephes math library. 
     486      // Copyright 1984 - 1992 by Stephen L. Moshier 
     487      if (x != x || x == 0.0) { 
     488         return x; // NaN and +/- 0 
     489      } else if (x < -0.5 || x > 0.5) { 
     490         return exp(x) - 1.0; 
     491      } else { 
     492         const double xsq = x*x; 
     493         const double p = ((( 
     494            +1.2617719307481059087798E-4)*xsq 
     495            +3.0299440770744196129956E-2)*xsq 
     496            +9.9999999999999999991025E-1); 
     497         const double q = (((( 
     498            +3.0019850513866445504159E-6)*xsq 
     499            +2.5244834034968410419224E-3)*xsq 
     500            +2.2726554820815502876593E-1)*xsq 
     501            +2.0000000000000000000897E0); 
     502         double r = x * p; 
     503         r =  r / (q - r); 
     504         return r+r; 
     505       } 
     506""" 
     507add_function( 
     508    name="sas_expm1(x)", 
     509    mp_function=mp.expm1, 
     510    np_function=np.expm1, 
     511    ocl_function=make_ocl(replacement_expm1, "sas_expm1"), 
     512) 
     513 
    450514# Alternate versions of 3 j1(x)/x, for posterity 
    451515def taylor_3j1x_x(x): 
Note: See TracChangeset for help on using the changeset viewer.