source: sasmodels/sasmodels/data.py @ 69ec80f

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 69ec80f was 69ec80f, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor code to reduce lint count

  • Property mode set to 100644
File size: 14.2 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np
38
39def load_data(filename):
40    """
41    Load data using a sasview loader.
42    """
43    from sas.dataloader.loader import Loader
44    loader = Loader()
45    data = loader.load(filename)
46    if data is None:
47        raise IOError("Data %r could not be loaded" % filename)
48    return data
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54    """
55    from sas.dataloader.manipulations import Ringcut
56    if hasattr(data, 'qx_data'):
57        data.mask = Ringcut(0, radius)(data)
58        if outer is not None:
59            data.mask += Ringcut(outer, np.inf)(data)
60    else:
61        data.mask = (data.x < radius)
62        if outer is not None:
63            data.mask |= (data.x >= outer)
64
65
66def set_half(data, half):
67    """
68    Select half of the data, either "right" or "left".
69    """
70    from sas.dataloader.manipulations import Boxcut
71    if half == 'right':
72        data.mask += \
73            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
74    if half == 'left':
75        data.mask += \
76            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
77
78
79def set_top(data, cutoff):
80    """
81    Chop the top off the data, above *cutoff*.
82    """
83    from sas.dataloader.manipulations import Boxcut
84    data.mask += \
85        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
86
87
88class Data1D(object):
89    def __init__(self, x=None, y=None, dx=None, dy=None):
90        self.x, self.y, self.dx, self.dy = x, y, dx, dy
91        self.dxl = None
92        self.filename = None
93        self.qmin = x.min() if x is not None else np.NaN
94        self.qmax = x.max() if x is not None else np.NaN
95        self.mask = np.isnan(y) if y is not None else None
96        self._xaxis, self._xunit = "x", ""
97        self._yaxis, self._yunit = "y", ""
98
99    def xaxis(self, label, unit):
100        """
101        set the x axis label and unit
102        """
103        self._xaxis = label
104        self._xunit = unit
105
106    def yaxis(self, label, unit):
107        """
108        set the y axis label and unit
109        """
110        self._yaxis = label
111        self._yunit = unit
112
113
114
115class Data2D(object):
116    def __init__(self, x=None, y=None, z=None, dx=None, dy=None, dz=None):
117        self.qx_data, self.dqx_data = x, dx
118        self.qy_data, self.dqy_data = y, dy
119        self.data, self.err_data = z, dz
120        self.mask = ~np.isnan(z) if z is not None else None
121        self.q_data = np.sqrt(x**2 + y**2)
122        self.qmin = 1e-16
123        self.qmax = np.inf
124        self.detector = []
125        self.source = Source()
126        self.Q_unit = "1/A"
127        self.I_unit = "1/cm"
128        self.xaxis("Q_x", "A^{-1}")
129        self.yaxis("Q_y", "A^{-1}")
130        self.zaxis("Intensity", r"\text{cm}^{-1}")
131        self._xaxis, self._xunit = "x", ""
132        self._yaxis, self._yunit = "y", ""
133        self._zaxis, self._zunit = "z", ""
134        self.x_bins, self.y_bins = None, None
135
136    def xaxis(self, label, unit):
137        """
138        set the x axis label and unit
139        """
140        self._xaxis = label
141        self._xunit = unit
142
143    def yaxis(self, label, unit):
144        """
145        set the y axis label and unit
146        """
147        self._yaxis = label
148        self._yunit = unit
149
150    def zaxis(self, label, unit):
151        """
152        set the y axis label and unit
153        """
154        self._zaxis = label
155        self._zunit = unit
156
157
158class Vector(object):
159    def __init__(self, x=None, y=None, z=None):
160        self.x, self.y, self.z = x, y, z
161
162class Detector(object):
163    """
164    Detector attributes.
165    """
166    def __init__(self, pixel_size=(None, None), distance=None):
167        self.pixel_size = Vector(*pixel_size)
168        self.distance = distance
169
170class Source(object):
171    """
172    Beam attributes.
173    """
174    def __init__(self):
175        self.wavelength = np.NaN
176        self.wavelength_unit = "A"
177
178
179def empty_data1D(q, resolution=0.05):
180    """
181    Create empty 1D data using the given *q* as the x value.
182
183    *resolution* dq/q defaults to 5%.
184    """
185
186    #Iq = 100 * np.ones_like(q)
187    #dIq = np.sqrt(Iq)
188    Iq, dIq = None, None
189    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
190    data.filename = "fake data"
191    return data
192
193
194def empty_data2D(qx, qy=None, resolution=0.05):
195    """
196    Create empty 2D data using the given mesh.
197
198    If *qy* is missing, create a square mesh with *qy=qx*.
199
200    *resolution* dq/q defaults to 5%.
201    """
202    if qy is None:
203        qy = qx
204    # 5% dQ/Q resolution
205    Qx, Qy = np.meshgrid(qx, qy)
206    Qx, Qy = Qx.flatten(), Qy.flatten()
207    Iq = 100 * np.ones_like(Qx)
208    dIq = np.sqrt(Iq)
209    if resolution != 0:
210        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
211        # Should have an additional constant which depends on distances and
212        # radii of the aperture, pixel dimensions and wavelength spread
213        # Instead, assume radial dQ/Q is constant, and perpendicular matches
214        # radial (which instead it should be inverse).
215        Q = np.sqrt(Qx**2 + Qy**2)
216        dqx = resolution * Q
217        dqy = resolution * Q
218    else:
219        dqx = dqy = None
220
221    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
222    data.x_bins = qx
223    data.y_bins = qy
224    data.filename = "fake data"
225
226    # pixel_size in mm, distance in m
227    detector = Detector(pixel_size=(5, 5), distance=4)
228    data.detector.append(detector)
229    data.source.wavelength = 5 # angstroms
230    data.source.wavelength_unit = "A"
231    return data
232
233
234def plot_data(data, view='log', limits=None):
235    """
236    Plot data loaded by the sasview loader.
237    """
238    # Note: kind of weird using the plot result functions to plot just the
239    # data, but they already handle the masking and graph markup already, so
240    # do not repeat.
241    if hasattr(data, 'lam'):
242        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
243    elif hasattr(data, 'qx_data'):
244        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
245    else:
246        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
247
248
249def plot_theory(data, theory, resid=None, view='log',
250                use_data=True, limits=None):
251    if hasattr(data, 'lam'):
252        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
253    elif hasattr(data, 'qx_data'):
254        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
255    else:
256        _plot_result1D(data, theory, resid, view, use_data, limits=limits)
257
258
259def protect(fn):
260    def wrapper(*args, **kw):
261        try:
262            return fn(*args, **kw)
263        except:
264            traceback.print_exc()
265
266    return wrapper
267
268
269@protect
270def _plot_result1D(data, theory, resid, view, use_data, limits=None):
271    """
272    Plot the data and residuals for 1D data.
273    """
274    import matplotlib.pyplot as plt
275    from numpy.ma import masked_array, masked
276
277    use_data = use_data and data.y is not None
278    use_theory = theory is not None
279    use_resid = resid is not None
280    num_plots = (use_data or use_theory) + use_resid
281
282    scale = data.x**4 if view == 'q4' else 1.0
283
284    if use_data or use_theory:
285        #print(vmin, vmax)
286        all_positive = True
287        some_present = False
288        if use_data:
289            mdata = masked_array(data.y, data.mask.copy())
290            mdata[~np.isfinite(mdata)] = masked
291            if view is 'log':
292                mdata[mdata <= 0] = masked
293            plt.errorbar(data.x/10, scale*mdata, yerr=data.dy, fmt='.')
294            all_positive = all_positive and (mdata > 0).all()
295            some_present = some_present or (mdata.count() > 0)
296
297
298        if use_theory:
299            mtheory = masked_array(theory, data.mask.copy())
300            mtheory[~np.isfinite(mtheory)] = masked
301            if view is 'log':
302                mtheory[mtheory <= 0] = masked
303            plt.plot(data.x/10, scale*mtheory, '-', hold=True)
304            all_positive = all_positive and (mtheory > 0).all()
305            some_present = some_present or (mtheory.count() > 0)
306
307        if limits is not None:
308            plt.ylim(*limits)
309
310        if num_plots > 1:
311            plt.subplot(1, num_plots, 1)
312        plt.xscale('linear' if not some_present else view)
313        plt.yscale('linear'
314                   if view == 'q4' or not some_present or not all_positive
315                   else view)
316        plt.xlabel("$q$/nm$^{-1}$")
317        plt.ylabel('$I(q)$')
318
319    if use_resid:
320        mresid = masked_array(resid, data.mask.copy())
321        mresid[~np.isfinite(mresid)] = masked
322        some_present = (mresid.count() > 0)
323
324        if num_plots > 1:
325            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
326        plt.plot(data.x/10, mresid, '-')
327        plt.xlabel("$q$/nm$^{-1}$")
328        plt.ylabel('residuals')
329        plt.xscale('linear' if not some_present else view)
330
331
332@protect
333def _plot_result_sesans(data, theory, resid, use_data, limits=None):
334    import matplotlib.pyplot as plt
335    use_data = use_data and data.y is not None
336    use_theory = theory is not None
337    use_resid = resid is not None
338    num_plots = (use_data or use_theory) + use_resid
339
340    if use_data or use_theory:
341        if num_plots > 1:
342            plt.subplot(1, num_plots, 1)
343        if use_data:
344            plt.errorbar(data.x, data.y, yerr=data.dy)
345        if theory is not None:
346            plt.plot(data.x, theory, '-', hold=True)
347        if limits is not None:
348            plt.ylim(*limits)
349        plt.xlabel('spin echo length (nm)')
350        plt.ylabel('polarization (P/P0)')
351
352    if resid is not None:
353        if num_plots > 1:
354            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
355        plt.plot(data.x, resid, 'x')
356        plt.xlabel('spin echo length (nm)')
357        plt.ylabel('residuals (P/P0)')
358
359
360@protect
361def _plot_result2D(data, theory, resid, view, use_data, limits=None):
362    """
363    Plot the data and residuals for 2D data.
364    """
365    import matplotlib.pyplot as plt
366    use_data = use_data and data.data is not None
367    use_theory = theory is not None
368    use_resid = resid is not None
369    num_plots = use_data + use_theory + use_resid
370
371    # Put theory and data on a common colormap scale
372    vmin, vmax = np.inf, -np.inf
373    if use_data:
374        target = data.data[~data.mask]
375        datamin = target[target > 0].min() if view == 'log' else target.min()
376        datamax = target.max()
377        vmin = min(vmin, datamin)
378        vmax = max(vmax, datamax)
379    if use_theory:
380        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
381        theorymax = theory.max()
382        vmin = min(vmin, theorymin)
383        vmax = max(vmax, theorymax)
384
385    # Override data limits from the caller
386    if limits is not None:
387        vmin, vmax = limits
388
389    # Plot data
390    if use_data:
391        if num_plots > 1:
392            plt.subplot(1, num_plots, 1)
393        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
394        plt.title('data')
395        h = plt.colorbar()
396        h.set_label('$I(q)$')
397
398    # plot theory
399    if use_theory:
400        if num_plots > 1:
401            plt.subplot(1, num_plots, use_data+1)
402        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
403        plt.title('theory')
404        h = plt.colorbar()
405        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
406                    else r'$q^4 I(q)$' if view == 'q4'
407                    else '$I(q)$')
408
409    # plot resid
410    if use_resid:
411        if num_plots > 1:
412            plt.subplot(1, num_plots, use_data+use_theory+1)
413        _plot_2d_signal(data, resid, view='linear')
414        plt.title('residuals')
415        h = plt.colorbar()
416        h.set_label(r'$\Delta I(q)$')
417
418
419@protect
420def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):
421    """
422    Plot the target value for the data.  This could be the data itself,
423    the theory calculation, or the residuals.
424
425    *scale* can be 'log' for log scale data, or 'linear'.
426    """
427    import matplotlib.pyplot as plt
428    from numpy.ma import masked_array
429
430    image = np.zeros_like(data.qx_data)
431    image[~data.mask] = signal
432    valid = np.isfinite(image)
433    if view == 'log':
434        valid[valid] = (image[valid] > 0)
435        if vmin is None: vmin = image[valid & ~data.mask].min()
436        if vmax is None: vmax = image[valid & ~data.mask].max()
437        image[valid] = np.log10(image[valid])
438    elif view == 'q4':
439        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
440        if vmin is None: vmin = image[valid & ~data.mask].min()
441        if vmax is None: vmax = image[valid & ~data.mask].max()
442    else:
443        if vmin is None: vmin = image[valid & ~data.mask].min()
444        if vmax is None: vmax = image[valid & ~data.mask].max()
445
446    image[~valid | data.mask] = 0
447    #plottable = Iq
448    plottable = masked_array(image, ~valid | data.mask)
449    xmin, xmax = min(data.qx_data)/10, max(data.qx_data)/10
450    ymin, ymax = min(data.qy_data)/10, max(data.qy_data)/10
451    if view == 'log':
452        vmin, vmax = np.log10(vmin), np.log10(vmax)
453    plt.imshow(plottable.reshape(len(data.x_bins), len(data.y_bins)),
454               interpolation='nearest', aspect=1, origin='upper',
455               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
456    plt.xlabel("$q_x$/nm$^{-1}$")
457    plt.ylabel("$q_y$/nm$^{-1}$")
458    return vmin, vmax
459
460def demo():
461    data = load_data('DEC07086.DAT')
462    set_beam_stop(data, 0.004)
463    plot_data(data)
464    import matplotlib.pyplot as plt; plt.show()
465
466
467if __name__ == "__main__":
468    demo()
Note: See TracBrowser for help on using the repository browser.