1 | #!/usr/bin/env python |
---|
2 | r""" |
---|
3 | Multiple scattering calculator |
---|
4 | |
---|
5 | Calculate multiple scattering using 2D FFT convolution. |
---|
6 | |
---|
7 | Usage: |
---|
8 | |
---|
9 | -f, --fraction: the scattering probability |
---|
10 | -q, --qmax: that max q that you care about |
---|
11 | -w, --window: the extension window (q is calculated for qmax*window) |
---|
12 | -n, --nq: the number of mesh points (dq = qmax*window/nq) |
---|
13 | -r, --random: generate a random parameter set |
---|
14 | -2, --2d: perform the calculation for an oriented pattern |
---|
15 | model_name |
---|
16 | model_par=value ... |
---|
17 | |
---|
18 | Assume the probability of scattering is $p$. After each scattering event, |
---|
19 | $1-p$ neutrons will leave the system and go to the detector, and the remaining |
---|
20 | $p$ will scatter again. |
---|
21 | |
---|
22 | Let the scattering probability for a single scattering event at $q$ be $f(q)$, |
---|
23 | where |
---|
24 | .. math:: f(q) = p \frac{I_1(q)}{\int I_1(q) dq} |
---|
25 | for $I_1(q)$, the single scattering from the system. After two scattering |
---|
26 | events, the scattering will be the convolution of the first scattering |
---|
27 | and itself, or $(f*f)(q)$. After $n$ events it will be |
---|
28 | $(f* \cdots * f)(q)$. |
---|
29 | |
---|
30 | The total scattering after multiple events will then be the number that didn't |
---|
31 | scatter the first time $(=1-p)$ plus the number that scattered only once |
---|
32 | $(=(1-p)f)$ plus the number that scattered only twice $(=(1-p)(f*f))$, etc., |
---|
33 | so |
---|
34 | .. math:: |
---|
35 | I(q) = (1-p)\sum_{k=0}^{\infty} f^{*k}(q) |
---|
36 | Since convolution is linear, the contributions from the individual terms |
---|
37 | will fall off as $p^k$, so we can cut off the series |
---|
38 | at $n = \lceil \ln C / \ln p \rceil$ for some small fraction $C$ (default is |
---|
39 | $C = 0.001$). |
---|
40 | |
---|
41 | Using the convolution theorem, where |
---|
42 | $F = \mathcal{F}(f)$ is the Fourier transform, |
---|
43 | .. math:: |
---|
44 | |
---|
45 | f * g = \mathcal{F}^{-1}\{\mathcal{F}\{f\} \cdot \mathcal{F}\{g\}\} |
---|
46 | |
---|
47 | so |
---|
48 | .. math:: |
---|
49 | |
---|
50 | f * \ldots * f = \mathcal{F}^{-1}\{ F^n \} |
---|
51 | |
---|
52 | Since the Fourier transform is a linear operator, we can move the polynomial |
---|
53 | expression for the convolution into the transform, giving |
---|
54 | .. math:: |
---|
55 | |
---|
56 | I(q) = \mathcal{F}^{-1}\left\{ (1-p) \sum_{k=0}^{n} F^k \right\} |
---|
57 | |
---|
58 | We drop the transmission term, $k=0$, and rescale the result by the |
---|
59 | total scattering $\int I_1(q) dq$. |
---|
60 | |
---|
61 | For speed we use the fast fourier transform for a power of two. The resulting |
---|
62 | $I(q)$ will be linearly spaced and likely heavily oversampled. The usual |
---|
63 | pinhole or slit resolution calculation can performed from these calculated |
---|
64 | values. |
---|
65 | """ |
---|
66 | |
---|
67 | from __future__ import print_function |
---|
68 | |
---|
69 | import argparse |
---|
70 | import time |
---|
71 | |
---|
72 | import numpy as np |
---|
73 | |
---|
74 | from sasmodels import core |
---|
75 | from sasmodels import compare |
---|
76 | from sasmodels import resolution2d |
---|
77 | from sasmodels.resolution import Resolution, bin_edges |
---|
78 | from sasmodels.data import empty_data1D, empty_data2D, plot_data |
---|
79 | from sasmodels.direct_model import call_kernel |
---|
80 | |
---|
81 | |
---|
82 | class MultipleScattering(Resolution): |
---|
83 | def __init__(self, qmax, nq, fraction, is2d, window=2, power=0): |
---|
84 | self.qmax = qmax |
---|
85 | self.nq = nq |
---|
86 | self.power = power |
---|
87 | self.fraction = fraction |
---|
88 | self.is2d = is2d |
---|
89 | self.window = window |
---|
90 | |
---|
91 | q_range = qmax * window |
---|
92 | q = np.linspace(-q_range, q_range, nq) |
---|
93 | qx, qy = np.meshgrid(q, q) |
---|
94 | |
---|
95 | if is2d: |
---|
96 | q_calc = [qx.flatten(), qy.flatten()] |
---|
97 | else: |
---|
98 | q_range_corners = np.sqrt(2.) * q_range |
---|
99 | nq_corners = int(np.sqrt(2.) * nq/2) |
---|
100 | q_corners = np.linspace(0, q_range_corners, nq_corners+1)[1:] |
---|
101 | q_calc = [q_corners] |
---|
102 | self._qxy = np.sqrt(qx**2 + qy**2) |
---|
103 | self._edges = bin_edges(q_corners) |
---|
104 | self._norm = np.histogram(self._qxy, bins=self._edges)[0] |
---|
105 | |
---|
106 | self.q_calc = q_calc |
---|
107 | self.q_range = q_range |
---|
108 | |
---|
109 | def apply(self, theory): |
---|
110 | t0 = time.time() |
---|
111 | if self.is2d: |
---|
112 | Iq_calc = theory |
---|
113 | else: |
---|
114 | q_corners = self.q_calc[0] |
---|
115 | Iq_calc = np.interp(self._qxy, q_corners, theory) |
---|
116 | Iq_calc = Iq_calc.reshape(self.nq, self.nq) |
---|
117 | #plotxy(Iq_calc); import pylab; pylab.figure() |
---|
118 | if self.power > 0: |
---|
119 | Iqxy = autoconv_n(Iq_calc, self.power) |
---|
120 | else: |
---|
121 | Iqxy = multiple_scattering(Iq_calc, self.fraction, cutoff=0.001) |
---|
122 | print("multiple scattering calc time", time.time()-t0) |
---|
123 | #plotxy(Iqxy); import pylab; pylab.figure() |
---|
124 | if self.is2d: |
---|
125 | if 1: |
---|
126 | import pylab |
---|
127 | plotxy(Iq_calc) |
---|
128 | pylab.title("single scattering") |
---|
129 | pylab.figure() |
---|
130 | |
---|
131 | return Iqxy |
---|
132 | else: |
---|
133 | # circular average, no anti-aliasing |
---|
134 | Iq = np.histogram(self._qxy, bins=self._edges, weights=Iqxy)[0]/self._norm |
---|
135 | |
---|
136 | if 1: |
---|
137 | import pylab |
---|
138 | pylab.loglog(q_corners, theory, label="single scattering") |
---|
139 | if self.power > 0: |
---|
140 | label = "scattering power %d"%self.power |
---|
141 | else: |
---|
142 | label = "scattering fraction %d"%self.fraction |
---|
143 | pylab.loglog(q_corners, Iq, label=label) |
---|
144 | pylab.legend() |
---|
145 | pylab.figure() |
---|
146 | return Iqxy |
---|
147 | return q_corners, Iq |
---|
148 | |
---|
149 | def multiple_scattering(Iq_calc, frac, cutoff=0.001): |
---|
150 | #plotxy(Iq_calc) |
---|
151 | num_scatter = int(np.ceil(np.log(cutoff)/np.log(frac))) |
---|
152 | |
---|
153 | # Prepare padded array for transform |
---|
154 | nq = Iq_calc.shape[0] |
---|
155 | half_nq = nq//2 |
---|
156 | frame = np.zeros((2*nq, 2*nq)) |
---|
157 | frame[:half_nq, :half_nq] = Iq_calc[half_nq:, half_nq:] |
---|
158 | frame[-half_nq:, :half_nq] = Iq_calc[:half_nq, half_nq:] |
---|
159 | frame[:half_nq, -half_nq:] = Iq_calc[half_nq:, :half_nq] |
---|
160 | frame[-half_nq:, -half_nq:] = Iq_calc[:half_nq, :half_nq] |
---|
161 | #plotxy(frame) |
---|
162 | |
---|
163 | # Compute multiple scattering via convolution. |
---|
164 | scale = np.sum(Iq_calc) |
---|
165 | fourier_frame = np.fft.fft2(frame/scale) |
---|
166 | #plotxy(abs(frame)) |
---|
167 | # total = (1-a)f + (1-a)af^2 + (1-a)a^2f^3 + ... |
---|
168 | # = (1-a)f[1 + af + (af)^2 + (af)^3 + ...] |
---|
169 | multiple_scattering = ( |
---|
170 | (1-frac)*fourier_frame |
---|
171 | *np.polyval(np.ones(num_scatter), frac*fourier_frame)) |
---|
172 | conv_frame = scale*np.fft.ifft2(multiple_scattering).real |
---|
173 | |
---|
174 | # Recover the transformed data |
---|
175 | #plotxy(conv_frame) |
---|
176 | Iq_conv = np.empty((nq, nq)) |
---|
177 | Iq_conv[half_nq:, half_nq:] = conv_frame[:half_nq, :half_nq] |
---|
178 | Iq_conv[:half_nq, half_nq:] = conv_frame[-half_nq:, :half_nq] |
---|
179 | Iq_conv[half_nq:, :half_nq] = conv_frame[:half_nq, -half_nq:] |
---|
180 | Iq_conv[:half_nq, :half_nq] = conv_frame[-half_nq:, -half_nq:] |
---|
181 | #plotxy(Iq_conv) |
---|
182 | return Iq_conv |
---|
183 | |
---|
184 | def multiple_scattering_cl(Iq_calc, frac, cutoff=0.001): |
---|
185 | raise NotImplementedError("no support for opencl calculations at this time") |
---|
186 | |
---|
187 | import pyopencl as cl |
---|
188 | import pyopencl.array as cla |
---|
189 | from gpyfft.fft import FFT |
---|
190 | context = cl.create_some_context() |
---|
191 | queue = cl.CommandQueue(context) |
---|
192 | |
---|
193 | #plotxy(Iq_calc) |
---|
194 | num_scatter = int(np.ceil(np.log(cutoff)/np.log(frac))) |
---|
195 | |
---|
196 | # Prepare padded array for transform |
---|
197 | nq = Iq_calc.shape[0] |
---|
198 | half_nq = nq//2 |
---|
199 | frame = np.zeros((2*nq, 2*nq), dtype='float32') |
---|
200 | frame[:half_nq, :half_nq] = Iq_calc[half_nq:, half_nq:] |
---|
201 | frame[-half_nq:, :half_nq] = Iq_calc[:half_nq, half_nq:] |
---|
202 | frame[:half_nq, -half_nq:] = Iq_calc[half_nq:, :half_nq] |
---|
203 | frame[-half_nq:, -half_nq:] = Iq_calc[:half_nq, :half_nq] |
---|
204 | #plotxy(frame) |
---|
205 | |
---|
206 | # Compute multiple scattering via convolution (OpenCL operations) |
---|
207 | frame_gpu = cla.to_device(queue, frame) |
---|
208 | fourier_frame_gpu = cla.zeros(frame.shape, dtype='complex64') |
---|
209 | scale = frame_gpu.sum() |
---|
210 | frame_gpu /= scale |
---|
211 | transform = FFT(context, queue, frame_gpu, fourier_frame_gpu, axes=(0,1)) |
---|
212 | event, = transform.enqueue() |
---|
213 | event.wait() |
---|
214 | fourier_frame_gpu *= frac |
---|
215 | multiple_scattering_gpu = fourier_frame_gpu.copy() |
---|
216 | for _ in range(num_scatter-1): |
---|
217 | multiple_scattering_gpu += 1 |
---|
218 | multiple_scattering_gpu *= fourier_frame_gpu |
---|
219 | multiple_scattering_gpu *= (1 - frac)/frac |
---|
220 | transform = FFT(context, queue, multiple_scattering_gpu, frame_gpu, axes=(0,1)) |
---|
221 | event, = transform.enqueue(forward=False) |
---|
222 | event.wait() |
---|
223 | conv_frame = cla.from_device(queue, frame_gpu) |
---|
224 | |
---|
225 | # Recover the transformed data |
---|
226 | #plotxy(conv_frame) |
---|
227 | Iq_conv = np.empty((nq, nq)) |
---|
228 | Iq_conv[half_nq:, half_nq:] = conv_frame[:half_nq, :half_nq] |
---|
229 | Iq_conv[:half_nq, half_nq:] = conv_frame[-half_nq:, :half_nq] |
---|
230 | Iq_conv[half_nq:, :half_nq] = conv_frame[:half_nq, -half_nq:] |
---|
231 | Iq_conv[:half_nq, :half_nq] = conv_frame[-half_nq:, -half_nq:] |
---|
232 | #plotxy(Iq_conv) |
---|
233 | return Iq_conv |
---|
234 | |
---|
235 | def autoconv_n(Iq_calc, power): |
---|
236 | # Compute multiple scattering via convolution. |
---|
237 | #plotxy(Iq_calc) |
---|
238 | scale = np.sum(Iq_calc) |
---|
239 | nq = Iq_calc.shape[0] |
---|
240 | frame = np.zeros((2*nq, 2*nq)) |
---|
241 | half_nq = nq//2 |
---|
242 | frame[:half_nq, :half_nq] = Iq_calc[half_nq:, half_nq:] |
---|
243 | frame[-half_nq:, :half_nq] = Iq_calc[:half_nq, half_nq:] |
---|
244 | frame[:half_nq, -half_nq:] = Iq_calc[half_nq:, :half_nq] |
---|
245 | frame[-half_nq:, -half_nq:] = Iq_calc[:half_nq, :half_nq] |
---|
246 | #plotxy(frame) |
---|
247 | fourier_frame = np.fft.fft2(frame/scale) |
---|
248 | #plotxy(abs(frame)) |
---|
249 | fourier_frame = fourier_frame**power |
---|
250 | conv_frame = scale*np.fft.ifft2(fourier_frame).real |
---|
251 | #plotxy(conv_frame) |
---|
252 | Iq_conv = np.empty((nq, nq)) |
---|
253 | Iq_conv[half_nq:, half_nq:] = conv_frame[:half_nq, :half_nq] |
---|
254 | Iq_conv[:half_nq, half_nq:] = conv_frame[-half_nq:, :half_nq] |
---|
255 | Iq_conv[half_nq:, :half_nq] = conv_frame[:half_nq, -half_nq:] |
---|
256 | Iq_conv[:half_nq, :half_nq] = conv_frame[-half_nq:, -half_nq:] |
---|
257 | #plotxy(Iq_conv) |
---|
258 | return Iq_conv |
---|
259 | |
---|
260 | def parse_pars(model, opts): |
---|
261 | # type: (ModelInfo, argparse.Namespace) -> Dict[str, float] |
---|
262 | |
---|
263 | seed = np.random.randint(1000000) if opts.random and opts.seed < 0 else opts.seed |
---|
264 | compare_opts = { |
---|
265 | 'info': (model.info, model.info), |
---|
266 | 'use_demo': False, |
---|
267 | 'seed': seed, |
---|
268 | 'mono': True, |
---|
269 | 'magnetic': False, |
---|
270 | 'values': opts.pars, |
---|
271 | 'show_pars': True, |
---|
272 | 'is2d': opts.is2d, |
---|
273 | } |
---|
274 | pars, pars2 = compare.parse_pars(compare_opts) |
---|
275 | return pars |
---|
276 | |
---|
277 | |
---|
278 | def main(): |
---|
279 | parser = argparse.ArgumentParser( |
---|
280 | description="Compute multiple scattering", |
---|
281 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
---|
282 | ) |
---|
283 | parser.add_argument('-p', '--power', type=int, default=0, help="show pattern for nth scattering") |
---|
284 | parser.add_argument('-f', '--fraction', type=float, default=0.1, help="scattering fraction") |
---|
285 | parser.add_argument('-n', '--nq', type=int, default=1024, help='number of mesh points') |
---|
286 | parser.add_argument('-q', '--qmax', type=float, default=0.5, help='max q') |
---|
287 | parser.add_argument('-w', '--window', type=float, default=2.0, help='q calc = q max * window') |
---|
288 | parser.add_argument('-2', '--2d', dest='is2d', action='store_true', help='oriented sample') |
---|
289 | parser.add_argument('-s', '--seed', default=-1, help='random pars with given seed') |
---|
290 | parser.add_argument('-r', '--random', action='store_true', help='random pars with random seed') |
---|
291 | parser.add_argument('model', type=str, help='sas model name such as cylinder') |
---|
292 | parser.add_argument('pars', type=str, nargs='*', help='model parameters such as radius=30') |
---|
293 | opts = parser.parse_args() |
---|
294 | assert opts.nq%2 == 0, "require even # points" |
---|
295 | |
---|
296 | model = core.load_model(opts.model) |
---|
297 | pars = parse_pars(model, opts) |
---|
298 | res = MultipleScattering(opts.qmax, opts.nq, opts.fraction, opts.is2d, |
---|
299 | window=opts.window, power=opts.power) |
---|
300 | kernel = model.make_kernel(res.q_calc) |
---|
301 | #print(pars) |
---|
302 | bg = pars.get('background', 0.0) |
---|
303 | pars['background'] = 0.0 |
---|
304 | Iq_calc = call_kernel(kernel, pars) |
---|
305 | t0 = time.time() |
---|
306 | for i in range(10): |
---|
307 | Iq_calc = call_kernel(kernel, pars) |
---|
308 | print("single scattering calc time", (time.time()-t0)/10) |
---|
309 | Iq = res.apply(Iq_calc) + bg |
---|
310 | plotxy(Iq) |
---|
311 | import pylab |
---|
312 | if opts.power > 0: |
---|
313 | pylab.title('scattering power %d'%opts.power) |
---|
314 | else: |
---|
315 | pylab.title('multiple scattering with fraction %g'%opts.fraction) |
---|
316 | pylab.show() |
---|
317 | |
---|
318 | def plotxy(Iq): |
---|
319 | import pylab |
---|
320 | if isinstance(Iq, tuple): |
---|
321 | q, Iq = Iq |
---|
322 | pylab.loglog(q, Iq) |
---|
323 | else: |
---|
324 | data = Iq+0. |
---|
325 | data[Iq <= 0] = np.min(Iq[Iq>0])/2 |
---|
326 | pylab.imshow(np.log10(data)) |
---|
327 | #import pylab; pylab.show() |
---|
328 | |
---|
329 | if __name__ == "__main__": |
---|
330 | main() |
---|