Changeset aa6989b in sasmodels for explore/jitter.py
- Timestamp:
- Oct 18, 2017 10:41:12 PM (7 years ago)
- Branches:
- master, core_shell_microgels, magnetic_model, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
- Children:
- 9b7b23f
- Parents:
- ef8e68c
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
explore/jitter.py
- Property mode changed from 100644 to 100755
rd4c33d6 raa6989b 1 #!/usr/bin/env python 1 2 """ 2 3 Application to explore the difference between sasview 3.x orientation 3 4 dispersity and possible replacement algorithms. 4 5 """ 5 import sys 6 from __future__ import division, print_function 7 8 import sys, os 9 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 6 10 7 11 import mpl_toolkits.mplot3d # Adds projection='3d' option to subplot … … 13 17 14 18 def draw_beam(ax, view=(0, 0)): 19 """ 20 Draw the beam going from source at (0, 0, 1) to detector at (0, 0, -1) 21 """ 15 22 #ax.plot([0,0],[0,0],[1,-1]) 16 23 #ax.scatter([0]*100,[0]*100,np.linspace(1, -1, 100), alpha=0.8) … … 33 40 ax.plot_surface(x, y, z, rstride=4, cstride=4, color='y', alpha=0.5) 34 41 35 def draw_jitter(ax, view, jitter): 36 size = [0.1, 0.4, 1.0] 42 def draw_jitter(ax, view, jitter, dist='gaussian', size=(0.1, 0.4, 1.0)): 43 """ 44 Represent jitter as a set of shapes at different orientations. 45 """ 46 # set max diagonal to 0.95 47 scale = 0.95/sqrt(sum(v**2 for v in size)) 48 size = tuple(scale*v for v in size) 37 49 draw_shape = draw_parallelepiped 38 50 #draw_shape = draw_ellipsoid … … 77 89 cloud = [v for v in cloud if v[2] == 0] 78 90 draw_shape(ax, size, view, [0, 0, 0], steps=100, alpha=0.8) 91 scale = 1/sqrt(3) if dist == 'rectangle' else 1 79 92 for point in cloud: 80 delta = [ dtheta*point[0], dphi*point[1],dpsi*point[2]]93 delta = [scale*dtheta*point[0], scale*dphi*point[1], scale*dpsi*point[2]] 81 94 draw_shape(ax, size, view, delta, alpha=0.8) 82 95 for v in 'xyz': … … 87 100 88 101 def draw_ellipsoid(ax, size, view, jitter, steps=25, alpha=1): 102 """Draw an ellipsoid.""" 89 103 a,b,c = size 90 104 u = np.linspace(0, 2 * np.pi, steps) … … 107 121 108 122 def draw_parallelepiped(ax, size, view, jitter, steps=None, alpha=1): 123 """Draw a parallelepiped.""" 109 124 a,b,c = size 110 125 x = a*np.array([ 1,-1, 1,-1, 1,-1, 1,-1]) … … 134 149 ]) 135 150 136 def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gauss'): 151 def draw_sphere(ax, radius=10., steps=100): 152 """Draw a sphere""" 153 u = np.linspace(0, 2 * np.pi, steps) 154 v = np.linspace(0, np.pi, steps) 155 156 x = radius * np.outer(np.cos(u), np.sin(v)) 157 y = radius * np.outer(np.sin(u), np.sin(v)) 158 z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) 159 ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w') 160 161 def draw_mesh(ax, view, jitter, radius=1.2, n=11, dist='gaussian'): 162 """ 163 Draw the dispersion mesh showing the theta-phi orientations at which 164 the model will be evaluated. 165 """ 137 166 theta, phi, psi = view 138 167 dtheta, dphi, dpsi = jitter 139 if dist == 'gauss': 168 169 if dist == 'gaussian': 140 170 t = np.linspace(-3, 3, n) 141 171 weights = exp(-0.5*t**2) 142 elif dist == 'rect': 143 t = np.linspace(0, 1, n) 172 elif dist == 'rectangle': 173 # Note: uses sasmodels ridiculous definition of rectangle width 174 t = np.linspace(-1, 1, n)*sqrt(3) 144 175 weights = np.ones_like(t) 145 176 else: 146 raise ValueError("expected dist to be 'gauss ' or 'rect'")177 raise ValueError("expected dist to be 'gaussian' or 'rectangle'") 147 178 148 179 # mesh in theta, phi formed by rotating z … … 154 185 points = orient_relative_to_beam(view, points) 155 186 156 w = np.outer(weights , weights)187 w = np.outer(weights*cos(radians(dtheta*t)), weights) 157 188 158 189 x, y, z = [np.array(v).flatten() for v in points] 159 190 ax.scatter(x, y, z, c=w.flatten(), marker='o', vmin=0., vmax=1.) 160 191 192 def draw_labels(ax, view, jitter, text): 193 """ 194 Draw text at a particular location. 195 """ 196 labels, locations, orientations = zip(*text) 197 px, py, pz = zip(*locations) 198 dx, dy, dz = zip(*orientations) 199 200 px, py, pz = transform_xyz(view, jitter, px, py, pz) 201 dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz) 202 203 # TODO: zdir for labels is broken, and labels aren't appearing. 204 for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)): 205 zdir = np.asarray(zdir).flatten() 206 ax.text(p[0], p[1], p[2], label, zdir=zdir) 207 208 # Definition of rotation matrices comes from wikipedia: 209 # https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations 161 210 def Rx(angle): 211 """Construct a matrix to rotate points about *x* by *angle* degrees.""" 162 212 a = radians(angle) 163 R = [[1 ., 0., 0.],164 [0 ., cos(a),sin(a)],165 [0 ., -sin(a),cos(a)]]213 R = [[1, 0, 0], 214 [0, +cos(a), -sin(a)], 215 [0, +sin(a), +cos(a)]] 166 216 return np.matrix(R) 167 217 168 218 def Ry(angle): 219 """Construct a matrix to rotate points about *y* by *angle* degrees.""" 169 220 a = radians(angle) 170 R = [[ cos(a), 0., -sin(a)],171 [0 ., 1., 0.],172 [ sin(a), 0.,cos(a)]]221 R = [[+cos(a), 0, +sin(a)], 222 [0, 1, 0], 223 [-sin(a), 0, +cos(a)]] 173 224 return np.matrix(R) 174 225 175 226 def Rz(angle): 227 """Construct a matrix to rotate points about *z* by *angle* degrees.""" 176 228 a = radians(angle) 177 R = [[ cos(a), -sin(a), 0.],178 [ sin(a), cos(a), 0.],179 [0 ., 0., 1.]]229 R = [[+cos(a), -sin(a), 0], 230 [+sin(a), +cos(a), 0], 231 [0, 0, 1]] 180 232 return np.matrix(R) 181 233 182 234 def transform_xyz(view, jitter, x, y, z): 235 """ 236 Send a set of (x,y,z) points through the jitter and view transforms. 237 """ 183 238 x, y, z = [np.asarray(v) for v in (x, y, z)] 184 239 shape = x.shape … … 190 245 191 246 def apply_jitter(jitter, points): 247 """ 248 Apply the jitter transform to a set of points. 249 250 Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 251 """ 192 252 dtheta, dphi, dpsi = jitter 193 253 points = Rx(dphi)*Ry(dtheta)*Rz(dpsi)*points … … 195 255 196 256 def orient_relative_to_beam(view, points): 257 """ 258 Apply the view transform to a set of points. 259 260 Points are stored in a 3 x n numpy matrix, not a numpy array or tuple. 261 """ 197 262 theta, phi, psi = view 198 263 points = Rz(phi)*Ry(theta)*Rz(psi)*points 199 264 return points 200 265 201 def draw_labels(ax, view, jitter, text): 202 labels, locations, orientations = zip(*text) 203 px, py, pz = zip(*locations) 204 dx, dy, dz = zip(*orientations) 205 206 px, py, pz = transform_xyz(view, jitter, px, py, pz) 207 dx, dy, dz = transform_xyz(view, jitter, dx, dy, dz) 208 209 for label, p, zdir in zip(labels, zip(px, py, pz), zip(dx, dy, dz)): 210 zdir = np.asarray(zdir).flatten() 211 ax.text(p[0], p[1], p[2], label, zdir=zdir) 212 213 def draw_sphere(ax, radius=10., steps=100): 214 u = np.linspace(0, 2 * np.pi, steps) 215 v = np.linspace(0, np.pi, steps) 216 217 x = radius * np.outer(np.cos(u), np.sin(v)) 218 y = radius * np.outer(np.sin(u), np.sin(v)) 219 z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) 220 ax.plot_surface(x, y, z, rstride=4, cstride=4, color='w') 221 222 def main(): 266 # translate between number of dimension of dispersity and the number of 267 # points along each dimension. 268 PD_N_TABLE = { 269 (0, 0, 0): (0, 0, 0), # 0 270 (1, 0, 0): (100, 0, 0), # 100 271 (0, 1, 0): (0, 100, 0), 272 (0, 0, 1): (0, 0, 100), 273 (1, 1, 0): (30, 30, 0), # 900 274 (1, 0, 1): (30, 0, 30), 275 (0, 1, 1): (0, 30, 30), 276 (1, 1, 1): (15, 15, 15), # 3375 277 } 278 279 def clipped_range(data, portion=1.0, mode='central'): 280 """ 281 Determine range from data. 282 283 If *portion* is 1, use full range, otherwise use the center of the range 284 or the top of the range, depending on whether *mode* is 'central' or 'top'. 285 """ 286 if portion == 1.0: 287 return data.min(), data.max() 288 elif mode == 'central': 289 data = np.sort(data.flatten()) 290 offset = int(portion*len(data)/2 + 0.5) 291 return data[offset], data[-offset] 292 elif mode == 'top': 293 data = np.sort(data.flatten()) 294 offset = int(portion*len(data) + 0.5) 295 return data[offset], data[-1] 296 297 def draw_scattering(calculator, ax, view, jitter, dist='gaussian'): 298 """ 299 Plot the scattering for the particular view. 300 301 *calculator* is returned from :func:`build_model`. *ax* are the 3D axes 302 on which the data will be plotted. *view* and *jitter* are the current 303 orientation and orientation dispersity. *dist* is one of the sasmodels 304 weight distributions. 305 """ 306 ## Sasmodels use sqrt(3)*width for the rectangle range; scale to the 307 ## proper width for comparison. Commented out since now using the 308 ## sasmodels definition of width for rectangle. 309 #scale = 1/sqrt(3) if dist == 'rectangle' else 1 310 scale = 1 311 312 # add the orientation parameters to the model parameters 313 theta, phi, psi = view 314 theta_pd, phi_pd, psi_pd = [scale*v for v in jitter] 315 theta_pd_n, phi_pd_n, psi_pd_n = PD_N_TABLE[(theta_pd>0, phi_pd>0, psi_pd>0)] 316 ## increase pd_n for testing jitter integration rather than simple viz 317 #theta_pd_n, phi_pd_n, psi_pd_n = [5*v for v in (theta_pd_n, phi_pd_n, psi_pd_n)] 318 319 pars = dict( 320 theta=theta, theta_pd=theta_pd, theta_pd_type=dist, theta_pd_n=theta_pd_n, 321 phi=phi, phi_pd=phi_pd, phi_pd_type=dist, phi_pd_n=phi_pd_n, 322 psi=psi, psi_pd=psi_pd, psi_pd_type=dist, psi_pd_n=psi_pd_n, 323 ) 324 pars.update(calculator.pars) 325 326 # compute the pattern 327 qx, qy = calculator._data.x_bins, calculator._data.y_bins 328 Iqxy = calculator(**pars).reshape(len(qx), len(qy)) 329 330 # scale it and draw it 331 Iqxy = np.log(Iqxy) 332 if calculator.limits: 333 # use limits from orientation (0,0,0) 334 vmin, vmax = calculator.limits 335 else: 336 vmin, vmax = clipped_range(Iqxy, portion=0.95, mode='top') 337 #print("range",(vmin,vmax)) 338 #qx, qy = np.meshgrid(qx, qy) 339 if 0: 340 level = np.asarray(255*(Iqxy - vmin)/(vmax - vmin), 'i') 341 level[level<0] = 0 342 colors = plt.get_cmap()(level) 343 ax.plot_surface(qx, qy, -1.1, rstride=1, cstride=1, facecolors=colors) 344 elif 1: 345 ax.contourf(qx/qx.max(), qy/qy.max(), Iqxy, zdir='z', offset=-1.1, 346 levels=np.linspace(vmin, vmax, 24)) 347 else: 348 ax.pcolormesh(qx, qy, Iqxy) 349 350 def build_model(model_name, n=150, qmax=0.5, **pars): 351 """ 352 Build a calculator for the given shape. 353 354 *model_name* is any sasmodels model. *n* and *qmax* define an n x n mesh 355 on which to evaluate the model. The remaining parameters are stored in 356 the returned calculator as *calculator.pars*. They are used by 357 :func:`draw_scattering` to set the non-orientation parameters in the 358 calculation. 359 360 Returns a *calculator* function which takes a dictionary or parameters and 361 produces Iqxy. The Iqxy value needs to be reshaped to an n x n matrix 362 for plotting. See the :class:`sasmodels.direct_model.DirectModel` class 363 for details. 364 """ 365 from sasmodels.core import load_model_info, build_model 366 from sasmodels.data import empty_data2D 367 from sasmodels.direct_model import DirectModel 368 369 model_info = load_model_info(model_name) 370 model = build_model(model_info) #, dtype='double!') 371 q = np.linspace(-qmax, qmax, n) 372 data = empty_data2D(q, q) 373 calculator = DirectModel(data, model) 374 375 # stuff the values for non-orientation parameters into the calculator 376 calculator.pars = pars.copy() 377 calculator.pars.setdefault('backgound', 1e-3) 378 379 # fix the data limits so that we can see if the pattern fades 380 # under rotation or angular dispersion 381 Iqxy = calculator(theta=0, phi=0, psi=0, **calculator.pars) 382 Iqxy = np.log(Iqxy) 383 vmin, vmax = clipped_range(Iqxy, 0.95, mode='top') 384 calculator.limits = vmin, vmax+1 385 386 return calculator 387 388 def select_calculator(model_name, n=150): 389 """ 390 Create a model calculator for the given shape. 391 392 *model_name* is one of sphere, cylinder, ellipsoid, triaxial_ellipsoid, 393 parallelepiped or bcc_paracrystal. *n* is the number of points to use 394 in the q range. *qmax* is chosen based on model parameters for the 395 given model to show something intersting. 396 397 Returns *calculator* and tuple *size* (a,b,c) giving minor and major 398 equitorial axes and polar axis respectively. See :func:`build_model` 399 for details on the returned calculator. 400 """ 401 a, b, c = 10, 40, 100 402 if model_name == 'sphere': 403 calculator = build_model('sphere', n=n, radius=c) 404 a = b = c 405 elif model_name == 'bcc_paracrystal': 406 calculator = build_model('bcc_paracrystal', n=n, dnn=c, 407 d_factor=0.06, radius=40) 408 a = b = c 409 elif model_name == 'cylinder': 410 calculator = build_model('cylinder', n=n, qmax=0.3, radius=b, length=c) 411 a = b 412 elif model_name == 'ellipsoid': 413 calculator = build_model('ellipsoid', n=n, qmax=1.0, 414 radius_polar=c, radius_equatorial=b) 415 a = b 416 elif model_name == 'triaxial_ellipsoid': 417 calculator = build_model('triaxial_ellipsoid', n=n, qmax=0.5, 418 radius_equat_minor=a, 419 radius_equat_major=b, 420 radius_polar=c) 421 elif model_name == 'parallelepiped': 422 calculator = build_model('parallelepiped', n=n, a=a, b=b, c=c) 423 else: 424 raise ValueError("unknown model %s"%model_name) 425 426 return calculator, (a, b, c) 427 428 def main(model_name='parallelepiped'): 429 """ 430 Show an interactive orientation and jitter demo. 431 432 *model_name* is one of the models available in :func:`select_model`. 433 """ 434 # set up calculator 435 calculator, size = select_calculator(model_name, n=150) 436 437 ## uncomment to set an independent the colour range for every view 438 ## If left commented, the colour range is fixed for all views 439 calculator.limits = None 440 441 ## use gaussian distribution unless testing integration 442 #dist = 'rectangle' 443 dist = 'gaussian' 444 445 ## initial view 446 #theta, dtheta = 70., 10. 447 #phi, dphi = -45., 3. 448 #psi, dpsi = -45., 3. 449 theta, phi, psi = 0, 0, 0 450 dtheta, dphi, dpsi = 0, 0, 0 451 452 ## create the plot window 223 453 #plt.hold(True) 224 454 plt.set_cmap('gist_earth') … … 227 457 #ax = plt.subplot(gs[0], projection='3d') 228 458 ax = plt.axes([0.0, 0.2, 1.0, 0.8], projection='3d') 229 230 theta, dtheta = 70., 10. 231 phi, dphi = -45., 3. 232 psi, dpsi = -45., 3. 233 theta, phi, psi = 0, 0, 0 234 dtheta, dphi, dpsi = 0, 0, 0 235 #dist = 'rect' 236 dist = 'gauss' 459 ax.axis('square') 237 460 238 461 axcolor = 'lightgoldenrodyellow' 239 462 463 ## add control widgets to plot 240 464 axtheta = plt.axes([0.1, 0.15, 0.45, 0.04], axisbg=axcolor) 241 465 axphi = plt.axes([0.1, 0.1, 0.45, 0.04], axisbg=axcolor) … … 248 472 axdphi = plt.axes([0.75, 0.1, 0.15, 0.04], axisbg=axcolor) 249 473 axdpsi= plt.axes([0.75, 0.05, 0.15, 0.04], axisbg=axcolor) 250 sdtheta = Slider(axdtheta, 'dTheta', 0, 30, valinit=dtheta) 251 sdphi = Slider(axdphi, 'dPhi', 0, 30, valinit=dphi) 252 sdpsi = Slider(axdpsi, 'dPsi', 0, 30, valinit=dpsi) 253 474 # Note: using ridiculous definition of rectangle distribution, whose width 475 # in sasmodels is sqrt(3) times the given width. Divide by sqrt(3) to keep 476 # the maximum width to 90. 477 dlimit = 30 if dist == 'gaussian' else 90/sqrt(3) 478 sdtheta = Slider(axdtheta, 'dTheta', 0, dlimit, valinit=dtheta) 479 sdphi = Slider(axdphi, 'dPhi', 0, 2*dlimit, valinit=dphi) 480 sdpsi = Slider(axdpsi, 'dPsi', 0, 2*dlimit, valinit=dpsi) 481 482 ## callback to draw the new view 254 483 def update(val, axis=None): 255 484 view = stheta.val, sphi.val, spsi.val 256 485 jitter = sdtheta.val, sdphi.val, sdpsi.val 486 # set small jitter as 0 if multiple pd dims 487 dims = sum(v > 0 for v in jitter) 488 limit = [0, 0, 2, 5][dims] 489 jitter = [0 if v < limit else v for v in jitter] 257 490 ax.cla() 258 491 draw_beam(ax, (0, 0)) 259 draw_jitter(ax, view, jitter )492 draw_jitter(ax, view, jitter, dist=dist, size=size) 260 493 #draw_jitter(ax, view, (0,0,0)) 261 draw_mesh(ax, view, jitter) 494 draw_mesh(ax, view, jitter, dist=dist) 495 draw_scattering(calculator, ax, view, jitter, dist=dist) 262 496 plt.gcf().canvas.draw() 263 497 498 ## bind control widgets to view updater 264 499 stheta.on_changed(lambda v: update(v,'theta')) 265 500 sphi.on_changed(lambda v: update(v, 'phi')) … … 269 504 sdpsi.on_changed(lambda v: update(v, 'dpsi')) 270 505 506 ## initialize view 271 507 update(None, 'phi') 272 508 509 ## go interactive 273 510 plt.show() 274 511 275 512 if __name__ == "__main__": 276 main() 513 model_name = sys.argv[1] if len(sys.argv) > 1 else 'parallelepiped' 514 main(model_name)
Note: See TracChangeset
for help on using the changeset viewer.