source: sasmodels/sasmodels/data.py @ a557a99

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a557a99 was 40a87fa, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

lint and latex cleanup

  • Property mode set to 100644
File size: 20.5 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np  # type: ignore
38
39try:
40    from typing import Union, Dict, List, Optional
41except ImportError:
42    pass
43else:
44    Data = Union["Data1D", "Data2D", "SesansData"]
45
46def load_data(filename):
47    # type: (str) -> Data
48    """
49    Load data using a sasview loader.
50    """
51    from sas.sascalc.dataloader.loader import Loader  # type: ignore
52    loader = Loader()
53    data = loader.load(filename)
54    if data is None:
55        raise IOError("Data %r could not be loaded" % filename)
56    return data
57
58
59def set_beam_stop(data, radius, outer=None):
60    # type: (Data, float, Optional[float]) -> None
61    """
62    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
63    """
64    from sas.sascalc.dataloader.manipulations import Ringcut
65    if hasattr(data, 'qx_data'):
66        data.mask = Ringcut(0, radius)(data)
67        if outer is not None:
68            data.mask += Ringcut(outer, np.inf)(data)
69    else:
70        data.mask = (data.x < radius)
71        if outer is not None:
72            data.mask |= (data.x >= outer)
73
74
75def set_half(data, half):
76    # type: (Data, str) -> None
77    """
78    Select half of the data, either "right" or "left".
79    """
80    from sas.sascalc.dataloader.manipulations import Boxcut
81    if half == 'right':
82        data.mask += \
83            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
84    if half == 'left':
85        data.mask += \
86            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
87
88
89def set_top(data, cutoff):
90    # type: (Data, float) -> None
91    """
92    Chop the top off the data, above *cutoff*.
93    """
94    from sas.sascalc.dataloader.manipulations import Boxcut
95    data.mask += \
96        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
97
98
99class Data1D(object):
100    """
101    1D data object.
102
103    Note that this definition matches the attributes from sasview, with
104    some generic 1D data vectors and some SAS specific definitions.  Some
105    refactoring to allow consistent naming conventions between 1D, 2D and
106    SESANS data would be helpful.
107
108    **Attributes**
109
110    *x*, *dx*: $q$ vector and gaussian resolution
111
112    *y*, *dy*: $I(q)$ vector and measurement uncertainty
113
114    *mask*: values to include in plotting/analysis
115
116    *dxl*: slit widths for slit smeared data, with *dx* ignored
117
118    *qmin*, *qmax*: range of $q$ values in *x*
119
120    *filename*: label for the data line
121
122    *_xaxis*, *_xunit*: label and units for the *x* axis
123
124    *_yaxis*, *_yunit*: label and units for the *y* axis
125    """
126    def __init__(self, x=None, y=None, dx=None, dy=None):
127        # type: (Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]) -> None
128        self.x, self.y, self.dx, self.dy = x, y, dx, dy
129        self.dxl = None
130        self.filename = None
131        self.qmin = x.min() if x is not None else np.NaN
132        self.qmax = x.max() if x is not None else np.NaN
133        # TODO: why is 1D mask False and 2D mask True?
134        self.mask = (np.isnan(y) if y is not None
135                     else np.zeros_like(x, 'b') if x is not None
136                     else None)
137        self._xaxis, self._xunit = "x", ""
138        self._yaxis, self._yunit = "y", ""
139
140    def xaxis(self, label, unit):
141        # type: (str, str) -> None
142        """
143        set the x axis label and unit
144        """
145        self._xaxis = label
146        self._xunit = unit
147
148    def yaxis(self, label, unit):
149        # type: (str, str) -> None
150        """
151        set the y axis label and unit
152        """
153        self._yaxis = label
154        self._yunit = unit
155
156class SesansData(Data1D):
157    """
158    SESANS data object.
159
160    This is just :class:`Data1D` with a wavelength parameter.
161
162    *x* is spin echo length and *y* is polarization (P/P0).
163    """
164    def __init__(self, **kw):
165        Data1D.__init__(self, **kw)
166        self.lam = None # type: Optional[np.ndarray]
167
168class Data2D(object):
169    """
170    2D data object.
171
172    Note that this definition matches the attributes from sasview. Some
173    refactoring to allow consistent naming conventions between 1D, 2D and
174    SESANS data would be helpful.
175
176    **Attributes**
177
178    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
179
180    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
181
182    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
183
184    *mask*: values to exclude from plotting/analysis
185
186    *qmin*, *qmax*: range of $q$ values in *x*
187
188    *filename*: label for the data line
189
190    *_xaxis*, *_xunit*: label and units for the *x* axis
191
192    *_yaxis*, *_yunit*: label and units for the *y* axis
193
194    *_zaxis*, *_zunit*: label and units for the *y* axis
195
196    *Q_unit*, *I_unit*: units for Q and intensity
197
198    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
199    """
200    def __init__(self, x=None, y=None, z=None, dx=None, dy=None, dz=None):
201        # type: (Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]) -> None
202        self.qx_data, self.dqx_data = x, dx
203        self.qy_data, self.dqy_data = y, dy
204        self.data, self.err_data = z, dz
205        self.mask = (np.isnan(z) if z is not None
206                     else np.zeros_like(x, dtype='bool') if x is not None
207                     else None)
208        self.q_data = np.sqrt(x**2 + y**2)
209        self.qmin = 1e-16
210        self.qmax = np.inf
211        self.detector = []
212        self.source = Source()
213        self.Q_unit = "1/A"
214        self.I_unit = "1/cm"
215        self.xaxis("Q_x", "1/A")
216        self.yaxis("Q_y", "1/A")
217        self.zaxis("Intensity", "1/cm")
218        self._xaxis, self._xunit = "x", ""
219        self._yaxis, self._yunit = "y", ""
220        self._zaxis, self._zunit = "z", ""
221        self.x_bins, self.y_bins = None, None
222        self.filename = None
223
224    def xaxis(self, label, unit):
225        # type: (str, str) -> None
226        """
227        set the x axis label and unit
228        """
229        self._xaxis = label
230        self._xunit = unit
231
232    def yaxis(self, label, unit):
233        # type: (str, str) -> None
234        """
235        set the y axis label and unit
236        """
237        self._yaxis = label
238        self._yunit = unit
239
240    def zaxis(self, label, unit):
241        # type: (str, str) -> None
242        """
243        set the y axis label and unit
244        """
245        self._zaxis = label
246        self._zunit = unit
247
248
249class Vector(object):
250    """
251    3-space vector of *x*, *y*, *z*
252    """
253    def __init__(self, x=None, y=None, z=None):
254        # type: (float, float, Optional[float]) -> None
255        self.x, self.y, self.z = x, y, z
256
257class Detector(object):
258    """
259    Detector attributes.
260    """
261    def __init__(self, pixel_size=(None, None), distance=None):
262        # type: (Tuple[float, float], float) -> None
263        self.pixel_size = Vector(*pixel_size)
264        self.distance = distance
265
266class Source(object):
267    """
268    Beam attributes.
269    """
270    def __init__(self):
271        # type: () -> None
272        self.wavelength = np.NaN
273        self.wavelength_unit = "A"
274
275
276def empty_data1D(q, resolution=0.0):
277    # type: (np.ndarray, float) -> Data1D
278    """
279    Create empty 1D data using the given *q* as the x value.
280
281    *resolution* dq/q defaults to 5%.
282    """
283
284    #Iq = 100 * np.ones_like(q)
285    #dIq = np.sqrt(Iq)
286    Iq, dIq = None, None
287    q = np.asarray(q)
288    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
289    data.filename = "fake data"
290    return data
291
292
293def empty_data2D(qx, qy=None, resolution=0.0):
294    # type: (np.ndarray, Optional[np.ndarray], float) -> Data2D
295    """
296    Create empty 2D data using the given mesh.
297
298    If *qy* is missing, create a square mesh with *qy=qx*.
299
300    *resolution* dq/q defaults to 5%.
301    """
302    if qy is None:
303        qy = qx
304    qx, qy = np.asarray(qx), np.asarray(qy)
305    # 5% dQ/Q resolution
306    Qx, Qy = np.meshgrid(qx, qy)
307    Qx, Qy = Qx.flatten(), Qy.flatten()
308    Iq = 100 * np.ones_like(Qx)  # type: np.ndarray
309    dIq = np.sqrt(Iq)
310    if resolution != 0:
311        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
312        # Should have an additional constant which depends on distances and
313        # radii of the aperture, pixel dimensions and wavelength spread
314        # Instead, assume radial dQ/Q is constant, and perpendicular matches
315        # radial (which instead it should be inverse).
316        Q = np.sqrt(Qx**2 + Qy**2)
317        dqx = resolution * Q
318        dqy = resolution * Q
319    else:
320        dqx = dqy = None
321
322    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
323    data.x_bins = qx
324    data.y_bins = qy
325    data.filename = "fake data"
326
327    # pixel_size in mm, distance in m
328    detector = Detector(pixel_size=(5, 5), distance=4)
329    data.detector.append(detector)
330    data.source.wavelength = 5 # angstroms
331    data.source.wavelength_unit = "A"
332    return data
333
334
335def plot_data(data, view='log', limits=None):
336    # type: (Data, str, Optional[Tuple[float, float]]) -> None
337    """
338    Plot data loaded by the sasview loader.
339
340    *data* is a sasview data object, either 1D, 2D or SESANS.
341
342    *view* is log or linear.
343
344    *limits* sets the intensity limits on the plot; if None then the limits
345    are inferred from the data.
346    """
347    # Note: kind of weird using the plot result functions to plot just the
348    # data, but they already handle the masking and graph markup already, so
349    # do not repeat.
350    if hasattr(data, 'lam'):
351        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
352    elif hasattr(data, 'qx_data'):
353        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
354    else:
355        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
356
357
358def plot_theory(data, theory, resid=None, view='log',
359                use_data=True, limits=None, Iq_calc=None):
360    # type: (Data, Optional[np.ndarray], Optional[np.ndarray], str, bool, Optional[Tuple[float,float]], Optional[np.ndarray]) -> None
361    """
362    Plot theory calculation.
363
364    *data* is needed to define the graph properties such as labels and
365    units, and to define the data mask.
366
367    *theory* is a matrix of the same shape as the data.
368
369    *view* is log or linear
370
371    *use_data* is True if the data should be plotted as well as the theory.
372
373    *limits* sets the intensity limits on the plot; if None then the limits
374    are inferred from the data.
375
376    *Iq_calc* is the raw theory values without resolution smearing
377    """
378    if hasattr(data, 'lam'):
379        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
380    elif hasattr(data, 'qx_data'):
381        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
382    else:
383        _plot_result1D(data, theory, resid, view, use_data,
384                       limits=limits, Iq_calc=Iq_calc)
385
386
387def protect(func):
388    # type: (Callable) -> Callable
389    """
390    Decorator to wrap calls in an exception trapper which prints the
391    exception and continues.  Keyboard interrupts are ignored.
392    """
393    def wrapper(*args, **kw):
394        """
395        Trap and print errors from function.
396        """
397        try:
398            return func(*args, **kw)
399        except Exception:
400            traceback.print_exc()
401
402    return wrapper
403
404
405@protect
406def _plot_result1D(data, theory, resid, view, use_data,
407                   limits=None, Iq_calc=None):
408    # type: (Data1D, Optional[np.ndarray], Optional[np.ndarray], str, bool, Optional[Tuple[float, float]], Optional[np.ndarray]) -> None
409    """
410    Plot the data and residuals for 1D data.
411    """
412    import matplotlib.pyplot as plt  # type: ignore
413    from numpy.ma import masked_array, masked  # type: ignore
414
415    use_data = use_data and data.y is not None
416    use_theory = theory is not None
417    use_resid = resid is not None
418    use_calc = use_theory and Iq_calc is not None
419    num_plots = (use_data or use_theory) + use_calc + use_resid
420    non_positive_x = (data.x <= 0.0).any()
421
422    scale = data.x**4 if view == 'q4' else 1.0
423
424    if use_data or use_theory:
425        if num_plots > 1:
426            plt.subplot(1, num_plots, 1)
427
428        #print(vmin, vmax)
429        all_positive = True
430        some_present = False
431        if use_data:
432            mdata = masked_array(data.y, data.mask.copy())
433            mdata[~np.isfinite(mdata)] = masked
434            if view is 'log':
435                mdata[mdata <= 0] = masked
436            plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
437            all_positive = all_positive and (mdata > 0).all()
438            some_present = some_present or (mdata.count() > 0)
439
440
441        if use_theory:
442            # Note: masks merge, so any masked theory points will stay masked,
443            # and the data mask will be added to it.
444            mtheory = masked_array(theory, data.mask.copy())
445            mtheory[~np.isfinite(mtheory)] = masked
446            if view is 'log':
447                mtheory[mtheory <= 0] = masked
448            plt.plot(data.x, scale*mtheory, '-', hold=True)
449            all_positive = all_positive and (mtheory > 0).all()
450            some_present = some_present or (mtheory.count() > 0)
451
452        if limits is not None:
453            plt.ylim(*limits)
454
455        plt.xscale('linear' if not some_present or non_positive_x  else view)
456        plt.yscale('linear'
457                   if view == 'q4' or not some_present or not all_positive
458                   else view)
459        plt.xlabel("$q$/A$^{-1}$")
460        plt.ylabel('$I(q)$')
461
462    if use_calc:
463        # Only have use_calc if have use_theory
464        plt.subplot(1, num_plots, 2)
465        qx, qy, Iqxy = Iq_calc
466        plt.pcolormesh(qx, qy[qy > 0], np.log10(Iqxy[qy > 0, :]))
467        plt.xlabel("$q_x$/A$^{-1}$")
468        plt.xlabel("$q_y$/A$^{-1}$")
469        plt.xscale('log')
470        plt.yscale('log')
471        #plt.axis('equal')
472
473    if use_resid:
474        mresid = masked_array(resid, data.mask.copy())
475        mresid[~np.isfinite(mresid)] = masked
476        some_present = (mresid.count() > 0)
477
478        if num_plots > 1:
479            plt.subplot(1, num_plots, use_calc + 2)
480        plt.plot(data.x, mresid, '-')
481        plt.xlabel("$q$/A$^{-1}$")
482        plt.ylabel('residuals')
483        plt.xscale('linear' if not some_present or non_positive_x else view)
484
485
486@protect
487def _plot_result_sesans(data, theory, resid, use_data, limits=None):
488    # type: (SesansData, Optional[np.ndarray], Optional[np.ndarray], bool, Optional[Tuple[float, float]]) -> None
489    """
490    Plot SESANS results.
491    """
492    import matplotlib.pyplot as plt  # type: ignore
493    use_data = use_data and data.y is not None
494    use_theory = theory is not None
495    use_resid = resid is not None
496    num_plots = (use_data or use_theory) + use_resid
497
498    if use_data or use_theory:
499        is_tof = (data.lam != data.lam[0]).any()
500        if num_plots > 1:
501            plt.subplot(1, num_plots, 1)
502        if use_data:
503            if is_tof:
504                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam),
505                             yerr=data.dy/data.y/(data.lam*data.lam))
506            else:
507                plt.errorbar(data.x, data.y, yerr=data.dy)
508        if theory is not None:
509            if is_tof:
510                plt.plot(data.x, np.log(theory)/(data.lam*data.lam), '-', hold=True)
511            else:
512                plt.plot(data.x, theory, '-', hold=True)
513        if limits is not None:
514            plt.ylim(*limits)
515
516        plt.xlabel('spin echo length ({})'.format(data._xunit))
517        if is_tof:
518            plt.ylabel(r'(Log (P/P$_0$))/$\lambda^2$')
519        else:
520            plt.ylabel('polarization (P/P0)')
521
522
523    if resid is not None:
524        if num_plots > 1:
525            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
526        plt.plot(data.x, resid, 'x')
527        plt.xlabel('spin echo length ({})'.format(data._xunit))
528        plt.ylabel('residuals (P/P0)')
529
530
531@protect
532def _plot_result2D(data, theory, resid, view, use_data, limits=None):
533    # type: (Data2D, Optional[np.ndarray], Optional[np.ndarray], str, bool, Optional[Tuple[float,float]]) -> None
534    """
535    Plot the data and residuals for 2D data.
536    """
537    import matplotlib.pyplot as plt  # type: ignore
538    use_data = use_data and data.data is not None
539    use_theory = theory is not None
540    use_resid = resid is not None
541    num_plots = use_data + use_theory + use_resid
542
543    # Put theory and data on a common colormap scale
544    vmin, vmax = np.inf, -np.inf
545    target = None # type: Optional[np.ndarray]
546    if use_data:
547        target = data.data[~data.mask]
548        datamin = target[target > 0].min() if view == 'log' else target.min()
549        datamax = target.max()
550        vmin = min(vmin, datamin)
551        vmax = max(vmax, datamax)
552    if use_theory:
553        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
554        theorymax = theory.max()
555        vmin = min(vmin, theorymin)
556        vmax = max(vmax, theorymax)
557
558    # Override data limits from the caller
559    if limits is not None:
560        vmin, vmax = limits
561
562    # Plot data
563    if use_data:
564        if num_plots > 1:
565            plt.subplot(1, num_plots, 1)
566        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
567        plt.title('data')
568        h = plt.colorbar()
569        h.set_label('$I(q)$')
570
571    # plot theory
572    if use_theory:
573        if num_plots > 1:
574            plt.subplot(1, num_plots, use_data+1)
575        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
576        plt.title('theory')
577        h = plt.colorbar()
578        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
579                    else r'$q^4 I(q)$' if view == 'q4'
580                    else '$I(q)$')
581
582    # plot resid
583    if use_resid:
584        if num_plots > 1:
585            plt.subplot(1, num_plots, use_data+use_theory+1)
586        _plot_2d_signal(data, resid, view='linear')
587        plt.title('residuals')
588        h = plt.colorbar()
589        h.set_label(r'$\Delta I(q)$')
590
591
592@protect
593def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):
594    # type: (Data2D, np.ndarray, Optional[float], Optional[float], str) -> Tuple[float, float]
595    """
596    Plot the target value for the data.  This could be the data itself,
597    the theory calculation, or the residuals.
598
599    *scale* can be 'log' for log scale data, or 'linear'.
600    """
601    import matplotlib.pyplot as plt  # type: ignore
602    from numpy.ma import masked_array  # type: ignore
603
604    image = np.zeros_like(data.qx_data)
605    image[~data.mask] = signal
606    valid = np.isfinite(image)
607    if view == 'log':
608        valid[valid] = (image[valid] > 0)
609        if vmin is None: vmin = image[valid & ~data.mask].min()
610        if vmax is None: vmax = image[valid & ~data.mask].max()
611        image[valid] = np.log10(image[valid])
612    elif view == 'q4':
613        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
614        if vmin is None: vmin = image[valid & ~data.mask].min()
615        if vmax is None: vmax = image[valid & ~data.mask].max()
616    else:
617        if vmin is None: vmin = image[valid & ~data.mask].min()
618        if vmax is None: vmax = image[valid & ~data.mask].max()
619
620    image[~valid | data.mask] = 0
621    #plottable = Iq
622    plottable = masked_array(image, ~valid | data.mask)
623    # Divide range by 10 to convert from angstroms to nanometers
624    xmin, xmax = min(data.qx_data), max(data.qx_data)
625    ymin, ymax = min(data.qy_data), max(data.qy_data)
626    if view == 'log':
627        vmin, vmax = np.log10(vmin), np.log10(vmax)
628    plt.imshow(plottable.reshape(len(data.x_bins), len(data.y_bins)),
629               interpolation='nearest', aspect=1, origin='lower',
630               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
631    plt.xlabel("$q_x$/A$^{-1}$")
632    plt.ylabel("$q_y$/A$^{-1}$")
633    return vmin, vmax
634
635def demo():
636    # type: () -> None
637    """
638    Load and plot a SAS dataset.
639    """
640    data = load_data('DEC07086.DAT')
641    set_beam_stop(data, 0.004)
642    plot_data(data)
643    import matplotlib.pyplot as plt  # type: ignore
644    plt.show()
645
646
647if __name__ == "__main__":
648    demo()
Note: See TracBrowser for help on using the repository browser.