source: sasmodels/sasmodels/data.py @ 71b751d

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 71b751d was 7e923c2, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

Merge branch 'master' into beta_approx

  • Property mode set to 100644
File size: 30.4 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np  # type: ignore
38from numpy import sqrt, sin, cos, pi
39
40# pylint: disable=unused-import
41try:
42    from typing import Union, Dict, List, Optional
43except ImportError:
44    pass
45else:
46    Data = Union["Data1D", "Data2D", "SesansData"]
47# pylint: enable=unused-import
48
49def load_data(filename, index=0):
50    # type: (str) -> Data
51    """
52    Load data using a sasview loader.
53    """
54    from sas.sascalc.dataloader.loader import Loader  # type: ignore
55    loader = Loader()
56    # Allow for one part in multipart file
57    if '[' in filename:
58        filename, indexstr = filename[:-1].split('[')
59        index = int(indexstr)
60    datasets = loader.load(filename)
61    if not datasets:  # None or []
62        raise IOError("Data %r could not be loaded" % filename)
63    if not isinstance(datasets, list):
64        datasets = [datasets]
65    for data in datasets:
66        if hasattr(data, 'x'):
67            data.qmin, data.qmax = data.x.min(), data.x.max()
68            data.mask = (np.isnan(data.y) if data.y is not None
69                         else np.zeros_like(data.x, dtype='bool'))
70        elif hasattr(data, 'qx_data'):
71            data.mask = ~data.mask
72    return datasets[index] if index != 'all' else datasets
73
74
75def set_beam_stop(data, radius, outer=None):
76    # type: (Data, float, Optional[float]) -> None
77    """
78    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
79    """
80    from sas.sascalc.dataloader.manipulations import Ringcut
81    if hasattr(data, 'qx_data'):
82        data.mask = Ringcut(0, radius)(data)
83        if outer is not None:
84            data.mask += Ringcut(outer, np.inf)(data)
85    else:
86        data.mask = (data.x < radius)
87        if outer is not None:
88            data.mask |= (data.x >= outer)
89
90
91def set_half(data, half):
92    # type: (Data, str) -> None
93    """
94    Select half of the data, either "right" or "left".
95    """
96    from sas.sascalc.dataloader.manipulations import Boxcut
97    if half == 'right':
98        data.mask += \
99            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
100    if half == 'left':
101        data.mask += \
102            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
103
104
105def set_top(data, cutoff):
106    # type: (Data, float) -> None
107    """
108    Chop the top off the data, above *cutoff*.
109    """
110    from sas.sascalc.dataloader.manipulations import Boxcut
111    data.mask += \
112        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
113
114
115class Data1D(object):
116    """
117    1D data object.
118
119    Note that this definition matches the attributes from sasview, with
120    some generic 1D data vectors and some SAS specific definitions.  Some
121    refactoring to allow consistent naming conventions between 1D, 2D and
122    SESANS data would be helpful.
123
124    **Attributes**
125
126    *x*, *dx*: $q$ vector and gaussian resolution
127
128    *y*, *dy*: $I(q)$ vector and measurement uncertainty
129
130    *mask*: values to include in plotting/analysis
131
132    *dxl*: slit widths for slit smeared data, with *dx* ignored
133
134    *qmin*, *qmax*: range of $q$ values in *x*
135
136    *filename*: label for the data line
137
138    *_xaxis*, *_xunit*: label and units for the *x* axis
139
140    *_yaxis*, *_yunit*: label and units for the *y* axis
141    """
142    def __init__(self,
143                 x=None,  # type: Optional[np.ndarray]
144                 y=None,  # type: Optional[np.ndarray]
145                 dx=None, # type: Optional[np.ndarray]
146                 dy=None  # type: Optional[np.ndarray]
147                ):
148        # type: (...) -> None
149        self.x, self.y, self.dx, self.dy = x, y, dx, dy
150        self.dxl = None
151        self.filename = None
152        self.qmin = x.min() if x is not None else np.NaN
153        self.qmax = x.max() if x is not None else np.NaN
154        # TODO: why is 1D mask False and 2D mask True?
155        self.mask = (np.isnan(y) if y is not None
156                     else np.zeros_like(x, 'b') if x is not None
157                     else None)
158        self._xaxis, self._xunit = "x", ""
159        self._yaxis, self._yunit = "y", ""
160
161    def xaxis(self, label, unit):
162        # type: (str, str) -> None
163        """
164        set the x axis label and unit
165        """
166        self._xaxis = label
167        self._xunit = unit
168
169    def yaxis(self, label, unit):
170        # type: (str, str) -> None
171        """
172        set the y axis label and unit
173        """
174        self._yaxis = label
175        self._yunit = unit
176
177class SesansData(Data1D):
178    """
179    SESANS data object.
180
181    This is just :class:`Data1D` with a wavelength parameter.
182
183    *x* is spin echo length and *y* is polarization (P/P0).
184    """
185    def __init__(self, **kw):
186        Data1D.__init__(self, **kw)
187        self.lam = None # type: Optional[np.ndarray]
188
189class Data2D(object):
190    """
191    2D data object.
192
193    Note that this definition matches the attributes from sasview. Some
194    refactoring to allow consistent naming conventions between 1D, 2D and
195    SESANS data would be helpful.
196
197    **Attributes**
198
199    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
200
201    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
202
203    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
204
205    *mask*: values to exclude from plotting/analysis
206
207    *qmin*, *qmax*: range of $q$ values in *x*
208
209    *filename*: label for the data line
210
211    *_xaxis*, *_xunit*: label and units for the *x* axis
212
213    *_yaxis*, *_yunit*: label and units for the *y* axis
214
215    *_zaxis*, *_zunit*: label and units for the *y* axis
216
217    *Q_unit*, *I_unit*: units for Q and intensity
218
219    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
220    """
221    def __init__(self,
222                 x=None,   # type: Optional[np.ndarray]
223                 y=None,   # type: Optional[np.ndarray]
224                 z=None,   # type: Optional[np.ndarray]
225                 dx=None,  # type: Optional[np.ndarray]
226                 dy=None,  # type: Optional[np.ndarray]
227                 dz=None   # type: Optional[np.ndarray]
228                ):
229        # type: (...) -> None
230        self.qx_data, self.dqx_data = x, dx
231        self.qy_data, self.dqy_data = y, dy
232        self.data, self.err_data = z, dz
233        self.mask = (np.isnan(z) if z is not None
234                     else np.zeros_like(x, dtype='bool') if x is not None
235                     else None)
236        self.q_data = np.sqrt(x**2 + y**2)
237        self.qmin = 1e-16
238        self.qmax = np.inf
239        self.detector = []
240        self.source = Source()
241        self.Q_unit = "1/A"
242        self.I_unit = "1/cm"
243        self.xaxis("Q_x", "1/A")
244        self.yaxis("Q_y", "1/A")
245        self.zaxis("Intensity", "1/cm")
246        self._xaxis, self._xunit = "x", ""
247        self._yaxis, self._yunit = "y", ""
248        self._zaxis, self._zunit = "z", ""
249        self.x_bins, self.y_bins = None, None
250        self.filename = None
251
252    def xaxis(self, label, unit):
253        # type: (str, str) -> None
254        """
255        set the x axis label and unit
256        """
257        self._xaxis = label
258        self._xunit = unit
259
260    def yaxis(self, label, unit):
261        # type: (str, str) -> None
262        """
263        set the y axis label and unit
264        """
265        self._yaxis = label
266        self._yunit = unit
267
268    def zaxis(self, label, unit):
269        # type: (str, str) -> None
270        """
271        set the y axis label and unit
272        """
273        self._zaxis = label
274        self._zunit = unit
275
276
277class Vector(object):
278    """
279    3-space vector of *x*, *y*, *z*
280    """
281    def __init__(self, x=None, y=None, z=None):
282        # type: (float, float, Optional[float]) -> None
283        self.x, self.y, self.z = x, y, z
284
285class Detector(object):
286    """
287    Detector attributes.
288    """
289    def __init__(self, pixel_size=(None, None), distance=None):
290        # type: (Tuple[float, float], float) -> None
291        self.pixel_size = Vector(*pixel_size)
292        self.distance = distance
293
294class Source(object):
295    """
296    Beam attributes.
297    """
298    def __init__(self):
299        # type: () -> None
300        self.wavelength = np.NaN
301        self.wavelength_unit = "A"
302
303
304def empty_data1D(q, resolution=0.0, L=0., dL=0.):
305    # type: (np.ndarray, float) -> Data1D
306    r"""
307    Create empty 1D data using the given *q* as the x value.
308
309    rms *resolution* $\Delta q/q$ defaults to 0%.  If wavelength *L* and rms
310    wavelength divergence *dL* are defined, then *resolution* defines
311    rms $\Delta \theta/\theta$ for the lowest *q*, with $\theta$ derived from
312    $q = 4\pi/\lambda \sin(\theta)$.
313    """
314
315    #Iq = 100 * np.ones_like(q)
316    #dIq = np.sqrt(Iq)
317    Iq, dIq = None, None
318    q = np.asarray(q)
319    if L != 0 and resolution != 0:
320        theta = np.arcsin(q*L/(4*pi))
321        dtheta = theta[0]*resolution
322        ## Solving Gaussian error propagation from
323        ##   Dq^2 = (dq/dL)^2 DL^2 + (dq/dtheta)^2 Dtheta^2
324        ## gives
325        ##   (Dq/q)^2 = (DL/L)**2 + (Dtheta/tan(theta))**2
326        ## Take the square root and multiply by q, giving
327        ##   Dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
328        dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
329    else:
330        dq = resolution * q
331    data = Data1D(q, Iq, dx=dq, dy=dIq)
332    data.filename = "fake data"
333    return data
334
335
336def empty_data2D(qx, qy=None, resolution=0.0):
337    # type: (np.ndarray, Optional[np.ndarray], float) -> Data2D
338    """
339    Create empty 2D data using the given mesh.
340
341    If *qy* is missing, create a square mesh with *qy=qx*.
342
343    *resolution* dq/q defaults to 5%.
344    """
345    if qy is None:
346        qy = qx
347    qx, qy = np.asarray(qx), np.asarray(qy)
348    # 5% dQ/Q resolution
349    Qx, Qy = np.meshgrid(qx, qy)
350    Qx, Qy = Qx.flatten(), Qy.flatten()
351    Iq = 100 * np.ones_like(Qx)  # type: np.ndarray
352    dIq = np.sqrt(Iq)
353    if resolution != 0:
354        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
355        # Should have an additional constant which depends on distances and
356        # radii of the aperture, pixel dimensions and wavelength spread
357        # Instead, assume radial dQ/Q is constant, and perpendicular matches
358        # radial (which instead it should be inverse).
359        Q = np.sqrt(Qx**2 + Qy**2)
360        dqx = resolution * Q
361        dqy = resolution * Q
362    else:
363        dqx = dqy = None
364
365    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
366    data.x_bins = qx
367    data.y_bins = qy
368    data.filename = "fake data"
369
370    # pixel_size in mm, distance in m
371    detector = Detector(pixel_size=(5, 5), distance=4)
372    data.detector.append(detector)
373    data.source.wavelength = 5 # angstroms
374    data.source.wavelength_unit = "A"
375    return data
376
377
378def plot_data(data, view='log', limits=None):
379    # type: (Data, str, Optional[Tuple[float, float]]) -> None
380    """
381    Plot data loaded by the sasview loader.
382
383    *data* is a sasview data object, either 1D, 2D or SESANS.
384
385    *view* is log or linear.
386
387    *limits* sets the intensity limits on the plot; if None then the limits
388    are inferred from the data.
389    """
390    # Note: kind of weird using the plot result functions to plot just the
391    # data, but they already handle the masking and graph markup already, so
392    # do not repeat.
393    if hasattr(data, 'isSesans') and data.isSesans:
394        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
395    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
396        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
397    else:
398        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
399
400
401def plot_theory(data,          # type: Data
402                theory,        # type: Optional[np.ndarray]
403                resid=None,    # type: Optional[np.ndarray]
404                view='log',    # type: str
405                use_data=True, # type: bool
406                limits=None,   # type: Optional[np.ndarray]
407                Iq_calc=None   # type: Optional[np.ndarray]
408               ):
409    # type: (...) -> None
410    """
411    Plot theory calculation.
412
413    *data* is needed to define the graph properties such as labels and
414    units, and to define the data mask.
415
416    *theory* is a matrix of the same shape as the data.
417
418    *view* is log or linear
419
420    *use_data* is True if the data should be plotted as well as the theory.
421
422    *limits* sets the intensity limits on the plot; if None then the limits
423    are inferred from the data.
424
425    *Iq_calc* is the raw theory values without resolution smearing
426    """
427    if hasattr(data, 'isSesans') and data.isSesans:
428        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
429    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
430        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
431    else:
432        _plot_result1D(data, theory, resid, view, use_data,
433                       limits=limits, Iq_calc=Iq_calc)
434
435
436def protect(func):
437    # type: (Callable) -> Callable
438    """
439    Decorator to wrap calls in an exception trapper which prints the
440    exception and continues.  Keyboard interrupts are ignored.
441    """
442    def wrapper(*args, **kw):
443        """
444        Trap and print errors from function.
445        """
446        try:
447            return func(*args, **kw)
448        except Exception:
449            traceback.print_exc()
450
451    return wrapper
452
453
454@protect
455def _plot_result1D(data,         # type: Data1D
456                   theory,       # type: Optional[np.ndarray]
457                   resid,        # type: Optional[np.ndarray]
458                   view,         # type: str
459                   use_data,     # type: bool
460                   limits=None,  # type: Optional[Tuple[float, float]]
461                   Iq_calc=None  # type: Optional[np.ndarray]
462                  ):
463    # type: (...) -> None
464    """
465    Plot the data and residuals for 1D data.
466    """
467    import matplotlib.pyplot as plt  # type: ignore
468    from numpy.ma import masked_array, masked  # type: ignore
469
470    if getattr(data, 'radial', False):
471        data.x = data.q_data
472        data.y = data.data
473
474    use_data = use_data and data.y is not None
475    use_theory = theory is not None
476    use_resid = resid is not None
477    use_calc = use_theory and Iq_calc is not None
478    num_plots = (use_data or use_theory) + use_calc + use_resid
479    non_positive_x = (data.x <= 0.0).any()
480
481    scale = data.x**4 if view == 'q4' else 1.0
482    xscale = yscale = 'linear' if view == 'linear' else 'log'
483
484    if use_data or use_theory:
485        if num_plots > 1:
486            plt.subplot(1, num_plots, 1)
487
488        #print(vmin, vmax)
489        all_positive = True
490        some_present = False
491        if use_data:
492            mdata = masked_array(data.y, data.mask.copy())
493            mdata[~np.isfinite(mdata)] = masked
494            if view is 'log':
495                mdata[mdata <= 0] = masked
496            plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
497            all_positive = all_positive and (mdata > 0).all()
498            some_present = some_present or (mdata.count() > 0)
499
500
501        if use_theory:
502            # Note: masks merge, so any masked theory points will stay masked,
503            # and the data mask will be added to it.
504            #mtheory = masked_array(theory, data.mask.copy())
505            theory_x = data.x[data.mask == 0]
506            mtheory = masked_array(theory)
507            mtheory[~np.isfinite(mtheory)] = masked
508            if view is 'log':
509                mtheory[mtheory <= 0] = masked
510            plt.plot(theory_x, scale*mtheory, '-')
511            all_positive = all_positive and (mtheory > 0).all()
512            some_present = some_present or (mtheory.count() > 0)
513
514        if limits is not None:
515            plt.ylim(*limits)
516
517
518        xscale = ('linear' if not some_present or non_positive_x
519                  else view if view is not None
520                  else 'log')
521        yscale = ('linear'
522                  if view == 'q4' or not some_present or not all_positive
523                  else view if view is not None
524                  else 'log')
525        plt.xscale(xscale)
526        plt.xlabel("$q$/A$^{-1}$")
527        plt.yscale(yscale)
528        plt.ylabel('$I(q)$')
529        title = ("data and model" if use_theory and use_data
530                 else "data" if use_data
531                 else "model")
532        plt.title(title)
533
534    if use_calc:
535        # Only have use_calc if have use_theory
536        plt.subplot(1, num_plots, 2)
537        qx, qy, Iqxy = Iq_calc
538        plt.pcolormesh(qx, qy[qy > 0], np.log10(Iqxy[qy > 0, :]))
539        plt.xlabel("$q_x$/A$^{-1}$")
540        plt.xlabel("$q_y$/A$^{-1}$")
541        plt.xscale('log')
542        plt.yscale('log')
543        #plt.axis('equal')
544
545    if use_resid:
546        theory_x = data.x[data.mask == 0]
547        mresid = masked_array(resid)
548        mresid[~np.isfinite(mresid)] = masked
549        some_present = (mresid.count() > 0)
550
551        if num_plots > 1:
552            plt.subplot(1, num_plots, use_calc + 2)
553        plt.plot(theory_x, mresid, '.')
554        plt.xlabel("$q$/A$^{-1}$")
555        plt.ylabel('residuals')
556        plt.title('(model - Iq)/dIq')
557        plt.xscale(xscale)
558        plt.yscale('linear')
559
560
561@protect
562def _plot_result_sesans(data,        # type: SesansData
563                        theory,      # type: Optional[np.ndarray]
564                        resid,       # type: Optional[np.ndarray]
565                        use_data,    # type: bool
566                        limits=None  # type: Optional[Tuple[float, float]]
567                       ):
568    # type: (...) -> None
569    """
570    Plot SESANS results.
571    """
572    import matplotlib.pyplot as plt  # type: ignore
573    use_data = use_data and data.y is not None
574    use_theory = theory is not None
575    use_resid = resid is not None
576    num_plots = (use_data or use_theory) + use_resid
577
578    if use_data or use_theory:
579        is_tof = data.lam is not None and (data.lam != data.lam[0]).any()
580        if num_plots > 1:
581            plt.subplot(1, num_plots, 1)
582        if use_data:
583            if is_tof:
584                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam),
585                             yerr=data.dy/data.y/(data.lam*data.lam))
586            else:
587                plt.errorbar(data.x, data.y, yerr=data.dy)
588        if theory is not None:
589            if is_tof:
590                plt.plot(data.x, np.log(theory)/(data.lam*data.lam), '-')
591            else:
592                plt.plot(data.x, theory, '-')
593        if limits is not None:
594            plt.ylim(*limits)
595
596        plt.xlabel('spin echo length ({})'.format(data._xunit))
597        if is_tof:
598            plt.ylabel(r'(Log (P/P$_0$))/$\lambda^2$')
599        else:
600            plt.ylabel('polarization (P/P0)')
601
602
603    if resid is not None:
604        if num_plots > 1:
605            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
606        plt.plot(data.x, resid, 'x')
607        plt.xlabel('spin echo length ({})'.format(data._xunit))
608        plt.ylabel('residuals (P/P0)')
609
610
611@protect
612def _plot_result2D(data,         # type: Data2D
613                   theory,       # type: Optional[np.ndarray]
614                   resid,        # type: Optional[np.ndarray]
615                   view,         # type: str
616                   use_data,     # type: bool
617                   limits=None   # type: Optional[Tuple[float, float]]
618                  ):
619    # type: (...) -> None
620    """
621    Plot the data and residuals for 2D data.
622    """
623    import matplotlib.pyplot as plt  # type: ignore
624    use_data = use_data and data.data is not None
625    use_theory = theory is not None
626    use_resid = resid is not None
627    num_plots = use_data + use_theory + use_resid
628
629    # Put theory and data on a common colormap scale
630    vmin, vmax = np.inf, -np.inf
631    target = None # type: Optional[np.ndarray]
632    if use_data:
633        target = data.data[~data.mask]
634        datamin = target[target > 0].min() if view == 'log' else target.min()
635        datamax = target.max()
636        vmin = min(vmin, datamin)
637        vmax = max(vmax, datamax)
638    if use_theory:
639        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
640        theorymax = theory.max()
641        vmin = min(vmin, theorymin)
642        vmax = max(vmax, theorymax)
643
644    # Override data limits from the caller
645    if limits is not None:
646        vmin, vmax = limits
647
648    # Plot data
649    if use_data:
650        if num_plots > 1:
651            plt.subplot(1, num_plots, 1)
652        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
653        plt.title('data')
654        h = plt.colorbar()
655        h.set_label('$I(q)$')
656
657    # plot theory
658    if use_theory:
659        if num_plots > 1:
660            plt.subplot(1, num_plots, use_data+1)
661        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
662        plt.title('theory')
663        h = plt.colorbar()
664        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
665                    else r'$q^4 I(q)$' if view == 'q4'
666                    else '$I(q)$')
667
668    # plot resid
669    if use_resid:
670        if num_plots > 1:
671            plt.subplot(1, num_plots, use_data+use_theory+1)
672        _plot_2d_signal(data, resid, view='linear')
673        plt.title('residuals')
674        h = plt.colorbar()
675        h.set_label(r'$\Delta I(q)$')
676
677
678@protect
679def _plot_2d_signal(data,       # type: Data2D
680                    signal,     # type: np.ndarray
681                    vmin=None,  # type: Optional[float]
682                    vmax=None,  # type: Optional[float]
683                    view='log'  # type: str
684                   ):
685    # type: (...) -> Tuple[float, float]
686    """
687    Plot the target value for the data.  This could be the data itself,
688    the theory calculation, or the residuals.
689
690    *scale* can be 'log' for log scale data, or 'linear'.
691    """
692    import matplotlib.pyplot as plt  # type: ignore
693    from numpy.ma import masked_array  # type: ignore
694
695    image = np.zeros_like(data.qx_data)
696    image[~data.mask] = signal
697    valid = np.isfinite(image)
698    if view == 'log':
699        valid[valid] = (image[valid] > 0)
700        if vmin is None:
701            vmin = image[valid & ~data.mask].min()
702        if vmax is None:
703            vmax = image[valid & ~data.mask].max()
704        image[valid] = np.log10(image[valid])
705    elif view == 'q4':
706        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
707        if vmin is None:
708            vmin = image[valid & ~data.mask].min()
709        if vmax is None:
710            vmax = image[valid & ~data.mask].max()
711    else:
712        if vmin is None:
713            vmin = image[valid & ~data.mask].min()
714        if vmax is None:
715            vmax = image[valid & ~data.mask].max()
716
717    image[~valid | data.mask] = 0
718    #plottable = Iq
719    plottable = masked_array(image, ~valid | data.mask)
720    # Divide range by 10 to convert from angstroms to nanometers
721    xmin, xmax = min(data.qx_data), max(data.qx_data)
722    ymin, ymax = min(data.qy_data), max(data.qy_data)
723    if view == 'log':
724        vmin_scaled, vmax_scaled = np.log10(vmin), np.log10(vmax)
725    else:
726        vmin_scaled, vmax_scaled = vmin, vmax
727    #nx, ny = len(data.x_bins), len(data.y_bins)
728    x_bins, y_bins, image = _build_matrix(data, plottable)
729    plt.imshow(image,
730               interpolation='nearest', aspect=1, origin='lower',
731               extent=[xmin, xmax, ymin, ymax],
732               vmin=vmin_scaled, vmax=vmax_scaled)
733    plt.xlabel("$q_x$/A$^{-1}$")
734    plt.ylabel("$q_y$/A$^{-1}$")
735    return vmin, vmax
736
737
738# === The following is modified from sas.sasgui.plottools.PlotPanel
739def _build_matrix(self, plottable):
740    """
741    Build a matrix for 2d plot from a vector
742    Returns a matrix (image) with ~ square binning
743    Requirement: need 1d array formats of
744    self.data, self.qx_data, and self.qy_data
745    where each one corresponds to z, x, or y axis values
746
747    """
748    # No qx or qy given in a vector format
749    if self.qx_data is None or self.qy_data is None \
750            or self.qx_data.ndim != 1 or self.qy_data.ndim != 1:
751        return self.x_bins, self.y_bins, plottable
752
753    # maximum # of loops to fillup_pixels
754    # otherwise, loop could never stop depending on data
755    max_loop = 1
756    # get the x and y_bin arrays.
757    x_bins, y_bins = _get_bins(self)
758    # set zero to None
759
760    #Note: Can not use scipy.interpolate.Rbf:
761    # 'cause too many data points (>10000)<=JHC.
762    # 1d array to use for weighting the data point averaging
763    #when they fall into a same bin.
764    weights_data = np.ones([self.data.size])
765    # get histogram of ones w/len(data); this will provide
766    #the weights of data on each bins
767    weights, xedges, yedges = np.histogram2d(x=self.qy_data,
768                                             y=self.qx_data,
769                                             bins=[y_bins, x_bins],
770                                             weights=weights_data)
771    # get histogram of data, all points into a bin in a way of summing
772    image, xedges, yedges = np.histogram2d(x=self.qy_data,
773                                           y=self.qx_data,
774                                           bins=[y_bins, x_bins],
775                                           weights=plottable)
776    # Now, normalize the image by weights only for weights>1:
777    # If weight == 1, there is only one data point in the bin so
778    # that no normalization is required.
779    image[weights > 1] = image[weights > 1] / weights[weights > 1]
780    # Set image bins w/o a data point (weight==0) as None (was set to zero
781    # by histogram2d.)
782    image[weights == 0] = None
783
784    # Fill empty bins with 8 nearest neighbors only when at least
785    #one None point exists
786    loop = 0
787
788    # do while loop until all vacant bins are filled up up
789    #to loop = max_loop
790    while (weights == 0).any():
791        if loop >= max_loop:  # this protects never-ending loop
792            break
793        image = _fillup_pixels(image=image, weights=weights)
794        loop += 1
795
796    return x_bins, y_bins, image
797
798def _get_bins(self):
799    """
800    get bins
801    set x_bins and y_bins into self, 1d arrays of the index with
802    ~ square binning
803    Requirement: need 1d array formats of
804    self.qx_data, and self.qy_data
805    where each one corresponds to  x, or y axis values
806    """
807    # find max and min values of qx and qy
808    xmax = self.qx_data.max()
809    xmin = self.qx_data.min()
810    ymax = self.qy_data.max()
811    ymin = self.qy_data.min()
812
813    # calculate the range of qx and qy: this way, it is a little
814    # more independent
815    x_size = xmax - xmin
816    y_size = ymax - ymin
817
818    # estimate the # of pixels on each axes
819    npix_y = int(np.floor(np.sqrt(len(self.qy_data))))
820    npix_x = int(np.floor(len(self.qy_data) / npix_y))
821
822    # bin size: x- & y-directions
823    xstep = x_size / (npix_x - 1)
824    ystep = y_size / (npix_y - 1)
825
826    # max and min taking account of the bin sizes
827    xmax = xmax + xstep / 2.0
828    xmin = xmin - xstep / 2.0
829    ymax = ymax + ystep / 2.0
830    ymin = ymin - ystep / 2.0
831
832    # store x and y bin centers in q space
833    x_bins = np.linspace(xmin, xmax, npix_x)
834    y_bins = np.linspace(ymin, ymax, npix_y)
835
836    return x_bins, y_bins
837
838def _fillup_pixels(image=None, weights=None):
839    """
840    Fill z values of the empty cells of 2d image matrix
841    with the average over up-to next nearest neighbor points
842
843    :param image: (2d matrix with some zi = None)
844
845    :return: image (2d array )
846
847    :TODO: Find better way to do for-loop below
848
849    """
850    # No image matrix given
851    if image is None or np.ndim(image) != 2 \
852            or np.isfinite(image).all() \
853            or weights is None:
854        return image
855    # Get bin size in y and x directions
856    len_y = len(image)
857    len_x = len(image[1])
858    temp_image = np.zeros([len_y, len_x])
859    weit = np.zeros([len_y, len_x])
860    # do for-loop for all pixels
861    for n_y in range(len(image)):
862        for n_x in range(len(image[1])):
863            # find only null pixels
864            if weights[n_y][n_x] > 0 or np.isfinite(image[n_y][n_x]):
865                continue
866            else:
867                # find 4 nearest neighbors
868                # check where or not it is at the corner
869                if n_y != 0 and np.isfinite(image[n_y - 1][n_x]):
870                    temp_image[n_y][n_x] += image[n_y - 1][n_x]
871                    weit[n_y][n_x] += 1
872                if n_x != 0 and np.isfinite(image[n_y][n_x - 1]):
873                    temp_image[n_y][n_x] += image[n_y][n_x - 1]
874                    weit[n_y][n_x] += 1
875                if n_y != len_y - 1 and np.isfinite(image[n_y + 1][n_x]):
876                    temp_image[n_y][n_x] += image[n_y + 1][n_x]
877                    weit[n_y][n_x] += 1
878                if n_x != len_x - 1 and np.isfinite(image[n_y][n_x + 1]):
879                    temp_image[n_y][n_x] += image[n_y][n_x + 1]
880                    weit[n_y][n_x] += 1
881                # go 4 next nearest neighbors when no non-zero
882                # neighbor exists
883                if n_y != 0 and n_x != 0 and \
884                        np.isfinite(image[n_y - 1][n_x - 1]):
885                    temp_image[n_y][n_x] += image[n_y - 1][n_x - 1]
886                    weit[n_y][n_x] += 1
887                if n_y != len_y - 1 and n_x != 0 and \
888                        np.isfinite(image[n_y + 1][n_x - 1]):
889                    temp_image[n_y][n_x] += image[n_y + 1][n_x - 1]
890                    weit[n_y][n_x] += 1
891                if n_y != len_y and n_x != len_x - 1 and \
892                        np.isfinite(image[n_y - 1][n_x + 1]):
893                    temp_image[n_y][n_x] += image[n_y - 1][n_x + 1]
894                    weit[n_y][n_x] += 1
895                if n_y != len_y - 1 and n_x != len_x - 1 and \
896                        np.isfinite(image[n_y + 1][n_x + 1]):
897                    temp_image[n_y][n_x] += image[n_y + 1][n_x + 1]
898                    weit[n_y][n_x] += 1
899
900    # get it normalized
901    ind = (weit > 0)
902    image[ind] = temp_image[ind] / weit[ind]
903
904    return image
905
906
907def demo():
908    # type: () -> None
909    """
910    Load and plot a SAS dataset.
911    """
912    data = load_data('DEC07086.DAT')
913    set_beam_stop(data, 0.004)
914    plot_data(data)
915    import matplotlib.pyplot as plt  # type: ignore
916    plt.show()
917
918
919if __name__ == "__main__":
920    demo()
Note: See TracBrowser for help on using the repository browser.