source: sasview/src/sas/sascalc/dataloader/file_reader_base_class.py @ b799f09

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249unittest-saveload
Last change on this file since b799f09 was 8f882fe, checked in by krzywon, 6 years ago

Changes to NXcanSAS reader to comply with specification. refs #976

  • Property mode set to 100644
File size: 20.1 KB
Line 
1"""
2This is the base file reader class most file readers should inherit from.
3All generic functionality required for a file loader/reader is built into this
4class
5"""
6
7import os
8import sys
9import math
10import logging
11from abc import abstractmethod
12
13import numpy as np
14from .loader_exceptions import NoKnownLoaderException, FileContentsException,\
15    DataReaderException, DefaultReaderException
16from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\
17    combine_data_info_with_plottable
18from sas.sascalc.data_util.nxsunit import Converter
19
20logger = logging.getLogger(__name__)
21
22if sys.version_info[0] < 3:
23    def decode(s):
24        return s
25else:
26    def decode(s):
27        return s.decode() if isinstance(s, bytes) else s
28
29# Data 1D fields for iterative purposes
30FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw')
31# Data 2D fields for iterative purposes
32FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data',
33                 'dqx_data', 'dqy_data', 'mask')
34DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh"
35                       "t not be fully reduced. Support for the reader associat"
36                       "ed with this file type has been removed. An attempt to "
37                       "load the file was made, but, should it be successful, "
38                       "SasView cannot guarantee the accuracy of the data.")
39
40class FileReader(object):
41    # String to describe the type of data this reader can load
42    type_name = "ASCII"
43    # Wildcards to display
44    type = ["Text files (*.txt|*.TXT)"]
45    # List of allowed extensions
46    ext = ['.txt']
47    # Deprecated extensions
48    deprecated_extensions = ['.asc', '.nxs']
49    # Bypass extension check and try to load anyway
50    allow_all = False
51    # Able to import the unit converter
52    has_converter = True
53    # Default value of zero
54    _ZERO = 1e-16
55
56    def __init__(self):
57        # List of Data1D and Data2D objects to be sent back to data_loader
58        self.output = []
59        # Current plottable_(1D/2D) object being loaded in
60        self.current_dataset = None
61        # Current DataInfo object being loaded in
62        self.current_datainfo = None
63        # File path sent to reader
64        self.filepath = None
65        # Open file handle
66        self.f_open = None
67
68    def read(self, filepath):
69        """
70        Basic file reader
71
72        :param filepath: The full or relative path to a file to be loaded
73        """
74        self.filepath = filepath
75        if os.path.isfile(filepath):
76            basename, extension = os.path.splitext(os.path.basename(filepath))
77            self.extension = extension.lower()
78            # If the file type is not allowed, return nothing
79            if self.extension in self.ext or self.allow_all:
80                # Try to load the file, but raise an error if unable to.
81                try:
82                    self.f_open = open(filepath, 'rb')
83                    self.get_file_contents()
84
85                except DataReaderException as e:
86                    self.handle_error_message(e.message)
87                except OSError as e:
88                    # If the file cannot be opened
89                    msg = "Unable to open file: {}\n".format(filepath)
90                    msg += e.message
91                    self.handle_error_message(msg)
92                finally:
93                    # Close the file handle if it is open
94                    if not self.f_open.closed:
95                        self.f_open.close()
96                    if any(filepath.lower().endswith(ext) for ext in
97                           self.deprecated_extensions):
98                        self.handle_error_message(DEPRECATION_MESSAGE)
99                    if len(self.output) > 0:
100                        # Sort the data that's been loaded
101                        self.convert_data_units()
102                        self.sort_data()
103        else:
104            msg = "Unable to find file at: {}\n".format(filepath)
105            msg += "Please check your file path and try again."
106            self.handle_error_message(msg)
107
108        # Return a list of parsed entries that data_loader can manage
109        final_data = self.output
110        self.reset_state()
111        return final_data
112
113    def reset_state(self):
114        """
115        Resets the class state to a base case when loading a new data file so previous
116        data files do not appear a second time
117        """
118        self.current_datainfo = None
119        self.current_dataset = None
120        self.filepath = None
121        self.ind = None
122        self.output = []
123
124    def nextline(self):
125        """
126        Returns the next line in the file as a string.
127        """
128        #return self.f_open.readline()
129        return decode(self.f_open.readline())
130
131    def nextlines(self):
132        """
133        Returns the next line in the file as a string.
134        """
135        for line in self.f_open:
136            #yield line
137            yield decode(line)
138
139    def readall(self):
140        """
141        Returns the entire file as a string.
142        """
143        #return self.f_open.read()
144        return decode(self.f_open.read())
145
146    def handle_error_message(self, msg):
147        """
148        Generic error handler to add an error to the current datainfo to
149        propagate the error up the error chain.
150        :param msg: Error message
151        """
152        if len(self.output) > 0:
153            self.output[-1].errors.append(msg)
154        elif isinstance(self.current_datainfo, DataInfo):
155            self.current_datainfo.errors.append(msg)
156        else:
157            logger.warning(msg)
158            raise NoKnownLoaderException(msg)
159
160    def send_to_output(self):
161        """
162        Helper that automatically combines the info and set and then appends it
163        to output
164        """
165        data_obj = combine_data_info_with_plottable(self.current_dataset,
166                                                    self.current_datainfo)
167        self.output.append(data_obj)
168
169    def sort_data(self):
170        """
171        Sort 1D data along the X axis for consistency
172        """
173        for data in self.output:
174            if isinstance(data, Data1D):
175                # Normalize the units for
176                data.x_unit = self.format_unit(data.x_unit)
177                data._xunit = data.x_unit
178                data.y_unit = self.format_unit(data.y_unit)
179                data._yunit = data.y_unit
180                # Sort data by increasing x and remove 1st point
181                ind = np.lexsort((data.y, data.x))
182                data.x = self._reorder_1d_array(data.x, ind)
183                data.y = self._reorder_1d_array(data.y, ind)
184                if data.dx is not None:
185                    if len(data.dx) == 0:
186                        data.dx = None
187                        continue
188                    data.dx = self._reorder_1d_array(data.dx, ind)
189                if data.dxl is not None:
190                    data.dxl = self._reorder_1d_array(data.dxl, ind)
191                if data.dxw is not None:
192                    data.dxw = self._reorder_1d_array(data.dxw, ind)
193                if data.dy is not None:
194                    if len(data.dy) == 0:
195                        data.dy = None
196                        continue
197                    data.dy = self._reorder_1d_array(data.dy, ind)
198                if data.lam is not None:
199                    data.lam = self._reorder_1d_array(data.lam, ind)
200                if data.dlam is not None:
201                    data.dlam = self._reorder_1d_array(data.dlam, ind)
202                data = self._remove_nans_in_data(data)
203                if len(data.x) > 0:
204                    data.xmin = np.min(data.x)
205                    data.xmax = np.max(data.x)
206                    data.ymin = np.min(data.y)
207                    data.ymax = np.max(data.y)
208            elif isinstance(data, Data2D):
209                # Normalize the units for
210                data.Q_unit = self.format_unit(data.Q_unit)
211                data.I_unit = self.format_unit(data.I_unit)
212                data._xunit = data.Q_unit
213                data._yunit = data.Q_unit
214                data._zunit = data.I_unit
215                data.data = data.data.astype(np.float64)
216                data.qx_data = data.qx_data.astype(np.float64)
217                data.xmin = np.min(data.qx_data)
218                data.xmax = np.max(data.qx_data)
219                data.qy_data = data.qy_data.astype(np.float64)
220                data.ymin = np.min(data.qy_data)
221                data.ymax = np.max(data.qy_data)
222                data.q_data = np.sqrt(data.qx_data * data.qx_data
223                                         + data.qy_data * data.qy_data)
224                if data.err_data is not None:
225                    data.err_data = data.err_data.astype(np.float64)
226                if data.dqx_data is not None:
227                    data.dqx_data = data.dqx_data.astype(np.float64)
228                if data.dqy_data is not None:
229                    data.dqy_data = data.dqy_data.astype(np.float64)
230                if data.mask is not None:
231                    data.mask = data.mask.astype(dtype=bool)
232
233                if len(data.data.shape) == 2:
234                    n_rows, n_cols = data.data.shape
235                    data.y_bins = data.qy_data[0::int(n_cols)]
236                    data.x_bins = data.qx_data[:int(n_cols)]
237                    data.data = data.data.flatten()
238                    data = self._remove_nans_in_data(data)
239                if len(data.data) > 0:
240                    data.xmin = np.min(data.qx_data)
241                    data.xmax = np.max(data.qx_data)
242                    data.ymin = np.min(data.qy_data)
243                    data.ymax = np.max(data.qx_data)
244
245    @staticmethod
246    def _reorder_1d_array(array, ind):
247        """
248        Reorders a 1D array based on the indices passed as ind
249        :param array: Array to be reordered
250        :param ind: Indices used to reorder array
251        :return: reordered array
252        """
253        array = np.asarray(array, dtype=np.float64)
254        return array[ind]
255
256    @staticmethod
257    def _remove_nans_in_data(data):
258        """
259        Remove data points where nan is loaded
260        :param data: 1D or 2D data object
261        :return: data with nan points removed
262        """
263        if isinstance(data, Data1D):
264            fields = FIELDS_1D
265        elif isinstance(data, Data2D):
266            fields = FIELDS_2D
267        else:
268            return data
269        # Make array of good points - all others will be removed
270        good = np.isfinite(getattr(data, fields[0]))
271        for name in fields[1:]:
272            array = getattr(data, name)
273            if array is not None:
274                # Update good points only if not already changed
275                good &= np.isfinite(array)
276        if not np.all(good):
277            for name in fields:
278                array = getattr(data, name)
279                if array is not None:
280                    setattr(data, name, array[good])
281        return data
282
283    @staticmethod
284    def set_default_1d_units(data):
285        """
286        Set the x and y axes to the default 1D units
287        :param data: 1D data set
288        :return:
289        """
290        data.xaxis("\\rm{Q}", '1/A')
291        data.yaxis("\\rm{Intensity}", "1/cm")
292        return data
293
294    @staticmethod
295    def set_default_2d_units(data):
296        """
297        Set the x and y axes to the default 2D units
298        :param data: 2D data set
299        :return:
300        """
301        data.xaxis("\\rm{Q_{x}}", '1/A')
302        data.yaxis("\\rm{Q_{y}}", '1/A')
303        data.zaxis("\\rm{Intensity}", "1/cm")
304        return data
305
306    def convert_data_units(self, default_q_unit="1/A", default_i_unit="1/cm"):
307        """
308        Converts al; data to the sasview default of units of A^{-1} for Q and
309        cm^{-1} for I.
310        :param default_x_unit: The default x unit used by Sasview
311        :param default_y_unit: The default y unit used by Sasview
312        """
313        new_output = []
314        for data in self.output:
315            if data.isSesans:
316                new_output.append(data)
317                continue
318            file_x_unit = data._xunit
319            data_conv_x = Converter(file_x_unit)
320            file_y_unit = data._yunit
321            data_conv_y = Converter(file_y_unit)
322            if isinstance(data, Data1D):
323                try:
324                    data.x = data_conv_x(data.x, units=default_q_unit)
325                    data._xunit = default_q_unit
326                    data.x_unit = default_q_unit
327                    if data.dx is not None:
328                        data.dx = data_conv_x(data.dx, units=default_q_unit)
329                    if data.dxl is not None:
330                        data.dxl = data_conv_x(data.dxl, units=default_q_unit)
331                    if data.dxw is not None:
332                        data.dxw = data_conv_x(data.dxw, units=default_q_unit)
333                except KeyError:
334                    message = "Unable to convert Q units from {0} to 1/A."
335                    message.format(default_q_unit)
336                    data.errors.append(message)
337                try:
338                    data.y = data_conv_y(data.y, units=default_i_unit)
339                    data._yunit = default_i_unit
340                    data.y_unit = default_i_unit
341                    if data.dy is not None:
342                        data.dy = data_conv_y(data.dy, units=default_i_unit)
343                except KeyError:
344                    message = "Unable to convert I units from {0} to 1/cm."
345                    message.format(default_q_unit)
346                    data.errors.append(message)
347            elif isinstance(data, Data2D):
348                try:
349                    data.qx_data = data_conv_x(data.qx_data, units=default_q_unit)
350                    if data.dqx_data is not None:
351                        data.dqx_data = data_conv_x(data.dqx_data, units=default_q_unit)
352                    data.qy_data = data_conv_y(data.qy_data, units=default_q_unit)
353                    if data.dqy_data is not None:
354                        data.dqy_data = data_conv_y(data.dqy_data, units=default_q_unit)
355                except KeyError:
356                    message = "Unable to convert Q units from {0} to 1/A."
357                    message.format(default_q_unit)
358                    data.errors.append(message)
359                try:
360                    file_z_unit = data._zunit
361                    data_conv_z = Converter(file_z_unit)
362                    data.data = data_conv_z(data.data, units=default_i_unit)
363                    if data.err_data is not None:
364                        data.err_data = data_conv_z(data.err_data, units=default_i_unit)
365                except KeyError:
366                    message = "Unable to convert I units from {0} to 1/cm."
367                    message.format(default_q_unit)
368                    data.errors.append(message)
369            else:
370                # TODO: Throw error of some sort...
371                pass
372            new_output.append(data)
373        self.output = new_output
374
375    def format_unit(self, unit=None):
376        """
377        Format units a common way
378        :param unit:
379        :return:
380        """
381        if unit:
382            split = unit.split("/")
383            if len(split) == 1:
384                return unit
385            elif split[0] == '1':
386                return "{0}^".format(split[1]) + "{-1}"
387            else:
388                return "{0}*{1}^".format(split[0], split[1]) + "{-1}"
389
390    def set_all_to_none(self):
391        """
392        Set all mutable values to None for error handling purposes
393        """
394        self.current_dataset = None
395        self.current_datainfo = None
396        self.output = []
397
398    def data_cleanup(self):
399        """
400        Clean up the data sets and refresh everything
401        :return: None
402        """
403        self.remove_empty_q_values()
404        self.send_to_output()  # Combine datasets with DataInfo
405        self.current_datainfo = DataInfo()  # Reset DataInfo
406
407    def remove_empty_q_values(self):
408        """
409        Remove any point where Q == 0
410        """
411        if isinstance(self.current_dataset, plottable_1D):
412            # Booleans for resolutions
413            has_error_dx = self.current_dataset.dx is not None
414            has_error_dxl = self.current_dataset.dxl is not None
415            has_error_dxw = self.current_dataset.dxw is not None
416            has_error_dy = self.current_dataset.dy is not None
417            # Create arrays of zeros for non-existent resolutions
418            if has_error_dxw and not has_error_dxl:
419                array_size = self.current_dataset.dxw.size - 1
420                self.current_dataset.dxl = np.append(self.current_dataset.dxl,
421                                                    np.zeros([array_size]))
422                has_error_dxl = True
423            elif has_error_dxl and not has_error_dxw:
424                array_size = self.current_dataset.dxl.size - 1
425                self.current_dataset.dxw = np.append(self.current_dataset.dxw,
426                                                    np.zeros([array_size]))
427                has_error_dxw = True
428            elif not has_error_dxl and not has_error_dxw and not has_error_dx:
429                array_size = self.current_dataset.x.size - 1
430                self.current_dataset.dx = np.append(self.current_dataset.dx,
431                                                    np.zeros([array_size]))
432                has_error_dx = True
433            if not has_error_dy:
434                array_size = self.current_dataset.y.size - 1
435                self.current_dataset.dy = np.append(self.current_dataset.dy,
436                                                    np.zeros([array_size]))
437                has_error_dy = True
438
439            # Remove points where q = 0
440            x = self.current_dataset.x
441            self.current_dataset.x = self.current_dataset.x[x != 0]
442            self.current_dataset.y = self.current_dataset.y[x != 0]
443            if has_error_dy:
444                self.current_dataset.dy = self.current_dataset.dy[x != 0]
445            if has_error_dx:
446                self.current_dataset.dx = self.current_dataset.dx[x != 0]
447            if has_error_dxl:
448                self.current_dataset.dxl = self.current_dataset.dxl[x != 0]
449            if has_error_dxw:
450                self.current_dataset.dxw = self.current_dataset.dxw[x != 0]
451        elif isinstance(self.current_dataset, plottable_2D):
452            has_error_dqx = self.current_dataset.dqx_data is not None
453            has_error_dqy = self.current_dataset.dqy_data is not None
454            has_error_dy = self.current_dataset.err_data is not None
455            has_mask = self.current_dataset.mask is not None
456            x = self.current_dataset.qx_data
457            self.current_dataset.data = self.current_dataset.data[x != 0]
458            self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0]
459            self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0]
460            self.current_dataset.q_data = np.sqrt(
461                np.square(self.current_dataset.qx_data) + np.square(
462                    self.current_dataset.qy_data))
463            if has_error_dy:
464                self.current_dataset.err_data = self.current_dataset.err_data[x != 0]
465            if has_error_dqx:
466                self.current_dataset.dqx_data = self.current_dataset.dqx_data[x != 0]
467            if has_error_dqy:
468                self.current_dataset.dqy_data = self.current_dataset.dqy_data[x != 0]
469            if has_mask:
470                self.current_dataset.mask = self.current_dataset.mask[x != 0]
471
472    def reset_data_list(self, no_lines=0):
473        """
474        Reset the plottable_1D object
475        """
476        # Initialize data sets with arrays the maximum possible size
477        x = np.zeros(no_lines)
478        y = np.zeros(no_lines)
479        dx = np.zeros(no_lines)
480        dy = np.zeros(no_lines)
481        self.current_dataset = plottable_1D(x, y, dx, dy)
482
483    @staticmethod
484    def splitline(line):
485        """
486        Splits a line into pieces based on common delimiters
487        :param line: A single line of text
488        :return: list of values
489        """
490        # Initial try for CSV (split on ,)
491        toks = line.split(',')
492        # Now try SCSV (split on ;)
493        if len(toks) < 2:
494            toks = line.split(';')
495        # Now go for whitespace
496        if len(toks) < 2:
497            toks = line.split()
498        return toks
499
500    @abstractmethod
501    def get_file_contents(self):
502        """
503        Reader specific class to access the contents of the file
504        All reader classes that inherit from FileReader must implement
505        """
506        pass
Note: See TracBrowser for help on using the repository browser.