Changeset 7cf2cfd in sasmodels for sasmodels/bumps_model.py
- Timestamp:
- Nov 22, 2015 11:37:15 PM (8 years ago)
- Branches:
- master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
- Children:
- 3b4243d
- Parents:
- 677ccf1
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
sasmodels/bumps_model.py
r5d80bbf r7cf2cfd 10 10 how far the polydispersity integral extends. 11 11 12 A variety of helper functions are provided:13 14 :func:`load_data` loads a sasview data file.15 16 :func:`empty_data1D` creates an empty dataset, which is useful for plotting17 a theory function before the data is measured.18 19 :func:`empty_data2D` creates an empty 2D dataset.20 21 :func:`set_beam_stop` masks the beam stop from the data.22 23 :func:`set_half` selects the right or left half of the data, which can24 be useful for shear measurements which have not been properly corrected25 for path length and reflections.26 27 :func:`set_top` cuts the top part off the data.28 29 :func:`plot_data` plots the data file.30 31 :func:`plot_theory` plots a calculated result from the model.32 33 12 """ 34 13 35 14 import datetime 36 15 import warnings 37 import traceback38 16 39 17 import numpy as np 40 18 19 from bumps.names import Parameter 20 41 21 from . import sesans 42 from .resolution import Perfect1D, Pinhole1D, Slit1D 43 from .resolution2d import Pinhole2D 44 45 # CRUFT python 2.6 46 if not hasattr(datetime.timedelta, 'total_seconds'): 47 def delay(dt): 48 """Return number date-time delta as number seconds""" 49 return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds 50 else: 51 def delay(dt): 52 """Return number date-time delta as number seconds""" 53 return dt.total_seconds() 54 22 from . import weights 23 from .data import plot_theory 24 from .direct_model import DataMixin 55 25 56 26 # CRUFT: old style bumps wrapper which doesn't separate data and model … … 64 34 65 35 66 67 def tic():68 """69 Timer function.70 71 Use "toc=tic()" to start the clock and "toc()" to measure72 a time interval.73 """74 then = datetime.datetime.now()75 return lambda: delay(datetime.datetime.now() - then)76 77 78 def load_data(filename):79 """80 Load data using a sasview loader.81 """82 from sas.dataloader.loader import Loader83 loader = Loader()84 data = loader.load(filename)85 if data is None:86 raise IOError("Data %r could not be loaded" % filename)87 return data88 89 def plot_data(data, view='log'):90 """91 Plot data loaded by the sasview loader.92 """93 if hasattr(data, 'qx_data'):94 _plot_2d_signal(data, data.data, view=view)95 else:96 # Note: kind of weird using the _plot_result1D to plot just the97 # data, but it handles the masking and graph markup already, so98 # do not repeat.99 _plot_result1D(data, None, None, view)100 101 def plot_theory(data, theory, view='log'):102 if hasattr(data, 'qx_data'):103 _plot_2d_signal(data, theory, view=view)104 else:105 _plot_result1D(data, theory, None, view, include_data=False)106 107 108 def empty_data1D(q, resolution=0.05):109 """110 Create empty 1D data using the given *q* as the x value.111 112 *resolution* dq/q defaults to 5%.113 """114 115 from sas.dataloader.data_info import Data1D116 117 Iq = 100 * np.ones_like(q)118 dIq = np.sqrt(Iq)119 data = Data1D(q, Iq, dx=resolution * q, dy=dIq)120 data.filename = "fake data"121 data.qmin, data.qmax = q.min(), q.max()122 data.mask = np.zeros(len(Iq), dtype='bool')123 return data124 125 126 def empty_data2D(qx, qy=None, resolution=0.05):127 """128 Create empty 2D data using the given mesh.129 130 If *qy* is missing, create a square mesh with *qy=qx*.131 132 *resolution* dq/q defaults to 5%.133 """134 from sas.dataloader.data_info import Data2D, Detector135 136 if qy is None:137 qy = qx138 Qx, Qy = np.meshgrid(qx, qy)139 Qx, Qy = Qx.flatten(), Qy.flatten()140 Iq = 100 * np.ones_like(Qx)141 dIq = np.sqrt(Iq)142 mask = np.ones(len(Iq), dtype='bool')143 144 data = Data2D()145 data.filename = "fake data"146 data.qx_data = Qx147 data.qy_data = Qy148 data.data = Iq149 data.err_data = dIq150 data.mask = mask151 data.qmin = 1e-16152 data.qmax = np.inf153 154 # 5% dQ/Q resolution155 if resolution != 0:156 # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf157 # Should have an additional constant which depends on distances and158 # radii of the aperture, pixel dimensions and wavelength spread159 # Instead, assume radial dQ/Q is constant, and perpendicular matches160 # radial (which instead it should be inverse).161 Q = np.sqrt(Qx**2 + Qy**2)162 data.dqx_data = resolution * Q163 data.dqy_data = resolution * Q164 165 detector = Detector()166 detector.pixel_size.x = 5 # mm167 detector.pixel_size.y = 5 # mm168 detector.distance = 4 # m169 data.detector.append(detector)170 data.xbins = qx171 data.ybins = qy172 data.source.wavelength = 5 # angstroms173 data.source.wavelength_unit = "A"174 data.Q_unit = "1/A"175 data.I_unit = "1/cm"176 data.q_data = np.sqrt(Qx ** 2 + Qy ** 2)177 data.xaxis("Q_x", "A^{-1}")178 data.yaxis("Q_y", "A^{-1}")179 data.zaxis("Intensity", r"\text{cm}^{-1}")180 return data181 182 183 def set_beam_stop(data, radius, outer=None):184 """185 Add a beam stop of the given *radius*. If *outer*, make an annulus.186 """187 from sas.dataloader.manipulations import Ringcut188 if hasattr(data, 'qx_data'):189 data.mask = Ringcut(0, radius)(data)190 if outer is not None:191 data.mask += Ringcut(outer, np.inf)(data)192 else:193 data.mask = (data.x >= radius)194 if outer is not None:195 data.mask &= (data.x < outer)196 197 198 def set_half(data, half):199 """200 Select half of the data, either "right" or "left".201 """202 from sas.dataloader.manipulations import Boxcut203 if half == 'right':204 data.mask += \205 Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)206 if half == 'left':207 data.mask += \208 Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)209 210 211 def set_top(data, cutoff):212 """213 Chop the top off the data, above *cutoff*.214 """215 from sas.dataloader.manipulations import Boxcut216 data.mask += \217 Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)218 219 def protect(fn):220 def wrapper(*args, **kw):221 try:222 return fn(*args, **kw)223 except:224 traceback.print_exc()225 return wrapper226 227 @protect228 def _plot_result1D(data, theory, resid, view, include_data=True):229 """230 Plot the data and residuals for 1D data.231 """232 import matplotlib.pyplot as plt233 from numpy.ma import masked_array, masked234 #print "not a number",sum(np.isnan(data.y))235 #data.y[data.y<0.05] = 0.5236 mdata = masked_array(data.y, data.mask)237 mdata[~np.isfinite(mdata)] = masked238 if view is 'log':239 mdata[mdata <= 0] = masked240 241 scale = data.x**4 if view == 'q4' else 1.0242 if resid is not None:243 plt.subplot(121)244 245 positive = False246 if include_data:247 plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')248 positive = positive or (mdata>0).any()249 if theory is not None:250 mtheory = masked_array(theory, mdata.mask)251 plt.plot(data.x, scale*mtheory, '-', hold=True)252 positive = positive or (mtheory>0).any()253 plt.xscale(view)254 plt.yscale('linear' if view == 'q4' or not positive else view)255 plt.xlabel('Q')256 plt.ylabel('I(Q)')257 if resid is not None:258 mresid = masked_array(resid, mdata.mask)259 plt.subplot(122)260 plt.plot(data.x, mresid, 'x')261 plt.ylabel('residuals')262 plt.xlabel('Q')263 plt.xscale(view)264 265 # pylint: disable=unused-argument266 @protect267 def _plot_sesans(data, theory, resid, view):268 import matplotlib.pyplot as plt269 plt.subplot(121)270 plt.errorbar(data.x, data.y, yerr=data.dy)271 plt.plot(data.x, theory, '-', hold=True)272 plt.xlabel('spin echo length (nm)')273 plt.ylabel('polarization (P/P0)')274 plt.subplot(122)275 plt.plot(data.x, resid, 'x')276 plt.xlabel('spin echo length (nm)')277 plt.ylabel('residuals (P/P0)')278 279 @protect280 def _plot_result2D(data, theory, resid, view):281 """282 Plot the data and residuals for 2D data.283 """284 import matplotlib.pyplot as plt285 target = data.data[~data.mask]286 if view == 'log':287 vmin = min(target[target>0].min(), theory[theory>0].min())288 vmax = max(target.max(), theory.max())289 else:290 vmin = min(target.min(), theory.min())291 vmax = max(target.max(), theory.max())292 #print vmin, vmax293 plt.subplot(131)294 _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)295 plt.title('data')296 plt.colorbar()297 plt.subplot(132)298 _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)299 plt.title('theory')300 plt.colorbar()301 plt.subplot(133)302 _plot_2d_signal(data, resid, view='linear')303 plt.title('residuals')304 plt.colorbar()305 306 @protect307 def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):308 """309 Plot the target value for the data. This could be the data itself,310 the theory calculation, or the residuals.311 312 *scale* can be 'log' for log scale data, or 'linear'.313 """314 import matplotlib.pyplot as plt315 from numpy.ma import masked_array316 317 image = np.zeros_like(data.qx_data)318 image[~data.mask] = signal319 valid = np.isfinite(image)320 if view == 'log':321 valid[valid] = (image[valid] > 0)322 image[valid] = np.log10(image[valid])323 elif view == 'q4':324 image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2325 image[~valid | data.mask] = 0326 #plottable = Iq327 plottable = masked_array(image, ~valid | data.mask)328 xmin, xmax = min(data.qx_data), max(data.qx_data)329 ymin, ymax = min(data.qy_data), max(data.qy_data)330 # TODO: fix vmin, vmax so it is shared for theory/resid331 vmin = vmax = None332 try:333 if vmin is None: vmin = image[valid & ~data.mask].min()334 if vmax is None: vmax = image[valid & ~data.mask].max()335 except:336 vmin, vmax = 0, 1337 #print vmin,vmax338 plt.imshow(plottable.reshape(128, 128),339 interpolation='nearest', aspect=1, origin='upper',340 extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)341 342 343 36 class Model(object): 344 def __init__(self, kernel, **kw): 345 from bumps.names import Parameter 346 347 self.kernel = kernel 348 partype = kernel.info['partype'] 37 def __init__(self, model, **kw): 38 self._sasmodel = model 39 partype = model.info['partype'] 349 40 350 41 pars = [] 351 for p in kernel.info['parameters']:42 for p in model.info['parameters']: 352 43 name, default, limits = p[0], p[2], p[3] 353 44 value = kw.pop(name, default) … … 378 69 return dict((k, getattr(self, k)) for k in self._parameter_names) 379 70 380 class Experiment(object): 71 72 class Experiment(DataMixin): 381 73 """ 382 74 Return a bumps wrapper for a SAS model. … … 397 89 398 90 # remember inputs so we can inspect from outside 399 self.data = data400 91 self.model = model 401 92 self.cutoff = cutoff 402 if hasattr(data, 'lam'): 403 self.data_type = 'sesans' 404 elif hasattr(data, 'qx_data'): 405 self.data_type = 'Iqxy' 406 else: 407 self.data_type = 'Iq' 408 409 # interpret data 410 partype = model.kernel.info['partype'] 411 if self.data_type == 'sesans': 412 q = sesans.make_q(data.sample.zacceptance, data.Rmax) 413 self.index = slice(None, None) 414 self.Iq = data.y 415 self.dIq = data.dy 416 #self._theory = np.zeros_like(q) 417 q_vectors = [q] 418 elif self.data_type == 'Iqxy': 419 q = np.sqrt(data.qx_data**2 + data.qy_data**2) 420 qmin = getattr(data, 'qmin', 1e-16) 421 qmax = getattr(data, 'qmax', np.inf) 422 accuracy = getattr(data, 'accuracy', 'Low') 423 self.index = (~data.mask) & (~np.isnan(data.data)) \ 424 & (q >= qmin) & (q <= qmax) 425 self.Iq = data.data[self.index] 426 self.dIq = data.err_data[self.index] 427 self.resolution = Pinhole2D(data=data, index=self.index, 428 nsigma=3.0, accuracy=accuracy) 429 #self._theory = np.zeros_like(self.Iq) 430 if not partype['orientation'] and not partype['magnetic']: 431 raise ValueError("not 2D without orientation or magnetic parameters") 432 #qx,qy = self.resolution.q_calc 433 #q_vectors = [np.sqrt(qx**2 + qy**2)] 434 else: 435 q_vectors = self.resolution.q_calc 436 elif self.data_type == 'Iq': 437 self.index = (data.x >= data.qmin) & (data.x <= data.qmax) & ~np.isnan(data.y) 438 self.Iq = data.y[self.index] 439 self.dIq = data.dy[self.index] 440 if getattr(data, 'dx', None) is not None: 441 q, dq = data.x[self.index], data.dx[self.index] 442 if (dq>0).any(): 443 self.resolution = Pinhole1D(q, dq) 444 else: 445 self.resolution = Perfect1D(q) 446 elif (getattr(data, 'dxl', None) is not None and 447 getattr(data, 'dxw', None) is not None): 448 q = data.x[self.index] 449 width = data.dxh[self.index] # Note: dx 450 self.resolution = Slit1D(data.x[self.index], 451 width=data.dxh[self.index], 452 height=data.dxw[self.index]) 453 else: 454 self.resolution = Perfect1D(data.x[self.index]) 455 456 #self._theory = np.zeros_like(self.Iq) 457 q_vectors = [self.resolution.q_calc] 458 else: 459 raise ValueError("Unknown data type") # never gets here 460 461 # Remember function inputs so we can delay loading the function and 462 # so we can save/restore state 463 self._fn_inputs = [v for v in q_vectors] 464 self._fn = None 465 93 self._interpret_data(data, model._sasmodel) 466 94 self.update() 467 95 … … 483 111 def theory(self): 484 112 if 'theory' not in self._cache: 113 pars = dict((k, v.value) for k,v in self.model.parameters().items()) 114 self._cache['theory'] = self._calc_theory(pars, cutoff=self.cutoff) 115 """ 485 116 if self._fn is None: 486 q_input = self.model.kernel.make_input(self._ fn_inputs)117 q_input = self.model.kernel.make_input(self._kernel_inputs) 487 118 self._fn = self.model.kernel(q_input) 488 119 … … 495 126 result = sesans.hankel(self.data.x, self.data.lam * 1e-9, 496 127 self.data.sample.thickness / 10, 497 self._ fn_inputs[0], Iq_calc)128 self._kernel_inputs[0], Iq_calc) 498 129 self._cache['theory'] = result 499 130 else: 500 131 Iq = self.resolution.apply(Iq_calc) 501 132 self._cache['theory'] = Iq 133 """ 502 134 return self._cache['theory'] 503 135 … … 518 150 Plot the data and residuals. 519 151 """ 520 data, theory, resid = self.data, self.theory(), self.residuals() 521 if self.data_type == 'Iq': 522 _plot_result1D(data, theory, resid, view) 523 elif self.data_type == 'Iqxy': 524 _plot_result2D(data, theory, resid, view) 525 elif self.data_type == 'sesans': 526 _plot_sesans(data, theory, resid, view) 527 else: 528 raise ValueError("Unknown data type") 152 data, theory, resid = self._data, self.theory(), self.residuals() 153 plot_theory(data, theory, resid, view) 529 154 530 155 def simulate_data(self, noise=None): 531 theory = self.theory() 532 if noise is not None: 533 self.dIq = theory*noise*0.01 534 dy = self.dIq 535 y = theory + np.random.randn(*dy.shape) * dy 536 self.Iq = y 537 if self.data_type == 'Iq': 538 self.data.dy[self.index] = dy 539 self.data.y[self.index] = y 540 elif self.data_type == 'Iqxy': 541 self.data.data[self.index] = y 542 elif self.data_type == 'sesans': 543 self.data.y[self.index] = y 544 else: 545 raise ValueError("Unknown model") 156 Iq = self.theory() 157 self._set_data(Iq, noise) 546 158 547 159 def save(self, basename): 548 160 pass 549 161 550 def _get_weights(self, par):162 def remove_get_weights(self, name): 551 163 """ 552 164 Get parameter dispersion weights 553 165 """ 554 from . import weights 555 556 relative = self.model.kernel.info['partype']['pd-rel'] 557 limits = self.model.kernel.info['limits'] 166 info = self.model.kernel.info 167 relative = name in info['partype']['pd-rel'] 168 limits = info['limits'][name] 558 169 disperser, value, npts, width, nsigma = [ 559 getattr(self.model, par+ ext)170 getattr(self.model, name + ext) 560 171 for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')] 561 172 value, weight = weights.get_weights( 562 173 disperser, int(npts.value), width.value, nsigma.value, 563 value.value, limits [par], par inrelative)174 value.value, limits, relative) 564 175 return value, weight / np.sum(weight) 565 176 … … 567 178 # Can't pickle gpu functions, so instead make them lazy 568 179 state = self.__dict__.copy() 569 state['_ fn'] = None180 state['_kernel'] = None 570 181 return state 571 182 … … 573 184 # pylint: disable=attribute-defined-outside-init 574 185 self.__dict__ = state 575 576 577 def demo():578 data = load_data('DEC07086.DAT')579 set_beam_stop(data, 0.004)580 plot_data(data)581 import matplotlib.pyplot as plt; plt.show()582 583 584 if __name__ == "__main__":585 demo()
Note: See TracChangeset
for help on using the changeset viewer.