source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ f02a0c6

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since f02a0c6 was f02a0c6, checked in by krzywon, 6 years ago

Extend nan removal to 2D data.

  • Property mode set to 100644
File size: 16.6 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import math
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18
19logger = logging.getLogger(__name__)
20
21if sys.version_info[0] < 3:
22    def decode(s):
23        return s
24else:
25    def decode(s):
26        return s.decode() if isinstance(s, bytes) else s
27
28class FileReader(object):
29    # String to describe the type of data this reader can load
30    type_name = "ASCII"
31    # Wildcards to display
32    type = ["Text files (*.txt|*.TXT)"]
33    # List of allowed extensions
34    ext = ['.txt']
35    # Bypass extension check and try to load anyway
36    allow_all = False
37    # Able to import the unit converter
38    has_converter = True
39    # Default value of zero
40    _ZERO = 1e-16
41
42    def __init__(self):
43        # List of Data1D and Data2D objects to be sent back to data_loader
44        self.output = []
45        # Current plottable_(1D/2D) object being loaded in
46        self.current_dataset = None
47        # Current DataInfo object being loaded in
48        self.current_datainfo = None
49        # File path sent to reader
50        self.filepath = None
51        # Open file handle
52        self.f_open = None
53
54    def read(self, filepath):
55        """
56        Basic file reader
57
58        :param filepath: The full or relative path to a file to be loaded
59        """
60        self.filepath = filepath
61        if os.path.isfile(filepath):
62            basename, extension = os.path.splitext(os.path.basename(filepath))
63            self.extension = extension.lower()
64            # If the file type is not allowed, return nothing
65            if self.extension in self.ext or self.allow_all:
66                # Try to load the file, but raise an error if unable to.
67                try:
68                    self.f_open = open(filepath, 'rb')
69                    self.get_file_contents()
70
71                except DataReaderException as e:
72                    self.handle_error_message(e.message)
73                except OSError as e:
74                    # If the file cannot be opened
75                    msg = "Unable to open file: {}\n".format(filepath)
76                    msg += e.message
77                    self.handle_error_message(msg)
78                finally:
79                    # Close the file handle if it is open
80                    if not self.f_open.closed:
81                        self.f_open.close()
82                    if len(self.output) > 0:
83                        # Sort the data that's been loaded
84                        self.sort_one_d_data()
85                        self.sort_two_d_data()
86        else:
87            msg = "Unable to find file at: {}\n".format(filepath)
88            msg += "Please check your file path and try again."
89            self.handle_error_message(msg)
90
91        # Return a list of parsed entries that data_loader can manage
92        final_data = self.output
93        self.reset_state()
94        return final_data
95
96    def reset_state(self):
97        """
98        Resets the class state to a base case when loading a new data file so previous
99        data files do not appear a second time
100        """
101        self.current_datainfo = None
102        self.current_dataset = None
103        self.filepath = None
104        self.ind = None
105        self.output = []
106
107    def nextline(self):
108        """
109        Returns the next line in the file as a string.
110        """
111        #return self.f_open.readline()
112        return decode(self.f_open.readline())
113
114    def nextlines(self):
115        """
116        Returns the next line in the file as a string.
117        """
118        for line in self.f_open:
119            #yield line
120            yield decode(line)
121
122    def readall(self):
123        """
124        Returns the entire file as a string.
125        """
126        #return self.f_open.read()
127        return decode(self.f_open.read())
128
129    def handle_error_message(self, msg):
130        """
131        Generic error handler to add an error to the current datainfo to
132        propagate the error up the error chain.
133        :param msg: Error message
134        """
135        if len(self.output) > 0:
136            self.output[-1].errors.append(msg)
137        elif isinstance(self.current_datainfo, DataInfo):
138            self.current_datainfo.errors.append(msg)
139        else:
140            logger.warning(msg)
141
142    def send_to_output(self):
143        """
144        Helper that automatically combines the info and set and then appends it
145        to output
146        """
147        data_obj = combine_data_info_with_plottable(self.current_dataset,
148                                                    self.current_datainfo)
149        self.output.append(data_obj)
150
151    def sort_one_d_data(self):
152        """
153        Sort 1D data along the X axis for consistency
154        """
155        for data in self.output:
156            if isinstance(data, Data1D):
157                # Normalize the units for
158                data.x_unit = self.format_unit(data.x_unit)
159                data.y_unit = self.format_unit(data.y_unit)
160                # Sort data by increasing x and remove 1st point
161                ind = np.lexsort((data.y, data.x))
162                data.x = self._reorder_1d_array(data.x, ind)
163                data.y = self._reorder_1d_array(data.y, ind)
164                if data.dx is not None:
165                    if len(data.dx) == 0:
166                        data.dx = None
167                        continue
168                    data.dx = self._reorder_1d_array(data.dx, ind)
169                if data.dxl is not None:
170                    data.dxl = self._reorder_1d_array(data.dxl, ind)
171                if data.dxw is not None:
172                    data.dxw = self._reorder_1d_array(data.dxw, ind)
173                if data.dy is not None:
174                    if len(data.dy) == 0:
175                        data.dy = None
176                        continue
177                    data.dy = self._reorder_1d_array(data.dy, ind)
178                if data.lam is not None:
179                    data.lam = self._reorder_1d_array(data.lam, ind)
180                if data.dlam is not None:
181                    data.dlam = self._reorder_1d_array(data.dlam, ind)
182                data = self._remove_nans_in_data(data)
183                if len(data.x) > 0:
184                    data.xmin = np.min(data.x)
185                    data.xmax = np.max(data.x)
186                    data.ymin = np.min(data.y)
187                    data.ymax = np.max(data.y)
188
189    @staticmethod
190    def _reorder_1d_array(array, ind):
191        """
192        Reorders a 1D array based on the indices passed as ind
193        :param array: Array to be reordered
194        :param ind: Indices used to reorder array
195        :return: reordered array
196        """
197        array = np.asarray(array, dtype=np.float64)
198        return array[ind]
199
200    @staticmethod
201    def _remove_nans_in_data(data):
202        """
203        Remove data points where nan is loaded
204        :param data: 1D data set
205        :return: data with mask=0 for any value of nan in data .x, .y, .dx, .dy
206        """
207        if isinstance(data, Data1D):
208            mask = np.ones(data.x.shape)
209            data_list = [data.x, data.y, data.dx, data.dy, data.dxl, data.dxw]
210        elif isinstance(data, Data2D):
211            mask = np.ones(data.data.shape)
212            data_list = [data.data, data.qx_data, data.qy_data, data.q_data,
213                         data.err_data, data.dqx_data, data.dqy_data, data.mask]
214        else:
215            mask = np.ones(0)
216            data_list = []
217        for array in data_list:
218            if array is not None:
219                # Set mask[i] to 0 when data.<param> is nan
220                mask[np.isnan(array)] = 0
221        # Data indices to mask/remove from the data
222        nans = np.where(mask == 0)[0]
223        if len(nans) > 0:
224            if isinstance(data, Data1D):
225                data.x = np.delete(data.x, nans)
226                data.y = np.delete(data.y, nans)
227                if data.dx is not None:
228                    data.dx = np.delete(data.dx, nans)
229                if data.dxl is not None:
230                    data.dxl = np.delete(data.dxl, nans)
231                if data.dxw is not None:
232                    data.dxw = np.delete(data.dxw, nans)
233                if data.dy is not None:
234                    data.dy = np.delete(data.dy, nans)
235            elif isinstance(data, Data2D):
236                data.data = np.delete(data.data, nans)
237                data.qx_data = np.delete(data.qx_data, nans)
238                data.qy_data = np.delete(data.qy_data, nans)
239                if data.q_data is not None:
240                    data.q_data = np.delete(data.q_data, nans)
241                if data.err_data is not None:
242                    data.err_data = np.delete(data.err_data, nans)
243                if data.dqx_data is not None:
244                    data.dqx_data = np.delete(data.dqx_data, nans)
245                if data.dqy_data is not None:
246                    data.dqy_data = np.delete(data.dqy_data, nans)
247                if data.mask is not None:
248                    data.mask = np.delete(data.mask, nans)
249        return data
250
251    def sort_two_d_data(self):
252        for dataset in self.output:
253            if isinstance(dataset, Data2D):
254                # Normalize the units for
255                dataset.x_unit = self.format_unit(dataset.Q_unit)
256                dataset.y_unit = self.format_unit(dataset.I_unit)
257                dataset.data = dataset.data.astype(np.float64)
258                dataset.qx_data = dataset.qx_data.astype(np.float64)
259                dataset.xmin = np.min(dataset.qx_data)
260                dataset.xmax = np.max(dataset.qx_data)
261                dataset.qy_data = dataset.qy_data.astype(np.float64)
262                dataset.ymin = np.min(dataset.qy_data)
263                dataset.ymax = np.max(dataset.qy_data)
264                dataset.q_data = np.sqrt(dataset.qx_data * dataset.qx_data
265                                         + dataset.qy_data * dataset.qy_data)
266                if dataset.err_data is not None:
267                    dataset.err_data = dataset.err_data.astype(np.float64)
268                if dataset.dqx_data is not None:
269                    dataset.dqx_data = dataset.dqx_data.astype(np.float64)
270                if dataset.dqy_data is not None:
271                    dataset.dqy_data = dataset.dqy_data.astype(np.float64)
272                if dataset.mask is not None:
273                    dataset.mask = dataset.mask.astype(dtype=bool)
274
275                if len(dataset.data.shape) == 2:
276                    n_rows, n_cols = dataset.data.shape
277                    dataset.y_bins = dataset.qy_data[0::int(n_cols)]
278                    dataset.x_bins = dataset.qx_data[:int(n_cols)]
279                dataset.data = dataset.data.flatten()
280                dataset = self._remove_nans_in_data(dataset)
281                if len(dataset.data) > 0:
282                    dataset.xmin = np.min(dataset.qx_data)
283                    dataset.xmax = np.max(dataset.qx_data)
284                    dataset.ymin = np.min(dataset.qy_data)
285                    dataset.ymax = np.max(dataset.qx_data)
286
287    def format_unit(self, unit=None):
288        """
289        Format units a common way
290        :param unit:
291        :return:
292        """
293        if unit:
294            split = unit.split("/")
295            if len(split) == 1:
296                return unit
297            elif split[0] == '1':
298                return "{0}^".format(split[1]) + "{-1}"
299            else:
300                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
301
302    def set_all_to_none(self):
303        """
304        Set all mutable values to None for error handling purposes
305        """
306        self.current_dataset = None
307        self.current_datainfo = None
308        self.output = []
309
310    def data_cleanup(self):
311        """
312        Clean up the data sets and refresh everything
313        :return: None
314        """
315        self.remove_empty_q_values()
316        self.send_to_output()  # Combine datasets with DataInfo
317        self.current_datainfo = DataInfo()  # Reset DataInfo
318
319    def remove_empty_q_values(self):
320        """
321        Remove any point where Q == 0
322        """
323        if isinstance(self.current_dataset, plottable_1D):
324            # Booleans for resolutions
325            has_error_dx = self.current_dataset.dx is not None
326            has_error_dxl = self.current_dataset.dxl is not None
327            has_error_dxw = self.current_dataset.dxw is not None
328            has_error_dy = self.current_dataset.dy is not None
329            # Create arrays of zeros for non-existent resolutions
330            if has_error_dxw and not has_error_dxl:
331                array_size = self.current_dataset.dxw.size - 1
332                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
333                                                    np.zeros([array_size]))
334                has_error_dxl = True
335            elif has_error_dxl and not has_error_dxw:
336                array_size = self.current_dataset.dxl.size - 1
337                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
338                                                    np.zeros([array_size]))
339                has_error_dxw = True
340            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
341                array_size = self.current_dataset.x.size - 1
342                self.current_dataset.dx = np.append(self.current_dataset.dx,
343                                                    np.zeros([array_size]))
344                has_error_dx = True
345            if not has_error_dy:
346                array_size = self.current_dataset.y.size - 1
347                self.current_dataset.dy = np.append(self.current_dataset.dy,
348                                                    np.zeros([array_size]))
349                has_error_dy = True
350
351            # Remove points where q = 0
352            x = self.current_dataset.x
353            self.current_dataset.x = self.current_dataset.x[x != 0]
354            self.current_dataset.y = self.current_dataset.y[x != 0]
355            if has_error_dy:
356                self.current_dataset.dy = self.current_dataset.dy[x != 0]
357            if has_error_dx:
358                self.current_dataset.dx = self.current_dataset.dx[x != 0]
359            if has_error_dxl:
360                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
361            if has_error_dxw:
362                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
363        elif isinstance(self.current_dataset, plottable_2D):
364            has_error_dqx = self.current_dataset.dqx_data is not None
365            has_error_dqy = self.current_dataset.dqy_data is not None
366            has_error_dy = self.current_dataset.err_data is not None
367            has_mask = self.current_dataset.mask is not None
368            x = self.current_dataset.qx_data
369            self.current_dataset.data = self.current_dataset.data[x != 0]
370            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
371            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
372            self.current_dataset.q_data = np.sqrt(
373                np.square(self.current_dataset.qx_data) + np.square(
374                    self.current_dataset.qy_data))
375            if has_error_dy:
376                self.current_dataset.err_data = self.current_dataset.err_data[x != 0]
377            if has_error_dqx:
378                self.current_dataset.dqx_data = self.current_dataset.dqx_data[x != 0]
379            if has_error_dqy:
380                self.current_dataset.dqy_data = self.current_dataset.dqy_data[x != 0]
381            if has_mask:
382                self.current_dataset.mask = self.current_dataset.mask[x != 0]
383
384    def reset_data_list(self, no_lines=0):
385        """
386        Reset the plottable_1D object
387        """
388        # Initialize data sets with arrays the maximum possible size
389        x = np.zeros(no_lines)
390        y = np.zeros(no_lines)
391        dx = np.zeros(no_lines)
392        dy = np.zeros(no_lines)
393        self.current_dataset = plottable_1D(x, y, dx, dy)
394
395    @staticmethod
396    def splitline(line):
397        """
398        Splits a line into pieces based on common delimiters
399        :param line: A single line of text
400        :return: list of values
401        """
402        # Initial try for CSV (split on ,)
403        toks = line.split(',')
404        # Now try SCSV (split on ;)
405        if len(toks) < 2:
406            toks = line.split(';')
407        # Now go for whitespace
408        if len(toks) < 2:
409            toks = line.split()
410        return toks
411
412    @abstractmethod
413    def get_file_contents(self):
414        """
415        Reader specific class to access the contents of the file
416        All reader classes that inherit from FileReader must implement
417        """
418        pass
Note: See TracBrowser for help on using the repository browser.