Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion orsopy/fileio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .base import (Column, ComplexValue, ErrorColumn, File, Header, Person, Value, ValueRange, ValueVector,
_read_header_data, _validate_header_data)
from .data_source import DataSource, Experiment, InstrumentSettings, Measurement, Polarization, Sample
from .orso import ORSO_VERSION, Orso, OrsoDataset, load_orso, save_orso
from .orso import ORSO_VERSION, Orso, OrsoDataset, load_orso, save_orso, load_nexus, save_nexus
from .reduction import Reduction, Software

__all__ = [s for s in dir() if not s.startswith("_")]
89 changes: 88 additions & 1 deletion orsopy/fileio/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def _noop(self, *args, **kw):
pass


JSON_MIMETYPE = "application/json"

yaml.emitter.Emitter.process_tag = _noop

# make sure that datetime strings get loaded as str not datetime instances
Expand Down Expand Up @@ -82,7 +84,12 @@ def _custom_init_fn(fieldsarg, frozen, has_post_init, self_name, globals):
)


# register all ORSO classes here:
ORSO_DATACLASSES = dict()


def orsodataclass(cls: type):
ORSO_DATACLASSES[cls.__name__] = cls
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be cleaner to use in-build python functionality instead of keeping track of our sub-classes ourselfs. All header classes are derived from Header so you could replace:

if value.__class__ in ORSO_DATACLASSES.values():

with

if isinstance(value, Header):

and

cls = ORSO_DATACLASSES[ORSO_class]

with

orso_classes = dict((c.__name__, c) for c in Header.__subclasses__()))
...
  cls = orso_classes[ORSO_class]

