source: sasmodels/sasmodels/data.py @ 5efe850

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5efe850 was 4e00c13, checked in by mathieu, 8 years ago

Update sas.dataloader to sas.sascalc.dataloader

  • Property mode set to 100644
File size: 18.3 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np
38
39def load_data(filename):
40    """
41    Load data using a sasview loader.
42    """
43    from sas.sascalc.dataloader.loader import Loader
44    loader = Loader()
45    data = loader.load(filename)
46    if data is None:
47        raise IOError("Data %r could not be loaded" % filename)
48    return data
49
50
51def set_beam_stop(data, radius, outer=None):
52    """
53    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
54    """
55    from sas.sascalc.dataloader.manipulations import Ringcut
56    if hasattr(data, 'qx_data'):
57        data.mask = Ringcut(0, radius)(data)
58        if outer is not None:
59            data.mask += Ringcut(outer, np.inf)(data)
60    else:
61        data.mask = (data.x < radius)
62        if outer is not None:
63            data.mask |= (data.x >= outer)
64
65
66def set_half(data, half):
67    """
68    Select half of the data, either "right" or "left".
69    """
70    from sas.sascalc.dataloader.manipulations import Boxcut
71    if half == 'right':
72        data.mask += \
73            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
74    if half == 'left':
75        data.mask += \
76            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
77
78
79def set_top(data, cutoff):
80    """
81    Chop the top off the data, above *cutoff*.
82    """
83    from sas.sascalc.dataloader.manipulations import Boxcut
84    data.mask += \
85        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
86
87
88class Data1D(object):
89    """
90    1D data object.
91
92    Note that this definition matches the attributes from sasview, with
93    some generic 1D data vectors and some SAS specific definitions.  Some
94    refactoring to allow consistent naming conventions between 1D, 2D and
95    SESANS data would be helpful.
96
97    **Attributes**
98
99    *x*, *dx*: $q$ vector and gaussian resolution
100
101    *y*, *dy*: $I(q)$ vector and measurement uncertainty
102
103    *mask*: values to include in plotting/analysis
104
105    *dxl*: slit widths for slit smeared data, with *dx* ignored
106
107    *qmin*, *qmax*: range of $q$ values in *x*
108
109    *filename*: label for the data line
110
111    *_xaxis*, *_xunit*: label and units for the *x* axis
112
113    *_yaxis*, *_yunit*: label and units for the *y* axis
114    """
115    def __init__(self, x=None, y=None, dx=None, dy=None):
116        self.x, self.y, self.dx, self.dy = x, y, dx, dy
117        self.dxl = None
118        self.filename = None
119        self.qmin = x.min() if x is not None else np.NaN
120        self.qmax = x.max() if x is not None else np.NaN
121        # TODO: why is 1D mask False and 2D mask True?
122        self.mask = (np.isnan(y) if y is not None
123                     else np.zeros_like(x, 'b') if x is not None
124                     else None)
125        self._xaxis, self._xunit = "x", ""
126        self._yaxis, self._yunit = "y", ""
127
128    def xaxis(self, label, unit):
129        """
130        set the x axis label and unit
131        """
132        self._xaxis = label
133        self._xunit = unit
134
135    def yaxis(self, label, unit):
136        """
137        set the y axis label and unit
138        """
139        self._yaxis = label
140        self._yunit = unit
141
142
143
144class Data2D(object):
145    """
146    2D data object.
147
148    Note that this definition matches the attributes from sasview. Some
149    refactoring to allow consistent naming conventions between 1D, 2D and
150    SESANS data would be helpful.
151
152    **Attributes**
153
154    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
155
156    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
157
158    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
159
160    *mask*: values to exclude from plotting/analysis
161
162    *qmin*, *qmax*: range of $q$ values in *x*
163
164    *filename*: label for the data line
165
166    *_xaxis*, *_xunit*: label and units for the *x* axis
167
168    *_yaxis*, *_yunit*: label and units for the *y* axis
169
170    *_zaxis*, *_zunit*: label and units for the *y* axis
171
172    *Q_unit*, *I_unit*: units for Q and intensity
173
174    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
175    """
176    def __init__(self, x=None, y=None, z=None, dx=None, dy=None, dz=None):
177        self.qx_data, self.dqx_data = x, dx
178        self.qy_data, self.dqy_data = y, dy
179        self.data, self.err_data = z, dz
180        self.mask = (np.isnan(z) if z is not None
181                     else np.zeros_like(x, dtype='bool') if x is not None
182                     else None)
183        self.q_data = np.sqrt(x**2 + y**2)
184        self.qmin = 1e-16
185        self.qmax = np.inf
186        self.detector = []
187        self.source = Source()
188        self.Q_unit = "1/A"
189        self.I_unit = "1/cm"
190        self.xaxis("Q_x", "1/A")
191        self.yaxis("Q_y", "1/A")
192        self.zaxis("Intensity", "1/cm")
193        self._xaxis, self._xunit = "x", ""
194        self._yaxis, self._yunit = "y", ""
195        self._zaxis, self._zunit = "z", ""
196        self.x_bins, self.y_bins = None, None
197
198    def xaxis(self, label, unit):
199        """
200        set the x axis label and unit
201        """
202        self._xaxis = label
203        self._xunit = unit
204
205    def yaxis(self, label, unit):
206        """
207        set the y axis label and unit
208        """
209        self._yaxis = label
210        self._yunit = unit
211
212    def zaxis(self, label, unit):
213        """
214        set the y axis label and unit
215        """
216        self._zaxis = label
217        self._zunit = unit
218
219
220class Vector(object):
221    """
222    3-space vector of *x*, *y*, *z*
223    """
224    def __init__(self, x=None, y=None, z=None):
225        self.x, self.y, self.z = x, y, z
226
227class Detector(object):
228    """
229    Detector attributes.
230    """
231    def __init__(self, pixel_size=(None, None), distance=None):
232        self.pixel_size = Vector(*pixel_size)
233        self.distance = distance
234
235class Source(object):
236    """
237    Beam attributes.
238    """
239    def __init__(self):
240        self.wavelength = np.NaN
241        self.wavelength_unit = "A"
242
243
244def empty_data1D(q, resolution=0.0):
245    """
246    Create empty 1D data using the given *q* as the x value.
247
248    *resolution* dq/q defaults to 5%.
249    """
250
251    #Iq = 100 * np.ones_like(q)
252    #dIq = np.sqrt(Iq)
253    Iq, dIq = None, None
254    q = np.asarray(q)
255    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
256    data.filename = "fake data"
257    return data
258
259
260def empty_data2D(qx, qy=None, resolution=0.0):
261    """
262    Create empty 2D data using the given mesh.
263
264    If *qy* is missing, create a square mesh with *qy=qx*.
265
266    *resolution* dq/q defaults to 5%.
267    """
268    if qy is None:
269        qy = qx
270    qx, qy = np.asarray(qx), np.asarray(qy)
271    # 5% dQ/Q resolution
272    Qx, Qy = np.meshgrid(qx, qy)
273    Qx, Qy = Qx.flatten(), Qy.flatten()
274    Iq = 100 * np.ones_like(Qx)
275    dIq = np.sqrt(Iq)
276    if resolution != 0:
277        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
278        # Should have an additional constant which depends on distances and
279        # radii of the aperture, pixel dimensions and wavelength spread
280        # Instead, assume radial dQ/Q is constant, and perpendicular matches
281        # radial (which instead it should be inverse).
282        Q = np.sqrt(Qx**2 + Qy**2)
283        dqx = resolution * Q
284        dqy = resolution * Q
285    else:
286        dqx = dqy = None
287
288    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
289    data.x_bins = qx
290    data.y_bins = qy
291    data.filename = "fake data"
292
293    # pixel_size in mm, distance in m
294    detector = Detector(pixel_size=(5, 5), distance=4)
295    data.detector.append(detector)
296    data.source.wavelength = 5 # angstroms
297    data.source.wavelength_unit = "A"
298    return data
299
300
301def plot_data(data, view='log', limits=None):
302    """
303    Plot data loaded by the sasview loader.
304
305    *data* is a sasview data object, either 1D, 2D or SESANS.
306
307    *view* is log or linear.
308
309    *limits* sets the intensity limits on the plot; if None then the limits
310    are inferred from the data.
311    """
312    # Note: kind of weird using the plot result functions to plot just the
313    # data, but they already handle the masking and graph markup already, so
314    # do not repeat.
315    if hasattr(data, 'lam'):
316        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
317    elif hasattr(data, 'qx_data'):
318        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
319    else:
320        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
321
322
323def plot_theory(data, theory, resid=None, view='log',
324                use_data=True, limits=None, Iq_calc=None):
325    """
326    Plot theory calculation.
327
328    *data* is needed to define the graph properties such as labels and
329    units, and to define the data mask.
330
331    *theory* is a matrix of the same shape as the data.
332
333    *view* is log or linear
334
335    *use_data* is True if the data should be plotted as well as the theory.
336
337    *limits* sets the intensity limits on the plot; if None then the limits
338    are inferred from the data.
339    """
340    if hasattr(data, 'lam'):
341        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
342    elif hasattr(data, 'qx_data'):
343        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
344    else:
345        _plot_result1D(data, theory, resid, view, use_data,
346                       limits=limits, Iq_calc=Iq_calc)
347
348
349def protect(fn):
350    """
351    Decorator to wrap calls in an exception trapper which prints the
352    exception and continues.  Keyboard interrupts are ignored.
353    """
354    def wrapper(*args, **kw):
355        """
356        Trap and print errors from function.
357        """
358        try:
359            return fn(*args, **kw)
360        except KeyboardInterrupt:
361            raise
362        except:
363            traceback.print_exc()
364
365    return wrapper
366
367
368@protect
369def _plot_result1D(data, theory, resid, view, use_data,
370                   limits=None, Iq_calc=None):
371    """
372    Plot the data and residuals for 1D data.
373    """
374    import matplotlib.pyplot as plt
375    from numpy.ma import masked_array, masked
376
377    use_data = use_data and data.y is not None
378    use_theory = theory is not None
379    use_resid = resid is not None
380    use_calc = use_theory and Iq_calc is not None
381    num_plots = (use_data or use_theory) + use_calc + use_resid
382    non_positive_x = (data.x<=0.0).any()
383
384    scale = data.x**4 if view == 'q4' else 1.0
385
386    if use_data or use_theory:
387        if num_plots > 1:
388            plt.subplot(1, num_plots, 1)
389
390        #print(vmin, vmax)
391        all_positive = True
392        some_present = False
393        if use_data:
394            mdata = masked_array(data.y, data.mask.copy())
395            mdata[~np.isfinite(mdata)] = masked
396            if view is 'log':
397                mdata[mdata <= 0] = masked
398            plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
399            all_positive = all_positive and (mdata > 0).all()
400            some_present = some_present or (mdata.count() > 0)
401
402
403        if use_theory:
404            # Note: masks merge, so any masked theory points will stay masked,
405            # and the data mask will be added to it.
406            mtheory = masked_array(theory, data.mask.copy())
407            mtheory[~np.isfinite(mtheory)] = masked
408            if view is 'log':
409                mtheory[mtheory <= 0] = masked
410            plt.plot(data.x, scale*mtheory, '-', hold=True)
411            all_positive = all_positive and (mtheory > 0).all()
412            some_present = some_present or (mtheory.count() > 0)
413
414        if limits is not None:
415            plt.ylim(*limits)
416
417        plt.xscale('linear' if not some_present or non_positive_x  else view)
418        plt.yscale('linear'
419                   if view == 'q4' or not some_present or not all_positive
420                   else view)
421        plt.xlabel("$q$/A$^{-1}$")
422        plt.ylabel('$I(q)$')
423
424    if use_calc:
425        # Only have use_calc if have use_theory
426        plt.subplot(1, num_plots, 2)
427        qx, qy, Iqxy = Iq_calc
428        plt.pcolormesh(qx, qy[qy>0], np.log10(Iqxy[qy>0,:]))
429        plt.xlabel("$q_x$/A$^{-1}$")
430        plt.xlabel("$q_y$/A$^{-1}$")
431        plt.xscale('log')
432        plt.yscale('log')
433        #plt.axis('equal')
434
435    if use_resid:
436        mresid = masked_array(resid, data.mask.copy())
437        mresid[~np.isfinite(mresid)] = masked
438        some_present = (mresid.count() > 0)
439
440        if num_plots > 1:
441            plt.subplot(1, num_plots, use_calc + 2)
442        plt.plot(data.x, mresid, '-')
443        plt.xlabel("$q$/A$^{-1}$")
444        plt.ylabel('residuals')
445        plt.xscale('linear' if not some_present or non_positive_x else view)
446
447
448@protect
449def _plot_result_sesans(data, theory, resid, use_data, limits=None):
450    """
451    Plot SESANS results.
452    """
453    import matplotlib.pyplot as plt
454    use_data = use_data and data.y is not None
455    use_theory = theory is not None
456    use_resid = resid is not None
457    num_plots = (use_data or use_theory) + use_resid
458
459    if use_data or use_theory:
460        is_tof = np.any(data.lam!=data.lam[0])
461        if num_plots > 1:
462            plt.subplot(1, num_plots, 1)
463        if use_data:
464            if is_tof:
465                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam), yerr=data.dy/data.y/(data.lam*data.lam))
466            else:
467                plt.errorbar(data.x, data.y, yerr=data.dy)
468        if theory is not None:
469            if is_tof:
470                plt.plot(data.x, np.log(theory)/(data.lam*data.lam), '-', hold=True)
471            else:
472                plt.plot(data.x, theory, '-', hold=True)
473        if limits is not None:
474            plt.ylim(*limits)
475
476        plt.xlabel('spin echo length ({})'.format(data._xunit))
477        if is_tof:
478            plt.ylabel('(Log (P/P$_0$))/$\lambda^2$')
479        else:
480            plt.ylabel('polarization (P/P0)')
481
482
483    if resid is not None:
484        if num_plots > 1:
485            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
486        plt.plot(data.x, resid, 'x')
487        plt.xlabel('spin echo length ({})'.format(data._xunit))
488        plt.ylabel('residuals (P/P0)')
489
490
491@protect
492def _plot_result2D(data, theory, resid, view, use_data, limits=None):
493    """
494    Plot the data and residuals for 2D data.
495    """
496    import matplotlib.pyplot as plt
497    use_data = use_data and data.data is not None
498    use_theory = theory is not None
499    use_resid = resid is not None
500    num_plots = use_data + use_theory + use_resid
501
502    # Put theory and data on a common colormap scale
503    vmin, vmax = np.inf, -np.inf
504    if use_data:
505        target = data.data[~data.mask]
506        datamin = target[target > 0].min() if view == 'log' else target.min()
507        datamax = target.max()
508        vmin = min(vmin, datamin)
509        vmax = max(vmax, datamax)
510    if use_theory:
511        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
512        theorymax = theory.max()
513        vmin = min(vmin, theorymin)
514        vmax = max(vmax, theorymax)
515
516    # Override data limits from the caller
517    if limits is not None:
518        vmin, vmax = limits
519
520    # Plot data
521    if use_data:
522        if num_plots > 1:
523            plt.subplot(1, num_plots, 1)
524        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
525        plt.title('data')
526        h = plt.colorbar()
527        h.set_label('$I(q)$')
528
529    # plot theory
530    if use_theory:
531        if num_plots > 1:
532            plt.subplot(1, num_plots, use_data+1)
533        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
534        plt.title('theory')
535        h = plt.colorbar()
536        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
537                    else r'$q^4 I(q)$' if view == 'q4'
538                    else '$I(q)$')
539
540    # plot resid
541    if use_resid:
542        if num_plots > 1:
543            plt.subplot(1, num_plots, use_data+use_theory+1)
544        _plot_2d_signal(data, resid, view='linear')
545        plt.title('residuals')
546        h = plt.colorbar()
547        h.set_label(r'$\Delta I(q)$')
548
549
550@protect
551def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):
552    """
553    Plot the target value for the data.  This could be the data itself,
554    the theory calculation, or the residuals.
555
556    *scale* can be 'log' for log scale data, or 'linear'.
557    """
558    import matplotlib.pyplot as plt
559    from numpy.ma import masked_array
560
561    image = np.zeros_like(data.qx_data)
562    image[~data.mask] = signal
563    valid = np.isfinite(image)
564    if view == 'log':
565        valid[valid] = (image[valid] > 0)
566        if vmin is None: vmin = image[valid & ~data.mask].min()
567        if vmax is None: vmax = image[valid & ~data.mask].max()
568        image[valid] = np.log10(image[valid])
569    elif view == 'q4':
570        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
571        if vmin is None: vmin = image[valid & ~data.mask].min()
572        if vmax is None: vmax = image[valid & ~data.mask].max()
573    else:
574        if vmin is None: vmin = image[valid & ~data.mask].min()
575        if vmax is None: vmax = image[valid & ~data.mask].max()
576
577    image[~valid | data.mask] = 0
578    #plottable = Iq
579    plottable = masked_array(image, ~valid | data.mask)
580    # Divide range by 10 to convert from angstroms to nanometers
581    xmin, xmax = min(data.qx_data), max(data.qx_data)
582    ymin, ymax = min(data.qy_data), max(data.qy_data)
583    if view == 'log':
584        vmin, vmax = np.log10(vmin), np.log10(vmax)
585    plt.imshow(plottable.reshape(len(data.x_bins), len(data.y_bins)),
586               interpolation='nearest', aspect=1, origin='lower',
587               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
588    plt.xlabel("$q_x$/A$^{-1}$")
589    plt.ylabel("$q_y$/A$^{-1}$")
590    return vmin, vmax
591
592def demo():
593    """
594    Load and plot a SAS dataset.
595    """
596    data = load_data('DEC07086.DAT')
597    set_beam_stop(data, 0.004)
598    plot_data(data)
599    import matplotlib.pyplot as plt; plt.show()
600
601
602if __name__ == "__main__":
603    demo()
Note: See TracBrowser for help on using the repository browser.