source: sasmodels/sasmodels/data.py @ b297ba9

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since b297ba9 was b297ba9, checked in by Paul Kienzle <pkienzle@…>, 9 months ago

lint

  • Property mode set to 100644
File size: 30.5 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np  # type: ignore
38from numpy import sqrt, sin, cos, pi
39
40# pylint: disable=unused-import
41try:
42    from typing import Union, Dict, List, Optional
43except ImportError:
44    pass
45else:
46    Data = Union["Data1D", "Data2D", "SesansData"]
47# pylint: enable=unused-import
48
49def load_data(filename, index=0):
50    # type: (str) -> Data
51    """
52    Load data using a sasview loader.
53    """
54    from sas.sascalc.dataloader.loader import Loader  # type: ignore
55    loader = Loader()
56    # Allow for one part in multipart file
57    if '[' in filename:
58        filename, indexstr = filename[:-1].split('[')
59        index = int(indexstr)
60    datasets = loader.load(filename)
61    if not datasets:  # None or []
62        raise IOError("Data %r could not be loaded" % filename)
63    if not isinstance(datasets, list):
64        datasets = [datasets]
65    for data in datasets:
66        if hasattr(data, 'x'):
67            data.qmin, data.qmax = data.x.min(), data.x.max()
68            data.mask = (np.isnan(data.y) if data.y is not None
69                         else np.zeros_like(data.x, dtype='bool'))
70        elif hasattr(data, 'qx_data'):
71            data.mask = ~data.mask
72    return datasets[index] if index != 'all' else datasets
73
74
75def set_beam_stop(data, radius, outer=None):
76    # type: (Data, float, Optional[float]) -> None
77    """
78    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
79    """
80    from sas.sascalc.dataloader.manipulations import Ringcut
81    if hasattr(data, 'qx_data'):
82        data.mask = Ringcut(0, radius)(data)
83        if outer is not None:
84            data.mask += Ringcut(outer, np.inf)(data)
85    else:
86        data.mask = (data.x < radius)
87        if outer is not None:
88            data.mask |= (data.x >= outer)
89
90
91def set_half(data, half):
92    # type: (Data, str) -> None
93    """
94    Select half of the data, either "right" or "left".
95    """
96    from sas.sascalc.dataloader.manipulations import Boxcut
97    if half == 'right':
98        data.mask += \
99            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
100    if half == 'left':
101        data.mask += \
102            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
103
104
105def set_top(data, cutoff):
106    # type: (Data, float) -> None
107    """
108    Chop the top off the data, above *cutoff*.
109    """
110    from sas.sascalc.dataloader.manipulations import Boxcut
111    data.mask += \
112        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
113
114
115class Data1D(object):
116    """
117    1D data object.
118
119    Note that this definition matches the attributes from sasview, with
120    some generic 1D data vectors and some SAS specific definitions.  Some
121    refactoring to allow consistent naming conventions between 1D, 2D and
122    SESANS data would be helpful.
123
124    **Attributes**
125
126    *x*, *dx*: $q$ vector and gaussian resolution
127
128    *y*, *dy*: $I(q)$ vector and measurement uncertainty
129
130    *mask*: values to include in plotting/analysis
131
132    *dxl*: slit widths for slit smeared data, with *dx* ignored
133
134    *qmin*, *qmax*: range of $q$ values in *x*
135
136    *filename*: label for the data line
137
138    *_xaxis*, *_xunit*: label and units for the *x* axis
139
140    *_yaxis*, *_yunit*: label and units for the *y* axis
141    """
142    def __init__(self,
143                 x=None,  # type: Optional[np.ndarray]
144                 y=None,  # type: Optional[np.ndarray]
145                 dx=None, # type: Optional[np.ndarray]
146                 dy=None  # type: Optional[np.ndarray]
147                ):
148        # type: (...) -> None
149        self.x, self.y, self.dx, self.dy = x, y, dx, dy
150        self.dxl = None
151        self.filename = None
152        self.qmin = x.min() if x is not None else np.NaN
153        self.qmax = x.max() if x is not None else np.NaN
154        # TODO: why is 1D mask False and 2D mask True?
155        self.mask = (np.isnan(y) if y is not None
156                     else np.zeros_like(x, 'b') if x is not None
157                     else None)
158        self._xaxis, self._xunit = "x", ""
159        self._yaxis, self._yunit = "y", ""
160
161    def xaxis(self, label, unit):
162        # type: (str, str) -> None
163        """
164        set the x axis label and unit
165        """
166        self._xaxis = label
167        self._xunit = unit
168
169    def yaxis(self, label, unit):
170        # type: (str, str) -> None
171        """
172        set the y axis label and unit
173        """
174        self._yaxis = label
175        self._yunit = unit
176
177class SesansData(Data1D):
178    """
179    SESANS data object.
180
181    This is just :class:`Data1D` with a wavelength parameter.
182
183    *x* is spin echo length and *y* is polarization (P/P0).
184    """
185    isSesans = True
186    def __init__(self, **kw):
187        Data1D.__init__(self, **kw)
188        self.lam = None # type: Optional[np.ndarray]
189
190class Data2D(object):
191    """
192    2D data object.
193
194    Note that this definition matches the attributes from sasview. Some
195    refactoring to allow consistent naming conventions between 1D, 2D and
196    SESANS data would be helpful.
197
198    **Attributes**
199
200    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
201
202    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
203
204    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
205
206    *mask*: values to exclude from plotting/analysis
207
208    *qmin*, *qmax*: range of $q$ values in *x*
209
210    *filename*: label for the data line
211
212    *_xaxis*, *_xunit*: label and units for the *x* axis
213
214    *_yaxis*, *_yunit*: label and units for the *y* axis
215
216    *_zaxis*, *_zunit*: label and units for the *y* axis
217
218    *Q_unit*, *I_unit*: units for Q and intensity
219
220    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
221    """
222    def __init__(self,
223                 x=None,   # type: Optional[np.ndarray]
224                 y=None,   # type: Optional[np.ndarray]
225                 z=None,   # type: Optional[np.ndarray]
226                 dx=None,  # type: Optional[np.ndarray]
227                 dy=None,  # type: Optional[np.ndarray]
228                 dz=None   # type: Optional[np.ndarray]
229                ):
230        # type: (...) -> None
231        self.qx_data, self.dqx_data = x, dx
232        self.qy_data, self.dqy_data = y, dy
233        self.data, self.err_data = z, dz
234        self.mask = (np.isnan(z) if z is not None
235                     else np.zeros_like(x, dtype='bool') if x is not None
236                     else None)
237        self.q_data = np.sqrt(x**2 + y**2)
238        self.qmin = 1e-16
239        self.qmax = np.inf
240        self.detector = []
241        self.source = Source()
242        self.Q_unit = "1/A"
243        self.I_unit = "1/cm"
244        self.xaxis("Q_x", "1/A")
245        self.yaxis("Q_y", "1/A")
246        self.zaxis("Intensity", "1/cm")
247        self._xaxis, self._xunit = "x", ""
248        self._yaxis, self._yunit = "y", ""
249        self._zaxis, self._zunit = "z", ""
250        self.x_bins, self.y_bins = None, None
251        self.filename = None
252
253    def xaxis(self, label, unit):
254        # type: (str, str) -> None
255        """
256        set the x axis label and unit
257        """
258        self._xaxis = label
259        self._xunit = unit
260
261    def yaxis(self, label, unit):
262        # type: (str, str) -> None
263        """
264        set the y axis label and unit
265        """
266        self._yaxis = label
267        self._yunit = unit
268
269    def zaxis(self, label, unit):
270        # type: (str, str) -> None
271        """
272        set the y axis label and unit
273        """
274        self._zaxis = label
275        self._zunit = unit
276
277
278class Vector(object):
279    """
280    3-space vector of *x*, *y*, *z*
281    """
282    def __init__(self, x=None, y=None, z=None):
283        # type: (float, float, Optional[float]) -> None
284        self.x, self.y, self.z = x, y, z
285
286class Detector(object):
287    """
288    Detector attributes.
289    """
290    def __init__(self, pixel_size=(None, None), distance=None):
291        # type: (Tuple[float, float], float) -> None
292        self.pixel_size = Vector(*pixel_size)
293        self.distance = distance
294
295class Source(object):
296    """
297    Beam attributes.
298    """
299    def __init__(self):
300        # type: () -> None
301        self.wavelength = np.NaN
302        self.wavelength_unit = "A"
303
304class Sample(object):
305    """
306    Sample attributes.
307    """
308    def __init__(self):
309        # type: () -> None
310        pass
311
312def empty_data1D(q, resolution=0.0, L=0., dL=0.):
313    # type: (np.ndarray, float) -> Data1D
314    r"""
315    Create empty 1D data using the given *q* as the x value.
316
317    rms *resolution* $\Delta q/q$ defaults to 0%.  If wavelength *L* and rms
318    wavelength divergence *dL* are defined, then *resolution* defines
319    rms $\Delta \theta/\theta$ for the lowest *q*, with $\theta$ derived from
320    $q = 4\pi/\lambda \sin(\theta)$.
321    """
322
323    #Iq = 100 * np.ones_like(q)
324    #dIq = np.sqrt(Iq)
325    Iq, dIq = None, None
326    q = np.asarray(q)
327    if L != 0 and resolution != 0:
328        theta = np.arcsin(q*L/(4*pi))
329        dtheta = theta[0]*resolution
330        ## Solving Gaussian error propagation from
331        ##   Dq^2 = (dq/dL)^2 DL^2 + (dq/dtheta)^2 Dtheta^2
332        ## gives
333        ##   (Dq/q)^2 = (DL/L)**2 + (Dtheta/tan(theta))**2
334        ## Take the square root and multiply by q, giving
335        ##   Dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
336        dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
337    else:
338        dq = resolution * q
339    data = Data1D(q, Iq, dx=dq, dy=dIq)
340    data.filename = "fake data"
341    return data
342
343
344def empty_data2D(qx, qy=None, resolution=0.0):
345    # type: (np.ndarray, Optional[np.ndarray], float) -> Data2D
346    """
347    Create empty 2D data using the given mesh.
348
349    If *qy* is missing, create a square mesh with *qy=qx*.
350
351    *resolution* dq/q defaults to 5%.
352    """
353    if qy is None:
354        qy = qx
355    qx, qy = np.asarray(qx), np.asarray(qy)
356    # 5% dQ/Q resolution
357    Qx, Qy = np.meshgrid(qx, qy)
358    Qx, Qy = Qx.flatten(), Qy.flatten()
359    Iq = 100 * np.ones_like(Qx)  # type: np.ndarray
360    dIq = np.sqrt(Iq)
361    if resolution != 0:
362        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
363        # Should have an additional constant which depends on distances and
364        # radii of the aperture, pixel dimensions and wavelength spread
365        # Instead, assume radial dQ/Q is constant, and perpendicular matches
366        # radial (which instead it should be inverse).
367        Q = np.sqrt(Qx**2 + Qy**2)
368        dqx = resolution * Q
369        dqy = resolution * Q
370    else:
371        dqx = dqy = None
372
373    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
374    data.x_bins = qx
375    data.y_bins = qy
376    data.filename = "fake data"
377
378    # pixel_size in mm, distance in m
379    detector = Detector(pixel_size=(5, 5), distance=4)
380    data.detector.append(detector)
381    data.source.wavelength = 5 # angstroms
382    data.source.wavelength_unit = "A"
383    return data
384
385
386def plot_data(data, view='log', limits=None):
387    # type: (Data, str, Optional[Tuple[float, float]]) -> None
388    """
389    Plot data loaded by the sasview loader.
390
391    *data* is a sasview data object, either 1D, 2D or SESANS.
392
393    *view* is log or linear.
394
395    *limits* sets the intensity limits on the plot; if None then the limits
396    are inferred from the data.
397    """
398    # Note: kind of weird using the plot result functions to plot just the
399    # data, but they already handle the masking and graph markup already, so
400    # do not repeat.
401    if hasattr(data, 'isSesans') and data.isSesans:
402        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
403    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
404        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
405    else:
406        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
407
408
409def plot_theory(data,          # type: Data
410                theory,        # type: Optional[np.ndarray]
411                resid=None,    # type: Optional[np.ndarray]
412                view='log',    # type: str
413                use_data=True, # type: bool
414                limits=None,   # type: Optional[np.ndarray]
415                Iq_calc=None   # type: Optional[np.ndarray]
416               ):
417    # type: (...) -> None
418    """
419    Plot theory calculation.
420
421    *data* is needed to define the graph properties such as labels and
422    units, and to define the data mask.
423
424    *theory* is a matrix of the same shape as the data.
425
426    *view* is log or linear
427
428    *use_data* is True if the data should be plotted as well as the theory.
429
430    *limits* sets the intensity limits on the plot; if None then the limits
431    are inferred from the data.
432
433    *Iq_calc* is the raw theory values without resolution smearing
434    """
435    if hasattr(data, 'isSesans') and data.isSesans:
436        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
437    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
438        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
439    else:
440        _plot_result1D(data, theory, resid, view, use_data,
441                       limits=limits, Iq_calc=Iq_calc)
442
443
444def protect(func):
445    # type: (Callable) -> Callable
446    """
447    Decorator to wrap calls in an exception trapper which prints the
448    exception and continues.  Keyboard interrupts are ignored.
449    """
450    def wrapper(*args, **kw):
451        """
452        Trap and print errors from function.
453        """
454        try:
455            return func(*args, **kw)
456        except Exception:
457            traceback.print_exc()
458
459    return wrapper
460
461
462@protect
463def _plot_result1D(data,         # type: Data1D
464                   theory,       # type: Optional[np.ndarray]
465                   resid,        # type: Optional[np.ndarray]
466                   view,         # type: str
467                   use_data,     # type: bool
468                   limits=None,  # type: Optional[Tuple[float, float]]
469                   Iq_calc=None  # type: Optional[np.ndarray]
470                  ):
471    # type: (...) -> None
472    """
473    Plot the data and residuals for 1D data.
474    """
475    import matplotlib.pyplot as plt  # type: ignore
476    from numpy.ma import masked_array, masked  # type: ignore
477
478    if getattr(data, 'radial', False):
479        data.x = data.q_data
480        data.y = data.data
481
482    use_data = use_data and data.y is not None
483    use_theory = theory is not None
484    use_resid = resid is not None
485    use_calc = use_theory and Iq_calc is not None
486    num_plots = (use_data or use_theory) + use_calc + use_resid
487    non_positive_x = (data.x <= 0.0).any()
488
489    scale = data.x**4 if view == 'q4' else 1.0
490    xscale = yscale = 'linear' if view == 'linear' else 'log'
491
492    if use_data or use_theory:
493        if num_plots > 1:
494            plt.subplot(1, num_plots, 1)
495
496        #print(vmin, vmax)
497        all_positive = True
498        some_present = False
499        if use_data:
500            mdata = masked_array(data.y, data.mask.copy())
501            mdata[~np.isfinite(mdata)] = masked
502            if view == 'log':
503                mdata[mdata <= 0] = masked
504            plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
505            all_positive = all_positive and (mdata > 0).all()
506            some_present = some_present or (mdata.count() > 0)
507
508
509        if use_theory:
510            # Note: masks merge, so any masked theory points will stay masked,
511            # and the data mask will be added to it.
512            #mtheory = masked_array(theory, data.mask.copy())
513            theory_x = data.x[data.mask == 0]
514            mtheory = masked_array(theory)
515            mtheory[~np.isfinite(mtheory)] = masked
516            if view == 'log':
517                mtheory[mtheory <= 0] = masked
518            plt.plot(theory_x, scale*mtheory, '-')
519            all_positive = all_positive and (mtheory > 0).all()
520            some_present = some_present or (mtheory.count() > 0)
521
522        if limits is not None:
523            plt.ylim(*limits)
524
525
526        xscale = ('linear' if not some_present or non_positive_x
527                  else view if view is not None
528                  else 'log')
529        yscale = ('linear'
530                  if view == 'q4' or not some_present or not all_positive
531                  else view if view is not None
532                  else 'log')
533        plt.xscale(xscale)
534        plt.xlabel("$q$/A$^{-1}$")
535        plt.yscale(yscale)
536        plt.ylabel('$I(q)$')
537        title = ("data and model" if use_theory and use_data
538                 else "data" if use_data
539                 else "model")
540        plt.title(title)
541
542    if use_calc:
543        # Only have use_calc if have use_theory
544        plt.subplot(1, num_plots, 2)
545        qx, qy, Iqxy = Iq_calc
546        plt.pcolormesh(qx, qy[qy > 0], np.log10(Iqxy[qy > 0, :]))
547        plt.xlabel("$q_x$/A$^{-1}$")
548        plt.xlabel("$q_y$/A$^{-1}$")
549        plt.xscale('log')
550        plt.yscale('log')
551        #plt.axis('equal')
552
553    if use_resid:
554        theory_x = data.x[data.mask == 0]
555        mresid = masked_array(resid)
556        mresid[~np.isfinite(mresid)] = masked
557        some_present = (mresid.count() > 0)
558
559        if num_plots > 1:
560            plt.subplot(1, num_plots, use_calc + 2)
561        plt.plot(theory_x, mresid, '.')
562        plt.xlabel("$q$/A$^{-1}$")
563        plt.ylabel('residuals')
564        plt.title('(model - Iq)/dIq')
565        plt.xscale(xscale)
566        plt.yscale('linear')
567
568
569@protect
570def _plot_result_sesans(data,        # type: SesansData
571                        theory,      # type: Optional[np.ndarray]
572                        resid,       # type: Optional[np.ndarray]
573                        use_data,    # type: bool
574                        limits=None  # type: Optional[Tuple[float, float]]
575                       ):
576    # type: (...) -> None
577    """
578    Plot SESANS results.
579    """
580    import matplotlib.pyplot as plt  # type: ignore
581    use_data = use_data and data.y is not None
582    use_theory = theory is not None
583    use_resid = resid is not None
584    num_plots = (use_data or use_theory) + use_resid
585
586    if use_data or use_theory:
587        is_tof = data.lam is not None and (data.lam != data.lam[0]).any()
588        if num_plots > 1:
589            plt.subplot(1, num_plots, 1)
590        if use_data:
591            if is_tof:
592                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam),
593                             yerr=data.dy/data.y/(data.lam*data.lam))
594            else:
595                plt.errorbar(data.x, data.y, yerr=data.dy)
596        if theory is not None:
597            if is_tof:
598                plt.plot(data.x, np.log(theory)/(data.lam*data.lam), '-')
599            else:
600                plt.plot(data.x, theory, '-')
601        if limits is not None:
602            plt.ylim(*limits)
603
604        plt.xlabel('spin echo length ({})'.format(data._xunit))
605        if is_tof:
606            plt.ylabel(r'(Log (P/P$_0$))/$\lambda^2$')
607        else:
608            plt.ylabel('polarization (P/P0)')
609
610
611    if resid is not None:
612        if num_plots > 1:
613            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
614        plt.plot(data.x, resid, 'x')
615        plt.xlabel('spin echo length ({})'.format(data._xunit))
616        plt.ylabel('residuals (P/P0)')
617
618
619@protect
620def _plot_result2D(data,         # type: Data2D
621                   theory,       # type: Optional[np.ndarray]
622                   resid,        # type: Optional[np.ndarray]
623                   view,         # type: str
624                   use_data,     # type: bool
625                   limits=None   # type: Optional[Tuple[float, float]]
626                  ):
627    # type: (...) -> None
628    """
629    Plot the data and residuals for 2D data.
630    """
631    import matplotlib.pyplot as plt  # type: ignore
632    use_data = use_data and data.data is not None
633    use_theory = theory is not None
634    use_resid = resid is not None
635    num_plots = use_data + use_theory + use_resid
636
637    # Put theory and data on a common colormap scale
638    vmin, vmax = np.inf, -np.inf
639    target = None # type: Optional[np.ndarray]
640    if use_data:
641        target = data.data[~data.mask]
642        datamin = target[target > 0].min() if view == 'log' else target.min()
643        datamax = target.max()
644        vmin = min(vmin, datamin)
645        vmax = max(vmax, datamax)
646    if use_theory:
647        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
648        theorymax = theory.max()
649        vmin = min(vmin, theorymin)
650        vmax = max(vmax, theorymax)
651
652    # Override data limits from the caller
653    if limits is not None:
654        vmin, vmax = limits
655
656    # Plot data
657    if use_data:
658        if num_plots > 1:
659            plt.subplot(1, num_plots, 1)
660        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
661        plt.title('data')
662        h = plt.colorbar()
663        h.set_label('$I(q)$')
664
665    # plot theory
666    if use_theory:
667        if num_plots > 1:
668            plt.subplot(1, num_plots, use_data+1)
669        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
670        plt.title('theory')
671        h = plt.colorbar()
672        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
673                    else r'$q^4 I(q)$' if view == 'q4'
674                    else '$I(q)$')
675
676    # plot resid
677    if use_resid:
678        if num_plots > 1:
679            plt.subplot(1, num_plots, use_data+use_theory+1)
680        _plot_2d_signal(data, resid, view='linear')
681        plt.title('residuals')
682        h = plt.colorbar()
683        h.set_label(r'$\Delta I(q)$')
684
685
686@protect
687def _plot_2d_signal(data,       # type: Data2D
688                    signal,     # type: np.ndarray
689                    vmin=None,  # type: Optional[float]
690                    vmax=None,  # type: Optional[float]
691                    view='log'  # type: str
692                   ):
693    # type: (...) -> Tuple[float, float]
694    """
695    Plot the target value for the data.  This could be the data itself,
696    the theory calculation, or the residuals.
697
698    *scale* can be 'log' for log scale data, or 'linear'.
699    """
700    import matplotlib.pyplot as plt  # type: ignore
701    from numpy.ma import masked_array  # type: ignore
702
703    image = np.zeros_like(data.qx_data)
704    image[~data.mask] = signal
705    valid = np.isfinite(image)
706    if view == 'log':
707        valid[valid] = (image[valid] > 0)
708        if vmin is None:
709            vmin = image[valid & ~data.mask].min()
710        if vmax is None:
711            vmax = image[valid & ~data.mask].max()
712        image[valid] = np.log10(image[valid])
713    elif view == 'q4':
714        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
715        if vmin is None:
716            vmin = image[valid & ~data.mask].min()
717        if vmax is None:
718            vmax = image[valid & ~data.mask].max()
719    else:
720        if vmin is None:
721            vmin = image[valid & ~data.mask].min()
722        if vmax is None:
723            vmax = image[valid & ~data.mask].max()
724
725    image[~valid | data.mask] = 0
726    #plottable = Iq
727    plottable = masked_array(image, ~valid | data.mask)
728    # Divide range by 10 to convert from angstroms to nanometers
729    xmin, xmax = min(data.qx_data), max(data.qx_data)
730    ymin, ymax = min(data.qy_data), max(data.qy_data)
731    if view == 'log':
732        vmin_scaled, vmax_scaled = np.log10(vmin), np.log10(vmax)
733    else:
734        vmin_scaled, vmax_scaled = vmin, vmax
735    #nx, ny = len(data.x_bins), len(data.y_bins)
736    x_bins, y_bins, image = _build_matrix(data, plottable)
737    plt.imshow(image,
738               interpolation='nearest', aspect=1, origin='lower',
739               extent=[xmin, xmax, ymin, ymax],
740               vmin=vmin_scaled, vmax=vmax_scaled)
741    plt.xlabel("$q_x$/A$^{-1}$")
742    plt.ylabel("$q_y$/A$^{-1}$")
743    return vmin, vmax
744
745
746# === The following is modified from sas.sasgui.plottools.PlotPanel
747def _build_matrix(self, plottable):
748    """
749    Build a matrix for 2d plot from a vector
750    Returns a matrix (image) with ~ square binning
751    Requirement: need 1d array formats of
752    self.data, self.qx_data, and self.qy_data
753    where each one corresponds to z, x, or y axis values
754
755    """
756    # No qx or qy given in a vector format
757    if self.qx_data is None or self.qy_data is None \
758            or self.qx_data.ndim != 1 or self.qy_data.ndim != 1:
759        return self.x_bins, self.y_bins, plottable
760
761    # maximum # of loops to fillup_pixels
762    # otherwise, loop could never stop depending on data
763    max_loop = 1
764    # get the x and y_bin arrays.
765    x_bins, y_bins = _get_bins(self)
766    # set zero to None
767
768    #Note: Can not use scipy.interpolate.Rbf:
769    # 'cause too many data points (>10000)<=JHC.
770    # 1d array to use for weighting the data point averaging
771    #when they fall into a same bin.
772    weights_data = np.ones([self.data.size])
773    # get histogram of ones w/len(data); this will provide
774    #the weights of data on each bins
775    weights, xedges, yedges = np.histogram2d(x=self.qy_data,
776                                             y=self.qx_data,
777                                             bins=[y_bins, x_bins],
778                                             weights=weights_data)
779    # get histogram of data, all points into a bin in a way of summing
780    image, xedges, yedges = np.histogram2d(x=self.qy_data,
781                                           y=self.qx_data,
782                                           bins=[y_bins, x_bins],
783                                           weights=plottable)
784    # Now, normalize the image by weights only for weights>1:
785    # If weight == 1, there is only one data point in the bin so
786    # that no normalization is required.
787    image[weights > 1] = image[weights > 1] / weights[weights > 1]
788    # Set image bins w/o a data point (weight==0) as None (was set to zero
789    # by histogram2d.)
790    image[weights == 0] = None
791
792    # Fill empty bins with 8 nearest neighbors only when at least
793    #one None point exists
794    loop = 0
795
796    # do while loop until all vacant bins are filled up up
797    #to loop = max_loop
798    while (weights == 0).any():
799        if loop >= max_loop:  # this protects never-ending loop
800            break
801        image = _fillup_pixels(image=image, weights=weights)
802        loop += 1
803
804    return x_bins, y_bins, image
805
806def _get_bins(self):
807    """
808    get bins
809    set x_bins and y_bins into self, 1d arrays of the index with
810    ~ square binning
811    Requirement: need 1d array formats of
812    self.qx_data, and self.qy_data
813    where each one corresponds to  x, or y axis values
814    """
815    # find max and min values of qx and qy
816    xmax = self.qx_data.max()
817    xmin = self.qx_data.min()
818    ymax = self.qy_data.max()
819    ymin = self.qy_data.min()
820
821    # calculate the range of qx and qy: this way, it is a little
822    # more independent
823    x_size = xmax - xmin
824    y_size = ymax - ymin
825
826    # estimate the # of pixels on each axes
827    npix_y = int(np.floor(np.sqrt(len(self.qy_data))))
828    npix_x = int(np.floor(len(self.qy_data) / npix_y))
829
830    # bin size: x- & y-directions
831    xstep = x_size / (npix_x - 1)
832    ystep = y_size / (npix_y - 1)
833
834    # max and min taking account of the bin sizes
835    xmax = xmax + xstep / 2.0
836    xmin = xmin - xstep / 2.0
837    ymax = ymax + ystep / 2.0
838    ymin = ymin - ystep / 2.0
839
840    # store x and y bin centers in q space
841    x_bins = np.linspace(xmin, xmax, npix_x)
842    y_bins = np.linspace(ymin, ymax, npix_y)
843
844    return x_bins, y_bins
845
846def _fillup_pixels(image=None, weights=None):
847    """
848    Fill z values of the empty cells of 2d image matrix
849    with the average over up-to next nearest neighbor points
850
851    :param image: (2d matrix with some zi = None)
852
853    :return: image (2d array )
854
855    :TODO: Find better way to do for-loop below
856
857    """
858    # No image matrix given
859    if image is None or np.ndim(image) != 2 \
860            or np.isfinite(image).all() \
861            or weights is None:
862        return image
863    # Get bin size in y and x directions
864    len_y = len(image)
865    len_x = len(image[1])
866    temp_image = np.zeros([len_y, len_x])
867    weit = np.zeros([len_y, len_x])
868    # do for-loop for all pixels
869    for n_y in range(len(image)):
870        for n_x in range(len(image[1])):
871            # find only null pixels
872            if weights[n_y][n_x] > 0 or np.isfinite(image[n_y][n_x]):
873                continue
874            else:
875                # find 4 nearest neighbors
876                # check where or not it is at the corner
877                if n_y != 0 and np.isfinite(image[n_y - 1][n_x]):
878                    temp_image[n_y][n_x] += image[n_y - 1][n_x]
879                    weit[n_y][n_x] += 1
880                if n_x != 0 and np.isfinite(image[n_y][n_x - 1]):
881                    temp_image[n_y][n_x] += image[n_y][n_x - 1]
882                    weit[n_y][n_x] += 1
883                if n_y != len_y - 1 and np.isfinite(image[n_y + 1][n_x]):
884                    temp_image[n_y][n_x] += image[n_y + 1][n_x]
885                    weit[n_y][n_x] += 1
886                if n_x != len_x - 1 and np.isfinite(image[n_y][n_x + 1]):
887                    temp_image[n_y][n_x] += image[n_y][n_x + 1]
888                    weit[n_y][n_x] += 1
889                # go 4 next nearest neighbors when no non-zero
890                # neighbor exists
891                if n_y != 0 and n_x != 0 and \
892                        np.isfinite(image[n_y - 1][n_x - 1]):
893                    temp_image[n_y][n_x] += image[n_y - 1][n_x - 1]
894                    weit[n_y][n_x] += 1
895                if n_y != len_y - 1 and n_x != 0 and \
896                        np.isfinite(image[n_y + 1][n_x - 1]):
897                    temp_image[n_y][n_x] += image[n_y + 1][n_x - 1]
898                    weit[n_y][n_x] += 1
899                if n_y != len_y and n_x != len_x - 1 and \
900                        np.isfinite(image[n_y - 1][n_x + 1]):
901                    temp_image[n_y][n_x] += image[n_y - 1][n_x + 1]
902                    weit[n_y][n_x] += 1
903                if n_y != len_y - 1 and n_x != len_x - 1 and \
904                        np.isfinite(image[n_y + 1][n_x + 1]):
905                    temp_image[n_y][n_x] += image[n_y + 1][n_x + 1]
906                    weit[n_y][n_x] += 1
907
908    # get it normalized
909    ind = (weit > 0)
910    image[ind] = temp_image[ind] / weit[ind]
911
912    return image
913
914
915def demo():
916    # type: () -> None
917    """
918    Load and plot a SAS dataset.
919    """
920    data = load_data('DEC07086.DAT')
921    set_beam_stop(data, 0.004)
922    plot_data(data)
923    import matplotlib.pyplot as plt  # type: ignore
924    plt.show()
925
926
927if __name__ == "__main__":
928    demo()
Note: See TracBrowser for help on using the repository browser.