source: sasmodels/explore/jitter.py @ aa6989b

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since aa6989b was aa6989b, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

add scattering pattern to jitter viewer

  • Property mode set to 100755
File size: 17.8 KB
Line 
1#!/usr/bin/env python
2"""
3Application to explore the difference between sasview 3.x orientation
4dispersity and possible replacement algorithms.
5"""
6from __future__ import division, print_function
7
8import sys, os
9sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
10
11import mpl_toolkits.mplot3d   # Adds projection='3d' option to subplot
12import matplotlib.pyplot as plt
13from matplotlib.widgets import Slider, CheckButtons
14from matplotlib import cm
15import numpy as np
16from numpy import pi, cos, sin, sqrt, exp, degrees, radians
17
18def draw_beam(ax, view=(0, 0)):
19    """
20    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
21    """
22    #ax.plot([0,0],[0,0],[1,-1])
23    #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8)
24
25    steps = 25
26    u = np.linspace(0, 2 * np.pi, steps)
27    v = np.linspace(-1, 1, steps)
28
29    r = 0.02
30    x = r*np.outer(np.cos(u), np.ones_like(v))
31    y = r*np.outer(np.sin(u), np.ones_like(v))
32    z = 1.3*np.outer(np.ones_like(u), v)
33
34    theta, phi = view
35    shape = x.shape
36    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
37    points = Rz(phi)*Ry(theta)*points
38    x, y, z = [v.reshape(shape) for v in points]
39
40    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5)
41
42def draw_jitter(ax, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0)):
43    """
44    Represent jitter as a set of shapes at different orientations.
45    """
46    # set max diagonal to 0.95
47    scale = 0.95/sqrt(sum(v**2 for v in size))
48    size = tuple(scale*v for v in size)
49    draw_shape = draw_parallelepiped
50    #draw_shape = draw_ellipsoid
51
52    #np.random.seed(10)
53    #cloud = np.random.randn(10,3)
54    cloud = [
55        [-1, -1, -1],
56        [-1, -1,  0],
57        [-1, -1,  1],
58        [-1,  0, -1],
59        [-1,  0,  0],
60        [-1,  0,  1],
61        [-1,  1, -1],
62        [-1,  1,  0],
63        [-1,  1,  1],
64        [ 0, -1, -1],
65        [ 0, -1,  0],
66        [ 0, -1,  1],
67        [ 0,  0, -1],
68        [ 0,  0,  0],
69        [ 0,  0,  1],
70        [ 0,  1, -1],
71        [ 0,  1,  0],
72        [ 0,  1,  1],
73        [ 1, -1, -1],
74        [ 1, -1,  0],
75        [ 1, -1,  1],
76        [ 1,  0, -1],
77        [ 1,  0,  0],
78        [ 1,  0,  1],
79        [ 1,  1, -1],
80        [ 1,  1,  0],
81        [ 1,  1,  1],
82    ]
83    dtheta, dphi, dpsi = jitter
84    if dtheta == 0:
85        cloud = [v for v in cloud if v[0] == 0]
86    if dphi == 0:
87        cloud = [v for v in cloud if v[1] == 0]
88    if dpsi == 0:
89        cloud = [v for v in cloud if v[2] == 0]
90    draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8)
91    scale = 1/sqrt(3) if dist == 'rectangle' else 1
92    for point in cloud:
93        delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]]
94        draw_shape(ax, size, view, delta, alpha=0.8)
95    for v in 'xyz':
96        a, b, c = size
97        lim = np.sqrt(a**2+b**2+c**2)
98        getattr(ax, 'set_'+v+'lim')([-lim, lim])
99        getattr(ax, v+'axis').label.set_text(v)
100
101def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1):
102    """Draw an ellipsoid."""
103    a,b,c = size
104    u = np.linspace(0, 2 * np.pi, steps)
105    v = np.linspace(0, np.pi, steps)
106    x = a*np.outer(np.cos(u), np.sin(v))
107    y = b*np.outer(np.sin(u), np.sin(v))
108    z = c*np.outer(np.ones_like(u), np.cos(v))
109    x, y, z = transform_xyz(view, jitter, x, y, z)
110
111    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha)
112
113    draw_labels(ax, view, jitter, [
114         ('c+', [ 0, 0, c], [ 1, 0, 0]),
115         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
116         ('a+', [ a, 0, 0], [ 0, 0, 1]),
117         ('a-', [-a, 0, 0], [ 0, 0,-1]),
118         ('b+', [ 0, b, 0], [-1, 0, 0]),
119         ('b-', [ 0,-b, 0], [-1, 0, 0]),
120    ])
121
122def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1):
123    """Draw a parallelepiped."""
124    a,b,c = size
125    x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1])
126    y = b*np.array([ 1, 1,-1,-1, 1, 1,-1,-1])
127    z = c*np.array([ 1, 1, 1, 1,-1,-1,-1,-1])
128    tri = np.array([
129        # counter clockwise triangles
130        # z: up/down, x: right/left, y: front/back
131        [0,1,2], [3,2,1], # top face
132        [6,5,4], [5,6,7], # bottom face
133        [0,2,6], [6,4,0], # right face
134        [1,5,7], [7,3,1], # left face
135        [2,3,6], [7,6,3], # front face
136        [4,1,0], [5,1,4], # back face
137    ])
138
139    x, y, z = transform_xyz(view, jitter, x, y, z)
140    ax.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha)
141
142    draw_labels(ax, view, jitter, [
143         ('c+', [ 0, 0, c], [ 1, 0, 0]),
144         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
145         ('a+', [ a, 0, 0], [ 0, 0, 1]),
146         ('a-', [-a, 0, 0], [ 0, 0,-1]),
147         ('b+', [ 0, b, 0], [-1, 0, 0]),
148         ('b-', [ 0,-b, 0], [-1, 0, 0]),
149    ])
150
151def draw_sphere(ax, radius=10., steps=100):
152    """Draw a sphere"""
153    u = np.linspace(0, 2 * np.pi, steps)
154    v = np.linspace(0, np.pi, steps)
155
156    x = radius * np.outer(np.cos(u), np.sin(v))
157    y = radius * np.outer(np.sin(u), np.sin(v))
158    z = radius * np.outer(np.ones(np.size(u)), np.cos(v))
159    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w')
160
161def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gaussian'):
162    """
163    Draw the dispersion mesh showing the theta-phi orientations at which
164    the model will be evaluated.
165    """
166    theta, phi, psi = view
167    dtheta, dphi, dpsi = jitter
168
169    if dist == 'gaussian':
170        t = np.linspace(-3, 3, n)
171        weights = exp(-0.5*t**2)
172    elif dist == 'rectangle':
173        # Note: uses sasmodels ridiculous definition of rectangle width
174        t = np.linspace(-1, 1, n)*sqrt(3)
175        weights = np.ones_like(t)
176    else:
177        raise ValueError("expected dist to be 'gaussian' or 'rectangle'")
178
179    # mesh in theta, phi formed by rotating z
180    z = np.matrix([[0], [0], [radius]])
181    points = np.hstack([Rx(phi_i)*Ry(theta_i)*z
182                        for theta_i in dtheta*t
183                        for phi_i in dphi*t])
184    # rotate relative to beam
185    points = orient_relative_to_beam(view, points)
186
187    w = np.outer(weights*cos(radians(dtheta*t)), weights)
188
189    x, y, z = [np.array(v).flatten() for v in points]
190    ax.scatter(x, y, z, c=w.flatten(), marker='o', vmin=0., vmax=1.)
191
192def draw_labels(ax, view, jitter, text):
193    """
194    Draw text at a particular location.
195    """
196    labels, locations, orientations = zip(*text)
197    px, py, pz = zip(*locations)
198    dx, dy, dz = zip(*orientations)
199
200    px, py, pz = transform_xyz(view, jitter, px, py, pz)
201    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
202
203    # TODO: zdir for labels is broken, and labels aren't appearing.
204    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
205        zdir = np.asarray(zdir).flatten()
206        ax.text(p[0], p[1], p[2], label, zdir=zdir)
207
208# Definition of rotation matrices comes from wikipedia:
209#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
210def Rx(angle):
211    """Construct a matrix to rotate points about *x* by *angle* degrees."""
212    a = radians(angle)
213    R = [[1, 0, 0],
214         [0, +cos(a), -sin(a)],
215         [0, +sin(a), +cos(a)]]
216    return np.matrix(R)
217
218def Ry(angle):
219    """Construct a matrix to rotate points about *y* by *angle* degrees."""
220    a = radians(angle)
221    R = [[+cos(a), 0, +sin(a)],
222         [0, 1, 0],
223         [-sin(a), 0, +cos(a)]]
224    return np.matrix(R)
225
226def Rz(angle):
227    """Construct a matrix to rotate points about *z* by *angle* degrees."""
228    a = radians(angle)
229    R = [[+cos(a), -sin(a), 0],
230         [+sin(a), +cos(a), 0],
231         [0, 0, 1]]
232    return np.matrix(R)
233
234def transform_xyz(view, jitter, x, y, z):
235    """
236    Send a set of (x,y,z) points through the jitter and view transforms.
237    """
238    x, y, z = [np.asarray(v) for v in (x, y, z)]
239    shape = x.shape
240    points = np.matrix([x.flatten(),y.flatten(),z.flatten()])
241    points = apply_jitter(jitter, points)
242    points = orient_relative_to_beam(view, points)
243    x, y, z = [np.array(v).reshape(shape) for v in points]
244    return x, y, z
245
246def apply_jitter(jitter, points):
247    """
248    Apply the jitter transform to a set of points.
249
250    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
251    """
252    dtheta, dphi, dpsi = jitter
253    points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
254    return points
255
256def orient_relative_to_beam(view, points):
257    """
258    Apply the view transform to a set of points.
259
260    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
261    """
262    theta, phi, psi = view
263    points = Rz(phi)*Ry(theta)*Rz(psi)*points
264    return points
265
266# translate between number of dimension of dispersity and the number of
267# points along each dimension.
268PD_N_TABLE = {
269    (0, 0, 0): (0, 0, 0),     # 0
270    (1, 0, 0): (100, 0, 0),   # 100
271    (0, 1, 0): (0, 100, 0),
272    (0, 0, 1): (0, 0, 100),
273    (1, 1, 0): (30, 30, 0),   # 900
274    (1, 0, 1): (30, 0, 30),
275    (0, 1, 1): (0, 30, 30),
276    (1, 1, 1): (15, 15, 15),  # 3375
277}
278
279def clipped_range(data, portion=1.0, mode='central'):
280    """
281    Determine range from data.
282
283    If *portion* is 1, use full range, otherwise use the center of the range
284    or the top of the range, depending on whether *mode* is 'central' or 'top'.
285    """
286    if portion == 1.0:
287        return data.min(), data.max()
288    elif mode == 'central':
289        data = np.sort(data.flatten())
290        offset = int(portion*len(data)/2 + 0.5)
291        return data[offset], data[-offset]
292    elif mode == 'top':
293        data = np.sort(data.flatten())
294        offset = int(portion*len(data) + 0.5)
295        return data[offset], data[-1]
296
297def draw_scattering(calculator, ax, view, jitter, dist='gaussian'):
298    """
299    Plot the scattering for the particular view.
300
301    *calculator* is returned from :func:`build_model`.  *ax* are the 3D axes
302    on which the data will be plotted.  *view* and *jitter* are the current
303    orientation and orientation dispersity.  *dist* is one of the sasmodels
304    weight distributions.
305    """
306    ## Sasmodels use sqrt(3)*width for the rectangle range; scale to the
307    ## proper width for comparison. Commented out since now using the
308    ## sasmodels definition of width for rectangle.
309    #scale = 1/sqrt(3) if dist == 'rectangle' else 1
310    scale = 1
311
312    # add the orientation parameters to the model parameters
313    theta, phi, psi = view
314    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
315    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd>0, phi_pd>0, psi_pd>0)]
316    ## increase pd_n for testing jitter integration rather than simple viz
317    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
318
319    pars = dict(
320        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
321        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
322        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
323    )
324    pars.update(calculator.pars)
325
326    # compute the pattern
327    qx, qy = calculator._data.x_bins, calculator._data.y_bins
328    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
329
330    # scale it and draw it
331    Iqxy = np.log(Iqxy)
332    if calculator.limits:
333        # use limits from orientation (0,0,0)
334        vmin, vmax = calculator.limits
335    else:
336        vmin, vmax = clipped_range(Iqxy, portion=0.95, mode='top')
337    #print("range",(vmin,vmax))
338    #qx, qy = np.meshgrid(qx, qy)
339    if 0:
340        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
341        level[level<0] = 0
342        colors = plt.get_cmap()(level)
343        ax.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors)
344    elif 1:
345        ax.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
346                    levels=np.linspace(vmin, vmax, 24))
347    else:
348        ax.pcolormesh(qx, qy, Iqxy)
349
350def build_model(model_name, n=150, qmax=0.5, **pars):
351    """
352    Build a calculator for the given shape.
353
354    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
355    on which to evaluate the model.  The remaining parameters are stored in
356    the returned calculator as *calculator.pars*.  They are used by
357    :func:`draw_scattering` to set the non-orientation parameters in the
358    calculation.
359
360    Returns a *calculator* function which takes a dictionary or parameters and
361    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
362    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
363    for details.
364    """
365    from sasmodels.core import load_model_info, build_model
366    from sasmodels.data import empty_data2D
367    from sasmodels.direct_model import DirectModel
368
369    model_info = load_model_info(model_name)
370    model = build_model(model_info) #, dtype='double!')
371    q = np.linspace(-qmax, qmax, n)
372    data = empty_data2D(q, q)
373    calculator = DirectModel(data, model)
374
375    # stuff the values for non-orientation parameters into the calculator
376    calculator.pars = pars.copy()
377    calculator.pars.setdefault('backgound', 1e-3)
378
379    # fix the data limits so that we can see if the pattern fades
380    # under rotation or angular dispersion
381    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
382    Iqxy = np.log(Iqxy)
383    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
384    calculator.limits = vmin, vmax+1
385
386    return calculator
387
388def select_calculator(model_name, n=150):
389    """
390    Create a model calculator for the given shape.
391
392    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
393    parallelepiped or bcc_paracrystal. *n* is the number of points to use
394    in the q range.  *qmax* is chosen based on model parameters for the
395    given model to show something intersting.
396
397    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
398    equitorial axes and polar axis respectively.  See :func:`build_model`
399    for details on the returned calculator.
400    """
401    a, b, c = 10, 40, 100
402    if model_name == 'sphere':
403        calculator = build_model('sphere', n=n, radius=c)
404        a = b = c
405    elif model_name == 'bcc_paracrystal':
406        calculator = build_model('bcc_paracrystal', n=n, dnn=c,
407                                  d_factor=0.06, radius=40)
408        a = b = c
409    elif model_name == 'cylinder':
410        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
411        a = b
412    elif model_name == 'ellipsoid':
413        calculator = build_model('ellipsoid', n=n, qmax=1.0,
414                                 radius_polar=c, radius_equatorial=b)
415        a = b
416    elif model_name == 'triaxial_ellipsoid':
417        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
418                                 radius_equat_minor=a,
419                                 radius_equat_major=b,
420                                 radius_polar=c)
421    elif model_name == 'parallelepiped':
422        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
423    else:
424        raise ValueError("unknown model %s"%model_name)
425
426    return calculator, (a, b, c)
427
428def main(model_name='parallelepiped'):
429    """
430    Show an interactive orientation and jitter demo.
431
432    *model_name* is one of the models available in :func:`select_model`.
433    """
434    # set up calculator
435    calculator, size = select_calculator(model_name, n=150)
436
437    ## uncomment to set an independent the colour range for every view
438    ## If left commented, the colour range is fixed for all views
439    calculator.limits = None
440
441    ## use gaussian distribution unless testing integration
442    #dist = 'rectangle'
443    dist = 'gaussian'
444
445    ## initial view
446    #theta, dtheta = 70., 10.
447    #phi, dphi = -45., 3.
448    #psi, dpsi = -45., 3.
449    theta, phi, psi = 0, 0, 0
450    dtheta, dphi, dpsi = 0, 0, 0
451
452    ## create the plot window
453    #plt.hold(True)
454    plt.set_cmap('gist_earth')
455    plt.clf()
456    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
457    #ax = plt.subplot(gs[0], projection='3d')
458    ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
459    ax.axis('square')
460
461    axcolor = 'lightgoldenrodyellow'
462
463    ## add control widgets to plot
464    axtheta  = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor)
465    axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor)
466    axpsi = plt.axes([0.1, 0.05, 0.45, 0.04], axisbg=axcolor)
467    stheta = Slider(axtheta, 'Theta', -90, 90, valinit=theta)
468    sphi = Slider(axphi, 'Phi', -180, 180, valinit=phi)
469    spsi = Slider(axpsi, 'Psi', -180, 180, valinit=psi)
470
471    axdtheta  = plt.axes([0.75, 0.15, 0.15, 0.04], axisbg=axcolor)
472    axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor)
473    axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor)
474    # Note: using ridiculous definition of rectangle distribution, whose width
475    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
476    # the maximum width to 90.
477    dlimit = 30 if dist == 'gaussian' else 90/sqrt(3)
478    sdtheta = Slider(axdtheta, 'dTheta', 0, dlimit, valinit=dtheta)
479    sdphi = Slider(axdphi, 'dPhi', 0, 2*dlimit, valinit=dphi)
480    sdpsi = Slider(axdpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi)
481
482    ## callback to draw the new view
483    def update(val, axis=None):
484        view = stheta.val, sphi.val, spsi.val
485        jitter = sdtheta.val, sdphi.val, sdpsi.val
486        # set small jitter as 0 if multiple pd dims
487        dims = sum(v > 0 for v in jitter)
488        limit = [0, 0, 2, 5][dims]
489        jitter = [0 if v < limit else v for v in jitter]
490        ax.cla()
491        draw_beam(ax, (0, 0))
492        draw_jitter(ax, view, jitter, dist=dist, size=size)
493        #draw_jitter(ax, view, (0,0,0))
494        draw_mesh(ax, view, jitter, dist=dist)
495        draw_scattering(calculator, ax, view, jitter, dist=dist)
496        plt.gcf().canvas.draw()
497
498    ## bind control widgets to view updater
499    stheta.on_changed(lambda v: update(v,'theta'))
500    sphi.on_changed(lambda v: update(v, 'phi'))
501    spsi.on_changed(lambda v: update(v, 'psi'))
502    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
503    sdphi.on_changed(lambda v: update(v, 'dphi'))
504    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
505
506    ## initialize view
507    update(None, 'phi')
508
509    ## go interactive
510    plt.show()
511
512if __name__ == "__main__":
513    model_name = sys.argv[1] if len(sys.argv) > 1 else 'parallelepiped'
514    main(model_name)
Note: See TracBrowser for help on using the repository browser.