Changeset b3703f5 in sasmodels for sasmodels/multiscat.py


Ignore:
Timestamp:
Mar 26, 2018 3:50:35 PM (6 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
c11d09f, c462169
Parents:
802c412
Message:

lint reduction

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/multiscat.py

    rd86f0fc rb3703f5  
    7373import argparse 
    7474import time 
    75 import os.path 
    7675 
    7776import numpy as np 
     
    8180from sasmodels import core 
    8281from sasmodels import compare 
    83 from sasmodels import resolution2d 
    8482from sasmodels.resolution import Resolution, bin_edges 
    85 from sasmodels.data import empty_data1D, empty_data2D, plot_data 
    8683from sasmodels.direct_model import call_kernel 
    8784import sasmodels.kernelcl 
     
    106103USE_FAST = True  # OpenCL faster, less accurate math 
    107104 
    108 class NumpyCalculator: 
     105class ICalculator: 
     106    """ 
     107    Multiple scattering calculator 
     108    """ 
     109    def fft(self, Iq): 
     110        """ 
     111        Compute the forward FFT for an image, real -> complex. 
     112        """ 
     113        raise NotImplementedError() 
     114 
     115    def ifft(self, Iq): 
     116        """ 
     117        Compute the inverse FFT for an image, complex -> complex. 
     118        """ 
     119        raise NotImplementedError() 
     120 
     121    def mulitple_scattering(self, Iq): 
     122        r""" 
     123        Compute multiple scattering for I(q) given scattering probability p. 
     124 
     125        Given a probability p of scattering with the thickness, the expected 
     126        number of scattering events, $\lambda$ is $-\log(1 - p)$, giving a 
     127        Poisson weighted sum of single, double, triple, etc. scattering patterns. 
     128        The number of patterns used is based on coverage (default 99%). 
     129        """ 
     130        raise NotImplementedError() 
     131 
     132class NumpyCalculator(ICalculator): 
     133    """ 
     134    Multiple scattering calculator using numpy fft. 
     135    """ 
    109136    def __init__(self, dims=None, dtype=PRECISION): 
    110137        self.dtype = dtype 
    111138        self.complex_dtype = np.dtype('F') if dtype == np.dtype('f') else np.dtype('D') 
    112         pass 
    113139 
    114140    def fft(self, Iq): 
     
    127153 
    128154    def multiple_scattering(self, Iq, p, coverage=0.99): 
    129         r""" 
    130         Compute multiple scattering for I(q) given scattering probability p. 
    131  
    132         Given a probability p of scattering with the thickness, the expected 
    133         number of scattering events, $\lambda$ is $-\log(1 - p)$, giving a 
    134         Poisson weighted sum of single, double, triple, etc. scattering patterns. 
    135         The number of patterns used is based on coverage (default 99%). 
    136         """ 
    137155        #t0 = time.time() 
    138156        coeffs = scattering_coeffs(p, coverage) 
     
    140158        scale = np.sum(Iq) 
    141159        frame = _forward_shift(Iq/scale, dtype=self.dtype) 
    142         F = np.fft.fft2(frame) 
    143         F_convolved = F * np.polyval(poly, F) 
    144         frame = np.fft.ifft2(F_convolved) 
     160        fourier_frame = np.fft.fft2(frame) 
     161        convolved = fourier_frame * np.polyval(poly, fourier_frame) 
     162        frame = np.fft.ifft2(convolved) 
    145163        result = scale * _inverse_shift(frame.real, dtype=self.dtype) 
    146164        #print("numpy multiscat time", time.time()-t0) 
     
    173191""" 
    174192 
    175 class OpenclCalculator(NumpyCalculator): 
     193class OpenclCalculator(ICalculator): 
     194    """ 
     195    Multiple scattering calculator using OpenCL via pyfft. 
     196    """ 
    176197    polyval1f = None 
    177198    polyval1d = None 
     
    180201        context = env.get_context(dtype) 
    181202        if dtype == np.dtype('f'): 
    182             if self.polyval1f is None: 
     203            if OpenclCalculator.polyval1f is None: 
    183204                program = sasmodels.kernelcl.compile_model( 
    184205                    context, POLYVAL1_KERNEL, dtype, fast=USE_FAST) 
     
    187208            self.dtype = dtype 
    188209            self.complex_dtype = np.dtype('F') 
    189             self.polyval1 = self.polyval1f 
     210            self.polyval1 = OpenclCalculator.polyval1f 
    190211        else: 
    191             if self.polyval1d is None: 
     212            if OpenclCalculator.polyval1d is None: 
    192213                program = sasmodels.kernelcl.compile_model( 
    193214                    context, POLYVAL1_KERNEL, dtype, fast=False) 
     
    196217            self.dtype = dtype 
    197218            self.complex_type = np.dtype('D') 
    198             self.polyval1 = self.polyval1d 
     219            self.polyval1 = OpenclCalculator.polyval1d 
    199220        self.queue = env.get_queue(dtype) 
    200221        self.plan = pyfft.cl.Plan(dims, queue=self.queue) 
     
    229250        gpu_poly = cl_array.to_device(self.queue, poly) 
    230251        self.plan.execute(gpu_data.data) 
    231         degree, n = poly.shape[0], frame.shape[0]*frame.shape[1] 
     252        degree, data_size= poly.shape[0], frame.shape[0]*frame.shape[1] 
    232253        self.polyval1( 
    233             self.queue, [n], None, 
    234             np.int32(degree), gpu_poly.data, np.int32(n), gpu_data.data) 
     254            self.queue, [data_size], None, 
     255            np.int32(degree), gpu_poly.data, np.int32(data_size), gpu_data.data) 
    235256        self.plan.execute(gpu_data.data, inverse=True) 
    236257        frame = gpu_data.get() 
     
    251272    """ 
    252273    if transform is None: 
    253         nx, ny = Iq.shape 
    254         transform = Calculator(dims=(nx*2, ny*2), dtype=dtype) 
     274        n_x, n_y = Iq.shape 
     275        transform = Calculator(dims=(n_x*2, n_y*2), dtype=dtype) 
    255276    scale = np.sum(Iq) 
    256277    frame = _forward_shift(Iq/scale, dtype=dtype) 
     
    528549def parse_pars(model, opts): 
    529550    # type: (ModelInfo, argparse.Namespace) -> Dict[str, float] 
     551    """ 
     552    Parse par=val arguments from the command line. 
     553    """ 
    530554 
    531555    seed = np.random.randint(1000000) if opts.random and opts.seed < 0 else opts.seed 
     
    541565        'is2d': opts.is2d, 
    542566    } 
    543     pars, pars2 = compare.parse_pars(compare_opts) 
     567    # Note: sascomp allows comparison on a pair of models, so ignore the second. 
     568    pars, _ = compare.parse_pars(compare_opts) 
    544569    return pars 
    545570 
     
    550575        formatter_class=argparse.ArgumentDefaultsHelpFormatter, 
    551576        ) 
    552     parser.add_argument('-p', '--probability', type=float, default=0.1, help="scattering probability") 
    553     parser.add_argument('-n', '--nq', type=int, default=1024, help='number of mesh points') 
    554     parser.add_argument('-q', '--qmax', type=float, default=0.5, help='max q') 
    555     parser.add_argument('-w', '--window', type=float, default=2.0, help='q calc = q max * window') 
    556     parser.add_argument('-2', '--2d', dest='is2d', action='store_true', help='oriented sample') 
    557     parser.add_argument('-s', '--seed', default=-1, help='random pars with given seed') 
    558     parser.add_argument('-r', '--random', action='store_true', help='random pars with random seed') 
    559     parser.add_argument('-o', '--outfile', type=str, default="", help='random pars with random seed') 
    560     parser.add_argument('model', type=str, help='sas model name such as cylinder') 
    561     parser.add_argument('pars', type=str, nargs='*', help='model parameters such as radius=30') 
     577    parser.add_argument('-p', '--probability', type=float, default=0.1, 
     578                        help="scattering probability") 
     579    parser.add_argument('-n', '--nq', type=int, default=1024, 
     580                        help='number of mesh points') 
     581    parser.add_argument('-q', '--qmax', type=float, default=0.5, 
     582                        help='max q') 
     583    parser.add_argument('-w', '--window', type=float, default=2.0, 
     584                        help='q calc = q max * window') 
     585    parser.add_argument('-2', '--2d', dest='is2d', action='store_true', 
     586                        help='oriented sample') 
     587    parser.add_argument('-s', '--seed', default=-1, 
     588                        help='random pars with given seed') 
     589    parser.add_argument('-r', '--random', action='store_true', 
     590                        help='random pars with random seed') 
     591    parser.add_argument('-o', '--outfile', type=str, default="", 
     592                        help='random pars with random seed') 
     593    parser.add_argument('model', type=str, 
     594                        help='sas model name such as cylinder') 
     595    parser.add_argument('pars', type=str, nargs='*', 
     596                        help='model parameters such as radius=30') 
    562597    opts = parser.parse_args() 
    563598    assert opts.nq%2 == 0, "require even # points" 
     
    607642            plotxy((res._q_steps, res._q_steps), res.Iqxy+background) 
    608643            pylab.title("total scattering for p=%g" % probability) 
     644            if res.resolution is not None: 
     645                pylab.figure() 
     646                plotxy((res._q_steps, res._q_steps), result) 
     647                pylab.title("total scattering with resolution") 
    609648    else: 
    610649        q = res._q 
     
    624663            # Plot 1D pattern for partial scattering 
    625664            pylab.loglog(q, res.Iq+background, label="total for p=%g"%probability) 
     665            if res.resolution is not None: 
     666                pylab.loglog(q, result, label="total with dQ") 
    626667            #new_annulus = annular_average(res._radius, res.Iqxy, res._edges) 
    627668            #pylab.loglog(q, new_annulus+background, label="new total for p=%g"%probability) 
Note: See TracChangeset for help on using the changeset viewer.