source: sasmodels/sasmodels/data.py @ 65fbf7c

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 65fbf7c was 65fbf7c, checked in by Paul Kienzle <pkienzle@…>, 4 years ago

play with resolution defined by fixed dtheta,dlambda

  • Property mode set to 100644
File size: 30.3 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np  # type: ignore
38from numpy import sqrt, sin, cos, pi
39
40# pylint: disable=unused-import
41try:
42    from typing import Union, Dict, List, Optional
43except ImportError:
44    pass
45else:
46    Data = Union["Data1D", "Data2D", "SesansData"]
47# pylint: enable=unused-import
48
49def load_data(filename, index=0):
50    # type: (str) -> Data
51    """
52    Load data using a sasview loader.
53    """
54    from sas.sascalc.dataloader.loader import Loader  # type: ignore
55    loader = Loader()
56    # Allow for one part in multipart file
57    if '[' in filename:
58        filename, indexstr = filename[:-1].split('[')
59        index = int(indexstr)
60    datasets = loader.load(filename)
61    if not datasets:  # None or []
62        raise IOError("Data %r could not be loaded" % filename)
63    if not isinstance(datasets, list):
64        datasets = [datasets]
65    for data in datasets:
66        if hasattr(data, 'x'):
67            data.qmin, data.qmax = data.x.min(), data.x.max()
68            data.mask = (np.isnan(data.y) if data.y is not None
69                         else np.zeros_like(data.x, dtype='bool'))
70        elif hasattr(data, 'qx_data'):
71            data.mask = ~data.mask
72    return datasets[index] if index != 'all' else datasets
73
74
75def set_beam_stop(data, radius, outer=None):
76    # type: (Data, float, Optional[float]) -> None
77    """
78    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
79    """
80    from sas.sascalc.dataloader.manipulations import Ringcut
81    if hasattr(data, 'qx_data'):
82        data.mask = Ringcut(0, radius)(data)
83        if outer is not None:
84            data.mask += Ringcut(outer, np.inf)(data)
85    else:
86        data.mask = (data.x < radius)
87        if outer is not None:
88            data.mask |= (data.x >= outer)
89
90
91def set_half(data, half):
92    # type: (Data, str) -> None
93    """
94    Select half of the data, either "right" or "left".
95    """
96    from sas.sascalc.dataloader.manipulations import Boxcut
97    if half == 'right':
98        data.mask += \
99            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
100    if half == 'left':
101        data.mask += \
102            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
103
104
105def set_top(data, cutoff):
106    # type: (Data, float) -> None
107    """
108    Chop the top off the data, above *cutoff*.
109    """
110    from sas.sascalc.dataloader.manipulations import Boxcut
111    data.mask += \
112        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
113
114
115class Data1D(object):
116    """
117    1D data object.
118
119    Note that this definition matches the attributes from sasview, with
120    some generic 1D data vectors and some SAS specific definitions.  Some
121    refactoring to allow consistent naming conventions between 1D, 2D and
122    SESANS data would be helpful.
123
124    **Attributes**
125
126    *x*, *dx*: $q$ vector and gaussian resolution
127
128    *y*, *dy*: $I(q)$ vector and measurement uncertainty
129
130    *mask*: values to include in plotting/analysis
131
132    *dxl*: slit widths for slit smeared data, with *dx* ignored
133
134    *qmin*, *qmax*: range of $q$ values in *x*
135
136    *filename*: label for the data line
137
138    *_xaxis*, *_xunit*: label and units for the *x* axis
139
140    *_yaxis*, *_yunit*: label and units for the *y* axis
141    """
142    def __init__(self,
143                 x=None,  # type: Optional[np.ndarray]
144                 y=None,  # type: Optional[np.ndarray]
145                 dx=None, # type: Optional[np.ndarray]
146                 dy=None  # type: Optional[np.ndarray]
147                ):
148        # type: (...) -> None
149        self.x, self.y, self.dx, self.dy = x, y, dx, dy
150        self.dxl = None
151        self.filename = None
152        self.qmin = x.min() if x is not None else np.NaN
153        self.qmax = x.max() if x is not None else np.NaN
154        # TODO: why is 1D mask False and 2D mask True?
155        self.mask = (np.isnan(y) if y is not None
156                     else np.zeros_like(x, 'b') if x is not None
157                     else None)
158        self._xaxis, self._xunit = "x", ""
159        self._yaxis, self._yunit = "y", ""
160
161    def xaxis(self, label, unit):
162        # type: (str, str) -> None
163        """
164        set the x axis label and unit
165        """
166        self._xaxis = label
167        self._xunit = unit
168
169    def yaxis(self, label, unit):
170        # type: (str, str) -> None
171        """
172        set the y axis label and unit
173        """
174        self._yaxis = label
175        self._yunit = unit
176
177class SesansData(Data1D):
178    """
179    SESANS data object.
180
181    This is just :class:`Data1D` with a wavelength parameter.
182
183    *x* is spin echo length and *y* is polarization (P/P0).
184    """
185    def __init__(self, **kw):
186        Data1D.__init__(self, **kw)
187        self.lam = None # type: Optional[np.ndarray]
188
189class Data2D(object):
190    """
191    2D data object.
192
193    Note that this definition matches the attributes from sasview. Some
194    refactoring to allow consistent naming conventions between 1D, 2D and
195    SESANS data would be helpful.
196
197    **Attributes**
198
199    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
200
201    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
202
203    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
204
205    *mask*: values to exclude from plotting/analysis
206
207    *qmin*, *qmax*: range of $q$ values in *x*
208
209    *filename*: label for the data line
210
211    *_xaxis*, *_xunit*: label and units for the *x* axis
212
213    *_yaxis*, *_yunit*: label and units for the *y* axis
214
215    *_zaxis*, *_zunit*: label and units for the *y* axis
216
217    *Q_unit*, *I_unit*: units for Q and intensity
218
219    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
220    """
221    def __init__(self,
222                 x=None,   # type: Optional[np.ndarray]
223                 y=None,   # type: Optional[np.ndarray]
224                 z=None,   # type: Optional[np.ndarray]
225                 dx=None,  # type: Optional[np.ndarray]
226                 dy=None,  # type: Optional[np.ndarray]
227                 dz=None   # type: Optional[np.ndarray]
228                ):
229        # type: (...) -> None
230        self.qx_data, self.dqx_data = x, dx
231        self.qy_data, self.dqy_data = y, dy
232        self.data, self.err_data = z, dz
233        self.mask = (np.isnan(z) if z is not None
234                     else np.zeros_like(x, dtype='bool') if x is not None
235                     else None)
236        self.q_data = np.sqrt(x**2 + y**2)
237        self.qmin = 1e-16
238        self.qmax = np.inf
239        self.detector = []
240        self.source = Source()
241        self.Q_unit = "1/A"
242        self.I_unit = "1/cm"
243        self.xaxis("Q_x", "1/A")
244        self.yaxis("Q_y", "1/A")
245        self.zaxis("Intensity", "1/cm")
246        self._xaxis, self._xunit = "x", ""
247        self._yaxis, self._yunit = "y", ""
248        self._zaxis, self._zunit = "z", ""
249        self.x_bins, self.y_bins = None, None
250        self.filename = None
251
252    def xaxis(self, label, unit):
253        # type: (str, str) -> None
254        """
255        set the x axis label and unit
256        """
257        self._xaxis = label
258        self._xunit = unit
259
260    def yaxis(self, label, unit):
261        # type: (str, str) -> None
262        """
263        set the y axis label and unit
264        """
265        self._yaxis = label
266        self._yunit = unit
267
268    def zaxis(self, label, unit):
269        # type: (str, str) -> None
270        """
271        set the y axis label and unit
272        """
273        self._zaxis = label
274        self._zunit = unit
275
276
277class Vector(object):
278    """
279    3-space vector of *x*, *y*, *z*
280    """
281    def __init__(self, x=None, y=None, z=None):
282        # type: (float, float, Optional[float]) -> None
283        self.x, self.y, self.z = x, y, z
284
285class Detector(object):
286    """
287    Detector attributes.
288    """
289    def __init__(self, pixel_size=(None, None), distance=None):
290        # type: (Tuple[float, float], float) -> None
291        self.pixel_size = Vector(*pixel_size)
292        self.distance = distance
293
294class Source(object):
295    """
296    Beam attributes.
297    """
298    def __init__(self):
299        # type: () -> None
300        self.wavelength = np.NaN
301        self.wavelength_unit = "A"
302
303
304def empty_data1D(q, resolution=0.0, L=0., dL=0.):
305    # type: (np.ndarray, float) -> Data1D
306    r"""
307    Create empty 1D data using the given *q* as the x value.
308
309    rms *resolution* $\Delta q/q$ defaults to 0%.  If wavelength *L* and rms
310    wavelength divergence *dL* are defined, then *resolution* defines
311    rms $\Delta \theta/\theta$ for the lowest *q*, with $\theta$ derived from
312    $q = 4\pi/\lambda \sin(\theta)$.
313    """
314
315    #Iq = 100 * np.ones_like(q)
316    #dIq = np.sqrt(Iq)
317    Iq, dIq = None, None
318    q = np.asarray(q)
319    if L != 0 and resolution != 0:
320        theta = np.arcsin(q*L/(4*pi))
321        dtheta = theta[0]*resolution
322        ## Solving Gaussian error propagation from
323        ##   Dq^2 = (dq/dL)^2 DL^2 + (dq/dtheta)^2 Dtheta^2
324        ## gives
325        ##   (Dq/q)^2 = (DL/L)**2 + (Dtheta/tan(theta))**2
326        ## Take the square root and multiply by q, giving
327        ##   Dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
328        dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
329    else:
330        dq = resolution * q
331
332    data = Data1D(q, Iq, dx=dq, dy=dIq)
333    data.filename = "fake data"
334    return data
335
336
337def empty_data2D(qx, qy=None, resolution=0.0):
338    # type: (np.ndarray, Optional[np.ndarray], float) -> Data2D
339    """
340    Create empty 2D data using the given mesh.
341
342    If *qy* is missing, create a square mesh with *qy=qx*.
343
344    *resolution* dq/q defaults to 5%.
345    """
346    if qy is None:
347        qy = qx
348    qx, qy = np.asarray(qx), np.asarray(qy)
349    # 5% dQ/Q resolution
350    Qx, Qy = np.meshgrid(qx, qy)
351    Qx, Qy = Qx.flatten(), Qy.flatten()
352    Iq = 100 * np.ones_like(Qx)  # type: np.ndarray
353    dIq = np.sqrt(Iq)
354    if resolution != 0:
355        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
356        # Should have an additional constant which depends on distances and
357        # radii of the aperture, pixel dimensions and wavelength spread
358        # Instead, assume radial dQ/Q is constant, and perpendicular matches
359        # radial (which instead it should be inverse).
360        Q = np.sqrt(Qx**2 + Qy**2)
361        dqx = resolution * Q
362        dqy = resolution * Q
363    else:
364        dqx = dqy = None
365
366    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
367    data.x_bins = qx
368    data.y_bins = qy
369    data.filename = "fake data"
370
371    # pixel_size in mm, distance in m
372    detector = Detector(pixel_size=(5, 5), distance=4)
373    data.detector.append(detector)
374    data.source.wavelength = 5 # angstroms
375    data.source.wavelength_unit = "A"
376    return data
377
378
379def plot_data(data, view='log', limits=None):
380    # type: (Data, str, Optional[Tuple[float, float]]) -> None
381    """
382    Plot data loaded by the sasview loader.
383
384    *data* is a sasview data object, either 1D, 2D or SESANS.
385
386    *view* is log or linear.
387
388    *limits* sets the intensity limits on the plot; if None then the limits
389    are inferred from the data.
390    """
391    # Note: kind of weird using the plot result functions to plot just the
392    # data, but they already handle the masking and graph markup already, so
393    # do not repeat.
394    if hasattr(data, 'isSesans') and data.isSesans:
395        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
396    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
397        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
398    else:
399        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
400
401
402def plot_theory(data,          # type: Data
403                theory,        # type: Optional[np.ndarray]
404                resid=None,    # type: Optional[np.ndarray]
405                view='log',    # type: str
406                use_data=True, # type: bool
407                limits=None,   # type: Optional[np.ndarray]
408                Iq_calc=None   # type: Optional[np.ndarray]
409               ):
410    # type: (...) -> None
411    """
412    Plot theory calculation.
413
414    *data* is needed to define the graph properties such as labels and
415    units, and to define the data mask.
416
417    *theory* is a matrix of the same shape as the data.
418
419    *view* is log or linear
420
421    *use_data* is True if the data should be plotted as well as the theory.
422
423    *limits* sets the intensity limits on the plot; if None then the limits
424    are inferred from the data.
425
426    *Iq_calc* is the raw theory values without resolution smearing
427    """
428    if hasattr(data, 'isSesans') and data.isSesans:
429        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
430    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
431        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
432    else:
433        _plot_result1D(data, theory, resid, view, use_data,
434                       limits=limits, Iq_calc=Iq_calc)
435
436
437def protect(func):
438    # type: (Callable) -> Callable
439    """
440    Decorator to wrap calls in an exception trapper which prints the
441    exception and continues.  Keyboard interrupts are ignored.
442    """
443    def wrapper(*args, **kw):
444        """
445        Trap and print errors from function.
446        """
447        try:
448            return func(*args, **kw)
449        except Exception:
450            traceback.print_exc()
451
452    return wrapper
453
454
455@protect
456def _plot_result1D(data,         # type: Data1D
457                   theory,       # type: Optional[np.ndarray]
458                   resid,        # type: Optional[np.ndarray]
459                   view,         # type: str
460                   use_data,     # type: bool
461                   limits=None,  # type: Optional[Tuple[float, float]]
462                   Iq_calc=None  # type: Optional[np.ndarray]
463                  ):
464    # type: (...) -> None
465    """
466    Plot the data and residuals for 1D data.
467    """
468    import matplotlib.pyplot as plt  # type: ignore
469    from numpy.ma import masked_array, masked  # type: ignore
470
471    if getattr(data, 'radial', False):
472        data.x = data.q_data
473        data.y = data.data
474
475    use_data = use_data and data.y is not None
476    use_theory = theory is not None
477    use_resid = resid is not None
478    use_calc = use_theory and Iq_calc is not None
479    num_plots = (use_data or use_theory) + use_calc + use_resid
480    non_positive_x = (data.x <= 0.0).any()
481
482    scale = data.x**4 if view == 'q4' else 1.0
483    xscale = yscale = 'linear' if view == 'linear' else 'log'
484
485    if use_data or use_theory:
486        if num_plots > 1:
487            plt.subplot(1, num_plots, 1)
488
489        #print(vmin, vmax)
490        all_positive = True
491        some_present = False
492        if use_data:
493            mdata = masked_array(data.y, data.mask.copy())
494            mdata[~np.isfinite(mdata)] = masked
495            if view is 'log':
496                mdata[mdata <= 0] = masked
497            plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
498            all_positive = all_positive and (mdata > 0).all()
499            some_present = some_present or (mdata.count() > 0)
500
501
502        if use_theory:
503            # Note: masks merge, so any masked theory points will stay masked,
504            # and the data mask will be added to it.
505            mtheory = masked_array(theory, data.mask.copy())
506            mtheory[~np.isfinite(mtheory)] = masked
507            if view is 'log':
508                mtheory[mtheory <= 0] = masked
509            plt.plot(data.x, scale*mtheory, '-')
510            all_positive = all_positive and (mtheory > 0).all()
511            some_present = some_present or (mtheory.count() > 0)
512
513        if limits is not None:
514            plt.ylim(*limits)
515
516
517        xscale = ('linear' if not some_present or non_positive_x
518                  else view if view is not None
519                  else 'log')
520        yscale = ('linear'
521                  if view == 'q4' or not some_present or not all_positive
522                  else view if view is not None
523                  else 'log')
524        plt.xscale(xscale)
525        plt.xlabel("$q$/A$^{-1}$")
526        plt.yscale(yscale)
527        plt.ylabel('$I(q)$')
528        title = ("data and model" if use_theory and use_data
529                 else "data" if use_data
530                 else "model")
531        plt.title(title)
532
533    if use_calc:
534        # Only have use_calc if have use_theory
535        plt.subplot(1, num_plots, 2)
536        qx, qy, Iqxy = Iq_calc
537        plt.pcolormesh(qx, qy[qy > 0], np.log10(Iqxy[qy > 0, :]))
538        plt.xlabel("$q_x$/A$^{-1}$")
539        plt.xlabel("$q_y$/A$^{-1}$")
540        plt.xscale('log')
541        plt.yscale('log')
542        #plt.axis('equal')
543
544    if use_resid:
545        mresid = masked_array(resid, data.mask.copy())
546        mresid[~np.isfinite(mresid)] = masked
547        some_present = (mresid.count() > 0)
548
549        if num_plots > 1:
550            plt.subplot(1, num_plots, use_calc + 2)
551        plt.plot(data.x, mresid, '.')
552        plt.xlabel("$q$/A$^{-1}$")
553        plt.ylabel('residuals')
554        plt.title('(model - Iq)/dIq')
555        plt.xscale(xscale)
556        plt.yscale('linear')
557
558
559@protect
560def _plot_result_sesans(data,        # type: SesansData
561                        theory,      # type: Optional[np.ndarray]
562                        resid,       # type: Optional[np.ndarray]
563                        use_data,    # type: bool
564                        limits=None  # type: Optional[Tuple[float, float]]
565                       ):
566    # type: (...) -> None
567    """
568    Plot SESANS results.
569    """
570    import matplotlib.pyplot as plt  # type: ignore
571    use_data = use_data and data.y is not None
572    use_theory = theory is not None
573    use_resid = resid is not None
574    num_plots = (use_data or use_theory) + use_resid
575
576    if use_data or use_theory:
577        is_tof = data.lam is not None and (data.lam != data.lam[0]).any()
578        if num_plots > 1:
579            plt.subplot(1, num_plots, 1)
580        if use_data:
581            if is_tof:
582                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam),
583                             yerr=data.dy/data.y/(data.lam*data.lam))
584            else:
585                plt.errorbar(data.x, data.y, yerr=data.dy)
586        if theory is not None:
587            if is_tof:
588                plt.plot(data.x, np.log(theory)/(data.lam*data.lam), '-')
589            else:
590                plt.plot(data.x, theory, '-')
591        if limits is not None:
592            plt.ylim(*limits)
593
594        plt.xlabel('spin echo length ({})'.format(data._xunit))
595        if is_tof:
596            plt.ylabel(r'(Log (P/P$_0$))/$\lambda^2$')
597        else:
598            plt.ylabel('polarization (P/P0)')
599
600
601    if resid is not None:
602        if num_plots > 1:
603            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
604        plt.plot(data.x, resid, 'x')
605        plt.xlabel('spin echo length ({})'.format(data._xunit))
606        plt.ylabel('residuals (P/P0)')
607
608
609@protect
610def _plot_result2D(data,         # type: Data2D
611                   theory,       # type: Optional[np.ndarray]
612                   resid,        # type: Optional[np.ndarray]
613                   view,         # type: str
614                   use_data,     # type: bool
615                   limits=None   # type: Optional[Tuple[float, float]]
616                  ):
617    # type: (...) -> None
618    """
619    Plot the data and residuals for 2D data.
620    """
621    import matplotlib.pyplot as plt  # type: ignore
622    use_data = use_data and data.data is not None
623    use_theory = theory is not None
624    use_resid = resid is not None
625    num_plots = use_data + use_theory + use_resid
626
627    # Put theory and data on a common colormap scale
628    vmin, vmax = np.inf, -np.inf
629    target = None # type: Optional[np.ndarray]
630    if use_data:
631        target = data.data[~data.mask]
632        datamin = target[target > 0].min() if view == 'log' else target.min()
633        datamax = target.max()
634        vmin = min(vmin, datamin)
635        vmax = max(vmax, datamax)
636    if use_theory:
637        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
638        theorymax = theory.max()
639        vmin = min(vmin, theorymin)
640        vmax = max(vmax, theorymax)
641
642    # Override data limits from the caller
643    if limits is not None:
644        vmin, vmax = limits
645
646    # Plot data
647    if use_data:
648        if num_plots > 1:
649            plt.subplot(1, num_plots, 1)
650        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
651        plt.title('data')
652        h = plt.colorbar()
653        h.set_label('$I(q)$')
654
655    # plot theory
656    if use_theory:
657        if num_plots > 1:
658            plt.subplot(1, num_plots, use_data+1)
659        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
660        plt.title('theory')
661        h = plt.colorbar()
662        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
663                    else r'$q^4 I(q)$' if view == 'q4'
664                    else '$I(q)$')
665
666    # plot resid
667    if use_resid:
668        if num_plots > 1:
669            plt.subplot(1, num_plots, use_data+use_theory+1)
670        _plot_2d_signal(data, resid, view='linear')
671        plt.title('residuals')
672        h = plt.colorbar()
673        h.set_label(r'$\Delta I(q)$')
674
675
676@protect
677def _plot_2d_signal(data,       # type: Data2D
678                    signal,     # type: np.ndarray
679                    vmin=None,  # type: Optional[float]
680                    vmax=None,  # type: Optional[float]
681                    view='log'  # type: str
682                   ):
683    # type: (...) -> Tuple[float, float]
684    """
685    Plot the target value for the data.  This could be the data itself,
686    the theory calculation, or the residuals.
687
688    *scale* can be 'log' for log scale data, or 'linear'.
689    """
690    import matplotlib.pyplot as plt  # type: ignore
691    from numpy.ma import masked_array  # type: ignore
692
693    image = np.zeros_like(data.qx_data)
694    image[~data.mask] = signal
695    valid = np.isfinite(image)
696    if view == 'log':
697        valid[valid] = (image[valid] > 0)
698        if vmin is None:
699            vmin = image[valid & ~data.mask].min()
700        if vmax is None:
701            vmax = image[valid & ~data.mask].max()
702        image[valid] = np.log10(image[valid])
703    elif view == 'q4':
704        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
705        if vmin is None:
706            vmin = image[valid & ~data.mask].min()
707        if vmax is None:
708            vmax = image[valid & ~data.mask].max()
709    else:
710        if vmin is None:
711            vmin = image[valid & ~data.mask].min()
712        if vmax is None:
713            vmax = image[valid & ~data.mask].max()
714
715    image[~valid | data.mask] = 0
716    #plottable = Iq
717    plottable = masked_array(image, ~valid | data.mask)
718    # Divide range by 10 to convert from angstroms to nanometers
719    xmin, xmax = min(data.qx_data), max(data.qx_data)
720    ymin, ymax = min(data.qy_data), max(data.qy_data)
721    if view == 'log':
722        vmin_scaled, vmax_scaled = np.log10(vmin), np.log10(vmax)
723    else:
724        vmin_scaled, vmax_scaled = vmin, vmax
725    #nx, ny = len(data.x_bins), len(data.y_bins)
726    x_bins, y_bins, image = _build_matrix(data, plottable)
727    plt.imshow(image,
728               interpolation='nearest', aspect=1, origin='lower',
729               extent=[xmin, xmax, ymin, ymax],
730               vmin=vmin_scaled, vmax=vmax_scaled)
731    plt.xlabel("$q_x$/A$^{-1}$")
732    plt.ylabel("$q_y$/A$^{-1}$")
733    return vmin, vmax
734
735
736# === The following is modified from sas.sasgui.plottools.PlotPanel
737def _build_matrix(self, plottable):
738    """
739    Build a matrix for 2d plot from a vector
740    Returns a matrix (image) with ~ square binning
741    Requirement: need 1d array formats of
742    self.data, self.qx_data, and self.qy_data
743    where each one corresponds to z, x, or y axis values
744
745    """
746    # No qx or qy given in a vector format
747    if self.qx_data is None or self.qy_data is None \
748            or self.qx_data.ndim != 1 or self.qy_data.ndim != 1:
749        return self.x_bins, self.y_bins, plottable
750
751    # maximum # of loops to fillup_pixels
752    # otherwise, loop could never stop depending on data
753    max_loop = 1
754    # get the x and y_bin arrays.
755    x_bins, y_bins = _get_bins(self)
756    # set zero to None
757
758    #Note: Can not use scipy.interpolate.Rbf:
759    # 'cause too many data points (>10000)<=JHC.
760    # 1d array to use for weighting the data point averaging
761    #when they fall into a same bin.
762    weights_data = np.ones([self.data.size])
763    # get histogram of ones w/len(data); this will provide
764    #the weights of data on each bins
765    weights, xedges, yedges = np.histogram2d(x=self.qy_data,
766                                             y=self.qx_data,
767                                             bins=[y_bins, x_bins],
768                                             weights=weights_data)
769    # get histogram of data, all points into a bin in a way of summing
770    image, xedges, yedges = np.histogram2d(x=self.qy_data,
771                                           y=self.qx_data,
772                                           bins=[y_bins, x_bins],
773                                           weights=plottable)
774    # Now, normalize the image by weights only for weights>1:
775    # If weight == 1, there is only one data point in the bin so
776    # that no normalization is required.
777    image[weights > 1] = image[weights > 1] / weights[weights > 1]
778    # Set image bins w/o a data point (weight==0) as None (was set to zero
779    # by histogram2d.)
780    image[weights == 0] = None
781
782    # Fill empty bins with 8 nearest neighbors only when at least
783    #one None point exists
784    loop = 0
785
786    # do while loop until all vacant bins are filled up up
787    #to loop = max_loop
788    while (weights == 0).any():
789        if loop >= max_loop:  # this protects never-ending loop
790            break
791        image = _fillup_pixels(image=image, weights=weights)
792        loop += 1
793
794    return x_bins, y_bins, image
795
796def _get_bins(self):
797    """
798    get bins
799    set x_bins and y_bins into self, 1d arrays of the index with
800    ~ square binning
801    Requirement: need 1d array formats of
802    self.qx_data, and self.qy_data
803    where each one corresponds to  x, or y axis values
804    """
805    # find max and min values of qx and qy
806    xmax = self.qx_data.max()
807    xmin = self.qx_data.min()
808    ymax = self.qy_data.max()
809    ymin = self.qy_data.min()
810
811    # calculate the range of qx and qy: this way, it is a little
812    # more independent
813    x_size = xmax - xmin
814    y_size = ymax - ymin
815
816    # estimate the # of pixels on each axes
817    npix_y = int(np.floor(np.sqrt(len(self.qy_data))))
818    npix_x = int(np.floor(len(self.qy_data) / npix_y))
819
820    # bin size: x- & y-directions
821    xstep = x_size / (npix_x - 1)
822    ystep = y_size / (npix_y - 1)
823
824    # max and min taking account of the bin sizes
825    xmax = xmax + xstep / 2.0
826    xmin = xmin - xstep / 2.0
827    ymax = ymax + ystep / 2.0
828    ymin = ymin - ystep / 2.0
829
830    # store x and y bin centers in q space
831    x_bins = np.linspace(xmin, xmax, npix_x)
832    y_bins = np.linspace(ymin, ymax, npix_y)
833
834    return x_bins, y_bins
835
836def _fillup_pixels(image=None, weights=None):
837    """
838    Fill z values of the empty cells of 2d image matrix
839    with the average over up-to next nearest neighbor points
840
841    :param image: (2d matrix with some zi = None)
842
843    :return: image (2d array )
844
845    :TODO: Find better way to do for-loop below
846
847    """
848    # No image matrix given
849    if image is None or np.ndim(image) != 2 \
850            or np.isfinite(image).all() \
851            or weights is None:
852        return image
853    # Get bin size in y and x directions
854    len_y = len(image)
855    len_x = len(image[1])
856    temp_image = np.zeros([len_y, len_x])
857    weit = np.zeros([len_y, len_x])
858    # do for-loop for all pixels
859    for n_y in range(len(image)):
860        for n_x in range(len(image[1])):
861            # find only null pixels
862            if weights[n_y][n_x] > 0 or np.isfinite(image[n_y][n_x]):
863                continue
864            else:
865                # find 4 nearest neighbors
866                # check where or not it is at the corner
867                if n_y != 0 and np.isfinite(image[n_y - 1][n_x]):
868                    temp_image[n_y][n_x] += image[n_y - 1][n_x]
869                    weit[n_y][n_x] += 1
870                if n_x != 0 and np.isfinite(image[n_y][n_x - 1]):
871                    temp_image[n_y][n_x] += image[n_y][n_x - 1]
872                    weit[n_y][n_x] += 1
873                if n_y != len_y - 1 and np.isfinite(image[n_y + 1][n_x]):
874                    temp_image[n_y][n_x] += image[n_y + 1][n_x]
875                    weit[n_y][n_x] += 1
876                if n_x != len_x - 1 and np.isfinite(image[n_y][n_x + 1]):
877                    temp_image[n_y][n_x] += image[n_y][n_x + 1]
878                    weit[n_y][n_x] += 1
879                # go 4 next nearest neighbors when no non-zero
880                # neighbor exists
881                if n_y != 0 and n_x != 0 and \
882                        np.isfinite(image[n_y - 1][n_x - 1]):
883                    temp_image[n_y][n_x] += image[n_y - 1][n_x - 1]
884                    weit[n_y][n_x] += 1
885                if n_y != len_y - 1 and n_x != 0 and \
886                        np.isfinite(image[n_y + 1][n_x - 1]):
887                    temp_image[n_y][n_x] += image[n_y + 1][n_x - 1]
888                    weit[n_y][n_x] += 1
889                if n_y != len_y and n_x != len_x - 1 and \
890                        np.isfinite(image[n_y - 1][n_x + 1]):
891                    temp_image[n_y][n_x] += image[n_y - 1][n_x + 1]
892                    weit[n_y][n_x] += 1
893                if n_y != len_y - 1 and n_x != len_x - 1 and \
894                        np.isfinite(image[n_y + 1][n_x + 1]):
895                    temp_image[n_y][n_x] += image[n_y + 1][n_x + 1]
896                    weit[n_y][n_x] += 1
897
898    # get it normalized
899    ind = (weit > 0)
900    image[ind] = temp_image[ind] / weit[ind]
901
902    return image
903
904
905def demo():
906    # type: () -> None
907    """
908    Load and plot a SAS dataset.
909    """
910    data = load_data('DEC07086.DAT')
911    set_beam_stop(data, 0.004)
912    plot_data(data)
913    import matplotlib.pyplot as plt  # type: ignore
914    plt.show()
915
916
917if __name__ == "__main__":
918    demo()
Note: See TracBrowser for help on using the repository browser.