(I think the first call misses sub-subclasses, so it's probably easier to implement it as recursive method in the Header class.)

In this case we are also safe with people sub-classing orso Header without using a decorator (e.g. adding functionality not attributes).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The subclass recursive search would miss OrsoDataset, which needs to be recognized as an ORSO class during deserialization. I think trying to make an inheritance tree that includes both Header and OrsoDataset is going to be more complicated.

attrs = cls.__dict__
bases = cls.__bases__
if "__annotations__" in attrs and len([k for k in attrs["__annotations__"].keys() if not k.startswith("_")]) > 0:
Expand Down Expand Up @@ -275,7 +282,8 @@ def _resolve_type(hint: type, item: Any) -> Any:
return item
else:
warnings.warn(
f"Has to be one of {get_args(hint)} got {item}", RuntimeWarning,
f"Has to be one of {get_args(hint)} got {item}",
RuntimeWarning,
)
return str(item)
return None
Expand Down Expand Up @@ -376,6 +384,67 @@ def yaml_representer_compact(self, dumper: yaml.Dumper):
output = self._to_object_dict()
return dumper.represent_mapping(dumper.DEFAULT_MAPPING_TAG, output, flow_style=True)

def to_nexus(self, root=None, name=None):
"""
Produces an HDF5 representation of the Header object, removing
any optional attributes with the value :code:`None`.

:return: HDF5 object
"""
classname = self.__class__.__name__
import h5py

assert isinstance(root, h5py.Group)
group = root.create_group(classname if name is None else name)
group.attrs["ORSO_class"] = classname

for child_name, value in self.__dict__.items():
if child_name.startswith("_") or (value is None and child_name in self._orso_optionals):
continue

if value.__class__ in ORSO_DATACLASSES.values():
value.to_nexus(root=group, name=child_name)
elif isinstance(value, (list, tuple)):
child_group = group.create_group(child_name)
child_group.attrs["sequence"] = 1
for index, item in enumerate(value):
# use the 'name' attribute of children if it exists, else index:
sub_name = getattr(item, "name", str(index))
if item.__class__ in ORSO_DATACLASSES.values():
item_out = item.to_nexus(root=child_group, name=sub_name)
else:
t_value = nexus_value_converter(item)
if any(isinstance(t_value, t) for t in (str, float, int, bool, np.ndarray)):
item_out = child_group.create_dataset(sub_name, data=t_value)
elif t_value is None:
# special handling for null datasets: no data
item_out = child_group.create_dataset(sub_name, dtype="f")
elif isinstance(t_value, dict):
item_out = child_group.create_dataset(sub_name, data=json.dumps(t_value))
item_out.attrs["mimetype"] = JSON_MIMETYPE
else:
import warnings
# raise ValueError(f"unserializable attribute found: {child_name}[{index}] = {t_value}")
warnings.warn(f"unserializable attribute found: {child_name}[{index}] = {t_value}")
continue
item_out.attrs["sequence_index"] = index
else:
# here _todict converts objects that aren't derived from Header
# and therefore don't have to_dict methods.
t_value = nexus_value_converter(value)
if any(isinstance(t_value, t) for t in (str, float, int, bool, np.ndarray)):
group.create_dataset(child_name, data=t_value)
elif t_value is None:
group.create_dataset(child_name, dtype="f")
elif isinstance(t_value, dict):
dset = group.create_dataset(child_name, data=json.dumps(t_value))
dset.attrs["mimetype"] = JSON_MIMETYPE
else:
import warnings
warnings.warn(f"unserializable attribute found: {child_name} = {t_value}")
# raise ValueError(f"unserializable attribute found: {child_name} = {t_value}")
return group

@staticmethod
def _check_unit(unit: str):
"""
Expand Down Expand Up @@ -952,6 +1021,24 @@ def _todict(obj: Any, classkey: Any = None) -> dict:
return obj


def json_datetime_trap(obj):
if isinstance(obj, datetime.datetime):
return obj.isoformat()
return obj


def enum_trap(obj):
if isinstance(obj, Enum):
return obj.value
return obj


def nexus_value_converter(obj):
for converter in (json_datetime_trap, enum_trap):
obj = converter(obj)
return obj


def _nested_update(d: dict, u: dict) -> dict:
"""
Nested dictionary update.
Expand Down
121 changes: 116 additions & 5 deletions orsopy/fileio/orso.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
Implementation of the top level class for the ORSO header.
"""

from dataclasses import dataclass
from typing import Any, List, Optional, TextIO, Union
from dataclasses import dataclass, fields
from typing import BinaryIO, List, Optional, Sequence, TextIO, Union

import numpy as np
import yaml

from .base import (Column, ErrorColumn, Header, _dict_diff, _nested_update, _possibly_open_file, _read_header_data,
orsodataclass)
from .base import (JSON_MIMETYPE, ORSO_DATACLASSES, Column, ErrorColumn, Header, _dict_diff, _nested_update,
_possibly_open_file, _read_header_data, orsodataclass)
from .data_source import DataSource
from .reduction import Reduction

Expand Down Expand Up @@ -163,7 +163,7 @@ class OrsoDataset:
"""

info: Orso
data: Any
data: Union[np.ndarray, Sequence[np.ndarray], Sequence[Sequence]]

def __post_init__(self):
if self.data.shape[1] != len(self.info.columns):
Expand Down Expand Up @@ -210,6 +210,9 @@ def __eq__(self, other: "OrsoDataset"):
return self.info == other.info and (self.data == other.data).all()


ORSO_DATACLASSES["OrsoDataset"] = OrsoDataset


def save_orso(
datasets: List[OrsoDataset], fname: Union[TextIO, str], comment: Optional[str] = None, data_separator: str = ""
) -> None:
Expand Down Expand Up @@ -273,3 +276,111 @@ def load_orso(fname: Union[TextIO, str]) -> List[OrsoDataset]:
od = OrsoDataset(o, data)
ods.append(od)
return ods


def _from_nexus_group(group):
if group.attrs.get("sequence", None) is not None:
sort_list = [[v.attrs["sequence_index"], v] for v in group.values()]
return [_get_nexus_item(v) for _, v in sorted(sort_list)]
else:
dct = dict()
for name, value in group.items():
if value.attrs.get("NX_class", None) == "NXdata":
# remove NXdata folder, which exists only for NeXus plotting
continue
dct[name] = _get_nexus_item(value)

ORSO_class = group.attrs.get("ORSO_class", None)
if ORSO_class is not None:
if ORSO_class == "OrsoDataset":
# TODO: remove swapaxes if order of data is changed (PR #107)
# reorder columns so column index is second:
dct["data"] = np.asarray(dct["data"]).swapaxes(0, 1)
cls = ORSO_DATACLASSES[ORSO_class]
return cls(**dct)
else:
return dct


def _get_nexus_item(value):
import json

import h5py

if isinstance(value, h5py.Group):
return _from_nexus_group(value)
elif isinstance(value, h5py.Dataset):
v = value[()]
if isinstance(v, h5py.Empty):
return None
elif value.attrs.get("mimetype", None) == JSON_MIMETYPE:
return json.loads(v)
elif hasattr(v, "decode"):
# it is a bytes object, should be string
return v.decode()
else:
return v


def load_nexus(fname: Union[str, BinaryIO]) -> List[OrsoDataset]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be advantageous to have a single load_orso function that chooses which loader to use automatically. That would allow external software to be agnostic about the file type that's used.
We can do this in the next release, though.

Suggested way of implementation:

  1. load_orso scans through a list of known data format classes
  2. Each format class has a static method to check if it supports the file in questoin (e.g. first header line for ort, hdf header for .orb)
  3. If it's supported, the function uses the format class to retrieve the Orso objects from the file
  4. Potentially we could store the format used to unpack in the Orso class

import h5py

f = h5py.File(fname, "r")
# Use '/' because order is not tracked on the File object, but is on the '/' group!
root = f['/']
return [_from_nexus_group(g) for g in root.values() if g.attrs.get("ORSO_class", None) == "OrsoDataset"]


def save_nexus(datasets: List[OrsoDataset], fname: Union[str, BinaryIO], comment: Optional[str] = None) -> BinaryIO:
import h5py
h5py.get_config().track_order = True

for idx, dataset in enumerate(datasets):
info = dataset.info
data_set = info.data_set
if data_set is None or (isinstance(data_set, str) and len(data_set) == 0):
# it's not set, or is zero length string
info.data_set = idx

dsets = [dataset.info.data_set for dataset in datasets]
if len(set(dsets)) != len(dsets):
raise ValueError("All `OrsoDataset.info.data_set` values must be unique")

with h5py.File(fname, mode="w") as f:
f.attrs["NX_class"] = "NXroot"
if comment is not None:
f.attrs["comment"] = comment

for dsi in datasets:
info = dsi.info
entry = f.create_group(str(info.data_set))
entry.attrs["ORSO_class"] = "OrsoDataset"
entry.attrs["ORSO_VERSION"] = ORSO_VERSION
entry.attrs["NX_class"] = "NXentry"
entry.attrs["default"] = "plottable_data"
info.to_nexus(root=entry, name="info")
data_group = entry.create_group("data")
data_group.attrs["sequence"] = 1
plottable_data_group = entry.create_group("plottable_data", track_order=True)
plottable_data_group.attrs["NX_class"] = "NXdata"
plottable_data_group.attrs["sequence"] = 1
plottable_data_group.attrs["axes"] = [info.columns[0].name]
plottable_data_group.attrs["signal"] = info.columns[1].name
plottable_data_group.attrs[f"{info.columns[0].name}_indices"] = [0]
for column_index, column in enumerate(info.columns):
# assume that dataset.data has dimension == ncolumns along first dimension
# (note that this is not how data would be loaded from e.g. load_orso, which is row-first)
col_data = data_group.create_dataset(column.name, data=dsi.data[:, column_index])
col_data.attrs["sequence_index"] = column_index
col_data.attrs["target"] = col_data.name
physical_quantity = getattr(column, 'physical_quantity', None)
if physical_quantity is not None:
col_data.attrs["physical_quantity"] = physical_quantity
if isinstance(column, ErrorColumn):
nexus_colname = column.error_of + "_errors"
else:
nexus_colname = column.name
if column.unit is not None:
col_data.attrs["units"] = column.unit

plottable_data_group[nexus_colname] = col_data
17 changes: 17 additions & 0 deletions orsopy/fileio/tests/test_orso.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,23 @@ def test_write_read(self):
columns=info.columns,
)
ds3 = fileio.OrsoDataset(info3, data)

