source: sasmodels/sasmodels/data.py @ 013adb7

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 013adb7 was 013adb7, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

fix vlimits on 2D plots during parameter exploration

  • Property mode set to 100644
File size: 13.6 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np
38
39def load_data(filename):
40    """
41    Load data using a sasview loader.
42    """
43    from sas.dataloader.loader import Loader
44    loader = Loader()
45    data = loader.load(filename)
46    if data is None:
47        raise IOError("Data %r could not be loaded" % filename)
48    return data
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54    """
55    from sas.dataloader.manipulations import Ringcut
56    if hasattr(data, 'qx_data'):
57        data.mask = Ringcut(0, radius)(data)
58        if outer is not None:
59            data.mask += Ringcut(outer, np.inf)(data)
60    else:
61        data.mask = (data.x < radius)
62        if outer is not None:
63            data.mask |= (data.x >= outer)
64
65
66def set_half(data, half):
67    """
68    Select half of the data, either "right" or "left".
69    """
70    from sas.dataloader.manipulations import Boxcut
71    if half == 'right':
72        data.mask += \
73            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
74    if half == 'left':
75        data.mask += \
76            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
77
78
79def set_top(data, cutoff):
80    """
81    Chop the top off the data, above *cutoff*.
82    """
83    from sas.dataloader.manipulations import Boxcut
84    data.mask += \
85        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
86
87
88class Data1D(object):
89    def __init__(self, x=None, y=None, dx=None, dy=None):
90        self.x, self.y, self.dx, self.dy = x, y, dx, dy
91        self.dxl = None
92
93    def xaxis(self, label, unit):
94        """
95        set the x axis label and unit
96        """
97        self._xaxis = label
98        self._xunit = unit
99
100    def yaxis(self, label, unit):
101        """
102        set the y axis label and unit
103        """
104        self._yaxis = label
105        self._yunit = unit
106
107
108
109class Data2D(object):
110    def __init__(self):
111        self.detector = []
112        self.source = Source()
113
114    def xaxis(self, label, unit):
115        """
116        set the x axis label and unit
117        """
118        self._xaxis = label
119        self._xunit = unit
120
121    def yaxis(self, label, unit):
122        """
123        set the y axis label and unit
124        """
125        self._yaxis = label
126        self._yunit = unit
127
128    def zaxis(self, label, unit):
129        """
130        set the y axis label and unit
131        """
132        self._zaxis = label
133        self._zunit = unit
134
135
136class Vector(object):
137    def __init__(self, x=None, y=None, z=None):
138        self.x, self.y, self.z = x, y, z
139
140class Detector(object):
141    def __init__(self):
142        self.pixel_size = Vector()
143
144class Source(object):
145    pass
146
147
148def empty_data1D(q, resolution=0.05):
149    """
150    Create empty 1D data using the given *q* as the x value.
151
152    *resolution* dq/q defaults to 5%.
153    """
154
155    #Iq = 100 * np.ones_like(q)
156    #dIq = np.sqrt(Iq)
157    Iq, dIq = None, None
158    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
159    data.filename = "fake data"
160    data.qmin, data.qmax = q.min(), q.max()
161    data.mask = np.zeros(len(q), dtype='bool')
162    return data
163
164
165def empty_data2D(qx, qy=None, resolution=0.05):
166    """
167    Create empty 2D data using the given mesh.
168
169    If *qy* is missing, create a square mesh with *qy=qx*.
170
171    *resolution* dq/q defaults to 5%.
172    """
173    if qy is None:
174        qy = qx
175    Qx, Qy = np.meshgrid(qx, qy)
176    Qx, Qy = Qx.flatten(), Qy.flatten()
177    Iq = 100 * np.ones_like(Qx)
178    dIq = np.sqrt(Iq)
179    mask = np.ones(len(Iq), dtype='bool')
180
181    data = Data2D()
182    data.filename = "fake data"
183    data.qx_data = Qx
184    data.qy_data = Qy
185    data.data = Iq
186    data.err_data = dIq
187    data.mask = mask
188    data.qmin = 1e-16
189    data.qmax = np.inf
190
191    # 5% dQ/Q resolution
192    if resolution != 0:
193        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
194        # Should have an additional constant which depends on distances and
195        # radii of the aperture, pixel dimensions and wavelength spread
196        # Instead, assume radial dQ/Q is constant, and perpendicular matches
197        # radial (which instead it should be inverse).
198        Q = np.sqrt(Qx**2 + Qy**2)
199        data.dqx_data = resolution * Q
200        data.dqy_data = resolution * Q
201    else:
202        data.dqx_data = data.dqy_data = None
203
204    detector = Detector()
205    detector.pixel_size.x = 5 # mm
206    detector.pixel_size.y = 5 # mm
207    detector.distance = 4 # m
208    data.detector.append(detector)
209    data.x_bins = qx
210    data.y_bins = qy
211    data.source.wavelength = 5 # angstroms
212    data.source.wavelength_unit = "A"
213    data.Q_unit = "1/A"
214    data.I_unit = "1/cm"
215    data.q_data = np.sqrt(Qx ** 2 + Qy ** 2)
216    data.xaxis("Q_x", "A^{-1}")
217    data.yaxis("Q_y", "A^{-1}")
218    data.zaxis("Intensity", r"\text{cm}^{-1}")
219    return data
220
221
222def plot_data(data, view='log', limits=None):
223    """
224    Plot data loaded by the sasview loader.
225    """
226    # Note: kind of weird using the plot result functions to plot just the
227    # data, but they already handle the masking and graph markup already, so
228    # do not repeat.
229    if hasattr(data, 'lam'):
230        _plot_result_sesans(data, None, None, plot_data=True, limits=limits)
231    elif hasattr(data, 'qx_data'):
232        _plot_result2D(data, None, None, view, plot_data=True, limits=limits)
233    else:
234        _plot_result1D(data, None, None, view, plot_data=True, limits=limits)
235
236
237def plot_theory(data, theory, resid=None, view='log',
238                plot_data=True, limits=None):
239    if hasattr(data, 'lam'):
240        _plot_result_sesans(data, theory, resid, plot_data=True, limits=limits)
241    elif hasattr(data, 'qx_data'):
242        _plot_result2D(data, theory, resid, view, plot_data, limits=limits)
243    else:
244        _plot_result1D(data, theory, resid, view, plot_data, limits=limits)
245
246
247def protect(fn):
248    def wrapper(*args, **kw):
249        try:
250            return fn(*args, **kw)
251        except:
252            traceback.print_exc()
253            pass
254
255    return wrapper
256
257
258@protect
259def _plot_result1D(data, theory, resid, view, plot_data, limits=None):
260    """
261    Plot the data and residuals for 1D data.
262    """
263    import matplotlib.pyplot as plt
264    from numpy.ma import masked_array, masked
265
266    plot_theory = theory is not None
267    plot_resid = resid is not None
268
269    if data.y is None:
270        plot_data = False
271    scale = data.x**4 if view == 'q4' else 1.0
272
273    if plot_data or plot_theory:
274        if plot_resid:
275            plt.subplot(121)
276
277        #print(vmin, vmax)
278        all_positive = True
279        some_present = False
280        if plot_data:
281            mdata = masked_array(data.y, data.mask.copy())
282            mdata[~np.isfinite(mdata)] = masked
283            if view is 'log':
284                mdata[mdata <= 0] = masked
285            plt.errorbar(data.x/10, scale*mdata, yerr=data.dy, fmt='.')
286            all_positive = all_positive and (mdata>0).all()
287            some_present = some_present or (mdata.count() > 0)
288
289
290        if plot_theory:
291            mtheory = masked_array(theory, data.mask.copy())
292            mtheory[~np.isfinite(mtheory)] = masked
293            if view is 'log':
294                mtheory[mtheory<=0] = masked
295            plt.plot(data.x/10, scale*mtheory, '-', hold=True)
296            all_positive = all_positive and (mtheory>0).all()
297            some_present = some_present or (mtheory.count() > 0)
298
299        if limits is not None:
300            plt.ylim(*limits)
301        plt.xscale('linear' if not some_present else view)
302        plt.yscale('linear'
303                   if view == 'q4' or not some_present or not all_positive
304                   else view)
305        plt.xlabel("$q$/nm$^{-1}$")
306        plt.ylabel('$I(q)$')
307
308    if plot_resid:
309        if plot_data or plot_theory:
310            plt.subplot(122)
311
312        mresid = masked_array(resid, data.mask.copy())
313        mresid[~np.isfinite(mresid)] = masked
314        some_present = (mresid.count() > 0)
315        plt.plot(data.x/10, mresid, '-')
316        plt.xlabel("$q$/nm$^{-1}$")
317        plt.ylabel('residuals')
318        plt.xscale('linear' if not some_present else view)
319
320
321@protect
322def _plot_result_sesans(data, theory, resid, plot_data, limits=None):
323    import matplotlib.pyplot as plt
324    if data.y is None:
325        plot_data = False
326    plot_theory = theory is not None
327    plot_resid = resid is not None
328
329    if plot_data or plot_theory:
330        if plot_resid:
331            plt.subplot(121)
332        if plot_data:
333            plt.errorbar(data.x, data.y, yerr=data.dy)
334        if theory is not None:
335            plt.plot(data.x, theory, '-', hold=True)
336        if limits is not None:
337            plt.ylim(*limits)
338        plt.xlabel('spin echo length (nm)')
339        plt.ylabel('polarization (P/P0)')
340
341    if resid is not None:
342        if plot_data or plot_theory:
343            plt.subplot(122)
344
345        plt.plot(data.x, resid, 'x')
346        plt.xlabel('spin echo length (nm)')
347        plt.ylabel('residuals (P/P0)')
348
349
350@protect
351def _plot_result2D(data, theory, resid, view, plot_data, limits=None):
352    """
353    Plot the data and residuals for 2D data.
354    """
355    import matplotlib.pyplot as plt
356    if data.data is None:
357        plot_data = False
358    plot_theory = theory is not None
359    plot_resid = resid is not None
360
361    # Put theory and data on a common colormap scale
362    if limits is None:
363        vmin, vmax = np.inf, -np.inf
364        if plot_data:
365            target = data.data[~data.mask]
366            datamin = target[target>0].min() if view == 'log' else target.min()
367            datamax = target.max()
368            vmin = min(vmin, datamin)
369            vmax = max(vmax, datamax)
370        if plot_theory:
371            theorymin = theory[theory>0].min() if view=='log' else theory.min()
372            theorymax = theory.max()
373            vmin = min(vmin, theorymin)
374            vmax = max(vmax, theorymax)
375    else:
376        vmin, vmax = limits
377
378    if plot_data:
379        if plot_theory and plot_resid:
380            plt.subplot(131)
381        elif plot_theory or plot_resid:
382            plt.subplot(121)
383        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
384        plt.title('data')
385        h = plt.colorbar()
386        h.set_label('$I(q)$')
387
388    if plot_theory:
389        if plot_data and plot_resid:
390            plt.subplot(132)
391        elif plot_data:
392            plt.subplot(122)
393        elif plot_resid:
394            plt.subplot(121)
395        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
396        plt.title('theory')
397        h = plt.colorbar()
398        h.set_label(r'$\log_{10}I(q)$' if view=='log'
399                    else r'$q^4 I(q)$' if view == 'q4'
400                    else '$I(q)$')
401
402    #if plot_data or plot_theory:
403    #    plt.colorbar()
404
405    if plot_resid:
406        if plot_data and plot_theory:
407            plt.subplot(133)
408        elif plot_data or plot_theory:
409            plt.subplot(122)
410        _plot_2d_signal(data, resid, view='linear')
411        plt.title('residuals')
412        h = plt.colorbar()
413        h.set_label('$\Delta I(q)$')
414
415
416@protect
417def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):
418    """
419    Plot the target value for the data.  This could be the data itself,
420    the theory calculation, or the residuals.
421
422    *scale* can be 'log' for log scale data, or 'linear'.
423    """
424    import matplotlib.pyplot as plt
425    from numpy.ma import masked_array
426
427    image = np.zeros_like(data.qx_data)
428    image[~data.mask] = signal
429    valid = np.isfinite(image)
430    if view == 'log':
431        valid[valid] = (image[valid] > 0)
432        if vmin is None: vmin = image[valid & ~data.mask].min()
433        if vmax is None: vmax = image[valid & ~data.mask].max()
434        image[valid] = np.log10(image[valid])
435    elif view == 'q4':
436        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
437        if vmin is None: vmin = image[valid & ~data.mask].min()
438        if vmax is None: vmax = image[valid & ~data.mask].max()
439    else:
440        if vmin is None: vmin = image[valid & ~data.mask].min()
441        if vmax is None: vmax = image[valid & ~data.mask].max()
442
443    image[~valid | data.mask] = 0
444    #plottable = Iq
445    plottable = masked_array(image, ~valid | data.mask)
446    xmin, xmax = min(data.qx_data)/10, max(data.qx_data)/10
447    ymin, ymax = min(data.qy_data)/10, max(data.qy_data)/10
448    if view == 'log':
449        vmin, vmax = np.log10(vmin), np.log10(vmax)
450    plt.imshow(plottable.reshape(len(data.x_bins), len(data.y_bins)),
451               interpolation='nearest', aspect=1, origin='upper',
452               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
453    plt.xlabel("$q_x$/nm$^{-1}$")
454    plt.ylabel("$q_y$/nm$^{-1}$")
455    return vmin, vmax
456
457def demo():
458    data = load_data('DEC07086.DAT')
459    set_beam_stop(data, 0.004)
460    plot_data(data)
461    import matplotlib.pyplot as plt; plt.show()
462
463
464if __name__ == "__main__":
465    demo()
Note: See TracBrowser for help on using the repository browser.