Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions transfer_queue/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
from tensordict import TensorDict
from tensordict.tensorclass import NonTensorData, NonTensorStack
from torch import Tensor

from transfer_queue.utils.enum_utils import ProductionStatus

Expand Down Expand Up @@ -815,18 +816,26 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) ->

production_status = ProductionStatus.READY_FOR_CONSUME if set_all_ready else ProductionStatus.NOT_PRODUCED

all_fields = [
{
name: FieldMeta(
name=name,
dtype=getattr(value, "dtype", None),
shape=getattr(value, "shape", None),
# unbind nested tensor
results: dict = {}
for field in tensor_dict.keys():
field_data = tensor_dict[field]
if batch_size > 1 and isinstance(field_data, Tensor) and field_data.is_nested:
results[field] = field_data.unbind()
else:
results[field] = field_data

all_fields = []
for idx in range(batch_size):
dict_of_field_meta = {}
for field_name in results.keys():
dict_of_field_meta[field_name] = FieldMeta(
name=field_name,
dtype=getattr(results[field_name][idx], "dtype", None),
shape=getattr(results[field_name][idx], "shape", None),
production_status=production_status,
)
for name, value in tensor_dict[idx].items()
}
for idx in range(batch_size)
]
all_fields.append(dict_of_field_meta)

return all_fields

Expand Down
9 changes: 8 additions & 1 deletion transfer_queue/storage/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,14 @@ def _generate_values(data: TensorDict) -> list[Tensor]:
list[Tensor]: Flattened list of tensors, e.g.,
[data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...]
"""
return [row_data for field in sorted(data.keys()) for row_data in data[field]]
results: list[Tensor] = []
for field in sorted(data.keys()):
field_data = data[field]
if isinstance(field_data, Tensor) and field_data.is_nested:
results.extend(field_data.unbind())
else:
results.extend(field_data)
return results

@staticmethod
def _shutdown_executor(thread_executor: Optional[ThreadPoolExecutor]) -> None:
Expand Down
34 changes: 24 additions & 10 deletions transfer_queue/storage/managers/simple_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import zmq
from omegaconf import DictConfig
from tensordict import NonTensorStack, TensorDict
from torch import Tensor

from transfer_queue.metadata import BatchMeta
from transfer_queue.storage.managers.base import TransferQueueStorageManager
Expand Down Expand Up @@ -201,10 +202,21 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
)

# unbind nested tensor
results: dict = {}
for field in data.keys():
field_data = data[field]
if data.batch_size[0] > 1 and isinstance(field_data, Tensor) and field_data.is_nested:
results[field] = field_data.unbind()
else:
results[field] = field_data

# send data to each storage unit
tasks = [
self._put_to_single_storage_unit(
meta_group.get_local_indexes(), _filter_storage_data(meta_group, data), target_storage_unit=storage_id
meta_group.get_local_indexes(),
_filter_storage_data(meta_group, results),
target_storage_unit=storage_id,
)
for storage_id, meta_group in storage_meta_groups.items()
]
Expand All @@ -221,8 +233,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
per_field_shapes[global_idx] = {}

# For each field, extract dtype and shape for each sample
for field in data.keys():
for i, data_item in enumerate(data[field]):
for field in results.keys():
for i, data_item in enumerate(results[field]):
global_idx = metadata.global_indexes[i]
per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None
per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None
Expand All @@ -234,7 +246,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:

# notify controller that new data is ready
await self.notify_data_update(
partition_id, list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes
partition_id, list(results.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes
)

@dynamic_storage_manager_socket(socket_name="put_get_socket")
Expand Down Expand Up @@ -432,20 +444,20 @@ def close(self) -> None:
super().close()


def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict) -> dict[str, Any]:
"""Filter batch-aligned data from a TensorDict using batch indexes from a StorageMetaGroup.
def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: dict) -> dict[str, Any]:
"""Filter batch-aligned data from a dict using batch indexes from a StorageMetaGroup.
This helper extracts a subset of items from each field in ``data`` according to the
batch indexes stored in ``storage_meta_group``. The same indexes are applied to every
field in the input ``TensorDict`` so that the returned samples remain aligned across
field in the input dict so that the returned samples remain aligned across
fields.

Args:
storage_meta_group: A :class:`StorageMetaGroup` instance that provides
a sequence of batch indexes via :meth:`get_batch_indexes`. Each index
refers to a position along the batch dimension of the tensors stored
in ``data``.
data: A :class:`tensordict.TensorDict` containing batched data fields. All
fields are expected to be indexable by the batch indexes returned by
data: A dict containing batched data fields. All fields are expected to
be indexable by the batch indexes returned by
``storage_meta_group.get_batch_indexes()``.
Returns:
dict[str, Any]: A dictionary mapping each field name in ``data`` to a list
Expand All @@ -461,7 +473,9 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict)
return results

for fname in data.keys():
result = itemgetter(*batch_indexes)(data[fname])
field_data = data[fname]
result = itemgetter(*batch_indexes)(field_data)

if not isinstance(result, tuple):
result = (result,)
results[fname] = list(result)
Expand Down
4 changes: 2 additions & 2 deletions tutorial/04_understanding_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def demonstrate_partition_isolation():
train_data = TensorDict(
{
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
"labels": torch.tensor([0, 1]),
"labels": torch.tensor([[0], [1]]),
},
batch_size=2,
)
Expand All @@ -81,7 +81,7 @@ def demonstrate_partition_isolation():
val_data = TensorDict(
{
"input_ids": torch.tensor([[7, 8, 9], [10, 11, 12]]),
"labels": torch.tensor([2, 3]),
"labels": torch.tensor([[2], [3]]),
},
batch_size=2,
)
Expand Down