source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ 058f6c3

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249unittest-saveload
Last change on this file since 058f6c3 was 058f6c3, checked in by krzywon, 4 years ago

Fixes for failing ascii, cansasXML, red2D, and sesans reader unit tests.

  • Property mode set to 100644
File size: 19.7 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import math
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18from sas.sascalc.data_util.nxsunit import Converter
19
20logger = logging.getLogger(__name__)
21
22if sys.version_info[0] < 3:
23    def decode(s):
24        return s
25else:
26    def decode(s):
27        return s.decode() if isinstance(s, bytes) else s
28
29# Data 1D fields for iterative purposes
30FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw')
31# Data 2D fields for iterative purposes
32FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data',
33                 'dqx_data', 'dqy_data', 'mask')
34DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh"
35                       "t not be fully reduced. Support for the reader associat"
36                       "ed with this file type has been removed. An attempt to "
37                       "load the file was made, but, should it be successful, "
38                       "SasView cannot guarantee the accuracy of the data.")
39
40class FileReader(object):
41    # String to describe the type of data this reader can load
42    type_name = "ASCII"
43    # Wildcards to display
44    type = ["Text files (*.txt|*.TXT)"]
45    # List of allowed extensions
46    ext = ['.txt']
47    # Deprecated extensions
48    deprecated_extensions = ['.asc', '.nxs']
49    # Bypass extension check and try to load anyway
50    allow_all = False
51    # Able to import the unit converter
52    has_converter = True
53    # Default value of zero
54    _ZERO = 1e-16
55
56    def __init__(self):
57        # List of Data1D and Data2D objects to be sent back to data_loader
58        self.output = []
59        # Current plottable_(1D/2D) object being loaded in
60        self.current_dataset = None
61        # Current DataInfo object being loaded in
62        self.current_datainfo = None
63        # File path sent to reader
64        self.filepath = None
65        # Open file handle
66        self.f_open = None
67
68    def read(self, filepath):
69        """
70        Basic file reader
71
72        :param filepath: The full or relative path to a file to be loaded
73        """
74        self.filepath = filepath
75        if os.path.isfile(filepath):
76            basename, extension = os.path.splitext(os.path.basename(filepath))
77            self.extension = extension.lower()
78            # If the file type is not allowed, return nothing
79            if self.extension in self.ext or self.allow_all:
80                # Try to load the file, but raise an error if unable to.
81                try:
82                    self.f_open = open(filepath, 'rb')
83                    self.get_file_contents()
84
85                except DataReaderException as e:
86                    self.handle_error_message(e.message)
87                except OSError as e:
88                    # If the file cannot be opened
89                    msg = "Unable to open file: {}\n".format(filepath)
90                    msg += e.message
91                    self.handle_error_message(msg)
92                finally:
93                    # Close the file handle if it is open
94                    if not self.f_open.closed:
95                        self.f_open.close()
96                    if any(filepath.lower().endswith(ext) for ext in
97                           self.deprecated_extensions):
98                        self.handle_error_message(DEPRECATION_MESSAGE)
99                    if len(self.output) > 0:
100                        # Sort the data that's been loaded
101                        self.convert_data_units()
102                        self.sort_data()
103        else:
104            msg = "Unable to find file at: {}\n".format(filepath)
105            msg += "Please check your file path and try again."
106            self.handle_error_message(msg)
107
108        # Return a list of parsed entries that data_loader can manage
109        final_data = self.output
110        self.reset_state()
111        return final_data
112
113    def reset_state(self):
114        """
115        Resets the class state to a base case when loading a new data file so previous
116        data files do not appear a second time
117        """
118        self.current_datainfo = None
119        self.current_dataset = None
120        self.filepath = None
121        self.ind = None
122        self.output = []
123
124    def nextline(self):
125        """
126        Returns the next line in the file as a string.
127        """
128        #return self.f_open.readline()
129        return decode(self.f_open.readline())
130
131    def nextlines(self):
132        """
133        Returns the next line in the file as a string.
134        """
135        for line in self.f_open:
136            #yield line
137            yield decode(line)
138
139    def readall(self):
140        """
141        Returns the entire file as a string.
142        """
143        #return self.f_open.read()
144        return decode(self.f_open.read())
145
146    def handle_error_message(self, msg):
147        """
148        Generic error handler to add an error to the current datainfo to
149        propagate the error up the error chain.
150        :param msg: Error message
151        """
152        if len(self.output) > 0:
153            self.output[-1].errors.append(msg)
154        elif isinstance(self.current_datainfo, DataInfo):
155            self.current_datainfo.errors.append(msg)
156        else:
157            logger.warning(msg)
158            raise NoKnownLoaderException(msg)
159
160    def send_to_output(self):
161        """
162        Helper that automatically combines the info and set and then appends it
163        to output
164        """
165        data_obj = combine_data_info_with_plottable(self.current_dataset,
166                                                    self.current_datainfo)
167        self.output.append(data_obj)
168
169    def sort_data(self):
170        """
171        Sort 1D data along the X axis for consistency
172        """
173        for data in self.output:
174            if isinstance(data, Data1D):
175                # Normalize the units for
176                data.x_unit = self.format_unit(data.x_unit)
177                data.y_unit = self.format_unit(data.y_unit)
178                # Sort data by increasing x and remove 1st point
179                ind = np.lexsort((data.y, data.x))
180                data.x = self._reorder_1d_array(data.x, ind)
181                data.y = self._reorder_1d_array(data.y, ind)
182                if data.dx is not None:
183                    if len(data.dx) == 0:
184                        data.dx = None
185                        continue
186                    data.dx = self._reorder_1d_array(data.dx, ind)
187                if data.dxl is not None:
188                    data.dxl = self._reorder_1d_array(data.dxl, ind)
189                if data.dxw is not None:
190                    data.dxw = self._reorder_1d_array(data.dxw, ind)
191                if data.dy is not None:
192                    if len(data.dy) == 0:
193                        data.dy = None
194                        continue
195                    data.dy = self._reorder_1d_array(data.dy, ind)
196                if data.lam is not None:
197                    data.lam = self._reorder_1d_array(data.lam, ind)
198                if data.dlam is not None:
199                    data.dlam = self._reorder_1d_array(data.dlam, ind)
200                data = self._remove_nans_in_data(data)
201                if len(data.x) > 0:
202                    data.xmin = np.min(data.x)
203                    data.xmax = np.max(data.x)
204                    data.ymin = np.min(data.y)
205                    data.ymax = np.max(data.y)
206            elif isinstance(data, Data2D):
207                # Normalize the units for
208                data.x_unit = self.format_unit(data.Q_unit)
209                data.y_unit = self.format_unit(data.I_unit)
210                data.data = data.data.astype(np.float64)
211                data.qx_data = data.qx_data.astype(np.float64)
212                data.xmin = np.min(data.qx_data)
213                data.xmax = np.max(data.qx_data)
214                data.qy_data = data.qy_data.astype(np.float64)
215                data.ymin = np.min(data.qy_data)
216                data.ymax = np.max(data.qy_data)
217                data.q_data = np.sqrt(data.qx_data * data.qx_data
218                                         + data.qy_data * data.qy_data)
219                if data.err_data is not None:
220                    data.err_data = data.err_data.astype(np.float64)
221                if data.dqx_data is not None:
222                    data.dqx_data = data.dqx_data.astype(np.float64)
223                if data.dqy_data is not None:
224                    data.dqy_data = data.dqy_data.astype(np.float64)
225                if data.mask is not None:
226                    data.mask = data.mask.astype(dtype=bool)
227
228                if len(data.data.shape) == 2:
229                    n_rows, n_cols = data.data.shape
230                    data.y_bins = data.qy_data[0::int(n_cols)]
231                    data.x_bins = data.qx_data[:int(n_cols)]
232                    data.data = data.data.flatten()
233                    data = self._remove_nans_in_data(data)
234                if len(data.data) > 0:
235                    data.xmin = np.min(data.qx_data)
236                    data.xmax = np.max(data.qx_data)
237                    data.ymin = np.min(data.qy_data)
238                    data.ymax = np.max(data.qx_data)
239
240    @staticmethod
241    def _reorder_1d_array(array, ind):
242        """
243        Reorders a 1D array based on the indices passed as ind
244        :param array: Array to be reordered
245        :param ind: Indices used to reorder array
246        :return: reordered array
247        """
248        array = np.asarray(array, dtype=np.float64)
249        return array[ind]
250
251    @staticmethod
252    def _remove_nans_in_data(data):
253        """
254        Remove data points where nan is loaded
255        :param data: 1D or 2D data object
256        :return: data with nan points removed
257        """
258        if isinstance(data, Data1D):
259            fields = FIELDS_1D
260        elif isinstance(data, Data2D):
261            fields = FIELDS_2D
262        else:
263            return data
264        # Make array of good points - all others will be removed
265        good = np.isfinite(getattr(data, fields[0]))
266        for name in fields[1:]:
267            array = getattr(data, name)
268            if array is not None:
269                # Update good points only if not already changed
270                good &= np.isfinite(array)
271        if not np.all(good):
272            for name in fields:
273                array = getattr(data, name)
274                if array is not None:
275                    setattr(data, name, array[good])
276        return data
277
278    @staticmethod
279    def set_default_1d_units(data):
280        """
281        Set the x and y axes to the default 1D units
282        :param data: 1D data set
283        :return:
284        """
285        data.xaxis("\\rm{Q}", '1/A')
286        data.yaxis("\\rm{Intensity}", "1/cm")
287        return data
288
289    @staticmethod
290    def set_default_2d_units(data):
291        """
292        Set the x and y axes to the default 2D units
293        :param data: 2D data set
294        :return:
295        """
296        data.xaxis("\\rm{Q_{x}}", '1/A')
297        data.yaxis("\\rm{Q_{y}}", '1/A')
298        data.zaxis("\\rm{Intensity}", "1/cm")
299        return data
300
301    def convert_data_units(self, default_q_unit="1/A", default_i_unit="1/cm"):
302        """
303        Converts al; data to the sasview default of units of A^{-1} for Q and
304        cm^{-1} for I.
305        :param default_x_unit: The default x unit used by Sasview
306        :param default_y_unit: The default y unit used by Sasview
307        """
308        new_output = []
309        for data in self.output:
310            if data.isSesans:
311                new_output.append(data)
312                continue
313            file_x_unit = data._xunit
314            data_conv_x = Converter(file_x_unit)
315            file_y_unit = data._yunit
316            data_conv_y = Converter(file_y_unit)
317            if isinstance(data, Data1D):
318                try:
319                    data.x = data_conv_x(data.x, units=default_q_unit)
320                    if data.dx is not None:
321                        data.dx = data_conv_x(data.dx, units=default_q_unit)
322                    if data.dxl is not None:
323                        data.dxl = data_conv_x(data.dxl, units=default_q_unit)
324                    if data.dxw is not None:
325                        data.dxw = data_conv_x(data.dxw, units=default_q_unit)
326                except KeyError:
327                    message = "Unable to convert Q units from {0} to 1/A."
328                    message.format(default_q_unit)
329                    data.errors.append(message)
330                try:
331                    data.y = data_conv_y(data.y, units=default_i_unit)
332                    if data.dy is not None:
333                        data.dy = data_conv_y(data.dy, units=default_i_unit)
334                except KeyError:
335                    message = "Unable to convert I units from {0} to 1/cm."
336                    message.format(default_q_unit)
337                    data.errors.append(message)
338            elif isinstance(data, Data2D):
339                try:
340                    data.qx_data = data_conv_x(data.qx_data, units=default_q_unit)
341                    if data.dqx_data is not None:
342                        data.dqx_data = data_conv_x(data.dqx_data, units=default_q_unit)
343                    data.qy_data = data_conv_y(data.qy_data, units=default_q_unit)
344                    if data.dqy_data is not None:
345                        data.dqy_data = data_conv_y(data.dqy_data, units=default_q_unit)
346                except KeyError:
347                    message = "Unable to convert Q units from {0} to 1/A."
348                    message.format(default_q_unit)
349                    data.errors.append(message)
350                try:
351                    file_z_unit = data._zunit
352                    data_conv_z = Converter(file_z_unit)
353                    data.data = data_conv_z(data.data, units=default_i_unit)
354                    if data.err_data is not None:
355                        data.err_data = data_conv_z(data.err_data, units=default_i_unit)
356                except KeyError:
357                    message = "Unable to convert I units from {0} to 1/cm."
358                    message.format(default_q_unit)
359                    data.errors.append(message)
360            else:
361                # TODO: Throw error of some sort...
362                pass
363            new_output.append(data)
364        self.output = new_output
365
366    def format_unit(self, unit=None):
367        """
368        Format units a common way
369        :param unit:
370        :return:
371        """
372        if unit:
373            split = unit.split("/")
374            if len(split) == 1:
375                return unit
376            elif split[0] == '1':
377                return "{0}^".format(split[1]) + "{-1}"
378            else:
379                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
380
381    def set_all_to_none(self):
382        """
383        Set all mutable values to None for error handling purposes
384        """
385        self.current_dataset = None
386        self.current_datainfo = None
387        self.output = []
388
389    def data_cleanup(self):
390        """
391        Clean up the data sets and refresh everything
392        :return: None
393        """
394        self.remove_empty_q_values()
395        self.send_to_output()  # Combine datasets with DataInfo
396        self.current_datainfo = DataInfo()  # Reset DataInfo
397
398    def remove_empty_q_values(self):
399        """
400        Remove any point where Q == 0
401        """
402        if isinstance(self.current_dataset, plottable_1D):
403            # Booleans for resolutions
404            has_error_dx = self.current_dataset.dx is not None
405            has_error_dxl = self.current_dataset.dxl is not None
406            has_error_dxw = self.current_dataset.dxw is not None
407            has_error_dy = self.current_dataset.dy is not None
408            # Create arrays of zeros for non-existent resolutions
409            if has_error_dxw and not has_error_dxl:
410                array_size = self.current_dataset.dxw.size - 1
411                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
412                                                    np.zeros([array_size]))
413                has_error_dxl = True
414            elif has_error_dxl and not has_error_dxw:
415                array_size = self.current_dataset.dxl.size - 1
416                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
417                                                    np.zeros([array_size]))
418                has_error_dxw = True
419            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
420                array_size = self.current_dataset.x.size - 1
421                self.current_dataset.dx = np.append(self.current_dataset.dx,
422                                                    np.zeros([array_size]))
423                has_error_dx = True
424            if not has_error_dy:
425                array_size = self.current_dataset.y.size - 1
426                self.current_dataset.dy = np.append(self.current_dataset.dy,
427                                                    np.zeros([array_size]))
428                has_error_dy = True
429
430            # Remove points where q = 0
431            x = self.current_dataset.x
432            self.current_dataset.x = self.current_dataset.x[x != 0]
433            self.current_dataset.y = self.current_dataset.y[x != 0]
434            if has_error_dy:
435                self.current_dataset.dy = self.current_dataset.dy[x != 0]
436            if has_error_dx:
437                self.current_dataset.dx = self.current_dataset.dx[x != 0]
438            if has_error_dxl:
439                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
440            if has_error_dxw:
441                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
442        elif isinstance(self.current_dataset, plottable_2D):
443            has_error_dqx = self.current_dataset.dqx_data is not None
444            has_error_dqy = self.current_dataset.dqy_data is not None
445            has_error_dy = self.current_dataset.err_data is not None
446            has_mask = self.current_dataset.mask is not None
447            x = self.current_dataset.qx_data
448            self.current_dataset.data = self.current_dataset.data[x != 0]
449            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
450            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
451            self.current_dataset.q_data = np.sqrt(
452                np.square(self.current_dataset.qx_data) + np.square(
453                    self.current_dataset.qy_data))
454            if has_error_dy:
455                self.current_dataset.err_data = self.current_dataset.err_data[x != 0]
456            if has_error_dqx:
457                self.current_dataset.dqx_data = self.current_dataset.dqx_data[x != 0]
458            if has_error_dqy:
459                self.current_dataset.dqy_data = self.current_dataset.dqy_data[x != 0]
460            if has_mask:
461                self.current_dataset.mask = self.current_dataset.mask[x != 0]
462
463    def reset_data_list(self, no_lines=0):
464        """
465        Reset the plottable_1D object
466        """
467        # Initialize data sets with arrays the maximum possible size
468        x = np.zeros(no_lines)
469        y = np.zeros(no_lines)
470        dx = np.zeros(no_lines)
471        dy = np.zeros(no_lines)
472        self.current_dataset = plottable_1D(x, y, dx, dy)
473
474    @staticmethod
475    def splitline(line):
476        """
477        Splits a line into pieces based on common delimiters
478        :param line: A single line of text
479        :return: list of values
480        """
481        # Initial try for CSV (split on ,)
482        toks = line.split(',')
483        # Now try SCSV (split on ;)
484        if len(toks) < 2:
485            toks = line.split(';')
486        # Now go for whitespace
487        if len(toks) < 2:
488            toks = line.split()
489        return toks
490
491    @abstractmethod
492    def get_file_contents(self):
493        """
494        Reader specific class to access the contents of the file
495        All reader classes that inherit from FileReader must implement
496        """
497        pass
Note: See TracBrowser for help on using the repository browser.