source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ a78446ce

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249unittest-saveload
Last change on this file since a78446ce was 8d5e11c, checked in by krzywon, 6 years ago

File loader code cleanup.

  • Property mode set to 100644
File size: 20.4 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import math
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18from sas.sascalc.data_util.nxsunit import Converter
19
20logger = logging.getLogger(__name__)
21
22if sys.version_info[0] < 3:
23    def decode(s):
24        return s
25else:
26    def decode(s):
27        return s.decode() if isinstance(s, bytes) else s
28
29# Data 1D fields for iterative purposes
30FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw')
31# Data 2D fields for iterative purposes
32FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data',
33                 'dqx_data', 'dqy_data', 'mask')
34DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh"
35                       "t not be fully reduced. Support for the reader associat"
36                       "ed with this file type has been removed. An attempt to "
37                       "load the file was made, but, should it be successful, "
38                       "SasView cannot guarantee the accuracy of the data.")
39
40
41class FileReader(object):
42    # String to describe the type of data this reader can load
43    type_name = "ASCII"
44    # Wildcards to display
45    type = ["Text files (*.txt|*.TXT)"]
46    # List of allowed extensions
47    ext = ['.txt']
48    # Deprecated extensions
49    deprecated_extensions = ['.asc', '.nxs']
50    # Bypass extension check and try to load anyway
51    allow_all = False
52    # Able to import the unit converter
53    has_converter = True
54    # Default value of zero
55    _ZERO = 1e-16
56
57    def __init__(self):
58        # List of Data1D and Data2D objects to be sent back to data_loader
59        self.output = []
60        # Current plottable_(1D/2D) object being loaded in
61        self.current_dataset = None
62        # Current DataInfo object being loaded in
63        self.current_datainfo = None
64        # File path sent to reader
65        self.filepath = None
66        # Open file handle
67        self.f_open = None
68
69    def read(self, filepath):
70        """
71        Basic file reader
72
73        :param filepath: The full or relative path to a file to be loaded
74        """
75        self.filepath = filepath
76        if os.path.isfile(filepath):
77            basename, extension = os.path.splitext(os.path.basename(filepath))
78            self.extension = extension.lower()
79            # If the file type is not allowed, return nothing
80            if self.extension in self.ext or self.allow_all:
81                # Try to load the file, but raise an error if unable to.
82                try:
83                    self.f_open = open(filepath, 'rb')
84                    self.get_file_contents()
85
86                except DataReaderException as e:
87                    self.handle_error_message(e.message)
88                except OSError as e:
89                    # If the file cannot be opened
90                    msg = "Unable to open file: {}\n".format(filepath)
91                    msg += e.message
92                    self.handle_error_message(msg)
93                finally:
94                    # Close the file handle if it is open
95                    if not self.f_open.closed:
96                        self.f_open.close()
97                    if any(filepath.lower().endswith(ext) for ext in
98                           self.deprecated_extensions):
99                        self.handle_error_message(DEPRECATION_MESSAGE)
100                    if len(self.output) > 0:
101                        # Sort the data that's been loaded
102                        self.convert_data_units()
103                        self.sort_data()
104        else:
105            msg = "Unable to find file at: {}\n".format(filepath)
106            msg += "Please check your file path and try again."
107            self.handle_error_message(msg)
108
109        # Return a list of parsed entries that data_loader can manage
110        final_data = self.output
111        self.reset_state()
112        return final_data
113
114    def reset_state(self):
115        """
116        Resets the class state to a base case when loading a new data file so previous
117        data files do not appear a second time
118        """
119        self.current_datainfo = None
120        self.current_dataset = None
121        self.filepath = None
122        self.ind = None
123        self.output = []
124
125    def nextline(self):
126        """
127        Returns the next line in the file as a string.
128        """
129        #return self.f_open.readline()
130        return decode(self.f_open.readline())
131
132    def nextlines(self):
133        """
134        Returns the next line in the file as a string.
135        """
136        for line in self.f_open:
137            #yield line
138            yield decode(line)
139
140    def readall(self):
141        """
142        Returns the entire file as a string.
143        """
144        return decode(self.f_open.read())
145
146    def handle_error_message(self, msg):
147        """
148        Generic error handler to add an error to the current datainfo to
149        propagate the error up the error chain.
150        :param msg: Error message
151        """
152        if len(self.output) > 0:
153            self.output[-1].errors.append(msg)
154        elif isinstance(self.current_datainfo, DataInfo):
155            self.current_datainfo.errors.append(msg)
156        else:
157            logger.warning(msg)
158            raise NoKnownLoaderException(msg)
159
160    def send_to_output(self):
161        """
162        Helper that automatically combines the info and set and then appends it
163        to output
164        """
165        data_obj = combine_data_info_with_plottable(self.current_dataset,
166                                                    self.current_datainfo)
167        self.output.append(data_obj)
168
169    def sort_data(self):
170        """
171        Sort 1D data along the X axis for consistency
172        """
173        for data in self.output:
174            if isinstance(data, Data1D):
175                # Normalize the units for
176                data.x_unit = self.format_unit(data.x_unit)
177                data._xunit = data.x_unit
178                data.y_unit = self.format_unit(data.y_unit)
179                data._yunit = data.y_unit
180                # Sort data by increasing x and remove 1st point
181                ind = np.lexsort((data.y, data.x))
182                data.x = self._reorder_1d_array(data.x, ind)
183                data.y = self._reorder_1d_array(data.y, ind)
184                if data.dx is not None:
185                    if len(data.dx) == 0:
186                        data.dx = None
187                        continue
188                    data.dx = self._reorder_1d_array(data.dx, ind)
189                if data.dxl is not None:
190                    data.dxl = self._reorder_1d_array(data.dxl, ind)
191                if data.dxw is not None:
192                    data.dxw = self._reorder_1d_array(data.dxw, ind)
193                if data.dy is not None:
194                    if len(data.dy) == 0:
195                        data.dy = None
196                        continue
197                    data.dy = self._reorder_1d_array(data.dy, ind)
198                if data.lam is not None:
199                    data.lam = self._reorder_1d_array(data.lam, ind)
200                if data.dlam is not None:
201                    data.dlam = self._reorder_1d_array(data.dlam, ind)
202                data = self._remove_nans_in_data(data)
203                if len(data.x) > 0:
204                    data.xmin = np.min(data.x)
205                    data.xmax = np.max(data.x)
206                    data.ymin = np.min(data.y)
207                    data.ymax = np.max(data.y)
208            elif isinstance(data, Data2D):
209                # Normalize the units for
210                data.Q_unit = self.format_unit(data.Q_unit)
211                data.I_unit = self.format_unit(data.I_unit)
212                data._xunit = data.Q_unit
213                data._yunit = data.Q_unit
214                data._zunit = data.I_unit
215                data.data = data.data.astype(np.float64)
216                data.qx_data = data.qx_data.astype(np.float64)
217                data.xmin = np.min(data.qx_data)
218                data.xmax = np.max(data.qx_data)
219                data.qy_data = data.qy_data.astype(np.float64)
220                data.ymin = np.min(data.qy_data)
221                data.ymax = np.max(data.qy_data)
222                data.q_data = np.sqrt(data.qx_data * data.qx_data
223                                         + data.qy_data * data.qy_data)
224                if data.err_data is not None:
225                    data.err_data = data.err_data.astype(np.float64)
226                if data.dqx_data is not None:
227                    data.dqx_data = data.dqx_data.astype(np.float64)
228                if data.dqy_data is not None:
229                    data.dqy_data = data.dqy_data.astype(np.float64)
230                if data.mask is not None:
231                    data.mask = data.mask.astype(dtype=bool)
232
233                if len(data.data.shape) == 2:
234                    n_rows, n_cols = data.data.shape
235                    data.y_bins = data.qy_data[0::int(n_cols)]
236                    data.x_bins = data.qx_data[:int(n_cols)]
237                    data.data = data.data.flatten()
238                    data = self._remove_nans_in_data(data)
239                if len(data.data) > 0:
240                    data.xmin = np.min(data.qx_data)
241                    data.xmax = np.max(data.qx_data)
242                    data.ymin = np.min(data.qy_data)
243                    data.ymax = np.max(data.qx_data)
244
245    @staticmethod
246    def _reorder_1d_array(array, ind):
247        """
248        Reorders a 1D array based on the indices passed as ind
249        :param array: Array to be reordered
250        :param ind: Indices used to reorder array
251        :return: reordered array
252        """
253        array = np.asarray(array, dtype=np.float64)
254        return array[ind]
255
256    @staticmethod
257    def _remove_nans_in_data(data):
258        """
259        Remove data points where nan is loaded
260        :param data: 1D or 2D data object
261        :return: data with nan points removed
262        """
263        if isinstance(data, Data1D):
264            fields = FIELDS_1D
265        elif isinstance(data, Data2D):
266            fields = FIELDS_2D
267        else:
268            return data
269        # Make array of good points - all others will be removed
270        good = np.isfinite(getattr(data, fields[0]))
271        for name in fields[1:]:
272            array = getattr(data, name)
273            if array is not None:
274                # Update good points only if not already changed
275                good &= np.isfinite(array)
276        if not np.all(good):
277            for name in fields:
278                array = getattr(data, name)
279                if array is not None:
280                    setattr(data, name, array[good])
281        return data
282
283    @staticmethod
284    def set_default_1d_units(data):
285        """
286        Set the x and y axes to the default 1D units
287        :param data: 1D data set
288        :return:
289        """
290        data.xaxis("\\rm{Q}", '1/A')
291        data.yaxis("\\rm{Intensity}", "1/cm")
292        return data
293
294    @staticmethod
295    def set_default_2d_units(data):
296        """
297        Set the x and y axes to the default 2D units
298        :param data: 2D data set
299        :return:
300        """
301        data.xaxis("\\rm{Q_{x}}", '1/A')
302        data.yaxis("\\rm{Q_{y}}", '1/A')
303        data.zaxis("\\rm{Intensity}", "1/cm")
304        return data
305
306    def convert_data_units(self, default_q_unit="1/A", default_i_unit="1/cm"):
307        """
308        Converts al; data to the sasview default of units of A^{-1} for Q and
309        cm^{-1} for I.
310        :param default_q_unit: The default Q unit used by Sasview
311        :param default_i_unit: The default I unit used by Sasview
312        """
313        new_output = []
314        for data in self.output:
315            if data.isSesans:
316                new_output.append(data)
317                continue
318            file_x_unit = data._xunit
319            data_conv_x = Converter(file_x_unit)
320            file_y_unit = data._yunit
321            data_conv_y = Converter(file_y_unit)
322            if isinstance(data, Data1D):
323                try:
324                    data.x = data_conv_x(data.x, units=default_q_unit)
325                    data._xunit = default_q_unit
326                    data.x_unit = default_q_unit
327                    if data.dx is not None:
328                        data.dx = data_conv_x(data.dx, units=default_q_unit)
329                    if data.dxl is not None:
330                        data.dxl = data_conv_x(data.dxl, units=default_q_unit)
331                    if data.dxw is not None:
332                        data.dxw = data_conv_x(data.dxw, units=default_q_unit)
333                except KeyError:
334                    message = "Unable to convert Q units from {0} to 1/A."
335                    message.format(default_q_unit)
336                    data.errors.append(message)
337                try:
338                    data.y = data_conv_y(data.y, units=default_i_unit)
339                    data._yunit = default_i_unit
340                    data.y_unit = default_i_unit
341                    if data.dy is not None:
342                        data.dy = data_conv_y(data.dy, units=default_i_unit)
343                except KeyError:
344                    message = "Unable to convert I units from {0} to 1/cm."
345                    message.format(default_q_unit)
346                    data.errors.append(message)
347            elif isinstance(data, Data2D):
348                try:
349                    data.qx_data = data_conv_x(data.qx_data,
350                                               units=default_q_unit)
351                    if data.dqx_data is not None:
352                        data.dqx_data = data_conv_x(data.dqx_data,
353                                                    units=default_q_unit)
354                    data.qy_data = data_conv_y(data.qy_data,
355                                               units=default_q_unit)
356                    if data.dqy_data is not None:
357                        data.dqy_data = data_conv_y(data.dqy_data,
358                                                    units=default_q_unit)
359                except KeyError:
360                    message = "Unable to convert Q units from {0} to 1/A."
361                    message.format(default_q_unit)
362                    data.errors.append(message)
363                try:
364                    file_z_unit = data._zunit
365                    data_conv_z = Converter(file_z_unit)
366                    data.data = data_conv_z(data.data, units=default_i_unit)
367                    if data.err_data is not None:
368                        data.err_data = data_conv_z(data.err_data,
369                                                    units=default_i_unit)
370                except KeyError:
371                    message = "Unable to convert I units from {0} to 1/cm."
372                    message.format(default_q_unit)
373                    data.errors.append(message)
374            else:
375                # TODO: Throw error of some sort...
376                pass
377            new_output.append(data)
378        self.output = new_output
379
380    def format_unit(self, unit=None):
381        """
382        Format units a common way
383        :param unit:
384        :return:
385        """
386        if unit:
387            split = unit.split("/")
388            if len(split) == 1:
389                return unit
390            elif split[0] == '1':
391                return "{0}^".format(split[1]) + "{-1}"
392            else:
393                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
394
395    def set_all_to_none(self):
396        """
397        Set all mutable values to None for error handling purposes
398        """
399        self.current_dataset = None
400        self.current_datainfo = None
401        self.output = []
402
403    def data_cleanup(self):
404        """
405        Clean up the data sets and refresh everything
406        :return: None
407        """
408        self.remove_empty_q_values()
409        self.send_to_output()  # Combine datasets with DataInfo
410        self.current_datainfo = DataInfo()  # Reset DataInfo
411
412    def remove_empty_q_values(self):
413        """
414        Remove any point where Q == 0
415        """
416        if isinstance(self.current_dataset, plottable_1D):
417            # Booleans for resolutions
418            has_error_dx = self.current_dataset.dx is not None
419            has_error_dxl = self.current_dataset.dxl is not None
420            has_error_dxw = self.current_dataset.dxw is not None
421            has_error_dy = self.current_dataset.dy is not None
422            # Create arrays of zeros for non-existent resolutions
423            if has_error_dxw and not has_error_dxl:
424                array_size = self.current_dataset.dxw.size - 1
425                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
426                                                    np.zeros([array_size]))
427                has_error_dxl = True
428            elif has_error_dxl and not has_error_dxw:
429                array_size = self.current_dataset.dxl.size - 1
430                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
431                                                    np.zeros([array_size]))
432                has_error_dxw = True
433            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
434                array_size = self.current_dataset.x.size - 1
435                self.current_dataset.dx = np.append(self.current_dataset.dx,
436                                                    np.zeros([array_size]))
437                has_error_dx = True
438            if not has_error_dy:
439                array_size = self.current_dataset.y.size - 1
440                self.current_dataset.dy = np.append(self.current_dataset.dy,
441                                                    np.zeros([array_size]))
442                has_error_dy = True
443
444            # Remove points where q = 0
445            x = self.current_dataset.x
446            self.current_dataset.x = self.current_dataset.x[x != 0]
447            self.current_dataset.y = self.current_dataset.y[x != 0]
448            if has_error_dy:
449                self.current_dataset.dy = self.current_dataset.dy[x != 0]
450            if has_error_dx:
451                self.current_dataset.dx = self.current_dataset.dx[x != 0]
452            if has_error_dxl:
453                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
454            if has_error_dxw:
455                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
456        elif isinstance(self.current_dataset, plottable_2D):
457            has_error_dqx = self.current_dataset.dqx_data is not None
458            has_error_dqy = self.current_dataset.dqy_data is not None
459            has_error_dy = self.current_dataset.err_data is not None
460            has_mask = self.current_dataset.mask is not None
461            x = self.current_dataset.qx_data
462            self.current_dataset.data = self.current_dataset.data[x != 0]
463            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
464            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
465            self.current_dataset.q_data = np.sqrt(
466                np.square(self.current_dataset.qx_data) + np.square(
467                    self.current_dataset.qy_data))
468            if has_error_dy:
469                self.current_dataset.err_data = self.current_dataset.err_data[
470                    x != 0]
471            if has_error_dqx:
472                self.current_dataset.dqx_data = self.current_dataset.dqx_data[
473                    x != 0]
474            if has_error_dqy:
475                self.current_dataset.dqy_data = self.current_dataset.dqy_data[
476                    x != 0]
477            if has_mask:
478                self.current_dataset.mask = self.current_dataset.mask[x != 0]
479
480    def reset_data_list(self, no_lines=0):
481        """
482        Reset the plottable_1D object
483        """
484        # Initialize data sets with arrays the maximum possible size
485        x = np.zeros(no_lines)
486        y = np.zeros(no_lines)
487        dx = np.zeros(no_lines)
488        dy = np.zeros(no_lines)
489        self.current_dataset = plottable_1D(x, y, dx, dy)
490
491    @staticmethod
492    def splitline(line):
493        """
494        Splits a line into pieces based on common delimiters
495        :param line: A single line of text
496        :return: list of values
497        """
498        # Initial try for CSV (split on ,)
499        toks = line.split(',')
500        # Now try SCSV (split on ;)
501        if len(toks) < 2:
502            toks = line.split(';')
503        # Now go for whitespace
504        if len(toks) < 2:
505            toks = line.split()
506        return toks
507
508    @abstractmethod
509    def get_file_contents(self):
510        """
511        Reader specific class to access the contents of the file
512        All reader classes that inherit from FileReader must implement
513        """
514        pass
Note: See TracBrowser for help on using the repository browser.