diff --git a/orsopy/fileio/__init__.py b/orsopy/fileio/__init__.py index e9d7bf91..129821e1 100644 --- a/orsopy/fileio/__init__.py +++ b/orsopy/fileio/__init__.py @@ -5,7 +5,7 @@ from .base import (Column, ComplexValue, ErrorColumn, File, Header, Person, Value, ValueRange, ValueVector, _read_header_data, _validate_header_data) from .data_source import DataSource, Experiment, InstrumentSettings, Measurement, Polarization, Sample -from .orso import ORSO_VERSION, Orso, OrsoDataset, load_orso, save_orso +from .orso import ORSO_VERSION, Orso, OrsoDataset, load_orso, save_orso, load_nexus, save_nexus from .reduction import Reduction, Software __all__ = [s for s in dir() if not s.startswith("_")] diff --git a/orsopy/fileio/base.py b/orsopy/fileio/base.py index 3514e2e7..87b2c855 100644 --- a/orsopy/fileio/base.py +++ b/orsopy/fileio/base.py @@ -35,6 +35,8 @@ def _noop(self, *args, **kw): pass +JSON_MIMETYPE = "application/json" + yaml.emitter.Emitter.process_tag = _noop # make sure that datetime strings get loaded as str not datetime instances @@ -82,7 +84,12 @@ def _custom_init_fn(fieldsarg, frozen, has_post_init, self_name, globals): ) +# register all ORSO classes here: +ORSO_DATACLASSES = dict() + + def orsodataclass(cls: type): + ORSO_DATACLASSES[cls.__name__] = cls attrs = cls.__dict__ bases = cls.__bases__ if "__annotations__" in attrs and len([k for k in attrs["__annotations__"].keys() if not k.startswith("_")]) > 0: @@ -275,7 +282,8 @@ def _resolve_type(hint: type, item: Any) -> Any: return item else: warnings.warn( - f"Has to be one of {get_args(hint)} got {item}", RuntimeWarning, + f"Has to be one of {get_args(hint)} got {item}", + RuntimeWarning, ) return str(item) return None @@ -376,6 +384,67 @@ def yaml_representer_compact(self, dumper: yaml.Dumper): output = self._to_object_dict() return dumper.represent_mapping(dumper.DEFAULT_MAPPING_TAG, output, flow_style=True) + def to_nexus(self, root=None, name=None): + """ + Produces an HDF5 representation of the Header object, removing + any optional attributes with the value :code:`None`. + + :return: HDF5 object + """ + classname = self.__class__.__name__ + import h5py + + assert isinstance(root, h5py.Group) + group = root.create_group(classname if name is None else name) + group.attrs["ORSO_class"] = classname + + for child_name, value in self.__dict__.items(): + if child_name.startswith("_") or (value is None and child_name in self._orso_optionals): + continue + + if value.__class__ in ORSO_DATACLASSES.values(): + value.to_nexus(root=group, name=child_name) + elif isinstance(value, (list, tuple)): + child_group = group.create_group(child_name) + child_group.attrs["sequence"] = 1 + for index, item in enumerate(value): + # use the 'name' attribute of children if it exists, else index: + sub_name = getattr(item, "name", str(index)) + if item.__class__ in ORSO_DATACLASSES.values(): + item_out = item.to_nexus(root=child_group, name=sub_name) + else: + t_value = nexus_value_converter(item) + if any(isinstance(t_value, t) for t in (str, float, int, bool, np.ndarray)): + item_out = child_group.create_dataset(sub_name, data=t_value) + elif t_value is None: + # special handling for null datasets: no data + item_out = child_group.create_dataset(sub_name, dtype="f") + elif isinstance(t_value, dict): + item_out = child_group.create_dataset(sub_name, data=json.dumps(t_value)) + item_out.attrs["mimetype"] = JSON_MIMETYPE + else: + import warnings + # raise ValueError(f"unserializable attribute found: {child_name}[{index}] = {t_value}") + warnings.warn(f"unserializable attribute found: {child_name}[{index}] = {t_value}") + continue + item_out.attrs["sequence_index"] = index + else: + # here _todict converts objects that aren't derived from Header + # and therefore don't have to_dict methods. + t_value = nexus_value_converter(value) + if any(isinstance(t_value, t) for t in (str, float, int, bool, np.ndarray)): + group.create_dataset(child_name, data=t_value) + elif t_value is None: + group.create_dataset(child_name, dtype="f") + elif isinstance(t_value, dict): + dset = group.create_dataset(child_name, data=json.dumps(t_value)) + dset.attrs["mimetype"] = JSON_MIMETYPE + else: + import warnings + warnings.warn(f"unserializable attribute found: {child_name} = {t_value}") + # raise ValueError(f"unserializable attribute found: {child_name} = {t_value}") + return group + @staticmethod def _check_unit(unit: str): """ @@ -952,6 +1021,24 @@ def _todict(obj: Any, classkey: Any = None) -> dict: return obj +def json_datetime_trap(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + return obj + + +def enum_trap(obj): + if isinstance(obj, Enum): + return obj.value + return obj + + +def nexus_value_converter(obj): + for converter in (json_datetime_trap, enum_trap): + obj = converter(obj) + return obj + + def _nested_update(d: dict, u: dict) -> dict: """ Nested dictionary update. diff --git a/orsopy/fileio/orso.py b/orsopy/fileio/orso.py index 6ddf21fe..dddd7cea 100644 --- a/orsopy/fileio/orso.py +++ b/orsopy/fileio/orso.py @@ -2,14 +2,14 @@ Implementation of the top level class for the ORSO header. """ -from dataclasses import dataclass -from typing import Any, List, Optional, TextIO, Union +from dataclasses import dataclass, fields +from typing import BinaryIO, List, Optional, Sequence, TextIO, Union import numpy as np import yaml -from .base import (Column, ErrorColumn, Header, _dict_diff, _nested_update, _possibly_open_file, _read_header_data, - orsodataclass) +from .base import (JSON_MIMETYPE, ORSO_DATACLASSES, Column, ErrorColumn, Header, _dict_diff, _nested_update, + _possibly_open_file, _read_header_data, orsodataclass) from .data_source import DataSource from .reduction import Reduction @@ -163,7 +163,7 @@ class OrsoDataset: """ info: Orso - data: Any + data: Union[np.ndarray, Sequence[np.ndarray], Sequence[Sequence]] def __post_init__(self): if self.data.shape[1] != len(self.info.columns): @@ -210,6 +210,9 @@ def __eq__(self, other: "OrsoDataset"): return self.info == other.info and (self.data == other.data).all() +ORSO_DATACLASSES["OrsoDataset"] = OrsoDataset + + def save_orso( datasets: List[OrsoDataset], fname: Union[TextIO, str], comment: Optional[str] = None, data_separator: str = "" ) -> None: @@ -273,3 +276,111 @@ def load_orso(fname: Union[TextIO, str]) -> List[OrsoDataset]: od = OrsoDataset(o, data) ods.append(od) return ods + + +def _from_nexus_group(group): + if group.attrs.get("sequence", None) is not None: + sort_list = [[v.attrs["sequence_index"], v] for v in group.values()] + return [_get_nexus_item(v) for _, v in sorted(sort_list)] + else: + dct = dict() + for name, value in group.items(): + if value.attrs.get("NX_class", None) == "NXdata": + # remove NXdata folder, which exists only for NeXus plotting + continue + dct[name] = _get_nexus_item(value) + + ORSO_class = group.attrs.get("ORSO_class", None) + if ORSO_class is not None: + if ORSO_class == "OrsoDataset": + # TODO: remove swapaxes if order of data is changed (PR #107) + # reorder columns so column index is second: + dct["data"] = np.asarray(dct["data"]).swapaxes(0, 1) + cls = ORSO_DATACLASSES[ORSO_class] + return cls(**dct) + else: + return dct + + +def _get_nexus_item(value): + import json + + import h5py + + if isinstance(value, h5py.Group): + return _from_nexus_group(value) + elif isinstance(value, h5py.Dataset): + v = value[()] + if isinstance(v, h5py.Empty): + return None + elif value.attrs.get("mimetype", None) == JSON_MIMETYPE: + return json.loads(v) + elif hasattr(v, "decode"): + # it is a bytes object, should be string + return v.decode() + else: + return v + + +def load_nexus(fname: Union[str, BinaryIO]) -> List[OrsoDataset]: + import h5py + + f = h5py.File(fname, "r") + # Use '/' because order is not tracked on the File object, but is on the '/' group! + root = f['/'] + return [_from_nexus_group(g) for g in root.values() if g.attrs.get("ORSO_class", None) == "OrsoDataset"] + + +def save_nexus(datasets: List[OrsoDataset], fname: Union[str, BinaryIO], comment: Optional[str] = None) -> BinaryIO: + import h5py + h5py.get_config().track_order = True + + for idx, dataset in enumerate(datasets): + info = dataset.info + data_set = info.data_set + if data_set is None or (isinstance(data_set, str) and len(data_set) == 0): + # it's not set, or is zero length string + info.data_set = idx + + dsets = [dataset.info.data_set for dataset in datasets] + if len(set(dsets)) != len(dsets): + raise ValueError("All `OrsoDataset.info.data_set` values must be unique") + + with h5py.File(fname, mode="w") as f: + f.attrs["NX_class"] = "NXroot" + if comment is not None: + f.attrs["comment"] = comment + + for dsi in datasets: + info = dsi.info + entry = f.create_group(str(info.data_set)) + entry.attrs["ORSO_class"] = "OrsoDataset" + entry.attrs["ORSO_VERSION"] = ORSO_VERSION + entry.attrs["NX_class"] = "NXentry" + entry.attrs["default"] = "plottable_data" + info.to_nexus(root=entry, name="info") + data_group = entry.create_group("data") + data_group.attrs["sequence"] = 1 + plottable_data_group = entry.create_group("plottable_data", track_order=True) + plottable_data_group.attrs["NX_class"] = "NXdata" + plottable_data_group.attrs["sequence"] = 1 + plottable_data_group.attrs["axes"] = [info.columns[0].name] + plottable_data_group.attrs["signal"] = info.columns[1].name + plottable_data_group.attrs[f"{info.columns[0].name}_indices"] = [0] + for column_index, column in enumerate(info.columns): + # assume that dataset.data has dimension == ncolumns along first dimension + # (note that this is not how data would be loaded from e.g. load_orso, which is row-first) + col_data = data_group.create_dataset(column.name, data=dsi.data[:, column_index]) + col_data.attrs["sequence_index"] = column_index + col_data.attrs["target"] = col_data.name + physical_quantity = getattr(column, 'physical_quantity', None) + if physical_quantity is not None: + col_data.attrs["physical_quantity"] = physical_quantity + if isinstance(column, ErrorColumn): + nexus_colname = column.error_of + "_errors" + else: + nexus_colname = column.name + if column.unit is not None: + col_data.attrs["units"] = column.unit + + plottable_data_group[nexus_colname] = col_data diff --git a/orsopy/fileio/tests/test_orso.py b/orsopy/fileio/tests/test_orso.py index f2e66403..5ff3823a 100644 --- a/orsopy/fileio/tests/test_orso.py +++ b/orsopy/fileio/tests/test_orso.py @@ -149,6 +149,8 @@ def test_write_read(self): columns=info.columns, ) ds3 = fileio.OrsoDataset(info3, data) + + # .ort read/write fileio.save_orso([ds, ds2, ds3], "test.ort", comment="Interdiffusion") ls1, ls2, ls3 = fileio.load_orso("test.ort") @@ -156,6 +158,14 @@ def test_write_read(self): assert ls2 == ds2 assert ls3 == ds3 + # .orb read/write + fileio.save_nexus([ds, ds2, ds3], "test.orb", comment="Interdiffusion") + + ls1, ls2, ls3 = fileio.load_nexus("test.orb") + assert ls1 == ds + assert ls2 == ds2 + assert ls3 == ds3 + # test empty lines between datasets fileio.save_orso([ds, ds2, ds3], "test.ort", data_separator="\n\n") @@ -238,11 +248,18 @@ def test_save_numpy_scalar_dtypes(self): info.data_source.measurement.instrument_settings.wavelength = Value(np.float64(10.0)) info.data_source.measurement.instrument_settings.incident_angle = Value(np.int32(2)) ds = fileio.orso.OrsoDataset(info, np.arange(20.).reshape(10, 2)) + # .ort test: fileio.save_orso([ds], "test_numpy.ort") ls = fileio.load_orso("test_numpy.ort") i_s = ls[0].info.data_source.measurement.instrument_settings assert i_s.wavelength.magnitude == 10.0 assert i_s.incident_angle.magnitude == 2 + # .orb test: + fileio.save_nexus([ds], "test_numpy.orb") + ln = fileio.load_nexus("test_numpy.orb") + i_n = ln[0].info.data_source.measurement.instrument_settings + assert i_n.wavelength.magnitude == 10.0 + assert i_n.incident_angle.magnitude == 2 class TestFunctions(unittest.TestCase): diff --git a/requirements_dev.txt b/requirements_dev.txt index 68469a17..9f648130 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -14,3 +14,4 @@ typing_extensions coverage coveralls pint +h5py>=3.1.0