# .ort read/write
fileio.save_orso([ds, ds2, ds3], "test.ort", comment="Interdiffusion")

ls1, ls2, ls3 = fileio.load_orso("test.ort")
assert ls1 == ds
assert ls2 == ds2
assert ls3 == ds3

# .orb read/write
fileio.save_nexus([ds, ds2, ds3], "test.orb", comment="Interdiffusion")

ls1, ls2, ls3 = fileio.load_nexus("test.orb")
assert ls1 == ds
assert ls2 == ds2
assert ls3 == ds3

# test empty lines between datasets
fileio.save_orso([ds, ds2, ds3], "test.ort", data_separator="\n\n")

Expand Down Expand Up @@ -238,11 +248,18 @@ def test_save_numpy_scalar_dtypes(self):
info.data_source.measurement.instrument_settings.wavelength = Value(np.float64(10.0))
info.data_source.measurement.instrument_settings.incident_angle = Value(np.int32(2))
ds = fileio.orso.OrsoDataset(info, np.arange(20.).reshape(10, 2))
# .ort test:
fileio.save_orso([ds], "test_numpy.ort")
ls = fileio.load_orso("test_numpy.ort")
i_s = ls[0].info.data_source.measurement.instrument_settings
assert i_s.wavelength.magnitude == 10.0
assert i_s.incident_angle.magnitude == 2
# .orb test:
fileio.save_nexus([ds], "test_numpy.orb")
ln = fileio.load_nexus("test_numpy.orb")
i_n = ln[0].info.data_source.measurement.instrument_settings
assert i_n.wavelength.magnitude == 10.0
assert i_n.incident_angle.magnitude == 2


class TestFunctions(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ typing_extensions
coverage
coveralls
pint
h5py>=3.1.0