source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ cb11a25

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since cb11a25 was cb11a25, checked in by Stuart Prescott <stuart@…>, 6 years ago

Fix instance vs class variables in FileReader?

Class variables get inherited and shared between different instantiations
of the class and this means that the current usage of TestFileReader? in
utest_generic_file_reader_class.py has side-effects that would cause
subsequent tests to fail.

  • Property mode set to 100644
File size: 13.9 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import re
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18
19logger = logging.getLogger(__name__)
20
21if sys.version_info[0] < 3:
22    def decode(s):
23        return s
24else:
25    def decode(s):
26        return s.decode() if isinstance(s, bytes) else s
27
28class FileReader(object):
29    # String to describe the type of data this reader can load
30    type_name = "ASCII"
31    # Wildcards to display
32    type = ["Text files (*.txt|*.TXT)"]
33    # List of allowed extensions
34    ext = ['.txt']
35    # Bypass extension check and try to load anyway
36    allow_all = False
37    # Able to import the unit converter
38    has_converter = True
39    # Default value of zero
40    _ZERO = 1e-16
41
42    def __init__(self):
43        # List of Data1D and Data2D objects to be sent back to data_loader
44        self.output = []
45        # Current plottable_(1D/2D) object being loaded in
46        self.current_dataset = None
47        # Current DataInfo object being loaded in
48        self.current_datainfo = None
49        # Open file handle
50        self.f_open = None
51
52    def read(self, filepath):
53        """
54        Basic file reader
55
56        :param filepath: The full or relative path to a file to be loaded
57        """
58        if os.path.isfile(filepath):
59            basename, extension = os.path.splitext(os.path.basename(filepath))
60            self.extension = extension.lower()
61            # If the file type is not allowed, return nothing
62            if self.extension in self.ext or self.allow_all:
63                # Try to load the file, but raise an error if unable to.
64                try:
65                    self.f_open = open(filepath, 'rb')
66                    self.get_file_contents()
67
68                except DataReaderException as e:
69                    self.handle_error_message(e.message)
70                except OSError as e:
71                    # If the file cannot be opened
72                    msg = "Unable to open file: {}\n".format(filepath)
73                    msg += e.message
74                    self.handle_error_message(msg)
75                finally:
76                    # Close the file handle if it is open
77                    if not self.f_open.closed:
78                        self.f_open.close()
79                    if len(self.output) > 0:
80                        # Sort the data that's been loaded
81                        self.sort_one_d_data()
82                        self.sort_two_d_data()
83        else:
84            msg = "Unable to find file at: {}\n".format(filepath)
85            msg += "Please check your file path and try again."
86            self.handle_error_message(msg)
87
88        # Return a list of parsed entries that data_loader can manage
89        final_data = self.output
90        self.reset_state()
91        return final_data
92
93    def reset_state(self):
94        """
95        Resets the class state to a base case when loading a new data file so previous
96        data files do not appear a second time
97        """
98        self.current_datainfo = None
99        self.current_dataset = None
100        self.output = []
101
102    def nextline(self):
103        """
104        Returns the next line in the file as a string.
105        """
106        #return self.f_open.readline()
107        return decode(self.f_open.readline())
108
109    def nextlines(self):
110        """
111        Returns the next line in the file as a string.
112        """
113        for line in self.f_open:
114            #yield line
115            yield decode(line)
116
117    def readall(self):
118        """
119        Returns the entire file as a string.
120        """
121        #return self.f_open.read()
122        return decode(self.f_open.read())
123
124    def handle_error_message(self, msg):
125        """
126        Generic error handler to add an error to the current datainfo to
127        propagate the error up the error chain.
128        :param msg: Error message
129        """
130        if len(self.output) > 0:
131            self.output[-1].errors.append(msg)
132        elif isinstance(self.current_datainfo, DataInfo):
133            self.current_datainfo.errors.append(msg)
134        else:
135            logger.warning(msg)
136
137    def send_to_output(self):
138        """
139        Helper that automatically combines the info and set and then appends it
140        to output
141        """
142        data_obj = combine_data_info_with_plottable(self.current_dataset,
143                                                    self.current_datainfo)
144        self.output.append(data_obj)
145
146    def sort_one_d_data(self):
147        """
148        Sort 1D data along the X axis for consistency
149        """
150        for data in self.output:
151            if isinstance(data, Data1D):
152                # Normalize the units for
153                data.x_unit = self.format_unit(data.x_unit)
154                data.y_unit = self.format_unit(data.y_unit)
155                # Sort data by increasing x and remove 1st point
156                ind = np.lexsort((data.y, data.x))
157                data.x = np.asarray([data.x[i] for i in ind]).astype(np.float64)
158                data.y = np.asarray([data.y[i] for i in ind]).astype(np.float64)
159                if data.dx is not None:
160                    if len(data.dx) == 0:
161                        data.dx = None
162                        continue
163                    data.dx = np.asarray([data.dx[i] for i in ind]).astype(np.float64)
164                if data.dxl is not None:
165                    data.dxl = np.asarray([data.dxl[i] for i in ind]).astype(np.float64)
166                if data.dxw is not None:
167                    data.dxw = np.asarray([data.dxw[i] for i in ind]).astype(np.float64)
168                if data.dy is not None:
169                    if len(data.dy) == 0:
170                        data.dy = None
171                        continue
172                    data.dy = np.asarray([data.dy[i] for i in ind]).astype(np.float64)
173                if data.lam is not None:
174                    data.lam = np.asarray([data.lam[i] for i in ind]).astype(np.float64)
175                if data.dlam is not None:
176                    data.dlam = np.asarray([data.dlam[i] for i in ind]).astype(np.float64)
177                if len(data.x) > 0:
178                    data.xmin = np.min(data.x)
179                    data.xmax = np.max(data.x)
180                    data.ymin = np.min(data.y)
181                    data.ymax = np.max(data.y)
182
183    def sort_two_d_data(self):
184        for dataset in self.output:
185            if isinstance(dataset, Data2D):
186                # Normalize the units for
187                dataset.x_unit = self.format_unit(dataset.Q_unit)
188                dataset.y_unit = self.format_unit(dataset.I_unit)
189                dataset.data = dataset.data.astype(np.float64)
190                dataset.qx_data = dataset.qx_data.astype(np.float64)
191                dataset.xmin = np.min(dataset.qx_data)
192                dataset.xmax = np.max(dataset.qx_data)
193                dataset.qy_data = dataset.qy_data.astype(np.float64)
194                dataset.ymin = np.min(dataset.qy_data)
195                dataset.ymax = np.max(dataset.qy_data)
196                dataset.q_data = np.sqrt(dataset.qx_data * dataset.qx_data
197                                         + dataset.qy_data * dataset.qy_data)
198                if dataset.err_data is not None:
199                    dataset.err_data = dataset.err_data.astype(np.float64)
200                if dataset.dqx_data is not None:
201                    dataset.dqx_data = dataset.dqx_data.astype(np.float64)
202                if dataset.dqy_data is not None:
203                    dataset.dqy_data = dataset.dqy_data.astype(np.float64)
204                if dataset.mask is not None:
205                    dataset.mask = dataset.mask.astype(dtype=bool)
206
207                if len(dataset.data.shape) == 2:
208                    n_rows, n_cols = dataset.data.shape
209                    dataset.y_bins = dataset.qy_data[0::int(n_cols)]
210                    dataset.x_bins = dataset.qx_data[:int(n_cols)]
211                dataset.data = dataset.data.flatten()
212                if len(dataset.data) > 0:
213                    dataset.xmin = np.min(dataset.qx_data)
214                    dataset.xmax = np.max(dataset.qx_data)
215                    dataset.ymin = np.min(dataset.qy_data)
216                    dataset.ymax = np.max(dataset.qx_data)
217
218    def format_unit(self, unit=None):
219        """
220        Format units a common way
221        :param unit:
222        :return:
223        """
224        if unit:
225            split = unit.split("/")
226            if len(split) == 1:
227                return unit
228            elif split[0] == '1':
229                return "{0}^".format(split[1]) + "{-1}"
230            else:
231                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
232
233    def set_all_to_none(self):
234        """
235        Set all mutable values to None for error handling purposes
236        """
237        self.current_dataset = None
238        self.current_datainfo = None
239        self.output = []
240
241    def data_cleanup(self):
242        """
243        Clean up the data sets and refresh everything
244        :return: None
245        """
246        self.remove_empty_q_values()
247        self.send_to_output()  # Combine datasets with DataInfo
248        self.current_datainfo = DataInfo()  # Reset DataInfo
249
250    def remove_empty_q_values(self):
251        """
252        Remove any point where Q == 0
253        """
254        if isinstance(self.current_dataset, plottable_1D):
255            # Booleans for resolutions
256            has_error_dx = self.current_dataset.dx is not None
257            has_error_dxl = self.current_dataset.dxl is not None
258            has_error_dxw = self.current_dataset.dxw is not None
259            has_error_dy = self.current_dataset.dy is not None
260            # Create arrays of zeros for non-existent resolutions
261            if has_error_dxw and not has_error_dxl:
262                array_size = self.current_dataset.dxw.size - 1
263                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
264                                                    np.zeros([array_size]))
265                has_error_dxl = True
266            elif has_error_dxl and not has_error_dxw:
267                array_size = self.current_dataset.dxl.size - 1
268                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
269                                                    np.zeros([array_size]))
270                has_error_dxw = True
271            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
272                array_size = self.current_dataset.x.size - 1
273                self.current_dataset.dx = np.append(self.current_dataset.dx,
274                                                    np.zeros([array_size]))
275                has_error_dx = True
276            if not has_error_dy:
277                array_size = self.current_dataset.y.size - 1
278                self.current_dataset.dy = np.append(self.current_dataset.dy,
279                                                    np.zeros([array_size]))
280                has_error_dy = True
281
282            # Remove points where q = 0
283            x = self.current_dataset.x
284            self.current_dataset.x = self.current_dataset.x[x != 0]
285            self.current_dataset.y = self.current_dataset.y[x != 0]
286            if has_error_dy:
287                self.current_dataset.dy = self.current_dataset.dy[x != 0]
288            if has_error_dx:
289                self.current_dataset.dx = self.current_dataset.dx[x != 0]
290            if has_error_dxl:
291                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
292            if has_error_dxw:
293                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
294        elif isinstance(self.current_dataset, plottable_2D):
295            has_error_dqx = self.current_dataset.dqx_data is not None
296            has_error_dqy = self.current_dataset.dqy_data is not None
297            has_error_dy = self.current_dataset.err_data is not None
298            has_mask = self.current_dataset.mask is not None
299            x = self.current_dataset.qx_data
300            self.current_dataset.data = self.current_dataset.data[x != 0]
301            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
302            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
303            self.current_dataset.q_data = np.sqrt(
304                np.square(self.current_dataset.qx_data) + np.square(
305                    self.current_dataset.qy_data))
306            if has_error_dy:
307                self.current_dataset.err_data = self.current_dataset.err_data[x != 0]
308            if has_error_dqx:
309                self.current_dataset.dqx_data = self.current_dataset.dqx_data[x != 0]
310            if has_error_dqy:
311                self.current_dataset.dqy_data = self.current_dataset.dqy_data[x != 0]
312            if has_mask:
313                self.current_dataset.mask = self.current_dataset.mask[x != 0]
314
315    def reset_data_list(self, no_lines=0):
316        """
317        Reset the plottable_1D object
318        """
319        # Initialize data sets with arrays the maximum possible size
320        x = np.zeros(no_lines)
321        y = np.zeros(no_lines)
322        dx = np.zeros(no_lines)
323        dy = np.zeros(no_lines)
324        self.current_dataset = plottable_1D(x, y, dx, dy)
325
326    @staticmethod
327    def splitline(line):
328        """
329        Splits a line into pieces based on common delimiters
330        :param line: A single line of text
331        :return: list of values
332        """
333        # Initial try for CSV (split on ,)
334        toks = line.split(',')
335        # Now try SCSV (split on ;)
336        if len(toks) < 2:
337            toks = line.split(';')
338        # Now go for whitespace
339        if len(toks) < 2:
340            toks = line.split()
341        return toks
342
343    @abstractmethod
344    def get_file_contents(self):
345        """
346        Reader specific class to access the contents of the file
347        All reader classes that inherit from FileReader must implement
348        """
349        pass
Note: See TracBrowser for help on using the repository browser.