source: sasmodels/sasmodels/data.py @ 1a8c11c

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 1a8c11c was 1a8c11c, checked in by Paul Kienzle <pkienzle@…>, 16 months ago

fix handling of 1D data mask when no mask specified

  • Property mode set to 100644
File size: 30.4 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np  # type: ignore
38from numpy import sqrt, sin, cos, pi
39
40# pylint: disable=unused-import
41try:
42    from typing import Union, Dict, List, Optional
43except ImportError:
44    pass
45else:
46    Data = Union["Data1D", "Data2D", "SesansData"]
47# pylint: enable=unused-import
48
49def load_data(filename, index=0):
50    # type: (str) -> Data
51    """
52    Load data using a sasview loader.
53    """
54    from sas.sascalc.dataloader.loader import Loader  # type: ignore
55    loader = Loader()
56    # Allow for one part in multipart file
57    if '[' in filename:
58        filename, indexstr = filename[:-1].split('[')
59        index = int(indexstr)
60    datasets = loader.load(filename)
61    if not datasets:  # None or []
62        raise IOError("Data %r could not be loaded" % filename)
63    if not isinstance(datasets, list):
64        datasets = [datasets]
65    for data in datasets:
66        if hasattr(data, 'x'):
67            data.qmin, data.qmax = data.x.min(), data.x.max()
68            data.mask = (np.isnan(data.y) if data.y is not None
69                         else np.zeros_like(data.x, dtype='bool'))
70        elif hasattr(data, 'qx_data'):
71            data.mask = ~data.mask
72    return datasets[index] if index != 'all' else datasets
73
74
75def set_beam_stop(data, radius, outer=None):
76    # type: (Data, float, Optional[float]) -> None
77    """
78    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
79    """
80    from sas.sascalc.dataloader.manipulations import Ringcut
81    if hasattr(data, 'qx_data'):
82        data.mask = Ringcut(0, radius)(data)
83        if outer is not None:
84            data.mask += Ringcut(outer, np.inf)(data)
85    else:
86        data.mask = (data.x < radius)
87        if outer is not None:
88            data.mask |= (data.x >= outer)
89
90
91def set_half(data, half):
92    # type: (Data, str) -> None
93    """
94    Select half of the data, either "right" or "left".
95    """
96    from sas.sascalc.dataloader.manipulations import Boxcut
97    if half == 'right':
98        data.mask += \
99            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
100    if half == 'left':
101        data.mask += \
102            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
103
104
105def set_top(data, cutoff):
106    # type: (Data, float) -> None
107    """
108    Chop the top off the data, above *cutoff*.
109    """
110    from sas.sascalc.dataloader.manipulations import Boxcut
111    data.mask += \
112        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
113
114
115class Data1D(object):
116    """
117    1D data object.
118
119    Note that this definition matches the attributes from sasview, with
120    some generic 1D data vectors and some SAS specific definitions.  Some
121    refactoring to allow consistent naming conventions between 1D, 2D and
122    SESANS data would be helpful.
123
124    **Attributes**
125
126    *x*, *dx*: $q$ vector and gaussian resolution
127
128    *y*, *dy*: $I(q)$ vector and measurement uncertainty
129
130    *mask*: values to include in plotting/analysis
131
132    *dxl*: slit widths for slit smeared data, with *dx* ignored
133
134    *qmin*, *qmax*: range of $q$ values in *x*
135
136    *filename*: label for the data line
137
138    *_xaxis*, *_xunit*: label and units for the *x* axis
139
140    *_yaxis*, *_yunit*: label and units for the *y* axis
141    """
142    def __init__(self,
143                 x=None,  # type: Optional[np.ndarray]
144                 y=None,  # type: Optional[np.ndarray]
145                 dx=None, # type: Optional[np.ndarray]
146                 dy=None  # type: Optional[np.ndarray]
147                ):
148        # type: (...) -> None
149        self.x, self.y, self.dx, self.dy = x, y, dx, dy
150        self.dxl = None
151        self.filename = None
152        self.qmin = x.min() if x is not None else np.NaN
153        self.qmax = x.max() if x is not None else np.NaN
154        # TODO: why is 1D mask False and 2D mask True?
155        self.mask = (np.isnan(y) if y is not None
156                     else np.zeros_like(x, 'b') if x is not None
157                     else None)
158        self._xaxis, self._xunit = "x", ""
159        self._yaxis, self._yunit = "y", ""
160
161    def xaxis(self, label, unit):
162        # type: (str, str) -> None
163        """
164        set the x axis label and unit
165        """
166        self._xaxis = label
167        self._xunit = unit
168
169    def yaxis(self, label, unit):
170        # type: (str, str) -> None
171        """
172        set the y axis label and unit
173        """
174        self._yaxis = label
175        self._yunit = unit
176
177class SesansData(Data1D):
178    """
179    SESANS data object.
180
181    This is just :class:`Data1D` with a wavelength parameter.
182
183    *x* is spin echo length and *y* is polarization (P/P0).
184    """
185    def __init__(self, **kw):
186        Data1D.__init__(self, **kw)
187        self.lam = None # type: Optional[np.ndarray]
188
189class Data2D(object):
190    """
191    2D data object.
192
193    Note that this definition matches the attributes from sasview. Some
194    refactoring to allow consistent naming conventions between 1D, 2D and
195    SESANS data would be helpful.
196
197    **Attributes**
198
199    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
200
201    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
202
203    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
204
205    *mask*: values to exclude from plotting/analysis
206
207    *qmin*, *qmax*: range of $q$ values in *x*
208
209    *filename*: label for the data line
210
211    *_xaxis*, *_xunit*: label and units for the *x* axis
212
213    *_yaxis*, *_yunit*: label and units for the *y* axis
214
215    *_zaxis*, *_zunit*: label and units for the *y* axis
216
217    *Q_unit*, *I_unit*: units for Q and intensity
218
219    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
220    """
221    def __init__(self,
222                 x=None,   # type: Optional[np.ndarray]
223                 y=None,   # type: Optional[np.ndarray]
224                 z=None,   # type: Optional[np.ndarray]
225                 dx=None,  # type: Optional[np.ndarray]
226                 dy=None,  # type: Optional[np.ndarray]
227                 dz=None   # type: Optional[np.ndarray]
228                ):
229        # type: (...) -> None
230        self.qx_data, self.dqx_data = x, dx
231        self.qy_data, self.dqy_data = y, dy
232        self.data, self.err_data = z, dz
233        self.mask = (np.isnan(z) if z is not None
234                     else np.zeros_like(x, dtype='bool') if x is not None
235                     else None)
236        self.q_data = np.sqrt(x**2 + y**2)
237        self.qmin = 1e-16
238        self.qmax = np.inf
239        self.detector = []
240        self.source = Source()
241        self.Q_unit = "1/A"
242        self.I_unit = "1/cm"
243        self.xaxis("Q_x", "1/A")
244        self.yaxis("Q_y", "1/A")
245        self.zaxis("Intensity", "1/cm")
246        self._xaxis, self._xunit = "x", ""
247        self._yaxis, self._yunit = "y", ""
248        self._zaxis, self._zunit = "z", ""
249        self.x_bins, self.y_bins = None, None
250        self.filename = None
251
252    def xaxis(self, label, unit):
253        # type: (str, str) -> None
254        """
255        set the x axis label and unit
256        """
257        self._xaxis = label
258        self._xunit = unit
259
260    def yaxis(self, label, unit):
261        # type: (str, str) -> None
262        """
263        set the y axis label and unit
264        """
265        self._yaxis = label
266        self._yunit = unit
267
268    def zaxis(self, label, unit):
269        # type: (str, str) -> None
270        """
271        set the y axis label and unit
272        """
273        self._zaxis = label
274        self._zunit = unit
275
276
277class Vector(object):
278    """
279    3-space vector of *x*, *y*, *z*
280    """
281    def __init__(self, x=None, y=None, z=None):
282        # type: (float, float, Optional[float]) -> None
283        self.x, self.y, self.z = x, y, z
284
285class Detector(object):
286    """
287    Detector attributes.
288    """
289    def __init__(self, pixel_size=(None, None), distance=None):
290        # type: (Tuple[float, float], float) -> None
291        self.pixel_size = Vector(*pixel_size)
292        self.distance = distance
293
294class Source(object):
295    """
296    Beam attributes.
297    """
298    def __init__(self):
299        # type: () -> None
300        self.wavelength = np.NaN
301        self.wavelength_unit = "A"
302
303
304def empty_data1D(q, resolution=0.0, L=0., dL=0.):
305    # type: (np.ndarray, float) -> Data1D
306    r"""
307    Create empty 1D data using the given *q* as the x value.
308
309    rms *resolution* $\Delta q/q$ defaults to 0%.  If wavelength *L* and rms
310    wavelength divergence *dL* are defined, then *resolution* defines
311    rms $\Delta \theta/\theta$ for the lowest *q*, with $\theta$ derived from
312    $q = 4\pi/\lambda \sin(\theta)$.
313    """
314
315    #Iq = 100 * np.ones_like(q)
316    #dIq = np.sqrt(Iq)
317    Iq, dIq = None, None
318    q = np.asarray(q)
319    if L != 0 and resolution != 0:
320        theta = np.arcsin(q*L/(4*pi))
321        dtheta = theta[0]*resolution
322        ## Solving Gaussian error propagation from
323        ##   Dq^2 = (dq/dL)^2 DL^2 + (dq/dtheta)^2 Dtheta^2
324        ## gives
325        ##   (Dq/q)^2 = (DL/L)**2 + (Dtheta/tan(theta))**2
326        ## Take the square root and multiply by q, giving
327        ##   Dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
328        dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
329    else:
330        dq = resolution * q
331
332    data = Data1D(q, Iq, dx=dq, dy=dIq)
333    data.filename = "fake data"
334    return data
335
336
337def empty_data2D(qx, qy=None, resolution=0.0):
338    # type: (np.ndarray, Optional[np.ndarray], float) -> Data2D
339    """
340    Create empty 2D data using the given mesh.
341
342    If *qy* is missing, create a square mesh with *qy=qx*.
343
344    *resolution* dq/q defaults to 5%.
345    """
346    if qy is None:
347        qy = qx
348    qx, qy = np.asarray(qx), np.asarray(qy)
349    # 5% dQ/Q resolution
350    Qx, Qy = np.meshgrid(qx, qy)
351    Qx, Qy = Qx.flatten(), Qy.flatten()
352    Iq = 100 * np.ones_like(Qx)  # type: np.ndarray
353    dIq = np.sqrt(Iq)
354    if resolution != 0:
355        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
356        # Should have an additional constant which depends on distances and
357        # radii of the aperture, pixel dimensions and wavelength spread
358        # Instead, assume radial dQ/Q is constant, and perpendicular matches
359        # radial (which instead it should be inverse).
360        Q = np.sqrt(Qx**2 + Qy**2)
361        dqx = resolution * Q
362        dqy = resolution * Q
363    else:
364        dqx = dqy = None
365
366    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
367    data.x_bins = qx
368    data.y_bins = qy
369    data.filename = "fake data"
370
371    # pixel_size in mm, distance in m
372    detector = Detector(pixel_size=(5, 5), distance=4)
373    data.detector.append(detector)
374    data.source.wavelength = 5 # angstroms
375    data.source.wavelength_unit = "A"
376    return data
377
378
379def plot_data(data, view='log', limits=None):
380    # type: (Data, str, Optional[Tuple[float, float]]) -> None
381    """
382    Plot data loaded by the sasview loader.
383
384    *data* is a sasview data object, either 1D, 2D or SESANS.
385
386    *view* is log or linear.
387
388    *limits* sets the intensity limits on the plot; if None then the limits
389    are inferred from the data.
390    """
391    # Note: kind of weird using the plot result functions to plot just the
392    # data, but they already handle the masking and graph markup already, so
393    # do not repeat.
394    if hasattr(data, 'isSesans') and data.isSesans:
395        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
396    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
397        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
398    else:
399        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
400
401
402def plot_theory(data,          # type: Data
403                theory,        # type: Optional[np.ndarray]
404                resid=None,    # type: Optional[np.ndarray]
405                view='log',    # type: str
406                use_data=True, # type: bool
407                limits=None,   # type: Optional[np.ndarray]
408                Iq_calc=None   # type: Optional[np.ndarray]
409               ):
410    # type: (...) -> None
411    """
412    Plot theory calculation.
413
414    *data* is needed to define the graph properties such as labels and
415    units, and to define the data mask.
416
417    *theory* is a matrix of the same shape as the data.
418
419    *view* is log or linear
420
421    *use_data* is True if the data should be plotted as well as the theory.
422
423    *limits* sets the intensity limits on the plot; if None then the limits
424    are inferred from the data.
425
426    *Iq_calc* is the raw theory values without resolution smearing
427    """
428    if hasattr(data, 'isSesans') and data.isSesans:
429        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
430    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
431        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
432    else:
433        _plot_result1D(data, theory, resid, view, use_data,
434                       limits=limits, Iq_calc=Iq_calc)
435
436
437def protect(func):
438    # type: (Callable) -> Callable
439    """
440    Decorator to wrap calls in an exception trapper which prints the
441    exception and continues.  Keyboard interrupts are ignored.
442    """
443    def wrapper(*args, **kw):
444        """
445        Trap and print errors from function.
446        """
447        try:
448            return func(*args, **kw)
449        except Exception:
450            traceback.print_exc()
451
452    return wrapper
453
454
455@protect
456def _plot_result1D(data,         # type: Data1D
457                   theory,       # type: Optional[np.ndarray]
458                   resid,        # type: Optional[np.ndarray]
459                   view,         # type: str
460                   use_data,     # type: bool
461                   limits=None,  # type: Optional[Tuple[float, float]]
462                   Iq_calc=None  # type: Optional[np.ndarray]
463                  ):
464    # type: (...) -> None
465    """
466    Plot the data and residuals for 1D data.
467    """
468    import matplotlib.pyplot as plt  # type: ignore
469    from numpy.ma import masked_array, masked  # type: ignore
470
471    if getattr(data, 'radial', False):
472        data.x = data.q_data
473        data.y = data.data
474
475    use_data = use_data and data.y is not None
476    use_theory = theory is not None
477    use_resid = resid is not None
478    use_calc = use_theory and Iq_calc is not None
479    num_plots = (use_data or use_theory) + use_calc + use_resid
480    non_positive_x = (data.x <= 0.0).any()
481
482    scale = data.x**4 if view == 'q4' else 1.0
483    xscale = yscale = 'linear' if view == 'linear' else 'log'
484
485    if use_data or use_theory:
486        if num_plots > 1:
487            plt.subplot(1, num_plots, 1)
488
489        #print(vmin, vmax)
490        all_positive = True
491        some_present = False
492        if use_data:
493            mdata = masked_array(data.y, data.mask.copy())
494            mdata[~np.isfinite(mdata)] = masked
495            if view is 'log':
496                mdata[mdata <= 0] = masked
497            plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
498            all_positive = all_positive and (mdata > 0).all()
499            some_present = some_present or (mdata.count() > 0)
500
501
502        if use_theory:
503            # Note: masks merge, so any masked theory points will stay masked,
504            # and the data mask will be added to it.
505            #mtheory = masked_array(theory, data.mask.copy())
506            theory_x = data.x[~data.mask]
507            mtheory = masked_array(theory)
508            mtheory[~np.isfinite(mtheory)] = masked
509            if view is 'log':
510                mtheory[mtheory <= 0] = masked
511            plt.plot(theory_x, scale*mtheory, '-')
512            all_positive = all_positive and (mtheory > 0).all()
513            some_present = some_present or (mtheory.count() > 0)
514
515        if limits is not None:
516            plt.ylim(*limits)
517
518
519        xscale = ('linear' if not some_present or non_positive_x
520                  else view if view is not None
521                  else 'log')
522        yscale = ('linear'
523                  if view == 'q4' or not some_present or not all_positive
524                  else view if view is not None
525                  else 'log')
526        plt.xscale(xscale)
527        plt.xlabel("$q$/A$^{-1}$")
528        plt.yscale(yscale)
529        plt.ylabel('$I(q)$')
530        title = ("data and model" if use_theory and use_data
531                 else "data" if use_data
532                 else "model")
533        plt.title(title)
534
535    if use_calc:
536        # Only have use_calc if have use_theory
537        plt.subplot(1, num_plots, 2)
538        qx, qy, Iqxy = Iq_calc
539        plt.pcolormesh(qx, qy[qy > 0], np.log10(Iqxy[qy > 0, :]))
540        plt.xlabel("$q_x$/A$^{-1}$")
541        plt.xlabel("$q_y$/A$^{-1}$")
542        plt.xscale('log')
543        plt.yscale('log')
544        #plt.axis('equal')
545
546    if use_resid:
547        theory_x = data.x[~data.mask]
548        mresid = masked_array(resid)
549        mresid[~np.isfinite(mresid)] = masked
550        some_present = (mresid.count() > 0)
551
552        if num_plots > 1:
553            plt.subplot(1, num_plots, use_calc + 2)
554        plt.plot(theory_x, mresid, '.')
555        plt.xlabel("$q$/A$^{-1}$")
556        plt.ylabel('residuals')
557        plt.title('(model - Iq)/dIq')
558        plt.xscale(xscale)
559        plt.yscale('linear')
560
561
562@protect
563def _plot_result_sesans(data,        # type: SesansData
564                        theory,      # type: Optional[np.ndarray]
565                        resid,       # type: Optional[np.ndarray]
566                        use_data,    # type: bool
567                        limits=None  # type: Optional[Tuple[float, float]]
568                       ):
569    # type: (...) -> None
570    """
571    Plot SESANS results.
572    """
573    import matplotlib.pyplot as plt  # type: ignore
574    use_data = use_data and data.y is not None
575    use_theory = theory is not None
576    use_resid = resid is not None
577    num_plots = (use_data or use_theory) + use_resid
578
579    if use_data or use_theory:
580        is_tof = data.lam is not None and (data.lam != data.lam[0]).any()
581        if num_plots > 1:
582            plt.subplot(1, num_plots, 1)
583        if use_data:
584            if is_tof:
585                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam),
586                             yerr=data.dy/data.y/(data.lam*data.lam))
587            else:
588                plt.errorbar(data.x, data.y, yerr=data.dy)
589        if theory is not None:
590            if is_tof:
591                plt.plot(data.x, np.log(theory)/(data.lam*data.lam), '-')
592            else:
593                plt.plot(data.x, theory, '-')
594        if limits is not None:
595            plt.ylim(*limits)
596
597        plt.xlabel('spin echo length ({})'.format(data._xunit))
598        if is_tof:
599            plt.ylabel(r'(Log (P/P$_0$))/$\lambda^2$')
600        else:
601            plt.ylabel('polarization (P/P0)')
602
603
604    if resid is not None:
605        if num_plots > 1:
606            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
607        plt.plot(data.x, resid, 'x')
608        plt.xlabel('spin echo length ({})'.format(data._xunit))
609        plt.ylabel('residuals (P/P0)')
610
611
612@protect
613def _plot_result2D(data,         # type: Data2D
614                   theory,       # type: Optional[np.ndarray]
615                   resid,        # type: Optional[np.ndarray]
616                   view,         # type: str
617                   use_data,     # type: bool
618                   limits=None   # type: Optional[Tuple[float, float]]
619                  ):
620    # type: (...) -> None
621    """
622    Plot the data and residuals for 2D data.
623    """
624    import matplotlib.pyplot as plt  # type: ignore
625    use_data = use_data and data.data is not None
626    use_theory = theory is not None
627    use_resid = resid is not None
628    num_plots = use_data + use_theory + use_resid
629
630    # Put theory and data on a common colormap scale
631    vmin, vmax = np.inf, -np.inf
632    target = None # type: Optional[np.ndarray]
633    if use_data:
634        target = data.data[~data.mask]
635        datamin = target[target > 0].min() if view == 'log' else target.min()
636        datamax = target.max()
637        vmin = min(vmin, datamin)
638        vmax = max(vmax, datamax)
639    if use_theory:
640        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
641        theorymax = theory.max()
642        vmin = min(vmin, theorymin)
643        vmax = max(vmax, theorymax)
644
645    # Override data limits from the caller
646    if limits is not None:
647        vmin, vmax = limits
648
649    # Plot data
650    if use_data:
651        if num_plots > 1:
652            plt.subplot(1, num_plots, 1)
653        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
654        plt.title('data')
655        h = plt.colorbar()
656        h.set_label('$I(q)$')
657
658    # plot theory
659    if use_theory:
660        if num_plots > 1:
661            plt.subplot(1, num_plots, use_data+1)
662        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
663        plt.title('theory')
664        h = plt.colorbar()
665        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
666                    else r'$q^4 I(q)$' if view == 'q4'
667                    else '$I(q)$')
668
669    # plot resid
670    if use_resid:
671        if num_plots > 1:
672            plt.subplot(1, num_plots, use_data+use_theory+1)
673        _plot_2d_signal(data, resid, view='linear')
674        plt.title('residuals')
675        h = plt.colorbar()
676        h.set_label(r'$\Delta I(q)$')
677
678
679@protect
680def _plot_2d_signal(data,       # type: Data2D
681                    signal,     # type: np.ndarray
682                    vmin=None,  # type: Optional[float]
683                    vmax=None,  # type: Optional[float]
684                    view='log'  # type: str
685                   ):
686    # type: (...) -> Tuple[float, float]
687    """
688    Plot the target value for the data.  This could be the data itself,
689    the theory calculation, or the residuals.
690
691    *scale* can be 'log' for log scale data, or 'linear'.
692    """
693    import matplotlib.pyplot as plt  # type: ignore
694    from numpy.ma import masked_array  # type: ignore
695
696    image = np.zeros_like(data.qx_data)
697    image[~data.mask] = signal
698    valid = np.isfinite(image)
699    if view == 'log':
700        valid[valid] = (image[valid] > 0)
701        if vmin is None:
702            vmin = image[valid & ~data.mask].min()
703        if vmax is None:
704            vmax = image[valid & ~data.mask].max()
705        image[valid] = np.log10(image[valid])
706    elif view == 'q4':
707        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
708        if vmin is None:
709            vmin = image[valid & ~data.mask].min()
710        if vmax is None:
711            vmax = image[valid & ~data.mask].max()
712    else:
713        if vmin is None:
714            vmin = image[valid & ~data.mask].min()
715        if vmax is None:
716            vmax = image[valid & ~data.mask].max()
717
718    image[~valid | data.mask] = 0
719    #plottable = Iq
720    plottable = masked_array(image, ~valid | data.mask)
721    # Divide range by 10 to convert from angstroms to nanometers
722    xmin, xmax = min(data.qx_data), max(data.qx_data)
723    ymin, ymax = min(data.qy_data), max(data.qy_data)
724    if view == 'log':
725        vmin_scaled, vmax_scaled = np.log10(vmin), np.log10(vmax)
726    else:
727        vmin_scaled, vmax_scaled = vmin, vmax
728    #nx, ny = len(data.x_bins), len(data.y_bins)
729    x_bins, y_bins, image = _build_matrix(data, plottable)
730    plt.imshow(image,
731               interpolation='nearest', aspect=1, origin='lower',
732               extent=[xmin, xmax, ymin, ymax],
733               vmin=vmin_scaled, vmax=vmax_scaled)
734    plt.xlabel("$q_x$/A$^{-1}$")
735    plt.ylabel("$q_y$/A$^{-1}$")
736    return vmin, vmax
737
738
739# === The following is modified from sas.sasgui.plottools.PlotPanel
740def _build_matrix(self, plottable):
741    """
742    Build a matrix for 2d plot from a vector
743    Returns a matrix (image) with ~ square binning
744    Requirement: need 1d array formats of
745    self.data, self.qx_data, and self.qy_data
746    where each one corresponds to z, x, or y axis values
747
748    """
749    # No qx or qy given in a vector format
750    if self.qx_data is None or self.qy_data is None \
751            or self.qx_data.ndim != 1 or self.qy_data.ndim != 1:
752        return self.x_bins, self.y_bins, plottable
753
754    # maximum # of loops to fillup_pixels
755    # otherwise, loop could never stop depending on data
756    max_loop = 1
757    # get the x and y_bin arrays.
758    x_bins, y_bins = _get_bins(self)
759    # set zero to None
760
761    #Note: Can not use scipy.interpolate.Rbf:
762    # 'cause too many data points (>10000)<=JHC.
763    # 1d array to use for weighting the data point averaging
764    #when they fall into a same bin.
765    weights_data = np.ones([self.data.size])
766    # get histogram of ones w/len(data); this will provide
767    #the weights of data on each bins
768    weights, xedges, yedges = np.histogram2d(x=self.qy_data,
769                                             y=self.qx_data,
770                                             bins=[y_bins, x_bins],
771                                             weights=weights_data)
772    # get histogram of data, all points into a bin in a way of summing
773    image, xedges, yedges = np.histogram2d(x=self.qy_data,
774                                           y=self.qx_data,
775                                           bins=[y_bins, x_bins],
776                                           weights=plottable)
777    # Now, normalize the image by weights only for weights>1:
778    # If weight == 1, there is only one data point in the bin so
779    # that no normalization is required.
780    image[weights > 1] = image[weights > 1] / weights[weights > 1]
781    # Set image bins w/o a data point (weight==0) as None (was set to zero
782    # by histogram2d.)
783    image[weights == 0] = None
784
785    # Fill empty bins with 8 nearest neighbors only when at least
786    #one None point exists
787    loop = 0
788
789    # do while loop until all vacant bins are filled up up
790    #to loop = max_loop
791    while (weights == 0).any():
792        if loop >= max_loop:  # this protects never-ending loop
793            break
794        image = _fillup_pixels(image=image, weights=weights)
795        loop += 1
796
797    return x_bins, y_bins, image
798
799def _get_bins(self):
800    """
801    get bins
802    set x_bins and y_bins into self, 1d arrays of the index with
803    ~ square binning
804    Requirement: need 1d array formats of
805    self.qx_data, and self.qy_data
806    where each one corresponds to  x, or y axis values
807    """
808    # find max and min values of qx and qy
809    xmax = self.qx_data.max()
810    xmin = self.qx_data.min()
811    ymax = self.qy_data.max()
812    ymin = self.qy_data.min()
813
814    # calculate the range of qx and qy: this way, it is a little
815    # more independent
816    x_size = xmax - xmin
817    y_size = ymax - ymin
818
819    # estimate the # of pixels on each axes
820    npix_y = int(np.floor(np.sqrt(len(self.qy_data))))
821    npix_x = int(np.floor(len(self.qy_data) / npix_y))
822
823    # bin size: x- & y-directions
824    xstep = x_size / (npix_x - 1)
825    ystep = y_size / (npix_y - 1)
826
827    # max and min taking account of the bin sizes
828    xmax = xmax + xstep / 2.0
829    xmin = xmin - xstep / 2.0
830    ymax = ymax + ystep / 2.0
831    ymin = ymin - ystep / 2.0
832
833    # store x and y bin centers in q space
834    x_bins = np.linspace(xmin, xmax, npix_x)
835    y_bins = np.linspace(ymin, ymax, npix_y)
836
837    return x_bins, y_bins
838
839def _fillup_pixels(image=None, weights=None):
840    """
841    Fill z values of the empty cells of 2d image matrix
842    with the average over up-to next nearest neighbor points
843
844    :param image: (2d matrix with some zi = None)
845
846    :return: image (2d array )
847
848    :TODO: Find better way to do for-loop below
849
850    """
851    # No image matrix given
852    if image is None or np.ndim(image) != 2 \
853            or np.isfinite(image).all() \
854            or weights is None:
855        return image
856    # Get bin size in y and x directions
857    len_y = len(image)
858    len_x = len(image[1])
859    temp_image = np.zeros([len_y, len_x])
860    weit = np.zeros([len_y, len_x])
861    # do for-loop for all pixels
862    for n_y in range(len(image)):
863        for n_x in range(len(image[1])):
864            # find only null pixels
865            if weights[n_y][n_x] > 0 or np.isfinite(image[n_y][n_x]):
866                continue
867            else:
868                # find 4 nearest neighbors
869                # check where or not it is at the corner
870                if n_y != 0 and np.isfinite(image[n_y - 1][n_x]):
871                    temp_image[n_y][n_x] += image[n_y - 1][n_x]
872                    weit[n_y][n_x] += 1
873                if n_x != 0 and np.isfinite(image[n_y][n_x - 1]):
874                    temp_image[n_y][n_x] += image[n_y][n_x - 1]
875                    weit[n_y][n_x] += 1
876                if n_y != len_y - 1 and np.isfinite(image[n_y + 1][n_x]):
877                    temp_image[n_y][n_x] += image[n_y + 1][n_x]
878                    weit[n_y][n_x] += 1
879                if n_x != len_x - 1 and np.isfinite(image[n_y][n_x + 1]):
880                    temp_image[n_y][n_x] += image[n_y][n_x + 1]
881                    weit[n_y][n_x] += 1
882                # go 4 next nearest neighbors when no non-zero
883                # neighbor exists
884                if n_y != 0 and n_x != 0 and \
885                        np.isfinite(image[n_y - 1][n_x - 1]):
886                    temp_image[n_y][n_x] += image[n_y - 1][n_x - 1]
887                    weit[n_y][n_x] += 1
888                if n_y != len_y - 1 and n_x != 0 and \
889                        np.isfinite(image[n_y + 1][n_x - 1]):
890                    temp_image[n_y][n_x] += image[n_y + 1][n_x - 1]
891                    weit[n_y][n_x] += 1
892                if n_y != len_y and n_x != len_x - 1 and \
893                        np.isfinite(image[n_y - 1][n_x + 1]):
894                    temp_image[n_y][n_x] += image[n_y - 1][n_x + 1]
895                    weit[n_y][n_x] += 1
896                if n_y != len_y - 1 and n_x != len_x - 1 and \
897                        np.isfinite(image[n_y + 1][n_x + 1]):
898                    temp_image[n_y][n_x] += image[n_y + 1][n_x + 1]
899                    weit[n_y][n_x] += 1
900
901    # get it normalized
902    ind = (weit > 0)
903    image[ind] = temp_image[ind] / weit[ind]
904
905    return image
906
907
908def demo():
909    # type: () -> None
910    """
911    Load and plot a SAS dataset.
912    """
913    data = load_data('DEC07086.DAT')
914    set_beam_stop(data, 0.004)
915    plot_data(data)
916    import matplotlib.pyplot as plt  # type: ignore
917    plt.show()
918
919
920if __name__ == "__main__":
921    demo()
Note: See TracBrowser for help on using the repository browser.