diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 18a6014..64134bf 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -26,6 +26,7 @@ import torch from tensordict import TensorDict from tensordict.tensorclass import NonTensorData, NonTensorStack +from torch import Tensor from transfer_queue.utils.enum_utils import ProductionStatus @@ -815,18 +816,26 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> production_status = ProductionStatus.READY_FOR_CONSUME if set_all_ready else ProductionStatus.NOT_PRODUCED - all_fields = [ - { - name: FieldMeta( - name=name, - dtype=getattr(value, "dtype", None), - shape=getattr(value, "shape", None), + # unbind nested tensor + results: dict = {} + for field in tensor_dict.keys(): + field_data = tensor_dict[field] + if batch_size > 1 and isinstance(field_data, Tensor) and field_data.is_nested: + results[field] = field_data.unbind() + else: + results[field] = field_data + + all_fields = [] + for idx in range(batch_size): + dict_of_field_meta = {} + for field_name in results.keys(): + dict_of_field_meta[field_name] = FieldMeta( + name=field_name, + dtype=getattr(results[field_name][idx], "dtype", None), + shape=getattr(results[field_name][idx], "shape", None), production_status=production_status, ) - for name, value in tensor_dict[idx].items() - } - for idx in range(batch_size) - ] + all_fields.append(dict_of_field_meta) return all_fields diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index c0c5058..2e9ce30 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -394,7 +394,14 @@ def _generate_values(data: TensorDict) -> list[Tensor]: list[Tensor]: Flattened list of tensors, e.g., [data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...] """ - return [row_data for field in sorted(data.keys()) for row_data in data[field]] + results: list[Tensor] = [] + for field in sorted(data.keys()): + field_data = data[field] + if isinstance(field_data, Tensor) and field_data.is_nested: + results.extend(field_data.unbind()) + else: + results.extend(field_data) + return results @staticmethod def _shutdown_executor(thread_executor: Optional[ThreadPoolExecutor]) -> None: diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index f142069..b658ba4 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -27,6 +27,7 @@ import zmq from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict +from torch import Tensor from transfer_queue.metadata import BatchMeta from transfer_queue.storage.managers.base import TransferQueueStorageManager @@ -201,10 +202,21 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping ) + # unbind nested tensor + results: dict = {} + for field in data.keys(): + field_data = data[field] + if data.batch_size[0] > 1 and isinstance(field_data, Tensor) and field_data.is_nested: + results[field] = field_data.unbind() + else: + results[field] = field_data + # send data to each storage unit tasks = [ self._put_to_single_storage_unit( - meta_group.get_local_indexes(), _filter_storage_data(meta_group, data), target_storage_unit=storage_id + meta_group.get_local_indexes(), + _filter_storage_data(meta_group, results), + target_storage_unit=storage_id, ) for storage_id, meta_group in storage_meta_groups.items() ] @@ -221,8 +233,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: per_field_shapes[global_idx] = {} # For each field, extract dtype and shape for each sample - for field in data.keys(): - for i, data_item in enumerate(data[field]): + for field in results.keys(): + for i, data_item in enumerate(results[field]): global_idx = metadata.global_indexes[i] per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None @@ -234,7 +246,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: # notify controller that new data is ready await self.notify_data_update( - partition_id, list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes + partition_id, list(results.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes ) @dynamic_storage_manager_socket(socket_name="put_get_socket") @@ -432,11 +444,11 @@ def close(self) -> None: super().close() -def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) -> dict[str, Any]: - """Filter batch-aligned data from a TensorDict using batch indexes from a StorageMetaGroup. +def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: dict) -> dict[str, Any]: + """Filter batch-aligned data from a dict using batch indexes from a StorageMetaGroup. This helper extracts a subset of items from each field in ``data`` according to the batch indexes stored in ``storage_meta_group``. The same indexes are applied to every - field in the input ``TensorDict`` so that the returned samples remain aligned across + field in the input dict so that the returned samples remain aligned across fields. Args: @@ -444,8 +456,8 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) a sequence of batch indexes via :meth:`get_batch_indexes`. Each index refers to a position along the batch dimension of the tensors stored in ``data``. - data: A :class:`tensordict.TensorDict` containing batched data fields. All - fields are expected to be indexable by the batch indexes returned by + data: A dict containing batched data fields. All fields are expected to + be indexable by the batch indexes returned by ``storage_meta_group.get_batch_indexes()``. Returns: dict[str, Any]: A dictionary mapping each field name in ``data`` to a list @@ -461,7 +473,9 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) return results for fname in data.keys(): - result = itemgetter(*batch_indexes)(data[fname]) + field_data = data[fname] + result = itemgetter(*batch_indexes)(field_data) + if not isinstance(result, tuple): result = (result,) results[fname] = list(result) diff --git a/tutorial/04_understanding_controller.py b/tutorial/04_understanding_controller.py index 746de14..73a0874 100644 --- a/tutorial/04_understanding_controller.py +++ b/tutorial/04_understanding_controller.py @@ -69,7 +69,7 @@ def demonstrate_partition_isolation(): train_data = TensorDict( { "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]), - "labels": torch.tensor([0, 1]), + "labels": torch.tensor([[0], [1]]), }, batch_size=2, ) @@ -81,7 +81,7 @@ def demonstrate_partition_isolation(): val_data = TensorDict( { "input_ids": torch.tensor([[7, 8, 9], [10, 11, 12]]), - "labels": torch.tensor([2, 3]), + "labels": torch.tensor([[2], [3]]), }, batch_size=2, )