source: sasmodels/explore/jitter.py @ 767dca8

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 767dca8 was b9578fc, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

alternative handling of phi point density: scale phi

  • Property mode set to 100755
File size: 18.6 KB
Line 
1#!/usr/bin/env python
2"""
3Application to explore the difference between sasview 3.x orientation
4dispersity and possible replacement algorithms.
5"""
6from __future__ import division, print_function
7
8import sys, os
9sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
10
11import mpl_toolkits.mplot3d   # Adds projection='3d' option to subplot
12import matplotlib.pyplot as plt
13from matplotlib.widgets import Slider, CheckButtons
14from matplotlib import cm
15import numpy as np
16from numpy import pi, cos, sin, sqrt, exp, degrees, radians
17
18SCALED_PHI = True
19
20def draw_beam(ax, view=(0, 0)):
21    """
22    Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1)
23    """
24    #ax.plot([0,0],[0,0],[1,-1])
25    #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8)
26
27    steps = 25
28    u = np.linspace(0, 2 * np.pi, steps)
29    v = np.linspace(-1, 1, steps)
30
31    r = 0.02
32    x = r*np.outer(np.cos(u), np.ones_like(v))
33    y = r*np.outer(np.sin(u), np.ones_like(v))
34    z = 1.3*np.outer(np.ones_like(u), v)
35
36    theta, phi = view
37    shape = x.shape
38    points = np.matrix([x.flatten(), y.flatten(), z.flatten()])
39    points = Rz(phi)*Ry(theta)*points
40    x, y, z = [v.reshape(shape) for v in points]
41
42    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5)
43
44def draw_jitter(ax, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0)):
45    """
46    Represent jitter as a set of shapes at different orientations.
47    """
48    # set max diagonal to 0.95
49    scale = 0.95/sqrt(sum(v**2 for v in size))
50    size = tuple(scale*v for v in size)
51    draw_shape = draw_parallelepiped
52    #draw_shape = draw_ellipsoid
53
54    #np.random.seed(10)
55    #cloud = np.random.randn(10,3)
56    cloud = [
57        [-1, -1, -1],
58        [-1, -1,  0],
59        [-1, -1,  1],
60        [-1,  0, -1],
61        [-1,  0,  0],
62        [-1,  0,  1],
63        [-1,  1, -1],
64        [-1,  1,  0],
65        [-1,  1,  1],
66        [ 0, -1, -1],
67        [ 0, -1,  0],
68        [ 0, -1,  1],
69        [ 0,  0, -1],
70        [ 0,  0,  0],
71        [ 0,  0,  1],
72        [ 0,  1, -1],
73        [ 0,  1,  0],
74        [ 0,  1,  1],
75        [ 1, -1, -1],
76        [ 1, -1,  0],
77        [ 1, -1,  1],
78        [ 1,  0, -1],
79        [ 1,  0,  0],
80        [ 1,  0,  1],
81        [ 1,  1, -1],
82        [ 1,  1,  0],
83        [ 1,  1,  1],
84    ]
85    dtheta, dphi, dpsi = jitter
86    if dtheta == 0:
87        cloud = [v for v in cloud if v[0] == 0]
88    if dphi == 0:
89        cloud = [v for v in cloud if v[1] == 0]
90    if dpsi == 0:
91        cloud = [v for v in cloud if v[2] == 0]
92    draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8)
93    scale = 1/sqrt(3) if dist == 'rectangle' else 1
94    for point in cloud:
95        delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]]
96        draw_shape(ax, size, view, delta, alpha=0.8)
97    for v in 'xyz':
98        a, b, c = size
99        lim = np.sqrt(a**2+b**2+c**2)
100        getattr(ax, 'set_'+v+'lim')([-lim, lim])
101        getattr(ax, v+'axis').label.set_text(v)
102
103def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1):
104    """Draw an ellipsoid."""
105    a,b,c = size
106    u = np.linspace(0, 2 * np.pi, steps)
107    v = np.linspace(0, np.pi, steps)
108    x = a*np.outer(np.cos(u), np.sin(v))
109    y = b*np.outer(np.sin(u), np.sin(v))
110    z = c*np.outer(np.ones_like(u), np.cos(v))
111    x, y, z = transform_xyz(view, jitter, x, y, z)
112
113    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w', alpha=alpha)
114
115    draw_labels(ax, view, jitter, [
116         ('c+', [ 0, 0, c], [ 1, 0, 0]),
117         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
118         ('a+', [ a, 0, 0], [ 0, 0, 1]),
119         ('a-', [-a, 0, 0], [ 0, 0,-1]),
120         ('b+', [ 0, b, 0], [-1, 0, 0]),
121         ('b-', [ 0,-b, 0], [-1, 0, 0]),
122    ])
123
124def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1):
125    """Draw a parallelepiped."""
126    a,b,c = size
127    x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1])
128    y = b*np.array([ 1, 1,-1,-1, 1, 1,-1,-1])
129    z = c*np.array([ 1, 1, 1, 1,-1,-1,-1,-1])
130    tri = np.array([
131        # counter clockwise triangles
132        # z: up/down, x: right/left, y: front/back
133        [0,1,2], [3,2,1], # top face
134        [6,5,4], [5,6,7], # bottom face
135        [0,2,6], [6,4,0], # right face
136        [1,5,7], [7,3,1], # left face
137        [2,3,6], [7,6,3], # front face
138        [4,1,0], [5,1,4], # back face
139    ])
140
141    x, y, z = transform_xyz(view, jitter, x, y, z)
142    ax.plot_trisurf(x, y, triangles=tri, Z=z, color='w', alpha=alpha)
143
144    draw_labels(ax, view, jitter, [
145         ('c+', [ 0, 0, c], [ 1, 0, 0]),
146         ('c-', [ 0, 0,-c], [ 0, 0,-1]),
147         ('a+', [ a, 0, 0], [ 0, 0, 1]),
148         ('a-', [-a, 0, 0], [ 0, 0,-1]),
149         ('b+', [ 0, b, 0], [-1, 0, 0]),
150         ('b-', [ 0,-b, 0], [-1, 0, 0]),
151    ])
152
153def draw_sphere(ax, radius=10., steps=100):
154    """Draw a sphere"""
155    u = np.linspace(0, 2 * np.pi, steps)
156    v = np.linspace(0, np.pi, steps)
157
158    x = radius * np.outer(np.cos(u), np.sin(v))
159    y = radius * np.outer(np.sin(u), np.sin(v))
160    z = radius * np.outer(np.ones(np.size(u)), np.cos(v))
161    ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w')
162
163def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gaussian'):
164    """
165    Draw the dispersion mesh showing the theta-phi orientations at which
166    the model will be evaluated.
167    """
168    theta, phi, psi = view
169    dtheta, dphi, dpsi = jitter
170
171    if dist == 'gaussian':
172        t = np.linspace(-3, 3, n)
173        weights = exp(-0.5*t**2)
174    elif dist == 'rectangle':
175        # Note: uses sasmodels ridiculous definition of rectangle width
176        t = np.linspace(-1, 1, n)*sqrt(3)
177        weights = np.ones_like(t)
178    else:
179        raise ValueError("expected dist to be 'gaussian' or 'rectangle'")
180
181    if SCALED_PHI:
182        scale_phi = lambda dtheta, dphi: (
183            dphi/abs(cos(radians(dtheta))) if dtheta != 90
184            else 0 if dphi == 0
185            else 4*pi)
186        w = np.outer(weights, weights)
187    else:
188        scale_phi = lambda dtheta, dphe: dphi
189        w = np.outer(weights*cos(radians(dtheta*t)), weights)
190
191    # mesh in theta, phi formed by rotating z
192    z = np.matrix([[0], [0], [radius]])
193    points = np.hstack([Rx(scale_phi(theta_i, phi_j))*Ry(theta_i)*z
194                        for theta_i in dtheta*t
195                        for phi_j in dphi*t])
196    # select just the active points (i.e., those with phi < 180
197    active = np.array([abs(scale_phi(theta_i, phi_j)) < 180
198                       for theta_i in dtheta*t
199                       for phi_j in dphi*t])
200    points = points[:, active]
201    w = w.flatten()[active]
202
203    # rotate relative to beam
204    points = orient_relative_to_beam(view, points)
205
206    x, y, z = [np.array(v).flatten() for v in points]
207    ax.scatter(x, y, z, c=w, marker='o', vmin=0., vmax=1.)
208
209def draw_labels(ax, view, jitter, text):
210    """
211    Draw text at a particular location.
212    """
213    labels, locations, orientations = zip(*text)
214    px, py, pz = zip(*locations)
215    dx, dy, dz = zip(*orientations)
216
217    px, py, pz = transform_xyz(view, jitter, px, py, pz)
218    dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz)
219
220    # TODO: zdir for labels is broken, and labels aren't appearing.
221    for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)):
222        zdir = np.asarray(zdir).flatten()
223        ax.text(p[0], p[1], p[2], label, zdir=zdir)
224
225# Definition of rotation matrices comes from wikipedia:
226#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
227def Rx(angle):
228    """Construct a matrix to rotate points about *x* by *angle* degrees."""
229    a = radians(angle)
230    R = [[1, 0, 0],
231         [0, +cos(a), -sin(a)],
232         [0, +sin(a), +cos(a)]]
233    return np.matrix(R)
234
235def Ry(angle):
236    """Construct a matrix to rotate points about *y* by *angle* degrees."""
237    a = radians(angle)
238    R = [[+cos(a), 0, +sin(a)],
239         [0, 1, 0],
240         [-sin(a), 0, +cos(a)]]
241    return np.matrix(R)
242
243def Rz(angle):
244    """Construct a matrix to rotate points about *z* by *angle* degrees."""
245    a = radians(angle)
246    R = [[+cos(a), -sin(a), 0],
247         [+sin(a), +cos(a), 0],
248         [0, 0, 1]]
249    return np.matrix(R)
250
251def transform_xyz(view, jitter, x, y, z):
252    """
253    Send a set of (x,y,z) points through the jitter and view transforms.
254    """
255    x, y, z = [np.asarray(v) for v in (x, y, z)]
256    shape = x.shape
257    points = np.matrix([x.flatten(),y.flatten(),z.flatten()])
258    points = apply_jitter(jitter, points)
259    points = orient_relative_to_beam(view, points)
260    x, y, z = [np.array(v).reshape(shape) for v in points]
261    return x, y, z
262
263def apply_jitter(jitter, points):
264    """
265    Apply the jitter transform to a set of points.
266
267    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
268    """
269    dtheta, dphi, dpsi = jitter
270    points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points
271    return points
272
273def orient_relative_to_beam(view, points):
274    """
275    Apply the view transform to a set of points.
276
277    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
278    """
279    theta, phi, psi = view
280    points = Rz(phi)*Ry(theta)*Rz(psi)*points
281    return points
282
283# translate between number of dimension of dispersity and the number of
284# points along each dimension.
285PD_N_TABLE = {
286    (0, 0, 0): (0, 0, 0),     # 0
287    (1, 0, 0): (100, 0, 0),   # 100
288    (0, 1, 0): (0, 100, 0),
289    (0, 0, 1): (0, 0, 100),
290    (1, 1, 0): (30, 30, 0),   # 900
291    (1, 0, 1): (30, 0, 30),
292    (0, 1, 1): (0, 30, 30),
293    (1, 1, 1): (15, 15, 15),  # 3375
294}
295
296def clipped_range(data, portion=1.0, mode='central'):
297    """
298    Determine range from data.
299
300    If *portion* is 1, use full range, otherwise use the center of the range
301    or the top of the range, depending on whether *mode* is 'central' or 'top'.
302    """
303    if portion == 1.0:
304        return data.min(), data.max()
305    elif mode == 'central':
306        data = np.sort(data.flatten())
307        offset = int(portion*len(data)/2 + 0.5)
308        return data[offset], data[-offset]
309    elif mode == 'top':
310        data = np.sort(data.flatten())
311        offset = int(portion*len(data) + 0.5)
312        return data[offset], data[-1]
313
314def draw_scattering(calculator, ax, view, jitter, dist='gaussian'):
315    """
316    Plot the scattering for the particular view.
317
318    *calculator* is returned from :func:`build_model`.  *ax* are the 3D axes
319    on which the data will be plotted.  *view* and *jitter* are the current
320    orientation and orientation dispersity.  *dist* is one of the sasmodels
321    weight distributions.
322    """
323    ## Sasmodels use sqrt(3)*width for the rectangle range; scale to the
324    ## proper width for comparison. Commented out since now using the
325    ## sasmodels definition of width for rectangle.
326    #scale = 1/sqrt(3) if dist == 'rectangle' else 1
327    scale = 1
328
329    # add the orientation parameters to the model parameters
330    theta, phi, psi = view
331    theta_pd, phi_pd, psi_pd = [scale*v for v in jitter]
332    theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd>0, phi_pd>0, psi_pd>0)]
333    ## increase pd_n for testing jitter integration rather than simple viz
334    #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)]
335
336    pars = dict(
337        theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n,
338        phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n,
339        psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n,
340    )
341    pars.update(calculator.pars)
342
343    # compute the pattern
344    qx, qy = calculator._data.x_bins, calculator._data.y_bins
345    Iqxy = calculator(**pars).reshape(len(qx), len(qy))
346
347    # scale it and draw it
348    Iqxy = np.log(Iqxy)
349    if calculator.limits:
350        # use limits from orientation (0,0,0)
351        vmin, vmax = calculator.limits
352    else:
353        vmin, vmax = clipped_range(Iqxy, portion=0.95, mode='top')
354    #print("range",(vmin,vmax))
355    #qx, qy = np.meshgrid(qx, qy)
356    if 0:
357        level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i')
358        level[level<0] = 0
359        colors = plt.get_cmap()(level)
360        ax.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors)
361    elif 1:
362        ax.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1,
363                    levels=np.linspace(vmin, vmax, 24))
364    else:
365        ax.pcolormesh(qx, qy, Iqxy)
366
367def build_model(model_name, n=150, qmax=0.5, **pars):
368    """
369    Build a calculator for the given shape.
370
371    *model_name* is any sasmodels model.  *n* and *qmax* define an n x n mesh
372    on which to evaluate the model.  The remaining parameters are stored in
373    the returned calculator as *calculator.pars*.  They are used by
374    :func:`draw_scattering` to set the non-orientation parameters in the
375    calculation.
376
377    Returns a *calculator* function which takes a dictionary or parameters and
378    produces Iqxy.  The Iqxy value needs to be reshaped to an n x n matrix
379    for plotting.  See the :class:`sasmodels.direct_model.DirectModel` class
380    for details.
381    """
382    from sasmodels.core import load_model_info, build_model
383    from sasmodels.data import empty_data2D
384    from sasmodels.direct_model import DirectModel
385
386    model_info = load_model_info(model_name)
387    model = build_model(model_info) #, dtype='double!')
388    q = np.linspace(-qmax, qmax, n)
389    data = empty_data2D(q, q)
390    calculator = DirectModel(data, model)
391
392    # stuff the values for non-orientation parameters into the calculator
393    calculator.pars = pars.copy()
394    calculator.pars.setdefault('backgound', 1e-3)
395
396    # fix the data limits so that we can see if the pattern fades
397    # under rotation or angular dispersion
398    Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars)
399    Iqxy = np.log(Iqxy)
400    vmin, vmax = clipped_range(Iqxy, 0.95, mode='top')
401    calculator.limits = vmin, vmax+1
402
403    return calculator
404
405def select_calculator(model_name, n=150, size=(10,40,100)):
406    """
407    Create a model calculator for the given shape.
408
409    *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid,
410    parallelepiped or bcc_paracrystal. *n* is the number of points to use
411    in the q range.  *qmax* is chosen based on model parameters for the
412    given model to show something intersting.
413
414    Returns *calculator* and tuple *size* (a,b,c) giving minor and major
415    equitorial axes and polar axis respectively.  See :func:`build_model`
416    for details on the returned calculator.
417    """
418    a, b, c = size
419    if model_name == 'sphere':
420        calculator = build_model('sphere', n=n, radius=c)
421        a = b = c
422    elif model_name == 'bcc_paracrystal':
423        calculator = build_model('bcc_paracrystal', n=n, dnn=c,
424                                  d_factor=0.06, radius=40)
425        a = b = c
426    elif model_name == 'cylinder':
427        calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c)
428        a = b
429    elif model_name == 'ellipsoid':
430        calculator = build_model('ellipsoid', n=n, qmax=1.0,
431                                 radius_polar=c, radius_equatorial=b)
432        a = b
433    elif model_name == 'triaxial_ellipsoid':
434        calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5,
435                                 radius_equat_minor=a,
436                                 radius_equat_major=b,
437                                 radius_polar=c)
438    elif model_name == 'parallelepiped':
439        calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c)
440    else:
441        raise ValueError("unknown model %s"%model_name)
442
443    return calculator, (a, b, c)
444
445def main(model_name='parallelepiped', size=(10, 40, 100)):
446    """
447    Show an interactive orientation and jitter demo.
448
449    *model_name* is one of the models available in :func:`select_model`.
450    """
451    # set up calculator
452    calculator, size = select_calculator(model_name, n=150, size=size)
453
454    ## uncomment to set an independent the colour range for every view
455    ## If left commented, the colour range is fixed for all views
456    calculator.limits = None
457
458    ## use gaussian distribution unless testing integration
459    #dist = 'rectangle'
460    dist = 'gaussian'
461
462    ## initial view
463    #theta, dtheta = 70., 10.
464    #phi, dphi = -45., 3.
465    #psi, dpsi = -45., 3.
466    theta, phi, psi = 0, 0, 0
467    dtheta, dphi, dpsi = 0, 0, 0
468
469    ## create the plot window
470    #plt.hold(True)
471    plt.set_cmap('gist_earth')
472    plt.clf()
473    #gs = gridspec.GridSpec(2,1,height_ratios=[4,1])
474    #ax = plt.subplot(gs[0], projection='3d')
475    ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d')
476    try:  # CRUFT: not all versions of matplotlib accept 'square' 3d projection
477        ax.axis('square')
478    except Exception:
479        pass
480
481    axcolor = 'lightgoldenrodyellow'
482
483    ## add control widgets to plot
484    axtheta  = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor)
485    axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor)
486    axpsi = plt.axes([0.1, 0.05, 0.45, 0.04], axisbg=axcolor)
487    stheta = Slider(axtheta, 'Theta', -90, 90, valinit=theta)
488    sphi = Slider(axphi, 'Phi', -180, 180, valinit=phi)
489    spsi = Slider(axpsi, 'Psi', -180, 180, valinit=psi)
490
491    axdtheta  = plt.axes([0.75, 0.15, 0.15, 0.04], axisbg=axcolor)
492    axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor)
493    axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor)
494    # Note: using ridiculous definition of rectangle distribution, whose width
495    # in sasmodels is sqrt(3) times the given width.  Divide by sqrt(3) to keep
496    # the maximum width to 90.
497    dlimit = 30 if dist == 'gaussian' else 90/sqrt(3)
498    sdtheta = Slider(axdtheta, 'dTheta', 0, dlimit, valinit=dtheta)
499    sdphi = Slider(axdphi, 'dPhi', 0, 2*dlimit, valinit=dphi)
500    sdpsi = Slider(axdpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi)
501
502    ## callback to draw the new view
503    def update(val, axis=None):
504        view = stheta.val, sphi.val, spsi.val
505        jitter = sdtheta.val, sdphi.val, sdpsi.val
506        # set small jitter as 0 if multiple pd dims
507        dims = sum(v > 0 for v in jitter)
508        limit = [0, 0, 2, 5][dims]
509        jitter = [0 if v < limit else v for v in jitter]
510        ax.cla()
511        draw_beam(ax, (0, 0))
512        draw_jitter(ax, view, jitter, dist=dist, size=size)
513        #draw_jitter(ax, view, (0,0,0))
514        draw_mesh(ax, view, jitter, dist=dist)
515        draw_scattering(calculator, ax, view, jitter, dist=dist)
516        plt.gcf().canvas.draw()
517
518    ## bind control widgets to view updater
519    stheta.on_changed(lambda v: update(v,'theta'))
520    sphi.on_changed(lambda v: update(v, 'phi'))
521    spsi.on_changed(lambda v: update(v, 'psi'))
522    sdtheta.on_changed(lambda v: update(v, 'dtheta'))
523    sdphi.on_changed(lambda v: update(v, 'dphi'))
524    sdpsi.on_changed(lambda v: update(v, 'dpsi'))
525
526    ## initialize view
527    update(None, 'phi')
528
529    ## go interactive
530    plt.show()
531
532if __name__ == "__main__":
533    model_name = sys.argv[1] if len(sys.argv) > 1 else 'parallelepiped'
534    size = tuple(int(v) for v in sys.argv[2].split(',')) if len(sys.argv) > 2 else (10, 40, 100)
535    main(model_name, size)
Note: See TracBrowser for help on using the repository browser.