Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions mart/transforms/tensor_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Copyright (C) 2022 Intel Corporation
#
# SPDX-License-Identifier: BSD-3-Clause
#

from functools import singledispatch

import numpy as np
import torch


# A recursive function to convert all np.ndarray in an object to torch.Tensor, or vice versa.
@singledispatch
def convert(obj, device=None):
"""All other types, no change."""
return obj


@convert.register
def _(obj: dict, device=None):
return {key: convert(value, device=device) for key, value in obj.items()}


@convert.register
def _(obj: list, device=None):
return [convert(item, device=device) for item in obj]


@convert.register
def _(obj: tuple, device=None):
return tuple(convert(item, device=device) for item in obj)


@convert.register
def _(obj: np.ndarray, device=None):
return torch.tensor(obj, device=device)


@convert.register
def _(obj: torch.Tensor, device=None):
return obj.detach().cpu().numpy()
21 changes: 21 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#
# Copyright (C) 2022 Intel Corporation
#
# SPDX-License-Identifier: BSD-3-Clause
#

import numpy as np
import torch

from mart.transforms.tensor_array import convert


def test_tensor_array_two_way_convert():
tensor_expected = [{"key": (torch.tensor(1.0), 2)}]
array_expected = [{"key": (np.array(1.0), 2)}]

array_result = convert(tensor_expected)
assert array_expected == array_result

tensor_result = convert(array_expected)
assert tensor_expected == tensor_result