source: sasmodels/sasmodels/data.py @ 644430f

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 644430f was 644430f, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

fix plots. 1D: handle all NaN. 2D: handle arbitrary detector shape. both: add x,y labels

  • Property mode set to 100644
File size: 12.8 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np
38
39def load_data(filename):
40    """
41    Load data using a sasview loader.
42    """
43    from sas.dataloader.loader import Loader
44    loader = Loader()
45    data = loader.load(filename)
46    if data is None:
47        raise IOError("Data %r could not be loaded" % filename)
48    return data
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54    """
55    from sas.dataloader.manipulations import Ringcut
56    if hasattr(data, 'qx_data'):
57        data.mask = Ringcut(0, radius)(data)
58        if outer is not None:
59            data.mask += Ringcut(outer, np.inf)(data)
60    else:
61        data.mask = (data.x < radius)
62        if outer is not None:
63            data.mask |= (data.x >= outer)
64
65
66def set_half(data, half):
67    """
68    Select half of the data, either "right" or "left".
69    """
70    from sas.dataloader.manipulations import Boxcut
71    if half == 'right':
72        data.mask += \
73            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
74    if half == 'left':
75        data.mask += \
76            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
77
78
79def set_top(data, cutoff):
80    """
81    Chop the top off the data, above *cutoff*.
82    """
83    from sas.dataloader.manipulations import Boxcut
84    data.mask += \
85        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
86
87
88class Data1D(object):
89    def __init__(self, x=None, y=None, dx=None, dy=None):
90        self.x, self.y, self.dx, self.dy = x, y, dx, dy
91        self.dxl = None
92
93    def xaxis(self, label, unit):
94        """
95        set the x axis label and unit
96        """
97        self._xaxis = label
98        self._xunit = unit
99
100    def yaxis(self, label, unit):
101        """
102        set the y axis label and unit
103        """
104        self._yaxis = label
105        self._yunit = unit
106
107
108
109class Data2D(object):
110    def __init__(self):
111        self.detector = []
112        self.source = Source()
113
114    def xaxis(self, label, unit):
115        """
116        set the x axis label and unit
117        """
118        self._xaxis = label
119        self._xunit = unit
120
121    def yaxis(self, label, unit):
122        """
123        set the y axis label and unit
124        """
125        self._yaxis = label
126        self._yunit = unit
127
128    def zaxis(self, label, unit):
129        """
130        set the y axis label and unit
131        """
132        self._zaxis = label
133        self._zunit = unit
134
135
136class Vector(object):
137    def __init__(self, x=None, y=None, z=None):
138        self.x, self.y, self.z = x, y, z
139
140class Detector(object):
141    def __init__(self):
142        self.pixel_size = Vector()
143
144class Source(object):
145    pass
146
147
148def empty_data1D(q, resolution=0.05):
149    """
150    Create empty 1D data using the given *q* as the x value.
151
152    *resolution* dq/q defaults to 5%.
153    """
154
155    #Iq = 100 * np.ones_like(q)
156    #dIq = np.sqrt(Iq)
157    Iq, dIq = None, None
158    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
159    data.filename = "fake data"
160    data.qmin, data.qmax = q.min(), q.max()
161    data.mask = np.zeros(len(q), dtype='bool')
162    return data
163
164
165def empty_data2D(qx, qy=None, resolution=0.05):
166    """
167    Create empty 2D data using the given mesh.
168
169    If *qy* is missing, create a square mesh with *qy=qx*.
170
171    *resolution* dq/q defaults to 5%.
172    """
173    if qy is None:
174        qy = qx
175    Qx, Qy = np.meshgrid(qx, qy)
176    Qx, Qy = Qx.flatten(), Qy.flatten()
177    Iq = 100 * np.ones_like(Qx)
178    dIq = np.sqrt(Iq)
179    mask = np.ones(len(Iq), dtype='bool')
180
181    data = Data2D()
182    data.filename = "fake data"
183    data.qx_data = Qx
184    data.qy_data = Qy
185    data.data = Iq
186    data.err_data = dIq
187    data.mask = mask
188    data.qmin = 1e-16
189    data.qmax = np.inf
190
191    # 5% dQ/Q resolution
192    if resolution != 0:
193        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
194        # Should have an additional constant which depends on distances and
195        # radii of the aperture, pixel dimensions and wavelength spread
196        # Instead, assume radial dQ/Q is constant, and perpendicular matches
197        # radial (which instead it should be inverse).
198        Q = np.sqrt(Qx**2 + Qy**2)
199        data.dqx_data = resolution * Q
200        data.dqy_data = resolution * Q
201    else:
202        data.dqx_data = data.dqy_data = None
203
204    detector = Detector()
205    detector.pixel_size.x = 5 # mm
206    detector.pixel_size.y = 5 # mm
207    detector.distance = 4 # m
208    data.detector.append(detector)
209    data.xbins = qx
210    data.ybins = qy
211    data.source.wavelength = 5 # angstroms
212    data.source.wavelength_unit = "A"
213    data.Q_unit = "1/A"
214    data.I_unit = "1/cm"
215    data.q_data = np.sqrt(Qx ** 2 + Qy ** 2)
216    data.xaxis("Q_x", "A^{-1}")
217    data.yaxis("Q_y", "A^{-1}")
218    data.zaxis("Intensity", r"\text{cm}^{-1}")
219    return data
220
221
222def plot_data(data, view='log'):
223    """
224    Plot data loaded by the sasview loader.
225    """
226    # Note: kind of weird using the plot result functions to plot just the
227    # data, but they already handle the masking and graph markup already, so
228    # do not repeat.
229    if hasattr(data, 'lam'):
230        _plot_result_sesans(data, None, None, plot_data=True)
231    elif hasattr(data, 'qx_data'):
232        _plot_result2D(data, None, None, view, plot_data=True)
233    else:
234        _plot_result1D(data, None, None, view, plot_data=True)
235
236
237def plot_theory(data, theory, resid=None, view='log', plot_data=True):
238    if hasattr(data, 'lam'):
239        _plot_result_sesans(data, theory, resid, plot_data=True)
240    elif hasattr(data, 'qx_data'):
241        _plot_result2D(data, theory, resid, view, plot_data)
242    else:
243        _plot_result1D(data, theory, resid, view, plot_data)
244
245
246def protect(fn):
247    def wrapper(*args, **kw):
248        try:
249            return fn(*args, **kw)
250        except:
251            traceback.print_exc()
252            pass
253
254    return wrapper
255
256
257@protect
258def _plot_result1D(data, theory, resid, view, plot_data):
259    """
260    Plot the data and residuals for 1D data.
261    """
262    import matplotlib.pyplot as plt
263    from numpy.ma import masked_array, masked
264
265    plot_theory = theory is not None
266    plot_resid = resid is not None
267
268    if data.y is None:
269        plot_data = False
270    scale = data.x**4 if view == 'q4' else 1.0
271
272    if plot_data or plot_theory:
273        if plot_resid:
274            plt.subplot(121)
275
276        #print(vmin, vmax)
277        all_positive = True
278        some_present = False
279        if plot_data:
280            mdata = masked_array(data.y, data.mask.copy())
281            mdata[~np.isfinite(mdata)] = masked
282            if view is 'log':
283                mdata[mdata <= 0] = masked
284            plt.errorbar(data.x/10, scale*mdata, yerr=data.dy, fmt='.')
285            all_positive = all_positive and (mdata>0).all()
286            some_present = some_present or (mdata.count() > 0)
287
288
289        if plot_theory:
290            mtheory = masked_array(theory, data.mask.copy())
291            mtheory[~np.isfinite(mtheory)] = masked
292            if view is 'log':
293                mtheory[mtheory<=0] = masked
294            plt.plot(data.x/10, scale*mtheory, '-', hold=True)
295            all_positive = all_positive and (mtheory>0).all()
296            some_present = some_present or (mtheory.count() > 0)
297
298        plt.xscale('linear' if not some_present else view)
299        plt.yscale('linear'
300                   if view == 'q4' or not some_present or not all_positive
301                   else view)
302        plt.xlabel("$q$/nm$^{-1}$")
303        plt.ylabel('$I(q)$')
304
305    if plot_resid:
306        if plot_data or plot_theory:
307            plt.subplot(122)
308
309        mresid = masked_array(resid, data.mask.copy())
310        mresid[~np.isfinite(mresid)] = masked
311        some_present = (mresid.count() > 0)
312        plt.plot(data.x/10, mresid, '-')
313        plt.xlabel("$q$/nm$^{-1}$")
314        plt.ylabel('residuals')
315        plt.xscale('linear' if not some_present else view)
316
317
318@protect
319def _plot_result_sesans(data, theory, resid, plot_data):
320    import matplotlib.pyplot as plt
321    if data.y is None:
322        plot_data = False
323    plot_theory = theory is not None
324    plot_resid = resid is not None
325
326    if plot_data or plot_theory:
327        if plot_resid:
328            plt.subplot(121)
329        if plot_data:
330            plt.errorbar(data.x, data.y, yerr=data.dy)
331        if theory is not None:
332            plt.plot(data.x, theory, '-', hold=True)
333        plt.xlabel('spin echo length (nm)')
334        plt.ylabel('polarization (P/P0)')
335
336    if resid is not None:
337        if plot_data or plot_theory:
338            plt.subplot(122)
339
340        plt.plot(data.x, resid, 'x')
341        plt.xlabel('spin echo length (nm)')
342        plt.ylabel('residuals (P/P0)')
343
344
345@protect
346def _plot_result2D(data, theory, resid, view, plot_data):
347    """
348    Plot the data and residuals for 2D data.
349    """
350    import matplotlib.pyplot as plt
351    if data.data is None:
352        plot_data = False
353    plot_theory = theory is not None
354    plot_resid = resid is not None
355
356    # Put theory and data on a common colormap scale
357    vmin, vmax = np.inf, -np.inf
358    if plot_data:
359        target = data.data[~data.mask]
360        datamin = target[target>0].min() if view == 'log' else target.min()
361        datamax = target.max()
362        vmin = min(vmin, datamin)
363        vmax = max(vmax, datamax)
364    if plot_theory:
365        theorymin = theory[theory>0].min() if view == 'log' else theory.min()
366        theorymax = theory.max()
367        vmin = min(vmin, theorymin)
368        vmax = max(vmax, theorymax)
369
370    if plot_data:
371        if plot_theory and plot_resid:
372            plt.subplot(131)
373        elif plot_theory or plot_resid:
374            plt.subplot(121)
375        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
376        plt.title('data')
377        h = plt.colorbar()
378        h.set_label('$I(q)$')
379
380    if plot_theory:
381        if plot_data and plot_resid:
382            plt.subplot(132)
383        elif plot_data:
384            plt.subplot(122)
385        elif plot_resid:
386            plt.subplot(121)
387        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
388        plt.title('theory')
389        h = plt.colorbar()
390        h.set_label('$I(q)$')
391
392    #if plot_data or plot_theory:
393    #    plt.colorbar()
394
395    if plot_resid:
396        if plot_data and plot_theory:
397            plt.subplot(133)
398        elif plot_data or plot_theory:
399            plt.subplot(122)
400        _plot_2d_signal(data, resid, view='linear')
401        plt.title('residuals')
402        h = plt.colorbar()
403        h.set_label('$\Delta I(q)$')
404
405
406@protect
407def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):
408    """
409    Plot the target value for the data.  This could be the data itself,
410    the theory calculation, or the residuals.
411
412    *scale* can be 'log' for log scale data, or 'linear'.
413    """
414    import matplotlib.pyplot as plt
415    from numpy.ma import masked_array
416
417    image = np.zeros_like(data.qx_data)
418    image[~data.mask] = signal
419    valid = np.isfinite(image)
420    if view == 'log':
421        valid[valid] = (image[valid] > 0)
422        image[valid] = np.log10(image[valid])
423    elif view == 'q4':
424        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
425    image[~valid | data.mask] = 0
426    #plottable = Iq
427    plottable = masked_array(image, ~valid | data.mask)
428    xmin, xmax = min(data.qx_data)/10, max(data.qx_data)/10
429    ymin, ymax = min(data.qy_data)/10, max(data.qy_data)/10
430    # TODO: fix vmin, vmax so it is shared for theory/resid
431    vmin = vmax = None
432    try:
433        if vmin is None: vmin = image[valid & ~data.mask].min()
434        if vmax is None: vmax = image[valid & ~data.mask].max()
435    except:
436        vmin, vmax = 0, 1
437    plt.imshow(plottable.reshape(len(data.xbins), len(data.ybins)),
438               interpolation='nearest', aspect=1, origin='upper',
439               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
440    plt.xlabel("$q_x$/nm$^{-1}$")
441    plt.ylabel("$q_y$/nm$^{-1}$")
442
443
444def demo():
445    data = load_data('DEC07086.DAT')
446    set_beam_stop(data, 0.004)
447    plot_data(data)
448    import matplotlib.pyplot as plt; plt.show()
449
450
451if __name__ == "__main__":
452    demo()
Note: See TracBrowser for help on using the repository browser.