source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ 3bab401

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249unittest-saveload
Last change on this file since 3bab401 was 3bab401, checked in by krzywon, 6 years ago

Create a universal unit converter for all data loaders. refs #1111

  • Property mode set to 100644
File size: 19.7 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import math
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18from sas.sascalc.data_util.nxsunit import Converter
19
20logger = logging.getLogger(__name__)
21
22if sys.version_info[0] < 3:
23    def decode(s):
24        return s
25else:
26    def decode(s):
27        return s.decode() if isinstance(s, bytes) else s
28
29# Data 1D fields for iterative purposes
30FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw')
31# Data 2D fields for iterative purposes
32FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data',
33                 'dqx_data', 'dqy_data', 'mask')
34DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh"
35                       "t not be fully reduced. Support for the reader associat"
36                       "ed with this file type has been removed. An attempt to "
37                       "load the file was made, but, should it be successful, "
38                       "SasView cannot guarantee the accuracy of the data.")
39
40class FileReader(object):
41    # String to describe the type of data this reader can load
42    type_name = "ASCII"
43    # Wildcards to display
44    type = ["Text files (*.txt|*.TXT)"]
45    # List of allowed extensions
46    ext = ['.txt']
47    # Deprecated extensions
48    deprecated_extensions = ['.asc', '.nxs']
49    # Bypass extension check and try to load anyway
50    allow_all = False
51    # Able to import the unit converter
52    has_converter = True
53    # Default value of zero
54    _ZERO = 1e-16
55
56    def __init__(self):
57        # List of Data1D and Data2D objects to be sent back to data_loader
58        self.output = []
59        # Current plottable_(1D/2D) object being loaded in
60        self.current_dataset = None
61        # Current DataInfo object being loaded in
62        self.current_datainfo = None
63        # File path sent to reader
64        self.filepath = None
65        # Open file handle
66        self.f_open = None
67
68    def read(self, filepath):
69        """
70        Basic file reader
71
72        :param filepath: The full or relative path to a file to be loaded
73        """
74        self.filepath = filepath
75        if os.path.isfile(filepath):
76            basename, extension = os.path.splitext(os.path.basename(filepath))
77            self.extension = extension.lower()
78            # If the file type is not allowed, return nothing
79            if self.extension in self.ext or self.allow_all:
80                # Try to load the file, but raise an error if unable to.
81                try:
82                    self.f_open = open(filepath, 'rb')
83                    self.get_file_contents()
84
85                except DataReaderException as e:
86                    self.handle_error_message(e.message)
87                except OSError as e:
88                    # If the file cannot be opened
89                    msg = "Unable to open file: {}\n".format(filepath)
90                    msg += e.message
91                    self.handle_error_message(msg)
92                finally:
93                    # Close the file handle if it is open
94                    if not self.f_open.closed:
95                        self.f_open.close()
96                    if any(filepath.lower().endswith(ext) for ext in
97                           self.deprecated_extensions):
98                        self.handle_error_message(DEPRECATION_MESSAGE)
99                    if len(self.output) > 0:
100                        # Sort the data that's been loaded
101                        self.convert_data_units()
102                        self.sort_data()
103        else:
104            msg = "Unable to find file at: {}\n".format(filepath)
105            msg += "Please check your file path and try again."
106            self.handle_error_message(msg)
107
108        # Return a list of parsed entries that data_loader can manage
109        final_data = self.output
110        self.reset_state()
111        return final_data
112
113    def reset_state(self):
114        """
115        Resets the class state to a base case when loading a new data file so previous
116        data files do not appear a second time
117        """
118        self.current_datainfo = None
119        self.current_dataset = None
120        self.filepath = None
121        self.ind = None
122        self.output = []
123
124    def nextline(self):
125        """
126        Returns the next line in the file as a string.
127        """
128        #return self.f_open.readline()
129        return decode(self.f_open.readline())
130
131    def nextlines(self):
132        """
133        Returns the next line in the file as a string.
134        """
135        for line in self.f_open:
136            #yield line
137            yield decode(line)
138
139    def readall(self):
140        """
141        Returns the entire file as a string.
142        """
143        #return self.f_open.read()
144        return decode(self.f_open.read())
145
146    def handle_error_message(self, msg):
147        """
148        Generic error handler to add an error to the current datainfo to
149        propagate the error up the error chain.
150        :param msg: Error message
151        """
152        if len(self.output) > 0:
153            self.output[-1].errors.append(msg)
154        elif isinstance(self.current_datainfo, DataInfo):
155            self.current_datainfo.errors.append(msg)
156        else:
157            logger.warning(msg)
158            raise NoKnownLoaderException(msg)
159
160    def send_to_output(self):
161        """
162        Helper that automatically combines the info and set and then appends it
163        to output
164        """
165        data_obj = combine_data_info_with_plottable(self.current_dataset,
166                                                    self.current_datainfo)
167        self.output.append(data_obj)
168
169    def sort_data(self):
170        """
171        Sort 1D data along the X axis for consistency
172        """
173        for data in self.output:
174            if isinstance(data, Data1D):
175                # Normalize the units for
176                data.x_unit = self.format_unit(data.x_unit)
177                data.y_unit = self.format_unit(data.y_unit)
178                # Sort data by increasing x and remove 1st point
179                ind = np.lexsort((data.y, data.x))
180                data.x = self._reorder_1d_array(data.x, ind)
181                data.y = self._reorder_1d_array(data.y, ind)
182                if data.dx is not None:
183                    if len(data.dx) == 0:
184                        data.dx = None
185                        continue
186                    data.dx = self._reorder_1d_array(data.dx, ind)
187                if data.dxl is not None:
188                    data.dxl = self._reorder_1d_array(data.dxl, ind)
189                if data.dxw is not None:
190                    data.dxw = self._reorder_1d_array(data.dxw, ind)
191                if data.dy is not None:
192                    if len(data.dy) == 0:
193                        data.dy = None
194                        continue
195                    data.dy = self._reorder_1d_array(data.dy, ind)
196                if data.lam is not None:
197                    data.lam = self._reorder_1d_array(data.lam, ind)
198                if data.dlam is not None:
199                    data.dlam = self._reorder_1d_array(data.dlam, ind)
200                data = self._remove_nans_in_data(data)
201                if len(data.x) > 0:
202                    data.xmin = np.min(data.x)
203                    data.xmax = np.max(data.x)
204                    data.ymin = np.min(data.y)
205                    data.ymax = np.max(data.y)
206            elif isinstance(data, Data2D):
207                # Normalize the units for
208                data.x_unit = self.format_unit(data.Q_unit)
209                data.y_unit = self.format_unit(data.I_unit)
210                data.data = data.data.astype(np.float64)
211                data.qx_data = data.qx_data.astype(np.float64)
212                data.xmin = np.min(data.qx_data)
213                data.xmax = np.max(data.qx_data)
214                data.qy_data = data.qy_data.astype(np.float64)
215                data.ymin = np.min(data.qy_data)
216                data.ymax = np.max(data.qy_data)
217                data.q_data = np.sqrt(data.qx_data * data.qx_data
218                                         + data.qy_data * data.qy_data)
219                if data.err_data is not None:
220                    data.err_data = data.err_data.astype(np.float64)
221                if data.dqx_data is not None:
222                    data.dqx_data = data.dqx_data.astype(np.float64)
223                if data.dqy_data is not None:
224                    data.dqy_data = data.dqy_data.astype(np.float64)
225                if data.mask is not None:
226                    data.mask = data.mask.astype(dtype=bool)
227
228                if len(data.data.shape) == 2:
229                    n_rows, n_cols = data.data.shape
230                    data.y_bins = data.qy_data[0::int(n_cols)]
231                    data.x_bins = data.qx_data[:int(n_cols)]
232                    data.data = data.data.flatten()
233                    data = self._remove_nans_in_data(data)
234                if len(data.data) > 0:
235                    data.xmin = np.min(data.qx_data)
236                    data.xmax = np.max(data.qx_data)
237                    data.ymin = np.min(data.qy_data)
238                    data.ymax = np.max(data.qx_data)
239
240    @staticmethod
241    def _reorder_1d_array(array, ind):
242        """
243        Reorders a 1D array based on the indices passed as ind
244        :param array: Array to be reordered
245        :param ind: Indices used to reorder array
246        :return: reordered array
247        """
248        array = np.asarray(array, dtype=np.float64)
249        return array[ind]
250
251    @staticmethod
252    def _remove_nans_in_data(data):
253        """
254        Remove data points where nan is loaded
255        :param data: 1D or 2D data object
256        :return: data with nan points removed
257        """
258        if isinstance(data, Data1D):
259            fields = FIELDS_1D
260        elif isinstance(data, Data2D):
261            fields = FIELDS_2D
262        else:
263            return data
264        # Make array of good points - all others will be removed
265        good = np.isfinite(getattr(data, fields[0]))
266        for name in fields[1:]:
267            array = getattr(data, name)
268            if array is not None:
269                # Update good points only if not already changed
270                good &= np.isfinite(array)
271        if not np.all(good):
272            for name in fields:
273                array = getattr(data, name)
274                if array is not None:
275                    setattr(data, name, array[good])
276        return data
277
278    @staticmethod
279    def set_default_1d_units(data):
280        """
281        Set the x and y axes to the default 1D units
282        :param data: 1D data set
283        :return:
284        """
285        data.xaxis("\\rm{Q}", '1/A')
286        data.yaxis("\\rm{Intensity}", "1/cm")
287        return data
288
289    @staticmethod
290    def set_default_2d_units(data):
291        """
292        Set the x and y axes to the default 2D units
293        :param data: 2D data set
294        :return:
295        """
296        data.xaxis("\\rm{Q_{x}}", '1/A')
297        data.yaxis("\\rm{Q_{y}}", '1/A')
298        data.zaxis("\\rm{Intensity}", "1/cm")
299        return data
300
301    def convert_data_units(self, default_q_unit="1/A", default_i_unit="1/cm"):
302        """
303        Converts al; data to the sasview default of units of A^{-1} for Q and
304        cm^{-1} for I.
305        :param default_x_unit: The default x unit used by Sasview
306        :param default_y_unit: The default y unit used by Sasview
307        """
308        new_output = []
309        for data in self.output:
310            file_x_unit = data._xunit
311            data_conv_x = Converter(file_x_unit)
312            file_y_unit = data._yunit
313            data_conv_y = Converter(file_y_unit)
314            if isinstance(data, Data1D):
315                try:
316                    data.x = data_conv_x(data.x, units=default_q_unit)
317                    if data.dx is not None:
318                        data.dx = data_conv_x(data.dx, units=default_q_unit)
319                    if data.dxl is not None:
320                        data.dxl = data_conv_x(data.dxl, units=default_q_unit)
321                    if data.dxw is not None:
322                        data.dxw = data_conv_x(data.dxw, units=default_q_unit)
323                except KeyError:
324                    message = "Unable to convert Q units from {0} to 1/A."
325                    message.format(default_q_unit)
326                    data.errors.append(message)
327                try:
328                    data.y = data_conv_y(data.y, units=default_i_unit)
329                    if data.dy is not None:
330                        data.dy = data_conv_y(data.dy, units=default_i_unit)
331                except KeyError:
332                    message = "Unable to convert I units from {0} to 1/cm."
333                    message.format(default_q_unit)
334                    data.errors.append(message)
335                new_output.append(data)
336            elif isinstance(data, Data2D):
337                try:
338                    data.qx_data = data_conv_x(data.qx_data, units=default_q_unit)
339                    if data.dqx_data is not None:
340                        data.dqx_data = data_conv_x(data.dqx_data, units=default_q_unit)
341                    data.qy_data = data_conv_y(data.qy_data, units=default_q_unit)
342                    if data.dqy_data is not None:
343                        data.dqy_data = data_conv_y(data.dqy_data, units=default_q_unit)
344                except KeyError:
345                    message = "Unable to convert Q units from {0} to 1/A."
346                    message.format(default_q_unit)
347                    data.errors.append(message)
348                try:
349                    file_z_unit = data._zunit
350                    data_conv_z = Converter(file_z_unit)
351                    data.data = data_conv_z(data.data, units=default_i_unit)
352                    if data.err_data is not None:
353                        data.err_data = data_conv_z(data.err_data, units=default_i_unit)
354                except KeyError:
355                    message = "Unable to convert I units from {0} to 1/cm."
356                    message.format(default_q_unit)
357                    data.errors.append(message)
358                new_output.append(data)
359            else:
360                # TODO: Throw error of some sort...
361                pass
362        self.output = new_output
363
364    def format_unit(self, unit=None):
365        """
366        Format units a common way
367        :param unit:
368        :return:
369        """
370        if unit:
371            split = unit.split("/")
372            if len(split) == 1:
373                return unit
374            elif split[0] == '1':
375                return "{0}^".format(split[1]) + "{-1}"
376            else:
377                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
378
379    def set_all_to_none(self):
380        """
381        Set all mutable values to None for error handling purposes
382        """
383        self.current_dataset = None
384        self.current_datainfo = None
385        self.output = []
386
387    def data_cleanup(self):
388        """
389        Clean up the data sets and refresh everything
390        :return: None
391        """
392        self.remove_empty_q_values()
393        self.send_to_output()  # Combine datasets with DataInfo
394        self.current_datainfo = DataInfo()  # Reset DataInfo
395
396    def remove_empty_q_values(self):
397        """
398        Remove any point where Q == 0
399        """
400        if isinstance(self.current_dataset, plottable_1D):
401            # Booleans for resolutions
402            has_error_dx = self.current_dataset.dx is not None
403            has_error_dxl = self.current_dataset.dxl is not None
404            has_error_dxw = self.current_dataset.dxw is not None
405            has_error_dy = self.current_dataset.dy is not None
406            # Create arrays of zeros for non-existent resolutions
407            if has_error_dxw and not has_error_dxl:
408                array_size = self.current_dataset.dxw.size - 1
409                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
410                                                    np.zeros([array_size]))
411                has_error_dxl = True
412            elif has_error_dxl and not has_error_dxw:
413                array_size = self.current_dataset.dxl.size - 1
414                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
415                                                    np.zeros([array_size]))
416                has_error_dxw = True
417            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
418                array_size = self.current_dataset.x.size - 1
419                self.current_dataset.dx = np.append(self.current_dataset.dx,
420                                                    np.zeros([array_size]))
421                has_error_dx = True
422            if not has_error_dy:
423                array_size = self.current_dataset.y.size - 1
424                self.current_dataset.dy = np.append(self.current_dataset.dy,
425                                                    np.zeros([array_size]))
426                has_error_dy = True
427
428            # Remove points where q = 0
429            x = self.current_dataset.x
430            self.current_dataset.x = self.current_dataset.x[x != 0]
431            self.current_dataset.y = self.current_dataset.y[x != 0]
432            if has_error_dy:
433                self.current_dataset.dy = self.current_dataset.dy[x != 0]
434            if has_error_dx:
435                self.current_dataset.dx = self.current_dataset.dx[x != 0]
436            if has_error_dxl:
437                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
438            if has_error_dxw:
439                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
440        elif isinstance(self.current_dataset, plottable_2D):
441            has_error_dqx = self.current_dataset.dqx_data is not None
442            has_error_dqy = self.current_dataset.dqy_data is not None
443            has_error_dy = self.current_dataset.err_data is not None
444            has_mask = self.current_dataset.mask is not None
445            x = self.current_dataset.qx_data
446            self.current_dataset.data = self.current_dataset.data[x != 0]
447            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
448            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
449            self.current_dataset.q_data = np.sqrt(
450                np.square(self.current_dataset.qx_data) + np.square(
451                    self.current_dataset.qy_data))
452            if has_error_dy:
453                self.current_dataset.err_data = self.current_dataset.err_data[x != 0]
454            if has_error_dqx:
455                self.current_dataset.dqx_data = self.current_dataset.dqx_data[x != 0]
456            if has_error_dqy:
457                self.current_dataset.dqy_data = self.current_dataset.dqy_data[x != 0]
458            if has_mask:
459                self.current_dataset.mask = self.current_dataset.mask[x != 0]
460
461    def reset_data_list(self, no_lines=0):
462        """
463        Reset the plottable_1D object
464        """
465        # Initialize data sets with arrays the maximum possible size
466        x = np.zeros(no_lines)
467        y = np.zeros(no_lines)
468        dx = np.zeros(no_lines)
469        dy = np.zeros(no_lines)
470        self.current_dataset = plottable_1D(x, y, dx, dy)
471
472    @staticmethod
473    def splitline(line):
474        """
475        Splits a line into pieces based on common delimiters
476        :param line: A single line of text
477        :return: list of values
478        """
479        # Initial try for CSV (split on ,)
480        toks = line.split(',')
481        # Now try SCSV (split on ;)
482        if len(toks) < 2:
483            toks = line.split(';')
484        # Now go for whitespace
485        if len(toks) < 2:
486            toks = line.split()
487        return toks
488
489    @abstractmethod
490    def get_file_contents(self):
491        """
492        Reader specific class to access the contents of the file
493        All reader classes that inherit from FileReader must implement
494        """
495        pass
Note: See TracBrowser for help on using the repository browser.