source:sasmodels/explore/realspace.py@5778279

ticket_1156
Last change on this file since 5778279 was 5778279, checked in by Paul Kienzle <pkienzle@…>, 18 months ago

faster lattice sampling for realspace code

• Property mode set to `100644`
File size: 39.4 KB
Line
1from __future__ import division, print_function
2
3import time
4from copy import copy
5import os
6import argparse
7from collections import OrderedDict
8
9import numpy as np
10from numpy import pi, radians, sin, cos, sqrt
11from numpy.random import poisson, uniform, randn, rand, randint
12from numpy.polynomial.legendre import leggauss
13from scipy.integrate import simps
14from scipy.special import j1 as J1
15
16try:
17    import numba
18    USE_NUMBA = True
19except ImportError:
20    USE_NUMBA = False
21
22# Definition of rotation matrices comes from wikipedia:
23#    https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
24def Rx(angle):
25    """Construct a matrix to rotate points about *x* by *angle* degrees."""
27    R = [[1, 0, 0],
28         [0, +cos(a), -sin(a)],
29         [0, +sin(a), +cos(a)]]
30    return np.matrix(R)
31
32def Ry(angle):
33    """Construct a matrix to rotate points about *y* by *angle* degrees."""
35    R = [[+cos(a), 0, +sin(a)],
36         [0, 1, 0],
37         [-sin(a), 0, +cos(a)]]
38    return np.matrix(R)
39
40def Rz(angle):
41    """Construct a matrix to rotate points about *z* by *angle* degrees."""
43    R = [[+cos(a), -sin(a), 0],
44         [+sin(a), +cos(a), 0],
45         [0, 0, 1]]
46    return np.matrix(R)
47
48def rotation(theta, phi, psi):
49    """
50    Apply the jitter transform to a set of points.
51
52    Points are stored in a 3 x n numpy matrix, not a numpy array or tuple.
53    """
54    return Rx(phi)*Ry(theta)*Rz(psi)
55
56def apply_view(points, view):
57    """
58    Apply the view transform (theta, phi, psi) to a set of points.
59
60    Points are stored in a 3 x n numpy array.
61
62    View angles are in degrees.
63    """
64    theta, phi, psi = view
65    return np.asarray((Rz(phi)*Ry(theta)*Rz(psi))*np.matrix(points.T)).T
66
67
68def invert_view(qx, qy, view):
69    """
70    Return (qa, qb, qc) for the (theta, phi, psi) view angle at detector
71    pixel (qx, qy).
72
73    View angles are in degrees.
74    """
75    theta, phi, psi = view
76    q = np.vstack((qx.flatten(), qy.flatten(), 0*qx.flatten()))
77    return np.asarray((Rz(-psi)*Ry(-theta)*Rz(-phi))*np.matrix(q))
78
79
80I3 = np.matrix([[1., 0, 0], [0, 1, 0], [0, 0, 1]])
81
82class Shape:
83    rotation = I3
84    center = np.array([0., 0., 0.])[:, None]
85    r_max = None
86    lattice_size = np.array((1, 1, 1))
87    lattice_spacing = np.array((1., 1., 1.))
88    lattice_distortion = 0.0
89    lattice_rotation = 0.0
90    lattice_type = ""
91
92    def volume(self):
93        # type: () -> float
94        raise NotImplementedError()
95
96    def sample(self, density):
97        # type: (float) -> np.ndarray[N], np.ndarray[N, 3]
98        raise NotImplementedError()
99
100    def dims(self):
101        # type: () -> float, float, float
102        raise NotImplementedError()
103
104    def rotate(self, theta, phi, psi):
105        if theta != 0. or phi != 0. or psi != 0.:
106            self.rotation = rotation(theta, phi, psi) * self.rotation
107        return self
108
109    def shift(self, x, y, z):
110        self.center = self.center + np.array([x, y, z])[:, None]
111        return self
112
113    def lattice(self, size=(1, 1, 1), spacing=(2, 2, 2), type="sc",
114                distortion=0.0, rotation=0.0):
115        self.lattice_size = np.asarray(size, 'i')
116        self.lattice_spacing = np.asarray(spacing, 'd')
117        self.lattice_type = type
118        self.lattice_distortion = distortion
119        self.lattice_rotation = rotation
120
122        if self.rotation is I3:
123            points = points.T + self.center
124        else:
125            points = np.asarray(self.rotation * np.matrix(points.T)) + self.center
126        if self.lattice_type:
127            points = self._apply_lattice(points)
128        return points.T
129
130    def r_bins(self, q, over_sampling=10, r_step=0.):
131        if self.lattice_type:
132            r_max = np.sqrt(np.sum(self.lattice_size*self.lattice_spacing*self.dims)**2)/2
133        else:
134            r_max = self.r_max
135        #r_max = min(2 * pi / q[0], r_max)
136        if r_step == 0.:
137            r_step = 2 * pi / q[-1] / over_sampling
138        #r_step = 0.01
139        return np.arange(r_step, r_max, r_step)
140
141    def _apply_lattice(self, points):
142        """Spread points to different lattice positions"""
143        size = self.lattice_size
144        spacing = self.lattice_spacing
145        shuffle = self.lattice_distortion
146        rotate = self.lattice_rotation
147        lattice = self.lattice_type
148
149        if rotate != 0:
150            # To vectorize the rotations we will need to unwrap the matrix multiply
151            raise NotImplementedError("don't handle rotations yet")
152
153        # Determine the number of lattice points in the lattice
154        shapes_per_cell = 2 if lattice == "bcc" else 4 if lattice == "fcc" else 1
155        number_of_lattice_points = np.prod(size) * shapes_per_cell
156
157        # For each point in the original shape, figure out which lattice point
158        # to translate it to.  This is both cell index (i*ny*nz + j*nz  + k) as
159        # well as the point in the cell (corner, body center or face center).
160        nsamples = points.shape[1]
161        lattice_point = randint(number_of_lattice_points, size=nsamples)
162
163        # Translate the cell index into the i,j,k coordinates of the senter
164        cell_index = lattice_point // shapes_per_cell
165        center = np.vstack((cell_index//(size[1]*size[2]),
166                            (cell_index%(size[1]*size[2]))//size[2],
167                            cell_index%size[2]))
168        center = np.asarray(center, dtype='d')
169        if lattice == "bcc":
170            center[:, lattice_point % shapes_per_cell == 1] += [[0.5], [0.5], [0.5]]
171        elif lattice == "fcc":
172            center[:, lattice_point % shapes_per_cell == 1] += [[0.0], [0.5], [0.5]]
173            center[:, lattice_point % shapes_per_cell == 2] += [[0.5], [0.0], [0.5]]
174            center[:, lattice_point % shapes_per_cell == 3] += [[0.5], [0.5], [0.0]]
175
176        # Each lattice point has its own displacement from the ideal position.
177        # Not checking that shapes do not overlap if displacement is too large.
178        offset = shuffle*(randn(3, number_of_lattice_points) if shuffle < 0.3
179                          else rand(3, number_of_lattice_points))
180        center += offset[:, cell_index]
181
182        # Each lattice point has its own rotation.  Rotate the point prior to
183        # applying any displacement.
184        # rotation = rotate*(randn(size=(shapes, 3)) if shuffle < 30 else rand(size=(nsamples, 3)))
185        # for k in shapes: points[k] = rotation[k]*points[k]
186        points += center*(np.array([spacing])*np.array(self.dims)).T
187        return points
188
189class Composite(Shape):
190    def __init__(self, shapes, center=(0, 0, 0), orientation=(0, 0, 0)):
191        self.shapes = shapes
192        self.rotate(*orientation)
193        self.shift(*center)
194
195        # Find the worst case distance between any two points amongst a set
196        # of shapes independent of orientation.  This could easily be a
197        # factor of two worse than necessary, e.g., a pair of thin rods
198        # end-to-end vs the same pair side-by-side.
199        distances = [((s1.r_max + s2.r_max)/2
200                      + sqrt(np.sum((s1.center - s2.center)**2)))
201                     for s1 in shapes
202                     for s2 in shapes]
203        self.r_max = max(distances + [s.r_max for s in shapes])
204        self.volume = sum(shape.volume for shape in self.shapes)
205
206    def sample(self, density):
207        values, points = zip(*(shape.sample(density) for shape in self.shapes))
209
210class Box(Shape):
211    def __init__(self, a, b, c,
212                 value, center=(0, 0, 0), orientation=(0, 0, 0)):
213        self.value = np.asarray(value)
214        self.rotate(*orientation)
215        self.shift(*center)
216        self.a, self.b, self.c = a, b, c
217        self._scale = np.array([a/2, b/2, c/2])[None, :]
218        self.r_max = sqrt(a**2 + b**2 + c**2)
219        self.dims = a, b, c
220        self.volume = a*b*c
221
222    def sample(self, density):
223        num_points = poisson(density*self.volume)
224        points = self._scale*uniform(-1, 1, size=(num_points, 3))
225        values = self.value.repeat(points.shape[0])
227
228class EllipticalCylinder(Shape):
229    def __init__(self, ra, rb, length,
230                 value, center=(0, 0, 0), orientation=(0, 0, 0)):
231        self.value = np.asarray(value)
232        self.rotate(*orientation)
233        self.shift(*center)
234        self.ra, self.rb, self.length = ra, rb, length
235        self._scale = np.array([ra, rb, length/2])[None, :]
236        self.r_max = sqrt(4*max(ra, rb)**2 + length**2)
237        self.dims = 2*ra, 2*rb, length
238        self.volume = pi*ra*rb*length
239
240    def sample(self, density):
241        # randomly sample from a box of side length 2*r, excluding anything
242        # not in the cylinder
243        num_points = poisson(density*4*self.ra*self.rb*self.length)
244        points = uniform(-1, 1, size=(num_points, 3))
245        radius = points[:, 0]**2 + points[:, 1]**2
246        points = points[radius <= 1]
247        values = self.value.repeat(points.shape[0])
249
250class EllipticalBicelle(Shape):
251    def __init__(self, ra, rb, length,
252                 thick_rim, thick_face,
253                 value_core, value_rim, value_face,
254                 center=(0, 0, 0), orientation=(0, 0, 0)):
255        self.rotate(*orientation)
256        self.shift(*center)
257        self.value = value_core
258        self.ra, self.rb, self.length = ra, rb, length
259        self.thick_rim, self.thick_face = thick_rim, thick_face
260        self.value_rim, self.value_face = value_rim, value_face
261
262        # reset cylinder to outer dimensions for calculating scale, etc.
263        ra = self.ra + self.thick_rim
264        rb = self.rb + self.thick_rim
265        length = self.length + 2*self.thick_face
266        self._scale = np.array([ra, rb, length/2])[None, :]
267        self.r_max = sqrt(4*max(ra, rb)**2 + length**2)
268        self.dims = 2*ra, 2*rb, length
269        self.volume = pi*ra*rb*length
270
271    def sample(self, density):
272        # randomly sample from a box of side length 2*r, excluding anything
273        # not in the cylinder
274        ra = self.ra + self.thick_rim
275        rb = self.rb + self.thick_rim
276        length = self.length + 2*self.thick_face
277        num_points = poisson(density*4*ra*rb*length)
278        points = uniform(-1, 1, size=(num_points, 3))
279        radius = points[:, 0]**2 + points[:, 1]**2
280        points = points[radius <= 1]
281        # set all to core value first
282        values = np.ones_like(points[:, 0])*self.value
283        # then set value to face value if |z| > face/(length/2))
284        values[abs(points[:, 2]) > self.length/(self.length + 2*self.thick_face)] = self.value_face
285        # finally set value to rim value if outside the core ellipse
286        radius = (points[:, 0]**2*(1 + self.thick_rim/self.ra)**2
287                  + points[:, 1]**2*(1 + self.thick_rim/self.rb)**2)
290
291class TriaxialEllipsoid(Shape):
292    def __init__(self, ra, rb, rc,
293                 value, center=(0, 0, 0), orientation=(0, 0, 0)):
294        self.value = np.asarray(value)
295        self.rotate(*orientation)
296        self.shift(*center)
297        self.ra, self.rb, self.rc = ra, rb, rc
298        self._scale = np.array([ra, rb, rc])[None, :]
299        self.r_max = 2*max(ra, rb, rc)
300        self.dims = 2*ra, 2*rb, 2*rc
301        self.volume = 4*pi/3 * ra * rb * rc
302
303    def sample(self, density):
304        # randomly sample from a box of side length 2*r, excluding anything
305        # not in the ellipsoid
306        num_points = poisson(density*8*self.ra*self.rb*self.rc)
307        points = uniform(-1, 1, size=(num_points, 3))
309        points = self._scale*points[radius <= 1]
310        values = self.value.repeat(points.shape[0])
312
313class Helix(Shape):
315                 value, center=(0, 0, 0), orientation=(0, 0, 0)):
316        self.value = np.asarray(value)
317        self.rotate(*orientation)
318        self.shift(*center)
319        helix_length = helix_pitch * tube_length/sqrt(helix_radius**2 + helix_pitch**2)
325        # small tube radius approximation; for larger tubes need to account
326        # for the fact that the inner length is much shorter than the outer
327        # length
329
330    def points(self, density):
332        points = uniform(-1, 1, size=(num_points, 3))
333        radius = points[:, 0]**2 + points[:, 1]**2
334        points = points[radius <= 1]
335
336        # Based on math stackexchange answer by Jyrki Lahtonen
337        #     https://math.stackexchange.com/a/461637
338        # with helix along z rather than x [so tuples in answer are (z, x, y)]
339        # and with random points in the cross section (p1, p2) rather than
340        # uniform points on the surface (cos u, sin u).
342        h = self.helix_pitch
343        scale = 1/sqrt(R**2 + h**2)
344        t = points[:, 3] * (self.tube_length * scale/2)
345        cos_t, sin_t = cos(t), sin(t)
346
347        # rx = R*cos_t
348        # ry = R*sin_t
349        # rz = h*t
350        # nx = -a * cos_t * points[:, 1]
351        # ny = -a * sin_t * points[:, 1]
352        # nz = 0
353        # bx = (a * h/scale) * sin_t * points[:, 2]
354        # by = (-a * h/scale) * cos_t * points[:, 2]
355        # bz = a*R/scale
356        # x = rx + nx + bx
357        # y = ry + ny + by
358        # z = rz + nz + bz
359        u, v = (R - a*points[:, 1]), (a * h/scale)*points[:, 2]
360        x = u * cos_t + v * sin_t
361        y = u * sin_t - v * cos_t
362        z = a*R/scale + h * t
363
364        points = np.hstack((x, y, z))
365        values = self.value.repeat(points.shape[0])
367
368def csbox(a=10, b=20, c=30, da=1, db=2, dc=3, slda=1, sldb=2, sldc=3, sld_core=4):
369    core = Box(a, b, c, sld_core)
370    side_a = Box(da, b, c, slda, center=((a+da)/2, 0, 0))
371    side_b = Box(a, db, c, sldb, center=(0, (b+db)/2, 0))
372    side_c = Box(a, b, dc, sldc, center=(0, 0, (c+dc)/2))
373    side_a2 = copy(side_a).shift(-a-da, 0, 0)
374    side_b2 = copy(side_b).shift(0, -b-db, 0)
375    side_c2 = copy(side_c).shift(0, 0, -c-dc)
376    shape = Composite((core, side_a, side_b, side_c, side_a2, side_b2, side_c2))
377    shape.dims = 2*da+a, 2*db+b, 2*dc+c
378    return shape
379
380def _Iqxy(values, x, y, z, qa, qb, qc):
381    """I(q) = |sum V(r) rho(r) e^(1j q.r)|^2 / sum V(r)"""
382    Iq = [abs(np.sum(values*np.exp(1j*(qa_k*x + qb_k*y + qc_k*z))))**2
383            for qa_k, qb_k, qc_k in zip(qa.flat, qb.flat, qc.flat)]
384    return Iq
385
386if USE_NUMBA:
387    # Override simple numpy solution with numba if available
388    from numba import njit
389    @njit("f8[:](f8[:],f8[:],f8[:],f8[:],f8[:],f8[:],f8[:])")
390    def _Iqxy(values, x, y, z, qa, qb, qc):
391        Iq = np.zeros_like(qa)
392        for j in range(len(Iq)):
393            total = 0. + 0j
394            for k in range(len(values)):
395                total += values[k]*np.exp(1j*(qa[j]*x[k] + qb[j]*y[k] + qc[j]*z[k]))
396            Iq[j] = abs(total)**2
397        return Iq
398
399def calc_Iqxy(qx, qy, rho, points, volume=1.0, view=(0, 0, 0)):
400    qx, qy = np.broadcast_arrays(qx, qy)
401    qa, qb, qc = invert_view(qx, qy, view)
402    rho, volume = np.broadcast_arrays(rho, volume)
403    values = rho*volume
404    x, y, z = points.T
405    values, x, y, z, qa, qb, qc = [np.asarray(v, 'd')
406                                   for v in (values, x, y, z, qa, qb, qc)]
407
408    # I(q) = |sum V(r) rho(r) e^(1j q.r)|^2 / sum V(r)
409    Iq = _Iqxy(values, x, y, z, qa.flatten(), qb.flatten(), qc.flatten())
410    return np.asarray(Iq).reshape(qx.shape) / np.sum(volume)
411
412def _calc_Pr_nonuniform(r, rho, points):
413    # Make Pr a little be bigger than necessary so that only distances
414    # min < d < max end up in Pr
415    n_max = len(r)+1
416    extended_Pr = np.zeros(n_max+1, 'd')
417    # r refers to bin centers; find corresponding bin edges
418    bins = bin_edges(r)
419    t_next = time.clock() + 3
420    for k, rho_k in enumerate(rho[:-1]):
421        distance = sqrt(np.sum((points[k] - points[k+1:])**2, axis=1))
422        weights = rho_k * rho[k+1:]
423        index = np.searchsorted(bins, distance)
424        # Note: indices may be duplicated, so "Pr[index] += w" will not work!!
425        extended_Pr += np.bincount(index, weights, n_max+1)
426        t = time.clock()
427        if t > t_next:
428            t_next = t + 3
429            print("processing %d of %d"%(k, len(rho)-1))
430    Pr = extended_Pr[1:-1]
431    return Pr
432
433def _calc_Pr_uniform(r, rho, points):
434    # Make Pr a little be bigger than necessary so that only distances
435    # min < d < max end up in Pr
436    dr, n_max = r[0], len(r)
437    extended_Pr = np.zeros(n_max+1, 'd')
438    t0 = time.clock()
439    t_next = t0 + 3
440    for k, rho_k in enumerate(rho[:-1]):
441        distances = sqrt(np.sum((points[k] - points[k+1:])**2, axis=1))
442        weights = rho_k * rho[k+1:]
443        index = np.minimum(np.asarray(distances/dr, 'i'), n_max)
444        # Note: indices may be duplicated, so "Pr[index] += w" will not work!!
445        extended_Pr += np.bincount(index, weights, n_max+1)
446        t = time.clock()
447        if t > t_next:
448            t_next = t + 3
449            print("processing %d of %d"%(k, len(rho)-1))
450    #print("time py:", time.clock() - t0)
451    Pr = extended_Pr[:-1]
452    # Make Pr independent of sampling density.  The factor of 2 comes because
453    # we are only accumulating the upper triangular distances.
454    #Pr = Pr * 2 / len(rho)**2
455    return Pr
456
457    # Can get an additional 2x by going to C.  Cuda/OpenCL will allow even
458    # more speedup, though still bounded by the n^2 cose.
459    """
460void pdfcalc(int n, const double *pts, const double *rho,
461             int nPr, double *Pr, double rstep)
462{
463  int i,j;
464
465  for (i=0; i<n-2; i++) {
466    for (j=i+1; j<=n-1; j++) {
467      const double dxx=pts[3*i]-pts[3*j];
468      const double dyy=pts[3*i+1]-pts[3*j+1];
469      const double dzz=pts[3*i+2]-pts[3*j+2];
470      const double d=sqrt(dxx*dxx+dyy*dyy+dzz*dzz);
471      const int k=rint(d/rstep);
472      if (k < nPr) Pr[k]+=rho[i]*rho[j];
473    }
474  }
475}
476"""
477
478if USE_NUMBA:
479    # Override simple numpy solution with numba if available
480    @njit("f8[:](f8[:], f8[:], f8[:,:])")
481    def _calc_Pr_uniform(r, rho, points):
482        dr = r[0]
483        n_max = len(r)
484        Pr = np.zeros_like(r)
485        for j in range(len(rho) - 1):
486            x, y, z = points[j, 0], points[j, 1], points[j, 2]
487            for k in range(j+1, len(rho)):
488                distance = sqrt((x - points[k, 0])**2
489                                + (y - points[k, 1])**2
490                                + (z - points[k, 2])**2)
491                index = int(distance/dr)
492                if index < n_max:
493                    Pr[index] += rho[j] * rho[k]
494        # Make Pr independent of sampling density.  The factor of 2 comes because
495        # we are only accumulating the upper triangular distances.
496        #Pr = Pr * 2 / len(rho)**2
497        return Pr
498
499
500def calc_Pr(r, rho, points):
501    # P(r) with uniform steps in r is 3x faster; check if we are uniform
502    # before continuing
503    r, rho, points = [np.asarray(v, 'd') for v in (r, rho, points)]
504    if np.max(np.abs(np.diff(r) - r[0])) > r[0]*0.01:
505        Pr = _calc_Pr_nonuniform(r, rho, points)
506    else:
507        Pr = _calc_Pr_uniform(r, rho, points)
508    return Pr / Pr.max()
509
510
511def j0(x):
512    return np.sinc(x/np.pi)
513
514def calc_Iq(q, r, Pr):
515    Iq = np.array([simps(Pr * j0(qk*r), r) for qk in q])
516    #Iq = np.array([np.trapz(Pr * j0(qk*r), r) for qk in q])
517    Iq /= Iq[0]
518    return Iq
519
520# NOTE: copied from sasmodels/resolution.py
521def bin_edges(x):
522    """
523    Determine bin edges from bin centers, assuming that edges are centered
524    between the bins.
525
526    Note: this uses the arithmetic mean, which may not be appropriate for
527    log-scaled data.
528    """
529    if len(x) < 2 or (np.diff(x) < 0).any():
530        raise ValueError("Expected bins to be an increasing set")
531    edges = np.hstack([
532        x[0]  - 0.5*(x[1]  - x[0]),  # first point minus half first interval
533        0.5*(x[1:] + x[:-1]),        # mid points of all central intervals
534        x[-1] + 0.5*(x[-1] - x[-2]), # last point plus half last interval
535        ])
536    return edges
537
538# -------------- plotters ----------------
539def plot_calc(r, Pr, q, Iq, theory=None, title=None):
540    import matplotlib.pyplot as plt
541    plt.subplot(211)
542    plt.plot(r, Pr, '-', label="Pr")
543    plt.xlabel('r (A)')
544    plt.ylabel('Pr (1/A^2)')
545    if title is not None:
546        plt.title(title)
547    plt.subplot(212)
548    plt.loglog(q, Iq, '-', label='from Pr')
549    plt.xlabel('q (1/A')
550    plt.ylabel('Iq')
551    if theory is not None:
552        plt.loglog(theory[0], theory[1]/theory[1][0], '-', label='analytic')
553        plt.legend()
554
555def plot_calc_2d(qx, qy, Iqxy, theory=None, title=None):
556    import matplotlib.pyplot as plt
557    qx, qy = bin_edges(qx), bin_edges(qy)
558    #qx, qy = np.meshgrid(qx, qy)
559    if theory is not None:
560        plt.subplot(121)
561    #plt.pcolor(qx, qy, np.log10(Iqxy))
562    extent = [qx[0], qx[-1], qy[0], qy[-1]]
563    plt.imshow(np.log10(Iqxy), extent=extent, interpolation="nearest",
564               origin='lower')
565    plt.xlabel('qx (1/A)')
566    plt.ylabel('qy (1/A)')
567    plt.axis('equal')
568    plt.axis(extent)
569    #plt.grid(True)
570    if title is not None:
571        plt.title(title)
572    if theory is not None:
573        plt.subplot(122)
574        plt.imshow(np.log10(theory), extent=extent, interpolation="nearest",
575                   origin='lower')
576        plt.axis('equal')
577        plt.axis(extent)
578        plt.xlabel('qx (1/A)')
579
580def plot_points(rho, points):
581    import mpl_toolkits.mplot3d
582    import matplotlib.pyplot as plt
583
584    ax = plt.axes(projection='3d')
585    try:
586        ax.axis('square')
587    except Exception:
588        pass
589    n = len(points)
590    #print("len points", n)
591    index = np.random.choice(n, size=500) if n > 500 else slice(None, None)
592    ax.scatter(points[index, 0], points[index, 1], points[index, 2], c=rho[index])
593    # make square axes
594    minmax = np.array([points.min(), points.max()])
595    ax.scatter(minmax, minmax, minmax, c='w')
596    #low, high = points.min(axis=0), points.max(axis=0)
597    #ax.axis([low[0], high[0], low[1], high[1], low[2], high[2]])
598    ax.set_xlabel("x")
599    ax.set_ylabel("y")
600    ax.set_zlabel("z")
601    ax.autoscale(True)
602
603# ----------- Analytic models --------------
604def sas_sinx_x(x):
605    with np.errstate(all='ignore'):
606        retvalue = sin(x)/x
607    retvalue[x == 0.] = 1.
608    return retvalue
609
610def sas_2J1x_x(x):
611    with np.errstate(all='ignore'):
612        retvalue = 2*J1(x)/x
613    retvalue[x == 0] = 1.
614    return retvalue
615
616def sas_3j1x_x(x):
617    """return 3*j1(x)/x"""
618    with np.errstate(all='ignore'):
619        retvalue = 3*(sin(x) - x*cos(x))/x**3
620    retvalue[x == 0.] = 1.
621    return retvalue
622
624    z, w = leggauss(76)
625    cos_alpha = (z+1)/2
626    sin_alpha = sqrt(1.0 - cos_alpha**2)
627    Iq = np.empty_like(q)
628    for k, qk in enumerate(q):
629        qab, qc = qk*sin_alpha, qk*cos_alpha
630        Fq = sas_2J1x_x(qab*radius) * sas_sinx_x(qc*length/2)
631        Iq[k] = np.sum(w*Fq**2)
632    Iq = Iq
633    return Iq
634
635def cylinder_Iqxy(qx, qy, radius, length, view=(0, 0, 0)):
636    qa, qb, qc = invert_view(qx, qy, view)
637    qab = sqrt(qa**2 + qb**2)
638    Fq = sas_2J1x_x(qab*radius) * sas_sinx_x(qc*length/2)
639    Iq = Fq**2
640    return Iq.reshape(qx.shape)
641
644    return Iq
645
646def box_Iq(q, a, b, c):
647    z, w = leggauss(76)
648    outer_sum = np.zeros_like(q)
649    for cos_alpha, outer_w in zip((z+1)/2, w):
650        sin_alpha = sqrt(1.0-cos_alpha*cos_alpha)
651        qc = q*cos_alpha
652        siC = c*sas_sinx_x(c*qc/2)
653        inner_sum = np.zeros_like(q)
654        for beta, inner_w in zip((z + 1)*pi/4, w):
655            qa, qb = q*sin_alpha*sin(beta), q*sin_alpha*cos(beta)
656            siA = a*sas_sinx_x(a*qa/2)
657            siB = b*sas_sinx_x(b*qb/2)
658            Fq = siA*siB*siC
659            inner_sum += inner_w * Fq**2
660        outer_sum += outer_w * inner_sum
661    Iq = outer_sum / 4  # = outer*um*zm*8.0/(4.0*M_PI)
662    return Iq
663
664def box_Iqxy(qx, qy, a, b, c, view=(0, 0, 0)):
665    qa, qb, qc = invert_view(qx, qy, view)
666    sia = sas_sinx_x(qa*a/2)
667    sib = sas_sinx_x(qb*b/2)
668    sic = sas_sinx_x(qc*c/2)
669    Fq = sia*sib*sic
670    Iq = Fq**2
671    return Iq.reshape(qx.shape)
672
673def csbox_Iq(q, a, b, c, da, db, dc, slda, sldb, sldc, sld_core):
674    z, w = leggauss(76)
675
676    sld_solvent = 0
677    overlapping = False
678    dr0 = sld_core - sld_solvent
679    drA, drB, drC = slda-sld_solvent, sldb-sld_solvent, sldc-sld_solvent
680    tA, tB, tC = a + 2*da, b + 2*db, c + 2*dc
681
682    outer_sum = np.zeros_like(q)
683    for cos_alpha, outer_w in zip((z+1)/2, w):
684        sin_alpha = sqrt(1.0-cos_alpha*cos_alpha)
685        qc = q*cos_alpha
686        siC = c*sas_sinx_x(c*qc/2)
687        siCt = tC*sas_sinx_x(tC*qc/2)
688        inner_sum = np.zeros_like(q)
689        for beta, inner_w in zip((z + 1)*pi/4, w):
690            qa, qb = q*sin_alpha*sin(beta), q*sin_alpha*cos(beta)
691            siA = a*sas_sinx_x(a*qa/2)
692            siB = b*sas_sinx_x(b*qb/2)
693            siAt = tA*sas_sinx_x(tA*qa/2)
694            siBt = tB*sas_sinx_x(tB*qb/2)
695            if overlapping:
696                Fq = (dr0*siA*siB*siC
697                      + drA*(siAt-siA)*siB*siC
698                      + drB*siAt*(siBt-siB)*siC
699                      + drC*siAt*siBt*(siCt-siC))
700            else:
701                Fq = (dr0*siA*siB*siC
702                      + drA*(siAt-siA)*siB*siC
703                      + drB*siA*(siBt-siB)*siC
704                      + drC*siA*siB*(siCt-siC))
705            inner_sum += inner_w * Fq**2
706        outer_sum += outer_w * inner_sum
707    Iq = outer_sum / 4  # = outer*um*zm*8.0/(4.0*M_PI)
708    return Iq/Iq[0]
709
710def csbox_Iqxy(qx, qy, a, b, c, da, db, dc, slda, sldb, sldc, sld_core, view=(0,0,0)):
711    qa, qb, qc = invert_view(qx, qy, view)
712
713    sld_solvent = 0
714    overlapping = False
715    dr0 = sld_core - sld_solvent
716    drA, drB, drC = slda-sld_solvent, sldb-sld_solvent, sldc-sld_solvent
717    tA, tB, tC = a + 2*da, b + 2*db, c + 2*dc
718    siA = a*sas_sinx_x(a*qa/2)
719    siB = b*sas_sinx_x(b*qb/2)
720    siC = c*sas_sinx_x(c*qc/2)
721    siAt = tA*sas_sinx_x(tA*qa/2)
722    siBt = tB*sas_sinx_x(tB*qb/2)
723    siCt = tC*sas_sinx_x(tC*qc/2)
724    Fq = (dr0*siA*siB*siC
725          + drA*(siAt-siA)*siB*siC
726          + drB*siA*(siBt-siB)*siC
727          + drC*siA*siB*(siCt-siC))
728    Iq = Fq**2
729    return Iq.reshape(qx.shape)
730
731def sasmodels_Iq(kernel, q, pars):
732    from sasmodels.data import empty_data1D
733    from sasmodels.direct_model import DirectModel
734    data = empty_data1D(q)
735    calculator = DirectModel(data, kernel)
736    Iq = calculator(**pars)
737    return Iq
738
739def sasmodels_Iqxy(kernel, qx, qy, pars, view):
740    from sasmodels.data import Data2D
741    from sasmodels.direct_model import DirectModel
742    Iq = 100 * np.ones_like(qx)
743    data = Data2D(x=qx, y=qy, z=Iq, dx=None, dy=None, dz=np.sqrt(Iq))
744    data.x_bins = qx[0, :]
745    data.y_bins = qy[:, 0]
746    data.filename = "fake data"
747
748    calculator = DirectModel(data, kernel)
749    pars_plus_view = pars.copy()
750    pars_plus_view.update(theta=view[0], phi=view[1], psi=view[2])
751    Iqxy = calculator(**pars_plus_view)
752    return Iqxy.reshape(qx.shape)
753
754def wrap_sasmodel(name, **pars):
757    fn = lambda q: sasmodels_Iq(kernel, q, pars)
758    fn_xy = lambda qx, qy, view: sasmodels_Iqxy(kernel, qx, qy, pars, view)
759    return fn, fn_xy
760
761
762# --------- Test cases -----------
763
766    fn = lambda q: cylinder_Iq(q, radius, length)*rho**2
767    fn_xy = lambda qx, qy, view: cylinder_Iqxy(qx, qy, radius, length, view=view)*rho**2
768    return shape, fn, fn_xy
769
771DEFAULT_SPHERE_CONTRAST = 2
774    fn = lambda q: sphere_Iq(q, radius)*rho**2
775    fn_xy = lambda qx, qy, view: sphere_Iq(np.sqrt(qx**2+qy**2), radius)*rho**2
776    return shape, fn, fn_xy
777
778def build_box(a=10, b=20, c=30, rho=2.):
779    shape = Box(a, b, c, rho)
780    fn = lambda q: box_Iq(q, a, b, c)*rho**2
781    fn_xy = lambda qx, qy, view: box_Iqxy(qx, qy, a, b, c, view=view)*rho**2
782    return shape, fn, fn_xy
783
784def build_csbox(a=10, b=20, c=30, da=1, db=2, dc=3, slda=1, sldb=2, sldc=3, sld_core=4):
785    shape = csbox(a, b, c, da, db, dc, slda, sldb, sldc, sld_core)
786    fn = lambda q: csbox_Iq(q, a, b, c, da, db, dc, slda, sldb, sldc, sld_core)
787    fn_xy = lambda qx, qy, view: csbox_Iqxy(qx, qy, a, b, c, da, db, dc,
788                                            slda, sldb, sldc, sld_core, view=view)
789    return shape, fn, fn_xy
790
791def build_ellcyl(ra=25, rb=50, length=125, rho=2.):
792    shape = EllipticalCylinder(ra, rb, length, rho)
793    fn, fn_xy = wrap_sasmodel(
794        'elliptical_cylinder',
795        scale=1,
796        background=0,
798        axis_ratio=rb/ra,
799        length=length,
800        sld=rho,
801        sld_solvent=0,
802    )
803    return shape, fn, fn_xy
804
805def build_cscyl(ra=30, rb=90, length=30, thick_rim=8, thick_face=14,
806                sld_core=4, sld_rim=1, sld_face=7):
807    shape = EllipticalBicelle(
808        ra=ra, rb=rb, length=length,
809        thick_rim=thick_rim, thick_face=thick_face,
810        value_core=sld_core, value_rim=sld_rim, value_face=sld_face,
811        )
812    fn, fn_xy = wrap_sasmodel(
813        'core_shell_bicelle_elliptical',
814        scale=1,
815        background=0,
817        x_core=rb/ra,
818        length=length,
819        thick_rim=thick_rim,
820        thick_face=thick_face,
821        sld_core=sld_core,
822        sld_face=sld_face,
823        sld_rim=sld_rim,
824        sld_solvent=0,
825    )
826    return shape, fn, fn_xy
827
828def build_sc_lattice(shape, nx=1, ny=1, nz=1, dx=2, dy=2, dz=2,
829                        shuffle=0, rotate=0):
830    a, b, c = shape.dims
831    corners= [copy(shape)
832              .shift((ix+(randn() if shuffle < 0.3 else rand())*shuffle)*dx*a,
833                     (iy+(randn() if shuffle < 0.3 else rand())*shuffle)*dy*b,
834                     (iz+(randn() if shuffle < 0.3 else rand())*shuffle)*dz*c)
835              .rotate(*((randn(3) if rotate < 30 else rand(3))*rotate))
836              for ix in range(nx)
837              for iy in range(ny)
838              for iz in range(nz)]
839    lattice = Composite(corners)
840    return lattice
841
842def build_bcc_lattice(shape, nx=1, ny=1, nz=1, dx=2, dy=2, dz=2,
843                      shuffle=0, rotate=0):
844    a, b, c = shape.dims
845    corners = [copy(shape)
846               .shift((ix+(randn() if shuffle < 0.3 else rand())*shuffle)*dx*a,
847                      (iy+(randn() if shuffle < 0.3 else rand())*shuffle)*dy*b,
848                      (iz+(randn() if shuffle < 0.3 else rand())*shuffle)*dz*c)
849               .rotate(*((randn(3) if rotate < 30 else rand(3))*rotate))
850               for ix in range(nx)
851               for iy in range(ny)
852               for iz in range(nz)]
853    centers = [copy(shape)
854               .shift((ix+0.5+(randn() if shuffle < 0.3 else rand())*shuffle)*dx*a,
855                      (iy+0.5+(randn() if shuffle < 0.3 else rand())*shuffle)*dy*b,
856                      (iz+0.5+(randn() if shuffle < 0.3 else rand())*shuffle)*dz*c)
857               .rotate(*((randn(3) if rotate < 30 else rand(3))*rotate))
858               for ix in range(nx)
859               for iy in range(ny)
860               for iz in range(nz)]
861    lattice = Composite(corners + centers)
862    return lattice
863
864def build_fcc_lattice(shape, nx=1, ny=1, nz=1, dx=2, dy=2, dz=2,
865                      shuffle=0, rotate=0):
866    a, b, c = shape.dims
867    corners = [copy(shape)
868               .shift((ix+(randn() if shuffle < 0.3 else rand())*shuffle)*dx*a,
869                      (iy+(randn() if shuffle < 0.3 else rand())*shuffle)*dy*b,
870                      (iz+(randn() if shuffle < 0.3 else rand())*shuffle)*dz*c)
871               .rotate(*((randn(3) if rotate < 30 else rand(3))*rotate))
872               for ix in range(nx)
873               for iy in range(ny)
874               for iz in range(nz)]
875    faces_a = [copy(shape)
876               .shift((ix+0.0+(randn() if shuffle < 0.3 else rand())*shuffle)*dx*a,
877                      (iy+0.5+(randn() if shuffle < 0.3 else rand())*shuffle)*dy*b,
878                      (iz+0.5+(randn() if shuffle < 0.3 else rand())*shuffle)*dz*c)
879               .rotate(*((randn(3) if rotate < 30 else rand(3))*rotate))
880               for ix in range(nx)
881               for iy in range(ny)
882               for iz in range(nz)]
883    faces_b = [copy(shape)
884               .shift((ix+0.5+(randn() if shuffle < 0.3 else rand())*shuffle)*dx*a,
885                      (iy+0.0+(randn() if shuffle < 0.3 else rand())*shuffle)*dy*b,
886                      (iz+0.5+(randn() if shuffle < 0.3 else rand())*shuffle)*dz*c)
887               .rotate(*((randn(3) if rotate < 30 else rand(3))*rotate))
888               for ix in range(nx)
889               for iy in range(ny)
890               for iz in range(nz)]
891    faces_c = [copy(shape)
892               .shift((ix+0.5+(randn() if shuffle < 0.3 else rand())*shuffle)*dx*a,
893                      (iy+0.5+(randn() if shuffle < 0.3 else rand())*shuffle)*dy*b,
894                      (iz+0.0+(randn() if shuffle < 0.3 else rand())*shuffle)*dz*c)
895               .rotate(*((randn(3) if rotate < 30 else rand(3))*rotate))
896               for ix in range(nx)
897               for iy in range(ny)
898               for iz in range(nz)]
899    lattice = Composite(corners + faces_a + faces_b + faces_c)
900    return lattice
901
902SHAPE_FUNCTIONS = OrderedDict([
903    ("cyl", build_cylinder),
904    ("ellcyl", build_ellcyl),
905    ("sphere", build_sphere),
906    ("box", build_box),
907    ("csbox", build_csbox),
908    ("cscyl", build_cscyl),
909])
910SHAPES = list(SHAPE_FUNCTIONS.keys())
911LATTICE_FUNCTIONS = OrderedDict([
912    ("sc", build_sc_lattice),
913    ("bcc", build_bcc_lattice),
914    ("fcc", build_fcc_lattice),
915])
916LATTICE_TYPES = list(LATTICE_FUNCTIONS.keys())
917
918def check_shape(title, shape, fn=None, show_points=False,
919                mesh=100, qmax=1.0, r_step=0.01, samples=5000):
920    rho_solvent = 0
921    qmin = qmax/100.
922    q = np.logspace(np.log10(qmin), np.log10(qmax), mesh)
923    r = shape.r_bins(q, r_step=r_step)
924    sampling_density = samples / shape.volume
925    print("sampling points")
926    rho, points = shape.sample(sampling_density)
927    print("calculating Pr")
928    t0 = time.time()
929    Pr = calc_Pr(r, rho-rho_solvent, points)
930    print("calc Pr time", time.time() - t0)
931    Iq = calc_Iq(q, r, Pr)
932    theory = (q, fn(q)) if fn is not None else None
933
934    import pylab
935    if show_points:
936        plot_points(rho, points); pylab.figure()
937    plot_calc(r, Pr, q, Iq, theory=theory, title=title)
938    pylab.gcf().canvas.set_window_title(title)
939    pylab.show()
940
941def check_shape_2d(title, shape, fn=None, view=(0, 0, 0), show_points=False,
942                   mesh=100, qmax=1.0, samples=5000):
943    rho_solvent = 0
944    #qx = np.linspace(0.0, qmax, mesh)
945    #qy = np.linspace(0.0, qmax, mesh)
946    qx = np.linspace(-qmax, qmax, mesh)
947    qy = np.linspace(-qmax, qmax, mesh)
948    Qx, Qy = np.meshgrid(qx, qy)
949    sampling_density = samples / shape.volume
950    print("sampling points")
951    t0 = time.time()
952    rho, points = shape.sample(sampling_density)
953    print("point generation time", time.time() - t0)
954    t0 = time.time()
955    Iqxy = calc_Iqxy(Qx, Qy, rho, points, view=view)
956    print("calc Iqxy time", time.time() - t0)
957    t0 = time.time()
958    theory = fn(Qx, Qy, view) if fn is not None else None
959    print("calc theory time", time.time() - t0)
960    Iqxy += 0.001 * Iqxy.max()
961    if theory is not None:
962        theory += 0.001 * theory.max()
963
964    import pylab
965    if show_points:
966        plot_points(rho, points); pylab.figure()
967    plot_calc_2d(qx, qy, Iqxy, theory=theory, title=title)
968    pylab.gcf().canvas.set_window_title(title)
969    pylab.show()
970
971def main():
972    parser = argparse.ArgumentParser(
973        description="Compute scattering from realspace sampling",
974        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
975        )
977                        help='dimension 1 or 2')
979                        help='number of mesh points')
981                        help="number of sample points")
983                        help='max q')
985                        help='theta,phi,psi angles')
987                        help='lattice size')
989                        help='lattice spacing (relative to shape)')
991                        default=LATTICE_TYPES[0],
992                        help='lattice type')
994                        help="rotation relative to lattice, gaussian < 30 degrees, uniform otherwise")
996                        help="position relative to lattice, gaussian < 0.3, uniform otherwise")
998                        help='plot points')
1000                        help='oriented shape')
1001    parser.add_argument('pars', type=str, nargs='*', help='shape parameters')
1002    opts = parser.parse_args()
1003    pars = {key: float(value) for p in opts.pars for key, value in [p.split('=')]}
1004    nx, ny, nz = [int(v) for v in opts.lattice.split(',')]
1005    dx, dy, dz = [float(v) for v in opts.spacing.split(',')]
1006    distortion, rotation = opts.shuffle, opts.rotate
1007    shape, fn, fn_xy = SHAPE_FUNCTIONS[opts.shape](**pars)
1008    view = tuple(float(v) for v in opts.view.split(','))
1009    # If comparing a sphere in a cubic lattice, compare against the
1010    # corresponding paracrystalline model.
1011    if opts.shape == "sphere" and dx == dy == dz and nx*ny*nz > 1:
1013        model_name = opts.type + "_paracrystal"
1014        model_pars = {
1015            "scale": 1.,
1016            "background": 0.,
1018            "lattice_distortion": distortion,
1020            "sld": pars.get('rho', DEFAULT_SPHERE_CONTRAST),
1021            "sld_solvent": 0.,
1022            "theta": view[0],
1023            "phi": view[1],
1024            "psi": view[2],
1025        }
1026        fn, fn_xy = wrap_sasmodel(model_name, **model_pars)
1027    if nx*ny*nz > 1:
1028        if rotation != 0:
1029            print("building %s lattice"%opts.type)
1030            build_lattice = LATTICE_FUNCTIONS[opts.type]
1031            shape = build_lattice(shape, nx, ny, nz, dx, dy, dz,
1032                                  distortion, rotation)
1033        else:
1034            shape.lattice(size=(nx, ny, nz), spacing=(dx, dy, dz),
1035                          type=opts.type,
1036                          rotation=rotation, distortion=distortion)
1037
1038    title = "%s(%s)" % (opts.shape, " ".join(opts.pars))
1039    if opts.dim == 1:
1040        check_shape(title, shape, fn, show_points=opts.plot,
1041                    mesh=opts.mesh, qmax=opts.qmax, samples=opts.samples)
1042    else:
1043        check_shape_2d(title, shape, fn_xy, view=view, show_points=opts.plot,
1044                       mesh=opts.mesh, qmax=opts.qmax, samples=opts.samples)
1045
1046
1047if __name__ == "__main__":
1048    main()
Note: See TracBrowser for help on using the repository browser.