source: sasview/src/sas/sascalc/file_converter/nxcansas_writer.py @ 9a0fc50

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249unittest-saveload
Last change on this file since 9a0fc50 was 2ca5d57b, checked in by krzywon, 6 years ago

Code cleanup and additional unit tests.

  • Property mode set to 100644
File size: 16.6 KB
Line 
1"""
2    NXcanSAS 1/2D data reader for writing HDF5 formatted NXcanSAS files.
3"""
4
5import h5py
6import numpy as np
7import re
8import os
9
10from sas.sascalc.dataloader.readers.cansas_reader_HDF5 import Reader
11from sas.sascalc.dataloader.data_info import Data1D, Data2D
12
13class NXcanSASWriter(Reader):
14    """
15    A class for writing in NXcanSAS data files. Any number of data sets may be
16    written to the file. Currently 1D and 2D SAS data sets are supported
17
18    NXcanSAS spec: http://download.nexusformat.org/sphinx/classes/contributed_definitions/NXcanSAS.html
19
20    :Dependencies:
21        The NXcanSAS writer requires h5py => v2.5.0 or later.
22    """
23
24    def write(self, dataset, filename):
25        """
26        Write an array of Data1d or Data2D objects to an NXcanSAS file, as
27        one SASEntry with multiple SASData elements. The metadata of the first
28        elememt in the array will be written as the SASentry metadata
29        (detector, instrument, sample, etc).
30
31        :param dataset: A list of Data1D or Data2D objects to write
32        :param filename: Where to write the NXcanSAS file
33        """
34
35        def _h5_string(string):
36            """
37            Convert a string to a numpy string in a numpy array. This way it is
38            written to the HDF5 file as a fixed length ASCII string and is
39            compatible with the Reader read() method.
40            """
41            if isinstance(string, np.ndarray):
42                return string
43            elif not isinstance(string, str):
44                string = str(string)
45
46            return np.array([np.string_(string)])
47
48        def _write_h5_string(entry, value, key):
49            entry[key] = _h5_string(value)
50
51        def _h5_float(x):
52            if not (isinstance(x, list)):
53                x = [x]
54            return np.array(x, dtype=np.float32)
55
56        def _write_h5_float(entry, value, key):
57            entry.create_dataset(key, data=_h5_float(value))
58
59        def _write_h5_vector(entry, vector, names=['x_position', 'y_position'],
60            units=None, write_fn=_write_h5_string):
61            """
62            Write a vector to an h5 entry
63
64            :param entry: The H5Py entry to write to
65            :param vector: The Vector to write
66            :param names: What to call the x,y and z components of the vector
67                when writing to the H5Py entry
68            :param units: The units of the vector (optional)
69            :param write_fn: A function to convert the value to the required
70                format and write it to the H5Py entry, of the form
71                f(entry, value, name) (optional)
72            """
73            if len(names) < 2:
74                raise ValueError("Length of names must be >= 2.")
75
76            if vector.x is not None:
77                write_fn(entry, vector.x, names[0])
78                if units is not None:
79                    entry[names[0]].attrs['units'] = units
80            if vector.y is not None:
81                write_fn(entry, vector.y, names[1])
82                if units is not None:
83                    entry[names[1]].attrs['units'] = units
84            if len(names) == 3 and vector.z is not None:
85                write_fn(entry, vector.z, names[2])
86                if units is not None:
87                    entry[names[2]].attrs['units'] = units
88
89        valid_data = all([issubclass(d.__class__, (Data1D, Data2D)) for d in
90                          dataset])
91        if not valid_data:
92            raise ValueError("All entries of dataset must be Data1D or Data2D"
93                             "objects")
94
95        # Get run name and number from first Data object
96        data_info = dataset[0]
97        run_number = ''
98        run_name = ''
99        if len(data_info.run) > 0:
100            run_number = data_info.run[0]
101            if len(data_info.run_name) > 0:
102                run_name = data_info.run_name[run_number]
103
104        f = h5py.File(filename, 'w')
105        sasentry = f.create_group('sasentry01')
106        sasentry['definition'] = _h5_string('NXcanSAS')
107        sasentry['run'] = _h5_string(run_number)
108        sasentry['run'].attrs['name'] = run_name
109        sasentry['title'] = _h5_string(data_info.title)
110        sasentry.attrs['canSAS_class'] = 'SASentry'
111        sasentry.attrs['version'] = '1.0'
112
113        for i, data_obj in enumerate(dataset):
114            data_entry = sasentry.create_group("sasdata{0:0=2d}".format(i+1))
115            data_entry.attrs['canSAS_class'] = 'SASdata'
116            if isinstance(data_obj, Data1D):
117                self._write_1d_data(data_obj, data_entry)
118            elif isinstance(data_obj, Data2D):
119                self._write_2d_data(data_obj, data_entry)
120
121        data_info = dataset[0]
122        # Sample metadata
123        sample_entry = sasentry.create_group('sassample')
124        sample_entry.attrs['canSAS_class'] = 'SASsample'
125        sample_entry['ID'] = _h5_string(data_info.sample.name)
126        sample_attrs = ['thickness', 'temperature', 'transmission']
127        for key in sample_attrs:
128            if getattr(data_info.sample, key) is not None:
129                sample_entry.create_dataset(key,
130                    data=_h5_float(getattr(data_info.sample, key)))
131        _write_h5_vector(sample_entry, data_info.sample.position)
132        # NXcanSAS doesn't save information about pitch, only roll
133        # and yaw. The _write_h5_vector method writes vector.y, but we
134        # need to write vector.z for yaw
135        data_info.sample.orientation.y = data_info.sample.orientation.z
136        _write_h5_vector(sample_entry, data_info.sample.orientation,
137            names=['polar_angle', 'azimuthal_angle'])
138        if data_info.sample.details is not None\
139            and data_info.sample.details != []:
140            details = None
141            if len(data_info.sample.details) > 1:
142                details = [np.string_(d) for d in data_info.sample.details]
143                details = np.array(details)
144            elif data_info.sample.details != []:
145                details = _h5_string(data_info.sample.details[0])
146            if details is not None:
147                sample_entry.create_dataset('details', data=details)
148
149        # Instrument metadata
150        instrument_entry = sasentry.create_group('sasinstrument')
151        instrument_entry.attrs['canSAS_class'] = 'SASinstrument'
152        instrument_entry['name'] = _h5_string(data_info.instrument)
153
154        # Source metadata
155        source_entry = instrument_entry.create_group('sassource')
156        source_entry.attrs['canSAS_class'] = 'SASsource'
157        if data_info.source.radiation is None:
158            source_entry['radiation'] = _h5_string('neutron')
159        else:
160            source_entry['radiation'] = _h5_string(data_info.source.radiation)
161        if data_info.source.beam_shape is not None:
162            source_entry['beam_shape'] = _h5_string(data_info.source.beam_shape)
163        wavelength_keys = { 'wavelength': 'incident_wavelength',
164            'wavelength_min':'wavelength_min',
165            'wavelength_max': 'wavelength_max',
166            'wavelength_spread': 'incident_wavelength_spread' }
167        for sasname, nxname in wavelength_keys.items():
168            value = getattr(data_info.source, sasname)
169            units = getattr(data_info.source, sasname + '_unit')
170            if value is not None:
171                source_entry[nxname] = _h5_float(value)
172                source_entry[nxname].attrs['units'] = units
173        _write_h5_vector(source_entry, data_info.source.beam_size,
174            names=['beam_size_x', 'beam_size_y'],
175            units=data_info.source.beam_size_unit, write_fn=_write_h5_float)
176
177        # Collimation metadata
178        if len(data_info.collimation) > 0:
179            for i, coll_info in enumerate(data_info.collimation):
180                collimation_entry = instrument_entry.create_group(
181                    'sascollimation{0:0=2d}'.format(i + 1))
182                collimation_entry.attrs['canSAS_class'] = 'SAScollimation'
183                if coll_info.length is not None:
184                    _write_h5_float(collimation_entry, coll_info.length, 'SDD')
185                    collimation_entry['SDD'].attrs['units'] =\
186                        coll_info.length_unit
187                if coll_info.name is not None:
188                    collimation_entry['name'] = _h5_string(coll_info.name)
189        else:
190            # Create a blank one - at least 1 collimation required by format
191            instrument_entry.create_group('sascollimation01')
192
193        # Detector metadata
194        if len(data_info.detector) > 0:
195            i = 1
196            for i, det_info in enumerate(data_info.detector):
197                detector_entry = instrument_entry.create_group(
198                    'sasdetector{0:0=2d}'.format(i + 1))
199                detector_entry.attrs['canSAS_class'] = 'SASdetector'
200                if det_info.distance is not None:
201                    _write_h5_float(detector_entry, det_info.distance, 'SDD')
202                    detector_entry['SDD'].attrs['units'] =\
203                        det_info.distance_unit
204                if det_info.name is not None:
205                    detector_entry['name'] = _h5_string(det_info.name)
206                else:
207                    detector_entry['name'] = _h5_string('')
208                if det_info.slit_length is not None:
209                    _write_h5_float(detector_entry, det_info.slit_length,
210                                    'slit_length')
211                    detector_entry['slit_length'].attrs['units'] =\
212                        det_info.slit_length_unit
213                _write_h5_vector(detector_entry, det_info.offset)
214                # NXcanSAS doesn't save information about pitch, only roll
215                # and yaw. The _write_h5_vector method writes vector.y, but we
216                # need to write vector.z for yaw
217                det_info.orientation.y = det_info.orientation.z
218                _write_h5_vector(detector_entry, det_info.orientation,
219                    names=['polar_angle', 'azimuthal_angle'])
220                _write_h5_vector(detector_entry, det_info.beam_center,
221                    names=['beam_center_x', 'beam_center_y'],
222                    write_fn=_write_h5_float, units=det_info.beam_center_unit)
223                _write_h5_vector(detector_entry, det_info.pixel_size,
224                    names=['x_pixel_size', 'y_pixel_size'],
225                    write_fn=_write_h5_float, units=det_info.pixel_size_unit)
226        else:
227            # Create a blank one - at least 1 detector required by format
228            detector_entry = instrument_entry.create_group('sasdetector01')
229            detector_entry.attrs['canSAS_class'] = 'SASdetector'
230            detector_entry.attrs['name'] = ''
231
232        # Process meta data
233        for i, process in enumerate(data_info.process):
234            process_entry = sasentry.create_group('sasprocess{0:0=2d}'.format(
235                i + 1))
236            process_entry.attrs['canSAS_class'] = 'SASprocess'
237            if process.name:
238                name = _h5_string(process.name)
239                process_entry.create_dataset('name', data=name)
240            if process.date:
241                date = _h5_string(process.date)
242                process_entry.create_dataset('date', data=date)
243            if process.description:
244                desc = _h5_string(process.description)
245                process_entry.create_dataset('description', data=desc)
246            for j, term in enumerate(process.term):
247                # Don't save empty terms
248                if term:
249                    h5_term = _h5_string(term)
250                    process_entry.create_dataset('term{0:0=2d}'.format(
251                        j + 1), data=h5_term)
252            for j, note in enumerate(process.notes):
253                # Don't save empty notes
254                if note:
255                    h5_note = _h5_string(note)
256                    process_entry.create_dataset('note{0:0=2d}'.format(
257                        j + 1), data=h5_note)
258
259        # Transmission Spectrum
260        for i, trans in enumerate(data_info.trans_spectrum):
261            trans_entry = sasentry.create_group(
262                'sastransmission_spectrum{0:0=2d}'.format(i + 1))
263            trans_entry.attrs['canSAS_class'] = 'SAStransmission_spectrum'
264            trans_entry.attrs['signal'] = 'T'
265            trans_entry.attrs['T_axes'] = 'T'
266            trans_entry.attrs['name'] = trans.name
267            if trans.timestamp is not '':
268                trans_entry.attrs['timestamp'] = trans.timestamp
269            transmission = trans_entry.create_dataset('T',
270                                                      data=trans.transmission)
271            transmission.attrs['unertainties'] = 'Tdev'
272            trans_entry.create_dataset('Tdev',
273                                       data=trans.transmission_deviation)
274            trans_entry.create_dataset('lambda', data=trans.wavelength)
275
276        note_entry = sasentry.create_group('sasnote'.format(i))
277        note_entry.attrs['canSAS_class'] = 'SASnote'
278        notes = None
279        if len(data_info.notes) > 1:
280            notes = [np.string_(n) for n in data_info.notes]
281            notes = np.array(notes)
282        elif data_info.notes != []:
283            notes = _h5_string(data_info.notes[0])
284        if notes is not None:
285            note_entry.create_dataset('SASnote', data=notes)
286
287        f.close()
288
289    def _write_1d_data(self, data_obj, data_entry):
290        """
291        Writes the contents of a Data1D object to a SASdata h5py Group
292
293        :param data_obj: A Data1D object to write to the file
294        :param data_entry: A h5py Group object representing the SASdata
295        """
296        data_entry.attrs['signal'] = 'I'
297        data_entry.attrs['I_axes'] = 'Q'
298        data_entry.attrs['Q_indices'] = [0]
299        q_entry = data_entry.create_dataset('Q', data=data_obj.x)
300        q_entry.attrs['units'] = data_obj.x_unit
301        i_entry = data_entry.create_dataset('I', data=data_obj.y)
302        i_entry.attrs['units'] = data_obj.y_unit
303        if data_obj.dy is not None:
304            i_entry.attrs['uncertainties'] = 'Idev'
305            i_dev_entry = data_entry.create_dataset('Idev', data=data_obj.dy)
306            i_dev_entry.attrs['units'] = data_obj.y_unit
307        if data_obj.dx is not None:
308            q_entry.attrs['resolutions'] = 'dQ'
309            dq_entry = data_entry.create_dataset('dQ', data=data_obj.dx)
310            dq_entry.attrs['units'] = data_obj.x_unit
311        elif data_obj.dxl is not None:
312            q_entry.attrs['resolutions'] = ['dQl','dQw']
313            dql_entry = data_entry.create_dataset('dQl', data=data_obj.dxl)
314            dql_entry.attrs['units'] = data_obj.x_unit
315            dqw_entry = data_entry.create_dataset('dQw', data=data_obj.dxw)
316            dqw_entry.attrs['units'] = data_obj.x_unit
317
318    def _write_2d_data(self, data, data_entry):
319        """
320        Writes the contents of a Data2D object to a SASdata h5py Group
321
322        :param data: A Data2D object to write to the file
323        :param data_entry: A h5py Group object representing the SASdata
324        """
325        data_entry.attrs['signal'] = 'I'
326        data_entry.attrs['I_axes'] = 'Qx,Qy'
327        data_entry.attrs['Q_indices'] = [0,1]
328
329        (n_rows, n_cols) = (len(data.y_bins), len(data.x_bins))
330
331        if n_rows == 0 and n_cols == 0:
332            # Calculate rows and columns, assuming detector is square
333            # Same logic as used in PlotPanel.py _get_bins
334            n_cols = int(np.floor(np.sqrt(len(data.qy_data))))
335            n_rows = int(np.floor(len(data.qy_data) / n_cols))
336
337            if n_rows * n_cols != len(data.qy_data):
338                raise ValueError("Unable to calculate dimensions of 2D data")
339
340        intensity = np.reshape(data.data, (n_rows, n_cols))
341        qx = np.reshape(data.qx_data, (n_rows, n_cols))
342        qy = np.reshape(data.qy_data, (n_rows, n_cols))
343
344        i_entry = data_entry.create_dataset('I', data=intensity)
345        i_entry.attrs['units'] = data.I_unit
346        qx_entry = data_entry.create_dataset('Qx', data=qx)
347        qx_entry.attrs['units'] = data.Q_unit
348        qy_entry = data_entry.create_dataset('Qy', data=qy)
349        qy_entry.attrs['units'] = data.Q_unit
350        if data.err_data is not None and not all(data.err_data == [None]):
351            d_i = np.reshape(data.err_data, (n_rows, n_cols))
352            i_entry.attrs['uncertainties'] = 'Idev'
353            i_dev_entry = data_entry.create_dataset('Idev', data=d_i)
354            i_dev_entry.attrs['units'] = data.I_unit
355        if data.dqx_data is not None and not all(data.dqx_data == [None]):
356            qx_entry.attrs['resolutions'] = 'dQx'
357            dqx_entry = data_entry.create_dataset('dQx', data=data.dqx_data)
358            dqx_entry.attrs['units'] = data.Q_unit
359        if data.dqy_data is not None and not all(data.dqy_data == [None]):
360            qy_entry.attrs['resolutions'] = 'dQy'
361            dqy_entry = data_entry.create_dataset('dQy', data=data.dqy_data)
362            dqy_entry.attrs['units'] = data.Q_unit
Note: See TracBrowser for help on using the repository browser.