source: sasmodels/explore/jitter.py @ 5b5ea20

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5b5ea20 was 5b5ea20, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

[possibly] correct weighting for the azimuthal equidistant projection

  • Property mode set to 100755
File size: 22.7 KB
Line 
1#!/usr/bin/env python
2"""
3Application to explore the difference between sasview 3.x orientation
4dispersity and possible replacement algorithms.
5"""
6from __future__ import division, print_function
7
8import sys, os
9sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
10
11import mpl_toolkits.mplot3d   # Adds projection='3d' option to subplot
12import matplotlib.pyplot as plt
13from matplotlib.widgets import Slider, CheckButtons
14from matplotlib import cm
15import numpy as np
16from numpy import pi, cos, sin, sqrt, exp, degrees, radians
17
18def draw_beam(ax, view=(0, 0)):
19    """
20    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
21    """
22    #ax.plot([0,0],[0,0],[1,-1])
23    #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8)
24
25    steps = 25
26    u = np.linspace(0, 2 * np.pi, steps)
27    v = np.linspace(-1, 1, steps)
28
29    r = 0.02
30    x = r*np.outer(np.cos(u), np.ones_like(v))
31    y = r*np.outer(np.sin(u), np.ones_like(v))
32    z = 1.3*np.outer(np.ones_like(u), v)
33
34    theta, phi = view
35    shape = x.shape
36    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
37    points = Rz(phi)*Ry(theta)*points
38    x, y, z = [v.reshape(shape) for v in points]
39
40    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5)
41
42def draw_jitter(ax, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0)):
43    """
44    Represent jitter as a set of shapes at different orientations.
45    """
46    # set max diagonal to 0.95
47    scale = 0.95/sqrt(sum(v**2 for v in size))
48    size = tuple(scale*v for v in size)
49    draw_shape = draw_parallelepiped
50    #draw_shape = draw_ellipsoid
51
52    #np.random.seed(10)
53    #cloud = np.random.randn(10,3)
54    cloud = [
55        [-1, -1, -1],
56        [-1, -1,  0],
57        [-1, -1,  1],
58        [-1,  0, -1],
59        [-1,  0,  0],
60        [-1,  0,  1],
61        [-1,  1, -1],
62        [-1,  1,  0],
63        [-1,  1,  1],
64        [ 0, -1, -1],
65        [ 0, -1,  0],
66        [ 0, -1,  1],
67        [ 0,  0, -1],
68        [ 0,  0,  0],
69        [ 0,  0,  1],
70        [ 0,  1, -1],
71        [ 0,  1,  0],
72        [ 0,  1,  1],
73        [ 1, -1, -1],
74        [ 1, -1,  0],
75        [ 1, -1,  1],
76        [ 1,  0, -1],
77        [ 1,  0,  0],
78        [ 1,  0,  1],
79        [ 1,  1, -1],
80        [ 1,  1,  0],
81        [ 1,  1,  1],
82    ]
83    dtheta, dphi, dpsi = jitter
84    if dtheta == 0:
85        cloud = [v for v in cloud if v[0] == 0]
86    if dphi == 0:
87        cloud = [v for v in cloud if v[1] == 0]
88    if dpsi == 0:
89        cloud = [v for v in cloud if v[2] == 0]
90    draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8)
91    scale = 1/sqrt(3) if dist == 'rectangle' else 1
92    for point in cloud:
93        delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]]
94        draw_shape(ax, size, view, delta, alpha=0.8)
95    for v in 'xyz':
96        a, b, c = size
97        lim = np.sqrt(a**2+b**2+c**2)
98        getattr(ax, 'set_'+v+'lim')([-lim, lim])
99        getattr(ax, v+'axis').label.set_text(v)
100
101def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1):
102    """Draw an ellipsoid."""
103    a,b,c = size
104    u = np.linspace(0, 2 * np.pi, steps)
105    v = np.linspace(0, np.pi, steps)
106    x = a*np.outer(np.cos(u), np.sin(v))
107    y = b*np.outer(np.sin(u), np.sin(v))
108    z = c*np.outer(np.ones_like(u), np.cos(v))
109    x, y, z = transform_xyz(view, jitter, x, y, z)
110
111    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha)
112
113    draw_labels(ax, view, jitter, [
114         ('c+', [ 0, 0, c], [ 1, 0, 0]),
115         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
116         ('a+', [ a, 0, 0], [ 0, 0, 1]),
117         ('a-', [-a, 0, 0], [ 0, 0,-1]),
118         ('b+', [ 0, b, 0], [-1, 0, 0]),
119         ('b-', [ 0,-b, 0], [-1, 0, 0]),
120    ])
121
122def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1):
123    """Draw a parallelepiped."""
124    a,b,c = size
125    x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1])
126    y = b*np.array([ 1, 1,-1,-1, 1, 1,-1,-1])
127    z = c*np.array([ 1, 1, 1, 1,-1,-1,-1,-1])
128    tri = np.array([
129        # counter clockwise triangles
130        # z: up/down, x: right/left, y: front/back
131        [0,1,2], [3,2,1], # top face
132        [6,5,4], [5,6,7], # bottom face
133        [0,2,6], [6,4,0], # right face
134        [1,5,7], [7,3,1], # left face
135        [2,3,6], [7,6,3], # front face
136        [4,1,0], [5,1,4], # back face
137    ])
138
139    x, y, z = transform_xyz(view, jitter, x, y, z)
140    ax.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha)
141
142    draw_labels(ax, view, jitter, [
143         ('c+', [ 0, 0, c], [ 1, 0, 0]),
144         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
145         ('a+', [ a, 0, 0], [ 0, 0, 1]),
146         ('a-', [-a, 0, 0], [ 0, 0,-1]),
147         ('b+', [ 0, b, 0], [-1, 0, 0]),
148         ('b-', [ 0,-b, 0], [-1, 0, 0]),
149    ])
150
151def draw_sphere(ax, radius=10., steps=100):
152    """Draw a sphere"""
153    u = np.linspace(0, 2 * np.pi, steps)
154    v = np.linspace(0, np.pi, steps)
155
156    x = radius * np.outer(np.cos(u), np.sin(v))
157    y = radius * np.outer(np.sin(u), np.sin(v))
158    z = radius * np.outer(np.ones(np.size(u)), np.cos(v))
159    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w')
160
161def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gaussian'):
162    """
163    Draw the dispersion mesh showing the theta-phi orientations at which
164    the model will be evaluated.
165    """
166    theta, phi, psi = view
167    dtheta, dphi, dpsi = jitter
168
169    if dist == 'gaussian':
170        t = np.linspace(-3, 3, n)
171        weights = exp(-0.5*t**2)
172    elif dist == 'rectangle':
173        # Note: uses sasmodels ridiculous definition of rectangle width
174        t = np.linspace(-1, 1, n)*sqrt(3)
175        weights = np.ones_like(t)
176    else:
177        raise ValueError("expected dist to be 'gaussian' or 'rectangle'")
178
179    #PROJECTION = 'stretched_phi'
180    PROJECTION = 'azimuthal_equidistance'
181    #PROJECTION = 'azimuthal_equal_area'
182    if PROJECTION == 'stretced_phi':
183        def rotate(theta_i, phi_j):
184            if theta_i != 90:
185                phi_j /= cos(radians(theta_i))
186            return Rx(phi_j)*Ry(theta_i)
187        def weight(theta_i, phi_j, wi, wj):
188            if theta_i != 90:
189                phi_j /= cos(radians(theta_i))
190            return wi*wj if abs(phi_j) < 180 else 0
191    elif PROJECTION == 'azimuthal_equidistance':
192        # https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection
193        def rotate(theta_i, phi_j):
194            latitude = sqrt(theta_i**2 + phi_j**2)
195            longitude = degrees(np.arctan2(phi_j, theta_i))
196            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
197            return Rz(longitude)*Ry(latitude)
198        def weight(theta_i, phi_j, wi, wj):
199            # Weighting for each point comes from the integral:
200            #     \int\int I(q, lat, log) sin(lat) dlat dlog
201            # We are doing a conformal mapping from disk to sphere, so we need
202            # a change of variables g(theta, phi) -> (lat, long):
203            #     lat, long = sqrt(theta^2 + phi^2), arctan(phi/theta)
204            # giving:
205            #     dtheta dphi = det(J) dlat dlong
206            # where J is the jacobian from the partials of g. Using
207            #     R = sqrt(theta^2 + phi^2),
208            # then
209            #     J = [[x/R, Y/R], -y/R^2, x/R^2]]
210            # and
211            #     det(J) = 1/R
212            # with the final integral being:
213            #    \int\int I(q, theta, phi) sin(R)/R dtheta dphi
214            #
215            # This does approximately the right thing, decreasing the weight
216            # of each point as you go farther out on the disk, but it hasn't
217            # yet been checked against the 1D integral results. Prior
218            # to declaring this "good enough" and checking that integrals
219            # work in practice, we will examine alternative mappings.
220            #
221            # The issue is that the mapping does not support the case of free
222            # rotation about a single axis correctly, with a small deviation
223            # in the orthogonal axis independent of the first axis.  Like the
224            # usual polar coordiates integration, the integrated sections
225            # form wedges, though at least in this case the wedge cuts through
226            # the entire sphere, and treats theta and phi identically.
227            latitude = sqrt(theta_i**2 + phi_j**2)
228            w = sin(radians(latitude))/latitude if latitude != 0 else 1
229            return w*wi*wj if latitude < 180 else 0
230    elif PROJECTION == 'azimuthal_equal_area':
231        # https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection
232        def rotate(theta_i, phi_j):
233            R = min(1, sqrt(theta_i**2 + phi_j**2)/180)
234            latitude = 180-degrees(2*np.arccos(R))
235            longitude = degrees(np.arctan2(phi_j, theta_i))
236            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
237            return Rz(longitude)*Ry(latitude)
238        def weight(theta_i, phi_j, wi, wj):
239            latitude = sqrt(theta_i**2 + phi_j**2)
240            w = sin(radians(latitude))/latitude if latitude != 0 else 1
241            return w*wi*wj if latitude < 180 else 0
242    elif SCALED_PHI == 10:  # random thrashing
243        def rotate(theta_i, phi_j):
244            theta_i, phi_j = 2*theta_i/abs(cos(radians(phi_j))), 2*phi_j/cos(radians(theta_i))
245            return Rx(phi_j)*Ry(theta_i)
246        def weight(theta_i, phi_j, wi, wj):
247            theta_i, phi_j = 2*theta_i/abs(cos(radians(phi_j))), 2*phi_j/cos(radians(theta_i))
248            return wi*wj if abs(phi_j) < 180 else 0
249    else:
250        def rotate(theta_i, phi_j):
251            return Rx(phi_j)*Ry(theta_i)
252        def weight(theta_i, phi_j, wi, wj):
253            return wi*wj*cos(radians(theta_i))
254
255    # mesh in theta, phi formed by rotating z
256    z = np.matrix([[0], [0], [radius]])
257    points = np.hstack([rotate(theta_i, phi_j)*z
258                        for theta_i in dtheta*t
259                        for phi_j in dphi*t])
260    # select just the active points (i.e., those with phi < 180
261    w = np.array([weight(theta_i, phi_j, wi, wj)
262                  for wi, theta_i in zip(weights, dtheta*t)
263                  for wj, phi_j in zip(weights, dphi*t)])
264    #print(max(w), min(w), min(w[w>0]))
265    points = points[:, w>0]
266    w = w[w>0]
267    w /= max(w)
268
269    if 0: # Kent distribution
270        points = np.hstack([Rx(phi_j)*Ry(theta_i)*z for theta_i in 30*t for phi_j in 60*t])
271        xp, yp, zp = [np.array(v).flatten() for v in points]
272        kappa = max(1e6, radians(dtheta)/(2*pi))
273        beta = 1/max(1e-6, radians(dphi)/(2*pi))/kappa
274        w = exp(kappa*zp) #+ beta*(xp**2 + yp**2)
275        print(kappa, dtheta, radians(dtheta), min(w), max(w), sum(w))
276        #w /= abs(cos(radians(
277        #w /= sum(w)
278
279    # rotate relative to beam
280    points = orient_relative_to_beam(view, points)
281
282    x, y, z = [np.array(v).flatten() for v in points]
283    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
284    ax.scatter(x, y, z, c=w, marker='o', vmin=0., vmax=1.)
285
286def draw_labels(ax, view, jitter, text):
287    """
288    Draw text at a particular location.
289    """
290    labels, locations, orientations = zip(*text)
291    px, py, pz = zip(*locations)
292    dx, dy, dz = zip(*orientations)
293
294    px, py, pz = transform_xyz(view, jitter, px, py, pz)
295    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
296
297    # TODO: zdir for labels is broken, and labels aren't appearing.
298    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
299        zdir = np.asarray(zdir).flatten()
300        ax.text(p[0], p[1], p[2], label, zdir=zdir)
301
302# Definition of rotation matrices comes from wikipedia:
303#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
304def Rx(angle):
305    """Construct a matrix to rotate points about *x* by *angle* degrees."""
306    a = radians(angle)
307    R = [[1, 0, 0],
308         [0, +cos(a), -sin(a)],
309         [0, +sin(a), +cos(a)]]
310    return np.matrix(R)
311
312def Ry(angle):
313    """Construct a matrix to rotate points about *y* by *angle* degrees."""
314    a = radians(angle)
315    R = [[+cos(a), 0, +sin(a)],
316         [0, 1, 0],
317         [-sin(a), 0, +cos(a)]]
318    return np.matrix(R)
319
320def Rz(angle):
321    """Construct a matrix to rotate points about *z* by *angle* degrees."""
322    a = radians(angle)
323    R = [[+cos(a), -sin(a), 0],
324         [+sin(a), +cos(a), 0],
325         [0, 0, 1]]
326    return np.matrix(R)
327
328def transform_xyz(view, jitter, x, y, z):
329    """
330    Send a set of (x,y,z) points through the jitter and view transforms.
331    """
332    x, y, z = [np.asarray(v) for v in (x, y, z)]
333    shape = x.shape
334    points = np.matrix([x.flatten(),y.flatten(),z.flatten()])
335    points = apply_jitter(jitter, points)
336    points = orient_relative_to_beam(view, points)
337    x, y, z = [np.array(v).reshape(shape) for v in points]
338    return x, y, z
339
340def apply_jitter(jitter, points):
341    """
342    Apply the jitter transform to a set of points.
343
344    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
345    """
346    dtheta, dphi, dpsi = jitter
347    points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
348    return points
349
350def orient_relative_to_beam(view, points):
351    """
352    Apply the view transform to a set of points.
353
354    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
355    """
356    theta, phi, psi = view
357    points = Rz(phi)*Ry(theta)*Rz(psi)*points
358    return points
359
360# translate between number of dimension of dispersity and the number of
361# points along each dimension.
362PD_N_TABLE = {
363    (0, 0, 0): (0, 0, 0),     # 0
364    (1, 0, 0): (100, 0, 0),   # 100
365    (0, 1, 0): (0, 100, 0),
366    (0, 0, 1): (0, 0, 100),
367    (1, 1, 0): (30, 30, 0),   # 900
368    (1, 0, 1): (30, 0, 30),
369    (0, 1, 1): (0, 30, 30),
370    (1, 1, 1): (15, 15, 15),  # 3375
371}
372
373def clipped_range(data, portion=1.0, mode='central'):
374    """
375    Determine range from data.
376
377    If *portion* is 1, use full range, otherwise use the center of the range
378    or the top of the range, depending on whether *mode* is 'central' or 'top'.
379    """
380    if portion == 1.0:
381        return data.min(), data.max()
382    elif mode == 'central':
383        data = np.sort(data.flatten())
384        offset = int(portion*len(data)/2 + 0.5)
385        return data[offset], data[-offset]
386    elif mode == 'top':
387        data = np.sort(data.flatten())
388        offset = int(portion*len(data) + 0.5)
389        return data[offset], data[-1]
390
391def draw_scattering(calculator, ax, view, jitter, dist='gaussian'):
392    """
393    Plot the scattering for the particular view.
394
395    *calculator* is returned from :func:`build_model`.  *ax* are the 3D axes
396    on which the data will be plotted.  *view* and *jitter* are the current
397    orientation and orientation dispersity.  *dist* is one of the sasmodels
398    weight distributions.
399    """
400    ## Sasmodels use sqrt(3)*width for the rectangle range; scale to the
401    ## proper width for comparison. Commented out since now using the
402    ## sasmodels definition of width for rectangle.
403    #scale = 1/sqrt(3) if dist == 'rectangle' else 1
404    scale = 1
405
406    # add the orientation parameters to the model parameters
407    theta, phi, psi = view
408    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
409    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd>0, phi_pd>0, psi_pd>0)]
410    ## increase pd_n for testing jitter integration rather than simple viz
411    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
412
413    pars = dict(
414        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
415        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
416        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
417    )
418    pars.update(calculator.pars)
419
420    # compute the pattern
421    qx, qy = calculator._data.x_bins, calculator._data.y_bins
422    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
423
424    # scale it and draw it
425    Iqxy = np.log(Iqxy)
426    if calculator.limits:
427        # use limits from orientation (0,0,0)
428        vmin, vmax = calculator.limits
429    else:
430        vmin, vmax = clipped_range(Iqxy, portion=0.95, mode='top')
431    #print("range",(vmin,vmax))
432    #qx, qy = np.meshgrid(qx, qy)
433    if 0:
434        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
435        level[level<0] = 0
436        colors = plt.get_cmap()(level)
437        ax.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors)
438    elif 1:
439        ax.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
440                    levels=np.linspace(vmin, vmax, 24))
441    else:
442        ax.pcolormesh(qx, qy, Iqxy)
443
444def build_model(model_name, n=150, qmax=0.5, **pars):
445    """
446    Build a calculator for the given shape.
447
448    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
449    on which to evaluate the model.  The remaining parameters are stored in
450    the returned calculator as *calculator.pars*.  They are used by
451    :func:`draw_scattering` to set the non-orientation parameters in the
452    calculation.
453
454    Returns a *calculator* function which takes a dictionary or parameters and
455    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
456    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
457    for details.
458    """
459    from sasmodels.core import load_model_info, build_model
460    from sasmodels.data import empty_data2D
461    from sasmodels.direct_model import DirectModel
462
463    model_info = load_model_info(model_name)
464    model = build_model(model_info) #, dtype='double!')
465    q = np.linspace(-qmax, qmax, n)
466    data = empty_data2D(q, q)
467    calculator = DirectModel(data, model)
468
469    # stuff the values for non-orientation parameters into the calculator
470    calculator.pars = pars.copy()
471    calculator.pars.setdefault('backgound', 1e-3)
472
473    # fix the data limits so that we can see if the pattern fades
474    # under rotation or angular dispersion
475    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
476    Iqxy = np.log(Iqxy)
477    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
478    calculator.limits = vmin, vmax+1
479
480    return calculator
481
482def select_calculator(model_name, n=150, size=(10,40,100)):
483    """
484    Create a model calculator for the given shape.
485
486    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
487    parallelepiped or bcc_paracrystal. *n* is the number of points to use
488    in the q range.  *qmax* is chosen based on model parameters for the
489    given model to show something intersting.
490
491    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
492    equitorial axes and polar axis respectively.  See :func:`build_model`
493    for details on the returned calculator.
494    """
495    a, b, c = size
496    if model_name == 'sphere':
497        calculator = build_model('sphere', n=n, radius=c)
498        a = b = c
499    elif model_name == 'bcc_paracrystal':
500        calculator = build_model('bcc_paracrystal', n=n, dnn=c,
501                                  d_factor=0.06, radius=40)
502        a = b = c
503    elif model_name == 'cylinder':
504        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
505        a = b
506    elif model_name == 'ellipsoid':
507        calculator = build_model('ellipsoid', n=n, qmax=1.0,
508                                 radius_polar=c, radius_equatorial=b)
509        a = b
510    elif model_name == 'triaxial_ellipsoid':
511        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
512                                 radius_equat_minor=a,
513                                 radius_equat_major=b,
514                                 radius_polar=c)
515    elif model_name == 'parallelepiped':
516        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
517    else:
518        raise ValueError("unknown model %s"%model_name)
519
520    return calculator, (a, b, c)
521
522def main(model_name='parallelepiped', size=(10, 40, 100)):
523    """
524    Show an interactive orientation and jitter demo.
525
526    *model_name* is one of the models available in :func:`select_model`.
527    """
528    # set up calculator
529    calculator, size = select_calculator(model_name, n=150, size=size)
530
531    ## uncomment to set an independent the colour range for every view
532    ## If left commented, the colour range is fixed for all views
533    calculator.limits = None
534
535    ## use gaussian distribution unless testing integration
536    dist = 'rectangle'
537    #dist = 'gaussian'
538
539    ## initial view
540    #theta, dtheta = 70., 10.
541    #phi, dphi = -45., 3.
542    #psi, dpsi = -45., 3.
543    theta, phi, psi = 0, 0, 0
544    dtheta, dphi, dpsi = 0, 0, 0
545
546    ## create the plot window
547    #plt.hold(True)
548    plt.set_cmap('gist_earth')
549    plt.clf()
550    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
551    #ax = plt.subplot(gs[0], projection='3d')
552    ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
553    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
554        ax.axis('square')
555    except Exception:
556        pass
557
558    axcolor = 'lightgoldenrodyellow'
559
560    ## add control widgets to plot
561    axtheta  = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor)
562    axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor)
563    axpsi = plt.axes([0.1, 0.05, 0.45, 0.04], axisbg=axcolor)
564    stheta = Slider(axtheta, 'Theta', -90, 90, valinit=theta)
565    sphi = Slider(axphi, 'Phi', -180, 180, valinit=phi)
566    spsi = Slider(axpsi, 'Psi', -180, 180, valinit=psi)
567
568    axdtheta  = plt.axes([0.75, 0.15, 0.15, 0.04], axisbg=axcolor)
569    axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor)
570    axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor)
571    # Note: using ridiculous definition of rectangle distribution, whose width
572    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
573    # the maximum width to 90.
574    dlimit = 30 if dist == 'gaussian' else 90/sqrt(3)
575    sdtheta = Slider(axdtheta, 'dTheta', 0, 2*dlimit, valinit=dtheta)
576    sdphi = Slider(axdphi, 'dPhi', 0, 2*dlimit, valinit=dphi)
577    sdpsi = Slider(axdpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi)
578
579    ## callback to draw the new view
580    def update(val, axis=None):
581        view = stheta.val, sphi.val, spsi.val
582        jitter = sdtheta.val, sdphi.val, sdpsi.val
583        # set small jitter as 0 if multiple pd dims
584        dims = sum(v > 0 for v in jitter)
585        limit = [0, 0, 2, 5][dims]
586        jitter = [0 if v < limit else v for v in jitter]
587        ax.cla()
588        draw_beam(ax, (0, 0))
589        draw_jitter(ax, view, jitter, dist=dist, size=size)
590        #draw_jitter(ax, view, (0,0,0))
591        draw_mesh(ax, view, jitter, dist=dist, n=30)
592        draw_scattering(calculator, ax, view, jitter, dist=dist)
593        plt.gcf().canvas.draw()
594
595    ## bind control widgets to view updater
596    stheta.on_changed(lambda v: update(v,'theta'))
597    sphi.on_changed(lambda v: update(v, 'phi'))
598    spsi.on_changed(lambda v: update(v, 'psi'))
599    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
600    sdphi.on_changed(lambda v: update(v, 'dphi'))
601    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
602
603    ## initialize view
604    update(None, 'phi')
605
606    ## go interactive
607    plt.show()
608
609if __name__ == "__main__":
610    model_name = sys.argv[1] if len(sys.argv) > 1 else 'parallelepiped'
611    size = tuple(int(v) for v in sys.argv[2].split(',')) if len(sys.argv) > 2 else (10, 40, 100)
612    main(model_name, size)
Note: See TracBrowser for help on using the repository browser.