source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ a58b5a0

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since a58b5a0 was a58b5a0, checked in by krzywon, 6 years ago

Update documentation to match code.

  • Property mode set to 100644
File size: 15.5 KB
RevLine 
[beba407]1"""
[b09095a]2This is the base file reader class most file readers should inherit from.
[beba407]3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
[7b50f14]8import sys
[8475d16]9import math
[beba407]10import logging
11from abc import abstractmethod
[574adc7]12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
[da8bb53]15    DataReaderException, DefaultReaderException
[574adc7]16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
[beba407]17    combine_data_info_with_plottable
18
19logger = logging.getLogger(__name__)
20
[7b50f14]21if sys.version_info[0] < 3:
22    def decode(s):
23        return s
24else:
25    def decode(s):
26        return s.decode() if isinstance(s, bytes) else s
[beba407]27
[6fd7b20]28# Data 1D fields for iterative purposes
29FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw')
30# Data 2D fields for iterative purposes
31FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data',
[340291a]32                 'dqx_data', 'dqy_data', 'mask')
[6fd7b20]33
34
[beba407]35class FileReader(object):
[b09095a]36    # String to describe the type of data this reader can load
37    type_name = "ASCII"
38    # Wildcards to display
39    type = ["Text files (*.txt|*.TXT)"]
[beba407]40    # List of allowed extensions
41    ext = ['.txt']
42    # Bypass extension check and try to load anyway
43    allow_all = False
[b09095a]44    # Able to import the unit converter
45    has_converter = True
46    # Default value of zero
47    _ZERO = 1e-16
[beba407]48
[cb11a25]49    def __init__(self):
50        # List of Data1D and Data2D objects to be sent back to data_loader
51        self.output = []
52        # Current plottable_(1D/2D) object being loaded in
53        self.current_dataset = None
54        # Current DataInfo object being loaded in
55        self.current_datainfo = None
[3053a4a]56        # File path sent to reader
57        self.filepath = None
[cb11a25]58        # Open file handle
59        self.f_open = None
60
[beba407]61    def read(self, filepath):
62        """
[bc570f4]63        Basic file reader
64
[beba407]65        :param filepath: The full or relative path to a file to be loaded
66        """
[3053a4a]67        self.filepath = filepath
[beba407]68        if os.path.isfile(filepath):
69            basename, extension = os.path.splitext(os.path.basename(filepath))
[da8bb53]70            self.extension = extension.lower()
[beba407]71            # If the file type is not allowed, return nothing
[da8bb53]72            if self.extension in self.ext or self.allow_all:
[beba407]73                # Try to load the file, but raise an error if unable to.
74                try:
[b09095a]75                    self.f_open = open(filepath, 'rb')
76                    self.get_file_contents()
[0b79323]77
[bc570f4]78                except DataReaderException as e:
[da8bb53]79                    self.handle_error_message(e.message)
[beba407]80                except OSError as e:
[b09095a]81                    # If the file cannot be opened
[beba407]82                    msg = "Unable to open file: {}\n".format(filepath)
83                    msg += e.message
84                    self.handle_error_message(msg)
[b09095a]85                finally:
[da8bb53]86                    # Close the file handle if it is open
[b09095a]87                    if not self.f_open.closed:
88                        self.f_open.close()
[248ff73]89                    if len(self.output) > 0:
90                        # Sort the data that's been loaded
91                        self.sort_one_d_data()
92                        self.sort_two_d_data()
[beba407]93        else:
94            msg = "Unable to find file at: {}\n".format(filepath)
95            msg += "Please check your file path and try again."
96            self.handle_error_message(msg)
[a78433dd]97
[b09095a]98        # Return a list of parsed entries that data_loader can manage
[1576693]99        final_data = self.output
100        self.reset_state()
101        return final_data
[beba407]102
[61f329f0]103    def reset_state(self):
104        """
105        Resets the class state to a base case when loading a new data file so previous
106        data files do not appear a second time
107        """
108        self.current_datainfo = None
109        self.current_dataset = None
[3053a4a]110        self.filepath = None
[8475d16]111        self.ind = None
[61f329f0]112        self.output = []
113
[26183bf]114    def nextline(self):
115        """
116        Returns the next line in the file as a string.
117        """
118        #return self.f_open.readline()
[7b50f14]119        return decode(self.f_open.readline())
[26183bf]120
121    def nextlines(self):
122        """
123        Returns the next line in the file as a string.
124        """
125        for line in self.f_open:
126            #yield line
[7b50f14]127            yield decode(line)
[26183bf]128
129    def readall(self):
130        """
131        Returns the entire file as a string.
132        """
133        #return self.f_open.read()
[7b50f14]134        return decode(self.f_open.read())
[26183bf]135
[beba407]136    def handle_error_message(self, msg):
137        """
138        Generic error handler to add an error to the current datainfo to
[20fa5fe]139        propagate the error up the error chain.
[beba407]140        :param msg: Error message
141        """
[dcb91cf]142        if len(self.output) > 0:
143            self.output[-1].errors.append(msg)
144        elif isinstance(self.current_datainfo, DataInfo):
[beba407]145            self.current_datainfo.errors.append(msg)
146        else:
147            logger.warning(msg)
148
149    def send_to_output(self):
150        """
151        Helper that automatically combines the info and set and then appends it
152        to output
153        """
154        data_obj = combine_data_info_with_plottable(self.current_dataset,
155                                                    self.current_datainfo)
156        self.output.append(data_obj)
157
[b09095a]158    def sort_one_d_data(self):
159        """
160        Sort 1D data along the X axis for consistency
161        """
162        for data in self.output:
163            if isinstance(data, Data1D):
[a78a02f]164                # Normalize the units for
165                data.x_unit = self.format_unit(data.x_unit)
166                data.y_unit = self.format_unit(data.y_unit)
[7477fb9]167                # Sort data by increasing x and remove 1st point
[e3133dc]168                ind = np.lexsort((data.y, data.x))
169                data.x = self._reorder_1d_array(data.x, ind)
170                data.y = self._reorder_1d_array(data.y, ind)
[b09095a]171                if data.dx is not None:
[4660990]172                    if len(data.dx) == 0:
173                        data.dx = None
174                        continue
[e3133dc]175                    data.dx = self._reorder_1d_array(data.dx, ind)
[b09095a]176                if data.dxl is not None:
[e3133dc]177                    data.dxl = self._reorder_1d_array(data.dxl, ind)
[b09095a]178                if data.dxw is not None:
[e3133dc]179                    data.dxw = self._reorder_1d_array(data.dxw, ind)
[b09095a]180                if data.dy is not None:
[4660990]181                    if len(data.dy) == 0:
182                        data.dy = None
183                        continue
[e3133dc]184                    data.dy = self._reorder_1d_array(data.dy, ind)
[b09095a]185                if data.lam is not None:
[e3133dc]186                    data.lam = self._reorder_1d_array(data.lam, ind)
[b09095a]187                if data.dlam is not None:
[e3133dc]188                    data.dlam = self._reorder_1d_array(data.dlam, ind)
[f02a0c6]189                data = self._remove_nans_in_data(data)
[dcb91cf]190                if len(data.x) > 0:
[248ff73]191                    data.xmin = np.min(data.x)
192                    data.xmax = np.max(data.x)
193                    data.ymin = np.min(data.y)
194                    data.ymax = np.max(data.y)
[b09095a]195
[e3133dc]196    @staticmethod
197    def _reorder_1d_array(array, ind):
198        """
199        Reorders a 1D array based on the indices passed as ind
200        :param array: Array to be reordered
201        :param ind: Indices used to reorder array
202        :return: reordered array
203        """
204        array = np.asarray(array, dtype=np.float64)
205        return array[ind]
206
207    @staticmethod
[f02a0c6]208    def _remove_nans_in_data(data):
[e3133dc]209        """
210        Remove data points where nan is loaded
[a58b5a0]211        :param data: 1D or 2D data object
212        :return: data with nan points removed
[e3133dc]213        """
[f02a0c6]214        if isinstance(data, Data1D):
[6fd7b20]215            fields = FIELDS_1D
[f02a0c6]216        elif isinstance(data, Data2D):
[6fd7b20]217            fields = FIELDS_2D
[f02a0c6]218        else:
[6fd7b20]219            return data
[a58b5a0]220        # Make array of good points - all others will be removed
[6fd7b20]221        good = np.isfinite(getattr(data, fields[0]))
222        for name in fields[1:]:
223            array = getattr(data, name)
[e3133dc]224            if array is not None:
[a58b5a0]225                # Update good points only if not already changed
[6fd7b20]226                good &= np.isfinite(array)
227        if not np.all(good):
228            for name in fields:
229                array = getattr(data, name)
230                if array is not None:
231                    setattr(data, name, array[good])
[e3133dc]232        return data
[8475d16]233
[0b79323]234    def sort_two_d_data(self):
235        for dataset in self.output:
[9d786e5]236            if isinstance(dataset, Data2D):
[a78a02f]237                # Normalize the units for
238                dataset.x_unit = self.format_unit(dataset.Q_unit)
239                dataset.y_unit = self.format_unit(dataset.I_unit)
[9d786e5]240                dataset.data = dataset.data.astype(np.float64)
241                dataset.qx_data = dataset.qx_data.astype(np.float64)
242                dataset.xmin = np.min(dataset.qx_data)
243                dataset.xmax = np.max(dataset.qx_data)
244                dataset.qy_data = dataset.qy_data.astype(np.float64)
245                dataset.ymin = np.min(dataset.qy_data)
246                dataset.ymax = np.max(dataset.qy_data)
247                dataset.q_data = np.sqrt(dataset.qx_data * dataset.qx_data
248                                         + dataset.qy_data * dataset.qy_data)
249                if dataset.err_data is not None:
250                    dataset.err_data = dataset.err_data.astype(np.float64)
251                if dataset.dqx_data is not None:
252                    dataset.dqx_data = dataset.dqx_data.astype(np.float64)
253                if dataset.dqy_data is not None:
254                    dataset.dqy_data = dataset.dqy_data.astype(np.float64)
255                if dataset.mask is not None:
256                    dataset.mask = dataset.mask.astype(dtype=bool)
257
258                if len(dataset.data.shape) == 2:
259                    n_rows, n_cols = dataset.data.shape
260                    dataset.y_bins = dataset.qy_data[0::int(n_cols)]
261                    dataset.x_bins = dataset.qx_data[:int(n_cols)]
[2f85af7]262                dataset.data = dataset.data.flatten()
[f02a0c6]263                dataset = self._remove_nans_in_data(dataset)
[deaa0c6]264                if len(dataset.data) > 0:
265                    dataset.xmin = np.min(dataset.qx_data)
266                    dataset.xmax = np.max(dataset.qx_data)
267                    dataset.ymin = np.min(dataset.qy_data)
268                    dataset.ymax = np.max(dataset.qx_data)
[0b79323]269
[a78a02f]270    def format_unit(self, unit=None):
271        """
272        Format units a common way
273        :param unit:
274        :return:
275        """
276        if unit:
277            split = unit.split("/")
278            if len(split) == 1:
279                return unit
280            elif split[0] == '1':
281                return "{0}^".format(split[1]) + "{-1}"
282            else:
283                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
284
[da8bb53]285    def set_all_to_none(self):
286        """
287        Set all mutable values to None for error handling purposes
288        """
289        self.current_dataset = None
290        self.current_datainfo = None
291        self.output = []
292
[7b07fbe]293    def data_cleanup(self):
294        """
295        Clean up the data sets and refresh everything
296        :return: None
297        """
298        self.remove_empty_q_values()
299        self.send_to_output()  # Combine datasets with DataInfo
300        self.current_datainfo = DataInfo()  # Reset DataInfo
301
302    def remove_empty_q_values(self):
[ad92c5a]303        """
304        Remove any point where Q == 0
305        """
[7b07fbe]306        if isinstance(self.current_dataset, plottable_1D):
307            # Booleans for resolutions
308            has_error_dx = self.current_dataset.dx is not None
309            has_error_dxl = self.current_dataset.dxl is not None
310            has_error_dxw = self.current_dataset.dxw is not None
311            has_error_dy = self.current_dataset.dy is not None
312            # Create arrays of zeros for non-existent resolutions
313            if has_error_dxw and not has_error_dxl:
314                array_size = self.current_dataset.dxw.size - 1
315                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
316                                                    np.zeros([array_size]))
317                has_error_dxl = True
318            elif has_error_dxl and not has_error_dxw:
319                array_size = self.current_dataset.dxl.size - 1
320                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
321                                                    np.zeros([array_size]))
322                has_error_dxw = True
323            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
324                array_size = self.current_dataset.x.size - 1
325                self.current_dataset.dx = np.append(self.current_dataset.dx,
326                                                    np.zeros([array_size]))
327                has_error_dx = True
328            if not has_error_dy:
329                array_size = self.current_dataset.y.size - 1
330                self.current_dataset.dy = np.append(self.current_dataset.dy,
331                                                    np.zeros([array_size]))
332                has_error_dy = True
333
334            # Remove points where q = 0
335            x = self.current_dataset.x
336            self.current_dataset.x = self.current_dataset.x[x != 0]
337            self.current_dataset.y = self.current_dataset.y[x != 0]
338            if has_error_dy:
339                self.current_dataset.dy = self.current_dataset.dy[x != 0]
340            if has_error_dx:
341                self.current_dataset.dx = self.current_dataset.dx[x != 0]
342            if has_error_dxl:
343                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
344            if has_error_dxw:
345                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
346        elif isinstance(self.current_dataset, plottable_2D):
347            has_error_dqx = self.current_dataset.dqx_data is not None
348            has_error_dqy = self.current_dataset.dqy_data is not None
349            has_error_dy = self.current_dataset.err_data is not None
350            has_mask = self.current_dataset.mask is not None
351            x = self.current_dataset.qx_data
352            self.current_dataset.data = self.current_dataset.data[x != 0]
353            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
354            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
[deaa0c6]355            self.current_dataset.q_data = np.sqrt(
356                np.square(self.current_dataset.qx_data) + np.square(
357                    self.current_dataset.qy_data))
[7b07fbe]358            if has_error_dy:
359                self.current_dataset.err_data = self.current_dataset.err_data[x != 0]
360            if has_error_dqx:
361                self.current_dataset.dqx_data = self.current_dataset.dqx_data[x != 0]
362            if has_error_dqy:
363                self.current_dataset.dqy_data = self.current_dataset.dqy_data[x != 0]
364            if has_mask:
365                self.current_dataset.mask = self.current_dataset.mask[x != 0]
[ad92c5a]366
367    def reset_data_list(self, no_lines=0):
368        """
369        Reset the plottable_1D object
370        """
371        # Initialize data sets with arrays the maximum possible size
372        x = np.zeros(no_lines)
373        y = np.zeros(no_lines)
[4660990]374        dx = np.zeros(no_lines)
375        dy = np.zeros(no_lines)
376        self.current_dataset = plottable_1D(x, y, dx, dy)
[ad92c5a]377
[b09095a]378    @staticmethod
379    def splitline(line):
380        """
[20fa5fe]381        Splits a line into pieces based on common delimiters
[b09095a]382        :param line: A single line of text
383        :return: list of values
384        """
385        # Initial try for CSV (split on ,)
386        toks = line.split(',')
387        # Now try SCSV (split on ;)
388        if len(toks) < 2:
389            toks = line.split(';')
390        # Now go for whitespace
391        if len(toks) < 2:
392            toks = line.split()
393        return toks
394
[beba407]395    @abstractmethod
[b09095a]396    def get_file_contents(self):
[beba407]397        """
[ad92c5a]398        Reader specific class to access the contents of the file
[b09095a]399        All reader classes that inherit from FileReader must implement
[beba407]400        """
401        pass
Note: See TracBrowser for help on using the repository browser.