source: sasview/src/sas/sascalc/file_converter/nxcansas_writer.py @ f38d027

Last change on this file since f38d027 was 8db20a9, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 6 years ago

Updated cansas read (cherrypicked and fixed from master).
Fixes: hdf5 returns byte strings so these need to be recasted properly.
https://github.com/h5py/h5py/issues/379

  • Property mode set to 100644
File size: 16.9 KB
Line 
1"""
2    NXcanSAS 1/2D data reader for writing HDF5 formatted NXcanSAS files.
3"""
4
5import h5py
6import numpy as np
7import re
8import os
9
10from sas.sascalc.dataloader.readers.cansas_reader_HDF5 import Reader
11from sas.sascalc.dataloader.data_info import Data1D, Data2D
12
13class NXcanSASWriter(Reader):
14    """
15    A class for writing in NXcanSAS data files. Any number of data sets may be
16    written to the file. Currently 1D and 2D SAS data sets are supported
17
18    NXcanSAS spec: http://download.nexusformat.org/sphinx/classes/contributed_definitions/NXcanSAS.html
19
20    :Dependencies:
21        The NXcanSAS writer requires h5py => v2.5.0 or later.
22    """
23
24    def write(self, dataset, filename):
25        """
26        Write an array of Data1d or Data2D objects to an NXcanSAS file, as
27        one SASEntry with multiple SASData elements. The metadata of the first
28        elememt in the array will be written as the SASentry metadata
29        (detector, instrument, sample, etc).
30
31        :param dataset: A list of Data1D or Data2D objects to write
32        :param filename: Where to write the NXcanSAS file
33        """
34
35        def _h5_string(string):
36            """
37            Convert a string to a numpy string in a numpy array. This way it is
38            written to the HDF5 file as a fixed length ASCII string and is
39            compatible with the Reader read() method.
40            """
41            if isinstance(string, np.ndarray):
42                return string
43            elif not isinstance(string, str):
44                string = str(string)
45
46            return np.array([np.string_(string)])
47
48        def _write_h5_string(entry, value, key):
49            entry[key] = _h5_string(value)
50
51        def _h5_float(x):
52            if not (isinstance(x, list)):
53                x = [x]
54            return np.array(x, dtype=np.float32)
55
56        def _write_h5_float(entry, value, key):
57            entry.create_dataset(key, data=_h5_float(value))
58
59        def _write_h5_vector(entry, vector, names=['x_position', 'y_position'],
60            units=None, write_fn=_write_h5_string):
61            """
62            Write a vector to an h5 entry
63
64            :param entry: The H5Py entry to write to
65            :param vector: The Vector to write
66            :param names: What to call the x,y and z components of the vector
67                when writing to the H5Py entry
68            :param units: The units of the vector (optional)
69            :param write_fn: A function to convert the value to the required
70                format and write it to the H5Py entry, of the form
71                f(entry, value, name) (optional)
72            """
73            if len(names) < 2:
74                raise ValueError("Length of names must be >= 2.")
75
76            if vector.x is not None:
77                write_fn(entry, vector.x, names[0])
78                if units is not None:
79                    entry[names[0]].attrs['units'] = units
80            if vector.y is not None:
81                write_fn(entry, vector.y, names[1])
82                if units is not None:
83                    entry[names[1]].attrs['units'] = units
84            if len(names) == 3 and vector.z is not None:
85                write_fn(entry, vector.z, names[2])
86                if units is not None:
87                    entry[names[2]].attrs['units'] = units
88
89        valid_data = all([isinstance(d, (Data1D, Data2D)) for d in dataset])
90        if not valid_data:
91            raise ValueError("All entries of dataset must be Data1D or Data2D"
92                             "objects")
93
94        # Get run name and number from first Data object
95        data_info = dataset[0]
96        run_number = ''
97        run_name = ''
98        if len(data_info.run) > 0:
99            run_number = data_info.run[0]
100            if len(data_info.run_name) > 0:
101                run_name = data_info.run_name[run_number]
102
103        f = h5py.File(filename, 'w')
104        sasentry = f.create_group('sasentry01')
105        sasentry['definition'] = _h5_string('NXcanSAS')
106        sasentry['run'] = _h5_string(run_number)
107        sasentry['run'].attrs['name'] = run_name
108        sasentry['title'] = _h5_string(data_info.title)
109        sasentry.attrs['canSAS_class'] = 'SASentry'
110        sasentry.attrs['version'] = '1.0'
111
112        for i, data_obj in enumerate(dataset):
113            data_entry = sasentry.create_group("sasdata{0:0=2d}".format(i+1))
114            data_entry.attrs['canSAS_class'] = 'SASdata'
115            if isinstance(data_obj, Data1D):
116                self._write_1d_data(data_obj, data_entry)
117            elif isinstance(data_obj, Data2D):
118                self._write_2d_data(data_obj, data_entry)
119
120        data_info = dataset[0]
121        # Sample metadata
122        sample_entry = sasentry.create_group('sassample')
123        sample_entry.attrs['canSAS_class'] = 'SASsample'
124        sample_entry['ID'] = _h5_string(data_info.sample.name)
125        sample_attrs = ['thickness', 'temperature', 'transmission']
126        for key in sample_attrs:
127            if getattr(data_info.sample, key) is not None:
128                sample_entry.create_dataset(key,
129                    data=_h5_float(getattr(data_info.sample, key)))
130        _write_h5_vector(sample_entry, data_info.sample.position)
131        # NXcanSAS doesn't save information about pitch, only roll
132        # and yaw. The _write_h5_vector method writes vector.y, but we
133        # need to write vector.z for yaw
134        data_info.sample.orientation.y = data_info.sample.orientation.z
135        _write_h5_vector(sample_entry, data_info.sample.orientation,
136            names=['polar_angle', 'azimuthal_angle'])
137        if data_info.sample.details is not None\
138            and data_info.sample.details != []:
139            details = None
140            if len(data_info.sample.details) > 1:
141                details = [np.string_(d) for d in data_info.sample.details]
142                details = np.array(details)
143            elif data_info.sample.details != []:
144                details = _h5_string(data_info.sample.details[0])
145            if details is not None:
146                sample_entry.create_dataset('details', data=details)
147
148        # Instrument metadata
149        instrument_entry = sasentry.create_group('sasinstrument')
150        instrument_entry.attrs['canSAS_class'] = 'SASinstrument'
151        instrument_entry['name'] = _h5_string(data_info.instrument)
152
153        # Source metadata
154        source_entry = instrument_entry.create_group('sassource')
155        source_entry.attrs['canSAS_class'] = 'SASsource'
156        if data_info.source.radiation is None:
157            source_entry['radiation'] = _h5_string('neutron')
158        else:
159            source_entry['radiation'] = _h5_string(data_info.source.radiation)
160        if data_info.source.beam_shape is not None:
161            source_entry['beam_shape'] = _h5_string(data_info.source.beam_shape)
162        wavelength_keys = { 'wavelength': 'incident_wavelength',
163            'wavelength_min':'wavelength_min',
164            'wavelength_max': 'wavelength_max',
165            'wavelength_spread': 'incident_wavelength_spread' }
166        for sasname, nxname in wavelength_keys.items():
167            value = getattr(data_info.source, sasname)
168            units = getattr(data_info.source, sasname + '_unit')
169            if value is not None:
170                source_entry[nxname] = _h5_float(value)
171                source_entry[nxname].attrs['units'] = units
172        _write_h5_vector(source_entry, data_info.source.beam_size,
173            names=['beam_size_x', 'beam_size_y'],
174            units=data_info.source.beam_size_unit, write_fn=_write_h5_float)
175
176        # Collimation metadata
177        if len(data_info.collimation) > 0:
178            for i, coll_info in enumerate(data_info.collimation):
179                collimation_entry = instrument_entry.create_group(
180                    'sascollimation{0:0=2d}'.format(i + 1))
181                collimation_entry.attrs['canSAS_class'] = 'SAScollimation'
182                if coll_info.length is not None:
183                    _write_h5_float(collimation_entry, coll_info.length, 'SDD')
184                    collimation_entry['SDD'].attrs['units'] =\
185                        coll_info.length_unit
186                if coll_info.name is not None:
187                    collimation_entry['name'] = _h5_string(coll_info.name)
188        else:
189            # Create a blank one - at least 1 collimation required by format
190            instrument_entry.create_group('sascollimation01')
191
192        # Detector metadata
193        if len(data_info.detector) > 0:
194            i = 1
195            for i, det_info in enumerate(data_info.detector):
196                detector_entry = instrument_entry.create_group(
197                    'sasdetector{0:0=2d}'.format(i + 1))
198                detector_entry.attrs['canSAS_class'] = 'SASdetector'
199                if det_info.distance is not None:
200                    _write_h5_float(detector_entry, det_info.distance, 'SDD')
201                    detector_entry['SDD'].attrs['units'] =\
202                        det_info.distance_unit
203                if det_info.name is not None:
204                    detector_entry['name'] = _h5_string(det_info.name)
205                else:
206                    detector_entry['name'] = _h5_string('')
207                if det_info.slit_length is not None:
208                    _write_h5_float(detector_entry, det_info.slit_length,
209                                    'slit_length')
210                    detector_entry['slit_length'].attrs['units'] =\
211                        det_info.slit_length_unit
212                _write_h5_vector(detector_entry, det_info.offset)
213                # NXcanSAS doesn't save information about pitch, only roll
214                # and yaw. The _write_h5_vector method writes vector.y, but we
215                # need to write vector.z for yaw
216                det_info.orientation.y = det_info.orientation.z
217                _write_h5_vector(detector_entry, det_info.orientation,
218                    names=['polar_angle', 'azimuthal_angle'])
219                _write_h5_vector(detector_entry, det_info.beam_center,
220                    names=['beam_center_x', 'beam_center_y'],
221                    write_fn=_write_h5_float, units=det_info.beam_center_unit)
222                _write_h5_vector(detector_entry, det_info.pixel_size,
223                    names=['x_pixel_size', 'y_pixel_size'],
224                    write_fn=_write_h5_float, units=det_info.pixel_size_unit)
225        else:
226            # Create a blank one - at least 1 detector required by format
227            detector_entry = instrument_entry.create_group('sasdetector01')
228            detector_entry.attrs['canSAS_class'] = 'SASdetector'
229            detector_entry.attrs['name'] = ''
230
231        # Process meta data
232        for i, process in enumerate(data_info.process):
233            process_entry = sasentry.create_group('sasprocess{0:0=2d}'.format(
234                i + 1))
235            process_entry.attrs['canSAS_class'] = 'SASprocess'
236            if process.name:
237                name = _h5_string(process.name)
238                process_entry.create_dataset('name', data=name)
239            if process.date:
240                date = _h5_string(process.date)
241                process_entry.create_dataset('date', data=date)
242            if process.description:
243                desc = _h5_string(process.description)
244                process_entry.create_dataset('description', data=desc)
245            for j, term in enumerate(process.term):
246                # Don't save empty terms
247                if term:
248                    h5_term = _h5_string(term)
249                    process_entry.create_dataset('term{0:0=2d}'.format(
250                        j + 1), data=h5_term)
251            for j, note in enumerate(process.notes):
252                # Don't save empty notes
253                if note:
254                    h5_note = _h5_string(note)
255                    process_entry.create_dataset('note{0:0=2d}'.format(
256                        j + 1), data=h5_note)
257
258        # Transmission Spectrum
259        for i, trans in enumerate(data_info.trans_spectrum):
260            trans_entry = sasentry.create_group(
261                'sastransmission_spectrum{0:0=2d}'.format(i + 1))
262            trans_entry.attrs['canSAS_class'] = 'SAStransmission_spectrum'
263            trans_entry.attrs['signal'] = 'T'
264            trans_entry.attrs['T_axes'] = 'T'
265            trans_entry.attrs['name'] = trans.name
266            if trans.timestamp is not '':
267                trans_entry.attrs['timestamp'] = trans.timestamp
268            transmission = trans_entry.create_dataset('T',
269                                                      data=trans.transmission)
270            transmission.attrs['unertainties'] = 'Tdev'
271            trans_entry.create_dataset('Tdev',
272                                       data=trans.transmission_deviation)
273            trans_entry.create_dataset('lambda', data=trans.wavelength)
274
275        note_entry = sasentry.create_group('sasnote'.format(i))
276        note_entry.attrs['canSAS_class'] = 'SASnote'
277        notes = None
278        if len(data_info.notes) > 1:
279            notes = [np.string_(n) for n in data_info.notes]
280            notes = np.array(notes)
281        elif data_info.notes != []:
282            notes = _h5_string(data_info.notes[0])
283        if notes is not None:
284            note_entry.create_dataset('SASnote', data=notes)
285
286        f.close()
287
288    def _write_1d_data(self, data_obj, data_entry):
289        """
290        Writes the contents of a Data1D object to a SASdata h5py Group
291
292        :param data_obj: A Data1D object to write to the file
293        :param data_entry: A h5py Group object representing the SASdata
294        """
295        data_entry.attrs['signal'] = 'I'
296        data_entry.attrs['I_axes'] = 'Q'
297        data_entry.attrs['Q_indices'] = [0]
298        q_entry = data_entry.create_dataset('Q', data=data_obj.x)
299        q_entry.attrs['units'] = data_obj.x_unit
300        i_entry = data_entry.create_dataset('I', data=data_obj.y)
301        i_entry.attrs['units'] = data_obj.y_unit
302        if data_obj.dy is not None:
303            i_entry.attrs['uncertainties'] = 'Idev'
304            i_dev_entry = data_entry.create_dataset('Idev', data=data_obj.dy)
305            i_dev_entry.attrs['units'] = data_obj.y_unit
306        if data_obj.dx is not None:
307            q_entry.attrs['resolutions'] = 'dQ'
308            dq_entry = data_entry.create_dataset('dQ', data=data_obj.dx)
309            dq_entry.attrs['units'] = data_obj.x_unit
310        elif data_obj.dxl is not None:
311            q_entry.attrs['resolutions'] = ['dQl','dQw']
312            dql_entry = data_entry.create_dataset('dQl', data=data_obj.dxl)
313            dql_entry.attrs['units'] = data_obj.x_unit
314            dqw_entry = data_entry.create_dataset('dQw', data=data_obj.dxw)
315            dqw_entry.attrs['units'] = data_obj.x_unit
316
317    def _write_2d_data(self, data, data_entry):
318        """
319        Writes the contents of a Data2D object to a SASdata h5py Group
320
321        :param data: A Data2D object to write to the file
322        :param data_entry: A h5py Group object representing the SASdata
323        """
324        data_entry.attrs['signal'] = 'I'
325        data_entry.attrs['I_axes'] = 'Qx,Qy'
326        data_entry.attrs['Q_indices'] = [0,1]
327
328        (n_rows, n_cols) = (len(data.y_bins), len(data.x_bins))
329
330        if (n_rows == 0 and n_cols == 0) or (n_cols*n_rows != data.data.size):
331            # Calculate rows and columns, assuming detector is square
332            # Same logic as used in PlotPanel.py _get_bins
333            n_cols = int(np.floor(np.sqrt(len(data.qy_data))))
334            n_rows = int(np.floor(len(data.qy_data) / n_cols))
335
336            if n_rows * n_cols != len(data.qy_data):
337                raise ValueError("Unable to calculate dimensions of 2D data")
338
339        intensity = np.reshape(data.data, (n_rows, n_cols))
340        qx = np.reshape(data.qx_data, (n_rows, n_cols))
341        qy = np.reshape(data.qy_data, (n_rows, n_cols))
342
343        i_entry = data_entry.create_dataset('I', data=intensity)
344        i_entry.attrs['units'] = data.I_unit
345        qx_entry = data_entry.create_dataset('Qx', data=qx)
346        qx_entry.attrs['units'] = data.Q_unit
347        qy_entry = data_entry.create_dataset('Qy', data=qy)
348        qy_entry.attrs['units'] = data.Q_unit
349        if (data.err_data is not None
350                and not all(v is None for v in data.err_data)):
351            d_i = np.reshape(data.err_data, (n_rows, n_cols))
352            i_entry.attrs['uncertainties'] = 'Idev'
353            i_dev_entry = data_entry.create_dataset('Idev', data=d_i)
354            i_dev_entry.attrs['units'] = data.I_unit
355        if (data.dqx_data is not None
356                and not all(v is None for v in data.dqx_data)):
357            qx_entry.attrs['resolutions'] = 'dQx'
358            dqx_entry = data_entry.create_dataset('dQx', data=data.dqx_data)
359            dqx_entry.attrs['units'] = data.Q_unit
360        if (data.dqy_data is not None
361                and not all(v is None for v in data.dqy_data)):
362            qy_entry.attrs['resolutions'] = 'dQy'
363            dqy_entry = data_entry.create_dataset('dQy', data=data.dqy_data)
364            dqy_entry.attrs['units'] = data.Q_unit
365        if data.mask is not None and not all(v is None for v in data.mask):
366            data_entry.attrs['mask'] = "mask"
367            mask = np.invert(np.asarray(data.mask, dtype=bool))
368            data_entry.create_dataset('mask', data=mask)
Note: See TracBrowser for help on using the repository browser.