source: sasmodels/sasmodels/data.py @ d86f0fc

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since d86f0fc was d86f0fc, checked in by Paul Kienzle <pkienzle@…>, 21 months ago

lint reduction

  • Property mode set to 100644
File size: 29.5 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np  # type: ignore
38
39# pylint: disable=unused-import
40try:
41    from typing import Union, Dict, List, Optional
42except ImportError:
43    pass
44else:
45    Data = Union["Data1D", "Data2D", "SesansData"]
46# pylint: enable=unused-import
47
48def load_data(filename, index=0):
49    # type: (str) -> Data
50    """
51    Load data using a sasview loader.
52    """
53    from sas.sascalc.dataloader.loader import Loader  # type: ignore
54    loader = Loader()
55    # Allow for one part in multipart file
56    if '[' in filename:
57        filename, indexstr = filename[:-1].split('[')
58        index = int(indexstr)
59    datasets = loader.load(filename)
60    if not datasets:  # None or []
61        raise IOError("Data %r could not be loaded" % filename)
62    if not isinstance(datasets, list):
63        datasets = [datasets]
64    for data in datasets:
65        if hasattr(data, 'x'):
66            data.qmin, data.qmax = data.x.min(), data.x.max()
67            data.mask = (np.isnan(data.y) if data.y is not None
68                         else np.zeros_like(data.x, dtype='bool'))
69        elif hasattr(data, 'qx_data'):
70            data.mask = ~data.mask
71    return datasets[index] if index != 'all' else datasets
72
73
74def set_beam_stop(data, radius, outer=None):
75    # type: (Data, float, Optional[float]) -> None
76    """
77    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
78    """
79    from sas.sascalc.dataloader.manipulations import Ringcut
80    if hasattr(data, 'qx_data'):
81        data.mask = Ringcut(0, radius)(data)
82        if outer is not None:
83            data.mask += Ringcut(outer, np.inf)(data)
84    else:
85        data.mask = (data.x < radius)
86        if outer is not None:
87            data.mask |= (data.x >= outer)
88
89
90def set_half(data, half):
91    # type: (Data, str) -> None
92    """
93    Select half of the data, either "right" or "left".
94    """
95    from sas.sascalc.dataloader.manipulations import Boxcut
96    if half == 'right':
97        data.mask += \
98            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
99    if half == 'left':
100        data.mask += \
101            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
102
103
104def set_top(data, cutoff):
105    # type: (Data, float) -> None
106    """
107    Chop the top off the data, above *cutoff*.
108    """
109    from sas.sascalc.dataloader.manipulations import Boxcut
110    data.mask += \
111        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
112
113
114class Data1D(object):
115    """
116    1D data object.
117
118    Note that this definition matches the attributes from sasview, with
119    some generic 1D data vectors and some SAS specific definitions.  Some
120    refactoring to allow consistent naming conventions between 1D, 2D and
121    SESANS data would be helpful.
122
123    **Attributes**
124
125    *x*, *dx*: $q$ vector and gaussian resolution
126
127    *y*, *dy*: $I(q)$ vector and measurement uncertainty
128
129    *mask*: values to include in plotting/analysis
130
131    *dxl*: slit widths for slit smeared data, with *dx* ignored
132
133    *qmin*, *qmax*: range of $q$ values in *x*
134
135    *filename*: label for the data line
136
137    *_xaxis*, *_xunit*: label and units for the *x* axis
138
139    *_yaxis*, *_yunit*: label and units for the *y* axis
140    """
141    def __init__(self,
142                 x=None,  # type: Optional[np.ndarray]
143                 y=None,  # type: Optional[np.ndarray]
144                 dx=None, # type: Optional[np.ndarray]
145                 dy=None  # type: Optional[np.ndarray]
146                ):
147        # type: (...) -> None
148        self.x, self.y, self.dx, self.dy = x, y, dx, dy
149        self.dxl = None
150        self.filename = None
151        self.qmin = x.min() if x is not None else np.NaN
152        self.qmax = x.max() if x is not None else np.NaN
153        # TODO: why is 1D mask False and 2D mask True?
154        self.mask = (np.isnan(y) if y is not None
155                     else np.zeros_like(x, 'b') if x is not None
156                     else None)
157        self._xaxis, self._xunit = "x", ""
158        self._yaxis, self._yunit = "y", ""
159
160    def xaxis(self, label, unit):
161        # type: (str, str) -> None
162        """
163        set the x axis label and unit
164        """
165        self._xaxis = label
166        self._xunit = unit
167
168    def yaxis(self, label, unit):
169        # type: (str, str) -> None
170        """
171        set the y axis label and unit
172        """
173        self._yaxis = label
174        self._yunit = unit
175
176class SesansData(Data1D):
177    """
178    SESANS data object.
179
180    This is just :class:`Data1D` with a wavelength parameter.
181
182    *x* is spin echo length and *y* is polarization (P/P0).
183    """
184    def __init__(self, **kw):
185        Data1D.__init__(self, **kw)
186        self.lam = None # type: Optional[np.ndarray]
187
188class Data2D(object):
189    """
190    2D data object.
191
192    Note that this definition matches the attributes from sasview. Some
193    refactoring to allow consistent naming conventions between 1D, 2D and
194    SESANS data would be helpful.
195
196    **Attributes**
197
198    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
199
200    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
201
202    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
203
204    *mask*: values to exclude from plotting/analysis
205
206    *qmin*, *qmax*: range of $q$ values in *x*
207
208    *filename*: label for the data line
209
210    *_xaxis*, *_xunit*: label and units for the *x* axis
211
212    *_yaxis*, *_yunit*: label and units for the *y* axis
213
214    *_zaxis*, *_zunit*: label and units for the *y* axis
215
216    *Q_unit*, *I_unit*: units for Q and intensity
217
218    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
219    """
220    def __init__(self,
221                 x=None,   # type: Optional[np.ndarray]
222                 y=None,   # type: Optional[np.ndarray]
223                 z=None,   # type: Optional[np.ndarray]
224                 dx=None,  # type: Optional[np.ndarray]
225                 dy=None,  # type: Optional[np.ndarray]
226                 dz=None   # type: Optional[np.ndarray]
227                ):
228        # type: (...) -> None
229        self.qx_data, self.dqx_data = x, dx
230        self.qy_data, self.dqy_data = y, dy
231        self.data, self.err_data = z, dz
232        self.mask = (np.isnan(z) if z is not None
233                     else np.zeros_like(x, dtype='bool') if x is not None
234                     else None)
235        self.q_data = np.sqrt(x**2 + y**2)
236        self.qmin = 1e-16
237        self.qmax = np.inf
238        self.detector = []
239        self.source = Source()
240        self.Q_unit = "1/A"
241        self.I_unit = "1/cm"
242        self.xaxis("Q_x", "1/A")
243        self.yaxis("Q_y", "1/A")
244        self.zaxis("Intensity", "1/cm")
245        self._xaxis, self._xunit = "x", ""
246        self._yaxis, self._yunit = "y", ""
247        self._zaxis, self._zunit = "z", ""
248        self.x_bins, self.y_bins = None, None
249        self.filename = None
250
251    def xaxis(self, label, unit):
252        # type: (str, str) -> None
253        """
254        set the x axis label and unit
255        """
256        self._xaxis = label
257        self._xunit = unit
258
259    def yaxis(self, label, unit):
260        # type: (str, str) -> None
261        """
262        set the y axis label and unit
263        """
264        self._yaxis = label
265        self._yunit = unit
266
267    def zaxis(self, label, unit):
268        # type: (str, str) -> None
269        """
270        set the y axis label and unit
271        """
272        self._zaxis = label
273        self._zunit = unit
274
275
276class Vector(object):
277    """
278    3-space vector of *x*, *y*, *z*
279    """
280    def __init__(self, x=None, y=None, z=None):
281        # type: (float, float, Optional[float]) -> None
282        self.x, self.y, self.z = x, y, z
283
284class Detector(object):
285    """
286    Detector attributes.
287    """
288    def __init__(self, pixel_size=(None, None), distance=None):
289        # type: (Tuple[float, float], float) -> None
290        self.pixel_size = Vector(*pixel_size)
291        self.distance = distance
292
293class Source(object):
294    """
295    Beam attributes.
296    """
297    def __init__(self):
298        # type: () -> None
299        self.wavelength = np.NaN
300        self.wavelength_unit = "A"
301
302
303def empty_data1D(q, resolution=0.0):
304    # type: (np.ndarray, float) -> Data1D
305    """
306    Create empty 1D data using the given *q* as the x value.
307
308    *resolution* dq/q defaults to 5%.
309    """
310
311    #Iq = 100 * np.ones_like(q)
312    #dIq = np.sqrt(Iq)
313    Iq, dIq = None, None
314    q = np.asarray(q)
315    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
316    data.filename = "fake data"
317    return data
318
319
320def empty_data2D(qx, qy=None, resolution=0.0):
321    # type: (np.ndarray, Optional[np.ndarray], float) -> Data2D
322    """
323    Create empty 2D data using the given mesh.
324
325    If *qy* is missing, create a square mesh with *qy=qx*.
326
327    *resolution* dq/q defaults to 5%.
328    """
329    if qy is None:
330        qy = qx
331    qx, qy = np.asarray(qx), np.asarray(qy)
332    # 5% dQ/Q resolution
333    Qx, Qy = np.meshgrid(qx, qy)
334    Qx, Qy = Qx.flatten(), Qy.flatten()
335    Iq = 100 * np.ones_like(Qx)  # type: np.ndarray
336    dIq = np.sqrt(Iq)
337    if resolution != 0:
338        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
339        # Should have an additional constant which depends on distances and
340        # radii of the aperture, pixel dimensions and wavelength spread
341        # Instead, assume radial dQ/Q is constant, and perpendicular matches
342        # radial (which instead it should be inverse).
343        Q = np.sqrt(Qx**2 + Qy**2)
344        dqx = resolution * Q
345        dqy = resolution * Q
346    else:
347        dqx = dqy = None
348
349    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
350    data.x_bins = qx
351    data.y_bins = qy
352    data.filename = "fake data"
353
354    # pixel_size in mm, distance in m
355    detector = Detector(pixel_size=(5, 5), distance=4)
356    data.detector.append(detector)
357    data.source.wavelength = 5 # angstroms
358    data.source.wavelength_unit = "A"
359    return data
360
361
362def plot_data(data, view='log', limits=None):
363    # type: (Data, str, Optional[Tuple[float, float]]) -> None
364    """
365    Plot data loaded by the sasview loader.
366
367    *data* is a sasview data object, either 1D, 2D or SESANS.
368
369    *view* is log or linear.
370
371    *limits* sets the intensity limits on the plot; if None then the limits
372    are inferred from the data.
373    """
374    # Note: kind of weird using the plot result functions to plot just the
375    # data, but they already handle the masking and graph markup already, so
376    # do not repeat.
377    if hasattr(data, 'isSesans') and data.isSesans:
378        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
379    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
380        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
381    else:
382        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
383
384
385def plot_theory(data,          # type: Data
386                theory,        # type: Optional[np.ndarray]
387                resid=None,    # type: Optional[np.ndarray]
388                view='log',    # type: str
389                use_data=True, # type: bool
390                limits=None,   # type: Optional[np.ndarray]
391                Iq_calc=None   # type: Optional[np.ndarray]
392               ):
393    # type: (...) -> None
394    """
395    Plot theory calculation.
396
397    *data* is needed to define the graph properties such as labels and
398    units, and to define the data mask.
399
400    *theory* is a matrix of the same shape as the data.
401
402    *view* is log or linear
403
404    *use_data* is True if the data should be plotted as well as the theory.
405
406    *limits* sets the intensity limits on the plot; if None then the limits
407    are inferred from the data.
408
409    *Iq_calc* is the raw theory values without resolution smearing
410    """
411    if hasattr(data, 'isSesans') and data.isSesans:
412        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
413    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
414        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
415    else:
416        _plot_result1D(data, theory, resid, view, use_data,
417                       limits=limits, Iq_calc=Iq_calc)
418
419
420def protect(func):
421    # type: (Callable) -> Callable
422    """
423    Decorator to wrap calls in an exception trapper which prints the
424    exception and continues.  Keyboard interrupts are ignored.
425    """
426    def wrapper(*args, **kw):
427        """
428        Trap and print errors from function.
429        """
430        try:
431            return func(*args, **kw)
432        except Exception:
433            traceback.print_exc()
434
435    return wrapper
436
437
438@protect
439def _plot_result1D(data,         # type: Data1D
440                   theory,       # type: Optional[np.ndarray]
441                   resid,        # type: Optional[np.ndarray]
442                   view,         # type: str
443                   use_data,     # type: bool
444                   limits=None,  # type: Optional[Tuple[float, float]]
445                   Iq_calc=None  # type: Optional[np.ndarray]
446                  ):
447    # type: (...) -> None
448    """
449    Plot the data and residuals for 1D data.
450    """
451    import matplotlib.pyplot as plt  # type: ignore
452    from numpy.ma import masked_array, masked  # type: ignore
453
454    if getattr(data, 'radial', False):
455        data.x = data.q_data
456        data.y = data.data
457
458    use_data = use_data and data.y is not None
459    use_theory = theory is not None
460    use_resid = resid is not None
461    use_calc = use_theory and Iq_calc is not None
462    num_plots = (use_data or use_theory) + use_calc + use_resid
463    non_positive_x = (data.x <= 0.0).any()
464
465    scale = data.x**4 if view == 'q4' else 1.0
466    xscale = yscale = 'linear' if view == 'linear' else 'log'
467
468    if use_data or use_theory:
469        if num_plots > 1:
470            plt.subplot(1, num_plots, 1)
471
472        #print(vmin, vmax)
473        all_positive = True
474        some_present = False
475        if use_data:
476            mdata = masked_array(data.y, data.mask.copy())
477            mdata[~np.isfinite(mdata)] = masked
478            if view is 'log':
479                mdata[mdata <= 0] = masked
480            plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
481            all_positive = all_positive and (mdata > 0).all()
482            some_present = some_present or (mdata.count() > 0)
483
484
485        if use_theory:
486            # Note: masks merge, so any masked theory points will stay masked,
487            # and the data mask will be added to it.
488            mtheory = masked_array(theory, data.mask.copy())
489            mtheory[~np.isfinite(mtheory)] = masked
490            if view is 'log':
491                mtheory[mtheory <= 0] = masked
492            plt.plot(data.x, scale*mtheory, '-')
493            all_positive = all_positive and (mtheory > 0).all()
494            some_present = some_present or (mtheory.count() > 0)
495
496        if limits is not None:
497            plt.ylim(*limits)
498
499
500        xscale = ('linear' if not some_present or non_positive_x
501                  else view if view is not None
502                  else 'log')
503        yscale = ('linear'
504                  if view == 'q4' or not some_present or not all_positive
505                  else view if view is not None
506                  else 'log')
507        plt.xscale(xscale)
508        plt.xlabel("$q$/A$^{-1}$")
509        plt.yscale(yscale)
510        plt.ylabel('$I(q)$')
511        title = ("data and model" if use_theory and use_data
512                 else "data" if use_data
513                 else "model")
514        plt.title(title)
515
516    if use_calc:
517        # Only have use_calc if have use_theory
518        plt.subplot(1, num_plots, 2)
519        qx, qy, Iqxy = Iq_calc
520        plt.pcolormesh(qx, qy[qy > 0], np.log10(Iqxy[qy > 0, :]))
521        plt.xlabel("$q_x$/A$^{-1}$")
522        plt.xlabel("$q_y$/A$^{-1}$")
523        plt.xscale('log')
524        plt.yscale('log')
525        #plt.axis('equal')
526
527    if use_resid:
528        mresid = masked_array(resid, data.mask.copy())
529        mresid[~np.isfinite(mresid)] = masked
530        some_present = (mresid.count() > 0)
531
532        if num_plots > 1:
533            plt.subplot(1, num_plots, use_calc + 2)
534        plt.plot(data.x, mresid, '.')
535        plt.xlabel("$q$/A$^{-1}$")
536        plt.ylabel('residuals')
537        plt.title('(model - Iq)/dIq')
538        plt.xscale(xscale)
539        plt.yscale('linear')
540
541
542@protect
543def _plot_result_sesans(data,        # type: SesansData
544                        theory,      # type: Optional[np.ndarray]
545                        resid,       # type: Optional[np.ndarray]
546                        use_data,    # type: bool
547                        limits=None  # type: Optional[Tuple[float, float]]
548                       ):
549    # type: (...) -> None
550    """
551    Plot SESANS results.
552    """
553    import matplotlib.pyplot as plt  # type: ignore
554    use_data = use_data and data.y is not None
555    use_theory = theory is not None
556    use_resid = resid is not None
557    num_plots = (use_data or use_theory) + use_resid
558
559    if use_data or use_theory:
560        is_tof = data.lam is not None and (data.lam != data.lam[0]).any()
561        if num_plots > 1:
562            plt.subplot(1, num_plots, 1)
563        if use_data:
564            if is_tof:
565                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam),
566                             yerr=data.dy/data.y/(data.lam*data.lam))
567            else:
568                plt.errorbar(data.x, data.y, yerr=data.dy)
569        if theory is not None:
570            if is_tof:
571                plt.plot(data.x, np.log(theory)/(data.lam*data.lam), '-')
572            else:
573                plt.plot(data.x, theory, '-')
574        if limits is not None:
575            plt.ylim(*limits)
576
577        plt.xlabel('spin echo length ({})'.format(data._xunit))
578        if is_tof:
579            plt.ylabel(r'(Log (P/P$_0$))/$\lambda^2$')
580        else:
581            plt.ylabel('polarization (P/P0)')
582
583
584    if resid is not None:
585        if num_plots > 1:
586            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
587        plt.plot(data.x, resid, 'x')
588        plt.xlabel('spin echo length ({})'.format(data._xunit))
589        plt.ylabel('residuals (P/P0)')
590
591
592@protect
593def _plot_result2D(data,         # type: Data2D
594                   theory,       # type: Optional[np.ndarray]
595                   resid,        # type: Optional[np.ndarray]
596                   view,         # type: str
597                   use_data,     # type: bool
598                   limits=None   # type: Optional[Tuple[float, float]]
599                  ):
600    # type: (...) -> None
601    """
602    Plot the data and residuals for 2D data.
603    """
604    import matplotlib.pyplot as plt  # type: ignore
605    use_data = use_data and data.data is not None
606    use_theory = theory is not None
607    use_resid = resid is not None
608    num_plots = use_data + use_theory + use_resid
609
610    # Put theory and data on a common colormap scale
611    vmin, vmax = np.inf, -np.inf
612    target = None # type: Optional[np.ndarray]
613    if use_data:
614        target = data.data[~data.mask]
615        datamin = target[target > 0].min() if view == 'log' else target.min()
616        datamax = target.max()
617        vmin = min(vmin, datamin)
618        vmax = max(vmax, datamax)
619    if use_theory:
620        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
621        theorymax = theory.max()
622        vmin = min(vmin, theorymin)
623        vmax = max(vmax, theorymax)
624
625    # Override data limits from the caller
626    if limits is not None:
627        vmin, vmax = limits
628
629    # Plot data
630    if use_data:
631        if num_plots > 1:
632            plt.subplot(1, num_plots, 1)
633        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
634        plt.title('data')
635        h = plt.colorbar()
636        h.set_label('$I(q)$')
637
638    # plot theory
639    if use_theory:
640        if num_plots > 1:
641            plt.subplot(1, num_plots, use_data+1)
642        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
643        plt.title('theory')
644        h = plt.colorbar()
645        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
646                    else r'$q^4 I(q)$' if view == 'q4'
647                    else '$I(q)$')
648
649    # plot resid
650    if use_resid:
651        if num_plots > 1:
652            plt.subplot(1, num_plots, use_data+use_theory+1)
653        _plot_2d_signal(data, resid, view='linear')
654        plt.title('residuals')
655        h = plt.colorbar()
656        h.set_label(r'$\Delta I(q)$')
657
658
659@protect
660def _plot_2d_signal(data,       # type: Data2D
661                    signal,     # type: np.ndarray
662                    vmin=None,  # type: Optional[float]
663                    vmax=None,  # type: Optional[float]
664                    view='log'  # type: str
665                   ):
666    # type: (...) -> Tuple[float, float]
667    """
668    Plot the target value for the data.  This could be the data itself,
669    the theory calculation, or the residuals.
670
671    *scale* can be 'log' for log scale data, or 'linear'.
672    """
673    import matplotlib.pyplot as plt  # type: ignore
674    from numpy.ma import masked_array  # type: ignore
675
676    image = np.zeros_like(data.qx_data)
677    image[~data.mask] = signal
678    valid = np.isfinite(image)
679    if view == 'log':
680        valid[valid] = (image[valid] > 0)
681        if vmin is None:
682            vmin = image[valid & ~data.mask].min()
683        if vmax is None:
684            vmax = image[valid & ~data.mask].max()
685        image[valid] = np.log10(image[valid])
686    elif view == 'q4':
687        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
688        if vmin is None:
689            vmin = image[valid & ~data.mask].min()
690        if vmax is None:
691            vmax = image[valid & ~data.mask].max()
692    else:
693        if vmin is None:
694            vmin = image[valid & ~data.mask].min()
695        if vmax is None:
696            vmax = image[valid & ~data.mask].max()
697
698    image[~valid | data.mask] = 0
699    #plottable = Iq
700    plottable = masked_array(image, ~valid | data.mask)
701    # Divide range by 10 to convert from angstroms to nanometers
702    xmin, xmax = min(data.qx_data), max(data.qx_data)
703    ymin, ymax = min(data.qy_data), max(data.qy_data)
704    if view == 'log':
705        vmin_scaled, vmax_scaled = np.log10(vmin), np.log10(vmax)
706    else:
707        vmin_scaled, vmax_scaled = vmin, vmax
708    #nx, ny = len(data.x_bins), len(data.y_bins)
709    x_bins, y_bins, image = _build_matrix(data, plottable)
710    plt.imshow(image,
711               interpolation='nearest', aspect=1, origin='lower',
712               extent=[xmin, xmax, ymin, ymax],
713               vmin=vmin_scaled, vmax=vmax_scaled)
714    plt.xlabel("$q_x$/A$^{-1}$")
715    plt.ylabel("$q_y$/A$^{-1}$")
716    return vmin, vmax
717
718
719# === The following is modified from sas.sasgui.plottools.PlotPanel
720def _build_matrix(self, plottable):
721    """
722    Build a matrix for 2d plot from a vector
723    Returns a matrix (image) with ~ square binning
724    Requirement: need 1d array formats of
725    self.data, self.qx_data, and self.qy_data
726    where each one corresponds to z, x, or y axis values
727
728    """
729    # No qx or qy given in a vector format
730    if self.qx_data is None or self.qy_data is None \
731            or self.qx_data.ndim != 1 or self.qy_data.ndim != 1:
732        return self.x_bins, self.y_bins, plottable
733
734    # maximum # of loops to fillup_pixels
735    # otherwise, loop could never stop depending on data
736    max_loop = 1
737    # get the x and y_bin arrays.
738    x_bins, y_bins = _get_bins(self)
739    # set zero to None
740
741    #Note: Can not use scipy.interpolate.Rbf:
742    # 'cause too many data points (>10000)<=JHC.
743    # 1d array to use for weighting the data point averaging
744    #when they fall into a same bin.
745    weights_data = np.ones([self.data.size])
746    # get histogram of ones w/len(data); this will provide
747    #the weights of data on each bins
748    weights, xedges, yedges = np.histogram2d(x=self.qy_data,
749                                             y=self.qx_data,
750                                             bins=[y_bins, x_bins],
751                                             weights=weights_data)
752    # get histogram of data, all points into a bin in a way of summing
753    image, xedges, yedges = np.histogram2d(x=self.qy_data,
754                                           y=self.qx_data,
755                                           bins=[y_bins, x_bins],
756                                           weights=plottable)
757    # Now, normalize the image by weights only for weights>1:
758    # If weight == 1, there is only one data point in the bin so
759    # that no normalization is required.
760    image[weights > 1] = image[weights > 1] / weights[weights > 1]
761    # Set image bins w/o a data point (weight==0) as None (was set to zero
762    # by histogram2d.)
763    image[weights == 0] = None
764
765    # Fill empty bins with 8 nearest neighbors only when at least
766    #one None point exists
767    loop = 0
768
769    # do while loop until all vacant bins are filled up up
770    #to loop = max_loop
771    while (weights == 0).any():
772        if loop >= max_loop:  # this protects never-ending loop
773            break
774        image = _fillup_pixels(image=image, weights=weights)
775        loop += 1
776
777    return x_bins, y_bins, image
778
779def _get_bins(self):
780    """
781    get bins
782    set x_bins and y_bins into self, 1d arrays of the index with
783    ~ square binning
784    Requirement: need 1d array formats of
785    self.qx_data, and self.qy_data
786    where each one corresponds to  x, or y axis values
787    """
788    # find max and min values of qx and qy
789    xmax = self.qx_data.max()
790    xmin = self.qx_data.min()
791    ymax = self.qy_data.max()
792    ymin = self.qy_data.min()
793
794    # calculate the range of qx and qy: this way, it is a little
795    # more independent
796    x_size = xmax - xmin
797    y_size = ymax - ymin
798
799    # estimate the # of pixels on each axes
800    npix_y = int(np.floor(np.sqrt(len(self.qy_data))))
801    npix_x = int(np.floor(len(self.qy_data) / npix_y))
802
803    # bin size: x- & y-directions
804    xstep = x_size / (npix_x - 1)
805    ystep = y_size / (npix_y - 1)
806
807    # max and min taking account of the bin sizes
808    xmax = xmax + xstep / 2.0
809    xmin = xmin - xstep / 2.0
810    ymax = ymax + ystep / 2.0
811    ymin = ymin - ystep / 2.0
812
813    # store x and y bin centers in q space
814    x_bins = np.linspace(xmin, xmax, npix_x)
815    y_bins = np.linspace(ymin, ymax, npix_y)
816
817    return x_bins, y_bins
818
819def _fillup_pixels(image=None, weights=None):
820    """
821    Fill z values of the empty cells of 2d image matrix
822    with the average over up-to next nearest neighbor points
823
824    :param image: (2d matrix with some zi = None)
825
826    :return: image (2d array )
827
828    :TODO: Find better way to do for-loop below
829
830    """
831    # No image matrix given
832    if image is None or np.ndim(image) != 2 \
833            or np.isfinite(image).all() \
834            or weights is None:
835        return image
836    # Get bin size in y and x directions
837    len_y = len(image)
838    len_x = len(image[1])
839    temp_image = np.zeros([len_y, len_x])
840    weit = np.zeros([len_y, len_x])
841    # do for-loop for all pixels
842    for n_y in range(len(image)):
843        for n_x in range(len(image[1])):
844            # find only null pixels
845            if weights[n_y][n_x] > 0 or np.isfinite(image[n_y][n_x]):
846                continue
847            else:
848                # find 4 nearest neighbors
849                # check where or not it is at the corner
850                if n_y != 0 and np.isfinite(image[n_y - 1][n_x]):
851                    temp_image[n_y][n_x] += image[n_y - 1][n_x]
852                    weit[n_y][n_x] += 1
853                if n_x != 0 and np.isfinite(image[n_y][n_x - 1]):
854                    temp_image[n_y][n_x] += image[n_y][n_x - 1]
855                    weit[n_y][n_x] += 1
856                if n_y != len_y - 1 and np.isfinite(image[n_y + 1][n_x]):
857                    temp_image[n_y][n_x] += image[n_y + 1][n_x]
858                    weit[n_y][n_x] += 1
859                if n_x != len_x - 1 and np.isfinite(image[n_y][n_x + 1]):
860                    temp_image[n_y][n_x] += image[n_y][n_x + 1]
861                    weit[n_y][n_x] += 1
862                # go 4 next nearest neighbors when no non-zero
863                # neighbor exists
864                if n_y != 0 and n_x != 0 and \
865                        np.isfinite(image[n_y - 1][n_x - 1]):
866                    temp_image[n_y][n_x] += image[n_y - 1][n_x - 1]
867                    weit[n_y][n_x] += 1
868                if n_y != len_y - 1 and n_x != 0 and \
869                        np.isfinite(image[n_y + 1][n_x - 1]):
870                    temp_image[n_y][n_x] += image[n_y + 1][n_x - 1]
871                    weit[n_y][n_x] += 1
872                if n_y != len_y and n_x != len_x - 1 and \
873                        np.isfinite(image[n_y - 1][n_x + 1]):
874                    temp_image[n_y][n_x] += image[n_y - 1][n_x + 1]
875                    weit[n_y][n_x] += 1
876                if n_y != len_y - 1 and n_x != len_x - 1 and \
877                        np.isfinite(image[n_y + 1][n_x + 1]):
878                    temp_image[n_y][n_x] += image[n_y + 1][n_x + 1]
879                    weit[n_y][n_x] += 1
880
881    # get it normalized
882    ind = (weit > 0)
883    image[ind] = temp_image[ind] / weit[ind]
884
885    return image
886
887
888def demo():
889    # type: () -> None
890    """
891    Load and plot a SAS dataset.
892    """
893    data = load_data('DEC07086.DAT')
894    set_beam_stop(data, 0.004)
895    plot_data(data)
896    import matplotlib.pyplot as plt  # type: ignore
897    plt.show()
898
899
900if __name__ == "__main__":
901    demo()
Note: See TracBrowser for help on using the repository browser.