source: sasmodels/sasmodels/data.py @ eafc9fa

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since eafc9fa was eafc9fa, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor kernel wrappers to simplify q input handling

  • Property mode set to 100644
File size: 17.0 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np
38
39def load_data(filename):
40    """
41    Load data using a sasview loader.
42    """
43    from sas.dataloader.loader import Loader
44    loader = Loader()
45    data = loader.load(filename)
46    if data is None:
47        raise IOError("Data %r could not be loaded" % filename)
48    return data
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54    """
55    from sas.dataloader.manipulations import Ringcut
56    if hasattr(data, 'qx_data'):
57        data.mask = Ringcut(0, radius)(data)
58        if outer is not None:
59            data.mask += Ringcut(outer, np.inf)(data)
60    else:
61        data.mask = (data.x < radius)
62        if outer is not None:
63            data.mask |= (data.x >= outer)
64
65
66def set_half(data, half):
67    """
68    Select half of the data, either "right" or "left".
69    """
70    from sas.dataloader.manipulations import Boxcut
71    if half == 'right':
72        data.mask += \
73            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
74    if half == 'left':
75        data.mask += \
76            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
77
78
79def set_top(data, cutoff):
80    """
81    Chop the top off the data, above *cutoff*.
82    """
83    from sas.dataloader.manipulations import Boxcut
84    data.mask += \
85        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
86
87
88class Data1D(object):
89    """
90    1D data object.
91
92    Note that this definition matches the attributes from sasview, with
93    some generic 1D data vectors and some SAS specific definitions.  Some
94    refactoring to allow consistent naming conventions between 1D, 2D and
95    SESANS data would be helpful.
96
97    **Attributes**
98
99    *x*, *dx*: $q$ vector and gaussian resolution
100
101    *y*, *dy*: $I(q)$ vector and measurement uncertainty
102
103    *mask*: values to include in plotting/analysis
104
105    *dxl*: slit widths for slit smeared data, with *dx* ignored
106
107    *qmin*, *qmax*: range of $q$ values in *x*
108
109    *filename*: label for the data line
110
111    *_xaxis*, *_xunit*: label and units for the *x* axis
112
113    *_yaxis*, *_yunit*: label and units for the *y* axis
114    """
115    def __init__(self, x=None, y=None, dx=None, dy=None):
116        self.x, self.y, self.dx, self.dy = x, y, dx, dy
117        self.dxl = None
118        self.filename = None
119        self.qmin = x.min() if x is not None else np.NaN
120        self.qmax = x.max() if x is not None else np.NaN
121        # TODO: why is 1D mask False and 2D mask True?
122        self.mask = (np.isnan(y) if y is not None
123                     else np.zeros_like(x, 'b') if x is not None
124                     else None)
125        self._xaxis, self._xunit = "x", ""
126        self._yaxis, self._yunit = "y", ""
127
128    def xaxis(self, label, unit):
129        """
130        set the x axis label and unit
131        """
132        self._xaxis = label
133        self._xunit = unit
134
135    def yaxis(self, label, unit):
136        """
137        set the y axis label and unit
138        """
139        self._yaxis = label
140        self._yunit = unit
141
142
143
144class Data2D(object):
145    """
146    2D data object.
147
148    Note that this definition matches the attributes from sasview. Some
149    refactoring to allow consistent naming conventions between 1D, 2D and
150    SESANS data would be helpful.
151
152    **Attributes**
153
154    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
155
156    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
157
158    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
159
160    *mask*: values to exclude from plotting/analysis
161
162    *qmin*, *qmax*: range of $q$ values in *x*
163
164    *filename*: label for the data line
165
166    *_xaxis*, *_xunit*: label and units for the *x* axis
167
168    *_yaxis*, *_yunit*: label and units for the *y* axis
169
170    *_zaxis*, *_zunit*: label and units for the *y* axis
171
172    *Q_unit*, *I_unit*: units for Q and intensity
173
174    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
175    """
176    def __init__(self, x=None, y=None, z=None, dx=None, dy=None, dz=None):
177        self.qx_data, self.dqx_data = x, dx
178        self.qy_data, self.dqy_data = y, dy
179        self.data, self.err_data = z, dz
180        self.mask = (~np.isnan(z) if z is not None
181                     else np.ones_like(x) if x is not None
182                     else None)
183        self.q_data = np.sqrt(x**2 + y**2)
184        self.qmin = 1e-16
185        self.qmax = np.inf
186        self.detector = []
187        self.source = Source()
188        self.Q_unit = "1/A"
189        self.I_unit = "1/cm"
190        self.xaxis("Q_x", "1/A")
191        self.yaxis("Q_y", "1/A")
192        self.zaxis("Intensity", "1/cm")
193        self._xaxis, self._xunit = "x", ""
194        self._yaxis, self._yunit = "y", ""
195        self._zaxis, self._zunit = "z", ""
196        self.x_bins, self.y_bins = None, None
197
198    def xaxis(self, label, unit):
199        """
200        set the x axis label and unit
201        """
202        self._xaxis = label
203        self._xunit = unit
204
205    def yaxis(self, label, unit):
206        """
207        set the y axis label and unit
208        """
209        self._yaxis = label
210        self._yunit = unit
211
212    def zaxis(self, label, unit):
213        """
214        set the y axis label and unit
215        """
216        self._zaxis = label
217        self._zunit = unit
218
219
220class Vector(object):
221    """
222    3-space vector of *x*, *y*, *z*
223    """
224    def __init__(self, x=None, y=None, z=None):
225        self.x, self.y, self.z = x, y, z
226
227class Detector(object):
228    """
229    Detector attributes.
230    """
231    def __init__(self, pixel_size=(None, None), distance=None):
232        self.pixel_size = Vector(*pixel_size)
233        self.distance = distance
234
235class Source(object):
236    """
237    Beam attributes.
238    """
239    def __init__(self):
240        self.wavelength = np.NaN
241        self.wavelength_unit = "A"
242
243
244def empty_data1D(q, resolution=0.05):
245    """
246    Create empty 1D data using the given *q* as the x value.
247
248    *resolution* dq/q defaults to 5%.
249    """
250
251    #Iq = 100 * np.ones_like(q)
252    #dIq = np.sqrt(Iq)
253    Iq, dIq = None, None
254    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
255    data.filename = "fake data"
256    return data
257
258
259def empty_data2D(qx, qy=None, resolution=0.05):
260    """
261    Create empty 2D data using the given mesh.
262
263    If *qy* is missing, create a square mesh with *qy=qx*.
264
265    *resolution* dq/q defaults to 5%.
266    """
267    if qy is None:
268        qy = qx
269    # 5% dQ/Q resolution
270    Qx, Qy = np.meshgrid(qx, qy)
271    Qx, Qy = Qx.flatten(), Qy.flatten()
272    Iq = 100 * np.ones_like(Qx)
273    dIq = np.sqrt(Iq)
274    if resolution != 0:
275        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
276        # Should have an additional constant which depends on distances and
277        # radii of the aperture, pixel dimensions and wavelength spread
278        # Instead, assume radial dQ/Q is constant, and perpendicular matches
279        # radial (which instead it should be inverse).
280        Q = np.sqrt(Qx**2 + Qy**2)
281        dqx = resolution * Q
282        dqy = resolution * Q
283    else:
284        dqx = dqy = None
285
286    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
287    data.x_bins = qx
288    data.y_bins = qy
289    data.filename = "fake data"
290
291    # pixel_size in mm, distance in m
292    detector = Detector(pixel_size=(5, 5), distance=4)
293    data.detector.append(detector)
294    data.source.wavelength = 5 # angstroms
295    data.source.wavelength_unit = "A"
296    return data
297
298
299def plot_data(data, view='log', limits=None):
300    """
301    Plot data loaded by the sasview loader.
302
303    *data* is a sasview data object, either 1D, 2D or SESANS.
304
305    *view* is log or linear.
306
307    *limits* sets the intensity limits on the plot; if None then the limits
308    are inferred from the data.
309    """
310    # Note: kind of weird using the plot result functions to plot just the
311    # data, but they already handle the masking and graph markup already, so
312    # do not repeat.
313    if hasattr(data, 'lam'):
314        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
315    elif hasattr(data, 'qx_data'):
316        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
317    else:
318        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
319
320
321def plot_theory(data, theory, resid=None, view='log',
322                use_data=True, limits=None):
323    """
324    Plot theory calculation.
325
326    *data* is needed to define the graph properties such as labels and
327    units, and to define the data mask.
328
329    *theory* is a matrix of the same shape as the data.
330
331    *view* is log or linear
332
333    *use_data* is True if the data should be plotted as well as the theory.
334
335    *limits* sets the intensity limits on the plot; if None then the limits
336    are inferred from the data.
337    """
338    if hasattr(data, 'lam'):
339        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
340    elif hasattr(data, 'qx_data'):
341        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
342    else:
343        _plot_result1D(data, theory, resid, view, use_data, limits=limits)
344
345
346def protect(fn):
347    """
348    Decorator to wrap calls in an exception trapper which prints the
349    exception and continues.  Keyboard interrupts are ignored.
350    """
351    def wrapper(*args, **kw):
352        """
353        Trap and print errors from function %s
354        """%fn.__name__
355        try:
356            return fn(*args, **kw)
357        except KeyboardInterrupt:
358            raise
359        except:
360            traceback.print_exc()
361
362    return wrapper
363
364
365@protect
366def _plot_result1D(data, theory, resid, view, use_data, limits=None):
367    """
368    Plot the data and residuals for 1D data.
369    """
370    import matplotlib.pyplot as plt
371    from numpy.ma import masked_array, masked
372
373    use_data = use_data and data.y is not None
374    use_theory = theory is not None
375    use_resid = resid is not None
376    num_plots = (use_data or use_theory) + use_resid
377
378    scale = data.x**4 if view == 'q4' else 1.0
379
380    if use_data or use_theory:
381        #print(vmin, vmax)
382        all_positive = True
383        some_present = False
384        if use_data:
385            mdata = masked_array(data.y, data.mask.copy())
386            mdata[~np.isfinite(mdata)] = masked
387            if view is 'log':
388                mdata[mdata <= 0] = masked
389            plt.errorbar(data.x/10, scale*mdata, yerr=data.dy, fmt='.')
390            all_positive = all_positive and (mdata > 0).all()
391            some_present = some_present or (mdata.count() > 0)
392
393
394        if use_theory:
395            mtheory = masked_array(theory, data.mask.copy())
396            mtheory[~np.isfinite(mtheory)] = masked
397            if view is 'log':
398                mtheory[mtheory <= 0] = masked
399            plt.plot(data.x/10, scale*mtheory, '-', hold=True)
400            all_positive = all_positive and (mtheory > 0).all()
401            some_present = some_present or (mtheory.count() > 0)
402
403        if limits is not None:
404            plt.ylim(*limits)
405
406        if num_plots > 1:
407            plt.subplot(1, num_plots, 1)
408        plt.xscale('linear' if not some_present else view)
409        plt.yscale('linear'
410                   if view == 'q4' or not some_present or not all_positive
411                   else view)
412        plt.xlabel("$q$/nm$^{-1}$")
413        plt.ylabel('$I(q)$')
414
415    if use_resid:
416        mresid = masked_array(resid, data.mask.copy())
417        mresid[~np.isfinite(mresid)] = masked
418        some_present = (mresid.count() > 0)
419
420        if num_plots > 1:
421            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
422        plt.plot(data.x/10, mresid, '-')
423        plt.xlabel("$q$/nm$^{-1}$")
424        plt.ylabel('residuals')
425        plt.xscale('linear' if not some_present else view)
426
427
428@protect
429def _plot_result_sesans(data, theory, resid, use_data, limits=None):
430    """
431    Plot SESANS results.
432    """
433    import matplotlib.pyplot as plt
434    use_data = use_data and data.y is not None
435    use_theory = theory is not None
436    use_resid = resid is not None
437    num_plots = (use_data or use_theory) + use_resid
438
439    if use_data or use_theory:
440        if num_plots > 1:
441            plt.subplot(1, num_plots, 1)
442        if use_data:
443            plt.errorbar(data.x, data.y, yerr=data.dy)
444        if theory is not None:
445            plt.plot(data.x, theory, '-', hold=True)
446        if limits is not None:
447            plt.ylim(*limits)
448        plt.xlabel('spin echo length (nm)')
449        plt.ylabel('polarization (P/P0)')
450
451    if resid is not None:
452        if num_plots > 1:
453            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
454        plt.plot(data.x, resid, 'x')
455        plt.xlabel('spin echo length (nm)')
456        plt.ylabel('residuals (P/P0)')
457
458
459@protect
460def _plot_result2D(data, theory, resid, view, use_data, limits=None):
461    """
462    Plot the data and residuals for 2D data.
463    """
464    import matplotlib.pyplot as plt
465    use_data = use_data and data.data is not None
466    use_theory = theory is not None
467    use_resid = resid is not None
468    num_plots = use_data + use_theory + use_resid
469
470    # Put theory and data on a common colormap scale
471    vmin, vmax = np.inf, -np.inf
472    if use_data:
473        target = data.data[~data.mask]
474        datamin = target[target > 0].min() if view == 'log' else target.min()
475        datamax = target.max()
476        vmin = min(vmin, datamin)
477        vmax = max(vmax, datamax)
478    if use_theory:
479        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
480        theorymax = theory.max()
481        vmin = min(vmin, theorymin)
482        vmax = max(vmax, theorymax)
483
484    # Override data limits from the caller
485    if limits is not None:
486        vmin, vmax = limits
487
488    # Plot data
489    if use_data:
490        if num_plots > 1:
491            plt.subplot(1, num_plots, 1)
492        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
493        plt.title('data')
494        h = plt.colorbar()
495        h.set_label('$I(q)$')
496
497    # plot theory
498    if use_theory:
499        if num_plots > 1:
500            plt.subplot(1, num_plots, use_data+1)
501        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
502        plt.title('theory')
503        h = plt.colorbar()
504        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
505                    else r'$q^4 I(q)$' if view == 'q4'
506                    else '$I(q)$')
507
508    # plot resid
509    if use_resid:
510        if num_plots > 1:
511            plt.subplot(1, num_plots, use_data+use_theory+1)
512        _plot_2d_signal(data, resid, view='linear')
513        plt.title('residuals')
514        h = plt.colorbar()
515        h.set_label(r'$\Delta I(q)$')
516
517
518@protect
519def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):
520    """
521    Plot the target value for the data.  This could be the data itself,
522    the theory calculation, or the residuals.
523
524    *scale* can be 'log' for log scale data, or 'linear'.
525    """
526    import matplotlib.pyplot as plt
527    from numpy.ma import masked_array
528
529    image = np.zeros_like(data.qx_data)
530    image[~data.mask] = signal
531    valid = np.isfinite(image)
532    if view == 'log':
533        valid[valid] = (image[valid] > 0)
534        if vmin is None: vmin = image[valid & ~data.mask].min()
535        if vmax is None: vmax = image[valid & ~data.mask].max()
536        image[valid] = np.log10(image[valid])
537    elif view == 'q4':
538        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
539        if vmin is None: vmin = image[valid & ~data.mask].min()
540        if vmax is None: vmax = image[valid & ~data.mask].max()
541    else:
542        if vmin is None: vmin = image[valid & ~data.mask].min()
543        if vmax is None: vmax = image[valid & ~data.mask].max()
544
545    image[~valid | data.mask] = 0
546    #plottable = Iq
547    plottable = masked_array(image, ~valid | data.mask)
548    xmin, xmax = min(data.qx_data)/10, max(data.qx_data)/10
549    ymin, ymax = min(data.qy_data)/10, max(data.qy_data)/10
550    if view == 'log':
551        vmin, vmax = np.log10(vmin), np.log10(vmax)
552    plt.imshow(plottable.reshape(len(data.x_bins), len(data.y_bins)),
553               interpolation='nearest', aspect=1, origin='upper',
554               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
555    plt.xlabel("$q_x$/nm$^{-1}$")
556    plt.ylabel("$q_y$/nm$^{-1}$")
557    return vmin, vmax
558
559def demo():
560    """
561    Load and plot a SAS dataset.
562    """
563    data = load_data('DEC07086.DAT')
564    set_beam_stop(data, 0.004)
565    plot_data(data)
566    import matplotlib.pyplot as plt; plt.show()
567
568
569if __name__ == "__main__":
570    demo()
Note: See TracBrowser for help on using the repository browser.