Changeset aa6989b in sasmodels for explore/jitter.py


Ignore:
Timestamp:
Oct 18, 2017 10:41:12 PM (7 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
9b7b23f
Parents:
ef8e68c
Message:

add scattering pattern to jitter viewer

File:
1 edited

Legend:

Unmodified
Added
Removed
  • explore/jitter.py

    • Property mode changed from 100644 to 100755
    rd4c33d6 raa6989b  
     1#!/usr/bin/env python 
    12""" 
    23Application to explore the difference between sasview 3.x orientation 
    34dispersity and possible replacement algorithms. 
    45""" 
    5 import sys 
     6from __future__ import division, print_function 
     7 
     8import sys, os 
     9sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 
    610 
    711import mpl_toolkits.mplot3d   # Adds projection='3d' option to subplot 
     
    1317 
    1418def draw_beam(ax, view=(0, 0)): 
     19    """ 
     20    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1) 
     21    """ 
    1522    #ax.plot([0,0],[0,0],[1,-1]) 
    1623    #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8) 
     
    3340    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5) 
    3441 
    35 def draw_jitter(ax, view, jitter): 
    36     size = [0.1, 0.4, 1.0] 
     42def draw_jitter(ax, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0)): 
     43    """ 
     44    Represent jitter as a set of shapes at different orientations. 
     45    """ 
     46    # set max diagonal to 0.95 
     47    scale = 0.95/sqrt(sum(v**2 for v in size)) 
     48    size = tuple(scale*v for v in size) 
    3749    draw_shape = draw_parallelepiped 
    3850    #draw_shape = draw_ellipsoid 
     
    7789        cloud = [v for v in cloud if v[2] == 0] 
    7890    draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8) 
     91    scale = 1/sqrt(3) if dist == 'rectangle' else 1 
    7992    for point in cloud: 
    80         delta = [dtheta*point[0], dphi*point[1], dpsi*point[2]] 
     93        delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]] 
    8194        draw_shape(ax, size, view, delta, alpha=0.8) 
    8295    for v in 'xyz': 
     
    87100 
    88101def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1): 
     102    """Draw an ellipsoid.""" 
    89103    a,b,c = size 
    90104    u = np.linspace(0, 2 * np.pi, steps) 
     
    107121 
    108122def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1): 
     123    """Draw a parallelepiped.""" 
    109124    a,b,c = size 
    110125    x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1]) 
     
    134149    ]) 
    135150 
    136 def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gauss'): 
     151def draw_sphere(ax, radius=10., steps=100): 
     152    """Draw a sphere""" 
     153    u = np.linspace(0, 2 * np.pi, steps) 
     154    v = np.linspace(0, np.pi, steps) 
     155 
     156    x = radius * np.outer(np.cos(u), np.sin(v)) 
     157    y = radius * np.outer(np.sin(u), np.sin(v)) 
     158    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) 
     159    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w') 
     160 
     161def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gaussian'): 
     162    """ 
     163    Draw the dispersion mesh showing the theta-phi orientations at which 
     164    the model will be evaluated. 
     165    """ 
    137166    theta, phi, psi = view 
    138167    dtheta, dphi, dpsi = jitter 
    139     if dist == 'gauss': 
     168 
     169    if dist == 'gaussian': 
    140170        t = np.linspace(-3, 3, n) 
    141171        weights = exp(-0.5*t**2) 
    142     elif dist == 'rect': 
    143         t = np.linspace(0, 1, n) 
     172    elif dist == 'rectangle': 
     173        # Note: uses sasmodels ridiculous definition of rectangle width 
     174        t = np.linspace(-1, 1, n)*sqrt(3) 
    144175        weights = np.ones_like(t) 
    145176    else: 
    146         raise ValueError("expected dist to be 'gauss' or 'rect'") 
     177        raise ValueError("expected dist to be 'gaussian' or 'rectangle'") 
    147178 
    148179    # mesh in theta, phi formed by rotating z 
     
    154185    points = orient_relative_to_beam(view, points) 
    155186 
    156     w = np.outer(weights, weights) 
     187    w = np.outer(weights*cos(radians(dtheta*t)), weights) 
    157188 
    158189    x, y, z = [np.array(v).flatten() for v in points] 
    159190    ax.scatter(x, y, z, c=w.flatten(), marker='o', vmin=0., vmax=1.) 
    160191 
     192def draw_labels(ax, view, jitter, text): 
     193    """ 
     194    Draw text at a particular location. 
     195    """ 
     196    labels, locations, orientations = zip(*text) 
     197    px, py, pz = zip(*locations) 
     198    dx, dy, dz = zip(*orientations) 
     199 
     200    px, py, pz = transform_xyz(view, jitter, px, py, pz) 
     201    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz) 
     202 
     203    # TODO: zdir for labels is broken, and labels aren't appearing. 
     204    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)): 
     205        zdir = np.asarray(zdir).flatten() 
     206        ax.text(p[0], p[1], p[2], label, zdir=zdir) 
     207 
     208# Definition of rotation matrices comes from wikipedia: 
     209#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations 
    161210def Rx(angle): 
     211    """Construct a matrix to rotate points about *x* by *angle* degrees.""" 
    162212    a = radians(angle) 
    163     R = [[1., 0., 0.], 
    164          [0.,  cos(a), sin(a)], 
    165          [0., -sin(a), cos(a)]] 
     213    R = [[1, 0, 0], 
     214         [0, +cos(a), -sin(a)], 
     215         [0, +sin(a), +cos(a)]] 
    166216    return np.matrix(R) 
    167217 
    168218def Ry(angle): 
     219    """Construct a matrix to rotate points about *y* by *angle* degrees.""" 
    169220    a = radians(angle) 
    170     R = [[cos(a), 0., -sin(a)], 
    171          [0., 1., 0.], 
    172          [sin(a), 0.,  cos(a)]] 
     221    R = [[+cos(a), 0, +sin(a)], 
     222         [0, 1, 0], 
     223         [-sin(a), 0, +cos(a)]] 
    173224    return np.matrix(R) 
    174225 
    175226def Rz(angle): 
     227    """Construct a matrix to rotate points about *z* by *angle* degrees.""" 
    176228    a = radians(angle) 
    177     R = [[cos(a), -sin(a), 0.], 
    178          [sin(a),  cos(a), 0.], 
    179          [0., 0., 1.]] 
     229    R = [[+cos(a), -sin(a), 0], 
     230         [+sin(a), +cos(a), 0], 
     231         [0, 0, 1]] 
    180232    return np.matrix(R) 
    181233 
    182234def transform_xyz(view, jitter, x, y, z): 
     235    """ 
     236    Send a set of (x,y,z) points through the jitter and view transforms. 
     237    """ 
    183238    x, y, z = [np.asarray(v) for v in (x, y, z)] 
    184239    shape = x.shape 
     
    190245 
    191246def apply_jitter(jitter, points): 
     247    """ 
     248    Apply the jitter transform to a set of points. 
     249 
     250    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 
     251    """ 
    192252    dtheta, dphi, dpsi = jitter 
    193253    points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points 
     
    195255 
    196256def orient_relative_to_beam(view, points): 
     257    """ 
     258    Apply the view transform to a set of points. 
     259 
     260    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 
     261    """ 
    197262    theta, phi, psi = view 
    198263    points = Rz(phi)*Ry(theta)*Rz(psi)*points 
    199264    return points 
    200265 
    201 def draw_labels(ax, view, jitter, text): 
    202     labels, locations, orientations = zip(*text) 
    203     px, py, pz = zip(*locations) 
    204     dx, dy, dz = zip(*orientations) 
    205  
    206     px, py, pz = transform_xyz(view, jitter, px, py, pz) 
    207     dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz) 
    208  
    209     for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)): 
    210         zdir = np.asarray(zdir).flatten() 
    211         ax.text(p[0], p[1], p[2], label, zdir=zdir) 
    212  
    213 def draw_sphere(ax, radius=10., steps=100): 
    214     u = np.linspace(0, 2 * np.pi, steps) 
    215     v = np.linspace(0, np.pi, steps) 
    216  
    217     x = radius * np.outer(np.cos(u), np.sin(v)) 
    218     y = radius * np.outer(np.sin(u), np.sin(v)) 
    219     z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) 
    220     ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w') 
    221  
    222 def main(): 
     266# translate between number of dimension of dispersity and the number of 
     267# points along each dimension. 
     268PD_N_TABLE = { 
     269    (0, 0, 0): (0, 0, 0),     # 0 
     270    (1, 0, 0): (100, 0, 0),   # 100 
     271    (0, 1, 0): (0, 100, 0), 
     272    (0, 0, 1): (0, 0, 100), 
     273    (1, 1, 0): (30, 30, 0),   # 900 
     274    (1, 0, 1): (30, 0, 30), 
     275    (0, 1, 1): (0, 30, 30), 
     276    (1, 1, 1): (15, 15, 15),  # 3375 
     277} 
     278 
     279def clipped_range(data, portion=1.0, mode='central'): 
     280    """ 
     281    Determine range from data. 
     282 
     283    If *portion* is 1, use full range, otherwise use the center of the range 
     284    or the top of the range, depending on whether *mode* is 'central' or 'top'. 
     285    """ 
     286    if portion == 1.0: 
     287        return data.min(), data.max() 
     288    elif mode == 'central': 
     289        data = np.sort(data.flatten()) 
     290        offset = int(portion*len(data)/2 + 0.5) 
     291        return data[offset], data[-offset] 
     292    elif mode == 'top': 
     293        data = np.sort(data.flatten()) 
     294        offset = int(portion*len(data) + 0.5) 
     295        return data[offset], data[-1] 
     296 
     297def draw_scattering(calculator, ax, view, jitter, dist='gaussian'): 
     298    """ 
     299    Plot the scattering for the particular view. 
     300 
     301    *calculator* is returned from :func:`build_model`.  *ax* are the 3D axes 
     302    on which the data will be plotted.  *view* and *jitter* are the current 
     303    orientation and orientation dispersity.  *dist* is one of the sasmodels 
     304    weight distributions. 
     305    """ 
     306    ## Sasmodels use sqrt(3)*width for the rectangle range; scale to the 
     307    ## proper width for comparison. Commented out since now using the 
     308    ## sasmodels definition of width for rectangle. 
     309    #scale = 1/sqrt(3) if dist == 'rectangle' else 1 
     310    scale = 1 
     311 
     312    # add the orientation parameters to the model parameters 
     313    theta, phi, psi = view 
     314    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter] 
     315    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd>0, phi_pd>0, psi_pd>0)] 
     316    ## increase pd_n for testing jitter integration rather than simple viz 
     317    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)] 
     318 
     319    pars = dict( 
     320        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n, 
     321        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n, 
     322        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n, 
     323    ) 
     324    pars.update(calculator.pars) 
     325 
     326    # compute the pattern 
     327    qx, qy = calculator._data.x_bins, calculator._data.y_bins 
     328    Iqxy = calculator(**pars).reshape(len(qx), len(qy)) 
     329 
     330    # scale it and draw it 
     331    Iqxy = np.log(Iqxy) 
     332    if calculator.limits: 
     333        # use limits from orientation (0,0,0) 
     334        vmin, vmax = calculator.limits 
     335    else: 
     336        vmin, vmax = clipped_range(Iqxy, portion=0.95, mode='top') 
     337    #print("range",(vmin,vmax)) 
     338    #qx, qy = np.meshgrid(qx, qy) 
     339    if 0: 
     340        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i') 
     341        level[level<0] = 0 
     342        colors = plt.get_cmap()(level) 
     343        ax.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors) 
     344    elif 1: 
     345        ax.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1, 
     346                    levels=np.linspace(vmin, vmax, 24)) 
     347    else: 
     348        ax.pcolormesh(qx, qy, Iqxy) 
     349 
     350def build_model(model_name, n=150, qmax=0.5, **pars): 
     351    """ 
     352    Build a calculator for the given shape. 
     353 
     354    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh 
     355    on which to evaluate the model.  The remaining parameters are stored in 
     356    the returned calculator as *calculator.pars*.  They are used by 
     357    :func:`draw_scattering` to set the non-orientation parameters in the 
     358    calculation. 
     359 
     360    Returns a *calculator* function which takes a dictionary or parameters and 
     361    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix 
     362    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class 
     363    for details. 
     364    """ 
     365    from sasmodels.core import load_model_info, build_model 
     366    from sasmodels.data import empty_data2D 
     367    from sasmodels.direct_model import DirectModel 
     368 
     369    model_info = load_model_info(model_name) 
     370    model = build_model(model_info) #, dtype='double!') 
     371    q = np.linspace(-qmax, qmax, n) 
     372    data = empty_data2D(q, q) 
     373    calculator = DirectModel(data, model) 
     374 
     375    # stuff the values for non-orientation parameters into the calculator 
     376    calculator.pars = pars.copy() 
     377    calculator.pars.setdefault('backgound', 1e-3) 
     378 
     379    # fix the data limits so that we can see if the pattern fades 
     380    # under rotation or angular dispersion 
     381    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars) 
     382    Iqxy = np.log(Iqxy) 
     383    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top') 
     384    calculator.limits = vmin, vmax+1 
     385 
     386    return calculator 
     387 
     388def select_calculator(model_name, n=150): 
     389    """ 
     390    Create a model calculator for the given shape. 
     391 
     392    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid, 
     393    parallelepiped or bcc_paracrystal. *n* is the number of points to use 
     394    in the q range.  *qmax* is chosen based on model parameters for the 
     395    given model to show something intersting. 
     396 
     397    Returns *calculator* and tuple *size* (a,b,c) giving minor and major 
     398    equitorial axes and polar axis respectively.  See :func:`build_model` 
     399    for details on the returned calculator. 
     400    """ 
     401    a, b, c = 10, 40, 100 
     402    if model_name == 'sphere': 
     403        calculator = build_model('sphere', n=n, radius=c) 
     404        a = b = c 
     405    elif model_name == 'bcc_paracrystal': 
     406        calculator = build_model('bcc_paracrystal', n=n, dnn=c, 
     407                                  d_factor=0.06, radius=40) 
     408        a = b = c 
     409    elif model_name == 'cylinder': 
     410        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c) 
     411        a = b 
     412    elif model_name == 'ellipsoid': 
     413        calculator = build_model('ellipsoid', n=n, qmax=1.0, 
     414                                 radius_polar=c, radius_equatorial=b) 
     415        a = b 
     416    elif model_name == 'triaxial_ellipsoid': 
     417        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5, 
     418                                 radius_equat_minor=a, 
     419                                 radius_equat_major=b, 
     420                                 radius_polar=c) 
     421    elif model_name == 'parallelepiped': 
     422        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c) 
     423    else: 
     424        raise ValueError("unknown model %s"%model_name) 
     425 
     426    return calculator, (a, b, c) 
     427 
     428def main(model_name='parallelepiped'): 
     429    """ 
     430    Show an interactive orientation and jitter demo. 
     431 
     432    *model_name* is one of the models available in :func:`select_model`. 
     433    """ 
     434    # set up calculator 
     435    calculator, size = select_calculator(model_name, n=150) 
     436 
     437    ## uncomment to set an independent the colour range for every view 
     438    ## If left commented, the colour range is fixed for all views 
     439    calculator.limits = None 
     440 
     441    ## use gaussian distribution unless testing integration 
     442    #dist = 'rectangle' 
     443    dist = 'gaussian' 
     444 
     445    ## initial view 
     446    #theta, dtheta = 70., 10. 
     447    #phi, dphi = -45., 3. 
     448    #psi, dpsi = -45., 3. 
     449    theta, phi, psi = 0, 0, 0 
     450    dtheta, dphi, dpsi = 0, 0, 0 
     451 
     452    ## create the plot window 
    223453    #plt.hold(True) 
    224454    plt.set_cmap('gist_earth') 
     
    227457    #ax = plt.subplot(gs[0], projection='3d') 
    228458    ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d') 
    229  
    230     theta, dtheta = 70., 10. 
    231     phi, dphi = -45., 3. 
    232     psi, dpsi = -45., 3. 
    233     theta, phi, psi = 0, 0, 0 
    234     dtheta, dphi, dpsi = 0, 0, 0 
    235     #dist = 'rect' 
    236     dist = 'gauss' 
     459    ax.axis('square') 
    237460 
    238461    axcolor = 'lightgoldenrodyellow' 
    239462 
     463    ## add control widgets to plot 
    240464    axtheta  = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor) 
    241465    axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor) 
     
    248472    axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor) 
    249473    axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor) 
    250     sdtheta = Slider(axdtheta, 'dTheta', 0, 30, valinit=dtheta) 
    251     sdphi = Slider(axdphi, 'dPhi', 0, 30, valinit=dphi) 
    252     sdpsi = Slider(axdpsi, 'dPsi', 0, 30, valinit=dpsi) 
    253  
     474    # Note: using ridiculous definition of rectangle distribution, whose width 
     475    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep 
     476    # the maximum width to 90. 
     477    dlimit = 30 if dist == 'gaussian' else 90/sqrt(3) 
     478    sdtheta = Slider(axdtheta, 'dTheta', 0, dlimit, valinit=dtheta) 
     479    sdphi = Slider(axdphi, 'dPhi', 0, 2*dlimit, valinit=dphi) 
     480    sdpsi = Slider(axdpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi) 
     481 
     482    ## callback to draw the new view 
    254483    def update(val, axis=None): 
    255484        view = stheta.val, sphi.val, spsi.val 
    256485        jitter = sdtheta.val, sdphi.val, sdpsi.val 
     486        # set small jitter as 0 if multiple pd dims 
     487        dims = sum(v > 0 for v in jitter) 
     488        limit = [0, 0, 2, 5][dims] 
     489        jitter = [0 if v < limit else v for v in jitter] 
    257490        ax.cla() 
    258491        draw_beam(ax, (0, 0)) 
    259         draw_jitter(ax, view, jitter) 
     492        draw_jitter(ax, view, jitter, dist=dist, size=size) 
    260493        #draw_jitter(ax, view, (0,0,0)) 
    261         draw_mesh(ax, view, jitter) 
     494        draw_mesh(ax, view, jitter, dist=dist) 
     495        draw_scattering(calculator, ax, view, jitter, dist=dist) 
    262496        plt.gcf().canvas.draw() 
    263497 
     498    ## bind control widgets to view updater 
    264499    stheta.on_changed(lambda v: update(v,'theta')) 
    265500    sphi.on_changed(lambda v: update(v, 'phi')) 
     
    269504    sdpsi.on_changed(lambda v: update(v, 'dpsi')) 
    270505 
     506    ## initialize view 
    271507    update(None, 'phi') 
    272508 
     509    ## go interactive 
    273510    plt.show() 
    274511 
    275512if __name__ == "__main__": 
    276     main() 
     513    model_name = sys.argv[1] if len(sys.argv) > 1 else 'parallelepiped' 
     514    main(model_name) 
Note: See TracChangeset for help on using the changeset viewer.