Ignore:
Timestamp:
Oct 19, 2018 4:17:38 PM (6 years ago)
Author:
krzywon
Branches:
unittest-saveload
Children:
08f921e
Parents:
497e06d (diff), 9fb4572 (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
Message:

Merge branch 'ticket-1111' into unittest-saveload

File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sas/sascalc/dataloader/file_reader_base_class.py

    r9e6aeaf r9fb4572  
    77import os 
    88import sys 
    9 import re 
     9import math 
    1010import logging 
    1111from abc import abstractmethod 
     
    1616from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\ 
    1717    combine_data_info_with_plottable 
     18from sas.sascalc.data_util.nxsunit import Converter 
    1819 
    1920logger = logging.getLogger(__name__) 
     
    2627        return s.decode() if isinstance(s, bytes) else s 
    2728 
     29# Data 1D fields for iterative purposes 
     30FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw') 
     31# Data 2D fields for iterative purposes 
     32FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data', 
     33                 'dqx_data', 'dqy_data', 'mask') 
     34DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh" 
     35                       "t not be fully reduced. Support for the reader associat" 
     36                       "ed with this file type has been removed. An attempt to " 
     37                       "load the file was made, but, should it be successful, " 
     38                       "SasView cannot guarantee the accuracy of the data.") 
     39 
     40 
    2841class FileReader(object): 
    29     # List of Data1D and Data2D objects to be sent back to data_loader 
    30     output = [] 
    31     # Current plottable_(1D/2D) object being loaded in 
    32     current_dataset = None 
    33     # Current DataInfo object being loaded in 
    34     current_datainfo = None 
    3542    # String to describe the type of data this reader can load 
    3643    type_name = "ASCII" 
     
    3946    # List of allowed extensions 
    4047    ext = ['.txt'] 
     48    # Deprecated extensions 
     49    deprecated_extensions = ['.asc', '.nxs'] 
    4150    # Bypass extension check and try to load anyway 
    4251    allow_all = False 
    4352    # Able to import the unit converter 
    4453    has_converter = True 
    45     # Open file handle 
    46     f_open = None 
    4754    # Default value of zero 
    4855    _ZERO = 1e-16 
    4956 
     57    def __init__(self): 
     58        # List of Data1D and Data2D objects to be sent back to data_loader 
     59        self.output = [] 
     60        # Current plottable_(1D/2D) object being loaded in 
     61        self.current_dataset = None 
     62        # Current DataInfo object being loaded in 
     63        self.current_datainfo = None 
     64        # File path sent to reader 
     65        self.filepath = None 
     66        # Open file handle 
     67        self.f_open = None 
     68 
    5069    def read(self, filepath): 
    5170        """ 
     
    5473        :param filepath: The full or relative path to a file to be loaded 
    5574        """ 
     75        self.filepath = filepath 
    5676        if os.path.isfile(filepath): 
    5777            basename, extension = os.path.splitext(os.path.basename(filepath)) 
     
    7595                    if not self.f_open.closed: 
    7696                        self.f_open.close() 
     97                    if any(filepath.lower().endswith(ext) for ext in 
     98                           self.deprecated_extensions): 
     99                        self.handle_error_message(DEPRECATION_MESSAGE) 
    77100                    if len(self.output) > 0: 
    78101                        # Sort the data that's been loaded 
    79                         self.sort_one_d_data() 
    80                         self.sort_two_d_data() 
     102                        self.convert_data_units() 
     103                        self.sort_data() 
    81104        else: 
    82105            msg = "Unable to find file at: {}\n".format(filepath) 
     
    85108 
    86109        # Return a list of parsed entries that data_loader can manage 
    87         return self.output 
     110        final_data = self.output 
     111        self.reset_state() 
     112        return final_data 
     113 
     114    def reset_state(self): 
     115        """ 
     116        Resets the class state to a base case when loading a new data file so previous 
     117        data files do not appear a second time 
     118        """ 
     119        self.current_datainfo = None 
     120        self.current_dataset = None 
     121        self.filepath = None 
     122        self.ind = None 
     123        self.output = [] 
    88124 
    89125    def nextline(self): 
     
    106142        Returns the entire file as a string. 
    107143        """ 
    108         #return self.f_open.read() 
    109144        return decode(self.f_open.read()) 
    110145 
     
    112147        """ 
    113148        Generic error handler to add an error to the current datainfo to 
    114         propogate the error up the error chain. 
     149        propagate the error up the error chain. 
    115150        :param msg: Error message 
    116151        """ 
     
    121156        else: 
    122157            logger.warning(msg) 
     158            raise NoKnownLoaderException(msg) 
    123159 
    124160    def send_to_output(self): 
     
    131167        self.output.append(data_obj) 
    132168 
    133     def sort_one_d_data(self): 
     169    def sort_data(self): 
    134170        """ 
    135171        Sort 1D data along the X axis for consistency 
     
    139175                # Normalize the units for 
    140176                data.x_unit = self.format_unit(data.x_unit) 
     177                data._xunit = data.x_unit 
    141178                data.y_unit = self.format_unit(data.y_unit) 
     179                data._yunit = data.y_unit 
    142180                # Sort data by increasing x and remove 1st point 
    143181                ind = np.lexsort((data.y, data.x)) 
    144                 data.x = np.asarray([data.x[i] for i in ind]).astype(np.float64) 
    145                 data.y = np.asarray([data.y[i] for i in ind]).astype(np.float64) 
     182                data.x = self._reorder_1d_array(data.x, ind) 
     183                data.y = self._reorder_1d_array(data.y, ind) 
    146184                if data.dx is not None: 
    147185                    if len(data.dx) == 0: 
    148186                        data.dx = None 
    149187                        continue 
    150                     data.dx = np.asarray([data.dx[i] for i in ind]).astype(np.float64) 
     188                    data.dx = self._reorder_1d_array(data.dx, ind) 
    151189                if data.dxl is not None: 
    152                     data.dxl = np.asarray([data.dxl[i] for i in ind]).astype(np.float64) 
     190                    data.dxl = self._reorder_1d_array(data.dxl, ind) 
    153191                if data.dxw is not None: 
    154                     data.dxw = np.asarray([data.dxw[i] for i in ind]).astype(np.float64) 
     192                    data.dxw = self._reorder_1d_array(data.dxw, ind) 
    155193                if data.dy is not None: 
    156194                    if len(data.dy) == 0: 
    157195                        data.dy = None 
    158196                        continue 
    159                     data.dy = np.asarray([data.dy[i] for i in ind]).astype(np.float64) 
     197                    data.dy = self._reorder_1d_array(data.dy, ind) 
    160198                if data.lam is not None: 
    161                     data.lam = np.asarray([data.lam[i] for i in ind]).astype(np.float64) 
     199                    data.lam = self._reorder_1d_array(data.lam, ind) 
    162200                if data.dlam is not None: 
    163                     data.dlam = np.asarray([data.dlam[i] for i in ind]).astype(np.float64) 
     201                    data.dlam = self._reorder_1d_array(data.dlam, ind) 
     202                data = self._remove_nans_in_data(data) 
    164203                if len(data.x) > 0: 
    165204                    data.xmin = np.min(data.x) 
     
    167206                    data.ymin = np.min(data.y) 
    168207                    data.ymax = np.max(data.y) 
    169  
    170     def sort_two_d_data(self): 
    171         for dataset in self.output: 
    172             if isinstance(dataset, Data2D): 
     208            elif isinstance(data, Data2D): 
    173209                # Normalize the units for 
    174                 dataset.x_unit = self.format_unit(dataset.Q_unit) 
    175                 dataset.y_unit = self.format_unit(dataset.I_unit) 
    176                 dataset.data = dataset.data.astype(np.float64) 
    177                 dataset.qx_data = dataset.qx_data.astype(np.float64) 
    178                 dataset.xmin = np.min(dataset.qx_data) 
    179                 dataset.xmax = np.max(dataset.qx_data) 
    180                 dataset.qy_data = dataset.qy_data.astype(np.float64) 
    181                 dataset.ymin = np.min(dataset.qy_data) 
    182                 dataset.ymax = np.max(dataset.qy_data) 
    183                 dataset.q_data = np.sqrt(dataset.qx_data * dataset.qx_data 
    184                                          + dataset.qy_data * dataset.qy_data) 
    185                 if dataset.err_data is not None: 
    186                     dataset.err_data = dataset.err_data.astype(np.float64) 
    187                 if dataset.dqx_data is not None: 
    188                     dataset.dqx_data = dataset.dqx_data.astype(np.float64) 
    189                 if dataset.dqy_data is not None: 
    190                     dataset.dqy_data = dataset.dqy_data.astype(np.float64) 
    191                 if dataset.mask is not None: 
    192                     dataset.mask = dataset.mask.astype(dtype=bool) 
    193  
    194                 if len(dataset.data.shape) == 2: 
    195                     n_rows, n_cols = dataset.data.shape 
    196                     dataset.y_bins = dataset.qy_data[0::int(n_cols)] 
    197                     dataset.x_bins = dataset.qx_data[:int(n_cols)] 
    198                 dataset.data = dataset.data.flatten() 
    199                 if len(dataset.data) > 0: 
    200                     dataset.xmin = np.min(dataset.qx_data) 
    201                     dataset.xmax = np.max(dataset.qx_data) 
    202                     dataset.ymin = np.min(dataset.qy_data) 
    203                     dataset.ymax = np.max(dataset.qx_data) 
     210                data.Q_unit = self.format_unit(data.Q_unit) 
     211                data.I_unit = self.format_unit(data.I_unit) 
     212                data._xunit = data.Q_unit 
     213                data._yunit = data.Q_unit 
     214                data._zunit = data.I_unit 
     215                data.data = data.data.astype(np.float64) 
     216                data.qx_data = data.qx_data.astype(np.float64) 
     217                data.xmin = np.min(data.qx_data) 
     218                data.xmax = np.max(data.qx_data) 
     219                data.qy_data = data.qy_data.astype(np.float64) 
     220                data.ymin = np.min(data.qy_data) 
     221                data.ymax = np.max(data.qy_data) 
     222                data.q_data = np.sqrt(data.qx_data * data.qx_data 
     223                                         + data.qy_data * data.qy_data) 
     224                if data.err_data is not None: 
     225                    data.err_data = data.err_data.astype(np.float64) 
     226                if data.dqx_data is not None: 
     227                    data.dqx_data = data.dqx_data.astype(np.float64) 
     228                if data.dqy_data is not None: 
     229                    data.dqy_data = data.dqy_data.astype(np.float64) 
     230                if data.mask is not None: 
     231                    data.mask = data.mask.astype(dtype=bool) 
     232 
     233                if len(data.data.shape) == 2: 
     234                    n_rows, n_cols = data.data.shape 
     235                    data.y_bins = data.qy_data[0::int(n_cols)] 
     236                    data.x_bins = data.qx_data[:int(n_cols)] 
     237                    data.data = data.data.flatten() 
     238                data = self._remove_nans_in_data(data) 
     239                if len(data.data) > 0: 
     240                    data.xmin = np.min(data.qx_data) 
     241                    data.xmax = np.max(data.qx_data) 
     242                    data.ymin = np.min(data.qy_data) 
     243                    data.ymax = np.max(data.qx_data) 
     244 
     245    @staticmethod 
     246    def _reorder_1d_array(array, ind): 
     247        """ 
     248        Reorders a 1D array based on the indices passed as ind 
     249        :param array: Array to be reordered 
     250        :param ind: Indices used to reorder array 
     251        :return: reordered array 
     252        """ 
     253        array = np.asarray(array, dtype=np.float64) 
     254        return array[ind] 
     255 
     256    @staticmethod 
     257    def _remove_nans_in_data(data): 
     258        """ 
     259        Remove data points where nan is loaded 
     260        :param data: 1D or 2D data object 
     261        :return: data with nan points removed 
     262        """ 
     263        if isinstance(data, Data1D): 
     264            fields = FIELDS_1D 
     265        elif isinstance(data, Data2D): 
     266            fields = FIELDS_2D 
     267        else: 
     268            return data 
     269        # Make array of good points - all others will be removed 
     270        good = np.isfinite(getattr(data, fields[0])) 
     271        for name in fields[1:]: 
     272            array = getattr(data, name) 
     273            if array is not None: 
     274                # Update good points only if not already changed 
     275                good &= np.isfinite(array) 
     276        if not np.all(good): 
     277            for name in fields: 
     278                array = getattr(data, name) 
     279                if array is not None: 
     280                    setattr(data, name, array[good]) 
     281        return data 
     282 
     283    @staticmethod 
     284    def set_default_1d_units(data): 
     285        """ 
     286        Set the x and y axes to the default 1D units 
     287        :param data: 1D data set 
     288        :return: 
     289        """ 
     290        data.xaxis(r"\rm{Q}", '1/A') 
     291        data.yaxis(r"\rm{Intensity}", "1/cm") 
     292        return data 
     293 
     294    @staticmethod 
     295    def set_default_2d_units(data): 
     296        """ 
     297        Set the x and y axes to the default 2D units 
     298        :param data: 2D data set 
     299        :return: 
     300        """ 
     301        data.xaxis("\\rm{Q_{x}}", '1/A') 
     302        data.yaxis("\\rm{Q_{y}}", '1/A') 
     303        data.zaxis("\\rm{Intensity}", "1/cm") 
     304        return data 
     305 
     306    def convert_data_units(self, default_q_unit="1/A", default_i_unit="1/cm"): 
     307        """ 
     308        Converts al; data to the sasview default of units of A^{-1} for Q and 
     309        cm^{-1} for I. 
     310        :param default_q_unit: The default Q unit used by Sasview 
     311        :param default_i_unit: The default I unit used by Sasview 
     312        """ 
     313        new_output = [] 
     314        for data in self.output: 
     315            if data.isSesans: 
     316                new_output.append(data) 
     317                continue 
     318            file_x_unit = data._xunit 
     319            data_conv_x = Converter(file_x_unit) 
     320            file_y_unit = data._yunit 
     321            data_conv_y = Converter(file_y_unit) 
     322            if isinstance(data, Data1D): 
     323                try: 
     324                    data.x = data_conv_x(data.x, units=default_q_unit) 
     325                    data._xunit = default_q_unit 
     326                    data.x_unit = default_q_unit 
     327                    if data.dx is not None: 
     328                        data.dx = data_conv_x(data.dx, units=default_q_unit) 
     329                    if data.dxl is not None: 
     330                        data.dxl = data_conv_x(data.dxl, units=default_q_unit) 
     331                    if data.dxw is not None: 
     332                        data.dxw = data_conv_x(data.dxw, units=default_q_unit) 
     333                except KeyError: 
     334                    message = "Unable to convert Q units from {0} to 1/A." 
     335                    message.format(default_q_unit) 
     336                    data.errors.append(message) 
     337                try: 
     338                    data.y = data_conv_y(data.y, units=default_i_unit) 
     339                    data._yunit = default_i_unit 
     340                    data.y_unit = default_i_unit 
     341                    if data.dy is not None: 
     342                        data.dy = data_conv_y(data.dy, units=default_i_unit) 
     343                except KeyError: 
     344                    message = "Unable to convert I units from {0} to 1/cm." 
     345                    message.format(default_q_unit) 
     346                    data.errors.append(message) 
     347            elif isinstance(data, Data2D): 
     348                try: 
     349                    data.qx_data = data_conv_x(data.qx_data, 
     350                                               units=default_q_unit) 
     351                    if data.dqx_data is not None: 
     352                        data.dqx_data = data_conv_x(data.dqx_data, 
     353                                                    units=default_q_unit) 
     354                    data.qy_data = data_conv_y(data.qy_data, 
     355                                               units=default_q_unit) 
     356                    if data.dqy_data is not None: 
     357                        data.dqy_data = data_conv_y(data.dqy_data, 
     358                                                    units=default_q_unit) 
     359                except KeyError: 
     360                    message = "Unable to convert Q units from {0} to 1/A." 
     361                    message.format(default_q_unit) 
     362                    data.errors.append(message) 
     363                try: 
     364                    file_z_unit = data._zunit 
     365                    data_conv_z = Converter(file_z_unit) 
     366                    data.data = data_conv_z(data.data, units=default_i_unit) 
     367                    if data.err_data is not None: 
     368                        data.err_data = data_conv_z(data.err_data, 
     369                                                    units=default_i_unit) 
     370                except KeyError: 
     371                    message = "Unable to convert I units from {0} to 1/cm." 
     372                    message.format(default_q_unit) 
     373                    data.errors.append(message) 
     374            else: 
     375                # TODO: Throw error of some sort... 
     376                pass 
     377            new_output.append(data) 
     378        self.output = new_output 
    204379 
    205380    def format_unit(self, unit=None): 
     
    292467                    self.current_dataset.qy_data)) 
    293468            if has_error_dy: 
    294                 self.current_dataset.err_data = self.current_dataset.err_data[x != 0] 
     469                self.current_dataset.err_data = self.current_dataset.err_data[ 
     470                    x != 0] 
    295471            if has_error_dqx: 
    296                 self.current_dataset.dqx_data = self.current_dataset.dqx_data[x != 0] 
     472                self.current_dataset.dqx_data = self.current_dataset.dqx_data[ 
     473                    x != 0] 
    297474            if has_error_dqy: 
    298                 self.current_dataset.dqy_data = self.current_dataset.dqy_data[x != 0] 
     475                self.current_dataset.dqy_data = self.current_dataset.dqy_data[ 
     476                    x != 0] 
    299477            if has_mask: 
    300478                self.current_dataset.mask = self.current_dataset.mask[x != 0] 
     
    314492    def splitline(line): 
    315493        """ 
    316         Splits a line into pieces based on common delimeters 
     494        Splits a line into pieces based on common delimiters 
    317495        :param line: A single line of text 
    318496        :return: list of values 
Note: See TracChangeset for help on using the changeset viewer.