Changes in sasmodels/data.py [d6f5da6:a5b8477] in sasmodels


Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/data.py

    rd6f5da6 ra5b8477  
    3535import traceback 
    3636 
    37 import numpy as np 
     37import numpy as np  # type: ignore 
     38 
     39try: 
     40    from typing import Union, Dict, List, Optional 
     41except ImportError: 
     42    pass 
     43else: 
     44    Data = Union["Data1D", "Data2D", "SesansData"] 
    3845 
    3946def load_data(filename): 
     47    # type: (str) -> Data 
    4048    """ 
    4149    Load data using a sasview loader. 
    4250    """ 
    43     from sas.sascalc.dataloader.loader import Loader 
     51    from sas.sascalc.dataloader.loader import Loader  # type: ignore 
    4452    loader = Loader() 
    4553    data = loader.load(filename) 
     
    5058 
    5159def set_beam_stop(data, radius, outer=None): 
     60    # type: (Data, float, Optional[float]) -> None 
    5261    """ 
    5362    Add a beam stop of the given *radius*.  If *outer*, make an annulus. 
    5463    """ 
    55     from sas.dataloader.manipulations import Ringcut 
     64    from sas.dataloader.manipulations import Ringcut  # type: ignore 
    5665    if hasattr(data, 'qx_data'): 
    5766        data.mask = Ringcut(0, radius)(data) 
     
    6574 
    6675def set_half(data, half): 
     76    # type: (Data, str) -> None 
    6777    """ 
    6878    Select half of the data, either "right" or "left". 
    6979    """ 
    70     from sas.dataloader.manipulations import Boxcut 
     80    from sas.dataloader.manipulations import Boxcut  # type: ignore 
    7181    if half == 'right': 
    7282        data.mask += \ 
     
    7888 
    7989def set_top(data, cutoff): 
     90    # type: (Data, float) -> None 
    8091    """ 
    8192    Chop the top off the data, above *cutoff*. 
    8293    """ 
    83     from sas.dataloader.manipulations import Boxcut 
     94    from sas.dataloader.manipulations import Boxcut  # type: ignore 
    8495    data.mask += \ 
    8596        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data) 
     
    114125    """ 
    115126    def __init__(self, x=None, y=None, dx=None, dy=None): 
     127        # type: (Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]) -> None 
    116128        self.x, self.y, self.dx, self.dy = x, y, dx, dy 
    117129        self.dxl = None 
     
    127139 
    128140    def xaxis(self, label, unit): 
     141        # type: (str, str) -> None 
    129142        """ 
    130143        set the x axis label and unit 
     
    134147 
    135148    def yaxis(self, label, unit): 
     149        # type: (str, str) -> None 
    136150        """ 
    137151        set the y axis label and unit 
     
    140154        self._yunit = unit 
    141155 
    142  
     156class SesansData(Data1D): 
     157    def __init__(self, **kw): 
     158        Data1D.__init__(self, **kw) 
     159        self.lam = None # type: Optional[np.ndarray] 
    143160 
    144161class Data2D(object): 
     
    175192    """ 
    176193    def __init__(self, x=None, y=None, z=None, dx=None, dy=None, dz=None): 
     194        # type: (Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]) -> None 
    177195        self.qx_data, self.dqx_data = x, dx 
    178196        self.qy_data, self.dqy_data = y, dy 
     
    197215 
    198216    def xaxis(self, label, unit): 
     217        # type: (str, str) -> None 
    199218        """ 
    200219        set the x axis label and unit 
     
    204223 
    205224    def yaxis(self, label, unit): 
     225        # type: (str, str) -> None 
    206226        """ 
    207227        set the y axis label and unit 
     
    211231 
    212232    def zaxis(self, label, unit): 
     233        # type: (str, str) -> None 
    213234        """ 
    214235        set the y axis label and unit 
     
    223244    """ 
    224245    def __init__(self, x=None, y=None, z=None): 
     246        # type: (float, float, Optional[float]) -> None 
    225247        self.x, self.y, self.z = x, y, z 
    226248 
     
    230252    """ 
    231253    def __init__(self, pixel_size=(None, None), distance=None): 
     254        # type: (Tuple[float, float], float) -> None 
    232255        self.pixel_size = Vector(*pixel_size) 
    233256        self.distance = distance 
     
    238261    """ 
    239262    def __init__(self): 
     263        # type: () -> None 
    240264        self.wavelength = np.NaN 
    241265        self.wavelength_unit = "A" 
     
    243267 
    244268def empty_data1D(q, resolution=0.0): 
     269    # type: (np.ndarray, float) -> Data1D 
    245270    """ 
    246271    Create empty 1D data using the given *q* as the x value. 
     
    259284 
    260285def empty_data2D(qx, qy=None, resolution=0.0): 
     286    # type: (np.ndarray, Optional[np.ndarray], float) -> Data2D 
    261287    """ 
    262288    Create empty 2D data using the given mesh. 
     
    272298    Qx, Qy = np.meshgrid(qx, qy) 
    273299    Qx, Qy = Qx.flatten(), Qy.flatten() 
    274     Iq = 100 * np.ones_like(Qx) 
     300    Iq = 100 * np.ones_like(Qx)  # type: np.ndarray 
    275301    dIq = np.sqrt(Iq) 
    276302    if resolution != 0: 
     
    300326 
    301327def plot_data(data, view='log', limits=None): 
     328    # type: (Data, str, Optional[Tuple[float, float]]) -> None 
    302329    """ 
    303330    Plot data loaded by the sasview loader. 
     
    323350def plot_theory(data, theory, resid=None, view='log', 
    324351                use_data=True, limits=None, Iq_calc=None): 
     352    # type: (Data, Optional[np.ndarray], Optional[np.ndarray], str, bool, Optional[Tuple[float,float]], Optional[np.ndarray]) -> None 
    325353    """ 
    326354    Plot theory calculation. 
     
    337365    *limits* sets the intensity limits on the plot; if None then the limits 
    338366    are inferred from the data. 
     367 
     368    *Iq_calc* is the raw theory values without resolution smearing 
    339369    """ 
    340370    if hasattr(data, 'lam'): 
     
    348378 
    349379def protect(fn): 
     380    # type: (Callable) -> Callable 
    350381    """ 
    351382    Decorator to wrap calls in an exception trapper which prints the 
     
    358389        try: 
    359390            return fn(*args, **kw) 
    360         except KeyboardInterrupt: 
    361             raise 
    362         except: 
     391        except Exception: 
    363392            traceback.print_exc() 
    364393 
     
    369398def _plot_result1D(data, theory, resid, view, use_data, 
    370399                   limits=None, Iq_calc=None): 
     400    # type: (Data1D, Optional[np.ndarray], Optional[np.ndarray], str, bool, Optional[Tuple[float, float]], Optional[np.ndarray]) -> None 
    371401    """ 
    372402    Plot the data and residuals for 1D data. 
    373403    """ 
    374     import matplotlib.pyplot as plt 
    375     from numpy.ma import masked_array, masked 
     404    import matplotlib.pyplot as plt  # type: ignore 
     405    from numpy.ma import masked_array, masked  # type: ignore 
    376406 
    377407    use_data = use_data and data.y is not None 
     
    446476@protect 
    447477def _plot_result_sesans(data, theory, resid, use_data, limits=None): 
     478    # type: (SesansData, Optional[np.ndarray], Optional[np.ndarray], bool, Optional[Tuple[float, float]]) -> None 
    448479    """ 
    449480    Plot SESANS results. 
    450481    """ 
    451     import matplotlib.pyplot as plt 
     482    import matplotlib.pyplot as plt  # type: ignore 
    452483    use_data = use_data and data.y is not None 
    453484    use_theory = theory is not None 
     
    456487 
    457488    if use_data or use_theory: 
    458         is_tof = np.any(data.lam!=data.lam[0]) 
     489        is_tof = (data.lam != data.lam[0]).any() 
    459490        if num_plots > 1: 
    460491            plt.subplot(1, num_plots, 1) 
    461492        if use_data: 
    462493            if is_tof: 
    463                 plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam), yerr=data.dy/data.y/(data.lam*data.lam)) 
     494                plt.errorbar(data.x, np.log(data.y)/(data.lam*data.lam), 
     495                             yerr=data.dy/data.y/(data.lam*data.lam)) 
    464496            else: 
    465497                plt.errorbar(data.x, data.y, yerr=data.dy) 
     
    489521@protect 
    490522def _plot_result2D(data, theory, resid, view, use_data, limits=None): 
     523    # type: (Data2D, Optional[np.ndarray], Optional[np.ndarray], str, bool, Optional[Tuple[float,float]]) -> None 
    491524    """ 
    492525    Plot the data and residuals for 2D data. 
    493526    """ 
    494     import matplotlib.pyplot as plt 
     527    import matplotlib.pyplot as plt  # type: ignore 
    495528    use_data = use_data and data.data is not None 
    496529    use_theory = theory is not None 
     
    500533    # Put theory and data on a common colormap scale 
    501534    vmin, vmax = np.inf, -np.inf 
     535    target = None # type: Optional[np.ndarray] 
    502536    if use_data: 
    503537        target = data.data[~data.mask] 
     
    548582@protect 
    549583def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'): 
     584    # type: (Data2D, np.ndarray, Optional[float], Optional[float], str) -> Tuple[float, float] 
    550585    """ 
    551586    Plot the target value for the data.  This could be the data itself, 
     
    554589    *scale* can be 'log' for log scale data, or 'linear'. 
    555590    """ 
    556     import matplotlib.pyplot as plt 
    557     from numpy.ma import masked_array 
     591    import matplotlib.pyplot as plt  # type: ignore 
     592    from numpy.ma import masked_array  # type: ignore 
    558593 
    559594    image = np.zeros_like(data.qx_data) 
     
    589624 
    590625def demo(): 
     626    # type: () -> None 
    591627    """ 
    592628    Load and plot a SAS dataset. 
     
    595631    set_beam_stop(data, 0.004) 
    596632    plot_data(data) 
    597     import matplotlib.pyplot as plt; plt.show() 
     633    import matplotlib.pyplot as plt  # type: ignore 
     634    plt.show() 
    598635 
    599636 
Note: See TracChangeset for help on using the changeset viewer.