source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ 2924532

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249unittest-saveload
Last change on this file since 2924532 was 2924532, checked in by krzywon, 5 years ago

Cleanup of the SasView? GUI data loader error handling and a more specific error message for deprecated file extension.

  • Property mode set to 100644
File size: 16.2 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import math
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18
19logger = logging.getLogger(__name__)
20
21if sys.version_info[0] < 3:
22    def decode(s):
23        return s
24else:
25    def decode(s):
26        return s.decode() if isinstance(s, bytes) else s
27
28# Data 1D fields for iterative purposes
29FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw')
30# Data 2D fields for iterative purposes
31FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data',
32                 'dqx_data', 'dqy_data', 'mask')
33DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh"
34                       "t not be fully reduced. Support for the reader associat"
35                       "ed with this file type has been removed. An attempt to "
36                       "load the file was made, however, even if a data set was"
37                       " generated, SasView cannot guarantee the accuracy of th"
38                       "e data.")
39
40class FileReader(object):
41    # String to describe the type of data this reader can load
42    type_name = "ASCII"
43    # Wildcards to display
44    type = ["Text files (*.txt|*.TXT)"]
45    # List of allowed extensions
46    ext = ['.txt']
47    # Deprecated extensions
48    deprecated_extensions = ['.asc', '.nxs']
49    # Bypass extension check and try to load anyway
50    allow_all = False
51    # Able to import the unit converter
52    has_converter = True
53    # Default value of zero
54    _ZERO = 1e-16
55
56    def __init__(self):
57        # List of Data1D and Data2D objects to be sent back to data_loader
58        self.output = []
59        # Current plottable_(1D/2D) object being loaded in
60        self.current_dataset = None
61        # Current DataInfo object being loaded in
62        self.current_datainfo = None
63        # File path sent to reader
64        self.filepath = None
65        # Open file handle
66        self.f_open = None
67
68    def read(self, filepath):
69        """
70        Basic file reader
71
72        :param filepath: The full or relative path to a file to be loaded
73        """
74        self.filepath = filepath
75        if os.path.isfile(filepath):
76            basename, extension = os.path.splitext(os.path.basename(filepath))
77            self.extension = extension.lower()
78            # If the file type is not allowed, return nothing
79            if self.extension in self.ext or self.allow_all:
80                # Try to load the file, but raise an error if unable to.
81                try:
82                    self.f_open = open(filepath, 'rb')
83                    self.get_file_contents()
84
85                except DataReaderException as e:
86                    self.handle_error_message(e.message)
87                except OSError as e:
88                    # If the file cannot be opened
89                    msg = "Unable to open file: {}\n".format(filepath)
90                    msg += e.message
91                    self.handle_error_message(msg)
92                finally:
93                    # Close the file handle if it is open
94                    if not self.f_open.closed:
95                        self.f_open.close()
96                    if any(filepath.lower().endswith(ext) for ext in
97                           self.deprecated_extensions):
98                        self.handle_error_message(DEPRECATION_MESSAGE)
99                    if len(self.output) > 0:
100                        # Sort the data that's been loaded
101                        self.sort_one_d_data()
102                        self.sort_two_d_data()
103        else:
104            msg = "Unable to find file at: {}\n".format(filepath)
105            msg += "Please check your file path and try again."
106            self.handle_error_message(msg)
107
108        # Return a list of parsed entries that data_loader can manage
109        final_data = self.output
110        self.reset_state()
111        return final_data
112
113    def reset_state(self):
114        """
115        Resets the class state to a base case when loading a new data file so previous
116        data files do not appear a second time
117        """
118        self.current_datainfo = None
119        self.current_dataset = None
120        self.filepath = None
121        self.ind = None
122        self.output = []
123
124    def nextline(self):
125        """
126        Returns the next line in the file as a string.
127        """
128        #return self.f_open.readline()
129        return decode(self.f_open.readline())
130
131    def nextlines(self):
132        """
133        Returns the next line in the file as a string.
134        """
135        for line in self.f_open:
136            #yield line
137            yield decode(line)
138
139    def readall(self):
140        """
141        Returns the entire file as a string.
142        """
143        #return self.f_open.read()
144        return decode(self.f_open.read())
145
146    def handle_error_message(self, msg):
147        """
148        Generic error handler to add an error to the current datainfo to
149        propagate the error up the error chain.
150        :param msg: Error message
151        """
152        if len(self.output) > 0:
153            self.output[-1].errors.append(msg)
154        elif isinstance(self.current_datainfo, DataInfo):
155            self.current_datainfo.errors.append(msg)
156        else:
157            logger.warning(msg)
158            raise NoKnownLoaderException(msg)
159
160    def send_to_output(self):
161        """
162        Helper that automatically combines the info and set and then appends it
163        to output
164        """
165        data_obj = combine_data_info_with_plottable(self.current_dataset,
166                                                    self.current_datainfo)
167        self.output.append(data_obj)
168
169    def sort_one_d_data(self):
170        """
171        Sort 1D data along the X axis for consistency
172        """
173        for data in self.output:
174            if isinstance(data, Data1D):
175                # Normalize the units for
176                data.x_unit = self.format_unit(data.x_unit)
177                data.y_unit = self.format_unit(data.y_unit)
178                # Sort data by increasing x and remove 1st point
179                ind = np.lexsort((data.y, data.x))
180                data.x = self._reorder_1d_array(data.x, ind)
181                data.y = self._reorder_1d_array(data.y, ind)
182                if data.dx is not None:
183                    if len(data.dx) == 0:
184                        data.dx = None
185                        continue
186                    data.dx = self._reorder_1d_array(data.dx, ind)
187                if data.dxl is not None:
188                    data.dxl = self._reorder_1d_array(data.dxl, ind)
189                if data.dxw is not None:
190                    data.dxw = self._reorder_1d_array(data.dxw, ind)
191                if data.dy is not None:
192                    if len(data.dy) == 0:
193                        data.dy = None
194                        continue
195                    data.dy = self._reorder_1d_array(data.dy, ind)
196                if data.lam is not None:
197                    data.lam = self._reorder_1d_array(data.lam, ind)
198                if data.dlam is not None:
199                    data.dlam = self._reorder_1d_array(data.dlam, ind)
200                data = self._remove_nans_in_data(data)
201                if len(data.x) > 0:
202                    data.xmin = np.min(data.x)
203                    data.xmax = np.max(data.x)
204                    data.ymin = np.min(data.y)
205                    data.ymax = np.max(data.y)
206
207    @staticmethod
208    def _reorder_1d_array(array, ind):
209        """
210        Reorders a 1D array based on the indices passed as ind
211        :param array: Array to be reordered
212        :param ind: Indices used to reorder array
213        :return: reordered array
214        """
215        array = np.asarray(array, dtype=np.float64)
216        return array[ind]
217
218    @staticmethod
219    def _remove_nans_in_data(data):
220        """
221        Remove data points where nan is loaded
222        :param data: 1D or 2D data object
223        :return: data with nan points removed
224        """
225        if isinstance(data, Data1D):
226            fields = FIELDS_1D
227        elif isinstance(data, Data2D):
228            fields = FIELDS_2D
229        else:
230            return data
231        # Make array of good points - all others will be removed
232        good = np.isfinite(getattr(data, fields[0]))
233        for name in fields[1:]:
234            array = getattr(data, name)
235            if array is not None:
236                # Update good points only if not already changed
237                good &= np.isfinite(array)
238        if not np.all(good):
239            for name in fields:
240                array = getattr(data, name)
241                if array is not None:
242                    setattr(data, name, array[good])
243        return data
244
245    def sort_two_d_data(self):
246        for dataset in self.output:
247            if isinstance(dataset, Data2D):
248                # Normalize the units for
249                dataset.x_unit = self.format_unit(dataset.Q_unit)
250                dataset.y_unit = self.format_unit(dataset.I_unit)
251                dataset.data = dataset.data.astype(np.float64)
252                dataset.qx_data = dataset.qx_data.astype(np.float64)
253                dataset.xmin = np.min(dataset.qx_data)
254                dataset.xmax = np.max(dataset.qx_data)
255                dataset.qy_data = dataset.qy_data.astype(np.float64)
256                dataset.ymin = np.min(dataset.qy_data)
257                dataset.ymax = np.max(dataset.qy_data)
258                dataset.q_data = np.sqrt(dataset.qx_data * dataset.qx_data
259                                         + dataset.qy_data * dataset.qy_data)
260                if dataset.err_data is not None:
261                    dataset.err_data = dataset.err_data.astype(np.float64)
262                if dataset.dqx_data is not None:
263                    dataset.dqx_data = dataset.dqx_data.astype(np.float64)
264                if dataset.dqy_data is not None:
265                    dataset.dqy_data = dataset.dqy_data.astype(np.float64)
266                if dataset.mask is not None:
267                    dataset.mask = dataset.mask.astype(dtype=bool)
268
269                if len(dataset.data.shape) == 2:
270                    n_rows, n_cols = dataset.data.shape
271                    dataset.y_bins = dataset.qy_data[0::int(n_cols)]
272                    dataset.x_bins = dataset.qx_data[:int(n_cols)]
273                dataset.data = dataset.data.flatten()
274                dataset = self._remove_nans_in_data(dataset)
275                if len(dataset.data) > 0:
276                    dataset.xmin = np.min(dataset.qx_data)
277                    dataset.xmax = np.max(dataset.qx_data)
278                    dataset.ymin = np.min(dataset.qy_data)
279                    dataset.ymax = np.max(dataset.qx_data)
280
281    def format_unit(self, unit=None):
282        """
283        Format units a common way
284        :param unit:
285        :return:
286        """
287        if unit:
288            split = unit.split("/")
289            if len(split) == 1:
290                return unit
291            elif split[0] == '1':
292                return "{0}^".format(split[1]) + "{-1}"
293            else:
294                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
295
296    def set_all_to_none(self):
297        """
298        Set all mutable values to None for error handling purposes
299        """
300        self.current_dataset = None
301        self.current_datainfo = None
302        self.output = []
303
304    def data_cleanup(self):
305        """
306        Clean up the data sets and refresh everything
307        :return: None
308        """
309        self.remove_empty_q_values()
310        self.send_to_output()  # Combine datasets with DataInfo
311        self.current_datainfo = DataInfo()  # Reset DataInfo
312
313    def remove_empty_q_values(self):
314        """
315        Remove any point where Q == 0
316        """
317        if isinstance(self.current_dataset, plottable_1D):
318            # Booleans for resolutions
319            has_error_dx = self.current_dataset.dx is not None
320            has_error_dxl = self.current_dataset.dxl is not None
321            has_error_dxw = self.current_dataset.dxw is not None
322            has_error_dy = self.current_dataset.dy is not None
323            # Create arrays of zeros for non-existent resolutions
324            if has_error_dxw and not has_error_dxl:
325                array_size = self.current_dataset.dxw.size - 1
326                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
327                                                    np.zeros([array_size]))
328                has_error_dxl = True
329            elif has_error_dxl and not has_error_dxw:
330                array_size = self.current_dataset.dxl.size - 1
331                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
332                                                    np.zeros([array_size]))
333                has_error_dxw = True
334            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
335                array_size = self.current_dataset.x.size - 1
336                self.current_dataset.dx = np.append(self.current_dataset.dx,
337                                                    np.zeros([array_size]))
338                has_error_dx = True
339            if not has_error_dy:
340                array_size = self.current_dataset.y.size - 1
341                self.current_dataset.dy = np.append(self.current_dataset.dy,
342                                                    np.zeros([array_size]))
343                has_error_dy = True
344
345            # Remove points where q = 0
346            x = self.current_dataset.x
347            self.current_dataset.x = self.current_dataset.x[x != 0]
348            self.current_dataset.y = self.current_dataset.y[x != 0]
349            if has_error_dy:
350                self.current_dataset.dy = self.current_dataset.dy[x != 0]
351            if has_error_dx:
352                self.current_dataset.dx = self.current_dataset.dx[x != 0]
353            if has_error_dxl:
354                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
355            if has_error_dxw:
356                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
357        elif isinstance(self.current_dataset, plottable_2D):
358            has_error_dqx = self.current_dataset.dqx_data is not None
359            has_error_dqy = self.current_dataset.dqy_data is not None
360            has_error_dy = self.current_dataset.err_data is not None
361            has_mask = self.current_dataset.mask is not None
362            x = self.current_dataset.qx_data
363            self.current_dataset.data = self.current_dataset.data[x != 0]
364            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
365            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
366            self.current_dataset.q_data = np.sqrt(
367                np.square(self.current_dataset.qx_data) + np.square(
368                    self.current_dataset.qy_data))
369            if has_error_dy:
370                self.current_dataset.err_data = self.current_dataset.err_data[x != 0]
371            if has_error_dqx:
372                self.current_dataset.dqx_data = self.current_dataset.dqx_data[x != 0]
373            if has_error_dqy:
374                self.current_dataset.dqy_data = self.current_dataset.dqy_data[x != 0]
375            if has_mask:
376                self.current_dataset.mask = self.current_dataset.mask[x != 0]
377
378    def reset_data_list(self, no_lines=0):
379        """
380        Reset the plottable_1D object
381        """
382        # Initialize data sets with arrays the maximum possible size
383        x = np.zeros(no_lines)
384        y = np.zeros(no_lines)
385        dx = np.zeros(no_lines)
386        dy = np.zeros(no_lines)
387        self.current_dataset = plottable_1D(x, y, dx, dy)
388
389    @staticmethod
390    def splitline(line):
391        """
392        Splits a line into pieces based on common delimiters
393        :param line: A single line of text
394        :return: list of values
395        """
396        # Initial try for CSV (split on ,)
397        toks = line.split(',')
398        # Now try SCSV (split on ;)
399        if len(toks) < 2:
400            toks = line.split(';')
401        # Now go for whitespace
402        if len(toks) < 2:
403            toks = line.split()
404        return toks
405
406    @abstractmethod
407    def get_file_contents(self):
408        """
409        Reader specific class to access the contents of the file
410        All reader classes that inherit from FileReader must implement
411        """
412        pass
Note: See TracBrowser for help on using the repository browser.