source: sasmodels/sasmodels/data.py @ bd7630d

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since bd7630d was bd7630d, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

update scripting docs. Refs #1141

  • Property mode set to 100644
File size: 30.5 KB
Line 
1"""
2SAS data representations.
3
4Plotting functions for data sets:
5
6    :func:`plot_data` plots the data file.
7
8    :func:`plot_theory` plots a calculated result from the model.
9
10Wrappers for the sasview data loader and data manipulations:
11
12    :func:`load_data` loads a sasview data file.
13
14    :func:`set_beam_stop` masks the beam stop from the data.
15
16    :func:`set_half` selects the right or left half of the data, which can
17    be useful for shear measurements which have not been properly corrected
18    for path length and reflections.
19
20    :func:`set_top` cuts the top part off the data.
21
22
23Empty data sets for evaluating models without data:
24
25    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
26    a theory function before the data is measured.
27
28    :func:`empty_data2D` creates an empty 2D dataset.
29
30Note that the empty datasets use a minimal representation of the SasView
31objects so that models can be run without SasView on the path.  You could
32also use these for your own data loader.
33
34"""
35import traceback
36
37import numpy as np  # type: ignore
38from numpy import sqrt, sin, cos, pi
39
40# pylint: disable=unused-import
41try:
42    from typing import Union, Dict, List, Optional
43except ImportError:
44    pass
45else:
46    Data = Union["Data1D", "Data2D", "SesansData"]
47# pylint: enable=unused-import
48
49def load_data(filename, index=0):
50    # type: (str) -> Data
51    """
52    Load data using a sasview loader.
53    """
54    from sas.sascalc.dataloader.loader import Loader  # type: ignore
55    loader = Loader()
56    # Allow for one part in multipart file
57    if '[' in filename:
58        filename, indexstr = filename[:-1].split('[')
59        index = int(indexstr)
60    datasets = loader.load(filename)
61    if not datasets:  # None or []
62        raise IOError("Data %r could not be loaded" % filename)
63    if not isinstance(datasets, list):
64        datasets = [datasets]
65    for data in datasets:
66        if hasattr(data, 'x'):
67            data.qmin, data.qmax = data.x.min(), data.x.max()
68            data.mask = (np.isnan(data.y) if data.y is not None
69                         else np.zeros_like(data.x, dtype='bool'))
70        elif hasattr(data, 'qx_data'):
71            data.mask = ~data.mask
72    return datasets[index] if index != 'all' else datasets
73
74
75def set_beam_stop(data, radius, outer=None):
76    # type: (Data, float, Optional[float]) -> None
77    """
78    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
79    """
80    from sas.sascalc.dataloader.manipulations import Ringcut
81    if hasattr(data, 'qx_data'):
82        data.mask = Ringcut(0, radius)(data)
83        if outer is not None:
84            data.mask += Ringcut(outer, np.inf)(data)
85    else:
86        data.mask = (data.x < radius)
87        if outer is not None:
88            data.mask |= (data.x >= outer)
89
90
91def set_half(data, half):
92    # type: (Data, str) -> None
93    """
94    Select half of the data, either "right" or "left".
95    """
96    from sas.sascalc.dataloader.manipulations import Boxcut
97    if half == 'right':
98        data.mask += \
99            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
100    if half == 'left':
101        data.mask += \
102            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
103
104
105def set_top(data, cutoff):
106    # type: (Data, float) -> None
107    """
108    Chop the top off the data, above *cutoff*.
109    """
110    from sas.sascalc.dataloader.manipulations import Boxcut
111    data.mask += \
112        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
113
114
115class Data1D(object):
116    """
117    1D data object.
118
119    Note that this definition matches the attributes from sasview, with
120    some generic 1D data vectors and some SAS specific definitions.  Some
121    refactoring to allow consistent naming conventions between 1D, 2D and
122    SESANS data would be helpful.
123
124    **Attributes**
125
126    *x*, *dx*: $q$ vector and gaussian resolution
127
128    *y*, *dy*: $I(q)$ vector and measurement uncertainty
129
130    *mask*: values to include in plotting/analysis
131
132    *dxl*: slit widths for slit smeared data, with *dx* ignored
133
134    *qmin*, *qmax*: range of $q$ values in *x*
135
136    *filename*: label for the data line
137
138    *_xaxis*, *_xunit*: label and units for the *x* axis
139
140    *_yaxis*, *_yunit*: label and units for the *y* axis
141    """
142    def __init__(self,
143                 x=None,  # type: Optional[np.ndarray]
144                 y=None,  # type: Optional[np.ndarray]
145                 dx=None, # type: Optional[np.ndarray]
146                 dy=None  # type: Optional[np.ndarray]
147                ):
148        # type: (...) -> None
149        self.x, self.y, self.dx, self.dy = x, y, dx, dy
150        self.dxl = None
151        self.filename = None
152        self.qmin = x.min() if x is not None else np.NaN
153        self.qmax = x.max() if x is not None else np.NaN
154        # TODO: why is 1D mask False and 2D mask True?
155        self.mask = (np.isnan(y) if y is not None
156                     else np.zeros_like(x, 'b') if x is not None
157                     else None)
158        self._xaxis, self._xunit = "x", ""
159        self._yaxis, self._yunit = "y", ""
160
161    def xaxis(self, label, unit):
162        # type: (str, str) -> None
163        """
164        set the x axis label and unit
165        """
166        self._xaxis = label
167        self._xunit = unit
168
169    def yaxis(self, label, unit):
170        # type: (str, str) -> None
171        """
172        set the y axis label and unit
173        """
174        self._yaxis = label
175        self._yunit = unit
176
177class SesansData(Data1D):
178    """
179    SESANS data object.
180
181    This is just :class:`Data1D` with a wavelength parameter.
182
183    *x* is spin echo length and *y* is polarization (P/P0).
184    """
185    isSesans = True
186    def __init__(self, **kw):
187        Data1D.__init__(self, **kw)
188        self.lam = None # type: Optional[np.ndarray]
189
190class Data2D(object):
191    """
192    2D data object.
193
194    Note that this definition matches the attributes from sasview. Some
195    refactoring to allow consistent naming conventions between 1D, 2D and
196    SESANS data would be helpful.
197
198    **Attributes**
199
200    *qx_data*, *dqx_data*: $q_x$ matrix and gaussian resolution
201
202    *qy_data*, *dqy_data*: $q_y$ matrix and gaussian resolution
203
204    *data*, *err_data*: $I(q)$ matrix and measurement uncertainty
205
206    *mask*: values to exclude from plotting/analysis
207
208    *qmin*, *qmax*: range of $q$ values in *x*
209
210    *filename*: label for the data line
211
212    *_xaxis*, *_xunit*: label and units for the *x* axis
213
214    *_yaxis*, *_yunit*: label and units for the *y* axis
215
216    *_zaxis*, *_zunit*: label and units for the *y* axis
217
218    *Q_unit*, *I_unit*: units for Q and intensity
219
220    *x_bins*, *y_bins*: grid steps in *x* and *y* directions
221    """
222    def __init__(self,
223                 x=None,   # type: Optional[np.ndarray]
224                 y=None,   # type: Optional[np.ndarray]
225                 z=None,   # type: Optional[np.ndarray]
226                 dx=None,  # type: Optional[np.ndarray]
227                 dy=None,  # type: Optional[np.ndarray]
228                 dz=None   # type: Optional[np.ndarray]
229                ):
230        # type: (...) -> None
231        self.qx_data, self.dqx_data = x, dx
232        self.qy_data, self.dqy_data = y, dy
233        self.data, self.err_data = z, dz
234        self.mask = (np.isnan(z) if z is not None
235                     else np.zeros_like(x, dtype='bool') if x is not None
236                     else None)
237        self.q_data = np.sqrt(x**2 + y**2)
238        self.qmin = 1e-16
239        self.qmax = np.inf
240        self.detector = []
241        self.source = Source()
242        self.Q_unit = "1/A"
243        self.I_unit = "1/cm"
244        self.xaxis("Q_x", "1/A")
245        self.yaxis("Q_y", "1/A")
246        self.zaxis("Intensity", "1/cm")
247        self._xaxis, self._xunit = "x", ""
248        self._yaxis, self._yunit = "y", ""
249        self._zaxis, self._zunit = "z", ""
250        self.x_bins, self.y_bins = None, None
251        self.filename = None
252
253    def xaxis(self, label, unit):
254        # type: (str, str) -> None
255        """
256        set the x axis label and unit
257        """
258        self._xaxis = label
259        self._xunit = unit
260
261    def yaxis(self, label, unit):
262        # type: (str, str) -> None
263        """
264        set the y axis label and unit
265        """
266        self._yaxis = label
267        self._yunit = unit
268
269    def zaxis(self, label, unit):
270        # type: (str, str) -> None
271        """
272        set the y axis label and unit
273        """
274        self._zaxis = label
275        self._zunit = unit
276
277
278class Vector(object):
279    """
280    3-space vector of *x*, *y*, *z*
281    """
282    def __init__(self, x=None, y=None, z=None):
283        # type: (float, float, Optional[float]) -> None
284        self.x, self.y, self.z = x, y, z
285
286class Detector(object):
287    """
288    Detector attributes.
289    """
290    def __init__(self, pixel_size=(None, None), distance=None):
291        # type: (Tuple[float, float], float) -> None
292        self.pixel_size = Vector(*pixel_size)
293        self.distance = distance
294
295class Source(object):
296    """
297    Beam attributes.
298    """
299    def __init__(self):
300        # type: () -> None
301        self.wavelength = np.NaN
302        self.wavelength_unit = "A"
303
304class Sample(object):
305    """
306    Sample attributes.
307    """
308    def __init__(self):
309        # type: () -> None
310        pass
311
312def empty_data1D(q, resolution=0.0, L=0., dL=0.):
313    # type: (np.ndarray, float) -> Data1D
314    r"""
315    Create empty 1D data using the given *q* as the x value.
316
317    rms *resolution* $\Delta q/q$ defaults to 0%.  If wavelength *L* and rms
318    wavelength divergence *dL* are defined, then *resolution* defines
319    rms $\Delta \theta/\theta$ for the lowest *q*, with $\theta$ derived from
320    $q = 4\pi/\lambda \sin(\theta)$.
321    """
322
323    #Iq = 100 * np.ones_like(q)
324    #dIq = np.sqrt(Iq)
325    Iq, dIq = None, None
326    q = np.asarray(q)
327    if L != 0 and resolution != 0:
328        theta = np.arcsin(q*L/(4*pi))
329        dtheta = theta[0]*resolution
330        ## Solving Gaussian error propagation from
331        ##   Dq^2 = (dq/dL)^2 DL^2 + (dq/dtheta)^2 Dtheta^2
332        ## gives
333        ##   (Dq/q)^2 = (DL/L)**2 + (Dtheta/tan(theta))**2
334        ## Take the square root and multiply by q, giving
335        ##   Dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
336        dq = (4*pi/L) * sqrt((sin(theta)*dL/L)**2 + (cos(theta)*dtheta)**2)
337    else:
338        dq = resolution * q
339
340    data = Data1D(q, Iq, dx=dq, dy=dIq)
341    data.filename = "fake data"
342    return data
343
344
345def empty_data2D(qx, qy=None, resolution=0.0):
346    # type: (np.ndarray, Optional[np.ndarray], float) -> Data2D
347    """
348    Create empty 2D data using the given mesh.
349
350    If *qy* is missing, create a square mesh with *qy=qx*.
351
352    *resolution* dq/q defaults to 5%.
353    """
354    if qy is None:
355        qy = qx
356    qx, qy = np.asarray(qx), np.asarray(qy)
357    # 5% dQ/Q resolution
358    Qx, Qy = np.meshgrid(qx, qy)
359    Qx, Qy = Qx.flatten(), Qy.flatten()
360    Iq = 100 * np.ones_like(Qx)  # type: np.ndarray
361    dIq = np.sqrt(Iq)
362    if resolution != 0:
363        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
364        # Should have an additional constant which depends on distances and
365        # radii of the aperture, pixel dimensions and wavelength spread
366        # Instead, assume radial dQ/Q is constant, and perpendicular matches
367        # radial (which instead it should be inverse).
368        Q = np.sqrt(Qx**2 + Qy**2)
369        dqx = resolution * Q
370        dqy = resolution * Q
371    else:
372        dqx = dqy = None
373
374    data = Data2D(x=Qx, y=Qy, z=Iq, dx=dqx, dy=dqy, dz=dIq)
375    data.x_bins = qx
376    data.y_bins = qy
377    data.filename = "fake data"
378
379    # pixel_size in mm, distance in m
380    detector = Detector(pixel_size=(5, 5), distance=4)
381    data.detector.append(detector)
382    data.source.wavelength = 5 # angstroms
383    data.source.wavelength_unit = "A"
384    return data
385
386
387def plot_data(data, view='log', limits=None):
388    # type: (Data, str, Optional[Tuple[float, float]]) -> None
389    """
390    Plot data loaded by the sasview loader.
391
392    *data* is a sasview data object, either 1D, 2D or SESANS.
393
394    *view* is log or linear.
395
396    *limits* sets the intensity limits on the plot; if None then the limits
397    are inferred from the data.
398    """
399    # Note: kind of weird using the plot result functions to plot just the
400    # data, but they already handle the masking and graph markup already, so
401    # do not repeat.
402    if hasattr(data, 'isSesans') and data.isSesans:
403        _plot_result_sesans(data, None, None, use_data=True, limits=limits)
404    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
405        _plot_result2D(data, None, None, view, use_data=True, limits=limits)
406    else:
407        _plot_result1D(data, None, None, view, use_data=True, limits=limits)
408
409
410def plot_theory(data,          # type: Data
411                theory,        # type: Optional[np.ndarray]
412                resid=None,    # type: Optional[np.ndarray]
413                view='log',    # type: str
414                use_data=True, # type: bool
415                limits=None,   # type: Optional[np.ndarray]
416                Iq_calc=None   # type: Optional[np.ndarray]
417               ):
418    # type: (...) -> None
419    """
420    Plot theory calculation.
421
422    *data* is needed to define the graph properties such as labels and
423    units, and to define the data mask.
424
425    *theory* is a matrix of the same shape as the data.
426
427    *view* is log or linear
428
429    *use_data* is True if the data should be plotted as well as the theory.
430
431    *limits* sets the intensity limits on the plot; if None then the limits
432    are inferred from the data.
433
434    *Iq_calc* is the raw theory values without resolution smearing
435    """
436    if hasattr(data, 'isSesans') and data.isSesans:
437        _plot_result_sesans(data, theory, resid, use_data=True, limits=limits)
438    elif hasattr(data, 'qx_data') and not getattr(data, 'radial', False):
439        _plot_result2D(data, theory, resid, view, use_data, limits=limits)
440    else:
441        _plot_result1D(data, theory, resid, view, use_data,
442                       limits=limits, Iq_calc=Iq_calc)
443
444
445def protect(func):
446    # type: (Callable) -> Callable
447    """
448    Decorator to wrap calls in an exception trapper which prints the
449    exception and continues.  Keyboard interrupts are ignored.
450    """
451    def wrapper(*args, **kw):
452        """
453        Trap and print errors from function.
454        """
455        try:
456            return func(*args, **kw)
457        except Exception:
458            traceback.print_exc()
459
460    return wrapper
461
462
463@protect
464def _plot_result1D(data,         # type: Data1D
465                   theory,       # type: Optional[np.ndarray]
466                   resid,        # type: Optional[np.ndarray]
467                   view,         # type: str
468                   use_data,     # type: bool
469                   limits=None,  # type: Optional[Tuple[float, float]]
470                   Iq_calc=None  # type: Optional[np.ndarray]
471                  ):
472    # type: (...) -> None
473    """
474    Plot the data and residuals for 1D data.
475    """
476    import matplotlib.pyplot as plt  # type: ignore
477    from numpy.ma import masked_array, masked  # type: ignore
478
479    if getattr(data, 'radial', False):
480        data.x = data.q_data
481        data.y = data.data
482
483    use_data = use_data and data.y is not None
484    use_theory = theory is not None
485    use_resid = resid is not None
486    use_calc = use_theory and Iq_calc is not None
487    num_plots = (use_data or use_theory) + use_calc + use_resid
488    non_positive_x = (data.x <= 0.0).any()
489
490    scale = data.x**4 if view == 'q4' else 1.0
491    xscale = yscale = 'linear' if view == 'linear' else 'log'
492
493    if use_data or use_theory:
494        if num_plots > 1:
495            plt.subplot(1, num_plots, 1)
496
497        #print(vmin, vmax)
498        all_positive = True
499        some_present = False
500        if use_data:
501            mdata = masked_array(data.y, data.mask.copy())
502            mdata[~np.isfinite(mdata)] = masked
503            if view is 'log':
504                mdata[mdata <= 0] = masked
505            plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
506            all_positive = all_positive and (mdata > 0).all()
507            some_present = some_present or (mdata.count() > 0)
508
509
510        if use_theory:
511            # Note: masks merge, so any masked theory points will stay masked,
512            # and the data mask will be added to it.
513            #mtheory = masked_array(theory, data.mask.copy())
514            theory_x = data.x[data.mask == 0]
515            mtheory = masked_array(theory)
516            mtheory[~np.isfinite(mtheory)] = masked
517            if view is 'log':
518                mtheory[mtheory <= 0] = masked
519            plt.plot(theory_x, scale*mtheory, '-')
520            all_positive = all_positive and (mtheory > 0).all()
521            some_present = some_present or (mtheory.count() > 0)
522
523        if limits is not None:
524            plt.ylim(*limits)
525
526
527        xscale = ('linear' if not some_present or non_positive_x
528                  else view if view is not None
529                  else 'log')
530        yscale = ('linear'
531                  if view == 'q4' or not some_present or not all_positive
532                  else view if view is not None
533                  else 'log')
534        plt.xscale(xscale)
535        plt.xlabel("$q$/A$^{-1}$")
536        plt.yscale(yscale)
537        plt.ylabel('$I(q)$')
538        title = ("data and model" if use_theory and use_data
539                 else "data" if use_data
540                 else "model")
541        plt.title(title)
542
543    if use_calc:
544        # Only have use_calc if have use_theory
545        plt.subplot(1, num_plots, 2)
546        qx, qy, Iqxy = Iq_calc
547        plt.pcolormesh(qx, qy[qy > 0], np.log10(Iqxy[qy > 0, :]))
548        plt.xlabel("$q_x$/A$^{-1}$")
549        plt.xlabel("$q_y$/A$^{-1}$")
550        plt.xscale('log')
551        plt.yscale('log')
552        #plt.axis('equal')
553
554    if use_resid:
555        theory_x = data.x[data.mask == 0]
556        mresid = masked_array(resid)
557        mresid[~np.isfinite(mresid)] = masked
558        some_present = (mresid.count() > 0)
559
560        if num_plots > 1:
561            plt.subplot(1, num_plots, use_calc + 2)
562        plt.plot(theory_x, mresid, '.')
563        plt.xlabel("$q$/A$^{-1}$")
564        plt.ylabel('residuals')
565        plt.title('(model - Iq)/dIq')
566        plt.xscale(xscale)
567        plt.yscale('linear')
568
569
570@protect
571def _plot_result_sesans(data,        # type: SesansData
572                        theory,      # type: Optional[np.ndarray]
573                        resid,       # type: Optional[np.ndarray]
574                        use_data,    # type: bool
575                        limits=None  # type: Optional[Tuple[float, float]]
576                       ):
577    # type: (...) -> None
578    """
579    Plot SESANS results.
580    """
581    import matplotlib.pyplot as plt  # type: ignore
582    use_data = use_data and data.y is not None
583    use_theory = theory is not None
584    use_resid = resid is not None
585    num_plots = (use_data or use_theory) + use_resid
586
587    if use_data or use_theory:
588        is_tof = data.lam is not None and (data.lam != data.lam[0]).any()
589        if num_plots > 1:
590            plt.subplot(1, num_plots, 1)
591        if use_data:
592            if is_tof:
593                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam),
594                             yerr=data.dy/data.y/(data.lam*data.lam))
595            else:
596                plt.errorbar(data.x, data.y, yerr=data.dy)
597        if theory is not None:
598            if is_tof:
599                plt.plot(data.x, np.log(theory)/(data.lam*data.lam), '-')
600            else:
601                plt.plot(data.x, theory, '-')
602        if limits is not None:
603            plt.ylim(*limits)
604
605        plt.xlabel('spin echo length ({})'.format(data._xunit))
606        if is_tof:
607            plt.ylabel(r'(Log (P/P$_0$))/$\lambda^2$')
608        else:
609            plt.ylabel('polarization (P/P0)')
610
611
612    if resid is not None:
613        if num_plots > 1:
614            plt.subplot(1, num_plots, (use_data or use_theory) + 1)
615        plt.plot(data.x, resid, 'x')
616        plt.xlabel('spin echo length ({})'.format(data._xunit))
617        plt.ylabel('residuals (P/P0)')
618
619
620@protect
621def _plot_result2D(data,         # type: Data2D
622                   theory,       # type: Optional[np.ndarray]
623                   resid,        # type: Optional[np.ndarray]
624                   view,         # type: str
625                   use_data,     # type: bool
626                   limits=None   # type: Optional[Tuple[float, float]]
627                  ):
628    # type: (...) -> None
629    """
630    Plot the data and residuals for 2D data.
631    """
632    import matplotlib.pyplot as plt  # type: ignore
633    use_data = use_data and data.data is not None
634    use_theory = theory is not None
635    use_resid = resid is not None
636    num_plots = use_data + use_theory + use_resid
637
638    # Put theory and data on a common colormap scale
639    vmin, vmax = np.inf, -np.inf
640    target = None # type: Optional[np.ndarray]
641    if use_data:
642        target = data.data[~data.mask]
643        datamin = target[target > 0].min() if view == 'log' else target.min()
644        datamax = target.max()
645        vmin = min(vmin, datamin)
646        vmax = max(vmax, datamax)
647    if use_theory:
648        theorymin = theory[theory > 0].min() if view == 'log' else theory.min()
649        theorymax = theory.max()
650        vmin = min(vmin, theorymin)
651        vmax = max(vmax, theorymax)
652
653    # Override data limits from the caller
654    if limits is not None:
655        vmin, vmax = limits
656
657    # Plot data
658    if use_data:
659        if num_plots > 1:
660            plt.subplot(1, num_plots, 1)
661        _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
662        plt.title('data')
663        h = plt.colorbar()
664        h.set_label('$I(q)$')
665
666    # plot theory
667    if use_theory:
668        if num_plots > 1:
669            plt.subplot(1, num_plots, use_data+1)
670        _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
671        plt.title('theory')
672        h = plt.colorbar()
673        h.set_label(r'$\log_{10}I(q)$' if view == 'log'
674                    else r'$q^4 I(q)$' if view == 'q4'
675                    else '$I(q)$')
676
677    # plot resid
678    if use_resid:
679        if num_plots > 1:
680            plt.subplot(1, num_plots, use_data+use_theory+1)
681        _plot_2d_signal(data, resid, view='linear')
682        plt.title('residuals')
683        h = plt.colorbar()
684        h.set_label(r'$\Delta I(q)$')
685
686
687@protect
688def _plot_2d_signal(data,       # type: Data2D
689                    signal,     # type: np.ndarray
690                    vmin=None,  # type: Optional[float]
691                    vmax=None,  # type: Optional[float]
692                    view='log'  # type: str
693                   ):
694    # type: (...) -> Tuple[float, float]
695    """
696    Plot the target value for the data.  This could be the data itself,
697    the theory calculation, or the residuals.
698
699    *scale* can be 'log' for log scale data, or 'linear'.
700    """
701    import matplotlib.pyplot as plt  # type: ignore
702    from numpy.ma import masked_array  # type: ignore
703
704    image = np.zeros_like(data.qx_data)
705    image[~data.mask] = signal
706    valid = np.isfinite(image)
707    if view == 'log':
708        valid[valid] = (image[valid] > 0)
709        if vmin is None:
710            vmin = image[valid & ~data.mask].min()
711        if vmax is None:
712            vmax = image[valid & ~data.mask].max()
713        image[valid] = np.log10(image[valid])
714    elif view == 'q4':
715        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
716        if vmin is None:
717            vmin = image[valid & ~data.mask].min()
718        if vmax is None:
719            vmax = image[valid & ~data.mask].max()
720    else:
721        if vmin is None:
722            vmin = image[valid & ~data.mask].min()
723        if vmax is None:
724            vmax = image[valid & ~data.mask].max()
725
726    image[~valid | data.mask] = 0
727    #plottable = Iq
728    plottable = masked_array(image, ~valid | data.mask)
729    # Divide range by 10 to convert from angstroms to nanometers
730    xmin, xmax = min(data.qx_data), max(data.qx_data)
731    ymin, ymax = min(data.qy_data), max(data.qy_data)
732    if view == 'log':
733        vmin_scaled, vmax_scaled = np.log10(vmin), np.log10(vmax)
734    else:
735        vmin_scaled, vmax_scaled = vmin, vmax
736    #nx, ny = len(data.x_bins), len(data.y_bins)
737    x_bins, y_bins, image = _build_matrix(data, plottable)
738    plt.imshow(image,
739               interpolation='nearest', aspect=1, origin='lower',
740               extent=[xmin, xmax, ymin, ymax],
741               vmin=vmin_scaled, vmax=vmax_scaled)
742    plt.xlabel("$q_x$/A$^{-1}$")
743    plt.ylabel("$q_y$/A$^{-1}$")
744    return vmin, vmax
745
746
747# === The following is modified from sas.sasgui.plottools.PlotPanel
748def _build_matrix(self, plottable):
749    """
750    Build a matrix for 2d plot from a vector
751    Returns a matrix (image) with ~ square binning
752    Requirement: need 1d array formats of
753    self.data, self.qx_data, and self.qy_data
754    where each one corresponds to z, x, or y axis values
755
756    """
757    # No qx or qy given in a vector format
758    if self.qx_data is None or self.qy_data is None \
759            or self.qx_data.ndim != 1 or self.qy_data.ndim != 1:
760        return self.x_bins, self.y_bins, plottable
761
762    # maximum # of loops to fillup_pixels
763    # otherwise, loop could never stop depending on data
764    max_loop = 1
765    # get the x and y_bin arrays.
766    x_bins, y_bins = _get_bins(self)
767    # set zero to None
768
769    #Note: Can not use scipy.interpolate.Rbf:
770    # 'cause too many data points (>10000)<=JHC.
771    # 1d array to use for weighting the data point averaging
772    #when they fall into a same bin.
773    weights_data = np.ones([self.data.size])
774    # get histogram of ones w/len(data); this will provide
775    #the weights of data on each bins
776    weights, xedges, yedges = np.histogram2d(x=self.qy_data,
777                                             y=self.qx_data,
778                                             bins=[y_bins, x_bins],
779                                             weights=weights_data)
780    # get histogram of data, all points into a bin in a way of summing
781    image, xedges, yedges = np.histogram2d(x=self.qy_data,
782                                           y=self.qx_data,
783                                           bins=[y_bins, x_bins],
784                                           weights=plottable)
785    # Now, normalize the image by weights only for weights>1:
786    # If weight == 1, there is only one data point in the bin so
787    # that no normalization is required.
788    image[weights > 1] = image[weights > 1] / weights[weights > 1]
789    # Set image bins w/o a data point (weight==0) as None (was set to zero
790    # by histogram2d.)
791    image[weights == 0] = None
792
793    # Fill empty bins with 8 nearest neighbors only when at least
794    #one None point exists
795    loop = 0
796
797    # do while loop until all vacant bins are filled up up
798    #to loop = max_loop
799    while (weights == 0).any():
800        if loop >= max_loop:  # this protects never-ending loop
801            break
802        image = _fillup_pixels(image=image, weights=weights)
803        loop += 1
804
805    return x_bins, y_bins, image
806
807def _get_bins(self):
808    """
809    get bins
810    set x_bins and y_bins into self, 1d arrays of the index with
811    ~ square binning
812    Requirement: need 1d array formats of
813    self.qx_data, and self.qy_data
814    where each one corresponds to  x, or y axis values
815    """
816    # find max and min values of qx and qy
817    xmax = self.qx_data.max()
818    xmin = self.qx_data.min()
819    ymax = self.qy_data.max()
820    ymin = self.qy_data.min()
821
822    # calculate the range of qx and qy: this way, it is a little
823    # more independent
824    x_size = xmax - xmin
825    y_size = ymax - ymin
826
827    # estimate the # of pixels on each axes
828    npix_y = int(np.floor(np.sqrt(len(self.qy_data))))
829    npix_x = int(np.floor(len(self.qy_data) / npix_y))
830
831    # bin size: x- & y-directions
832    xstep = x_size / (npix_x - 1)
833    ystep = y_size / (npix_y - 1)
834
835    # max and min taking account of the bin sizes
836    xmax = xmax + xstep / 2.0
837    xmin = xmin - xstep / 2.0
838    ymax = ymax + ystep / 2.0
839    ymin = ymin - ystep / 2.0
840
841    # store x and y bin centers in q space
842    x_bins = np.linspace(xmin, xmax, npix_x)
843    y_bins = np.linspace(ymin, ymax, npix_y)
844
845    return x_bins, y_bins
846
847def _fillup_pixels(image=None, weights=None):
848    """
849    Fill z values of the empty cells of 2d image matrix
850    with the average over up-to next nearest neighbor points
851
852    :param image: (2d matrix with some zi = None)
853
854    :return: image (2d array )
855
856    :TODO: Find better way to do for-loop below
857
858    """
859    # No image matrix given
860    if image is None or np.ndim(image) != 2 \
861            or np.isfinite(image).all() \
862            or weights is None:
863        return image
864    # Get bin size in y and x directions
865    len_y = len(image)
866    len_x = len(image[1])
867    temp_image = np.zeros([len_y, len_x])
868    weit = np.zeros([len_y, len_x])
869    # do for-loop for all pixels
870    for n_y in range(len(image)):
871        for n_x in range(len(image[1])):
872            # find only null pixels
873            if weights[n_y][n_x] > 0 or np.isfinite(image[n_y][n_x]):
874                continue
875            else:
876                # find 4 nearest neighbors
877                # check where or not it is at the corner
878                if n_y != 0 and np.isfinite(image[n_y - 1][n_x]):
879                    temp_image[n_y][n_x] += image[n_y - 1][n_x]
880                    weit[n_y][n_x] += 1
881                if n_x != 0 and np.isfinite(image[n_y][n_x - 1]):
882                    temp_image[n_y][n_x] += image[n_y][n_x - 1]
883                    weit[n_y][n_x] += 1
884                if n_y != len_y - 1 and np.isfinite(image[n_y + 1][n_x]):
885                    temp_image[n_y][n_x] += image[n_y + 1][n_x]
886                    weit[n_y][n_x] += 1
887                if n_x != len_x - 1 and np.isfinite(image[n_y][n_x + 1]):
888                    temp_image[n_y][n_x] += image[n_y][n_x + 1]
889                    weit[n_y][n_x] += 1
890                # go 4 next nearest neighbors when no non-zero
891                # neighbor exists
892                if n_y != 0 and n_x != 0 and \
893                        np.isfinite(image[n_y - 1][n_x - 1]):
894                    temp_image[n_y][n_x] += image[n_y - 1][n_x - 1]
895                    weit[n_y][n_x] += 1
896                if n_y != len_y - 1 and n_x != 0 and \
897                        np.isfinite(image[n_y + 1][n_x - 1]):
898                    temp_image[n_y][n_x] += image[n_y + 1][n_x - 1]
899                    weit[n_y][n_x] += 1
900                if n_y != len_y and n_x != len_x - 1 and \
901                        np.isfinite(image[n_y - 1][n_x + 1]):
902                    temp_image[n_y][n_x] += image[n_y - 1][n_x + 1]
903                    weit[n_y][n_x] += 1
904                if n_y != len_y - 1 and n_x != len_x - 1 and \
905                        np.isfinite(image[n_y + 1][n_x + 1]):
906                    temp_image[n_y][n_x] += image[n_y + 1][n_x + 1]
907                    weit[n_y][n_x] += 1
908
909    # get it normalized
910    ind = (weit > 0)
911    image[ind] = temp_image[ind] / weit[ind]
912
913    return image
914
915
916def demo():
917    # type: () -> None
918    """
919    Load and plot a SAS dataset.
920    """
921    data = load_data('DEC07086.DAT')
922    set_beam_stop(data, 0.004)
923    plot_data(data)
924    import matplotlib.pyplot as plt  # type: ignore
925    plt.show()
926
927
928if __name__ == "__main__":
929    demo()
Note: See TracBrowser for help on using the repository browser.