From 5b0379cfefe25483dbb650e34ade29c2e679b925 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Feb 2026 10:41:46 +0800 Subject: [PATCH 1/7] unbind jagged tensor Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/base.py | 1 + .../storage/managers/simple_backend_manager.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index c0c5058..2376fb4 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -394,6 +394,7 @@ def _generate_values(data: TensorDict) -> list[Tensor]: list[Tensor]: Flattened list of tensors, e.g., [data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...] """ + # TODO: unbind jagged tensor first return [row_data for field in sorted(data.keys()) for row_data in data[field]] @staticmethod diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index f142069..a79be63 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -461,7 +461,17 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) return results for fname in data.keys(): - result = itemgetter(*batch_indexes)(data[fname]) + field_data = data[fname] + + # For nested tensors, itemgetter with multiple indexes is extremely slow + # because it requires repeated indexing operations. Unbinding first and then + # using itemgetter on the list is much faster + if isinstance(field_data, torch.Tensor) and field_data.layout == torch.jagged: + field_list = field_data.unbind() + result = itemgetter(*batch_indexes)(field_list) + else: + result = itemgetter(*batch_indexes)(field_data) + if not isinstance(result, tuple): result = (result,) results[fname] = list(result) From 0e1462a3508ab335ef0d51d7a474321c97677855 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Feb 2026 11:06:16 +0800 Subject: [PATCH 2/7] optimize for KVStorageManager Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/base.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 2376fb4..cdbc06a 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -394,8 +394,16 @@ def _generate_values(data: TensorDict) -> list[Tensor]: list[Tensor]: Flattened list of tensors, e.g., [data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...] """ - # TODO: unbind jagged tensor first - return [row_data for field in sorted(data.keys()) for row_data in data[field]] + results: list[Tensor] = [] + for field in sorted(data.keys()): + field_data = data[field] + # For jagged tensors, iterate over unbind() list (views, not copies) + # This is much faster than direct iteration over nested tensor + if isinstance(field_data, Tensor) and field_data.layout == torch.jagged: + results.extend(field_data.unbind()) + else: + results.extend(field_data) + return results @staticmethod def _shutdown_executor(thread_executor: Optional[ThreadPoolExecutor]) -> None: From 3f32bea56216b43965c75c6ec03624c2213ccdca Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Feb 2026 11:28:26 +0800 Subject: [PATCH 3/7] fix review comments Signed-off-by: 0oshowero0 --- .../managers/simple_backend_manager.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index a79be63..5198162 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -27,6 +27,7 @@ import zmq from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict +from torch import Tensor from transfer_queue.metadata import BatchMeta from transfer_queue.storage.managers.base import TransferQueueStorageManager @@ -201,10 +202,23 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping ) + # unbind jagged tensor + results: dict = {} + for field in sorted(data.keys()): + field_data = data[field] + + # For jagged tensors, unbind() first to accelerate indexing process + if isinstance(field_data, Tensor) and field_data.layout == torch.jagged: + results[field] = field_data.unbind() + else: + results[field] = field_data + # send data to each storage unit tasks = [ self._put_to_single_storage_unit( - meta_group.get_local_indexes(), _filter_storage_data(meta_group, data), target_storage_unit=storage_id + meta_group.get_local_indexes(), + _filter_storage_data(meta_group, results), + target_storage_unit=storage_id, ) for storage_id, meta_group in storage_meta_groups.items() ] @@ -221,8 +235,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: per_field_shapes[global_idx] = {} # For each field, extract dtype and shape for each sample - for field in data.keys(): - for i, data_item in enumerate(data[field]): + for field in results.keys(): + for i, data_item in enumerate(results[field]): global_idx = metadata.global_indexes[i] per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None @@ -234,7 +248,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: # notify controller that new data is ready await self.notify_data_update( - partition_id, list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes + partition_id, list(results.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes ) @dynamic_storage_manager_socket(socket_name="put_get_socket") @@ -432,11 +446,11 @@ def close(self) -> None: super().close() -def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) -> dict[str, Any]: +def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: dict) -> dict[str, Any]: """Filter batch-aligned data from a TensorDict using batch indexes from a StorageMetaGroup. This helper extracts a subset of items from each field in ``data`` according to the batch indexes stored in ``storage_meta_group``. The same indexes are applied to every - field in the input ``TensorDict`` so that the returned samples remain aligned across + field in the input dict so that the returned samples remain aligned across fields. Args: @@ -444,8 +458,8 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) a sequence of batch indexes via :meth:`get_batch_indexes`. Each index refers to a position along the batch dimension of the tensors stored in ``data``. - data: A :class:`tensordict.TensorDict` containing batched data fields. All - fields are expected to be indexable by the batch indexes returned by + data: A dict containing batched data fields. All fields are expected to + be indexable by the batch indexes returned by ``storage_meta_group.get_batch_indexes()``. Returns: dict[str, Any]: A dictionary mapping each field name in ``data`` to a list @@ -462,15 +476,7 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) for fname in data.keys(): field_data = data[fname] - - # For nested tensors, itemgetter with multiple indexes is extremely slow - # because it requires repeated indexing operations. Unbinding first and then - # using itemgetter on the list is much faster - if isinstance(field_data, torch.Tensor) and field_data.layout == torch.jagged: - field_list = field_data.unbind() - result = itemgetter(*batch_indexes)(field_list) - else: - result = itemgetter(*batch_indexes)(field_data) + result = itemgetter(*batch_indexes)(field_data) if not isinstance(result, tuple): result = (result,) From c735a44d4e82bc1971020b1f7edd69b1e9d0d2a8 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Feb 2026 11:54:32 +0800 Subject: [PATCH 4/7] fix Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/base.py | 4 +--- transfer_queue/storage/managers/simple_backend_manager.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index cdbc06a..2e9ce30 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -397,9 +397,7 @@ def _generate_values(data: TensorDict) -> list[Tensor]: results: list[Tensor] = [] for field in sorted(data.keys()): field_data = data[field] - # For jagged tensors, iterate over unbind() list (views, not copies) - # This is much faster than direct iteration over nested tensor - if isinstance(field_data, Tensor) and field_data.layout == torch.jagged: + if isinstance(field_data, Tensor) and field_data.is_nested: results.extend(field_data.unbind()) else: results.extend(field_data) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 5198162..875c8d0 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -447,7 +447,7 @@ def close(self) -> None: def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: dict) -> dict[str, Any]: - """Filter batch-aligned data from a TensorDict using batch indexes from a StorageMetaGroup. + """Filter batch-aligned data from a dict using batch indexes from a StorageMetaGroup. This helper extracts a subset of items from each field in ``data`` according to the batch indexes stored in ``storage_meta_group``. The same indexes are applied to every field in the input dict so that the returned samples remain aligned across From 4844b596eadf800ad363b235b79e245b7a6a20ac Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Feb 2026 16:19:50 +0800 Subject: [PATCH 5/7] fix Signed-off-by: 0oshowero0 --- transfer_queue/metadata.py | 13 ++++++++++++- .../storage/managers/simple_backend_manager.py | 6 ++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 18a6014..2e35447 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -26,6 +26,7 @@ import torch from tensordict import TensorDict from tensordict.tensorclass import NonTensorData, NonTensorStack +from torch import Tensor from transfer_queue.utils.enum_utils import ProductionStatus @@ -815,6 +816,15 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> production_status = ProductionStatus.READY_FOR_CONSUME if set_all_ready else ProductionStatus.NOT_PRODUCED + # unbind nested tensor + results: dict = {} + for field in sorted(tensor_dict.keys()): + field_data = tensor_dict[field] + if isinstance(field_data, Tensor) and field_data.is_nested: + results[field] = field_data.unbind() + else: + results[field] = field_data + all_fields = [ { name: FieldMeta( @@ -823,7 +833,8 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> shape=getattr(value, "shape", None), production_status=production_status, ) - for name, value in tensor_dict[idx].items() + for name in results.keys() + for value in results[name][idx] } for idx in range(batch_size) ] diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 875c8d0..9bea43e 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -202,13 +202,11 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping ) - # unbind jagged tensor + # unbind nested tensor results: dict = {} for field in sorted(data.keys()): field_data = data[field] - - # For jagged tensors, unbind() first to accelerate indexing process - if isinstance(field_data, Tensor) and field_data.layout == torch.jagged: + if isinstance(field_data, Tensor) and field_data.is_nested: results[field] = field_data.unbind() else: results[field] = field_data From 0e85e0debba9179701d561ebde6847bef0f67888 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Feb 2026 16:31:46 +0800 Subject: [PATCH 6/7] fix Signed-off-by: 0oshowero0 --- tutorial/04_understanding_controller.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorial/04_understanding_controller.py b/tutorial/04_understanding_controller.py index 746de14..73a0874 100644 --- a/tutorial/04_understanding_controller.py +++ b/tutorial/04_understanding_controller.py @@ -69,7 +69,7 @@ def demonstrate_partition_isolation(): train_data = TensorDict( { "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]), - "labels": torch.tensor([0, 1]), + "labels": torch.tensor([[0], [1]]), }, batch_size=2, ) @@ -81,7 +81,7 @@ def demonstrate_partition_isolation(): val_data = TensorDict( { "input_ids": torch.tensor([[7, 8, 9], [10, 11, 12]]), - "labels": torch.tensor([2, 3]), + "labels": torch.tensor([[2], [3]]), }, batch_size=2, ) From 1a09d4c5544cc59e968260f2a7346b8bc9c40e12 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Feb 2026 17:36:48 +0800 Subject: [PATCH 7/7] fix Signed-off-by: 0oshowero0 --- transfer_queue/metadata.py | 24 +++++++++---------- .../managers/simple_backend_manager.py | 4 ++-- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 2e35447..64134bf 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -818,26 +818,24 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> # unbind nested tensor results: dict = {} - for field in sorted(tensor_dict.keys()): + for field in tensor_dict.keys(): field_data = tensor_dict[field] - if isinstance(field_data, Tensor) and field_data.is_nested: + if batch_size > 1 and isinstance(field_data, Tensor) and field_data.is_nested: results[field] = field_data.unbind() else: results[field] = field_data - all_fields = [ - { - name: FieldMeta( - name=name, - dtype=getattr(value, "dtype", None), - shape=getattr(value, "shape", None), + all_fields = [] + for idx in range(batch_size): + dict_of_field_meta = {} + for field_name in results.keys(): + dict_of_field_meta[field_name] = FieldMeta( + name=field_name, + dtype=getattr(results[field_name][idx], "dtype", None), + shape=getattr(results[field_name][idx], "shape", None), production_status=production_status, ) - for name in results.keys() - for value in results[name][idx] - } - for idx in range(batch_size) - ] + all_fields.append(dict_of_field_meta) return all_fields diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 9bea43e..b658ba4 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -204,9 +204,9 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: # unbind nested tensor results: dict = {} - for field in sorted(data.keys()): + for field in data.keys(): field_data = data[field] - if isinstance(field_data, Tensor) and field_data.is_nested: + if data.batch_size[0] > 1 and isinstance(field_data, Tensor) and field_data.is_nested: results[field] = field_data.unbind() else: results[field] = field_data