source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ b9cc210

Last change on this file since b9cc210 was 4a8d55c, checked in by krzywon, 7 years ago

Propagate through loader when errors are thrown regardless of the error. Add tests using the same file with different extensions (including deprecated extensions).

  • Property mode set to 100644
File size: 16.2 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import math
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18
19logger = logging.getLogger(__name__)
20
21if sys.version_info[0] < 3:
22    def decode(s):
23        return s
24else:
25    def decode(s):
26        return s.decode() if isinstance(s, bytes) else s
27
28# Data 1D fields for iterative purposes
29FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw')
30# Data 2D fields for iterative purposes
31FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data',
32                 'dqx_data', 'dqy_data', 'mask')
33DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh"
34                       "t not be fully reduced. Support for the reader associat"
35                       "ed with this file type has been removed. An attempt to "
36                       "load the file was made, but, should it be successful, "
37                       "SasView cannot guarantee the accuracy of the data.")
38
39class FileReader(object):
40    # String to describe the type of data this reader can load
41    type_name = "ASCII"
42    # Wildcards to display
43    type = ["Text files (*.txt|*.TXT)"]
44    # List of allowed extensions
45    ext = ['.txt']
46    # Deprecated extensions
47    deprecated_extensions = ['.asc', '.nxs']
48    # Bypass extension check and try to load anyway
49    allow_all = False
50    # Able to import the unit converter
51    has_converter = True
52    # Default value of zero
53    _ZERO = 1e-16
54
55    def __init__(self):
56        # List of Data1D and Data2D objects to be sent back to data_loader
57        self.output = []
58        # Current plottable_(1D/2D) object being loaded in
59        self.current_dataset = None
60        # Current DataInfo object being loaded in
61        self.current_datainfo = None
62        # File path sent to reader
63        self.filepath = None
64        # Open file handle
65        self.f_open = None
66
67    def read(self, filepath):
68        """
69        Basic file reader
70
71        :param filepath: The full or relative path to a file to be loaded
72        """
73        self.filepath = filepath
74        if os.path.isfile(filepath):
75            basename, extension = os.path.splitext(os.path.basename(filepath))
76            self.extension = extension.lower()
77            # If the file type is not allowed, return nothing
78            if self.extension in self.ext or self.allow_all:
79                # Try to load the file, but raise an error if unable to.
80                try:
81                    self.f_open = open(filepath, 'rb')
82                    self.get_file_contents()
83
84                except DataReaderException as e:
85                    self.handle_error_message(e.message)
86                except OSError as e:
87                    # If the file cannot be opened
88                    msg = "Unable to open file: {}\n".format(filepath)
89                    msg += e.message
90                    self.handle_error_message(msg)
91                finally:
92                    # Close the file handle if it is open
93                    if not self.f_open.closed:
94                        self.f_open.close()
95                    if any(filepath.lower().endswith(ext) for ext in
96                           self.deprecated_extensions):
97                        self.handle_error_message(DEPRECATION_MESSAGE)
98                    if len(self.output) > 0:
99                        # Sort the data that's been loaded
100                        self.sort_one_d_data()
101                        self.sort_two_d_data()
102        else:
103            msg = "Unable to find file at: {}\n".format(filepath)
104            msg += "Please check your file path and try again."
105            self.handle_error_message(msg)
106
107        # Return a list of parsed entries that data_loader can manage
108        final_data = self.output
109        self.reset_state()
110        return final_data
111
112    def reset_state(self):
113        """
114        Resets the class state to a base case when loading a new data file so previous
115        data files do not appear a second time
116        """
117        self.current_datainfo = None
118        self.current_dataset = None
119        self.filepath = None
120        self.ind = None
121        self.output = []
122
123    def nextline(self):
124        """
125        Returns the next line in the file as a string.
126        """
127        #return self.f_open.readline()
128        return decode(self.f_open.readline())
129
130    def nextlines(self):
131        """
132        Returns the next line in the file as a string.
133        """
134        for line in self.f_open:
135            #yield line
136            yield decode(line)
137
138    def readall(self):
139        """
140        Returns the entire file as a string.
141        """
142        #return self.f_open.read()
143        return decode(self.f_open.read())
144
145    def handle_error_message(self, msg):
146        """
147        Generic error handler to add an error to the current datainfo to
148        propagate the error up the error chain.
149        :param msg: Error message
150        """
151        if len(self.output) > 0:
152            self.output[-1].errors.append(msg)
153        elif isinstance(self.current_datainfo, DataInfo):
154            self.current_datainfo.errors.append(msg)
155        else:
156            logger.warning(msg)
157            raise NoKnownLoaderException(msg)
158
159    def send_to_output(self):
160        """
161        Helper that automatically combines the info and set and then appends it
162        to output
163        """
164        data_obj = combine_data_info_with_plottable(self.current_dataset,
165                                                    self.current_datainfo)
166        self.output.append(data_obj)
167
168    def sort_one_d_data(self):
169        """
170        Sort 1D data along the X axis for consistency
171        """
172        for data in self.output:
173            if isinstance(data, Data1D):
174                # Normalize the units for
175                data.x_unit = self.format_unit(data.x_unit)
176                data.y_unit = self.format_unit(data.y_unit)
177                # Sort data by increasing x and remove 1st point
178                ind = np.lexsort((data.y, data.x))
179                data.x = self._reorder_1d_array(data.x, ind)
180                data.y = self._reorder_1d_array(data.y, ind)
181                if data.dx is not None:
182                    if len(data.dx) == 0:
183                        data.dx = None
184                        continue
185                    data.dx = self._reorder_1d_array(data.dx, ind)
186                if data.dxl is not None:
187                    data.dxl = self._reorder_1d_array(data.dxl, ind)
188                if data.dxw is not None:
189                    data.dxw = self._reorder_1d_array(data.dxw, ind)
190                if data.dy is not None:
191                    if len(data.dy) == 0:
192                        data.dy = None
193                        continue
194                    data.dy = self._reorder_1d_array(data.dy, ind)
195                if data.lam is not None:
196                    data.lam = self._reorder_1d_array(data.lam, ind)
197                if data.dlam is not None:
198                    data.dlam = self._reorder_1d_array(data.dlam, ind)
199                data = self._remove_nans_in_data(data)
200                if len(data.x) > 0:
201                    data.xmin = np.min(data.x)
202                    data.xmax = np.max(data.x)
203                    data.ymin = np.min(data.y)
204                    data.ymax = np.max(data.y)
205
206    @staticmethod
207    def _reorder_1d_array(array, ind):
208        """
209        Reorders a 1D array based on the indices passed as ind
210        :param array: Array to be reordered
211        :param ind: Indices used to reorder array
212        :return: reordered array
213        """
214        array = np.asarray(array, dtype=np.float64)
215        return array[ind]
216
217    @staticmethod
218    def _remove_nans_in_data(data):
219        """
220        Remove data points where nan is loaded
221        :param data: 1D or 2D data object
222        :return: data with nan points removed
223        """
224        if isinstance(data, Data1D):
225            fields = FIELDS_1D
226        elif isinstance(data, Data2D):
227            fields = FIELDS_2D
228        else:
229            return data
230        # Make array of good points - all others will be removed
231        good = np.isfinite(getattr(data, fields[0]))
232        for name in fields[1:]:
233            array = getattr(data, name)
234            if array is not None:
235                # Update good points only if not already changed
236                good &= np.isfinite(array)
237        if not np.all(good):
238            for name in fields:
239                array = getattr(data, name)
240                if array is not None:
241                    setattr(data, name, array[good])
242        return data
243
244    def sort_two_d_data(self):
245        for dataset in self.output:
246            if isinstance(dataset, Data2D):
247                # Normalize the units for
248                dataset.x_unit = self.format_unit(dataset.Q_unit)
249                dataset.y_unit = self.format_unit(dataset.I_unit)
250                dataset.data = dataset.data.astype(np.float64)
251                dataset.qx_data = dataset.qx_data.astype(np.float64)
252                dataset.xmin = np.min(dataset.qx_data)
253                dataset.xmax = np.max(dataset.qx_data)
254                dataset.qy_data = dataset.qy_data.astype(np.float64)
255                dataset.ymin = np.min(dataset.qy_data)
256                dataset.ymax = np.max(dataset.qy_data)
257                dataset.q_data = np.sqrt(dataset.qx_data * dataset.qx_data
258                                         + dataset.qy_data * dataset.qy_data)
259                if dataset.err_data is not None:
260                    dataset.err_data = dataset.err_data.astype(np.float64)
261                if dataset.dqx_data is not None:
262                    dataset.dqx_data = dataset.dqx_data.astype(np.float64)
263                if dataset.dqy_data is not None:
264                    dataset.dqy_data = dataset.dqy_data.astype(np.float64)
265                if dataset.mask is not None:
266                    dataset.mask = dataset.mask.astype(dtype=bool)
267
268                if len(dataset.data.shape) == 2:
269                    n_rows, n_cols = dataset.data.shape
270                    dataset.y_bins = dataset.qy_data[0::int(n_cols)]
271                    dataset.x_bins = dataset.qx_data[:int(n_cols)]
272                dataset.data = dataset.data.flatten()
273                dataset = self._remove_nans_in_data(dataset)
274                if len(dataset.data) > 0:
275                    dataset.xmin = np.min(dataset.qx_data)
276                    dataset.xmax = np.max(dataset.qx_data)
277                    dataset.ymin = np.min(dataset.qy_data)
278                    dataset.ymax = np.max(dataset.qx_data)
279
280    def format_unit(self, unit=None):
281        """
282        Format units a common way
283        :param unit:
284        :return:
285        """
286        if unit:
287            split = unit.split("/")
288            if len(split) == 1:
289                return unit
290            elif split[0] == '1':
291                return "{0}^".format(split[1]) + "{-1}"
292            else:
293                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
294
295    def set_all_to_none(self):
296        """
297        Set all mutable values to None for error handling purposes
298        """
299        self.current_dataset = None
300        self.current_datainfo = None
301        self.output = []
302
303    def data_cleanup(self):
304        """
305        Clean up the data sets and refresh everything
306        :return: None
307        """
308        self.remove_empty_q_values()
309        self.send_to_output()  # Combine datasets with DataInfo
310        self.current_datainfo = DataInfo()  # Reset DataInfo
311
312    def remove_empty_q_values(self):
313        """
314        Remove any point where Q == 0
315        """
316        if isinstance(self.current_dataset, plottable_1D):
317            # Booleans for resolutions
318            has_error_dx = self.current_dataset.dx is not None
319            has_error_dxl = self.current_dataset.dxl is not None
320            has_error_dxw = self.current_dataset.dxw is not None
321            has_error_dy = self.current_dataset.dy is not None
322            # Create arrays of zeros for non-existent resolutions
323            if has_error_dxw and not has_error_dxl:
324                array_size = self.current_dataset.dxw.size - 1
325                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
326                                                    np.zeros([array_size]))
327                has_error_dxl = True
328            elif has_error_dxl and not has_error_dxw:
329                array_size = self.current_dataset.dxl.size - 1
330                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
331                                                    np.zeros([array_size]))
332                has_error_dxw = True
333            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
334                array_size = self.current_dataset.x.size - 1
335                self.current_dataset.dx = np.append(self.current_dataset.dx,
336                                                    np.zeros([array_size]))
337                has_error_dx = True
338            if not has_error_dy:
339                array_size = self.current_dataset.y.size - 1
340                self.current_dataset.dy = np.append(self.current_dataset.dy,
341                                                    np.zeros([array_size]))
342                has_error_dy = True
343
344            # Remove points where q = 0
345            x = self.current_dataset.x
346            self.current_dataset.x = self.current_dataset.x[x != 0]
347            self.current_dataset.y = self.current_dataset.y[x != 0]
348            if has_error_dy:
349                self.current_dataset.dy = self.current_dataset.dy[x != 0]
350            if has_error_dx:
351                self.current_dataset.dx = self.current_dataset.dx[x != 0]
352            if has_error_dxl:
353                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
354            if has_error_dxw:
355                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
356        elif isinstance(self.current_dataset, plottable_2D):
357            has_error_dqx = self.current_dataset.dqx_data is not None
358            has_error_dqy = self.current_dataset.dqy_data is not None
359            has_error_dy = self.current_dataset.err_data is not None
360            has_mask = self.current_dataset.mask is not None
361            x = self.current_dataset.qx_data
362            self.current_dataset.data = self.current_dataset.data[x != 0]
363            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
364            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
365            self.current_dataset.q_data = np.sqrt(
366                np.square(self.current_dataset.qx_data) + np.square(
367                    self.current_dataset.qy_data))
368            if has_error_dy:
369                self.current_dataset.err_data = self.current_dataset.err_data[x != 0]
370            if has_error_dqx:
371                self.current_dataset.dqx_data = self.current_dataset.dqx_data[x != 0]
372            if has_error_dqy:
373                self.current_dataset.dqy_data = self.current_dataset.dqy_data[x != 0]
374            if has_mask:
375                self.current_dataset.mask = self.current_dataset.mask[x != 0]
376
377    def reset_data_list(self, no_lines=0):
378        """
379        Reset the plottable_1D object
380        """
381        # Initialize data sets with arrays the maximum possible size
382        x = np.zeros(no_lines)
383        y = np.zeros(no_lines)
384        dx = np.zeros(no_lines)
385        dy = np.zeros(no_lines)
386        self.current_dataset = plottable_1D(x, y, dx, dy)
387
388    @staticmethod
389    def splitline(line):
390        """
391        Splits a line into pieces based on common delimiters
392        :param line: A single line of text
393        :return: list of values
394        """
395        # Initial try for CSV (split on ,)
396        toks = line.split(',')
397        # Now try SCSV (split on ;)
398        if len(toks) < 2:
399            toks = line.split(';')
400        # Now go for whitespace
401        if len(toks) < 2:
402            toks = line.split()
403        return toks
404
405    @abstractmethod
406    def get_file_contents(self):
407        """
408        Reader specific class to access the contents of the file
409        All reader classes that inherit from FileReader must implement
410        """
411        pass
Note: See TracBrowser for help on using the repository browser.