Changeset 7d0afa3 in sasmodels for explore


Ignore:
Timestamp:
Jan 31, 2018 4:04:37 PM (7 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
ae91fce
Parents:
ecf895e
Message:

explore/realspace: update monte carlo simulation with command line parameters

File:
1 edited

Legend:

Unmodified
Added
Removed
  • explore/realspace.py

    r8fb2a94 r7d0afa3  
    33import time 
    44from copy import copy 
     5import os 
     6import argparse 
     7from collections import OrderedDict 
    58 
    69import numpy as np 
     
    1013from scipy.integrate import simps 
    1114from scipy.special import j1 as J1 
     15 
     16try: 
     17    import numba 
     18    USE_NUMBA = True 
     19except ImportError: 
     20    USE_NUMBA = False 
    1221 
    1322# Definition of rotation matrices comes from wikipedia: 
     
    244253        return values, self._adjust(points) 
    245254 
    246 NUMBA = False 
    247 if NUMBA: 
     255def csbox(a=10, b=20, c=30, da=1, db=2, dc=3, slda=1, sldb=2, sldc=3, sld_core=4): 
     256    core = Box(a, b, c, sld_core) 
     257    side_a = Box(da, b, c, slda, center=((a+da)/2, 0, 0)) 
     258    side_b = Box(a, db, c, sldb, center=(0, (b+db)/2, 0)) 
     259    side_c = Box(a, b, dc, sldc, center=(0, 0, (c+dc)/2)) 
     260    side_a2 = copy(side_a).shift(-a-da, 0, 0) 
     261    side_b2 = copy(side_b).shift(0, -b-db, 0) 
     262    side_c2 = copy(side_c).shift(0, 0, -c-dc) 
     263    shape = Composite((core, side_a, side_b, side_c, side_a2, side_b2, side_c2)) 
     264    return shape 
     265 
     266def _Iqxy(values, x, y, z, qa, qb, qc): 
     267    """I(q) = |sum V(r) rho(r) e^(1j q.r)|^2 / sum V(r)""" 
     268    Iq = [abs(np.sum(values*np.exp(1j*(qa_k*x + qb_k*y + qc_k*z))))**2 
     269            for qa_k, qb_k, qc_k in zip(qa.flat, qb.flat, qc.flat)] 
     270    return Iq 
     271 
     272if USE_NUMBA: 
     273    # Override simple numpy solution with numba if available 
    248274    from numba import njit 
    249275    @njit("f8[:](f8[:],f8[:],f8[:],f8[:],f8[:],f8[:],f8[:])") 
    250     def _Iqxy(values, x, y, z, qa, qb, qc): 
     276    def _Iqxy_jit(values, x, y, z, qa, qb, qc): 
    251277        Iq = np.zeros_like(qa) 
    252278        for j in range(len(Iq)): 
     
    265291 
    266292    # I(q) = |sum V(r) rho(r) e^(1j q.r)|^2 / sum V(r) 
    267     if NUMBA: 
    268         Iq = _Iqxy(values, x, y, z, qa.flatten(), qb.flatten(), qc.flatten()) 
    269     else: 
    270         Iq = [abs(np.sum(values*np.exp(1j*(qa_k*x + qb_k*y + qc_k*z))))**2 
    271               for qa_k, qb_k, qc_k in zip(qa.flat, qb.flat, qc.flat)] 
     293    Iq = _Iqxy(values, x, y, z, qa.flatten(), qb.flatten(), qc.flatten()) 
    272294    return np.asarray(Iq).reshape(qx.shape) / np.sum(volume) 
    273295 
     
    338360""" 
    339361 
    340 if NUMBA: 
     362if USE_NUMBA: 
     363    # Override simple numpy solution with numba if available 
    341364    @njit("f8[:](f8[:], f8[:], f8[:,:])") 
    342     def _calc_Pr_uniform_jit(r, rho, points): 
     365    def _calc_Pr_uniform(r, rho, points): 
    343366        dr = r[0] 
    344367        n_max = len(r) 
     
    365388        Pr = _calc_Pr_nonuniform(r, rho, points) 
    366389    else: 
    367         if NUMBA: 
    368             Pr = _calc_Pr_uniform_jit(r, rho, points) 
    369         else: 
    370             Pr = _calc_Pr_uniform(r, rho, points) 
     390        Pr = _calc_Pr_uniform(r, rho, points) 
    371391    return Pr / Pr.max() 
    372392 
     
    399419    return edges 
    400420 
     421# -------------- plotters ---------------- 
    401422def plot_calc(r, Pr, q, Iq, theory=None): 
    402423    import matplotlib.pyplot as plt 
     
    410431    plt.ylabel('Iq') 
    411432    if theory is not None: 
    412         plt.loglog(theory[0], theory[1], '-', label='analytic') 
     433        plt.loglog(theory[0], theory[1]/theory[1][0], '-', label='analytic') 
    413434        plt.legend() 
    414435 
     
    444465    ax.autoscale(True) 
    445466 
    446 def check_shape(shape, fn=None): 
    447     rho_solvent = 0 
    448     q = np.logspace(-3, 0, 200) 
    449     r = shape.r_bins(q, r_step=0.01) 
    450     sampling_density = 6*5000 / shape.volume() 
    451     rho, points = shape.sample(sampling_density) 
    452     t0 = time.time() 
    453     Pr = calc_Pr(r, rho-rho_solvent, points) 
    454     print("calc Pr time", time.time() - t0) 
    455     Iq = calc_Iq(q, r, Pr) 
    456     theory = (q, fn(q)) if fn is not None else None 
    457  
    458     import pylab 
    459     #plot_points(rho, points); pylab.figure() 
    460     plot_calc(r, Pr, q, Iq, theory=theory) 
    461     pylab.show() 
    462  
    463 def check_shape_2d(shape, fn=None, view=(0, 0, 0)): 
    464     rho_solvent = 0 
    465     nq, qmax = 100, 1.0 
    466     qx = np.linspace(0.0, qmax, nq) 
    467     qy = np.linspace(0.0, qmax, nq) 
    468     Qx, Qy = np.meshgrid(qx, qy) 
    469     sampling_density = 50000 / shape.volume() 
    470     #t0 = time.time() 
    471     rho, points = shape.sample(sampling_density) 
    472     #print("sample time", time.time() - t0) 
    473     t0 = time.time() 
    474     Iqxy = calc_Iqxy(Qx, Qy, rho, points, view=view) 
    475     print("calc time", time.time() - t0) 
    476     theory = fn(Qx, Qy) if fn is not None else None 
    477     Iqxy += 0.001 * Iqxy.max() 
    478     if theory is not None: 
    479         theory += 0.001 * theory.max() 
    480  
    481     import pylab 
    482     #plot_points(rho, points); pylab.figure() 
    483     plot_calc_2d(qx, qy, Iqxy, theory=theory) 
    484     pylab.show() 
    485  
     467# ----------- Analytic models -------------- 
    486468def sas_sinx_x(x): 
    487469    with np.errstate(all='ignore'): 
     
    510492    for k, qk in enumerate(q): 
    511493        qab, qc = qk*sin_alpha, qk*cos_alpha 
    512         Fq = sas_2J1x_x(qab*radius) * j0(qc*length/2) 
     494        Fq = sas_2J1x_x(qab*radius) * sas_sinx_x(qc*length/2) 
    513495        Iq[k] = np.sum(w*Fq**2) 
    514496    Iq = Iq/Iq[0] 
     
    517499def cylinder_Iqxy(qx, qy, radius, length, view=(0, 0, 0)): 
    518500    qa, qb, qc = invert_view(qx, qy, view) 
    519     qab = np.sqrt(qa**2 + qb**2) 
    520     Fq = sas_2J1x_x(qab*radius) * j0(qc*length/2) 
     501    qab = sqrt(qa**2 + qb**2) 
     502    Fq = sas_2J1x_x(qab*radius) * sas_sinx_x(qc*length/2) 
    521503    Iq = Fq**2 
    522504    return Iq.reshape(qx.shape) 
     
    525507    Iq = sas_3j1x_x(q*radius)**2 
    526508    return Iq/Iq[0] 
     509 
     510def box_Iq(q, a, b, c): 
     511    z, w = leggauss(76) 
     512    outer_sum = np.zeros_like(q) 
     513    for cos_alpha, outer_w in zip((z+1)/2, w): 
     514        sin_alpha = sqrt(1.0-cos_alpha*cos_alpha) 
     515        qc = q*cos_alpha 
     516        siC = c*sas_sinx_x(c*qc/2) 
     517        inner_sum = np.zeros_like(q) 
     518        for beta, inner_w in zip((z + 1)*pi/4, w): 
     519            qa, qb = q*sin_alpha*sin(beta), q*sin_alpha*cos(beta) 
     520            siA = a*sas_sinx_x(a*qa/2) 
     521            siB = b*sas_sinx_x(b*qb/2) 
     522            Fq = siA*siB*siC 
     523            inner_sum += inner_w * Fq**2 
     524        outer_sum += outer_w * inner_sum 
     525    Iq = outer_sum / 4  # = outer*um*zm*8.0/(4.0*M_PI) 
     526    return Iq/Iq[0] 
     527 
     528def box_Iqxy(qx, qy, a, b, c, view=(0, 0, 0)): 
     529    qa, qb, qc = invert_view(qx, qy, view) 
     530    sia = sas_sinx_x(qa*a/2) 
     531    sib = sas_sinx_x(qb*b/2) 
     532    sic = sas_sinx_x(qc*c/2) 
     533    Fq = sia*sib*sic 
     534    Iq = Fq**2 
     535    return Iq.reshape(qx.shape) 
    527536 
    528537def csbox_Iq(q, a, b, c, da, db, dc, slda, sldb, sldc, sld_core): 
     
    539548        sin_alpha = sqrt(1.0-cos_alpha*cos_alpha) 
    540549        qc = q*cos_alpha 
    541         siC = c*j0(c*qc/2) 
    542         siCt = tC*j0(tC*qc/2) 
     550        siC = c*sas_sinx_x(c*qc/2) 
     551        siCt = tC*sas_sinx_x(tC*qc/2) 
    543552        inner_sum = np.zeros_like(q) 
    544553        for beta, inner_w in zip((z + 1)*pi/4, w): 
    545554            qa, qb = q*sin_alpha*sin(beta), q*sin_alpha*cos(beta) 
    546             siA = a*j0(a*qa/2) 
    547             siB = b*j0(b*qb/2) 
    548             siAt = tA*j0(tA*qa/2) 
    549             siBt = tB*j0(tB*qb/2) 
     555            siA = a*sas_sinx_x(a*qa/2) 
     556            siB = b*sas_sinx_x(b*qb/2) 
     557            siAt = tA*sas_sinx_x(tA*qa/2) 
     558            siBt = tB*sas_sinx_x(tB*qb/2) 
    550559            if overlapping: 
    551560                Fq = (dr0*siA*siB*siC 
     
    584593    return Iq.reshape(qx.shape) 
    585594 
     595# --------- Test cases ----------- 
     596 
    586597def check_cylinder(radius=25, length=125, rho=2.): 
    587598    shape = EllipticalCylinder(radius, radius, length, rho) 
    588     fn = lambda q: cylinder_Iq(q, radius, length) 
    589     check_shape(shape, fn) 
    590  
    591 def check_cylinder_2d(radius=25, length=125, rho=2., view=(0, 0, 0)): 
    592     shape = EllipticalCylinder(radius, radius, length, rho) 
    593     fn = lambda qx, qy, view=view: cylinder_Iqxy(qx, qy, radius, length, view=view) 
    594     check_shape_2d(shape, fn, view=view) 
    595  
    596 def check_cylinder_2d_lattice(radius=25, length=125, rho=2., 
    597                               view=(0, 0, 0)): 
     599    fn = lambda q: cylinder_Iq(q, radius, length)*rho**2 
     600    fn_xy = lambda qx, qy, view: cylinder_Iqxy(qx, qy, radius, length, view=view)*rho**2 
     601    return shape, fn, fn_xy 
     602 
     603def check_cylinder_lattice(radius=25, length=125, rho=2.): 
    598604    nx, dx = 1, 2*radius 
    599605    ny, dy = 30, 2*radius 
     
    604610        space = 2 
    605611        return [(space*n+np.random.randn()*sigma)*x for n, x in args] 
     612    t0 = time.time() 
    606613    shapes = [EllipticalCylinder(radius, radius, length, rho, 
    607614                                 #center=(ix*dx, iy*dy, iz*dz) 
     
    613620              for iz in range(nz)] 
    614621    shape = Composite(shapes) 
    615     fn = lambda qx, qy, view=view: cylinder_Iqxy(qx, qy, radius, length, view=view) 
    616     check_shape_2d(shape, fn, view=view) 
     622    print("generate points time", time.time() - t0) 
     623    fn = None 
     624    fn_xy = lambda qx, qy, view: cylinder_Iqxy(qx, qy, radius, length, view=view) 
     625    return shape, fn, fn_xy 
    617626 
    618627def check_sphere(radius=125, rho=2): 
    619628    shape = TriaxialEllipsoid(radius, radius, radius, rho) 
    620     fn = lambda q: sphere_Iq(q, radius) 
    621     check_shape(shape, fn) 
     629    fn = lambda q: cylinder_Iq(q, radius, length)*rho**2 
     630    fn_xy = lambda qx, qy, view: sphere_Iq(np.sqrt(qx**2+qy**2), radius)*rho**2 
     631    return shape, fn, fn_xy 
     632 
     633def check_box(a=10, b=20, c=30, rho=2.): 
     634    shape = Box(a, b, c, rho) 
     635    fn = lambda q: box_Iq(q, a, b, c)*rho**2 
     636    fn_xy = lambda qx, qy, view: box_Iqxy(qx, qy, a, b, c, view=view)*rho**2 
     637    return shape, fn, fn_xy 
     638 
     639def check_box_lattice(a=10, b=20, c=30, rho=2.): 
     640    nx, dx = 3, a 
     641    ny, dy = 5, b 
     642    nz, dz = 5, c 
     643    dx, dy, dz = 2*dx, 2*dy, 2*dz 
     644    def center(*args): 
     645        sigma = 0.333 
     646        space = 2 
     647        return [(space*n+np.random.randn()*sigma)*x for n, x in args] 
     648    t0 = time.time() 
     649    shapes = [Box(a, b, c, rho, 
     650                  #center=(ix*dx, iy*dy, iz*dz) 
     651                  orientation=np.random.randn(3)*10, 
     652                  center=center((ix, dx), (iy, dy), (iz, dz)) 
     653                 ) 
     654              for ix in range(nx) 
     655              for iy in range(ny) 
     656              for iz in range(nz)] 
     657    shape = Composite(shapes) 
     658    print("generate points time", time.time() - t0) 
     659    fn = None 
     660    fn_xy = lambda qx, qy, view: box_Iqxy(qx, qy, a, b, c, view=view) 
     661    return shape, fn, fn_xy 
     662 
    622663 
    623664def check_csbox(a=10, b=20, c=30, da=1, db=2, dc=3, slda=1, sldb=2, sldc=3, sld_core=4): 
    624     core = Box(a, b, c, sld_core) 
    625     side_a = Box(da, b, c, slda, center=((a+da)/2, 0, 0)) 
    626     side_b = Box(a, db, c, sldb, center=(0, (b+db)/2, 0)) 
    627     side_c = Box(a, b, dc, sldc, center=(0, 0, (c+dc)/2)) 
    628     side_a2 = copy(side_a).shift(-a-da, 0, 0) 
    629     side_b2 = copy(side_b).shift(0, -b-db, 0) 
    630     side_c2 = copy(side_c).shift(0, 0, -c-dc) 
    631     shape = Composite((core, side_a, side_b, side_c, side_a2, side_b2, side_c2)) 
    632     def fn(q): 
    633         return csbox_Iq(q, a, b, c, da, db, dc, slda, sldb, sldc, sld_core) 
    634     #check_shape(shape, fn) 
    635  
    636     view = (20, 30, 40) 
    637     def fn_xy(qx, qy): 
    638         return csbox_Iqxy(qx, qy, a, b, c, da, db, dc, 
    639                           slda, sldb, sldc, sld_core, view=view) 
    640     check_shape_2d(shape, fn_xy, view=view) 
     665    shape = csbox(a, b, c, da, db, dc, slda, sldb, sldc, sld_core) 
     666    fn = lambda q: csbox_Iq(q, a, b, c, da, db, dc, slda, sldb, sldc, sld_core) 
     667    fn_xy = lambda qx, qy, view: csbox_Iqxy(qx, qy, a, b, c, da, db, dc, 
     668                                            slda, sldb, sldc, sld_core, view=view) 
     669    return shape, fn, fn_xy 
     670 
     671 
     672SHAPE_FUNCTIONS = OrderedDict([ 
     673    ("cylinder", check_cylinder), 
     674    ("sphere", check_sphere), 
     675    ("box", check_box), 
     676    ("csbox", check_csbox), 
     677    ("multicyl", check_cylinder_lattice), 
     678    ("multibox", check_box_lattice), 
     679]) 
     680SHAPES = list(SHAPE_FUNCTIONS.keys()) 
     681 
     682def check_shape(title, shape, fn=None, show_points=False, 
     683                mesh=100, qmax=1.0, r_step=0.01, samples=5000): 
     684    rho_solvent = 0 
     685    qmin = qmax/1000. 
     686    q = np.logspace(np.log10(qmin), np.log10(qmax), mesh) 
     687    r = shape.r_bins(q, r_step=r_step) 
     688    sampling_density = samples / shape.volume() 
     689    rho, points = shape.sample(sampling_density) 
     690    t0 = time.time() 
     691    Pr = calc_Pr(r, rho-rho_solvent, points) 
     692    print("calc Pr time", time.time() - t0) 
     693    Iq = calc_Iq(q, r, Pr) 
     694    theory = (q, fn(q)) if fn is not None else None 
     695 
     696    import pylab 
     697    if show_points: 
     698         plot_points(rho, points); pylab.figure() 
     699    plot_calc(r, Pr, q, Iq, theory=theory) 
     700    pylab.gcf().canvas.set_window_title(title) 
     701    pylab.show() 
     702 
     703def check_shape_2d(title, shape, fn=None, view=(0, 0, 0), show_points=False, 
     704                   mesh=100, qmax=1.0, samples=5000): 
     705    rho_solvent = 0 
     706    qx = np.linspace(0.0, qmax, mesh) 
     707    qy = np.linspace(0.0, qmax, mesh) 
     708    Qx, Qy = np.meshgrid(qx, qy) 
     709    sampling_density = samples / shape.volume() 
     710    #t0 = time.time() 
     711    rho, points = shape.sample(sampling_density) 
     712    #print("sample time", time.time() - t0) 
     713    t0 = time.time() 
     714    Iqxy = calc_Iqxy(Qx, Qy, rho, points, view=view) 
     715    print("calc Iqxy time", time.time() - t0) 
     716    theory = fn(Qx, Qy, view) if fn is not None else None 
     717    Iqxy += 0.001 * Iqxy.max() 
     718    if theory is not None: 
     719        theory += 0.001 * theory.max() 
     720 
     721    import pylab 
     722    if show_points: 
     723        plot_points(rho, points); pylab.figure() 
     724    plot_calc_2d(qx, qy, Iqxy, theory=theory) 
     725    pylab.gcf().canvas.set_window_title(title) 
     726    pylab.show() 
     727 
     728def main(): 
     729    parser = argparse.ArgumentParser( 
     730        description="Compute scattering from realspace sampling", 
     731        formatter_class=argparse.ArgumentDefaultsHelpFormatter, 
     732        ) 
     733    parser.add_argument('-d', '--dim', type=int, default=1, help='dimension 1 or 2') 
     734    parser.add_argument('-m', '--mesh', type=int, default=100, help='number of mesh points') 
     735    parser.add_argument('-s', '--samples', type=int, default=5000, help="number of sample points") 
     736    parser.add_argument('-q', '--qmax', type=float, default=0.5, help='max q') 
     737    parser.add_argument('-v', '--view', type=str, default='0,0,0', help='theta,phi,psi angles') 
     738    parser.add_argument('-p', '--plot', action='store_true', help='plot points') 
     739    parser.add_argument('shape', choices=SHAPES, nargs='?', default=SHAPES[0], help='oriented shape') 
     740    parser.add_argument('pars', type=str, nargs='*', help='shape parameters') 
     741    opts = parser.parse_args() 
     742    pars = {key: float(value) for p in opts.pars for key, value in [p.split('=')]} 
     743    shape, fn, fn_xy = SHAPE_FUNCTIONS[opts.shape](**pars) 
     744    title = "%s(%s)" % (opts.shape, " ".join(opts.pars)) 
     745    if opts.dim == 1: 
     746        check_shape(title, shape, fn, show_points=opts.plot, 
     747                    mesh=opts.mesh, qmax=opts.qmax, samples=opts.samples) 
     748    else: 
     749        view = tuple(float(v) for v in opts.view.split(',')) 
     750        check_shape_2d(title, shape, fn_xy, view=view, show_points=opts.plot, 
     751                       mesh=opts.mesh, qmax=opts.qmax, samples=opts.samples) 
     752 
    641753 
    642754if __name__ == "__main__": 
    643     check_cylinder(radius=10, length=40) 
    644     #check_cylinder_2d(radius=10, length=40, view=(90,30,0)) 
    645     #check_cylinder_2d_lattice(radius=10, length=50, view=(90,30,0)) 
    646     #check_sphere() 
    647     #check_csbox() 
    648     #check_csbox(da=100, db=200, dc=300) 
     755    main() 
Note: See TracChangeset for help on using the changeset viewer.