source: sasmodels/explore/jitter.py @ 36b3154

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 36b3154 was 36b3154, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

hack around axis('square') issue for matplotlib in explore/jitter.py

  • Property mode set to 100755
File size: 18.0 KB
Line 
1#!/usr/bin/env python
2"""
3Application to explore the difference between sasview 3.x orientation
4dispersity and possible replacement algorithms.
5"""
6from __future__ import division, print_function
7
8import sys, os
9sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
10
11import mpl_toolkits.mplot3d   # Adds projection='3d' option to subplot
12import matplotlib.pyplot as plt
13from matplotlib.widgets import Slider, CheckButtons
14from matplotlib import cm
15import numpy as np
16from numpy import pi, cos, sin, sqrt, exp, degrees, radians
17
18def draw_beam(ax, view=(0, 0)):
19    """
20    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
21    """
22    #ax.plot([0,0],[0,0],[1,-1])
23    #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8)
24
25    steps = 25
26    u = np.linspace(0, 2 * np.pi, steps)
27    v = np.linspace(-1, 1, steps)
28
29    r = 0.02
30    x = r*np.outer(np.cos(u), np.ones_like(v))
31    y = r*np.outer(np.sin(u), np.ones_like(v))
32    z = 1.3*np.outer(np.ones_like(u), v)
33
34    theta, phi = view
35    shape = x.shape
36    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
37    points = Rz(phi)*Ry(theta)*points
38    x, y, z = [v.reshape(shape) for v in points]
39
40    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5)
41
42def draw_jitter(ax, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0)):
43    """
44    Represent jitter as a set of shapes at different orientations.
45    """
46    # set max diagonal to 0.95
47    scale = 0.95/sqrt(sum(v**2 for v in size))
48    size = tuple(scale*v for v in size)
49    draw_shape = draw_parallelepiped
50    #draw_shape = draw_ellipsoid
51
52    #np.random.seed(10)
53    #cloud = np.random.randn(10,3)
54    cloud = [
55        [-1, -1, -1],
56        [-1, -1,  0],
57        [-1, -1,  1],
58        [-1,  0, -1],
59        [-1,  0,  0],
60        [-1,  0,  1],
61        [-1,  1, -1],
62        [-1,  1,  0],
63        [-1,  1,  1],
64        [ 0, -1, -1],
65        [ 0, -1,  0],
66        [ 0, -1,  1],
67        [ 0,  0, -1],
68        [ 0,  0,  0],
69        [ 0,  0,  1],
70        [ 0,  1, -1],
71        [ 0,  1,  0],
72        [ 0,  1,  1],
73        [ 1, -1, -1],
74        [ 1, -1,  0],
75        [ 1, -1,  1],
76        [ 1,  0, -1],
77        [ 1,  0,  0],
78        [ 1,  0,  1],
79        [ 1,  1, -1],
80        [ 1,  1,  0],
81        [ 1,  1,  1],
82    ]
83    dtheta, dphi, dpsi = jitter
84    if dtheta == 0:
85        cloud = [v for v in cloud if v[0] == 0]
86    if dphi == 0:
87        cloud = [v for v in cloud if v[1] == 0]
88    if dpsi == 0:
89        cloud = [v for v in cloud if v[2] == 0]
90    draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8)
91    scale = 1/sqrt(3) if dist == 'rectangle' else 1
92    for point in cloud:
93        delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]]
94        draw_shape(ax, size, view, delta, alpha=0.8)
95    for v in 'xyz':
96        a, b, c = size
97        lim = np.sqrt(a**2+b**2+c**2)
98        getattr(ax, 'set_'+v+'lim')([-lim, lim])
99        getattr(ax, v+'axis').label.set_text(v)
100
101def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1):
102    """Draw an ellipsoid."""
103    a,b,c = size
104    u = np.linspace(0, 2 * np.pi, steps)
105    v = np.linspace(0, np.pi, steps)
106    x = a*np.outer(np.cos(u), np.sin(v))
107    y = b*np.outer(np.sin(u), np.sin(v))
108    z = c*np.outer(np.ones_like(u), np.cos(v))
109    x, y, z = transform_xyz(view, jitter, x, y, z)
110
111    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha)
112
113    draw_labels(ax, view, jitter, [
114         ('c+', [ 0, 0, c], [ 1, 0, 0]),
115         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
116         ('a+', [ a, 0, 0], [ 0, 0, 1]),
117         ('a-', [-a, 0, 0], [ 0, 0,-1]),
118         ('b+', [ 0, b, 0], [-1, 0, 0]),
119         ('b-', [ 0,-b, 0], [-1, 0, 0]),
120    ])
121
122def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1):
123    """Draw a parallelepiped."""
124    a,b,c = size
125    x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1])
126    y = b*np.array([ 1, 1,-1,-1, 1, 1,-1,-1])
127    z = c*np.array([ 1, 1, 1, 1,-1,-1,-1,-1])
128    tri = np.array([
129        # counter clockwise triangles
130        # z: up/down, x: right/left, y: front/back
131        [0,1,2], [3,2,1], # top face
132        [6,5,4], [5,6,7], # bottom face
133        [0,2,6], [6,4,0], # right face
134        [1,5,7], [7,3,1], # left face
135        [2,3,6], [7,6,3], # front face
136        [4,1,0], [5,1,4], # back face
137    ])
138
139    x, y, z = transform_xyz(view, jitter, x, y, z)
140    ax.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha)
141
142    draw_labels(ax, view, jitter, [
143         ('c+', [ 0, 0, c], [ 1, 0, 0]),
144         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
145         ('a+', [ a, 0, 0], [ 0, 0, 1]),
146         ('a-', [-a, 0, 0], [ 0, 0,-1]),
147         ('b+', [ 0, b, 0], [-1, 0, 0]),
148         ('b-', [ 0,-b, 0], [-1, 0, 0]),
149    ])
150
151def draw_sphere(ax, radius=10., steps=100):
152    """Draw a sphere"""
153    u = np.linspace(0, 2 * np.pi, steps)
154    v = np.linspace(0, np.pi, steps)
155
156    x = radius * np.outer(np.cos(u), np.sin(v))
157    y = radius * np.outer(np.sin(u), np.sin(v))
158    z = radius * np.outer(np.ones(np.size(u)), np.cos(v))
159    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w')
160
161def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gaussian'):
162    """
163    Draw the dispersion mesh showing the theta-phi orientations at which
164    the model will be evaluated.
165    """
166    theta, phi, psi = view
167    dtheta, dphi, dpsi = jitter
168
169    if dist == 'gaussian':
170        t = np.linspace(-3, 3, n)
171        weights = exp(-0.5*t**2)
172    elif dist == 'rectangle':
173        # Note: uses sasmodels ridiculous definition of rectangle width
174        t = np.linspace(-1, 1, n)*sqrt(3)
175        weights = np.ones_like(t)
176    else:
177        raise ValueError("expected dist to be 'gaussian' or 'rectangle'")
178
179    # mesh in theta, phi formed by rotating z
180    z = np.matrix([[0], [0], [radius]])
181    points = np.hstack([Rx(phi_i)*Ry(theta_i)*z
182                        for theta_i in dtheta*t
183                        for phi_i in dphi*t])
184    # rotate relative to beam
185    points = orient_relative_to_beam(view, points)
186
187    w = np.outer(weights*cos(radians(dtheta*t)), weights)
188
189    x, y, z = [np.array(v).flatten() for v in points]
190    ax.scatter(x, y, z, c=w.flatten(), marker='o', vmin=0., vmax=1.)
191
192def draw_labels(ax, view, jitter, text):
193    """
194    Draw text at a particular location.
195    """
196    labels, locations, orientations = zip(*text)
197    px, py, pz = zip(*locations)
198    dx, dy, dz = zip(*orientations)
199
200    px, py, pz = transform_xyz(view, jitter, px, py, pz)
201    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
202
203    # TODO: zdir for labels is broken, and labels aren't appearing.
204    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
205        zdir = np.asarray(zdir).flatten()
206        ax.text(p[0], p[1], p[2], label, zdir=zdir)
207
208# Definition of rotation matrices comes from wikipedia:
209#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
210def Rx(angle):
211    """Construct a matrix to rotate points about *x* by *angle* degrees."""
212    a = radians(angle)
213    R = [[1, 0, 0],
214         [0, +cos(a), -sin(a)],
215         [0, +sin(a), +cos(a)]]
216    return np.matrix(R)
217
218def Ry(angle):
219    """Construct a matrix to rotate points about *y* by *angle* degrees."""
220    a = radians(angle)
221    R = [[+cos(a), 0, +sin(a)],
222         [0, 1, 0],
223         [-sin(a), 0, +cos(a)]]
224    return np.matrix(R)
225
226def Rz(angle):
227    """Construct a matrix to rotate points about *z* by *angle* degrees."""
228    a = radians(angle)
229    R = [[+cos(a), -sin(a), 0],
230         [+sin(a), +cos(a), 0],
231         [0, 0, 1]]
232    return np.matrix(R)
233
234def transform_xyz(view, jitter, x, y, z):
235    """
236    Send a set of (x,y,z) points through the jitter and view transforms.
237    """
238    x, y, z = [np.asarray(v) for v in (x, y, z)]
239    shape = x.shape
240    points = np.matrix([x.flatten(),y.flatten(),z.flatten()])
241    points = apply_jitter(jitter, points)
242    points = orient_relative_to_beam(view, points)
243    x, y, z = [np.array(v).reshape(shape) for v in points]
244    return x, y, z
245
246def apply_jitter(jitter, points):
247    """
248    Apply the jitter transform to a set of points.
249
250    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
251    """
252    dtheta, dphi, dpsi = jitter
253    points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
254    return points
255
256def orient_relative_to_beam(view, points):
257    """
258    Apply the view transform to a set of points.
259
260    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
261    """
262    theta, phi, psi = view
263    points = Rz(phi)*Ry(theta)*Rz(psi)*points
264    return points
265
266# translate between number of dimension of dispersity and the number of
267# points along each dimension.
268PD_N_TABLE = {
269    (0, 0, 0): (0, 0, 0),     # 0
270    (1, 0, 0): (100, 0, 0),   # 100
271    (0, 1, 0): (0, 100, 0),
272    (0, 0, 1): (0, 0, 100),
273    (1, 1, 0): (30, 30, 0),   # 900
274    (1, 0, 1): (30, 0, 30),
275    (0, 1, 1): (0, 30, 30),
276    (1, 1, 1): (15, 15, 15),  # 3375
277}
278
279def clipped_range(data, portion=1.0, mode='central'):
280    """
281    Determine range from data.
282
283    If *portion* is 1, use full range, otherwise use the center of the range
284    or the top of the range, depending on whether *mode* is 'central' or 'top'.
285    """
286    if portion == 1.0:
287        return data.min(), data.max()
288    elif mode == 'central':
289        data = np.sort(data.flatten())
290        offset = int(portion*len(data)/2 + 0.5)
291        return data[offset], data[-offset]
292    elif mode == 'top':
293        data = np.sort(data.flatten())
294        offset = int(portion*len(data) + 0.5)
295        return data[offset], data[-1]
296
297def draw_scattering(calculator, ax, view, jitter, dist='gaussian'):
298    """
299    Plot the scattering for the particular view.
300
301    *calculator* is returned from :func:`build_model`.  *ax* are the 3D axes
302    on which the data will be plotted.  *view* and *jitter* are the current
303    orientation and orientation dispersity.  *dist* is one of the sasmodels
304    weight distributions.
305    """
306    ## Sasmodels use sqrt(3)*width for the rectangle range; scale to the
307    ## proper width for comparison. Commented out since now using the
308    ## sasmodels definition of width for rectangle.
309    #scale = 1/sqrt(3) if dist == 'rectangle' else 1
310    scale = 1
311
312    # add the orientation parameters to the model parameters
313    theta, phi, psi = view
314    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
315    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd>0, phi_pd>0, psi_pd>0)]
316    ## increase pd_n for testing jitter integration rather than simple viz
317    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
318
319    pars = dict(
320        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
321        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
322        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
323    )
324    pars.update(calculator.pars)
325
326    # compute the pattern
327    qx, qy = calculator._data.x_bins, calculator._data.y_bins
328    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
329
330    # scale it and draw it
331    Iqxy = np.log(Iqxy)
332    if calculator.limits:
333        # use limits from orientation (0,0,0)
334        vmin, vmax = calculator.limits
335    else:
336        vmin, vmax = clipped_range(Iqxy, portion=0.95, mode='top')
337    #print("range",(vmin,vmax))
338    #qx, qy = np.meshgrid(qx, qy)
339    if 0:
340        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
341        level[level<0] = 0
342        colors = plt.get_cmap()(level)
343        ax.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors)
344    elif 1:
345        ax.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
346                    levels=np.linspace(vmin, vmax, 24))
347    else:
348        ax.pcolormesh(qx, qy, Iqxy)
349
350def build_model(model_name, n=150, qmax=0.5, **pars):
351    """
352    Build a calculator for the given shape.
353
354    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
355    on which to evaluate the model.  The remaining parameters are stored in
356    the returned calculator as *calculator.pars*.  They are used by
357    :func:`draw_scattering` to set the non-orientation parameters in the
358    calculation.
359
360    Returns a *calculator* function which takes a dictionary or parameters and
361    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
362    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
363    for details.
364    """
365    from sasmodels.core import load_model_info, build_model
366    from sasmodels.data import empty_data2D
367    from sasmodels.direct_model import DirectModel
368
369    model_info = load_model_info(model_name)
370    model = build_model(model_info) #, dtype='double!')
371    q = np.linspace(-qmax, qmax, n)
372    data = empty_data2D(q, q)
373    calculator = DirectModel(data, model)
374
375    # stuff the values for non-orientation parameters into the calculator
376    calculator.pars = pars.copy()
377    calculator.pars.setdefault('backgound', 1e-3)
378
379    # fix the data limits so that we can see if the pattern fades
380    # under rotation or angular dispersion
381    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
382    Iqxy = np.log(Iqxy)
383    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
384    calculator.limits = vmin, vmax+1
385
386    return calculator
387
388def select_calculator(model_name, n=150, size=(10,40,100)):
389    """
390    Create a model calculator for the given shape.
391
392    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
393    parallelepiped or bcc_paracrystal. *n* is the number of points to use
394    in the q range.  *qmax* is chosen based on model parameters for the
395    given model to show something intersting.
396
397    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
398    equitorial axes and polar axis respectively.  See :func:`build_model`
399    for details on the returned calculator.
400    """
401    a, b, c = size
402    if model_name == 'sphere':
403        calculator = build_model('sphere', n=n, radius=c)
404        a = b = c
405    elif model_name == 'bcc_paracrystal':
406        calculator = build_model('bcc_paracrystal', n=n, dnn=c,
407                                  d_factor=0.06, radius=40)
408        a = b = c
409    elif model_name == 'cylinder':
410        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
411        a = b
412    elif model_name == 'ellipsoid':
413        calculator = build_model('ellipsoid', n=n, qmax=1.0,
414                                 radius_polar=c, radius_equatorial=b)
415        a = b
416    elif model_name == 'triaxial_ellipsoid':
417        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
418                                 radius_equat_minor=a,
419                                 radius_equat_major=b,
420                                 radius_polar=c)
421    elif model_name == 'parallelepiped':
422        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
423    else:
424        raise ValueError("unknown model %s"%model_name)
425
426    return calculator, (a, b, c)
427
428def main(model_name='parallelepiped', size=(10, 40, 100)):
429    """
430    Show an interactive orientation and jitter demo.
431
432    *model_name* is one of the models available in :func:`select_model`.
433    """
434    # set up calculator
435    calculator, size = select_calculator(model_name, n=150, size=size)
436
437    ## uncomment to set an independent the colour range for every view
438    ## If left commented, the colour range is fixed for all views
439    calculator.limits = None
440
441    ## use gaussian distribution unless testing integration
442    #dist = 'rectangle'
443    dist = 'gaussian'
444
445    ## initial view
446    #theta, dtheta = 70., 10.
447    #phi, dphi = -45., 3.
448    #psi, dpsi = -45., 3.
449    theta, phi, psi = 0, 0, 0
450    dtheta, dphi, dpsi = 0, 0, 0
451
452    ## create the plot window
453    #plt.hold(True)
454    plt.set_cmap('gist_earth')
455    plt.clf()
456    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
457    #ax = plt.subplot(gs[0], projection='3d')
458    ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
459    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
460        ax.axis('square')
461    except Exception:
462        pass
463
464    axcolor = 'lightgoldenrodyellow'
465
466    ## add control widgets to plot
467    axtheta  = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor)
468    axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor)
469    axpsi = plt.axes([0.1, 0.05, 0.45, 0.04], axisbg=axcolor)
470    stheta = Slider(axtheta, 'Theta', -90, 90, valinit=theta)
471    sphi = Slider(axphi, 'Phi', -180, 180, valinit=phi)
472    spsi = Slider(axpsi, 'Psi', -180, 180, valinit=psi)
473
474    axdtheta  = plt.axes([0.75, 0.15, 0.15, 0.04], axisbg=axcolor)
475    axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor)
476    axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor)
477    # Note: using ridiculous definition of rectangle distribution, whose width
478    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
479    # the maximum width to 90.
480    dlimit = 30 if dist == 'gaussian' else 90/sqrt(3)
481    sdtheta = Slider(axdtheta, 'dTheta', 0, dlimit, valinit=dtheta)
482    sdphi = Slider(axdphi, 'dPhi', 0, 2*dlimit, valinit=dphi)
483    sdpsi = Slider(axdpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi)
484
485    ## callback to draw the new view
486    def update(val, axis=None):
487        view = stheta.val, sphi.val, spsi.val
488        jitter = sdtheta.val, sdphi.val, sdpsi.val
489        # set small jitter as 0 if multiple pd dims
490        dims = sum(v > 0 for v in jitter)
491        limit = [0, 0, 2, 5][dims]
492        jitter = [0 if v < limit else v for v in jitter]
493        ax.cla()
494        draw_beam(ax, (0, 0))
495        draw_jitter(ax, view, jitter, dist=dist, size=size)
496        #draw_jitter(ax, view, (0,0,0))
497        draw_mesh(ax, view, jitter, dist=dist)
498        draw_scattering(calculator, ax, view, jitter, dist=dist)
499        plt.gcf().canvas.draw()
500
501    ## bind control widgets to view updater
502    stheta.on_changed(lambda v: update(v,'theta'))
503    sphi.on_changed(lambda v: update(v, 'phi'))
504    spsi.on_changed(lambda v: update(v, 'psi'))
505    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
506    sdphi.on_changed(lambda v: update(v, 'dphi'))
507    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
508
509    ## initialize view
510    update(None, 'phi')
511
512    ## go interactive
513    plt.show()
514
515if __name__ == "__main__":
516    model_name = sys.argv[1] if len(sys.argv) > 1 else 'parallelepiped'
517    size = tuple(int(v) for v in sys.argv[2].split(',')) if len(sys.argv) > 2 else (10, 40, 100)
518    main(model_name, size)
Note: See TracBrowser for help on using the repository browser.