From 68eab7c0235f6cce5b8acef43a6998ef26b36f9a Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Wed, 17 Apr 2024 00:19:51 -0600 Subject: [PATCH 1/3] Modifies collation for numpy arrays --- pfrl/utils/batch_states.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pfrl/utils/batch_states.py b/pfrl/utils/batch_states.py index bd8af6a97..8f419c183 100644 --- a/pfrl/utils/batch_states.py +++ b/pfrl/utils/batch_states.py @@ -1,8 +1,23 @@ + from typing import Any, Callable, Sequence +import numpy as np import torch -from torch.utils.data._utils.collate import default_collate +from torch.utils.data._utils.collate import collate +from torch.utils.data._utils.collate import default_collate_fn_map + + +def collate_numpy_array_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): + """Forked from: https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/collate.py#L216 + """ + elem = batch[0] + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + return collate([torch.tensor(b) for b in batch], collate_fn_map=collate_fn_map) + +pfrl_default_collate_fn_map[np.ndarray] = collate_numpy_array_fn def _to_recursive(batched: Any, device: torch.device) -> Any: if isinstance(batched, torch.Tensor): @@ -29,8 +44,7 @@ def batch_states( the object which will be given as input to the model. """ features = [phi(s) for s in states] - # return concat_examples(features, device=device) - collated_features = default_collate(features) + collated_features = collate(batch, collate_fn_map=pfrl_default_collate_fn_map) if isinstance(features[0], tuple): collated_features = tuple(collated_features) return _to_recursive(collated_features, device) From a481682d4e03bd5c54d33492d30d93f9bf97326d Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Wed, 17 Apr 2024 00:33:22 -0600 Subject: [PATCH 2/3] fixes collation --- pfrl/utils/batch_states.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pfrl/utils/batch_states.py b/pfrl/utils/batch_states.py index 8f419c183..9f7c281fe 100644 --- a/pfrl/utils/batch_states.py +++ b/pfrl/utils/batch_states.py @@ -1,10 +1,10 @@ -from typing import Any, Callable, Sequence +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union +import copy import numpy as np import torch -from torch.utils.data._utils.collate import collate -from torch.utils.data._utils.collate import default_collate_fn_map +from torch.utils.data._utils.collate import collate, default_collate_fn_map, np_str_obj_array_pattern def collate_numpy_array_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): @@ -16,7 +16,7 @@ def collate_numpy_array_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, T raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return collate([torch.tensor(b) for b in batch], collate_fn_map=collate_fn_map) - +pfrl_default_collate_fn_map = copy.deepcopy(default_collate_fn_map) pfrl_default_collate_fn_map[np.ndarray] = collate_numpy_array_fn def _to_recursive(batched: Any, device: torch.device) -> Any: From f28670d7723774a26c0bae6989c1a4c76933319e Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Wed, 17 Apr 2024 00:46:37 -0600 Subject: [PATCH 3/3] fixes typo --- pfrl/utils/batch_states.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pfrl/utils/batch_states.py b/pfrl/utils/batch_states.py index 9f7c281fe..5a924be09 100644 --- a/pfrl/utils/batch_states.py +++ b/pfrl/utils/batch_states.py @@ -44,7 +44,7 @@ def batch_states( the object which will be given as input to the model. """ features = [phi(s) for s in states] - collated_features = collate(batch, collate_fn_map=pfrl_default_collate_fn_map) + collated_features = collate(features, collate_fn_map=pfrl_default_collate_fn_map) if isinstance(features[0], tuple): collated_features = tuple(collated_features) return _to_recursive(collated_features, device)