source: sasmodels/sasmodels/data.py @ c036ddb

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since c036ddb was c036ddb, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

refactor so Iq is not needed if Fq is defined

  • Property mode set to 100644
File size: 30.3 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np  # type: ignore
38from numpy import sqrt, sin, cos, pi
39
40# pylint: disable=unused-import
41try:
42    from typing import Union, Dict, List, Optional
43except ImportError:
44    pass
45else:
46    Data = Union["Data1D", "Data2D", "SesansData"]
47# pylint: enable=unused-import
48
49def load_data(filename, index=0):
50    # type: (str) -> Data
51    """
52    Load data using a sasview loader.
53    """
54    from sas.sascalc.dataloader.loader import Loader  # type: ignore
55    loader = Loader()
56    # Allow for one part in multipart file
57    if '[' in filename:
58        filename, indexstr = filename[:-1].split('[')
59        index = int(indexstr)
60    datasets = loader.load(filename)
61    if not datasets:  # None or []
62        raise IOError("Data %r could not be loaded" % filename)
63    if not isinstance(datasets, list):
64        datasets = [datasets]
65    for data in datasets:
66        if hasattr(data, 'x'):
67            data.qmin, data.qmax = data.x.min(), data.x.max()
68            data.mask = (np.isnan(data.y) if data.y is not None
69                         else np.zeros_like(data.x, dtype='bool'))
70        elif hasattr(data, 'qx_data'):
71            data.mask = ~data.mask
72    return datasets[index] if index != 'all' else datasets
73
74
75def set_beam_stop(data, radius, outer=None):
76    # type: (Data, float, Optional[float]) -> None
77    """
78    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
79    """
80    from sas.sascalc.dataloader.manipulations import Ringcut
81    if hasattr(data, 'qx_data'):
82        data.mask = Ringcut(0, radius)(data)
83        if outer is not None:
84            data.mask += Ringcut(outer, np.inf)(data)
85    else:
86        data.mask = (data.x < radius)
87        if outer is not None:
88            data.mask |= (data.x >= outer)
89
90
91def set_half(data, half):
92    # type: (Data, str) -> None
93    """
94    Select half of the data, either "right" or "left".
95    """
96    from sas.sascalc.dataloader.manipulations import Boxcut
97    if half == 'right':
98        data.mask += \
99            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
100    if half == 'left':
101        data.mask += \
102            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
103
104
105def set_top(data, cutoff):
106    # type: (Data, float) -> None
107    """
108    Chop the top off the data, above *cutoff*.
109    """
110    from sas.sascalc.dataloader.manipulations import Boxcut
111    data.mask += \
112        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
113
114
115class Data1D(object):
116    """
117    1D data object.
118
119    Note that this definition matches the attributes from sasview, with
120    some generic 1D data vectors and some SAS specific definitions.  Some
121    refactoring to allow consistent naming conventions between 1D, 2D and
122    SESANS data would be helpful.
123
124    **Attributes**
125
126    *x*, *dx*: $q$ vector and gaussian resolution
127
128    *y*, *dy*: $I(q)$ vector and measurement uncertainty
129
130    *mask*: values to include in plotting/analysis
131
132    *dxl*: slit widths for slit smeared data, with *dx* ignored
133
134    *qmin*, *qmax*: range of $q$ values in *x*
135
136    *filename*: label for the data line
137
138    *_xaxis*, *_xunit*: label and units for the *x* axis
139
140    *_yaxis*, *_yunit*: label and units for the *y* axis
141    """
142    def __init__(self,
143                 x=None,  # type: Optional[np.ndarray]
144                 y=None,  # type: Optional[np.ndarray]
145                 dx=None, # type: Optional[np.ndarray]
146                 dy=None  # type: Optional[np.ndarray]
147                ):
148        # type: (...) -> None
149        self.x, self.y, self.dx, self.dy = x, y, dx, dy
150        self.dxl = None
151        self.filename = None
152        self.qmin = x.min() if x is not None else np.NaN
153        self.qmax = x.max() if x is not None else np.NaN
154        # TODO: why is 1D mask False and 2D mask True?
155        self.mask = (np.isnan(y) if y is not None
156                     else np.zeros_like(x, 'b') if x is not None
157                     else None)
158        self._xaxis, self._xunit = "x", ""
159        self._yaxis, self._yunit = "y", ""
160
161    def xaxis(self, label, unit):
162        # type: (str, str) -> None
163        """
164        set the x axis label and unit
165        """
166        self._xaxis = label
167        self._xunit = unit
168
169    def yaxis(self, label, unit):
170        # type: (str, str) -> None
171        """
172        set the y axis label and unit
173        """
174        self._yaxis = label
175        self._yunit = unit
176
177class SesansData(Data1D):
178    """
179    SESANS data object.
180
181    This is just :class:`Data1D` with a wavelength parameter.
182
183    *x* is spin echo length and *y* is polarization (P/P0).
184    """
185    def __init__(self, **kw):
186        Data1D.__init__(self, **kw)
187        self.lam = None # type: Optional[np.ndarray]
188
189class Data2D(object):
190    """
191    2D data object.
192
193    Note that this definition matches the attributes from sasview. Some
194    refactoring to allow consistent naming conventions between 1D, 2D and
195    SESANS data would be helpful.
196
197    **Attributes**
198
199    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
200
201    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
202
203    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
204
205    *mask*: values to exclude from plotting/analysis
206
207    *qmin*, *qmax*: range of $q$ values in *x*
208
209    *filename*: label for the data line
210
211    *_xaxis*, *_xunit*: label and units for the *x* axis
212
213    *_yaxis*, *_yunit*: label and units for the *y* axis
214
215    *_zaxis*, *_zunit*: label and units for the *y* axis
216
217    *Q_unit*, *I_unit*: units for Q and intensity
218
219    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
220    """
221    def __init__(self,
222                 x=None,   # type: Optional[np.ndarray]
223                 y=None,   # type: Optional[np.ndarray]
224                 z=None,   # type: Optional[np.ndarray]
225                 dx=None,  # type: Optional[np.ndarray]
226                 dy=None,  # type: Optional[np.ndarray]
227                 dz=None   # type: Optional[np.ndarray]
228                ):
229        # type: (...) -> None
230        self.qx_data, self.dqx_data = x, dx
231        self.qy_data, self.dqy_data = y, dy
232        self.data, self.err_data = z, dz
233        self.mask = (np.isnan(z) if z is not None
234                     else np.zeros_like(x, dtype='bool') if x is not None
235                     else None)
236        self.q_data = np.sqrt(x**2 + y**2)
237        self.qmin = 1e-16
238        self.qmax = np.inf
239        self.detector = []
240        self.source = Source()
241        self.Q_unit = "1/A"
242        self.I_unit = "1/cm"
243        self.xaxis("Q_x", "1/A")
244        self.yaxis("Q_y", "1/A")
245        self.zaxis("Intensity", "1/cm")
246        self._xaxis, self._xunit = "x", ""
247        self._yaxis, self._yunit = "y", ""
248        self._zaxis, self._zunit = "z", ""
249        self.x_bins, self.y_bins = None, None
250        self.filename = None
251
252    def xaxis(self, label, unit):
253        # type: (str, str) -> None
254        """
255        set the x axis label and unit
256        """
257        self._xaxis = label
258        self._xunit = unit
259
260    def yaxis(self, label, unit):
261        # type: (str, str) -> None
262        """
263        set the y axis label and unit
264        """
265        self._yaxis = label
266        self._yunit = unit
267
268    def zaxis(self, label, unit):
269        # type: (str, str) -> None
270        """
271        set the y axis label and unit
272        """
273        self._zaxis = label
274        self._zunit = unit
275
276
277class Vector(object):
278    """
279    3-space vector of *x*, *y*, *z*
280    """
281    def __init__(self, x=None, y=None, z=None):
282        # type: (float, float, Optional[float]) -> None
283        self.x, self.y, self.z = x, y, z
284
285class Detector(object):
286    """
287    Detector attributes.
288    """
289    def __init__(self, pixel_size=(None, None), distance=None):
290        # type: (Tuple[float, float], float) -> None
291        self.pixel_size = Vector(*pixel_size)
292        self.distance = distance
293
294class Source(object):
295    """
296    Beam attributes.
297    """
298    def __init__(self):
299        # type: () -> None
300        self.wavelength = np.NaN
301        self.wavelength_unit = "A"
302
303
304def empty_data1D(q, resolution=0.0, L=0., dL=0.):
305    # type: (np.ndarray, float) -> Data1D
306    r"""
307    Create empty 1D data using the given *q* as the x value.
308
309    rms *resolution* $\Delta q/q$ defaults to 0%.  If wavelength *L* and rms
310    wavelength divergence *dL* are defined, then *resolution* defines
311    rms $\Delta \theta/\theta$ for the lowest *q*, with $\theta$ derived from
312    $q = 4\pi/\lambda \sin(\theta)$.
313    """
314
315    #Iq = 100 * np.ones_like(q)
316    #dIq = np.sqrt(Iq)
317    Iq, dIq = None, None
318    q = np.asarray(q)
319    if L != 0 and resolution != 0:
320        theta = np.arcsin(q*L/(4*pi))
321        dtheta = theta[0]*resolution
322        ## Solving Gaussian error propagation from
323        ##   Dq^2 = (dq/dL)^2 DL^2 + (dq/dtheta)^2 Dtheta^2
324        ## gives
325        ##   (Dq/q)^2 = (DL/L)**2 + (Dtheta/tan(theta))**2
326        ## Take the square root and multiply by q, giving
327        ##   Dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
328        dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
329    else:
330        dq = resolution * q
331    data = Data1D(q, Iq, dx=dq, dy=dIq)
332    data.filename = "fake data"
333    return data
334
335
336def empty_data2D(qx, qy=None, resolution=0.0):
337    # type: (np.ndarray, Optional[np.ndarray], float) -> Data2D
338    """
339    Create empty 2D data using the given mesh.
340
341    If *qy* is missing, create a square mesh with *qy=qx*.
342
343    *resolution* dq/q defaults to 5%.
344    """
345    if qy is None:
346        qy = qx
347    qx, qy = np.asarray(qx), np.asarray(qy)
348    # 5% dQ/Q resolution
349    Qx, Qy = np.meshgrid(qx, qy)
350    Qx, Qy = Qx.flatten(), Qy.flatten()
351    Iq = 100 * np.ones_like(Qx)  # type: np.ndarray
352    dIq = np.sqrt(Iq)
353    if resolution != 0:
354        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
355        # Should have an additional constant which depends on distances and
356        # radii of the aperture, pixel dimensions and wavelength spread
357        # Instead, assume radial dQ/Q is constant, and perpendicular matches
358        # radial (which instead it should be inverse).
359        Q = np.sqrt(Qx**2 + Qy**2)
360        dqx = resolution * Q
361        dqy = resolution * Q
362    else:
363        dqx = dqy = None
364
365    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
366    data.x_bins = qx
367    data.y_bins = qy
368    data.filename = "fake data"
369
370    # pixel_size in mm, distance in m
371    detector = Detector(pixel_size=(5, 5), distance=4)
372    data.detector.append(detector)
373    data.source.wavelength = 5 # angstroms
374    data.source.wavelength_unit = "A"
375    return data
376
377
378def plot_data(data, view='log', limits=None):
379    # type: (Data, str, Optional[Tuple[float, float]]) -> None
380    """
381    Plot data loaded by the sasview loader.
382
383    *data* is a sasview data object, either 1D, 2D or SESANS.
384
385    *view* is log or linear.
386
387    *limits* sets the intensity limits on the plot; if None then the limits
388    are inferred from the data.
389    """
390    # Note: kind of weird using the plot result functions to plot just the
391    # data, but they already handle the masking and graph markup already, so
392    # do not repeat.
393    if hasattr(data, 'isSesans') and data.isSesans:
394        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
395    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
396        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
397    else:
398        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
399
400
401def plot_theory(data,          # type: Data
402                theory,        # type: Optional[np.ndarray]
403                resid=None,    # type: Optional[np.ndarray]
404                view='log',    # type: str
405                use_data=True, # type: bool
406                limits=None,   # type: Optional[np.ndarray]
407                Iq_calc=None   # type: Optional[np.ndarray]
408               ):
409    # type: (...) -> None
410    """
411    Plot theory calculation.
412
413    *data* is needed to define the graph properties such as labels and
414    units, and to define the data mask.
415
416    *theory* is a matrix of the same shape as the data.
417
418    *view* is log or linear
419
420    *use_data* is True if the data should be plotted as well as the theory.
421
422    *limits* sets the intensity limits on the plot; if None then the limits
423    are inferred from the data.
424
425    *Iq_calc* is the raw theory values without resolution smearing
426    """
427    if hasattr(data, 'isSesans') and data.isSesans:
428        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
429    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
430        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
431    else:
432        _plot_result1D(data, theory, resid, view, use_data,
433                       limits=limits, Iq_calc=Iq_calc)
434
435
436def protect(func):
437    # type: (Callable) -> Callable
438    """
439    Decorator to wrap calls in an exception trapper which prints the
440    exception and continues.  Keyboard interrupts are ignored.
441    """
442    def wrapper(*args, **kw):
443        """
444        Trap and print errors from function.
445        """
446        try:
447            return func(*args, **kw)
448        except Exception:
449            traceback.print_exc()
450
451    return wrapper
452
453
454@protect
455def _plot_result1D(data,         # type: Data1D
456                   theory,       # type: Optional[np.ndarray]
457                   resid,        # type: Optional[np.ndarray]
458                   view,         # type: str
459                   use_data,     # type: bool
460                   limits=None,  # type: Optional[Tuple[float, float]]
461                   Iq_calc=None  # type: Optional[np.ndarray]
462                  ):
463    # type: (...) -> None
464    """
465    Plot the data and residuals for 1D data.
466    """
467    import matplotlib.pyplot as plt  # type: ignore
468    from numpy.ma import masked_array, masked  # type: ignore
469
470    if getattr(data, 'radial', False):
471        data.x = data.q_data
472        data.y = data.data
473
474    use_data = use_data and data.y is not None
475    use_theory = theory is not None
476    use_resid = resid is not None
477    use_calc = use_theory and Iq_calc is not None
478    num_plots = (use_data or use_theory) + use_calc + use_resid
479    non_positive_x = (data.x <= 0.0).any()
480
481    scale = data.x**4 if view == 'q4' else 1.0
482    xscale = yscale = 'linear' if view == 'linear' else 'log'
483
484    if use_data or use_theory:
485        if num_plots > 1:
486            plt.subplot(1, num_plots, 1)
487
488        #print(vmin, vmax)
489        all_positive = True
490        some_present = False
491        if use_data:
492            mdata = masked_array(data.y, data.mask.copy())
493            mdata[~np.isfinite(mdata)] = masked
494            if view is 'log':
495                mdata[mdata <= 0] = masked
496            plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
497            all_positive = all_positive and (mdata > 0).all()
498            some_present = some_present or (mdata.count() > 0)
499
500
501        if use_theory:
502            # Note: masks merge, so any masked theory points will stay masked,
503            # and the data mask will be added to it.
504            mtheory = masked_array(theory, data.mask.copy())
505            mtheory[~np.isfinite(mtheory)] = masked
506            if view is 'log':
507                mtheory[mtheory <= 0] = masked
508            plt.plot(data.x, scale*mtheory, '-')
509            all_positive = all_positive and (mtheory > 0).all()
510            some_present = some_present or (mtheory.count() > 0)
511
512        if limits is not None:
513            plt.ylim(*limits)
514
515
516        xscale = ('linear' if not some_present or non_positive_x
517                  else view if view is not None
518                  else 'log')
519        yscale = ('linear'
520                  if view == 'q4' or not some_present or not all_positive
521                  else view if view is not None
522                  else 'log')
523        plt.xscale(xscale)
524        plt.xlabel("$q$/A$^{-1}$")
525        plt.yscale(yscale)
526        plt.ylabel('$I(q)$')
527        title = ("data and model" if use_theory and use_data
528                 else "data" if use_data
529                 else "model")
530        plt.title(title)
531
532    if use_calc:
533        # Only have use_calc if have use_theory
534        plt.subplot(1, num_plots, 2)
535        qx, qy, Iqxy = Iq_calc
536        plt.pcolormesh(qx, qy[qy > 0], np.log10(Iqxy[qy > 0, :]))
537        plt.xlabel("$q_x$/A$^{-1}$")
538        plt.xlabel("$q_y$/A$^{-1}$")
539        plt.xscale('log')
540        plt.yscale('log')
541        #plt.axis('equal')
542
543    if use_resid:
544        mresid = masked_array(resid, data.mask.copy())
545        mresid[~np.isfinite(mresid)] = masked
546        some_present = (mresid.count() > 0)
547
548        if num_plots > 1:
549            plt.subplot(1, num_plots, use_calc + 2)
550        plt.plot(data.x, mresid, '.')
551        plt.xlabel("$q$/A$^{-1}$")
552        plt.ylabel('residuals')
553        plt.title('(model - Iq)/dIq')
554        plt.xscale(xscale)
555        plt.yscale('linear')
556
557
558@protect
559def _plot_result_sesans(data,        # type: SesansData
560                        theory,      # type: Optional[np.ndarray]
561                        resid,       # type: Optional[np.ndarray]
562                        use_data,    # type: bool
563                        limits=None  # type: Optional[Tuple[float, float]]
564                       ):
565    # type: (...) -> None
566    """
567    Plot SESANS results.
568    """
569    import matplotlib.pyplot as plt  # type: ignore
570    use_data = use_data and data.y is not None
571    use_theory = theory is not None
572    use_resid = resid is not None
573    num_plots = (use_data or use_theory) + use_resid
574
575    if use_data or use_theory:
576        is_tof = data.lam is not None and (data.lam != data.lam[0]).any()
577        if num_plots > 1:
578            plt.subplot(1, num_plots, 1)
579        if use_data:
580            if is_tof:
581                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam),
582                             yerr=data.dy/data.y/(data.lam*data.lam))
583            else:
584                plt.errorbar(data.x, data.y, yerr=data.dy)
585        if theory is not None:
586            if is_tof:
587                plt.plot(data.x, np.log(theory)/(data.lam*data.lam), '-')
588            else:
589                plt.plot(data.x, theory, '-')
590        if limits is not None:
591            plt.ylim(*limits)
592
593        plt.xlabel('spin echo length ({})'.format(data._xunit))
594        if is_tof:
595            plt.ylabel(r'(Log (P/P$_0$))/$\lambda^2$')
596        else:
597            plt.ylabel('polarization (P/P0)')
598
599
600    if resid is not None:
601        if num_plots > 1:
602            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
603        plt.plot(data.x, resid, 'x')
604        plt.xlabel('spin echo length ({})'.format(data._xunit))
605        plt.ylabel('residuals (P/P0)')
606
607
608@protect
609def _plot_result2D(data,         # type: Data2D
610                   theory,       # type: Optional[np.ndarray]
611                   resid,        # type: Optional[np.ndarray]
612                   view,         # type: str
613                   use_data,     # type: bool
614                   limits=None   # type: Optional[Tuple[float, float]]
615                  ):
616    # type: (...) -> None
617    """
618    Plot the data and residuals for 2D data.
619    """
620    import matplotlib.pyplot as plt  # type: ignore
621    use_data = use_data and data.data is not None
622    use_theory = theory is not None
623    use_resid = resid is not None
624    num_plots = use_data + use_theory + use_resid
625
626    # Put theory and data on a common colormap scale
627    vmin, vmax = np.inf, -np.inf
628    target = None # type: Optional[np.ndarray]
629    if use_data:
630        target = data.data[~data.mask]
631        datamin = target[target > 0].min() if view == 'log' else target.min()
632        datamax = target.max()
633        vmin = min(vmin, datamin)
634        vmax = max(vmax, datamax)
635    if use_theory:
636        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
637        theorymax = theory.max()
638        vmin = min(vmin, theorymin)
639        vmax = max(vmax, theorymax)
640
641    # Override data limits from the caller
642    if limits is not None:
643        vmin, vmax = limits
644
645    # Plot data
646    if use_data:
647        if num_plots > 1:
648            plt.subplot(1, num_plots, 1)
649        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
650        plt.title('data')
651        h = plt.colorbar()
652        h.set_label('$I(q)$')
653
654    # plot theory
655    if use_theory:
656        if num_plots > 1:
657            plt.subplot(1, num_plots, use_data+1)
658        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
659        plt.title('theory')
660        h = plt.colorbar()
661        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
662                    else r'$q^4 I(q)$' if view == 'q4'
663                    else '$I(q)$')
664
665    # plot resid
666    if use_resid:
667        if num_plots > 1:
668            plt.subplot(1, num_plots, use_data+use_theory+1)
669        _plot_2d_signal(data, resid, view='linear')
670        plt.title('residuals')
671        h = plt.colorbar()
672        h.set_label(r'$\Delta I(q)$')
673
674
675@protect
676def _plot_2d_signal(data,       # type: Data2D
677                    signal,     # type: np.ndarray
678                    vmin=None,  # type: Optional[float]
679                    vmax=None,  # type: Optional[float]
680                    view='log'  # type: str
681                   ):
682    # type: (...) -> Tuple[float, float]
683    """
684    Plot the target value for the data.  This could be the data itself,
685    the theory calculation, or the residuals.
686
687    *scale* can be 'log' for log scale data, or 'linear'.
688    """
689    import matplotlib.pyplot as plt  # type: ignore
690    from numpy.ma import masked_array  # type: ignore
691
692    image = np.zeros_like(data.qx_data)
693    image[~data.mask] = signal
694    valid = np.isfinite(image)
695    if view == 'log':
696        valid[valid] = (image[valid] > 0)
697        if vmin is None:
698            vmin = image[valid & ~data.mask].min()
699        if vmax is None:
700            vmax = image[valid & ~data.mask].max()
701        image[valid] = np.log10(image[valid])
702    elif view == 'q4':
703        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
704        if vmin is None:
705            vmin = image[valid & ~data.mask].min()
706        if vmax is None:
707            vmax = image[valid & ~data.mask].max()
708    else:
709        if vmin is None:
710            vmin = image[valid & ~data.mask].min()
711        if vmax is None:
712            vmax = image[valid & ~data.mask].max()
713
714    image[~valid | data.mask] = 0
715    #plottable = Iq
716    plottable = masked_array(image, ~valid | data.mask)
717    # Divide range by 10 to convert from angstroms to nanometers
718    xmin, xmax = min(data.qx_data), max(data.qx_data)
719    ymin, ymax = min(data.qy_data), max(data.qy_data)
720    if view == 'log':
721        vmin_scaled, vmax_scaled = np.log10(vmin), np.log10(vmax)
722    else:
723        vmin_scaled, vmax_scaled = vmin, vmax
724    #nx, ny = len(data.x_bins), len(data.y_bins)
725    x_bins, y_bins, image = _build_matrix(data, plottable)
726    plt.imshow(image,
727               interpolation='nearest', aspect=1, origin='lower',
728               extent=[xmin, xmax, ymin, ymax],
729               vmin=vmin_scaled, vmax=vmax_scaled)
730    plt.xlabel("$q_x$/A$^{-1}$")
731    plt.ylabel("$q_y$/A$^{-1}$")
732    return vmin, vmax
733
734
735# === The following is modified from sas.sasgui.plottools.PlotPanel
736def _build_matrix(self, plottable):
737    """
738    Build a matrix for 2d plot from a vector
739    Returns a matrix (image) with ~ square binning
740    Requirement: need 1d array formats of
741    self.data, self.qx_data, and self.qy_data
742    where each one corresponds to z, x, or y axis values
743
744    """
745    # No qx or qy given in a vector format
746    if self.qx_data is None or self.qy_data is None \
747            or self.qx_data.ndim != 1 or self.qy_data.ndim != 1:
748        return self.x_bins, self.y_bins, plottable
749
750    # maximum # of loops to fillup_pixels
751    # otherwise, loop could never stop depending on data
752    max_loop = 1
753    # get the x and y_bin arrays.
754    x_bins, y_bins = _get_bins(self)
755    # set zero to None
756
757    #Note: Can not use scipy.interpolate.Rbf:
758    # 'cause too many data points (>10000)<=JHC.
759    # 1d array to use for weighting the data point averaging
760    #when they fall into a same bin.
761    weights_data = np.ones([self.data.size])
762    # get histogram of ones w/len(data); this will provide
763    #the weights of data on each bins
764    weights, xedges, yedges = np.histogram2d(x=self.qy_data,
765                                             y=self.qx_data,
766                                             bins=[y_bins, x_bins],
767                                             weights=weights_data)
768    # get histogram of data, all points into a bin in a way of summing
769    image, xedges, yedges = np.histogram2d(x=self.qy_data,
770                                           y=self.qx_data,
771                                           bins=[y_bins, x_bins],
772                                           weights=plottable)
773    # Now, normalize the image by weights only for weights>1:
774    # If weight == 1, there is only one data point in the bin so
775    # that no normalization is required.
776    image[weights > 1] = image[weights > 1] / weights[weights > 1]
777    # Set image bins w/o a data point (weight==0) as None (was set to zero
778    # by histogram2d.)
779    image[weights == 0] = None
780
781    # Fill empty bins with 8 nearest neighbors only when at least
782    #one None point exists
783    loop = 0
784
785    # do while loop until all vacant bins are filled up up
786    #to loop = max_loop
787    while (weights == 0).any():
788        if loop >= max_loop:  # this protects never-ending loop
789            break
790        image = _fillup_pixels(image=image, weights=weights)
791        loop += 1
792
793    return x_bins, y_bins, image
794
795def _get_bins(self):
796    """
797    get bins
798    set x_bins and y_bins into self, 1d arrays of the index with
799    ~ square binning
800    Requirement: need 1d array formats of
801    self.qx_data, and self.qy_data
802    where each one corresponds to  x, or y axis values
803    """
804    # find max and min values of qx and qy
805    xmax = self.qx_data.max()
806    xmin = self.qx_data.min()
807    ymax = self.qy_data.max()
808    ymin = self.qy_data.min()
809
810    # calculate the range of qx and qy: this way, it is a little
811    # more independent
812    x_size = xmax - xmin
813    y_size = ymax - ymin
814
815    # estimate the # of pixels on each axes
816    npix_y = int(np.floor(np.sqrt(len(self.qy_data))))
817    npix_x = int(np.floor(len(self.qy_data) / npix_y))
818
819    # bin size: x- & y-directions
820    xstep = x_size / (npix_x - 1)
821    ystep = y_size / (npix_y - 1)
822
823    # max and min taking account of the bin sizes
824    xmax = xmax + xstep / 2.0
825    xmin = xmin - xstep / 2.0
826    ymax = ymax + ystep / 2.0
827    ymin = ymin - ystep / 2.0
828
829    # store x and y bin centers in q space
830    x_bins = np.linspace(xmin, xmax, npix_x)
831    y_bins = np.linspace(ymin, ymax, npix_y)
832
833    return x_bins, y_bins
834
835def _fillup_pixels(image=None, weights=None):
836    """
837    Fill z values of the empty cells of 2d image matrix
838    with the average over up-to next nearest neighbor points
839
840    :param image: (2d matrix with some zi = None)
841
842    :return: image (2d array )
843
844    :TODO: Find better way to do for-loop below
845
846    """
847    # No image matrix given
848    if image is None or np.ndim(image) != 2 \
849            or np.isfinite(image).all() \
850            or weights is None:
851        return image
852    # Get bin size in y and x directions
853    len_y = len(image)
854    len_x = len(image[1])
855    temp_image = np.zeros([len_y, len_x])
856    weit = np.zeros([len_y, len_x])
857    # do for-loop for all pixels
858    for n_y in range(len(image)):
859        for n_x in range(len(image[1])):
860            # find only null pixels
861            if weights[n_y][n_x] > 0 or np.isfinite(image[n_y][n_x]):
862                continue
863            else:
864                # find 4 nearest neighbors
865                # check where or not it is at the corner
866                if n_y != 0 and np.isfinite(image[n_y - 1][n_x]):
867                    temp_image[n_y][n_x] += image[n_y - 1][n_x]
868                    weit[n_y][n_x] += 1
869                if n_x != 0 and np.isfinite(image[n_y][n_x - 1]):
870                    temp_image[n_y][n_x] += image[n_y][n_x - 1]
871                    weit[n_y][n_x] += 1
872                if n_y != len_y - 1 and np.isfinite(image[n_y + 1][n_x]):
873                    temp_image[n_y][n_x] += image[n_y + 1][n_x]
874                    weit[n_y][n_x] += 1
875                if n_x != len_x - 1 and np.isfinite(image[n_y][n_x + 1]):
876                    temp_image[n_y][n_x] += image[n_y][n_x + 1]
877                    weit[n_y][n_x] += 1
878                # go 4 next nearest neighbors when no non-zero
879                # neighbor exists
880                if n_y != 0 and n_x != 0 and \
881                        np.isfinite(image[n_y - 1][n_x - 1]):
882                    temp_image[n_y][n_x] += image[n_y - 1][n_x - 1]
883                    weit[n_y][n_x] += 1
884                if n_y != len_y - 1 and n_x != 0 and \
885                        np.isfinite(image[n_y + 1][n_x - 1]):
886                    temp_image[n_y][n_x] += image[n_y + 1][n_x - 1]
887                    weit[n_y][n_x] += 1
888                if n_y != len_y and n_x != len_x - 1 and \
889                        np.isfinite(image[n_y - 1][n_x + 1]):
890                    temp_image[n_y][n_x] += image[n_y - 1][n_x + 1]
891                    weit[n_y][n_x] += 1
892                if n_y != len_y - 1 and n_x != len_x - 1 and \
893                        np.isfinite(image[n_y + 1][n_x + 1]):
894                    temp_image[n_y][n_x] += image[n_y + 1][n_x + 1]
895                    weit[n_y][n_x] += 1
896
897    # get it normalized
898    ind = (weit > 0)
899    image[ind] = temp_image[ind] / weit[ind]
900
901    return image
902
903
904def demo():
905    # type: () -> None
906    """
907    Load and plot a SAS dataset.
908    """
909    data = load_data('DEC07086.DAT')
910    set_beam_stop(data, 0.004)
911    plot_data(data)
912    import matplotlib.pyplot as plt  # type: ignore
913    plt.show()
914
915
916if __name__ == "__main__":
917    demo()
Note: See TracBrowser for help on using the repository browser.