source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ d76beb4

ESS_GUI
Last change on this file since d76beb4 was 8db20a9, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 6 years ago

Updated cansas read (cherrypicked and fixed from master).
Fixes: hdf5 returns byte strings so these need to be recasted properly.
https://github.com/h5py/h5py/issues/379

  • Property mode set to 100644
File size: 19.7 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import math
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18from sas.sascalc.data_util.nxsunit import Converter
19
20logger = logging.getLogger(__name__)
21
22if sys.version_info[0] < 3:
23    def decode(s):
24        return s
25else:
26    def decode(s):
27        return s.decode() if isinstance(s, bytes) else s
28
29# Data 1D fields for iterative purposes
30FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw')
31# Data 2D fields for iterative purposes
32FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data',
33                 'dqx_data', 'dqy_data', 'mask')
34DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh"
35                       "t not be fully reduced. Support for the reader associat"
36                       "ed with this file type has been removed. An attempt to "
37                       "load the file was made, but, should it be successful, "
38                       "SasView cannot guarantee the accuracy of the data.")
39
40
41class FileReader(object):
42    # String to describe the type of data this reader can load
43    type_name = "ASCII"
44    # Wildcards to display
45    type = ["Text files (*.txt|*.TXT)"]
46    # List of allowed extensions
47    ext = ['.txt']
48    # Deprecated extensions
49    deprecated_extensions = ['.asc']
50    # Bypass extension check and try to load anyway
51    allow_all = False
52    # Able to import the unit converter
53    has_converter = True
54    # Default value of zero
55    _ZERO = 1e-16
56
57    def __init__(self):
58        # List of Data1D and Data2D objects to be sent back to data_loader
59        self.output = []
60        # Current plottable_(1D/2D) object being loaded in
61        self.current_dataset = None
62        # Current DataInfo object being loaded in
63        self.current_datainfo = None
64        # File path sent to reader
65        self.filepath = None
66        # Open file handle
67        self.f_open = None
68
69    def read(self, filepath):
70        """
71        Basic file reader
72
73        :param filepath: The full or relative path to a file to be loaded
74        """
75        self.filepath = filepath
76        if os.path.isfile(filepath):
77            basename, extension = os.path.splitext(os.path.basename(filepath))
78            self.extension = extension.lower()
79            # If the file type is not allowed, return nothing
80            if self.extension in self.ext or self.allow_all:
81                # Try to load the file, but raise an error if unable to.
82                try:
83                    self.f_open = open(filepath, 'rb')
84                    self.get_file_contents()
85
86                except DataReaderException as e:
87                    self.handle_error_message(e.message)
88                except OSError as e:
89                    # If the file cannot be opened
90                    msg = "Unable to open file: {}\n".format(filepath)
91                    msg += e.message
92                    self.handle_error_message(msg)
93                finally:
94                    # Close the file handle if it is open
95                    if not self.f_open.closed:
96                        self.f_open.close()
97                    if any(filepath.lower().endswith(ext) for ext in
98                           self.deprecated_extensions):
99                        self.handle_error_message(DEPRECATION_MESSAGE)
100                    if len(self.output) > 0:
101                        # Sort the data that's been loaded
102                        self.convert_data_units()
103                        self.sort_data()
104        else:
105            msg = "Unable to find file at: {}\n".format(filepath)
106            msg += "Please check your file path and try again."
107            self.handle_error_message(msg)
108
109        # Return a list of parsed entries that data_loader can manage
110        final_data = self.output
111        self.reset_state()
112        return final_data
113
114    def reset_state(self):
115        """
116        Resets the class state to a base case when loading a new data file so previous
117        data files do not appear a second time
118        """
119        self.current_datainfo = None
120        self.current_dataset = None
121        self.filepath = None
122        self.ind = None
123        self.output = []
124
125    def nextline(self):
126        """
127        Returns the next line in the file as a string.
128        """
129        #return self.f_open.readline()
130        return decode(self.f_open.readline())
131
132    def nextlines(self):
133        """
134        Returns the next line in the file as a string.
135        """
136        for line in self.f_open:
137            #yield line
138            yield decode(line)
139
140    def readall(self):
141        """
142        Returns the entire file as a string.
143        """
144        return decode(self.f_open.read())
145
146    def handle_error_message(self, msg):
147        """
148        Generic error handler to add an error to the current datainfo to
149        propagate the error up the error chain.
150        :param msg: Error message
151        """
152        if len(self.output) > 0:
153            self.output[-1].errors.append(msg)
154        elif isinstance(self.current_datainfo, DataInfo):
155            self.current_datainfo.errors.append(msg)
156        else:
157            logger.warning(msg)
158            raise NoKnownLoaderException(msg)
159
160    def send_to_output(self):
161        """
162        Helper that automatically combines the info and set and then appends it
163        to output
164        """
165        data_obj = combine_data_info_with_plottable(self.current_dataset,
166                                                    self.current_datainfo)
167        self.output.append(data_obj)
168
169    def sort_data(self):
170        """
171        Sort 1D data along the X axis for consistency
172        """
173        for data in self.output:
174            if isinstance(data, Data1D):
175                # Normalize the units for
176                data.x_unit = self.format_unit(data.x_unit)
177                data._xunit = data.x_unit
178                data.y_unit = self.format_unit(data.y_unit)
179                data._yunit = data.y_unit
180                # Sort data by increasing x and remove 1st point
181                ind = np.lexsort((data.y, data.x))
182                data.x = self._reorder_1d_array(data.x, ind)
183                data.y = self._reorder_1d_array(data.y, ind)
184                if data.dx is not None:
185                    if len(data.dx) == 0:
186                        data.dx = None
187                        continue
188                    data.dx = self._reorder_1d_array(data.dx, ind)
189                if data.dxl is not None:
190                    data.dxl = self._reorder_1d_array(data.dxl, ind)
191                if data.dxw is not None:
192                    data.dxw = self._reorder_1d_array(data.dxw, ind)
193                if data.dy is not None:
194                    if len(data.dy) == 0:
195                        data.dy = None
196                        continue
197                    data.dy = self._reorder_1d_array(data.dy, ind)
198                if data.lam is not None:
199                    data.lam = self._reorder_1d_array(data.lam, ind)
200                if data.dlam is not None:
201                    data.dlam = self._reorder_1d_array(data.dlam, ind)
202                data = self._remove_nans_in_data(data)
203                if len(data.x) > 0:
204                    data.xmin = np.min(data.x)
205                    data.xmax = np.max(data.x)
206                    data.ymin = np.min(data.y)
207                    data.ymax = np.max(data.y)
208            elif isinstance(data, Data2D):
209                # Normalize the units for
210                data.Q_unit = self.format_unit(data.Q_unit)
211                data.I_unit = self.format_unit(data.I_unit)
212                data._xunit = data.Q_unit
213                data._yunit = data.Q_unit
214                data._zunit = data.I_unit
215                data.data = data.data.astype(np.float64)
216                data.qx_data = data.qx_data.astype(np.float64)
217                data.xmin = np.min(data.qx_data)
218                data.xmax = np.max(data.qx_data)
219                data.qy_data = data.qy_data.astype(np.float64)
220                data.ymin = np.min(data.qy_data)
221                data.ymax = np.max(data.qy_data)
222                data.q_data = np.sqrt(data.qx_data * data.qx_data
223                                         + data.qy_data * data.qy_data)
224                if data.err_data is not None:
225                    data.err_data = data.err_data.astype(np.float64)
226                if data.dqx_data is not None:
227                    data.dqx_data = data.dqx_data.astype(np.float64)
228                if data.dqy_data is not None:
229                    data.dqy_data = data.dqy_data.astype(np.float64)
230                if data.mask is not None:
231                    data.mask = data.mask.astype(dtype=bool)
232
233                if len(data.data.shape) == 2:
234                    n_rows, n_cols = data.data.shape
235                    data.y_bins = data.qy_data[0::int(n_cols)]
236                    data.x_bins = data.qx_data[:int(n_cols)]
237                    data.data = data.data.flatten()
238                data = self._remove_nans_in_data(data)
239                if len(data.data) > 0:
240                    data.xmin = np.min(data.qx_data)
241                    data.xmax = np.max(data.qx_data)
242                    data.ymin = np.min(data.qy_data)
243                    data.ymax = np.max(data.qy_data)
244
245    @staticmethod
246    def _reorder_1d_array(array, ind):
247        """
248        Reorders a 1D array based on the indices passed as ind
249        :param array: Array to be reordered
250        :param ind: Indices used to reorder array
251        :return: reordered array
252        """
253        array = np.asarray(array, dtype=np.float64)
254        return array[ind]
255
256    @staticmethod
257    def _remove_nans_in_data(data):
258        """
259        Remove data points where nan is loaded
260        :param data: 1D or 2D data object
261        :return: data with nan points removed
262        """
263        if isinstance(data, Data1D):
264            fields = FIELDS_1D
265        elif isinstance(data, Data2D):
266            fields = FIELDS_2D
267        else:
268            return data
269        # Make array of good points - all others will be removed
270        good = np.isfinite(getattr(data, fields[0]))
271        for name in fields[1:]:
272            array = getattr(data, name)
273            if array is not None:
274                # Update good points only if not already changed
275                good &= np.isfinite(array)
276        if not np.all(good):
277            for name in fields:
278                array = getattr(data, name)
279                if array is not None:
280                    setattr(data, name, array[good])
281        return data
282
283    @staticmethod
284    def set_default_1d_units(data):
285        """
286        Set the x and y axes to the default 1D units
287        :param data: 1D data set
288        :return:
289        """
290        data.xaxis(r"\rm{Q}", '1/A')
291        data.yaxis(r"\rm{Intensity}", "1/cm")
292        return data
293
294    @staticmethod
295    def set_default_2d_units(data):
296        """
297        Set the x and y axes to the default 2D units
298        :param data: 2D data set
299        :return:
300        """
301        data.xaxis("\\rm{Q_{x}}", '1/A')
302        data.yaxis("\\rm{Q_{y}}", '1/A')
303        data.zaxis("\\rm{Intensity}", "1/cm")
304        return data
305
306    def convert_data_units(self, default_q_unit="1/A"):
307        """
308        Converts al; data to the sasview default of units of A^{-1} for Q and
309        cm^{-1} for I.
310        :param default_q_unit: The default Q unit used by Sasview
311        """
312        convert_q = True
313        new_output = []
314        for data in self.output:
315            if data.isSesans:
316                new_output.append(data)
317                continue
318            try:
319                file_x_unit = data._xunit
320                data_conv_x = Converter(file_x_unit)
321            except KeyError:
322                logger.info("Unrecognized Q units in data file. No data "
323                            "conversion attempted")
324                convert_q = False
325            try:
326
327                if isinstance(data, Data1D):
328                        if convert_q:
329                            data.x = data_conv_x(data.x, units=default_q_unit)
330                            data._xunit = default_q_unit
331                            data.x_unit = default_q_unit
332                            if data.dx is not None:
333                                data.dx = data_conv_x(data.dx,
334                                                      units=default_q_unit)
335                            if data.dxl is not None:
336                                data.dxl = data_conv_x(data.dxl,
337                                                       units=default_q_unit)
338                            if data.dxw is not None:
339                                data.dxw = data_conv_x(data.dxw,
340                                                       units=default_q_unit)
341                elif isinstance(data, Data2D):
342                    if convert_q:
343                        data.qx_data = data_conv_x(data.qx_data,
344                                                   units=default_q_unit)
345                        if data.dqx_data is not None:
346                            data.dqx_data = data_conv_x(data.dqx_data,
347                                                        units=default_q_unit)
348                        try:
349                            file_y_unit = data._yunit
350                            data_conv_y = Converter(file_y_unit)
351                            data.qy_data = data_conv_y(data.qy_data,
352                                                       units=default_q_unit)
353                            if data.dqy_data is not None:
354                                data.dqy_data = data_conv_y(data.dqy_data,
355                                                            units=default_q_unit)
356                        except KeyError:
357                            logger.info("Unrecognized Qy units in data file. No"
358                                        " data conversion attempted")
359            except KeyError:
360                message = "Unable to convert Q units from {0} to 1/A."
361                message.format(default_q_unit)
362                data.errors.append(message)
363            new_output.append(data)
364        self.output = new_output
365
366    def format_unit(self, unit=None):
367        """
368        Format units a common way
369        :param unit:
370        :return:
371        """
372        if unit:
373            split = unit.split("/")
374            if len(split) == 1:
375                return unit
376            elif split[0] == '1':
377                return "{0}^".format(split[1]) + "{-1}"
378            else:
379                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
380
381    def set_all_to_none(self):
382        """
383        Set all mutable values to None for error handling purposes
384        """
385        self.current_dataset = None
386        self.current_datainfo = None
387        self.output = []
388
389    def data_cleanup(self):
390        """
391        Clean up the data sets and refresh everything
392        :return: None
393        """
394        self.remove_empty_q_values()
395        self.send_to_output()  # Combine datasets with DataInfo
396        self.current_datainfo = DataInfo()  # Reset DataInfo
397
398    def remove_empty_q_values(self):
399        """
400        Remove any point where Q == 0
401        """
402        if isinstance(self.current_dataset, plottable_1D):
403            # Booleans for resolutions
404            has_error_dx = self.current_dataset.dx is not None
405            has_error_dxl = self.current_dataset.dxl is not None
406            has_error_dxw = self.current_dataset.dxw is not None
407            has_error_dy = self.current_dataset.dy is not None
408            # Create arrays of zeros for non-existent resolutions
409            if has_error_dxw and not has_error_dxl:
410                array_size = self.current_dataset.dxw.size - 1
411                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
412                                                    np.zeros([array_size]))
413                has_error_dxl = True
414            elif has_error_dxl and not has_error_dxw:
415                array_size = self.current_dataset.dxl.size - 1
416                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
417                                                    np.zeros([array_size]))
418                has_error_dxw = True
419            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
420                array_size = self.current_dataset.x.size - 1
421                self.current_dataset.dx = np.append(self.current_dataset.dx,
422                                                    np.zeros([array_size]))
423                has_error_dx = True
424            if not has_error_dy:
425                array_size = self.current_dataset.y.size - 1
426                self.current_dataset.dy = np.append(self.current_dataset.dy,
427                                                    np.zeros([array_size]))
428                has_error_dy = True
429
430            # Remove points where q = 0
431            x = self.current_dataset.x
432            self.current_dataset.x = self.current_dataset.x[x != 0]
433            self.current_dataset.y = self.current_dataset.y[x != 0]
434            if has_error_dy:
435                self.current_dataset.dy = self.current_dataset.dy[x != 0]
436            if has_error_dx:
437                self.current_dataset.dx = self.current_dataset.dx[x != 0]
438            if has_error_dxl:
439                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
440            if has_error_dxw:
441                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
442        elif isinstance(self.current_dataset, plottable_2D):
443            has_error_dqx = self.current_dataset.dqx_data is not None
444            has_error_dqy = self.current_dataset.dqy_data is not None
445            has_error_dy = self.current_dataset.err_data is not None
446            has_mask = self.current_dataset.mask is not None
447            x = self.current_dataset.qx_data
448            self.current_dataset.data = self.current_dataset.data[x != 0]
449            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
450            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
451            self.current_dataset.q_data = np.sqrt(
452                np.square(self.current_dataset.qx_data) + np.square(
453                    self.current_dataset.qy_data))
454            if has_error_dy:
455                self.current_dataset.err_data = self.current_dataset.err_data[
456                    x != 0]
457            if has_error_dqx:
458                self.current_dataset.dqx_data = self.current_dataset.dqx_data[
459                    x != 0]
460            if has_error_dqy:
461                self.current_dataset.dqy_data = self.current_dataset.dqy_data[
462                    x != 0]
463            if has_mask:
464                self.current_dataset.mask = self.current_dataset.mask[x != 0]
465
466    def reset_data_list(self, no_lines=0):
467        """
468        Reset the plottable_1D object
469        """
470        # Initialize data sets with arrays the maximum possible size
471        x = np.zeros(no_lines)
472        y = np.zeros(no_lines)
473        dx = np.zeros(no_lines)
474        dy = np.zeros(no_lines)
475        self.current_dataset = plottable_1D(x, y, dx, dy)
476
477    @staticmethod
478    def splitline(line):
479        """
480        Splits a line into pieces based on common delimiters
481        :param line: A single line of text
482        :return: list of values
483        """
484        # Initial try for CSV (split on ,)
485        toks = line.split(',')
486        # Now try SCSV (split on ;)
487        if len(toks) < 2:
488            toks = line.split(';')
489        # Now go for whitespace
490        if len(toks) < 2:
491            toks = line.split()
492        return toks
493
494    @abstractmethod
495    def get_file_contents(self):
496        """
497        Reader specific class to access the contents of the file
498        All reader classes that inherit from FileReader must implement
499        """
500        pass
Note: See TracBrowser for help on using the repository browser.