diff --git a/mart/transforms/tensor_array.py b/mart/transforms/tensor_array.py new file mode 100644 index 00000000..00a9345a --- /dev/null +++ b/mart/transforms/tensor_array.py @@ -0,0 +1,42 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from functools import singledispatch + +import numpy as np +import torch + + +# A recursive function to convert all np.ndarray in an object to torch.Tensor, or vice versa. +@singledispatch +def convert(obj, device=None): + """All other types, no change.""" + return obj + + +@convert.register +def _(obj: dict, device=None): + return {key: convert(value, device=device) for key, value in obj.items()} + + +@convert.register +def _(obj: list, device=None): + return [convert(item, device=device) for item in obj] + + +@convert.register +def _(obj: tuple, device=None): + return tuple(convert(item, device=device) for item in obj) + + +@convert.register +def _(obj: np.ndarray, device=None): + return torch.tensor(obj, device=device) + + +@convert.register +def _(obj: torch.Tensor, device=None): + return obj.detach().cpu().numpy() diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 00000000..9af53288 --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,21 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +import numpy as np +import torch + +from mart.transforms.tensor_array import convert + + +def test_tensor_array_two_way_convert(): + tensor_expected = [{"key": (torch.tensor(1.0), 2)}] + array_expected = [{"key": (np.array(1.0), 2)}] + + array_result = convert(tensor_expected) + assert array_expected == array_result + + tensor_result = convert(array_expected) + assert tensor_expected == tensor_result