source: sasmodels/sasmodels/data.py @ 2c1bb7b0

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 2c1bb7b0 was 2c1bb7b0, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

fix mask broken by recent data refactoring

  • Property mode set to 100644
File size: 14.4 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np
38
39def load_data(filename):
40    """
41    Load data using a sasview loader.
42    """
43    from sas.dataloader.loader import Loader
44    loader = Loader()
45    data = loader.load(filename)
46    if data is None:
47        raise IOError("Data %r could not be loaded" % filename)
48    return data
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54    """
55    from sas.dataloader.manipulations import Ringcut
56    if hasattr(data, 'qx_data'):
57        data.mask = Ringcut(0, radius)(data)
58        if outer is not None:
59            data.mask += Ringcut(outer, np.inf)(data)
60    else:
61        data.mask = (data.x < radius)
62        if outer is not None:
63            data.mask |= (data.x >= outer)
64
65
66def set_half(data, half):
67    """
68    Select half of the data, either "right" or "left".
69    """
70    from sas.dataloader.manipulations import Boxcut
71    if half == 'right':
72        data.mask += \
73            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
74    if half == 'left':
75        data.mask += \
76            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
77
78
79def set_top(data, cutoff):
80    """
81    Chop the top off the data, above *cutoff*.
82    """
83    from sas.dataloader.manipulations import Boxcut
84    data.mask += \
85        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
86
87
88class Data1D(object):
89    def __init__(self, x=None, y=None, dx=None, dy=None):
90        self.x, self.y, self.dx, self.dy = x, y, dx, dy
91        self.dxl = None
92        self.filename = None
93        self.qmin = x.min() if x is not None else np.NaN
94        self.qmax = x.max() if x is not None else np.NaN
95        # TODO: why is 1D mask False and 2D mask True?
96        self.mask = (np.isnan(y) if y is not None
97                     else np.zeros_like(x,'b') if x is not None
98                     else None)
99        self._xaxis, self._xunit = "x", ""
100        self._yaxis, self._yunit = "y", ""
101
102    def xaxis(self, label, unit):
103        """
104        set the x axis label and unit
105        """
106        self._xaxis = label
107        self._xunit = unit
108
109    def yaxis(self, label, unit):
110        """
111        set the y axis label and unit
112        """
113        self._yaxis = label
114        self._yunit = unit
115
116
117
118class Data2D(object):
119    def __init__(self, x=None, y=None, z=None, dx=None, dy=None, dz=None):
120        self.qx_data, self.dqx_data = x, dx
121        self.qy_data, self.dqy_data = y, dy
122        self.data, self.err_data = z, dz
123        self.mask = (~np.isnan(z) if z is not None
124                     else np.ones_like(x) if x is not None
125                     else None)
126        self.q_data = np.sqrt(x**2 + y**2)
127        self.qmin = 1e-16
128        self.qmax = np.inf
129        self.detector = []
130        self.source = Source()
131        self.Q_unit = "1/A"
132        self.I_unit = "1/cm"
133        self.xaxis("Q_x", "A^{-1}")
134        self.yaxis("Q_y", "A^{-1}")
135        self.zaxis("Intensity", r"\text{cm}^{-1}")
136        self._xaxis, self._xunit = "x", ""
137        self._yaxis, self._yunit = "y", ""
138        self._zaxis, self._zunit = "z", ""
139        self.x_bins, self.y_bins = None, None
140
141    def xaxis(self, label, unit):
142        """
143        set the x axis label and unit
144        """
145        self._xaxis = label
146        self._xunit = unit
147
148    def yaxis(self, label, unit):
149        """
150        set the y axis label and unit
151        """
152        self._yaxis = label
153        self._yunit = unit
154
155    def zaxis(self, label, unit):
156        """
157        set the y axis label and unit
158        """
159        self._zaxis = label
160        self._zunit = unit
161
162
163class Vector(object):
164    def __init__(self, x=None, y=None, z=None):
165        self.x, self.y, self.z = x, y, z
166
167class Detector(object):
168    """
169    Detector attributes.
170    """
171    def __init__(self, pixel_size=(None, None), distance=None):
172        self.pixel_size = Vector(*pixel_size)
173        self.distance = distance
174
175class Source(object):
176    """
177    Beam attributes.
178    """
179    def __init__(self):
180        self.wavelength = np.NaN
181        self.wavelength_unit = "A"
182
183
184def empty_data1D(q, resolution=0.05):
185    """
186    Create empty 1D data using the given *q* as the x value.
187
188    *resolution* dq/q defaults to 5%.
189    """
190
191    #Iq = 100 * np.ones_like(q)
192    #dIq = np.sqrt(Iq)
193    Iq, dIq = None, None
194    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
195    data.filename = "fake data"
196    return data
197
198
199def empty_data2D(qx, qy=None, resolution=0.05):
200    """
201    Create empty 2D data using the given mesh.
202
203    If *qy* is missing, create a square mesh with *qy=qx*.
204
205    *resolution* dq/q defaults to 5%.
206    """
207    if qy is None:
208        qy = qx
209    # 5% dQ/Q resolution
210    Qx, Qy = np.meshgrid(qx, qy)
211    Qx, Qy = Qx.flatten(), Qy.flatten()
212    Iq = 100 * np.ones_like(Qx)
213    dIq = np.sqrt(Iq)
214    if resolution != 0:
215        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
216        # Should have an additional constant which depends on distances and
217        # radii of the aperture, pixel dimensions and wavelength spread
218        # Instead, assume radial dQ/Q is constant, and perpendicular matches
219        # radial (which instead it should be inverse).
220        Q = np.sqrt(Qx**2 + Qy**2)
221        dqx = resolution * Q
222        dqy = resolution * Q
223    else:
224        dqx = dqy = None
225
226    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
227    data.x_bins = qx
228    data.y_bins = qy
229    data.filename = "fake data"
230
231    # pixel_size in mm, distance in m
232    detector = Detector(pixel_size=(5, 5), distance=4)
233    data.detector.append(detector)
234    data.source.wavelength = 5 # angstroms
235    data.source.wavelength_unit = "A"
236    return data
237
238
239def plot_data(data, view='log', limits=None):
240    """
241    Plot data loaded by the sasview loader.
242    """
243    # Note: kind of weird using the plot result functions to plot just the
244    # data, but they already handle the masking and graph markup already, so
245    # do not repeat.
246    if hasattr(data, 'lam'):
247        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
248    elif hasattr(data, 'qx_data'):
249        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
250    else:
251        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
252
253
254def plot_theory(data, theory, resid=None, view='log',
255                use_data=True, limits=None):
256    if hasattr(data, 'lam'):
257        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
258    elif hasattr(data, 'qx_data'):
259        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
260    else:
261        _plot_result1D(data, theory, resid, view, use_data, limits=limits)
262
263
264def protect(fn):
265    def wrapper(*args, **kw):
266        try:
267            return fn(*args, **kw)
268        except:
269            traceback.print_exc()
270
271    return wrapper
272
273
274@protect
275def _plot_result1D(data, theory, resid, view, use_data, limits=None):
276    """
277    Plot the data and residuals for 1D data.
278    """
279    import matplotlib.pyplot as plt
280    from numpy.ma import masked_array, masked
281
282    use_data = use_data and data.y is not None
283    use_theory = theory is not None
284    use_resid = resid is not None
285    num_plots = (use_data or use_theory) + use_resid
286
287    scale = data.x**4 if view == 'q4' else 1.0
288
289    if use_data or use_theory:
290        #print(vmin, vmax)
291        all_positive = True
292        some_present = False
293        if use_data:
294            mdata = masked_array(data.y, data.mask.copy())
295            mdata[~np.isfinite(mdata)] = masked
296            if view is 'log':
297                mdata[mdata <= 0] = masked
298            plt.errorbar(data.x/10, scale*mdata, yerr=data.dy, fmt='.')
299            all_positive = all_positive and (mdata > 0).all()
300            some_present = some_present or (mdata.count() > 0)
301
302
303        if use_theory:
304            mtheory = masked_array(theory, data.mask.copy())
305            mtheory[~np.isfinite(mtheory)] = masked
306            if view is 'log':
307                mtheory[mtheory <= 0] = masked
308            plt.plot(data.x/10, scale*mtheory, '-', hold=True)
309            all_positive = all_positive and (mtheory > 0).all()
310            some_present = some_present or (mtheory.count() > 0)
311
312        if limits is not None:
313            plt.ylim(*limits)
314
315        if num_plots > 1:
316            plt.subplot(1, num_plots, 1)
317        plt.xscale('linear' if not some_present else view)
318        plt.yscale('linear'
319                   if view == 'q4' or not some_present or not all_positive
320                   else view)
321        plt.xlabel("$q$/nm$^{-1}$")
322        plt.ylabel('$I(q)$')
323
324    if use_resid:
325        mresid = masked_array(resid, data.mask.copy())
326        mresid[~np.isfinite(mresid)] = masked
327        some_present = (mresid.count() > 0)
328
329        if num_plots > 1:
330            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
331        plt.plot(data.x/10, mresid, '-')
332        plt.xlabel("$q$/nm$^{-1}$")
333        plt.ylabel('residuals')
334        plt.xscale('linear' if not some_present else view)
335
336
337@protect
338def _plot_result_sesans(data, theory, resid, use_data, limits=None):
339    import matplotlib.pyplot as plt
340    use_data = use_data and data.y is not None
341    use_theory = theory is not None
342    use_resid = resid is not None
343    num_plots = (use_data or use_theory) + use_resid
344
345    if use_data or use_theory:
346        if num_plots > 1:
347            plt.subplot(1, num_plots, 1)
348        if use_data:
349            plt.errorbar(data.x, data.y, yerr=data.dy)
350        if theory is not None:
351            plt.plot(data.x, theory, '-', hold=True)
352        if limits is not None:
353            plt.ylim(*limits)
354        plt.xlabel('spin echo length (nm)')
355        plt.ylabel('polarization (P/P0)')
356
357    if resid is not None:
358        if num_plots > 1:
359            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
360        plt.plot(data.x, resid, 'x')
361        plt.xlabel('spin echo length (nm)')
362        plt.ylabel('residuals (P/P0)')
363
364
365@protect
366def _plot_result2D(data, theory, resid, view, use_data, limits=None):
367    """
368    Plot the data and residuals for 2D data.
369    """
370    import matplotlib.pyplot as plt
371    use_data = use_data and data.data is not None
372    use_theory = theory is not None
373    use_resid = resid is not None
374    num_plots = use_data + use_theory + use_resid
375
376    # Put theory and data on a common colormap scale
377    vmin, vmax = np.inf, -np.inf
378    if use_data:
379        target = data.data[~data.mask]
380        datamin = target[target > 0].min() if view == 'log' else target.min()
381        datamax = target.max()
382        vmin = min(vmin, datamin)
383        vmax = max(vmax, datamax)
384    if use_theory:
385        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
386        theorymax = theory.max()
387        vmin = min(vmin, theorymin)
388        vmax = max(vmax, theorymax)
389
390    # Override data limits from the caller
391    if limits is not None:
392        vmin, vmax = limits
393
394    # Plot data
395    if use_data:
396        if num_plots > 1:
397            plt.subplot(1, num_plots, 1)
398        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
399        plt.title('data')
400        h = plt.colorbar()
401        h.set_label('$I(q)$')
402
403    # plot theory
404    if use_theory:
405        if num_plots > 1:
406            plt.subplot(1, num_plots, use_data+1)
407        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
408        plt.title('theory')
409        h = plt.colorbar()
410        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
411                    else r'$q^4 I(q)$' if view == 'q4'
412                    else '$I(q)$')
413
414    # plot resid
415    if use_resid:
416        if num_plots > 1:
417            plt.subplot(1, num_plots, use_data+use_theory+1)
418        _plot_2d_signal(data, resid, view='linear')
419        plt.title('residuals')
420        h = plt.colorbar()
421        h.set_label(r'$\Delta I(q)$')
422
423
424@protect
425def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):
426    """
427    Plot the target value for the data.  This could be the data itself,
428    the theory calculation, or the residuals.
429
430    *scale* can be 'log' for log scale data, or 'linear'.
431    """
432    import matplotlib.pyplot as plt
433    from numpy.ma import masked_array
434
435    image = np.zeros_like(data.qx_data)
436    image[~data.mask] = signal
437    valid = np.isfinite(image)
438    if view == 'log':
439        valid[valid] = (image[valid] > 0)
440        if vmin is None: vmin = image[valid & ~data.mask].min()
441        if vmax is None: vmax = image[valid & ~data.mask].max()
442        image[valid] = np.log10(image[valid])
443    elif view == 'q4':
444        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
445        if vmin is None: vmin = image[valid & ~data.mask].min()
446        if vmax is None: vmax = image[valid & ~data.mask].max()
447    else:
448        if vmin is None: vmin = image[valid & ~data.mask].min()
449        if vmax is None: vmax = image[valid & ~data.mask].max()
450
451    image[~valid | data.mask] = 0
452    #plottable = Iq
453    plottable = masked_array(image, ~valid | data.mask)
454    xmin, xmax = min(data.qx_data)/10, max(data.qx_data)/10
455    ymin, ymax = min(data.qy_data)/10, max(data.qy_data)/10
456    if view == 'log':
457        vmin, vmax = np.log10(vmin), np.log10(vmax)
458    plt.imshow(plottable.reshape(len(data.x_bins), len(data.y_bins)),
459               interpolation='nearest', aspect=1, origin='upper',
460               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
461    plt.xlabel("$q_x$/nm$^{-1}$")
462    plt.ylabel("$q_y$/nm$^{-1}$")
463    return vmin, vmax
464
465def demo():
466    data = load_data('DEC07086.DAT')
467    set_beam_stop(data, 0.004)
468    plot_data(data)
469    import matplotlib.pyplot as plt; plt.show()
470
471
472if __name__ == "__main__":
473    demo()
Note: See TracBrowser for help on using the repository browser.