source: sasmodels/explore/multiscat.py @ 8bd379a

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 8bd379a was 049e02d, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

explore/multiscat.py: slightly improved help

  • Property mode set to 100644
File size: 11.9 KB
Line 
1#!/usr/bin/env python
2r"""
3Multiple scattering calculator
4
5Calculate multiple scattering using 2D FFT convolution.
6
7Usage:
8
9    -f, --fraction: the scattering probability
10    -q, --qmax: that max q that you care about
11    -w, --window: the extension window (q is calculated for qmax*window)
12    -n, --nq: the number of mesh points (dq = qmax*window/nq)
13    -r, --random: generate a random parameter set
14    -2, --2d: perform the calculation for an oriented pattern
15    model_name
16    model_par=value ...
17
18Assume the probability of scattering is $p$. After each scattering event,
19$1-p$ neutrons will leave the system and go to the detector, and the remaining
20$p$ will scatter again.
21
22Let the scattering probability for a single scattering event at $q$ be $f(q)$,
23where
24.. math:: f(q) = p \frac{I_1(q)}{\int I_1(q) dq}
25for $I_1(q)$, the single scattering from the system. After two scattering
26events, the scattering will be the convolution of the first scattering
27and itself, or $(f*f)(q)$.  After $n$ events it will be
28$(f* \cdots * f)(q)$.
29
30The total scattering after multiple events will then be the number that didn't
31scatter the first time $(=1-p)$ plus the number that scattered only once
32$(=(1-p)f)$ plus the number that scattered only twice $(=(1-p)(f*f))$, etc.,
33so
34.. math::
35    I(q) = (1-p)\sum_{k=0}^{\infty} f^{*k}(q)
36Since convolution is linear, the contributions from the individual terms
37will fall off as $p^k$, so we can cut off the series
38at $n = \lceil \ln C / \ln p \rceil$ for some small fraction $C$ (default is
39$C = 0.001$).
40
41Using the convolution theorem, where
42$F = \mathcal{F}(f)$ is the Fourier transform,
43.. math::
44
45    f * g = \mathcal{F}^{-1}\{\mathcal{F}\{f\} \cdot \mathcal{F}\{g\}\}
46
47so
48.. math::
49
50    f * \ldots * f = \mathcal{F}^{-1}\{ F^n \}
51
52Since the Fourier transform is a linear operator, we can move the polynomial
53expression for the convolution into the transform, giving
54.. math::
55
56    I(q) = \mathcal{F}^{-1}\left\{ (1-p) \sum_{k=0}^{n} F^k \right\}
57
58We drop the transmission term, $k=0$, and rescale the result by the
59total scattering $\int I_1(q) dq$.
60
61For speed we use the fast fourier transform for a power of two.  The resulting
62$I(q)$ will be linearly spaced and likely heavily oversampled.  The usual
63pinhole or slit resolution calculation can performed from these calculated
64values.
65"""
66
67from __future__ import print_function
68
69import argparse
70import time
71
72import numpy as np
73
74from sasmodels import core
75from sasmodels import compare
76from sasmodels import resolution2d
77from sasmodels.resolution import Resolution, bin_edges
78from sasmodels.data import empty_data1D, empty_data2D, plot_data
79from sasmodels.direct_model import call_kernel
80
81
82class MultipleScattering(Resolution):
83    def __init__(self, qmax, nq, fraction, is2d, window=2, power=0):
84        self.qmax = qmax
85        self.nq = nq
86        self.power = power
87        self.fraction = fraction
88        self.is2d = is2d
89        self.window = window
90
91        q_range = qmax * window
92        q = np.linspace(-q_range, q_range, nq)
93        qx, qy = np.meshgrid(q, q)
94
95        if is2d:
96            q_calc = [qx.flatten(), qy.flatten()]
97        else:
98            q_range_corners = np.sqrt(2.) * q_range
99            nq_corners = int(np.sqrt(2.) * nq/2)
100            q_corners = np.linspace(0, q_range_corners, nq_corners+1)[1:]
101            q_calc = [q_corners]
102            self._qxy = np.sqrt(qx**2 + qy**2)
103            self._edges = bin_edges(q_corners)
104            self._norm = np.histogram(self._qxy, bins=self._edges)[0]
105
106        self.q_calc = q_calc
107        self.q_range = q_range
108
109    def apply(self, theory):
110        t0 = time.time()
111        if self.is2d:
112            Iq_calc = theory
113        else:
114            q_corners = self.q_calc[0]
115            Iq_calc = np.interp(self._qxy, q_corners, theory)
116        Iq_calc = Iq_calc.reshape(self.nq, self.nq)
117        #plotxy(Iq_calc); import pylab; pylab.figure()
118        if self.power > 0:
119            Iqxy = autoconv_n(Iq_calc, self.power)
120        else:
121            Iqxy = multiple_scattering(Iq_calc, self.fraction, cutoff=0.001)
122        print("multiple scattering calc time", time.time()-t0)
123        #plotxy(Iqxy); import pylab; pylab.figure()
124        if self.is2d:
125            if 1:
126                import pylab
127                plotxy(Iq_calc)
128                pylab.title("single scattering")
129                pylab.figure()
130
131            return Iqxy
132        else:
133            # circular average, no anti-aliasing
134            Iq = np.histogram(self._qxy, bins=self._edges, weights=Iqxy)[0]/self._norm
135
136            if 1:
137                import pylab
138                pylab.loglog(q_corners, theory, label="single scattering")
139                if self.power > 0:
140                    label = "scattering power %d"%self.power
141                else:
142                    label = "scattering fraction %d"%self.fraction
143                pylab.loglog(q_corners, Iq, label=label)
144                pylab.legend()
145                pylab.figure()
146                return Iqxy
147            return q_corners, Iq
148
149def multiple_scattering(Iq_calc, frac, cutoff=0.001):
150    #plotxy(Iq_calc)
151    num_scatter = int(np.ceil(np.log(cutoff)/np.log(frac)))
152
153    # Prepare padded array for transform
154    nq = Iq_calc.shape[0]
155    half_nq = nq//2
156    frame = np.zeros((2*nq, 2*nq))
157    frame[:half_nq, :half_nq] = Iq_calc[half_nq:, half_nq:]
158    frame[-half_nq:, :half_nq] = Iq_calc[:half_nq, half_nq:]
159    frame[:half_nq, -half_nq:] = Iq_calc[half_nq:, :half_nq]
160    frame[-half_nq:, -half_nq:] = Iq_calc[:half_nq, :half_nq]
161    #plotxy(frame)
162
163    # Compute multiple scattering via convolution.
164    scale = np.sum(Iq_calc)
165    fourier_frame = np.fft.fft2(frame/scale)
166    #plotxy(abs(frame))
167    # total = (1-a)f + (1-a)af^2 + (1-a)a^2f^3 + ...
168    #       = (1-a)f[1 + af + (af)^2 + (af)^3 + ...]
169    multiple_scattering = (
170        (1-frac)*fourier_frame
171        *np.polyval(np.ones(num_scatter), frac*fourier_frame))
172    conv_frame = scale*np.fft.ifft2(multiple_scattering).real
173
174    # Recover the transformed data
175    #plotxy(conv_frame)
176    Iq_conv = np.empty((nq, nq))
177    Iq_conv[half_nq:, half_nq:] = conv_frame[:half_nq, :half_nq]
178    Iq_conv[:half_nq, half_nq:] = conv_frame[-half_nq:, :half_nq]
179    Iq_conv[half_nq:, :half_nq] = conv_frame[:half_nq, -half_nq:]
180    Iq_conv[:half_nq, :half_nq] = conv_frame[-half_nq:, -half_nq:]
181    #plotxy(Iq_conv)
182    return Iq_conv
183
184def multiple_scattering_cl(Iq_calc, frac, cutoff=0.001):
185    raise NotImplementedError("no support for opencl calculations at this time")
186
187    import pyopencl as cl
188    import pyopencl.array as cla
189    from gpyfft.fft import FFT
190    context = cl.create_some_context()
191    queue = cl.CommandQueue(context)
192
193    #plotxy(Iq_calc)
194    num_scatter = int(np.ceil(np.log(cutoff)/np.log(frac)))
195
196    # Prepare padded array for transform
197    nq = Iq_calc.shape[0]
198    half_nq = nq//2
199    frame = np.zeros((2*nq, 2*nq), dtype='float32')
200    frame[:half_nq, :half_nq] = Iq_calc[half_nq:, half_nq:]
201    frame[-half_nq:, :half_nq] = Iq_calc[:half_nq, half_nq:]
202    frame[:half_nq, -half_nq:] = Iq_calc[half_nq:, :half_nq]
203    frame[-half_nq:, -half_nq:] = Iq_calc[:half_nq, :half_nq]
204    #plotxy(frame)
205
206    # Compute multiple scattering via convolution (OpenCL operations)
207    frame_gpu = cla.to_device(queue, frame)
208    fourier_frame_gpu = cla.zeros(frame.shape, dtype='complex64')
209    scale = frame_gpu.sum()
210    frame_gpu /= scale
211    transform = FFT(context, queue, frame_gpu, fourier_frame_gpu, axes=(0,1))
212    event, = transform.enqueue()
213    event.wait()
214    fourier_frame_gpu *= frac
215    multiple_scattering_gpu = fourier_frame_gpu.copy()
216    for _ in range(num_scatter-1):
217        multiple_scattering_gpu += 1
218        multiple_scattering_gpu *= fourier_frame_gpu
219    multiple_scattering_gpu *= (1 - frac)/frac
220    transform = FFT(context, queue, multiple_scattering_gpu, frame_gpu, axes=(0,1))
221    event, = transform.enqueue(forward=False)
222    event.wait()
223    conv_frame = cla.from_device(queue, frame_gpu)
224
225    # Recover the transformed data
226    #plotxy(conv_frame)
227    Iq_conv = np.empty((nq, nq))
228    Iq_conv[half_nq:, half_nq:] = conv_frame[:half_nq, :half_nq]
229    Iq_conv[:half_nq, half_nq:] = conv_frame[-half_nq:, :half_nq]
230    Iq_conv[half_nq:, :half_nq] = conv_frame[:half_nq, -half_nq:]
231    Iq_conv[:half_nq, :half_nq] = conv_frame[-half_nq:, -half_nq:]
232    #plotxy(Iq_conv)
233    return Iq_conv
234
235def autoconv_n(Iq_calc, power):
236    # Compute multiple scattering via convolution.
237    #plotxy(Iq_calc)
238    scale = np.sum(Iq_calc)
239    nq = Iq_calc.shape[0]
240    frame = np.zeros((2*nq, 2*nq))
241    half_nq = nq//2
242    frame[:half_nq, :half_nq] = Iq_calc[half_nq:, half_nq:]
243    frame[-half_nq:, :half_nq] = Iq_calc[:half_nq, half_nq:]
244    frame[:half_nq, -half_nq:] = Iq_calc[half_nq:, :half_nq]
245    frame[-half_nq:, -half_nq:] = Iq_calc[:half_nq, :half_nq]
246    #plotxy(frame)
247    fourier_frame = np.fft.fft2(frame/scale)
248    #plotxy(abs(frame))
249    fourier_frame = fourier_frame**power
250    conv_frame = scale*np.fft.ifft2(fourier_frame).real
251    #plotxy(conv_frame)
252    Iq_conv = np.empty((nq, nq))
253    Iq_conv[half_nq:, half_nq:] = conv_frame[:half_nq, :half_nq]
254    Iq_conv[:half_nq, half_nq:] = conv_frame[-half_nq:, :half_nq]
255    Iq_conv[half_nq:, :half_nq] = conv_frame[:half_nq, -half_nq:]
256    Iq_conv[:half_nq, :half_nq] = conv_frame[-half_nq:, -half_nq:]
257    #plotxy(Iq_conv)
258    return Iq_conv
259
260def parse_pars(model, opts):
261    # type: (ModelInfo, argparse.Namespace) -> Dict[str, float]
262
263    seed = np.random.randint(1000000) if opts.random and opts.seed < 0 else opts.seed
264    compare_opts = {
265        'info': (model.info, model.info),
266        'use_demo': False,
267        'seed': seed,
268        'mono': True,
269        'magnetic': False,
270        'values': opts.pars,
271        'show_pars': True,
272        'is2d': opts.is2d,
273    }
274    pars, pars2 = compare.parse_pars(compare_opts)
275    return pars
276
277
278def main():
279    parser = argparse.ArgumentParser(
280        description="Compute multiple scattering",
281        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
282        )
283    parser.add_argument('-p', '--power', type=int, default=0, help="show pattern for nth scattering")
284    parser.add_argument('-f', '--fraction', type=float, default=0.1, help="scattering fraction")
285    parser.add_argument('-n', '--nq', type=int, default=1024, help='number of mesh points')
286    parser.add_argument('-q', '--qmax', type=float, default=0.5, help='max q')
287    parser.add_argument('-w', '--window', type=float, default=2.0, help='q calc = q max * window')
288    parser.add_argument('-2', '--2d', dest='is2d', action='store_true', help='oriented sample')
289    parser.add_argument('-s', '--seed', default=-1, help='random pars with given seed')
290    parser.add_argument('-r', '--random', action='store_true', help='random pars with random seed')
291    parser.add_argument('model', type=str, help='sas model name such as cylinder')
292    parser.add_argument('pars', type=str, nargs='*', help='model parameters such as radius=30')
293    opts = parser.parse_args()
294    assert opts.nq%2 == 0, "require even # points"
295
296    model = core.load_model(opts.model)
297    pars = parse_pars(model, opts)
298    res = MultipleScattering(opts.qmax, opts.nq, opts.fraction, opts.is2d,
299                             window=opts.window, power=opts.power)
300    kernel = model.make_kernel(res.q_calc)
301    #print(pars)
302    bg = pars.get('background', 0.0)
303    pars['background'] = 0.0
304    Iq_calc = call_kernel(kernel, pars)
305    t0 = time.time()
306    for i in range(10):
307        Iq_calc = call_kernel(kernel, pars)
308    print("single scattering calc time", (time.time()-t0)/10)
309    Iq = res.apply(Iq_calc) + bg
310    plotxy(Iq)
311    import pylab
312    if opts.power > 0:
313        pylab.title('scattering power %d'%opts.power)
314    else:
315        pylab.title('multiple scattering with fraction %g'%opts.fraction)
316    pylab.show()
317
318def plotxy(Iq):
319    import pylab
320    if isinstance(Iq, tuple):
321        q, Iq = Iq
322        pylab.loglog(q, Iq)
323    else:
324        data = Iq+0.
325        data[Iq <= 0] = np.min(Iq[Iq>0])/2
326        pylab.imshow(np.log10(data))
327    #import pylab; pylab.show()
328
329if __name__ == "__main__":
330    main()
Note: See TracBrowser for help on using the repository browser.