source: sasmodels/explore/jitter.py @ 87a6591

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 87a6591 was 87a6591, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

explore different map projections for jitter

  • Property mode set to 100755
File size: 21.0 KB
Line 
1#!/usr/bin/env python
2"""
3Application to explore the difference between sasview 3.x orientation
4dispersity and possible replacement algorithms.
5"""
6from __future__ import division, print_function
7
8import sys, os
9sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
10
11import mpl_toolkits.mplot3d   # Adds projection='3d' option to subplot
12import matplotlib.pyplot as plt
13from matplotlib.widgets import Slider, CheckButtons
14from matplotlib import cm
15import numpy as np
16from numpy import pi, cos, sin, sqrt, exp, degrees, radians
17
18def draw_beam(ax, view=(0, 0)):
19    """
20    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
21    """
22    #ax.plot([0,0],[0,0],[1,-1])
23    #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8)
24
25    steps = 25
26    u = np.linspace(0, 2 * np.pi, steps)
27    v = np.linspace(-1, 1, steps)
28
29    r = 0.02
30    x = r*np.outer(np.cos(u), np.ones_like(v))
31    y = r*np.outer(np.sin(u), np.ones_like(v))
32    z = 1.3*np.outer(np.ones_like(u), v)
33
34    theta, phi = view
35    shape = x.shape
36    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
37    points = Rz(phi)*Ry(theta)*points
38    x, y, z = [v.reshape(shape) for v in points]
39
40    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5)
41
42def draw_jitter(ax, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0)):
43    """
44    Represent jitter as a set of shapes at different orientations.
45    """
46    # set max diagonal to 0.95
47    scale = 0.95/sqrt(sum(v**2 for v in size))
48    size = tuple(scale*v for v in size)
49    draw_shape = draw_parallelepiped
50    #draw_shape = draw_ellipsoid
51
52    #np.random.seed(10)
53    #cloud = np.random.randn(10,3)
54    cloud = [
55        [-1, -1, -1],
56        [-1, -1,  0],
57        [-1, -1,  1],
58        [-1,  0, -1],
59        [-1,  0,  0],
60        [-1,  0,  1],
61        [-1,  1, -1],
62        [-1,  1,  0],
63        [-1,  1,  1],
64        [ 0, -1, -1],
65        [ 0, -1,  0],
66        [ 0, -1,  1],
67        [ 0,  0, -1],
68        [ 0,  0,  0],
69        [ 0,  0,  1],
70        [ 0,  1, -1],
71        [ 0,  1,  0],
72        [ 0,  1,  1],
73        [ 1, -1, -1],
74        [ 1, -1,  0],
75        [ 1, -1,  1],
76        [ 1,  0, -1],
77        [ 1,  0,  0],
78        [ 1,  0,  1],
79        [ 1,  1, -1],
80        [ 1,  1,  0],
81        [ 1,  1,  1],
82    ]
83    dtheta, dphi, dpsi = jitter
84    if dtheta == 0:
85        cloud = [v for v in cloud if v[0] == 0]
86    if dphi == 0:
87        cloud = [v for v in cloud if v[1] == 0]
88    if dpsi == 0:
89        cloud = [v for v in cloud if v[2] == 0]
90    draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8)
91    scale = 1/sqrt(3) if dist == 'rectangle' else 1
92    for point in cloud:
93        delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]]
94        draw_shape(ax, size, view, delta, alpha=0.8)
95    for v in 'xyz':
96        a, b, c = size
97        lim = np.sqrt(a**2+b**2+c**2)
98        getattr(ax, 'set_'+v+'lim')([-lim, lim])
99        getattr(ax, v+'axis').label.set_text(v)
100
101def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1):
102    """Draw an ellipsoid."""
103    a,b,c = size
104    u = np.linspace(0, 2 * np.pi, steps)
105    v = np.linspace(0, np.pi, steps)
106    x = a*np.outer(np.cos(u), np.sin(v))
107    y = b*np.outer(np.sin(u), np.sin(v))
108    z = c*np.outer(np.ones_like(u), np.cos(v))
109    x, y, z = transform_xyz(view, jitter, x, y, z)
110
111    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha)
112
113    draw_labels(ax, view, jitter, [
114         ('c+', [ 0, 0, c], [ 1, 0, 0]),
115         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
116         ('a+', [ a, 0, 0], [ 0, 0, 1]),
117         ('a-', [-a, 0, 0], [ 0, 0,-1]),
118         ('b+', [ 0, b, 0], [-1, 0, 0]),
119         ('b-', [ 0,-b, 0], [-1, 0, 0]),
120    ])
121
122def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1):
123    """Draw a parallelepiped."""
124    a,b,c = size
125    x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1])
126    y = b*np.array([ 1, 1,-1,-1, 1, 1,-1,-1])
127    z = c*np.array([ 1, 1, 1, 1,-1,-1,-1,-1])
128    tri = np.array([
129        # counter clockwise triangles
130        # z: up/down, x: right/left, y: front/back
131        [0,1,2], [3,2,1], # top face
132        [6,5,4], [5,6,7], # bottom face
133        [0,2,6], [6,4,0], # right face
134        [1,5,7], [7,3,1], # left face
135        [2,3,6], [7,6,3], # front face
136        [4,1,0], [5,1,4], # back face
137    ])
138
139    x, y, z = transform_xyz(view, jitter, x, y, z)
140    ax.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha)
141
142    draw_labels(ax, view, jitter, [
143         ('c+', [ 0, 0, c], [ 1, 0, 0]),
144         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
145         ('a+', [ a, 0, 0], [ 0, 0, 1]),
146         ('a-', [-a, 0, 0], [ 0, 0,-1]),
147         ('b+', [ 0, b, 0], [-1, 0, 0]),
148         ('b-', [ 0,-b, 0], [-1, 0, 0]),
149    ])
150
151def draw_sphere(ax, radius=10., steps=100):
152    """Draw a sphere"""
153    u = np.linspace(0, 2 * np.pi, steps)
154    v = np.linspace(0, np.pi, steps)
155
156    x = radius * np.outer(np.cos(u), np.sin(v))
157    y = radius * np.outer(np.sin(u), np.sin(v))
158    z = radius * np.outer(np.ones(np.size(u)), np.cos(v))
159    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w')
160
161def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gaussian'):
162    """
163    Draw the dispersion mesh showing the theta-phi orientations at which
164    the model will be evaluated.
165    """
166    theta, phi, psi = view
167    dtheta, dphi, dpsi = jitter
168
169    if dist == 'gaussian':
170        t = np.linspace(-3, 3, n)
171        weights = exp(-0.5*t**2)
172    elif dist == 'rectangle':
173        # Note: uses sasmodels ridiculous definition of rectangle width
174        t = np.linspace(-1, 1, n)*sqrt(3)
175        weights = np.ones_like(t)
176    else:
177        raise ValueError("expected dist to be 'gaussian' or 'rectangle'")
178
179    #PROJECTION = 'stretched_phi'
180    PROJECTION = 'azimuthal_equidistance'
181    #PROJECTION = 'azimuthal_equal_area'
182    if PROJECTION == 'stretced_phi':
183        def rotate(theta_i, phi_j):
184            if theta_i != 90:
185                phi_j /= cos(radians(theta_i))
186            return Rx(phi_j)*Ry(theta_i)
187        def weight(theta_i, phi_j, wi, wj):
188            if theta_i != 90:
189                phi_j /= cos(radians(theta_i))
190            return wi*wj if abs(phi_j) < 180 else 0
191    elif PROJECTION == 'azimuthal_equidistance':
192        # https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection
193        def rotate(theta_i, phi_j):
194            latitude = sqrt(theta_i**2 + phi_j**2)
195            longitude = degrees(np.arctan2(phi_j, theta_i))
196            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
197            return Rz(longitude)*Ry(latitude)
198        def weight(theta_i, phi_j, wi, wj):
199            latitude = sqrt(theta_i**2 + phi_j**2)
200            w = wi*wj if latitude < 180 else 0
201            return w
202    elif PROJECTION == 'azimuthal_equal_area':
203        # https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection
204        def rotate(theta_i, phi_j):
205            R = min(1, sqrt(theta_i**2 + phi_j**2)/180)
206            latitude = 180-degrees(2*np.arccos(R))
207            longitude = degrees(np.arctan2(phi_j, theta_i))
208            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
209            return Rz(longitude)*Ry(latitude)
210        def weight(theta_i, phi_j, wi, wj):
211            R = sqrt(theta_i**2 + phi_j**2)
212            return wi*wj if R <= 180 else 0
213            #return wi*wj if latitude <= 180 else 0
214    elif SCALED_PHI == 10:  # random thrashing
215        def rotate(theta_i, phi_j):
216            theta_i, phi_j = 2*theta_i/abs(cos(radians(phi_j))), 2*phi_j/cos(radians(theta_i))
217            return Rx(phi_j)*Ry(theta_i)
218        def weight(theta_i, phi_j, wi, wj):
219            theta_i, phi_j = 2*theta_i/abs(cos(radians(phi_j))), 2*phi_j/cos(radians(theta_i))
220            return wi*wj if abs(phi_j) < 180 else 0
221    else:
222        def rotate(theta_i, phi_j):
223            return Rx(phi_j)*Ry(theta_i)
224        def weight(theta_i, phi_j, wi, wj):
225            return wi*wj*cos(radians(theta_i))
226
227    # mesh in theta, phi formed by rotating z
228    z = np.matrix([[0], [0], [radius]])
229    points = np.hstack([rotate(theta_i, phi_j)*z
230                        for theta_i in dtheta*t
231                        for phi_j in dphi*t])
232    # select just the active points (i.e., those with phi < 180
233    w = np.array([weight(theta_i, phi_j, wi, wj)
234                  for wi, theta_i in zip(weights, dtheta*t)
235                  for wj, phi_j in zip(weights, dphi*t)])
236    points = points[:, w>0]
237    w = w[w>0]
238
239    if 0: # Kent distribution
240        points = np.hstack([Rx(phi_j)*Ry(theta_i)*z for theta_i in 30*t for phi_j in 60*t])
241        xp, yp, zp = [np.array(v).flatten() for v in points]
242        kappa = max(1e6, radians(dtheta)/(2*pi))
243        beta = 1/max(1e-6, radians(dphi)/(2*pi))/kappa
244        w = exp(kappa*zp) #+ beta*(xp**2 + yp**2)
245        print(kappa, dtheta, radians(dtheta), min(w), max(w), sum(w))
246        #w /= abs(cos(radians(
247        #w /= sum(w)
248
249    # rotate relative to beam
250    points = orient_relative_to_beam(view, points)
251
252    x, y, z = [np.array(v).flatten() for v in points]
253    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
254    ax.scatter(x, y, z, c=w, marker='o', vmin=0., vmax=1.)
255
256def draw_labels(ax, view, jitter, text):
257    """
258    Draw text at a particular location.
259    """
260    labels, locations, orientations = zip(*text)
261    px, py, pz = zip(*locations)
262    dx, dy, dz = zip(*orientations)
263
264    px, py, pz = transform_xyz(view, jitter, px, py, pz)
265    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
266
267    # TODO: zdir for labels is broken, and labels aren't appearing.
268    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
269        zdir = np.asarray(zdir).flatten()
270        ax.text(p[0], p[1], p[2], label, zdir=zdir)
271
272# Definition of rotation matrices comes from wikipedia:
273#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
274def Rx(angle):
275    """Construct a matrix to rotate points about *x* by *angle* degrees."""
276    a = radians(angle)
277    R = [[1, 0, 0],
278         [0, +cos(a), -sin(a)],
279         [0, +sin(a), +cos(a)]]
280    return np.matrix(R)
281
282def Ry(angle):
283    """Construct a matrix to rotate points about *y* by *angle* degrees."""
284    a = radians(angle)
285    R = [[+cos(a), 0, +sin(a)],
286         [0, 1, 0],
287         [-sin(a), 0, +cos(a)]]
288    return np.matrix(R)
289
290def Rz(angle):
291    """Construct a matrix to rotate points about *z* by *angle* degrees."""
292    a = radians(angle)
293    R = [[+cos(a), -sin(a), 0],
294         [+sin(a), +cos(a), 0],
295         [0, 0, 1]]
296    return np.matrix(R)
297
298def transform_xyz(view, jitter, x, y, z):
299    """
300    Send a set of (x,y,z) points through the jitter and view transforms.
301    """
302    x, y, z = [np.asarray(v) for v in (x, y, z)]
303    shape = x.shape
304    points = np.matrix([x.flatten(),y.flatten(),z.flatten()])
305    points = apply_jitter(jitter, points)
306    points = orient_relative_to_beam(view, points)
307    x, y, z = [np.array(v).reshape(shape) for v in points]
308    return x, y, z
309
310def apply_jitter(jitter, points):
311    """
312    Apply the jitter transform to a set of points.
313
314    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
315    """
316    dtheta, dphi, dpsi = jitter
317    points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
318    return points
319
320def orient_relative_to_beam(view, points):
321    """
322    Apply the view transform to a set of points.
323
324    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
325    """
326    theta, phi, psi = view
327    points = Rz(phi)*Ry(theta)*Rz(psi)*points
328    return points
329
330# translate between number of dimension of dispersity and the number of
331# points along each dimension.
332PD_N_TABLE = {
333    (0, 0, 0): (0, 0, 0),     # 0
334    (1, 0, 0): (100, 0, 0),   # 100
335    (0, 1, 0): (0, 100, 0),
336    (0, 0, 1): (0, 0, 100),
337    (1, 1, 0): (30, 30, 0),   # 900
338    (1, 0, 1): (30, 0, 30),
339    (0, 1, 1): (0, 30, 30),
340    (1, 1, 1): (15, 15, 15),  # 3375
341}
342
343def clipped_range(data, portion=1.0, mode='central'):
344    """
345    Determine range from data.
346
347    If *portion* is 1, use full range, otherwise use the center of the range
348    or the top of the range, depending on whether *mode* is 'central' or 'top'.
349    """
350    if portion == 1.0:
351        return data.min(), data.max()
352    elif mode == 'central':
353        data = np.sort(data.flatten())
354        offset = int(portion*len(data)/2 + 0.5)
355        return data[offset], data[-offset]
356    elif mode == 'top':
357        data = np.sort(data.flatten())
358        offset = int(portion*len(data) + 0.5)
359        return data[offset], data[-1]
360
361def draw_scattering(calculator, ax, view, jitter, dist='gaussian'):
362    """
363    Plot the scattering for the particular view.
364
365    *calculator* is returned from :func:`build_model`.  *ax* are the 3D axes
366    on which the data will be plotted.  *view* and *jitter* are the current
367    orientation and orientation dispersity.  *dist* is one of the sasmodels
368    weight distributions.
369    """
370    ## Sasmodels use sqrt(3)*width for the rectangle range; scale to the
371    ## proper width for comparison. Commented out since now using the
372    ## sasmodels definition of width for rectangle.
373    #scale = 1/sqrt(3) if dist == 'rectangle' else 1
374    scale = 1
375
376    # add the orientation parameters to the model parameters
377    theta, phi, psi = view
378    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
379    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd>0, phi_pd>0, psi_pd>0)]
380    ## increase pd_n for testing jitter integration rather than simple viz
381    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
382
383    pars = dict(
384        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
385        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
386        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
387    )
388    pars.update(calculator.pars)
389
390    # compute the pattern
391    qx, qy = calculator._data.x_bins, calculator._data.y_bins
392    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
393
394    # scale it and draw it
395    Iqxy = np.log(Iqxy)
396    if calculator.limits:
397        # use limits from orientation (0,0,0)
398        vmin, vmax = calculator.limits
399    else:
400        vmin, vmax = clipped_range(Iqxy, portion=0.95, mode='top')
401    #print("range",(vmin,vmax))
402    #qx, qy = np.meshgrid(qx, qy)
403    if 0:
404        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
405        level[level<0] = 0
406        colors = plt.get_cmap()(level)
407        ax.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors)
408    elif 1:
409        ax.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
410                    levels=np.linspace(vmin, vmax, 24))
411    else:
412        ax.pcolormesh(qx, qy, Iqxy)
413
414def build_model(model_name, n=150, qmax=0.5, **pars):
415    """
416    Build a calculator for the given shape.
417
418    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
419    on which to evaluate the model.  The remaining parameters are stored in
420    the returned calculator as *calculator.pars*.  They are used by
421    :func:`draw_scattering` to set the non-orientation parameters in the
422    calculation.
423
424    Returns a *calculator* function which takes a dictionary or parameters and
425    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
426    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
427    for details.
428    """
429    from sasmodels.core import load_model_info, build_model
430    from sasmodels.data import empty_data2D
431    from sasmodels.direct_model import DirectModel
432
433    model_info = load_model_info(model_name)
434    model = build_model(model_info) #, dtype='double!')
435    q = np.linspace(-qmax, qmax, n)
436    data = empty_data2D(q, q)
437    calculator = DirectModel(data, model)
438
439    # stuff the values for non-orientation parameters into the calculator
440    calculator.pars = pars.copy()
441    calculator.pars.setdefault('backgound', 1e-3)
442
443    # fix the data limits so that we can see if the pattern fades
444    # under rotation or angular dispersion
445    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
446    Iqxy = np.log(Iqxy)
447    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
448    calculator.limits = vmin, vmax+1
449
450    return calculator
451
452def select_calculator(model_name, n=150, size=(10,40,100)):
453    """
454    Create a model calculator for the given shape.
455
456    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
457    parallelepiped or bcc_paracrystal. *n* is the number of points to use
458    in the q range.  *qmax* is chosen based on model parameters for the
459    given model to show something intersting.
460
461    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
462    equitorial axes and polar axis respectively.  See :func:`build_model`
463    for details on the returned calculator.
464    """
465    a, b, c = size
466    if model_name == 'sphere':
467        calculator = build_model('sphere', n=n, radius=c)
468        a = b = c
469    elif model_name == 'bcc_paracrystal':
470        calculator = build_model('bcc_paracrystal', n=n, dnn=c,
471                                  d_factor=0.06, radius=40)
472        a = b = c
473    elif model_name == 'cylinder':
474        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
475        a = b
476    elif model_name == 'ellipsoid':
477        calculator = build_model('ellipsoid', n=n, qmax=1.0,
478                                 radius_polar=c, radius_equatorial=b)
479        a = b
480    elif model_name == 'triaxial_ellipsoid':
481        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
482                                 radius_equat_minor=a,
483                                 radius_equat_major=b,
484                                 radius_polar=c)
485    elif model_name == 'parallelepiped':
486        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
487    else:
488        raise ValueError("unknown model %s"%model_name)
489
490    return calculator, (a, b, c)
491
492def main(model_name='parallelepiped', size=(10, 40, 100)):
493    """
494    Show an interactive orientation and jitter demo.
495
496    *model_name* is one of the models available in :func:`select_model`.
497    """
498    # set up calculator
499    calculator, size = select_calculator(model_name, n=150, size=size)
500
501    ## uncomment to set an independent the colour range for every view
502    ## If left commented, the colour range is fixed for all views
503    calculator.limits = None
504
505    ## use gaussian distribution unless testing integration
506    #dist = 'rectangle'
507    dist = 'gaussian'
508
509    ## initial view
510    #theta, dtheta = 70., 10.
511    #phi, dphi = -45., 3.
512    #psi, dpsi = -45., 3.
513    theta, phi, psi = 0, 0, 0
514    dtheta, dphi, dpsi = 0, 0, 0
515
516    ## create the plot window
517    #plt.hold(True)
518    plt.set_cmap('gist_earth')
519    plt.clf()
520    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
521    #ax = plt.subplot(gs[0], projection='3d')
522    ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
523    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
524        ax.axis('square')
525    except Exception:
526        pass
527
528    axcolor = 'lightgoldenrodyellow'
529
530    ## add control widgets to plot
531    axtheta  = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor)
532    axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor)
533    axpsi = plt.axes([0.1, 0.05, 0.45, 0.04], axisbg=axcolor)
534    stheta = Slider(axtheta, 'Theta', -90, 90, valinit=theta)
535    sphi = Slider(axphi, 'Phi', -180, 180, valinit=phi)
536    spsi = Slider(axpsi, 'Psi', -180, 180, valinit=psi)
537
538    axdtheta  = plt.axes([0.75, 0.15, 0.15, 0.04], axisbg=axcolor)
539    axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor)
540    axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor)
541    # Note: using ridiculous definition of rectangle distribution, whose width
542    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
543    # the maximum width to 90.
544    dlimit = 30 if dist == 'gaussian' else 90/sqrt(3)
545    sdtheta = Slider(axdtheta, 'dTheta', 0, 2*dlimit, valinit=dtheta)
546    sdphi = Slider(axdphi, 'dPhi', 0, 2*dlimit, valinit=dphi)
547    sdpsi = Slider(axdpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi)
548
549    ## callback to draw the new view
550    def update(val, axis=None):
551        view = stheta.val, sphi.val, spsi.val
552        jitter = sdtheta.val, sdphi.val, sdpsi.val
553        # set small jitter as 0 if multiple pd dims
554        dims = sum(v > 0 for v in jitter)
555        limit = [0, 0, 2, 5][dims]
556        jitter = [0 if v < limit else v for v in jitter]
557        ax.cla()
558        draw_beam(ax, (0, 0))
559        draw_jitter(ax, view, jitter, dist=dist, size=size)
560        #draw_jitter(ax, view, (0,0,0))
561        draw_mesh(ax, view, jitter, dist=dist, n=30)
562        draw_scattering(calculator, ax, view, jitter, dist=dist)
563        plt.gcf().canvas.draw()
564
565    ## bind control widgets to view updater
566    stheta.on_changed(lambda v: update(v,'theta'))
567    sphi.on_changed(lambda v: update(v, 'phi'))
568    spsi.on_changed(lambda v: update(v, 'psi'))
569    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
570    sdphi.on_changed(lambda v: update(v, 'dphi'))
571    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
572
573    ## initialize view
574    update(None, 'phi')
575
576    ## go interactive
577    plt.show()
578
579if __name__ == "__main__":
580    model_name = sys.argv[1] if len(sys.argv) > 1 else 'parallelepiped'
581    size = tuple(int(v) for v in sys.argv[2].split(',')) if len(sys.argv) > 2 else (10, 40, 100)
582    main(model_name, size)
Note: See TracBrowser for help on using the repository browser.