From 8acf1c657c0566d7eb44e7286d5d45bfdd0e8012 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 21 Sep 2023 11:21:09 -0700 Subject: [PATCH 1/3] Add conversion between PyTorch tensors and Numpy arrays. --- mart/transforms/tensor_array.py | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 mart/transforms/tensor_array.py diff --git a/mart/transforms/tensor_array.py b/mart/transforms/tensor_array.py new file mode 100644 index 00000000..ca88072f --- /dev/null +++ b/mart/transforms/tensor_array.py @@ -0,0 +1,42 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from functools import singledispatch + +import numpy as np +import torch + + +# A recursive function to convert all np.ndarray in an object to torch.Tensor, or vice versa. +@singledispatch +def convert(obj, device=None): + """All other types, no change.""" + return obj + + +@convert.register +def _(obj: dict, device=None): + return {key: convert(value, device=device) for key, value in obj.items()} + + +@convert.register +def _(obj: list, device=None): + return [convert(item, device=device) for item in obj] + + +@convert.register +def _(obj: tuple, device=None): + return tuple(convert(obj, device=device)) + + +@convert.register +def _(obj: np.ndarray, device=None): + return torch.tensor(obj, device=device) + + +@convert.register +def _(obj: torch.Tensor, device=None): + return obj.detach().cpu().numpy() From 45b735bdf358853fe9e0a8e8a356349fd04d73c2 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 21 Sep 2023 11:36:53 -0700 Subject: [PATCH 2/3] Add test. --- tests/test_transforms.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 tests/test_transforms.py diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 00000000..9af53288 --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,21 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +import numpy as np +import torch + +from mart.transforms.tensor_array import convert + + +def test_tensor_array_two_way_convert(): + tensor_expected = [{"key": (torch.tensor(1.0), 2)}] + array_expected = [{"key": (np.array(1.0), 2)}] + + array_result = convert(tensor_expected) + assert array_expected == array_result + + tensor_result = convert(array_expected) + assert tensor_expected == tensor_result From 1d44eed290bbb794eaeb7e60cb464624d57e17d6 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 21 Sep 2023 11:37:40 -0700 Subject: [PATCH 3/3] Fix a bug. --- mart/transforms/tensor_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/transforms/tensor_array.py b/mart/transforms/tensor_array.py index ca88072f..00a9345a 100644 --- a/mart/transforms/tensor_array.py +++ b/mart/transforms/tensor_array.py @@ -29,7 +29,7 @@ def _(obj: list, device=None): @convert.register def _(obj: tuple, device=None): - return tuple(convert(obj, device=device)) + return tuple(convert(item, device=device) for item in obj) @convert.register