source: sasmodels/explore/jitter.py @ bcb5594

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since bcb5594 was bcb5594, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

allow selection of jitter mesh behaviour from command line

  • Property mode set to 100755
File size: 26.2 KB
Line 
1#!/usr/bin/env python
2"""
3Application to explore the difference between sasview 3.x orientation
4dispersity and possible replacement algorithms.
5"""
6from __future__ import division, print_function
7
8import sys, os
9sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
10
11import argparse
12
13import mpl_toolkits.mplot3d   # Adds projection='3d' option to subplot
14import matplotlib.pyplot as plt
15from matplotlib.widgets import Slider, CheckButtons
16from matplotlib import cm
17import numpy as np
18from numpy import pi, cos, sin, sqrt, exp, degrees, radians
19
20def draw_beam(ax, view=(0, 0)):
21    """
22    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
23    """
24    #ax.plot([0,0],[0,0],[1,-1])
25    #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8)
26
27    steps = 25
28    u = np.linspace(0, 2 * np.pi, steps)
29    v = np.linspace(-1, 1, steps)
30
31    r = 0.02
32    x = r*np.outer(np.cos(u), np.ones_like(v))
33    y = r*np.outer(np.sin(u), np.ones_like(v))
34    z = 1.3*np.outer(np.ones_like(u), v)
35
36    theta, phi = view
37    shape = x.shape
38    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
39    points = Rz(phi)*Ry(theta)*points
40    x, y, z = [v.reshape(shape) for v in points]
41
42    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5)
43
44def draw_jitter(ax, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0)):
45    """
46    Represent jitter as a set of shapes at different orientations.
47    """
48    # set max diagonal to 0.95
49    scale = 0.95/sqrt(sum(v**2 for v in size))
50    size = tuple(scale*v for v in size)
51    draw_shape = draw_parallelepiped
52    #draw_shape = draw_ellipsoid
53
54    #np.random.seed(10)
55    #cloud = np.random.randn(10,3)
56    cloud = [
57        [-1, -1, -1],
58        [-1, -1,  0],
59        [-1, -1,  1],
60        [-1,  0, -1],
61        [-1,  0,  0],
62        [-1,  0,  1],
63        [-1,  1, -1],
64        [-1,  1,  0],
65        [-1,  1,  1],
66        [ 0, -1, -1],
67        [ 0, -1,  0],
68        [ 0, -1,  1],
69        [ 0,  0, -1],
70        [ 0,  0,  0],
71        [ 0,  0,  1],
72        [ 0,  1, -1],
73        [ 0,  1,  0],
74        [ 0,  1,  1],
75        [ 1, -1, -1],
76        [ 1, -1,  0],
77        [ 1, -1,  1],
78        [ 1,  0, -1],
79        [ 1,  0,  0],
80        [ 1,  0,  1],
81        [ 1,  1, -1],
82        [ 1,  1,  0],
83        [ 1,  1,  1],
84    ]
85    dtheta, dphi, dpsi = jitter
86    if dtheta == 0:
87        cloud = [v for v in cloud if v[0] == 0]
88    if dphi == 0:
89        cloud = [v for v in cloud if v[1] == 0]
90    if dpsi == 0:
91        cloud = [v for v in cloud if v[2] == 0]
92    draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8)
93    scale = {'gaussian':1, 'rectangle':1/sqrt(3), 'uniform':1/3}[dist]
94    for point in cloud:
95        delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]]
96        draw_shape(ax, size, view, delta, alpha=0.8)
97    for v in 'xyz':
98        a, b, c = size
99        lim = np.sqrt(a**2+b**2+c**2)
100        getattr(ax, 'set_'+v+'lim')([-lim, lim])
101        getattr(ax, v+'axis').label.set_text(v)
102
103def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1):
104    """Draw an ellipsoid."""
105    a,b,c = size
106    u = np.linspace(0, 2 * np.pi, steps)
107    v = np.linspace(0, np.pi, steps)
108    x = a*np.outer(np.cos(u), np.sin(v))
109    y = b*np.outer(np.sin(u), np.sin(v))
110    z = c*np.outer(np.ones_like(u), np.cos(v))
111    x, y, z = transform_xyz(view, jitter, x, y, z)
112
113    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha)
114
115    draw_labels(ax, view, jitter, [
116         ('c+', [ 0, 0, c], [ 1, 0, 0]),
117         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
118         ('a+', [ a, 0, 0], [ 0, 0, 1]),
119         ('a-', [-a, 0, 0], [ 0, 0,-1]),
120         ('b+', [ 0, b, 0], [-1, 0, 0]),
121         ('b-', [ 0,-b, 0], [-1, 0, 0]),
122    ])
123
124def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1):
125    """Draw a parallelepiped."""
126    a,b,c = size
127    x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1])
128    y = b*np.array([ 1, 1,-1,-1, 1, 1,-1,-1])
129    z = c*np.array([ 1, 1, 1, 1,-1,-1,-1,-1])
130    tri = np.array([
131        # counter clockwise triangles
132        # z: up/down, x: right/left, y: front/back
133        [0,1,2], [3,2,1], # top face
134        [6,5,4], [5,6,7], # bottom face
135        [0,2,6], [6,4,0], # right face
136        [1,5,7], [7,3,1], # left face
137        [2,3,6], [7,6,3], # front face
138        [4,1,0], [5,1,4], # back face
139    ])
140
141    x, y, z = transform_xyz(view, jitter, x, y, z)
142    ax.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha)
143
144    draw_labels(ax, view, jitter, [
145         ('c+', [ 0, 0, c], [ 1, 0, 0]),
146         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
147         ('a+', [ a, 0, 0], [ 0, 0, 1]),
148         ('a-', [-a, 0, 0], [ 0, 0,-1]),
149         ('b+', [ 0, b, 0], [-1, 0, 0]),
150         ('b-', [ 0,-b, 0], [-1, 0, 0]),
151    ])
152
153def draw_sphere(ax, radius=10., steps=100):
154    """Draw a sphere"""
155    u = np.linspace(0, 2 * np.pi, steps)
156    v = np.linspace(0, np.pi, steps)
157
158    x = radius * np.outer(np.cos(u), np.sin(v))
159    y = radius * np.outer(np.sin(u), np.sin(v))
160    z = radius * np.outer(np.ones(np.size(u)), np.cos(v))
161    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w')
162
163PROJECTIONS = [
164    'equirectangular', 'azimuthal_equidistance', 'sinusoidal',
165]
166def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gaussian',
167              projection='equirectangular'):
168    """
169    Draw the dispersion mesh showing the theta-phi orientations at which
170    the model will be evaluated.
171
172    jitter projections
173    <https://en.wikipedia.org/wiki/List_of_map_projections>
174
175    equirectangular (standard latitude-longitude mesh)
176        <https://en.wikipedia.org/wiki/Equirectangular_projection>
177        Allows free movement in phi (around the equator), but theta is
178        limited to +/- 90, and points are cos-weighted. Jitter in phi is
179        uniform in weight along a line of latitude.  With small theta and
180        phi ranging over +/- 180 this forms a wobbling disk.  With small
181        phi and theta ranging over +/- 90 this forms a wedge like a slice
182        of an orange.
183    azimuthal_equidistance (Postel)
184        <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>
185        Preserves distance from center, and so is an excellent map for
186        representing a bivariate gaussian on the surface.  Theta and phi
187        operate identically, cutting wegdes from the antipode of the viewing
188        angle.  This unfortunately does not allow free movement in either
189        theta or phi since the orthogonal wobble decreases to 0 as the body
190        rotates through 180 degrees.
191    sinusoidal (Sanson-Flamsteed, Mercator equal-area)
192        <https://en.wikipedia.org/wiki/Sinusoidal_projection>
193        Preserves arc length with latitude, giving bad behaviour at
194        theta near +/- 90.  Theta and phi operate somewhat differently,
195        so a system with a-b-c dtheta-dphi-dpsi will not give the same
196        value as one with b-a-c dphi-dtheta-dpsi, as would be the case
197        for azimuthal equidistance.  Free movement using theta or phi
198        uniform over +/- 180 will work, but not as well as equirectangular
199        phi, with theta being slightly worse.  Computationally it is much
200        cheaper for wide theta-phi meshes since it excludes points which
201        lie outside the sinusoid near theta +/- 90 rather than packing
202        them close together as in equirectangle.
203    Guyour (hemisphere-in-a-square)  **not implemented**
204        <https://en.wikipedia.org/wiki/Guyou_hemisphere-in-a-square_projection>
205        Promising.  With tiling should allow rotation in phi or theta
206        through +/- 180, preserving almost disk-like behaviour in either
207        direction (phi rotation will not be as uniform as it is in
208        equirectangular; not sure about theta).  Unfortunately, distortion
209        is not restricted to the corners of the theta-phi mesh, so this will
210        not be as good as the azimuthal equidistance project for gaussian
211        distributions.
212    azimuthal_equal_area  **incomplete**
213        <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>
214        Preserves the relative density of the surface patches.  Not that
215        useful and not completely implemented
216    Gauss-Kreuger **not implemented**
217        <https://en.wikipedia.org/wiki/Transverse_Mercator_projection#Ellipsoidal_transverse_Mercator>
218        Should allow free movement in theta, but phi is distorted.
219    """
220    theta, phi, psi = view
221    dtheta, dphi, dpsi = jitter
222
223    t = np.linspace(-1, 1, n)
224    weights = np.ones_like(t)
225    if dist == 'gaussian':
226        t *= 3
227        weights = exp(-0.5*t**2)
228    elif dist == 'rectangle':
229        # Note: uses sasmodels ridiculous definition of rectangle width
230        t *= sqrt(3)
231    elif dist == 'uniform':
232        pass
233    else:
234        raise ValueError("expected dist to be gaussian, rectangle or uniform")
235
236    if projection == 'equirectangular':
237        def rotate(theta_i, phi_j):
238            return Rx(phi_j)*Ry(theta_i)
239        def weight(theta_i, phi_j, wi, wj):
240            return wi*wj*cos(radians(theta_i))
241    elif projection == 'azimuthal_equidistance':
242        def rotate(theta_i, phi_j):
243            latitude = sqrt(theta_i**2 + phi_j**2)
244            longitude = degrees(np.arctan2(phi_j, theta_i))
245            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
246            return Rz(longitude)*Ry(latitude)
247        def weight(theta_i, phi_j, wi, wj):
248            # Weighting for each point comes from the integral:
249            #     \int\int I(q, lat, log) sin(lat) dlat dlog
250            # We are doing a conformal mapping from disk to sphere, so we need
251            # a change of variables g(theta, phi) -> (lat, long):
252            #     lat, long = sqrt(theta^2 + phi^2), arctan(phi/theta)
253            # giving:
254            #     dtheta dphi = det(J) dlat dlong
255            # where J is the jacobian from the partials of g. Using
256            #     R = sqrt(theta^2 + phi^2),
257            # then
258            #     J = [[x/R, Y/R], -y/R^2, x/R^2]]
259            # and
260            #     det(J) = 1/R
261            # with the final integral being:
262            #    \int\int I(q, theta, phi) sin(R)/R dtheta dphi
263            #
264            # This does approximately the right thing, decreasing the weight
265            # of each point as you go farther out on the disk, but it hasn't
266            # yet been checked against the 1D integral results. Prior
267            # to declaring this "good enough" and checking that integrals
268            # work in practice, we will examine alternative mappings.
269            #
270            # The issue is that the mapping does not support the case of free
271            # rotation about a single axis correctly, with a small deviation
272            # in the orthogonal axis independent of the first axis.  Like the
273            # usual polar coordiates integration, the integrated sections
274            # form wedges, though at least in this case the wedge cuts through
275            # the entire sphere, and treats theta and phi identically.
276            latitude = sqrt(theta_i**2 + phi_j**2)
277            w = sin(radians(latitude))/latitude if latitude != 0 else 1
278            return w*wi*wj if latitude < 180 else 0
279    elif projection == 'sinusoidal':
280        def rotate(theta_i, phi_j):
281            latitude = theta_i
282            scale = cos(radians(latitude))
283            longitude = phi_j/scale if abs(phi_j) < abs(scale)*180 else 0
284            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
285            return Rx(longitude)*Ry(latitude)
286        def weight(theta_i, phi_j, wi, wj):
287            latitude = theta_i
288            scale = cos(radians(latitude))
289            w = 1 if abs(phi_j) < abs(scale)*180 else 0
290            return w*wi*wj
291    elif projection == 'azimuthal_equal_area':
292        def rotate(theta_i, phi_j):
293            R = min(1, sqrt(theta_i**2 + phi_j**2)/180)
294            latitude = 180-degrees(2*np.arccos(R))
295            longitude = degrees(np.arctan2(phi_j, theta_i))
296            #print("(%+7.2f, %+7.2f) => (%+7.2f, %+7.2f)"%(theta_i, phi_j, latitude, longitude))
297            return Rz(longitude)*Ry(latitude)
298        def weight(theta_i, phi_j, wi, wj):
299            latitude = sqrt(theta_i**2 + phi_j**2)
300            w = sin(radians(latitude))/latitude if latitude != 0 else 1
301            return w*wi*wj if latitude < 180 else 0
302    else:
303        raise ValueError("unknown projection %r"%projection)
304
305    # mesh in theta, phi formed by rotating z
306    z = np.matrix([[0], [0], [radius]])
307    points = np.hstack([rotate(theta_i, phi_j)*z
308                        for theta_i in dtheta*t
309                        for phi_j in dphi*t])
310    # select just the active points (i.e., those with phi < 180
311    w = np.array([weight(theta_i, phi_j, wi, wj)
312                  for wi, theta_i in zip(weights, dtheta*t)
313                  for wj, phi_j in zip(weights, dphi*t)])
314    #print(max(w), min(w), min(w[w>0]))
315    points = points[:, w>0]
316    w = w[w>0]
317    w /= max(w)
318
319    if 0: # Kent distribution
320        points = np.hstack([Rx(phi_j)*Ry(theta_i)*z for theta_i in 30*t for phi_j in 60*t])
321        xp, yp, zp = [np.array(v).flatten() for v in points]
322        kappa = max(1e6, radians(dtheta)/(2*pi))
323        beta = 1/max(1e-6, radians(dphi)/(2*pi))/kappa
324        w = exp(kappa*zp) #+ beta*(xp**2 + yp**2)
325        print(kappa, dtheta, radians(dtheta), min(w), max(w), sum(w))
326        #w /= abs(cos(radians(
327        #w /= sum(w)
328
329    # rotate relative to beam
330    points = orient_relative_to_beam(view, points)
331
332    x, y, z = [np.array(v).flatten() for v in points]
333    #plt.figure(2); plt.clf(); plt.hist(z, bins=np.linspace(-1, 1, 51))
334    ax.scatter(x, y, z, c=w, marker='o', vmin=0., vmax=1.)
335
336def draw_labels(ax, view, jitter, text):
337    """
338    Draw text at a particular location.
339    """
340    labels, locations, orientations = zip(*text)
341    px, py, pz = zip(*locations)
342    dx, dy, dz = zip(*orientations)
343
344    px, py, pz = transform_xyz(view, jitter, px, py, pz)
345    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
346
347    # TODO: zdir for labels is broken, and labels aren't appearing.
348    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
349        zdir = np.asarray(zdir).flatten()
350        ax.text(p[0], p[1], p[2], label, zdir=zdir)
351
352# Definition of rotation matrices comes from wikipedia:
353#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
354def Rx(angle):
355    """Construct a matrix to rotate points about *x* by *angle* degrees."""
356    a = radians(angle)
357    R = [[1, 0, 0],
358         [0, +cos(a), -sin(a)],
359         [0, +sin(a), +cos(a)]]
360    return np.matrix(R)
361
362def Ry(angle):
363    """Construct a matrix to rotate points about *y* by *angle* degrees."""
364    a = radians(angle)
365    R = [[+cos(a), 0, +sin(a)],
366         [0, 1, 0],
367         [-sin(a), 0, +cos(a)]]
368    return np.matrix(R)
369
370def Rz(angle):
371    """Construct a matrix to rotate points about *z* by *angle* degrees."""
372    a = radians(angle)
373    R = [[+cos(a), -sin(a), 0],
374         [+sin(a), +cos(a), 0],
375         [0, 0, 1]]
376    return np.matrix(R)
377
378def transform_xyz(view, jitter, x, y, z):
379    """
380    Send a set of (x,y,z) points through the jitter and view transforms.
381    """
382    x, y, z = [np.asarray(v) for v in (x, y, z)]
383    shape = x.shape
384    points = np.matrix([x.flatten(),y.flatten(),z.flatten()])
385    points = apply_jitter(jitter, points)
386    points = orient_relative_to_beam(view, points)
387    x, y, z = [np.array(v).reshape(shape) for v in points]
388    return x, y, z
389
390def apply_jitter(jitter, points):
391    """
392    Apply the jitter transform to a set of points.
393
394    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
395    """
396    dtheta, dphi, dpsi = jitter
397    points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
398    return points
399
400def orient_relative_to_beam(view, points):
401    """
402    Apply the view transform to a set of points.
403
404    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
405    """
406    theta, phi, psi = view
407    points = Rz(phi)*Ry(theta)*Rz(psi)*points
408    return points
409
410# translate between number of dimension of dispersity and the number of
411# points along each dimension.
412PD_N_TABLE = {
413    (0, 0, 0): (0, 0, 0),     # 0
414    (1, 0, 0): (100, 0, 0),   # 100
415    (0, 1, 0): (0, 100, 0),
416    (0, 0, 1): (0, 0, 100),
417    (1, 1, 0): (30, 30, 0),   # 900
418    (1, 0, 1): (30, 0, 30),
419    (0, 1, 1): (0, 30, 30),
420    (1, 1, 1): (15, 15, 15),  # 3375
421}
422
423def clipped_range(data, portion=1.0, mode='central'):
424    """
425    Determine range from data.
426
427    If *portion* is 1, use full range, otherwise use the center of the range
428    or the top of the range, depending on whether *mode* is 'central' or 'top'.
429    """
430    if portion == 1.0:
431        return data.min(), data.max()
432    elif mode == 'central':
433        data = np.sort(data.flatten())
434        offset = int(portion*len(data)/2 + 0.5)
435        return data[offset], data[-offset]
436    elif mode == 'top':
437        data = np.sort(data.flatten())
438        offset = int(portion*len(data) + 0.5)
439        return data[offset], data[-1]
440
441def draw_scattering(calculator, ax, view, jitter, dist='gaussian'):
442    """
443    Plot the scattering for the particular view.
444
445    *calculator* is returned from :func:`build_model`.  *ax* are the 3D axes
446    on which the data will be plotted.  *view* and *jitter* are the current
447    orientation and orientation dispersity.  *dist* is one of the sasmodels
448    weight distributions.
449    """
450    if dist == 'uniform':  # uniform is not yet in this branch
451        dist, scale = 'rectangle', 1/sqrt(3)
452    else:
453        scale = 1
454
455    # add the orientation parameters to the model parameters
456    theta, phi, psi = view
457    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
458    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd>0, phi_pd>0, psi_pd>0)]
459    ## increase pd_n for testing jitter integration rather than simple viz
460    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
461
462    pars = dict(
463        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
464        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
465        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
466    )
467    pars.update(calculator.pars)
468
469    # compute the pattern
470    qx, qy = calculator._data.x_bins, calculator._data.y_bins
471    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
472
473    # scale it and draw it
474    Iqxy = np.log(Iqxy)
475    if calculator.limits:
476        # use limits from orientation (0,0,0)
477        vmin, vmax = calculator.limits
478    else:
479        vmin, vmax = clipped_range(Iqxy, portion=0.95, mode='top')
480    #print("range",(vmin,vmax))
481    #qx, qy = np.meshgrid(qx, qy)
482    if 0:
483        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
484        level[level<0] = 0
485        colors = plt.get_cmap()(level)
486        ax.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors)
487    elif 1:
488        ax.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
489                    levels=np.linspace(vmin, vmax, 24))
490    else:
491        ax.pcolormesh(qx, qy, Iqxy)
492
493def build_model(model_name, n=150, qmax=0.5, **pars):
494    """
495    Build a calculator for the given shape.
496
497    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
498    on which to evaluate the model.  The remaining parameters are stored in
499    the returned calculator as *calculator.pars*.  They are used by
500    :func:`draw_scattering` to set the non-orientation parameters in the
501    calculation.
502
503    Returns a *calculator* function which takes a dictionary or parameters and
504    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
505    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
506    for details.
507    """
508    from sasmodels.core import load_model_info, build_model
509    from sasmodels.data import empty_data2D
510    from sasmodels.direct_model import DirectModel
511
512    model_info = load_model_info(model_name)
513    model = build_model(model_info) #, dtype='double!')
514    q = np.linspace(-qmax, qmax, n)
515    data = empty_data2D(q, q)
516    calculator = DirectModel(data, model)
517
518    # stuff the values for non-orientation parameters into the calculator
519    calculator.pars = pars.copy()
520    calculator.pars.setdefault('backgound', 1e-3)
521
522    # fix the data limits so that we can see if the pattern fades
523    # under rotation or angular dispersion
524    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
525    Iqxy = np.log(Iqxy)
526    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
527    calculator.limits = vmin, vmax+1
528
529    return calculator
530
531def select_calculator(model_name, n=150, size=(10,40,100)):
532    """
533    Create a model calculator for the given shape.
534
535    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
536    parallelepiped or bcc_paracrystal. *n* is the number of points to use
537    in the q range.  *qmax* is chosen based on model parameters for the
538    given model to show something intersting.
539
540    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
541    equitorial axes and polar axis respectively.  See :func:`build_model`
542    for details on the returned calculator.
543    """
544    a, b, c = size
545    if model_name == 'sphere':
546        calculator = build_model('sphere', n=n, radius=c)
547        a = b = c
548    elif model_name == 'bcc_paracrystal':
549        calculator = build_model('bcc_paracrystal', n=n, dnn=c,
550                                  d_factor=0.06, radius=40)
551        a = b = c
552    elif model_name == 'cylinder':
553        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
554        a = b
555    elif model_name == 'ellipsoid':
556        calculator = build_model('ellipsoid', n=n, qmax=1.0,
557                                 radius_polar=c, radius_equatorial=b)
558        a = b
559    elif model_name == 'triaxial_ellipsoid':
560        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
561                                 radius_equat_minor=a,
562                                 radius_equat_major=b,
563                                 radius_polar=c)
564    elif model_name == 'parallelepiped':
565        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
566    else:
567        raise ValueError("unknown model %s"%model_name)
568
569    return calculator, (a, b, c)
570
571SHAPES = [
572    'parallelepiped', 'triaxial_ellipsoid', 'bcc_paracrystal',
573    'cylinder', 'ellipsoid',
574    'sphere',
575 ]
576
577DISTRIBUTIONS = [
578    'gaussian', 'rectangle', 'uniform',
579]
580DIST_LIMITS = {
581    'gaussian': 30,
582    'rectangle': 90/sqrt(3),
583    'uniform': 90,
584}
585
586def run(model_name='parallelepiped', size=(10, 40, 100),
587        dist='gaussian', mesh=30,
588        projection='equirectangular'):
589    """
590    Show an interactive orientation and jitter demo.
591
592    *model_name* is one of the models available in :func:`select_model`.
593    """
594    # set up calculator
595    calculator, size = select_calculator(model_name, n=150, size=size)
596
597    ## uncomment to set an independent the colour range for every view
598    ## If left commented, the colour range is fixed for all views
599    calculator.limits = None
600
601    ## initial view
602    #theta, dtheta = 70., 10.
603    #phi, dphi = -45., 3.
604    #psi, dpsi = -45., 3.
605    theta, phi, psi = 0, 0, 0
606    dtheta, dphi, dpsi = 0, 0, 0
607
608    ## create the plot window
609    #plt.hold(True)
610    plt.set_cmap('gist_earth')
611    plt.clf()
612    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
613    #ax = plt.subplot(gs[0], projection='3d')
614    ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
615    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
616        ax.axis('square')
617    except Exception:
618        pass
619
620    axcolor = 'lightgoldenrodyellow'
621
622    ## add control widgets to plot
623    axtheta  = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor)
624    axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor)
625    axpsi = plt.axes([0.1, 0.05, 0.45, 0.04], axisbg=axcolor)
626    stheta = Slider(axtheta, 'Theta', -90, 90, valinit=theta)
627    sphi = Slider(axphi, 'Phi', -180, 180, valinit=phi)
628    spsi = Slider(axpsi, 'Psi', -180, 180, valinit=psi)
629
630    axdtheta  = plt.axes([0.75, 0.15, 0.15, 0.04], axisbg=axcolor)
631    axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor)
632    axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor)
633    # Note: using ridiculous definition of rectangle distribution, whose width
634    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
635    # the maximum width to 90.
636    dlimit = DIST_LIMITS[dist]
637    sdtheta = Slider(axdtheta, 'dTheta', 0, 2*dlimit, valinit=dtheta)
638    sdphi = Slider(axdphi, 'dPhi', 0, 2*dlimit, valinit=dphi)
639    sdpsi = Slider(axdpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi)
640
641    ## callback to draw the new view
642    def update(val, axis=None):
643        view = stheta.val, sphi.val, spsi.val
644        jitter = sdtheta.val, sdphi.val, sdpsi.val
645        # set small jitter as 0 if multiple pd dims
646        dims = sum(v > 0 for v in jitter)
647        limit = [0, 0, 0.5, 5][dims]
648        jitter = [0 if v < limit else v for v in jitter]
649        ax.cla()
650        draw_beam(ax, (0, 0))
651        draw_jitter(ax, view, jitter, dist=dist, size=size)
652        #draw_jitter(ax, view, (0,0,0))
653        draw_mesh(ax, view, jitter, dist=dist, n=mesh, projection=projection)
654        draw_scattering(calculator, ax, view, jitter, dist=dist)
655        plt.gcf().canvas.draw()
656
657    ## bind control widgets to view updater
658    stheta.on_changed(lambda v: update(v,'theta'))
659    sphi.on_changed(lambda v: update(v, 'phi'))
660    spsi.on_changed(lambda v: update(v, 'psi'))
661    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
662    sdphi.on_changed(lambda v: update(v, 'dphi'))
663    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
664
665    ## initialize view
666    update(None, 'phi')
667
668    ## go interactive
669    plt.show()
670
671def main():
672    parser = argparse.ArgumentParser(
673        description="Display jitter",
674        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
675        )
676    parser.add_argument('-p', '--projection', choices=PROJECTIONS, default=PROJECTIONS[0], help='coordinate projection')
677    parser.add_argument('-s', '--size', type=str, default='10,40,100', help='a,b,c lengths')
678    parser.add_argument('-d', '--distribution', choices=DISTRIBUTIONS, default=DISTRIBUTIONS[0], help='jitter distribution')
679    parser.add_argument('-m', '--mesh', type=int, default=30, help='#points in theta-phi mesh')
680    parser.add_argument('shape', choices=SHAPES, nargs='?', default=SHAPES[0], help='oriented shape')
681    opts = parser.parse_args()
682    size = tuple(int(v) for v in opts.size.split(','))
683    run(opts.shape, size=size,
684        mesh=opts.mesh, dist=opts.distribution,
685        projection=opts.projection)
686
687if __name__ == "__main__":
688    main()
Note: See TracBrowser for help on using the repository browser.