source: sasmodels/sasmodels/data.py @ 01c8d9e

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 01c8d9e was 01c8d9e, checked in by Suczewski <ges3@…>, 6 years ago

beta approximation, first pass

  • Property mode set to 100644
File size: 30.3 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np  # type: ignore
38from numpy import sqrt, sin, cos, pi
39
40# pylint: disable=unused-import
41try:
42    from typing import Union, Dict, List, Optional
43except ImportError:
44    pass
45else:
46    Data = Union["Data1D", "Data2D", "SesansData"]
47# pylint: enable=unused-import
48
49def load_data(filename, index=0):
50    # type: (str) -> Data
51    """
52    Load data using a sasview loader.
53    """
54    from sas.sascalc.dataloader.loader import Loader  # type: ignore
55    loader = Loader()
56    # Allow for one part in multipart file
57    if '[' in filename:
58        filename, indexstr = filename[:-1].split('[')
59        index = int(indexstr)
60    datasets = loader.load(filename)
61    if not datasets:  # None or []
62        raise IOError("Data %r could not be loaded" % filename)
63    if not isinstance(datasets, list):
64        datasets = [datasets]
65    for data in datasets:
66        if hasattr(data, 'x'):
67            data.qmin, data.qmax = data.x.min(), data.x.max()
68            data.mask = (np.isnan(data.y) if data.y is not None
69                         else np.zeros_like(data.x, dtype='bool'))
70        elif hasattr(data, 'qx_data'):
71            data.mask = ~data.mask
72    return datasets[index] if index != 'all' else datasets
73
74
75def set_beam_stop(data, radius, outer=None):
76    # type: (Data, float, Optional[float]) -> None
77    """
78    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
79    """
80    from sas.sascalc.dataloader.manipulations import Ringcut
81    if hasattr(data, 'qx_data'):
82        data.mask = Ringcut(0, radius)(data)
83        if outer is not None:
84            data.mask += Ringcut(outer, np.inf)(data)
85    else:
86        data.mask = (data.x < radius)
87        if outer is not None:
88            data.mask |= (data.x >= outer)
89
90
91def set_half(data, half):
92    # type: (Data, str) -> None
93    """
94    Select half of the data, either "right" or "left".
95    """
96    from sas.sascalc.dataloader.manipulations import Boxcut
97    if half == 'right':
98        data.mask += \
99            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
100    if half == 'left':
101        data.mask += \
102            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
103
104
105def set_top(data, cutoff):
106    # type: (Data, float) -> None
107    """
108    Chop the top off the data, above *cutoff*.
109    """
110    from sas.sascalc.dataloader.manipulations import Boxcut
111    data.mask += \
112        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
113
114
115class Data1D(object):
116    """
117    1D data object.
118
119    Note that this definition matches the attributes from sasview, with
120    some generic 1D data vectors and some SAS specific definitions.  Some
121    refactoring to allow consistent naming conventions between 1D, 2D and
122    SESANS data would be helpful.
123
124    **Attributes**
125
126    *x*, *dx*: $q$ vector and gaussian resolution
127
128    *y*, *dy*: $I(q)$ vector and measurement uncertainty
129
130    *mask*: values to include in plotting/analysis
131
132    *dxl*: slit widths for slit smeared data, with *dx* ignored
133
134    *qmin*, *qmax*: range of $q$ values in *x*
135
136    *filename*: label for the data line
137
138    *_xaxis*, *_xunit*: label and units for the *x* axis
139
140    *_yaxis*, *_yunit*: label and units for the *y* axis
141    """
142    def __init__(self,
143                 x=None,  # type: Optional[np.ndarray]
144                 y=None,  # type: Optional[np.ndarray]
145                 dx=None, # type: Optional[np.ndarray]
146                 dy=None  # type: Optional[np.ndarray]
147                ):
148        # type: (...) -> None
149        self.x, self.y, self.dx, self.dy = x, y, dx, dy
150        self.dxl = None
151        self.filename = None
152        self.qmin = x.min() if x is not None else np.NaN
153        self.qmax = x.max() if x is not None else np.NaN
154        # TODO: why is 1D mask False and 2D mask True?
155        self.mask = (np.isnan(y) if y is not None
156                     else np.zeros_like(x, 'b') if x is not None
157                     else None)
158        self._xaxis, self._xunit = "x", ""
159        self._yaxis, self._yunit = "y", ""
160
161    def xaxis(self, label, unit):
162        # type: (str, str) -> None
163        """
164        set the x axis label and unit
165        """
166        self._xaxis = label
167        self._xunit = unit
168
169    def yaxis(self, label, unit):
170        # type: (str, str) -> None
171        """
172        set the y axis label and unit
173        """
174        self._yaxis = label
175        self._yunit = unit
176
177class SesansData(Data1D):
178    """
179    SESANS data object.
180
181    This is just :class:`Data1D` with a wavelength parameter.
182
183    *x* is spin echo length and *y* is polarization (P/P0).
184    """
185    def __init__(self, **kw):
186        Data1D.__init__(self, **kw)
187        self.lam = None # type: Optional[np.ndarray]
188
189class Data2D(object):
190    """
191    2D data object.
192
193    Note that this definition matches the attributes from sasview. Some
194    refactoring to allow consistent naming conventions between 1D, 2D and
195    SESANS data would be helpful.
196
197    **Attributes**
198
199    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
200
201    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
202
203    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
204
205    *mask*: values to exclude from plotting/analysis
206
207    *qmin*, *qmax*: range of $q$ values in *x*
208
209    *filename*: label for the data line
210
211    *_xaxis*, *_xunit*: label and units for the *x* axis
212
213    *_yaxis*, *_yunit*: label and units for the *y* axis
214
215    *_zaxis*, *_zunit*: label and units for the *y* axis
216
217    *Q_unit*, *I_unit*: units for Q and intensity
218
219    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
220    """
221    def __init__(self,
222                 x=None,   # type: Optional[np.ndarray]
223                 y=None,   # type: Optional[np.ndarray]
224                 z=None,   # type: Optional[np.ndarray]
225                 dx=None,  # type: Optional[np.ndarray]
226                 dy=None,  # type: Optional[np.ndarray]
227                 dz=None   # type: Optional[np.ndarray]
228                ):
229        # type: (...) -> None
230        self.qx_data, self.dqx_data = x, dx
231        self.qy_data, self.dqy_data = y, dy
232        self.data, self.err_data = z, dz
233        self.mask = (np.isnan(z) if z is not None
234                     else np.zeros_like(x, dtype='bool') if x is not None
235                     else None)
236        self.q_data = np.sqrt(x**2 + y**2)
237        self.qmin = 1e-16
238        self.qmax = np.inf
239        self.detector = []
240        self.source = Source()
241        self.Q_unit = "1/A"
242        self.I_unit = "1/cm"
243        self.xaxis("Q_x", "1/A")
244        self.yaxis("Q_y", "1/A")
245        self.zaxis("Intensity", "1/cm")
246        self._xaxis, self._xunit = "x", ""
247        self._yaxis, self._yunit = "y", ""
248        self._zaxis, self._zunit = "z", ""
249        self.x_bins, self.y_bins = None, None
250        self.filename = None
251
252    def xaxis(self, label, unit):
253        # type: (str, str) -> None
254        """
255        set the x axis label and unit
256        """
257        self._xaxis = label
258        self._xunit = unit
259
260    def yaxis(self, label, unit):
261        # type: (str, str) -> None
262        """
263        set the y axis label and unit
264        """
265        self._yaxis = label
266        self._yunit = unit
267
268    def zaxis(self, label, unit):
269        # type: (str, str) -> None
270        """
271        set the y axis label and unit
272        """
273        self._zaxis = label
274        self._zunit = unit
275
276
277class Vector(object):
278    """
279    3-space vector of *x*, *y*, *z*
280    """
281    def __init__(self, x=None, y=None, z=None):
282        # type: (float, float, Optional[float]) -> None
283        self.x, self.y, self.z = x, y, z
284
285class Detector(object):
286    """
287    Detector attributes.
288    """
289    def __init__(self, pixel_size=(None, None), distance=None):
290        # type: (Tuple[float, float], float) -> None
291        self.pixel_size = Vector(*pixel_size)
292        self.distance = distance
293
294class Source(object):
295    """
296    Beam attributes.
297    """
298    def __init__(self):
299        # type: () -> None
300        self.wavelength = np.NaN
301        self.wavelength_unit = "A"
302
303
304def empty_data1D(q, resolution=0.0, L=0., dL=0.):
305    # type: (np.ndarray, float) -> Data1D
306    r"""
307    Create empty 1D data using the given *q* as the x value.
308
309    rms *resolution* $\Delta q/q$ defaults to 0%.  If wavelength *L* and rms
310    wavelength divergence *dL* are defined, then *resolution* defines
311    rms $\Delta \theta/\theta$ for the lowest *q*, with $\theta$ derived from
312    $q = 4\pi/\lambda \sin(\theta)$.
313    """
314
315    #Iq = 100 * np.ones_like(q)
316    #dIq = np.sqrt(Iq)
317    Iq, dIq = None, None
318    q = np.asarray(q)
319    if L != 0 and resolution != 0:
320        theta = np.arcsin(q*L/(4*pi))
321        dtheta = theta[0]*resolution
322        ## Solving Gaussian error propagation from
323        ##   Dq^2 = (dq/dL)^2 DL^2 + (dq/dtheta)^2 Dtheta^2
324        ## gives
325        ##   (Dq/q)^2 = (DL/L)**2 + (Dtheta/tan(theta))**2
326        ## Take the square root and multiply by q, giving
327        ##   Dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
328        dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
329    else:
330        dq = resolution * q
331    data = Data1D(q, Iq, dx=dq, dy=dIq)
332    data.filename = "fake data"
333    return data
334
335
336def empty_data2D(qx, qy=None, resolution=0.0):
337    # type: (np.ndarray, Optional[np.ndarray], float) -> Data2D
338    """
339    Create empty 2D data using the given mesh.
340
341    If *qy* is missing, create a square mesh with *qy=qx*.
342
343    *resolution* dq/q defaults to 5%.
344    """
345    if qy is None:
346        qy = qx
347    qx, qy = np.asarray(qx), np.asarray(qy)
348    # 5% dQ/Q resolution
349    Qx, Qy = np.meshgrid(qx, qy)
350    Qx, Qy = Qx.flatten(), Qy.flatten()
351    Iq = 100 * np.ones_like(Qx)  # type: np.ndarray
352    dIq = np.sqrt(Iq)
353    if resolution != 0:
354        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
355        # Should have an additional constant which depends on distances and
356        # radii of the aperture, pixel dimensions and wavelength spread
357        # Instead, assume radial dQ/Q is constant, and perpendicular matches
358        # radial (which instead it should be inverse).
359        Q = np.sqrt(Qx**2 + Qy**2)
360        dqx = resolution * Q
361        dqy = resolution * Q
362    else:
363        dqx = dqy = None
364
365    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
366    data.x_bins = qx
367    data.y_bins = qy
368    data.filename = "fake data"
369
370    # pixel_size in mm, distance in m
371    detector = Detector(pixel_size=(5, 5), distance=4)
372    data.detector.append(detector)
373    data.source.wavelength = 5 # angstroms
374    data.source.wavelength_unit = "A"
375    return data
376
377
378def plot_data(data, view='log', limits=None):
379    # type: (Data, str, Optional[Tuple[float, float]]) -> None
380    """
381    Plot data loaded by the sasview loader.
382
383    *data* is a sasview data object, either 1D, 2D or SESANS.
384
385    *view* is log or linear.
386
387    *limits* sets the intensity limits on the plot; if None then the limits
388    are inferred from the data.
389    """
390    # Note: kind of weird using the plot result functions to plot just the
391    # data, but they already handle the masking and graph markup already, so
392    # do not repeat.
393    if hasattr(data, 'isSesans') and data.isSesans:
394        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
395    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
396        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
397    else:
398        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
399
400
401def plot_theory(data,          # type: Data
402                theory,        # type: Optional[np.ndarray]
403                resid=None,    # type: Optional[np.ndarray]
404                view='log',    # type: str
405                use_data=True, # type: bool
406                limits=None,   # type: Optional[np.ndarray]
407                Iq_calc=None   # type: Optional[np.ndarray]
408               ):
409    # type: (...) -> None
410    """
411    Plot theory calculation.
412
413    *data* is needed to define the graph properties such as labels and
414    units, and to define the data mask.
415
416    *theory* is a matrix of the same shape as the data.
417
418    *view* is log or linear
419
420    *use_data* is True if the data should be plotted as well as the theory.
421
422    *limits* sets the intensity limits on the plot; if None then the limits
423    are inferred from the data.
424
425    *Iq_calc* is the raw theory values without resolution smearing
426    """
427    if hasattr(data, 'isSesans') and data.isSesans:
428        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
429    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
430        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
431    else:
432        _plot_result1D(data, theory, resid, view, use_data,
433                       limits=limits, Iq_calc=Iq_calc)
434
435
436def protect(func):
437    # type: (Callable) -> Callable
438    """
439    Decorator to wrap calls in an exception trapper which prints the
440    exception and continues.  Keyboard interrupts are ignored.
441    """
442    def wrapper(*args, **kw):
443        """
444        Trap and print errors from function.
445        """
446        try:
447            return func(*args, **kw)
448        except Exception:
449            traceback.print_exc()
450
451    return wrapper
452
453
454@protect
455def _plot_result1D(data,         # type: Data1D
456                   theory,       # type: Optional[np.ndarray]
457                   resid,        # type: Optional[np.ndarray]
458                   view,         # type: str
459                   use_data,     # type: bool
460                   limits=None,  # type: Optional[Tuple[float, float]]
461                   Iq_calc=None  # type: Optional[np.ndarray]
462                  ):
463    # type: (...) -> None
464    """
465    Plot the data and residuals for 1D data.
466    """
467    import matplotlib.pyplot as plt  # type: ignore
468    from numpy.ma import masked_array, masked  # type: ignore
469
470    if getattr(data, 'radial', False):
471        data.x = data.q_data
472        data.y = data.data
473
474    use_data = use_data and data.y is not None
475    use_theory = theory is not None
476    use_resid = resid is not None
477    use_calc = use_theory and Iq_calc is not None
478    num_plots = (use_data or use_theory) + use_calc + use_resid
479    non_positive_x = (data.x <= 0.0).any()
480
481    scale = data.x**4 if view == 'q4' else 1.0
482    xscale = yscale = 'linear' if view == 'linear' else 'log'
483
484    if use_data or use_theory:
485        if num_plots > 1:
486            plt.subplot(1, num_plots, 1)
487
488        #print(vmin, vmax)
489        all_positive = True
490        some_present = False
491        if use_data:
492            mdata = masked_array(data.y, data.mask.copy())
493            mdata[~np.isfinite(mdata)] = masked
494            if view is 'log':
495                mdata[mdata <= 0] = masked
496            plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
497            all_positive = all_positive and (mdata > 0).all()
498            some_present = some_present or (mdata.count() > 0)
499
500
501        if use_theory:
502            # Note: masks merge, so any masked theory points will stay masked,
503            # and the data mask will be added to it.
504            mtheory = masked_array(theory, data.mask.copy())
505            mtheory[~np.isfinite(mtheory)] = masked
506            if view is 'log':
507                mtheory[mtheory <= 0] = masked
508            plt.plot(data.x, scale*mtheory, '-')
509            all_positive = all_positive and (mtheory > 0).all()
510            some_present = some_present or (mtheory.count() > 0)
511
512        if limits is not None:
513            plt.ylim(*limits)
514
515
516        xscale = ('linear' if not some_present or non_positive_x
517                  else view if view is not None
518                  else 'log')
519        yscale = ('linear'
520                  if view == 'q4' or not some_present or not all_positive
521                  else view if view is not None
522                  else 'log')
523        plt.xscale(xscale)
524       
525        plt.xlabel("$q$/A$^{-1}$")
526        plt.yscale(yscale)
527        plt.ylabel('$I(q)$')
528        title = ("data and model" if use_theory and use_data
529                 else "data" if use_data
530                 else "model")
531        plt.title(title)
532
533    if use_calc:
534        # Only have use_calc if have use_theory
535        plt.subplot(1, num_plots, 2)
536        qx, qy, Iqxy = Iq_calc
537        plt.pcolormesh(qx, qy[qy > 0], np.log10(Iqxy[qy > 0, :]))
538        plt.xlabel("$q_x$/A$^{-1}$")
539        plt.xlabel("$q_y$/A$^{-1}$")
540        plt.xscale('log')
541        plt.yscale('log')
542        #plt.axis('equal')
543
544    if use_resid:
545        mresid = masked_array(resid, data.mask.copy())
546        mresid[~np.isfinite(mresid)] = masked
547        some_present = (mresid.count() > 0)
548
549        if num_plots > 1:
550            plt.subplot(1, num_plots, use_calc + 2)
551        plt.plot(data.x, mresid, '.')
552        plt.xlabel("$q$/A$^{-1}$")
553        plt.ylabel('residuals')
554        plt.title('(model - Iq)/dIq')
555        plt.xscale(xscale)
556        plt.yscale('linear')
557
558
559@protect
560def _plot_result_sesans(data,        # type: SesansData
561                        theory,      # type: Optional[np.ndarray]
562                        resid,       # type: Optional[np.ndarray]
563                        use_data,    # type: bool
564                        limits=None  # type: Optional[Tuple[float, float]]
565                       ):
566    # type: (...) -> None
567    """
568    Plot SESANS results.
569    """
570    import matplotlib.pyplot as plt  # type: ignore
571    use_data = use_data and data.y is not None
572    use_theory = theory is not None
573    use_resid = resid is not None
574    num_plots = (use_data or use_theory) + use_resid
575
576    if use_data or use_theory:
577        is_tof = data.lam is not None and (data.lam != data.lam[0]).any()
578        if num_plots > 1:
579            plt.subplot(1, num_plots, 1)
580        if use_data:
581            if is_tof:
582                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam),
583                             yerr=data.dy/data.y/(data.lam*data.lam))
584            else:
585                plt.errorbar(data.x, data.y, yerr=data.dy)
586        if theory is not None:
587            if is_tof:
588                plt.plot(data.x, np.log(theory)/(data.lam*data.lam), '-')
589            else:
590                plt.plot(data.x, theory, '-')
591        if limits is not None:
592            plt.ylim(*limits)
593
594        plt.xlabel('spin echo length ({})'.format(data._xunit))
595        if is_tof:
596            plt.ylabel(r'(Log (P/P$_0$))/$\lambda^2$')
597        else:
598            plt.ylabel('polarization (P/P0)')
599
600
601    if resid is not None:
602        if num_plots > 1:
603            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
604        plt.plot(data.x, resid, 'x')
605        plt.xlabel('spin echo length ({})'.format(data._xunit))
606        plt.ylabel('residuals (P/P0)')
607
608
609@protect
610def _plot_result2D(data,         # type: Data2D
611                   theory,       # type: Optional[np.ndarray]
612                   resid,        # type: Optional[np.ndarray]
613                   view,         # type: str
614                   use_data,     # type: bool
615                   limits=None   # type: Optional[Tuple[float, float]]
616                  ):
617    # type: (...) -> None
618    """
619    Plot the data and residuals for 2D data.
620    """
621    import matplotlib.pyplot as plt  # type: ignore
622    use_data = use_data and data.data is not None
623    use_theory = theory is not None
624    use_resid = resid is not None
625    num_plots = use_data + use_theory + use_resid
626
627    # Put theory and data on a common colormap scale
628    vmin, vmax = np.inf, -np.inf
629    target = None # type: Optional[np.ndarray]
630    if use_data:
631        target = data.data[~data.mask]
632        datamin = target[target > 0].min() if view == 'log' else target.min()
633        datamax = target.max()
634        vmin = min(vmin, datamin)
635        vmax = max(vmax, datamax)
636    if use_theory:
637        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
638        theorymax = theory.max()
639        vmin = min(vmin, theorymin)
640        vmax = max(vmax, theorymax)
641
642    # Override data limits from the caller
643    if limits is not None:
644        vmin, vmax = limits
645
646    # Plot data
647    if use_data:
648        if num_plots > 1:
649            plt.subplot(1, num_plots, 1)
650        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
651        plt.title('data')
652        h = plt.colorbar()
653        h.set_label('$I(q)$')
654
655    # plot theory
656    if use_theory:
657        if num_plots > 1:
658            plt.subplot(1, num_plots, use_data+1)
659        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
660        plt.title('theory')
661        h = plt.colorbar()
662        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
663                    else r'$q^4 I(q)$' if view == 'q4'
664                    else '$I(q)$')
665
666    # plot resid
667    if use_resid:
668        if num_plots > 1:
669            plt.subplot(1, num_plots, use_data+use_theory+1)
670        _plot_2d_signal(data, resid, view='linear')
671        plt.title('residuals')
672        h = plt.colorbar()
673        h.set_label(r'$\Delta I(q)$')
674
675
676@protect
677def _plot_2d_signal(data,       # type: Data2D
678                    signal,     # type: np.ndarray
679                    vmin=None,  # type: Optional[float]
680                    vmax=None,  # type: Optional[float]
681                    view='log'  # type: str
682                   ):
683    # type: (...) -> Tuple[float, float]
684    """
685    Plot the target value for the data.  This could be the data itself,
686    the theory calculation, or the residuals.
687
688    *scale* can be 'log' for log scale data, or 'linear'.
689    """
690    import matplotlib.pyplot as plt  # type: ignore
691    from numpy.ma import masked_array  # type: ignore
692
693    image = np.zeros_like(data.qx_data)
694    image[~data.mask] = signal
695    valid = np.isfinite(image)
696    if view == 'log':
697        valid[valid] = (image[valid] > 0)
698        if vmin is None:
699            vmin = image[valid & ~data.mask].min()
700        if vmax is None:
701            vmax = image[valid & ~data.mask].max()
702        image[valid] = np.log10(image[valid])
703    elif view == 'q4':
704        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
705        if vmin is None:
706            vmin = image[valid & ~data.mask].min()
707        if vmax is None:
708            vmax = image[valid & ~data.mask].max()
709    else:
710        if vmin is None:
711            vmin = image[valid & ~data.mask].min()
712        if vmax is None:
713            vmax = image[valid & ~data.mask].max()
714
715    image[~valid | data.mask] = 0
716    #plottable = Iq
717    plottable = masked_array(image, ~valid | data.mask)
718    # Divide range by 10 to convert from angstroms to nanometers
719    xmin, xmax = min(data.qx_data), max(data.qx_data)
720    ymin, ymax = min(data.qy_data), max(data.qy_data)
721    if view == 'log':
722        vmin_scaled, vmax_scaled = np.log10(vmin), np.log10(vmax)
723    else:
724        vmin_scaled, vmax_scaled = vmin, vmax
725    #nx, ny = len(data.x_bins), len(data.y_bins)
726    x_bins, y_bins, image = _build_matrix(data, plottable)
727    plt.imshow(image,
728               interpolation='nearest', aspect=1, origin='lower',
729               extent=[xmin, xmax, ymin, ymax],
730               vmin=vmin_scaled, vmax=vmax_scaled)
731    plt.xlabel("$q_x$/A$^{-1}$")
732    plt.ylabel("$q_y$/A$^{-1}$")
733    return vmin, vmax
734
735
736# === The following is modified from sas.sasgui.plottools.PlotPanel
737def _build_matrix(self, plottable):
738    """
739    Build a matrix for 2d plot from a vector
740    Returns a matrix (image) with ~ square binning
741    Requirement: need 1d array formats of
742    self.data, self.qx_data, and self.qy_data
743    where each one corresponds to z, x, or y axis values
744
745    """
746    # No qx or qy given in a vector format
747    if self.qx_data is None or self.qy_data is None \
748            or self.qx_data.ndim != 1 or self.qy_data.ndim != 1:
749        return self.x_bins, self.y_bins, plottable
750
751    # maximum # of loops to fillup_pixels
752    # otherwise, loop could never stop depending on data
753    max_loop = 1
754    # get the x and y_bin arrays.
755    x_bins, y_bins = _get_bins(self)
756    # set zero to None
757
758    #Note: Can not use scipy.interpolate.Rbf:
759    # 'cause too many data points (>10000)<=JHC.
760    # 1d array to use for weighting the data point averaging
761    #when they fall into a same bin.
762    weights_data = np.ones([self.data.size])
763    # get histogram of ones w/len(data); this will provide
764    #the weights of data on each bins
765    weights, xedges, yedges = np.histogram2d(x=self.qy_data,
766                                             y=self.qx_data,
767                                             bins=[y_bins, x_bins],
768                                             weights=weights_data)
769    # get histogram of data, all points into a bin in a way of summing
770    image, xedges, yedges = np.histogram2d(x=self.qy_data,
771                                           y=self.qx_data,
772                                           bins=[y_bins, x_bins],
773                                           weights=plottable)
774    # Now, normalize the image by weights only for weights>1:
775    # If weight == 1, there is only one data point in the bin so
776    # that no normalization is required.
777    image[weights > 1] = image[weights > 1] / weights[weights > 1]
778    # Set image bins w/o a data point (weight==0) as None (was set to zero
779    # by histogram2d.)
780    image[weights == 0] = None
781
782    # Fill empty bins with 8 nearest neighbors only when at least
783    #one None point exists
784    loop = 0
785
786    # do while loop until all vacant bins are filled up up
787    #to loop = max_loop
788    while (weights == 0).any():
789        if loop >= max_loop:  # this protects never-ending loop
790            break
791        image = _fillup_pixels(image=image, weights=weights)
792        loop += 1
793
794    return x_bins, y_bins, image
795
796def _get_bins(self):
797    """
798    get bins
799    set x_bins and y_bins into self, 1d arrays of the index with
800    ~ square binning
801    Requirement: need 1d array formats of
802    self.qx_data, and self.qy_data
803    where each one corresponds to  x, or y axis values
804    """
805    # find max and min values of qx and qy
806    xmax = self.qx_data.max()
807    xmin = self.qx_data.min()
808    ymax = self.qy_data.max()
809    ymin = self.qy_data.min()
810
811    # calculate the range of qx and qy: this way, it is a little
812    # more independent
813    x_size = xmax - xmin
814    y_size = ymax - ymin
815
816    # estimate the # of pixels on each axes
817    npix_y = int(np.floor(np.sqrt(len(self.qy_data))))
818    npix_x = int(np.floor(len(self.qy_data) / npix_y))
819
820    # bin size: x- & y-directions
821    xstep = x_size / (npix_x - 1)
822    ystep = y_size / (npix_y - 1)
823
824    # max and min taking account of the bin sizes
825    xmax = xmax + xstep / 2.0
826    xmin = xmin - xstep / 2.0
827    ymax = ymax + ystep / 2.0
828    ymin = ymin - ystep / 2.0
829
830    # store x and y bin centers in q space
831    x_bins = np.linspace(xmin, xmax, npix_x)
832    y_bins = np.linspace(ymin, ymax, npix_y)
833
834    return x_bins, y_bins
835
836def _fillup_pixels(image=None, weights=None):
837    """
838    Fill z values of the empty cells of 2d image matrix
839    with the average over up-to next nearest neighbor points
840
841    :param image: (2d matrix with some zi = None)
842
843    :return: image (2d array )
844
845    :TODO: Find better way to do for-loop below
846
847    """
848    # No image matrix given
849    if image is None or np.ndim(image) != 2 \
850            or np.isfinite(image).all() \
851            or weights is None:
852        return image
853    # Get bin size in y and x directions
854    len_y = len(image)
855    len_x = len(image[1])
856    temp_image = np.zeros([len_y, len_x])
857    weit = np.zeros([len_y, len_x])
858    # do for-loop for all pixels
859    for n_y in range(len(image)):
860        for n_x in range(len(image[1])):
861            # find only null pixels
862            if weights[n_y][n_x] > 0 or np.isfinite(image[n_y][n_x]):
863                continue
864            else:
865                # find 4 nearest neighbors
866                # check where or not it is at the corner
867                if n_y != 0 and np.isfinite(image[n_y - 1][n_x]):
868                    temp_image[n_y][n_x] += image[n_y - 1][n_x]
869                    weit[n_y][n_x] += 1
870                if n_x != 0 and np.isfinite(image[n_y][n_x - 1]):
871                    temp_image[n_y][n_x] += image[n_y][n_x - 1]
872                    weit[n_y][n_x] += 1
873                if n_y != len_y - 1 and np.isfinite(image[n_y + 1][n_x]):
874                    temp_image[n_y][n_x] += image[n_y + 1][n_x]
875                    weit[n_y][n_x] += 1
876                if n_x != len_x - 1 and np.isfinite(image[n_y][n_x + 1]):
877                    temp_image[n_y][n_x] += image[n_y][n_x + 1]
878                    weit[n_y][n_x] += 1
879                # go 4 next nearest neighbors when no non-zero
880                # neighbor exists
881                if n_y != 0 and n_x != 0 and \
882                        np.isfinite(image[n_y - 1][n_x - 1]):
883                    temp_image[n_y][n_x] += image[n_y - 1][n_x - 1]
884                    weit[n_y][n_x] += 1
885                if n_y != len_y - 1 and n_x != 0 and \
886                        np.isfinite(image[n_y + 1][n_x - 1]):
887                    temp_image[n_y][n_x] += image[n_y + 1][n_x - 1]
888                    weit[n_y][n_x] += 1
889                if n_y != len_y and n_x != len_x - 1 and \
890                        np.isfinite(image[n_y - 1][n_x + 1]):
891                    temp_image[n_y][n_x] += image[n_y - 1][n_x + 1]
892                    weit[n_y][n_x] += 1
893                if n_y != len_y - 1 and n_x != len_x - 1 and \
894                        np.isfinite(image[n_y + 1][n_x + 1]):
895                    temp_image[n_y][n_x] += image[n_y + 1][n_x + 1]
896                    weit[n_y][n_x] += 1
897
898    # get it normalized
899    ind = (weit > 0)
900    image[ind] = temp_image[ind] / weit[ind]
901
902    return image
903
904
905def demo():
906    # type: () -> None
907    """
908    Load and plot a SAS dataset.
909    """
910    data = load_data('DEC07086.DAT')
911    set_beam_stop(data, 0.004)
912    plot_data(data)
913    import matplotlib.pyplot as plt  # type: ignore
914    plt.show()
915
916
917if __name__ == "__main__":
918    demo()
Note: See TracBrowser for help on using the repository browser.