Changeset b3703f5 in sasmodels for sasmodels/multiscat.py
- Timestamp:
- Mar 26, 2018 5:50:35 PM (7 years ago)
- Branches:
- master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
- Children:
- c11d09f, c462169
- Parents:
- 802c412
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
sasmodels/multiscat.py
rd86f0fc rb3703f5 73 73 import argparse 74 74 import time 75 import os.path76 75 77 76 import numpy as np … … 81 80 from sasmodels import core 82 81 from sasmodels import compare 83 from sasmodels import resolution2d84 82 from sasmodels.resolution import Resolution, bin_edges 85 from sasmodels.data import empty_data1D, empty_data2D, plot_data86 83 from sasmodels.direct_model import call_kernel 87 84 import sasmodels.kernelcl … … 106 103 USE_FAST = True # OpenCL faster, less accurate math 107 104 108 class NumpyCalculator: 105 class ICalculator: 106 """ 107 Multiple scattering calculator 108 """ 109 def fft(self, Iq): 110 """ 111 Compute the forward FFT for an image, real -> complex. 112 """ 113 raise NotImplementedError() 114 115 def ifft(self, Iq): 116 """ 117 Compute the inverse FFT for an image, complex -> complex. 118 """ 119 raise NotImplementedError() 120 121 def mulitple_scattering(self, Iq): 122 r""" 123 Compute multiple scattering for I(q) given scattering probability p. 124 125 Given a probability p of scattering with the thickness, the expected 126 number of scattering events, $\lambda$ is $-\log(1 - p)$, giving a 127 Poisson weighted sum of single, double, triple, etc. scattering patterns. 128 The number of patterns used is based on coverage (default 99%). 129 """ 130 raise NotImplementedError() 131 132 class NumpyCalculator(ICalculator): 133 """ 134 Multiple scattering calculator using numpy fft. 135 """ 109 136 def __init__(self, dims=None, dtype=PRECISION): 110 137 self.dtype = dtype 111 138 self.complex_dtype = np.dtype('F') if dtype == np.dtype('f') else np.dtype('D') 112 pass113 139 114 140 def fft(self, Iq): … … 127 153 128 154 def multiple_scattering(self, Iq, p, coverage=0.99): 129 r"""130 Compute multiple scattering for I(q) given scattering probability p.131 132 Given a probability p of scattering with the thickness, the expected133 number of scattering events, $\lambda$ is $-\log(1 - p)$, giving a134 Poisson weighted sum of single, double, triple, etc. scattering patterns.135 The number of patterns used is based on coverage (default 99%).136 """137 155 #t0 = time.time() 138 156 coeffs = scattering_coeffs(p, coverage) … … 140 158 scale = np.sum(Iq) 141 159 frame = _forward_shift(Iq/scale, dtype=self.dtype) 142 F= np.fft.fft2(frame)143 F_convolved = F * np.polyval(poly, F)144 frame = np.fft.ifft2( F_convolved)160 fourier_frame = np.fft.fft2(frame) 161 convolved = fourier_frame * np.polyval(poly, fourier_frame) 162 frame = np.fft.ifft2(convolved) 145 163 result = scale * _inverse_shift(frame.real, dtype=self.dtype) 146 164 #print("numpy multiscat time", time.time()-t0) … … 173 191 """ 174 192 175 class OpenclCalculator(NumpyCalculator): 193 class OpenclCalculator(ICalculator): 194 """ 195 Multiple scattering calculator using OpenCL via pyfft. 196 """ 176 197 polyval1f = None 177 198 polyval1d = None … … 180 201 context = env.get_context(dtype) 181 202 if dtype == np.dtype('f'): 182 if self.polyval1f is None:203 if OpenclCalculator.polyval1f is None: 183 204 program = sasmodels.kernelcl.compile_model( 184 205 context, POLYVAL1_KERNEL, dtype, fast=USE_FAST) … … 187 208 self.dtype = dtype 188 209 self.complex_dtype = np.dtype('F') 189 self.polyval1 = self.polyval1f210 self.polyval1 = OpenclCalculator.polyval1f 190 211 else: 191 if self.polyval1d is None:212 if OpenclCalculator.polyval1d is None: 192 213 program = sasmodels.kernelcl.compile_model( 193 214 context, POLYVAL1_KERNEL, dtype, fast=False) … … 196 217 self.dtype = dtype 197 218 self.complex_type = np.dtype('D') 198 self.polyval1 = self.polyval1d219 self.polyval1 = OpenclCalculator.polyval1d 199 220 self.queue = env.get_queue(dtype) 200 221 self.plan = pyfft.cl.Plan(dims, queue=self.queue) … … 229 250 gpu_poly = cl_array.to_device(self.queue, poly) 230 251 self.plan.execute(gpu_data.data) 231 degree, n= poly.shape[0], frame.shape[0]*frame.shape[1]252 degree, data_size= poly.shape[0], frame.shape[0]*frame.shape[1] 232 253 self.polyval1( 233 self.queue, [ n], None,234 np.int32(degree), gpu_poly.data, np.int32( n), gpu_data.data)254 self.queue, [data_size], None, 255 np.int32(degree), gpu_poly.data, np.int32(data_size), gpu_data.data) 235 256 self.plan.execute(gpu_data.data, inverse=True) 236 257 frame = gpu_data.get() … … 251 272 """ 252 273 if transform is None: 253 n x, ny = Iq.shape254 transform = Calculator(dims=(n x*2, ny*2), dtype=dtype)274 n_x, n_y = Iq.shape 275 transform = Calculator(dims=(n_x*2, n_y*2), dtype=dtype) 255 276 scale = np.sum(Iq) 256 277 frame = _forward_shift(Iq/scale, dtype=dtype) … … 528 549 def parse_pars(model, opts): 529 550 # type: (ModelInfo, argparse.Namespace) -> Dict[str, float] 551 """ 552 Parse par=val arguments from the command line. 553 """ 530 554 531 555 seed = np.random.randint(1000000) if opts.random and opts.seed < 0 else opts.seed … … 541 565 'is2d': opts.is2d, 542 566 } 543 pars, pars2 = compare.parse_pars(compare_opts) 567 # Note: sascomp allows comparison on a pair of models, so ignore the second. 568 pars, _ = compare.parse_pars(compare_opts) 544 569 return pars 545 570 … … 550 575 formatter_class=argparse.ArgumentDefaultsHelpFormatter, 551 576 ) 552 parser.add_argument('-p', '--probability', type=float, default=0.1, help="scattering probability") 553 parser.add_argument('-n', '--nq', type=int, default=1024, help='number of mesh points') 554 parser.add_argument('-q', '--qmax', type=float, default=0.5, help='max q') 555 parser.add_argument('-w', '--window', type=float, default=2.0, help='q calc = q max * window') 556 parser.add_argument('-2', '--2d', dest='is2d', action='store_true', help='oriented sample') 557 parser.add_argument('-s', '--seed', default=-1, help='random pars with given seed') 558 parser.add_argument('-r', '--random', action='store_true', help='random pars with random seed') 559 parser.add_argument('-o', '--outfile', type=str, default="", help='random pars with random seed') 560 parser.add_argument('model', type=str, help='sas model name such as cylinder') 561 parser.add_argument('pars', type=str, nargs='*', help='model parameters such as radius=30') 577 parser.add_argument('-p', '--probability', type=float, default=0.1, 578 help="scattering probability") 579 parser.add_argument('-n', '--nq', type=int, default=1024, 580 help='number of mesh points') 581 parser.add_argument('-q', '--qmax', type=float, default=0.5, 582 help='max q') 583 parser.add_argument('-w', '--window', type=float, default=2.0, 584 help='q calc = q max * window') 585 parser.add_argument('-2', '--2d', dest='is2d', action='store_true', 586 help='oriented sample') 587 parser.add_argument('-s', '--seed', default=-1, 588 help='random pars with given seed') 589 parser.add_argument('-r', '--random', action='store_true', 590 help='random pars with random seed') 591 parser.add_argument('-o', '--outfile', type=str, default="", 592 help='random pars with random seed') 593 parser.add_argument('model', type=str, 594 help='sas model name such as cylinder') 595 parser.add_argument('pars', type=str, nargs='*', 596 help='model parameters such as radius=30') 562 597 opts = parser.parse_args() 563 598 assert opts.nq%2 == 0, "require even # points" … … 607 642 plotxy((res._q_steps, res._q_steps), res.Iqxy+background) 608 643 pylab.title("total scattering for p=%g" % probability) 644 if res.resolution is not None: 645 pylab.figure() 646 plotxy((res._q_steps, res._q_steps), result) 647 pylab.title("total scattering with resolution") 609 648 else: 610 649 q = res._q … … 624 663 # Plot 1D pattern for partial scattering 625 664 pylab.loglog(q, res.Iq+background, label="total for p=%g"%probability) 665 if res.resolution is not None: 666 pylab.loglog(q, result, label="total with dQ") 626 667 #new_annulus = annular_average(res._radius, res.Iqxy, res._edges) 627 668 #pylab.loglog(q, new_annulus+background, label="new total for p=%g"%probability)
Note: See TracChangeset
for help on using the changeset viewer.