source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ c36c09f

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249unittest-saveload
Last change on this file since c36c09f was c36c09f, checked in by krzywon, 6 years ago

Show a dialog warning the user the data set may not load properly. All logic is stored within sascalc.dataloader.

  • Property mode set to 100644
File size: 16.1 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import math
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18
19logger = logging.getLogger(__name__)
20
21if sys.version_info[0] < 3:
22    def decode(s):
23        return s
24else:
25    def decode(s):
26        return s.decode() if isinstance(s, bytes) else s
27
28# Data 1D fields for iterative purposes
29FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw')
30# Data 2D fields for iterative purposes
31FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data',
32                 'dqx_data', 'dqy_data', 'mask')
33DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh"
34                       "t not be fully reduced. Support for the reader associat"
35                       "ed with this file type has been removed. An attempt to "
36                       "load the file was made, but SasView cannot guarantee th"
37                       "e accuracy of the data.")
38
39
40class FileReader(object):
41    # String to describe the type of data this reader can load
42    type_name = "ASCII"
43    # Wildcards to display
44    type = ["Text files (*.txt|*.TXT)"]
45    # List of allowed extensions
46    ext = ['.txt']
47    # Deprecated extensions
48    deprecated_extensions = ['.asc', '.nxs']
49    # Bypass extension check and try to load anyway
50    allow_all = False
51    # Able to import the unit converter
52    has_converter = True
53    # Default value of zero
54    _ZERO = 1e-16
55
56    def __init__(self):
57        # List of Data1D and Data2D objects to be sent back to data_loader
58        self.output = []
59        # Current plottable_(1D/2D) object being loaded in
60        self.current_dataset = None
61        # Current DataInfo object being loaded in
62        self.current_datainfo = None
63        # File path sent to reader
64        self.filepath = None
65        # Open file handle
66        self.f_open = None
67
68    def read(self, filepath):
69        """
70        Basic file reader
71
72        :param filepath: The full or relative path to a file to be loaded
73        """
74        self.filepath = filepath
75        if os.path.isfile(filepath):
76            basename, extension = os.path.splitext(os.path.basename(filepath))
77            self.extension = extension.lower()
78            # If the file type is not allowed, return nothing
79            if self.extension in self.ext or self.allow_all:
80                # Try to load the file, but raise an error if unable to.
81                try:
82                    self.f_open = open(filepath, 'rb')
83                    self.get_file_contents()
84
85                except DataReaderException as e:
86                    self.handle_error_message(e.message)
87                except OSError as e:
88                    # If the file cannot be opened
89                    msg = "Unable to open file: {}\n".format(filepath)
90                    msg += e.message
91                    self.handle_error_message(msg)
92                finally:
93                    # Close the file handle if it is open
94                    if not self.f_open.closed:
95                        self.f_open.close()
96                    if len(self.output) > 0:
97                        if any(filepath.lower().endswith(ext) for ext in
98                               self.deprecated_extensions):
99                            self.handle_error_message(DEPRECATION_MESSAGE)
100                        # Sort the data that's been loaded
101                        self.sort_one_d_data()
102                        self.sort_two_d_data()
103        else:
104            msg = "Unable to find file at: {}\n".format(filepath)
105            msg += "Please check your file path and try again."
106            self.handle_error_message(msg)
107
108        # Return a list of parsed entries that data_loader can manage
109        final_data = self.output
110        self.reset_state()
111        return final_data
112
113    def reset_state(self):
114        """
115        Resets the class state to a base case when loading a new data file so previous
116        data files do not appear a second time
117        """
118        self.current_datainfo = None
119        self.current_dataset = None
120        self.filepath = None
121        self.ind = None
122        self.output = []
123
124    def nextline(self):
125        """
126        Returns the next line in the file as a string.
127        """
128        #return self.f_open.readline()
129        return decode(self.f_open.readline())
130
131    def nextlines(self):
132        """
133        Returns the next line in the file as a string.
134        """
135        for line in self.f_open:
136            #yield line
137            yield decode(line)
138
139    def readall(self):
140        """
141        Returns the entire file as a string.
142        """
143        #return self.f_open.read()
144        return decode(self.f_open.read())
145
146    def handle_error_message(self, msg):
147        """
148        Generic error handler to add an error to the current datainfo to
149        propagate the error up the error chain.
150        :param msg: Error message
151        """
152        if len(self.output) > 0:
153            self.output[-1].errors.append(msg)
154        elif isinstance(self.current_datainfo, DataInfo):
155            self.current_datainfo.errors.append(msg)
156        else:
157            logger.warning(msg)
158
159    def send_to_output(self):
160        """
161        Helper that automatically combines the info and set and then appends it
162        to output
163        """
164        data_obj = combine_data_info_with_plottable(self.current_dataset,
165                                                    self.current_datainfo)
166        self.output.append(data_obj)
167
168    def sort_one_d_data(self):
169        """
170        Sort 1D data along the X axis for consistency
171        """
172        for data in self.output:
173            if isinstance(data, Data1D):
174                # Normalize the units for
175                data.x_unit = self.format_unit(data.x_unit)
176                data.y_unit = self.format_unit(data.y_unit)
177                # Sort data by increasing x and remove 1st point
178                ind = np.lexsort((data.y, data.x))
179                data.x = self._reorder_1d_array(data.x, ind)
180                data.y = self._reorder_1d_array(data.y, ind)
181                if data.dx is not None:
182                    if len(data.dx) == 0:
183                        data.dx = None
184                        continue
185                    data.dx = self._reorder_1d_array(data.dx, ind)
186                if data.dxl is not None:
187                    data.dxl = self._reorder_1d_array(data.dxl, ind)
188                if data.dxw is not None:
189                    data.dxw = self._reorder_1d_array(data.dxw, ind)
190                if data.dy is not None:
191                    if len(data.dy) == 0:
192                        data.dy = None
193                        continue
194                    data.dy = self._reorder_1d_array(data.dy, ind)
195                if data.lam is not None:
196                    data.lam = self._reorder_1d_array(data.lam, ind)
197                if data.dlam is not None:
198                    data.dlam = self._reorder_1d_array(data.dlam, ind)
199                data = self._remove_nans_in_data(data)
200                if len(data.x) > 0:
201                    data.xmin = np.min(data.x)
202                    data.xmax = np.max(data.x)
203                    data.ymin = np.min(data.y)
204                    data.ymax = np.max(data.y)
205
206    @staticmethod
207    def _reorder_1d_array(array, ind):
208        """
209        Reorders a 1D array based on the indices passed as ind
210        :param array: Array to be reordered
211        :param ind: Indices used to reorder array
212        :return: reordered array
213        """
214        array = np.asarray(array, dtype=np.float64)
215        return array[ind]
216
217    @staticmethod
218    def _remove_nans_in_data(data):
219        """
220        Remove data points where nan is loaded
221        :param data: 1D or 2D data object
222        :return: data with nan points removed
223        """
224        if isinstance(data, Data1D):
225            fields = FIELDS_1D
226        elif isinstance(data, Data2D):
227            fields = FIELDS_2D
228        else:
229            return data
230        # Make array of good points - all others will be removed
231        good = np.isfinite(getattr(data, fields[0]))
232        for name in fields[1:]:
233            array = getattr(data, name)
234            if array is not None:
235                # Update good points only if not already changed
236                good &= np.isfinite(array)
237        if not np.all(good):
238            for name in fields:
239                array = getattr(data, name)
240                if array is not None:
241                    setattr(data, name, array[good])
242        return data
243
244    def sort_two_d_data(self):
245        for dataset in self.output:
246            if isinstance(dataset, Data2D):
247                # Normalize the units for
248                dataset.x_unit = self.format_unit(dataset.Q_unit)
249                dataset.y_unit = self.format_unit(dataset.I_unit)
250                dataset.data = dataset.data.astype(np.float64)
251                dataset.qx_data = dataset.qx_data.astype(np.float64)
252                dataset.xmin = np.min(dataset.qx_data)
253                dataset.xmax = np.max(dataset.qx_data)
254                dataset.qy_data = dataset.qy_data.astype(np.float64)
255                dataset.ymin = np.min(dataset.qy_data)
256                dataset.ymax = np.max(dataset.qy_data)
257                dataset.q_data = np.sqrt(dataset.qx_data * dataset.qx_data
258                                         + dataset.qy_data * dataset.qy_data)
259                if dataset.err_data is not None:
260                    dataset.err_data = dataset.err_data.astype(np.float64)
261                if dataset.dqx_data is not None:
262                    dataset.dqx_data = dataset.dqx_data.astype(np.float64)
263                if dataset.dqy_data is not None:
264                    dataset.dqy_data = dataset.dqy_data.astype(np.float64)
265                if dataset.mask is not None:
266                    dataset.mask = dataset.mask.astype(dtype=bool)
267
268                if len(dataset.data.shape) == 2:
269                    n_rows, n_cols = dataset.data.shape
270                    dataset.y_bins = dataset.qy_data[0::int(n_cols)]
271                    dataset.x_bins = dataset.qx_data[:int(n_cols)]
272                dataset.data = dataset.data.flatten()
273                dataset = self._remove_nans_in_data(dataset)
274                if len(dataset.data) > 0:
275                    dataset.xmin = np.min(dataset.qx_data)
276                    dataset.xmax = np.max(dataset.qx_data)
277                    dataset.ymin = np.min(dataset.qy_data)
278                    dataset.ymax = np.max(dataset.qx_data)
279
280    def format_unit(self, unit=None):
281        """
282        Format units a common way
283        :param unit:
284        :return:
285        """
286        if unit:
287            split = unit.split("/")
288            if len(split) == 1:
289                return unit
290            elif split[0] == '1':
291                return "{0}^".format(split[1]) + "{-1}"
292            else:
293                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
294
295    def set_all_to_none(self):
296        """
297        Set all mutable values to None for error handling purposes
298        """
299        self.current_dataset = None
300        self.current_datainfo = None
301        self.output = []
302
303    def data_cleanup(self):
304        """
305        Clean up the data sets and refresh everything
306        :return: None
307        """
308        self.remove_empty_q_values()
309        self.send_to_output()  # Combine datasets with DataInfo
310        self.current_datainfo = DataInfo()  # Reset DataInfo
311
312    def remove_empty_q_values(self):
313        """
314        Remove any point where Q == 0
315        """
316        if isinstance(self.current_dataset, plottable_1D):
317            # Booleans for resolutions
318            has_error_dx = self.current_dataset.dx is not None
319            has_error_dxl = self.current_dataset.dxl is not None
320            has_error_dxw = self.current_dataset.dxw is not None
321            has_error_dy = self.current_dataset.dy is not None
322            # Create arrays of zeros for non-existent resolutions
323            if has_error_dxw and not has_error_dxl:
324                array_size = self.current_dataset.dxw.size - 1
325                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
326                                                    np.zeros([array_size]))
327                has_error_dxl = True
328            elif has_error_dxl and not has_error_dxw:
329                array_size = self.current_dataset.dxl.size - 1
330                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
331                                                    np.zeros([array_size]))
332                has_error_dxw = True
333            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
334                array_size = self.current_dataset.x.size - 1
335                self.current_dataset.dx = np.append(self.current_dataset.dx,
336                                                    np.zeros([array_size]))
337                has_error_dx = True
338            if not has_error_dy:
339                array_size = self.current_dataset.y.size - 1
340                self.current_dataset.dy = np.append(self.current_dataset.dy,
341                                                    np.zeros([array_size]))
342                has_error_dy = True
343
344            # Remove points where q = 0
345            x = self.current_dataset.x
346            self.current_dataset.x = self.current_dataset.x[x != 0]
347            self.current_dataset.y = self.current_dataset.y[x != 0]
348            if has_error_dy:
349                self.current_dataset.dy = self.current_dataset.dy[x != 0]
350            if has_error_dx:
351                self.current_dataset.dx = self.current_dataset.dx[x != 0]
352            if has_error_dxl:
353                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
354            if has_error_dxw:
355                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
356        elif isinstance(self.current_dataset, plottable_2D):
357            has_error_dqx = self.current_dataset.dqx_data is not None
358            has_error_dqy = self.current_dataset.dqy_data is not None
359            has_error_dy = self.current_dataset.err_data is not None
360            has_mask = self.current_dataset.mask is not None
361            x = self.current_dataset.qx_data
362            self.current_dataset.data = self.current_dataset.data[x != 0]
363            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
364            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
365            self.current_dataset.q_data = np.sqrt(
366                np.square(self.current_dataset.qx_data) + np.square(
367                    self.current_dataset.qy_data))
368            if has_error_dy:
369                self.current_dataset.err_data = self.current_dataset.err_data[x != 0]
370            if has_error_dqx:
371                self.current_dataset.dqx_data = self.current_dataset.dqx_data[x != 0]
372            if has_error_dqy:
373                self.current_dataset.dqy_data = self.current_dataset.dqy_data[x != 0]
374            if has_mask:
375                self.current_dataset.mask = self.current_dataset.mask[x != 0]
376
377    def reset_data_list(self, no_lines=0):
378        """
379        Reset the plottable_1D object
380        """
381        # Initialize data sets with arrays the maximum possible size
382        x = np.zeros(no_lines)
383        y = np.zeros(no_lines)
384        dx = np.zeros(no_lines)
385        dy = np.zeros(no_lines)
386        self.current_dataset = plottable_1D(x, y, dx, dy)
387
388    @staticmethod
389    def splitline(line):
390        """
391        Splits a line into pieces based on common delimiters
392        :param line: A single line of text
393        :return: list of values
394        """
395        # Initial try for CSV (split on ,)
396        toks = line.split(',')
397        # Now try SCSV (split on ;)
398        if len(toks) < 2:
399            toks = line.split(';')
400        # Now go for whitespace
401        if len(toks) < 2:
402            toks = line.split()
403        return toks
404
405    @abstractmethod
406    def get_file_contents(self):
407        """
408        Reader specific class to access the contents of the file
409        All reader classes that inherit from FileReader must implement
410        """
411        pass
Note: See TracBrowser for help on using the repository browser.