Changeset 5778279 in sasmodels


Ignore:
Timestamp:
Sep 25, 2018 5:39:08 PM (6 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
ticket_1156
Children:
b3f4831
Parents:
2a39ca4
Message:

faster lattice sampling for realspace code

File:
1 edited

Legend:

Unmodified
Added
Removed
  • explore/realspace.py

    r2a39ca4 r5778279  
    99import numpy as np 
    1010from numpy import pi, radians, sin, cos, sqrt 
    11 from numpy.random import poisson, uniform, randn, rand 
     11from numpy.random import poisson, uniform, randn, rand, randint 
    1212from numpy.polynomial.legendre import leggauss 
    1313from scipy.integrate import simps 
     
    8484    center = np.array([0., 0., 0.])[:, None] 
    8585    r_max = None 
     86    lattice_size = np.array((1, 1, 1)) 
     87    lattice_spacing = np.array((1., 1., 1.)) 
     88    lattice_distortion = 0.0 
     89    lattice_rotation = 0.0 
     90    lattice_type = "" 
    8691 
    8792    def volume(self): 
     
    105110        self.center = self.center + np.array([x, y, z])[:, None] 
    106111        return self 
     112 
     113    def lattice(self, size=(1, 1, 1), spacing=(2, 2, 2), type="sc", 
     114                distortion=0.0, rotation=0.0): 
     115        self.lattice_size = np.asarray(size, 'i') 
     116        self.lattice_spacing = np.asarray(spacing, 'd') 
     117        self.lattice_type = type 
     118        self.lattice_distortion = distortion 
     119        self.lattice_rotation = rotation 
    107120 
    108121    def _adjust(self, points): 
     
    111124        else: 
    112125            points = np.asarray(self.rotation * np.matrix(points.T)) + self.center 
     126        if self.lattice_type: 
     127            points = self._apply_lattice(points) 
    113128        return points.T 
    114129 
    115     def r_bins(self, q, over_sampling=1, r_step=0.): 
    116         r_max = min(2 * pi / q[0], self.r_max) 
     130    def r_bins(self, q, over_sampling=10, r_step=0.): 
     131        if self.lattice_type: 
     132            r_max = np.sqrt(np.sum(self.lattice_size*self.lattice_spacing*self.dims)**2)/2 
     133        else: 
     134            r_max = self.r_max 
     135        #r_max = min(2 * pi / q[0], r_max) 
    117136        if r_step == 0.: 
    118137            r_step = 2 * pi / q[-1] / over_sampling 
    119138        #r_step = 0.01 
    120139        return np.arange(r_step, r_max, r_step) 
     140 
     141    def _apply_lattice(self, points): 
     142        """Spread points to different lattice positions""" 
     143        size = self.lattice_size 
     144        spacing = self.lattice_spacing 
     145        shuffle = self.lattice_distortion 
     146        rotate = self.lattice_rotation 
     147        lattice = self.lattice_type 
     148 
     149        if rotate != 0: 
     150            # To vectorize the rotations we will need to unwrap the matrix multiply 
     151            raise NotImplementedError("don't handle rotations yet") 
     152 
     153        # Determine the number of lattice points in the lattice 
     154        shapes_per_cell = 2 if lattice == "bcc" else 4 if lattice == "fcc" else 1 
     155        number_of_lattice_points = np.prod(size) * shapes_per_cell 
     156 
     157        # For each point in the original shape, figure out which lattice point 
     158        # to translate it to.  This is both cell index (i*ny*nz + j*nz  + k) as 
     159        # well as the point in the cell (corner, body center or face center). 
     160        nsamples = points.shape[1] 
     161        lattice_point = randint(number_of_lattice_points, size=nsamples) 
     162 
     163        # Translate the cell index into the i,j,k coordinates of the senter 
     164        cell_index = lattice_point // shapes_per_cell 
     165        center = np.vstack((cell_index//(size[1]*size[2]), 
     166                            (cell_index%(size[1]*size[2]))//size[2], 
     167                            cell_index%size[2])) 
     168        center = np.asarray(center, dtype='d') 
     169        if lattice == "bcc": 
     170            center[:, lattice_point % shapes_per_cell == 1] += [[0.5], [0.5], [0.5]] 
     171        elif lattice == "fcc": 
     172            center[:, lattice_point % shapes_per_cell == 1] += [[0.0], [0.5], [0.5]] 
     173            center[:, lattice_point % shapes_per_cell == 2] += [[0.5], [0.0], [0.5]] 
     174            center[:, lattice_point % shapes_per_cell == 3] += [[0.5], [0.5], [0.0]] 
     175 
     176        # Each lattice point has its own displacement from the ideal position. 
     177        # Not checking that shapes do not overlap if displacement is too large. 
     178        offset = shuffle*(randn(3, number_of_lattice_points) if shuffle < 0.3 
     179                          else rand(3, number_of_lattice_points)) 
     180        center += offset[:, cell_index] 
     181 
     182        # Each lattice point has its own rotation.  Rotate the point prior to 
     183        # applying any displacement. 
     184        # rotation = rotate*(randn(size=(shapes, 3)) if shuffle < 30 else rand(size=(nsamples, 3))) 
     185        # for k in shapes: points[k] = rotation[k]*points[k] 
     186        points += center*(np.array([spacing])*np.array(self.dims)).T 
     187        return points 
    121188 
    122189class Composite(Shape): 
     
    867934    import pylab 
    868935    if show_points: 
    869          plot_points(rho, points); pylab.figure() 
     936        plot_points(rho, points); pylab.figure() 
    870937    plot_calc(r, Pr, q, Iq, theory=theory, title=title) 
    871938    pylab.gcf().canvas.set_window_title(title) 
     
    9371004    nx, ny, nz = [int(v) for v in opts.lattice.split(',')] 
    9381005    dx, dy, dz = [float(v) for v in opts.spacing.split(',')] 
    939     shuffle, rotate = opts.shuffle, opts.rotate 
     1006    distortion, rotation = opts.shuffle, opts.rotate 
    9401007    shape, fn, fn_xy = SHAPE_FUNCTIONS[opts.shape](**pars) 
    9411008    view = tuple(float(v) for v in opts.view.split(',')) 
    942     if nx > 1 or ny > 1 or nz > 1: 
    943         print("building %s lattice"%opts.type) 
    944         lattice = LATTICE_FUNCTIONS[opts.type] 
    945         shape = lattice(shape, nx, ny, nz, dx, dy, dz, shuffle, rotate) 
    946         # If comparing a sphere in a cubic lattice, compare against the 
    947         # corresponding paracrystalline model. 
    948         if opts.shape == "sphere" and dx == dy == dz: 
    949             radius = pars.get('radius', DEFAULT_SPHERE_RADIUS) 
    950             model_name = opts.type + "_paracrystal" 
    951             model_pars = { 
    952                 "scale": 1., 
    953                 "background": 0., 
    954                 "lattice_spacing": 2*radius*dx, 
    955                 "lattice_distortion": shuffle, 
    956                 "radius": radius, 
    957                 "sld": pars.get('rho', DEFAULT_SPHERE_CONTRAST), 
    958                 "sld_solvent": 0., 
    959                 "theta": view[0], 
    960                 "phi": view[1], 
    961                 "psi": view[2], 
    962             } 
    963             fn, fn_xy = wrap_sasmodel(model_name, **model_pars) 
     1009    # If comparing a sphere in a cubic lattice, compare against the 
     1010    # corresponding paracrystalline model. 
     1011    if opts.shape == "sphere" and dx == dy == dz and nx*ny*nz > 1: 
     1012        radius = pars.get('radius', DEFAULT_SPHERE_RADIUS) 
     1013        model_name = opts.type + "_paracrystal" 
     1014        model_pars = { 
     1015            "scale": 1., 
     1016            "background": 0., 
     1017            "lattice_spacing": 2*radius*dx, 
     1018            "lattice_distortion": distortion, 
     1019            "radius": radius, 
     1020            "sld": pars.get('rho', DEFAULT_SPHERE_CONTRAST), 
     1021            "sld_solvent": 0., 
     1022            "theta": view[0], 
     1023            "phi": view[1], 
     1024            "psi": view[2], 
     1025        } 
     1026        fn, fn_xy = wrap_sasmodel(model_name, **model_pars) 
     1027    if nx*ny*nz > 1: 
     1028        if rotation != 0: 
     1029            print("building %s lattice"%opts.type) 
     1030            build_lattice = LATTICE_FUNCTIONS[opts.type] 
     1031            shape = build_lattice(shape, nx, ny, nz, dx, dy, dz, 
     1032                                  distortion, rotation) 
     1033        else: 
     1034            shape.lattice(size=(nx, ny, nz), spacing=(dx, dy, dz), 
     1035                          type=opts.type, 
     1036                          rotation=rotation, distortion=distortion) 
     1037 
    9641038    title = "%s(%s)" % (opts.shape, " ".join(opts.pars)) 
    9651039    if opts.dim == 1: 
Note: See TracChangeset for help on using the changeset viewer.