source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ b1ec23d

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249unittest-saveload
Last change on this file since b1ec23d was b1ec23d, checked in by krzywon, 6 years ago

Remove unneeded shape check, give better error when saving fails, and data info cleanup.

  • Property mode set to 100644
File size: 20.3 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import math
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18from sas.sascalc.data_util.nxsunit import Converter
19
20logger = logging.getLogger(__name__)
21
22if sys.version_info[0] < 3:
23    def decode(s):
24        return s
25else:
26    def decode(s):
27        return s.decode() if isinstance(s, bytes) else s
28
29# Data 1D fields for iterative purposes
30FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw')
31# Data 2D fields for iterative purposes
32FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data',
33                 'dqx_data', 'dqy_data', 'mask')
34DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh"
35                       "t not be fully reduced. Support for the reader associat"
36                       "ed with this file type has been removed. An attempt to "
37                       "load the file was made, but, should it be successful, "
38                       "SasView cannot guarantee the accuracy of the data.")
39
40
41class FileReader(object):
42    # String to describe the type of data this reader can load
43    type_name = "ASCII"
44    # Wildcards to display
45    type = ["Text files (*.txt|*.TXT)"]
46    # List of allowed extensions
47    ext = ['.txt']
48    # Deprecated extensions
49    deprecated_extensions = ['.asc', '.nxs']
50    # Bypass extension check and try to load anyway
51    allow_all = False
52    # Able to import the unit converter
53    has_converter = True
54    # Default value of zero
55    _ZERO = 1e-16
56
57    def __init__(self):
58        # List of Data1D and Data2D objects to be sent back to data_loader
59        self.output = []
60        # Current plottable_(1D/2D) object being loaded in
61        self.current_dataset = None
62        # Current DataInfo object being loaded in
63        self.current_datainfo = None
64        # File path sent to reader
65        self.filepath = None
66        # Open file handle
67        self.f_open = None
68
69    def read(self, filepath):
70        """
71        Basic file reader
72
73        :param filepath: The full or relative path to a file to be loaded
74        """
75        self.filepath = filepath
76        if os.path.isfile(filepath):
77            basename, extension = os.path.splitext(os.path.basename(filepath))
78            self.extension = extension.lower()
79            # If the file type is not allowed, return nothing
80            if self.extension in self.ext or self.allow_all:
81                # Try to load the file, but raise an error if unable to.
82                try:
83                    self.f_open = open(filepath, 'rb')
84                    self.get_file_contents()
85
86                except DataReaderException as e:
87                    self.handle_error_message(e.message)
88                except OSError as e:
89                    # If the file cannot be opened
90                    msg = "Unable to open file: {}\n".format(filepath)
91                    msg += e.message
92                    self.handle_error_message(msg)
93                finally:
94                    # Close the file handle if it is open
95                    if not self.f_open.closed:
96                        self.f_open.close()
97                    if any(filepath.lower().endswith(ext) for ext in
98                           self.deprecated_extensions):
99                        self.handle_error_message(DEPRECATION_MESSAGE)
100                    if len(self.output) > 0:
101                        # Sort the data that's been loaded
102                        self.convert_data_units()
103                        self.sort_data()
104        else:
105            msg = "Unable to find file at: {}\n".format(filepath)
106            msg += "Please check your file path and try again."
107            self.handle_error_message(msg)
108
109        # Return a list of parsed entries that data_loader can manage
110        final_data = self.output
111        self.reset_state()
112        return final_data
113
114    def reset_state(self):
115        """
116        Resets the class state to a base case when loading a new data file so previous
117        data files do not appear a second time
118        """
119        self.current_datainfo = None
120        self.current_dataset = None
121        self.filepath = None
122        self.ind = None
123        self.output = []
124
125    def nextline(self):
126        """
127        Returns the next line in the file as a string.
128        """
129        #return self.f_open.readline()
130        return decode(self.f_open.readline())
131
132    def nextlines(self):
133        """
134        Returns the next line in the file as a string.
135        """
136        for line in self.f_open:
137            #yield line
138            yield decode(line)
139
140    def readall(self):
141        """
142        Returns the entire file as a string.
143        """
144        return decode(self.f_open.read())
145
146    def handle_error_message(self, msg):
147        """
148        Generic error handler to add an error to the current datainfo to
149        propagate the error up the error chain.
150        :param msg: Error message
151        """
152        if len(self.output) > 0:
153            self.output[-1].errors.append(msg)
154        elif isinstance(self.current_datainfo, DataInfo):
155            self.current_datainfo.errors.append(msg)
156        else:
157            logger.warning(msg)
158            raise NoKnownLoaderException(msg)
159
160    def send_to_output(self):
161        """
162        Helper that automatically combines the info and set and then appends it
163        to output
164        """
165        data_obj = combine_data_info_with_plottable(self.current_dataset,
166                                                    self.current_datainfo)
167        self.output.append(data_obj)
168
169    def sort_data(self):
170        """
171        Sort 1D data along the X axis for consistency
172        """
173        for data in self.output:
174            if isinstance(data, Data1D):
175                # Normalize the units for
176                data.x_unit = self.format_unit(data.x_unit)
177                data._xunit = data.x_unit
178                data.y_unit = self.format_unit(data.y_unit)
179                data._yunit = data.y_unit
180                # Sort data by increasing x and remove 1st point
181                ind = np.lexsort((data.y, data.x))
182                data.x = self._reorder_1d_array(data.x, ind)
183                data.y = self._reorder_1d_array(data.y, ind)
184                if data.dx is not None:
185                    if len(data.dx) == 0:
186                        data.dx = None
187                        continue
188                    data.dx = self._reorder_1d_array(data.dx, ind)
189                if data.dxl is not None:
190                    data.dxl = self._reorder_1d_array(data.dxl, ind)
191                if data.dxw is not None:
192                    data.dxw = self._reorder_1d_array(data.dxw, ind)
193                if data.dy is not None:
194                    if len(data.dy) == 0:
195                        data.dy = None
196                        continue
197                    data.dy = self._reorder_1d_array(data.dy, ind)
198                if data.lam is not None:
199                    data.lam = self._reorder_1d_array(data.lam, ind)
200                if data.dlam is not None:
201                    data.dlam = self._reorder_1d_array(data.dlam, ind)
202                data = self._remove_nans_in_data(data)
203                if len(data.x) > 0:
204                    data.xmin = np.min(data.x)
205                    data.xmax = np.max(data.x)
206                    data.ymin = np.min(data.y)
207                    data.ymax = np.max(data.y)
208            elif isinstance(data, Data2D):
209                # Normalize the units for
210                data.Q_unit = self.format_unit(data.Q_unit)
211                data.I_unit = self.format_unit(data.I_unit)
212                data._xunit = data.Q_unit
213                data._yunit = data.Q_unit
214                data._zunit = data.I_unit
215                data.data = data.data.astype(np.float64)
216                data.qx_data = data.qx_data.astype(np.float64)
217                data.xmin = np.min(data.qx_data)
218                data.xmax = np.max(data.qx_data)
219                data.qy_data = data.qy_data.astype(np.float64)
220                data.ymin = np.min(data.qy_data)
221                data.ymax = np.max(data.qy_data)
222                data.q_data = np.sqrt(data.qx_data * data.qx_data
223                                         + data.qy_data * data.qy_data)
224                if data.err_data is not None:
225                    data.err_data = data.err_data.astype(np.float64)
226                if data.dqx_data is not None:
227                    data.dqx_data = data.dqx_data.astype(np.float64)
228                if data.dqy_data is not None:
229                    data.dqy_data = data.dqy_data.astype(np.float64)
230                if data.mask is not None:
231                    data.mask = data.mask.astype(dtype=bool)
232
233                n_rows, n_cols = data.data.shape
234                data.y_bins = data.qy_data[0::int(n_cols)]
235                data.x_bins = data.qx_data[:int(n_cols)]
236                data.data = data.data.flatten()
237                data = self._remove_nans_in_data(data)
238                if len(data.data) > 0:
239                    data.xmin = np.min(data.qx_data)
240                    data.xmax = np.max(data.qx_data)
241                    data.ymin = np.min(data.qy_data)
242                    data.ymax = np.max(data.qx_data)
243
244    @staticmethod
245    def _reorder_1d_array(array, ind):
246        """
247        Reorders a 1D array based on the indices passed as ind
248        :param array: Array to be reordered
249        :param ind: Indices used to reorder array
250        :return: reordered array
251        """
252        array = np.asarray(array, dtype=np.float64)
253        return array[ind]
254
255    @staticmethod
256    def _remove_nans_in_data(data):
257        """
258        Remove data points where nan is loaded
259        :param data: 1D or 2D data object
260        :return: data with nan points removed
261        """
262        if isinstance(data, Data1D):
263            fields = FIELDS_1D
264        elif isinstance(data, Data2D):
265            fields = FIELDS_2D
266        else:
267            return data
268        # Make array of good points - all others will be removed
269        good = np.isfinite(getattr(data, fields[0]))
270        for name in fields[1:]:
271            array = getattr(data, name)
272            if array is not None:
273                # Update good points only if not already changed
274                good &= np.isfinite(array)
275        if not np.all(good):
276            for name in fields:
277                array = getattr(data, name)
278                if array is not None:
279                    setattr(data, name, array[good])
280        return data
281
282    @staticmethod
283    def set_default_1d_units(data):
284        """
285        Set the x and y axes to the default 1D units
286        :param data: 1D data set
287        :return:
288        """
289        data.xaxis(r"\rm{Q}", '1/A')
290        data.yaxis(r"\rm{Intensity}", "1/cm")
291        return data
292
293    @staticmethod
294    def set_default_2d_units(data):
295        """
296        Set the x and y axes to the default 2D units
297        :param data: 2D data set
298        :return:
299        """
300        data.xaxis("\\rm{Q_{x}}", '1/A')
301        data.yaxis("\\rm{Q_{y}}", '1/A')
302        data.zaxis("\\rm{Intensity}", "1/cm")
303        return data
304
305    def convert_data_units(self, default_q_unit="1/A", default_i_unit="1/cm"):
306        """
307        Converts al; data to the sasview default of units of A^{-1} for Q and
308        cm^{-1} for I.
309        :param default_q_unit: The default Q unit used by Sasview
310        :param default_i_unit: The default I unit used by Sasview
311        """
312        new_output = []
313        for data in self.output:
314            if data.isSesans:
315                new_output.append(data)
316                continue
317            file_x_unit = data._xunit
318            data_conv_x = Converter(file_x_unit)
319            file_y_unit = data._yunit
320            data_conv_y = Converter(file_y_unit)
321            if isinstance(data, Data1D):
322                try:
323                    data.x = data_conv_x(data.x, units=default_q_unit)
324                    data._xunit = default_q_unit
325                    data.x_unit = default_q_unit
326                    if data.dx is not None:
327                        data.dx = data_conv_x(data.dx, units=default_q_unit)
328                    if data.dxl is not None:
329                        data.dxl = data_conv_x(data.dxl, units=default_q_unit)
330                    if data.dxw is not None:
331                        data.dxw = data_conv_x(data.dxw, units=default_q_unit)
332                except KeyError:
333                    message = "Unable to convert Q units from {0} to 1/A."
334                    message.format(default_q_unit)
335                    data.errors.append(message)
336                try:
337                    data.y = data_conv_y(data.y, units=default_i_unit)
338                    data._yunit = default_i_unit
339                    data.y_unit = default_i_unit
340                    if data.dy is not None:
341                        data.dy = data_conv_y(data.dy, units=default_i_unit)
342                except KeyError:
343                    message = "Unable to convert I units from {0} to 1/cm."
344                    message.format(default_q_unit)
345                    data.errors.append(message)
346            elif isinstance(data, Data2D):
347                try:
348                    data.qx_data = data_conv_x(data.qx_data,
349                                               units=default_q_unit)
350                    if data.dqx_data is not None:
351                        data.dqx_data = data_conv_x(data.dqx_data,
352                                                    units=default_q_unit)
353                    data.qy_data = data_conv_y(data.qy_data,
354                                               units=default_q_unit)
355                    if data.dqy_data is not None:
356                        data.dqy_data = data_conv_y(data.dqy_data,
357                                                    units=default_q_unit)
358                except KeyError:
359                    message = "Unable to convert Q units from {0} to 1/A."
360                    message.format(default_q_unit)
361                    data.errors.append(message)
362                try:
363                    file_z_unit = data._zunit
364                    data_conv_z = Converter(file_z_unit)
365                    data.data = data_conv_z(data.data, units=default_i_unit)
366                    if data.err_data is not None:
367                        data.err_data = data_conv_z(data.err_data,
368                                                    units=default_i_unit)
369                except KeyError:
370                    message = "Unable to convert I units from {0} to 1/cm."
371                    message.format(default_q_unit)
372                    data.errors.append(message)
373            else:
374                # TODO: Throw error of some sort...
375                pass
376            new_output.append(data)
377        self.output = new_output
378
379    def format_unit(self, unit=None):
380        """
381        Format units a common way
382        :param unit:
383        :return:
384        """
385        if unit:
386            split = unit.split("/")
387            if len(split) == 1:
388                return unit
389            elif split[0] == '1':
390                return "{0}^".format(split[1]) + "{-1}"
391            else:
392                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
393
394    def set_all_to_none(self):
395        """
396        Set all mutable values to None for error handling purposes
397        """
398        self.current_dataset = None
399        self.current_datainfo = None
400        self.output = []
401
402    def data_cleanup(self):
403        """
404        Clean up the data sets and refresh everything
405        :return: None
406        """
407        self.remove_empty_q_values()
408        self.send_to_output()  # Combine datasets with DataInfo
409        self.current_datainfo = DataInfo()  # Reset DataInfo
410
411    def remove_empty_q_values(self):
412        """
413        Remove any point where Q == 0
414        """
415        if isinstance(self.current_dataset, plottable_1D):
416            # Booleans for resolutions
417            has_error_dx = self.current_dataset.dx is not None
418            has_error_dxl = self.current_dataset.dxl is not None
419            has_error_dxw = self.current_dataset.dxw is not None
420            has_error_dy = self.current_dataset.dy is not None
421            # Create arrays of zeros for non-existent resolutions
422            if has_error_dxw and not has_error_dxl:
423                array_size = self.current_dataset.dxw.size - 1
424                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
425                                                    np.zeros([array_size]))
426                has_error_dxl = True
427            elif has_error_dxl and not has_error_dxw:
428                array_size = self.current_dataset.dxl.size - 1
429                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
430                                                    np.zeros([array_size]))
431                has_error_dxw = True
432            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
433                array_size = self.current_dataset.x.size - 1
434                self.current_dataset.dx = np.append(self.current_dataset.dx,
435                                                    np.zeros([array_size]))
436                has_error_dx = True
437            if not has_error_dy:
438                array_size = self.current_dataset.y.size - 1
439                self.current_dataset.dy = np.append(self.current_dataset.dy,
440                                                    np.zeros([array_size]))
441                has_error_dy = True
442
443            # Remove points where q = 0
444            x = self.current_dataset.x
445            self.current_dataset.x = self.current_dataset.x[x != 0]
446            self.current_dataset.y = self.current_dataset.y[x != 0]
447            if has_error_dy:
448                self.current_dataset.dy = self.current_dataset.dy[x != 0]
449            if has_error_dx:
450                self.current_dataset.dx = self.current_dataset.dx[x != 0]
451            if has_error_dxl:
452                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
453            if has_error_dxw:
454                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
455        elif isinstance(self.current_dataset, plottable_2D):
456            has_error_dqx = self.current_dataset.dqx_data is not None
457            has_error_dqy = self.current_dataset.dqy_data is not None
458            has_error_dy = self.current_dataset.err_data is not None
459            has_mask = self.current_dataset.mask is not None
460            x = self.current_dataset.qx_data
461            self.current_dataset.data = self.current_dataset.data[x != 0]
462            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
463            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
464            self.current_dataset.q_data = np.sqrt(
465                np.square(self.current_dataset.qx_data) + np.square(
466                    self.current_dataset.qy_data))
467            if has_error_dy:
468                self.current_dataset.err_data = self.current_dataset.err_data[
469                    x != 0]
470            if has_error_dqx:
471                self.current_dataset.dqx_data = self.current_dataset.dqx_data[
472                    x != 0]
473            if has_error_dqy:
474                self.current_dataset.dqy_data = self.current_dataset.dqy_data[
475                    x != 0]
476            if has_mask:
477                self.current_dataset.mask = self.current_dataset.mask[x != 0]
478
479    def reset_data_list(self, no_lines=0):
480        """
481        Reset the plottable_1D object
482        """
483        # Initialize data sets with arrays the maximum possible size
484        x = np.zeros(no_lines)
485        y = np.zeros(no_lines)
486        dx = np.zeros(no_lines)
487        dy = np.zeros(no_lines)
488        self.current_dataset = plottable_1D(x, y, dx, dy)
489
490    @staticmethod
491    def splitline(line):
492        """
493        Splits a line into pieces based on common delimiters
494        :param line: A single line of text
495        :return: list of values
496        """
497        # Initial try for CSV (split on ,)
498        toks = line.split(',')
499        # Now try SCSV (split on ;)
500        if len(toks) < 2:
501            toks = line.split(';')
502        # Now go for whitespace
503        if len(toks) < 2:
504            toks = line.split()
505        return toks
506
507    @abstractmethod
508    def get_file_contents(self):
509        """
510        Reader specific class to access the contents of the file
511        All reader classes that inherit from FileReader must implement
512        """
513        pass
Note: See TracBrowser for help on using the repository browser.