diff --git a/scripts/put_benchmark.py b/scripts/put_benchmark.py index 1700d55..6b2afb5 100644 --- a/scripts/put_benchmark.py +++ b/scripts/put_benchmark.py @@ -33,13 +33,11 @@ parent_dir = Path(__file__).resolve().parent.parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - AsyncTransferQueueClient, - SimpleStorageUnit, - TransferQueueController, - process_zmq_server_info, -) +from transfer_queue import TransferQueueClient # noqa: E402 +from transfer_queue.controller import TransferQueueController # noqa: E402 +from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402 from transfer_queue.utils.common import get_placement_group # noqa: E402 +from transfer_queue.utils.zmq_utils import process_zmq_server_info # noqa: E402 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -309,7 +307,7 @@ def initialize_system(self, config_dict): self.tq_config = OmegaConf.merge(tq_internal_conf, self.tq_config) # Client Init - self.data_system_client = AsyncTransferQueueClient( + self.data_system_client = TransferQueueClient( client_id="Trainer", controller_info=self.data_system_controller_info ) self.data_system_client.initialize_storage_manager( diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index 4e7025f..a8afaaf 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -420,15 +420,20 @@ def test_cross_shard_complex_update(e2e_client): "Region 30-39 tensor_f32 should match original Put B" ) - # 9. Verify new fields exist in update region - extended_fields = base_fields + ["new_extra_tensor", "new_extra_non_tensor"] - update_region_meta = poll_for_meta( - client, partition_id, extended_fields, 20, "update_region_task", mode="force_fetch" + # 9. Verify new fields exist in update region (indices 10-29 only have new fields). + # Build extended_meta from full_meta (which has valid _custom_backend_meta/_su_id) + # by selecting the subset of samples whose global_indexes match meta_update. + # Using meta_update directly would fail because it was derived from alloc_meta + # before put(), so its _custom_backend_meta lacks _su_id. + update_gis = set(meta_update.global_indexes) + update_positions_in_full = [i for i, gi in enumerate(full_meta.global_indexes) if gi in update_gis] + update_meta_with_backend = full_meta.select_samples(update_positions_in_full) + extended_meta = update_meta_with_backend.with_data_fields( + base_fields + ["new_extra_tensor", "new_extra_non_tensor"] ) - if update_region_meta is not None and update_region_meta.size > 0: - update_region_data = client.get_data(update_region_meta) - assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist" - assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist" + update_region_data = client.get_data(extended_meta) + assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist" + assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist" finally: client.clear_partition(partition_id) diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 5187e16..0a45653 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -17,6 +17,7 @@ from pathlib import Path from unittest.mock import AsyncMock, Mock, patch +import numpy as np import pytest import pytest_asyncio import torch @@ -27,7 +28,7 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 +from transfer_queue.metadata import BatchMeta # noqa: E402 from transfer_queue.storage import AsyncSimpleStorageManager # noqa: E402 from transfer_queue.utils.enum_utils import TransferQueueRole # noqa: E402 from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo # noqa: E402 @@ -79,11 +80,6 @@ async def mock_async_storage_manager(): manager.controller_handshake_socket = None manager.zmq_context = None - # Add mapping functions - storage_unit_keys = list(storage_unit_infos.keys()) - manager.global_index_storage_unit_mapping = lambda x: storage_unit_keys[x % len(storage_unit_keys)] - manager.global_index_local_index_mapping = lambda x: x // len(storage_unit_keys) - # Mock essential methods manager._connect_to_controller = mock_connect @@ -100,41 +96,35 @@ async def test_async_storage_manager_initialization(mock_async_storage_manager): assert "storage_0" in manager.storage_unit_infos assert "storage_1" in manager.storage_unit_infos - # Test mapping functions - assert manager.global_index_storage_unit_mapping(0) == "storage_0" - assert manager.global_index_storage_unit_mapping(1) == "storage_1" - assert manager.global_index_local_index_mapping(0) == 0 - assert manager.global_index_local_index_mapping(3) == 1 - @pytest.mark.asyncio async def test_async_storage_manager_mock_operations(mock_async_storage_manager): """Test AsyncSimpleStorageManager operations with mocked ZMQ.""" manager = mock_async_storage_manager - # Create test metadata - sample_metas = [ - SampleMeta( - partition_id="0", - global_index=0, - fields={ - "test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)), - }, - ), - SampleMeta( - partition_id="0", - global_index=1, - fields={ - "test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)), - }, - ), - ] - batch_meta = BatchMeta(samples=sample_metas) + # Create test metadata using columnar API + batch_meta = BatchMeta( + global_indexes=[0, 1], + partition_ids=["0", "0"], + field_schema={ + "test_field": { + "dtype": torch.float32, + "shape": (2,), + "is_nested": False, + "is_non_tensor": False, + } + }, + production_status=np.ones(2, dtype=np.int8), + _custom_backend_meta={ + 0: {"_su_id": "storage_0"}, + 1: {"_su_id": "storage_1"}, + }, + ) # Create test data test_data = TensorDict( { - "test_field": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])], + "test_field": torch.stack([torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]), }, batch_size=2, ) @@ -163,92 +153,6 @@ async def test_async_storage_manager_mock_operations(mock_async_storage_manager) await manager.clear_data(batch_meta) -@pytest.mark.asyncio -async def test_async_storage_manager_mapping_functions(): - """Test AsyncSimpleStorageManager mapping functions.""" - - # Mock storage unit infos - storage_unit_infos = { - "storage_0": ZMQServerInfo( - role=TransferQueueRole.STORAGE, - id="storage_0", - ip="127.0.0.1", - ports={"put_get_socket": 12345}, - ), - "storage_1": ZMQServerInfo( - role=TransferQueueRole.STORAGE, - id="storage_1", - ip="127.0.0.1", - ports={"put_get_socket": 12346}, - ), - "storage_2": ZMQServerInfo( - role=TransferQueueRole.STORAGE, - id="storage_2", - ip="127.0.0.1", - ports={"put_get_socket": 12347}, - ), - } - - # Mock controller info - controller_info = ZMQServerInfo( - role=TransferQueueRole.CONTROLLER, - id="controller_0", - ip="127.0.0.1", - ports={"handshake_socket": 12348, "data_status_update_socket": 12349}, - ) - - config = { - "zmq_info": storage_unit_infos, - } - - # Mock ZMQ operations - with ( - patch("transfer_queue.storage.managers.base.create_zmq_socket") as mock_create_socket, - patch("zmq.Poller") as mock_poller, - ): - # Create mock socket with proper sync methods - mock_socket = Mock() - mock_socket.connect = Mock() # sync method - mock_socket.send = Mock() # sync method - mock_create_socket.return_value = mock_socket - - # Mock poller with sync methods - mock_poller_instance = Mock() - mock_poller_instance.register = Mock() # sync method - # Return mock socket in poll to simulate handshake response - mock_poller_instance.poll = Mock(return_value=[(mock_socket, zmq.POLLIN)]) # sync method - mock_poller.return_value = mock_poller_instance - - # Mock handshake response - handshake_response = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, - sender_id="controller_0", - body={"message": "Handshake successful"}, - ) - mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize()) - - # Create manager - manager = AsyncSimpleStorageManager(controller_info, config) - - # Test round-robin mapping for 3 storage units - # global_index -> storage_unit mapping: 0->storage_0, 1->storage_1, 2->storage_2, - # 3->storage_0, 4->storage_1, ... - assert manager.global_index_storage_unit_mapping(0) == "storage_0" - assert manager.global_index_storage_unit_mapping(1) == "storage_1" - assert manager.global_index_storage_unit_mapping(2) == "storage_2" - assert manager.global_index_storage_unit_mapping(3) == "storage_0" - assert manager.global_index_storage_unit_mapping(4) == "storage_1" - assert manager.global_index_storage_unit_mapping(5) == "storage_2" - - # global_index -> local_index mapping: global_index // num_storage_units - assert manager.global_index_local_index_mapping(0) == 0 - assert manager.global_index_local_index_mapping(1) == 0 - assert manager.global_index_local_index_mapping(2) == 0 - assert manager.global_index_local_index_mapping(3) == 1 - assert manager.global_index_local_index_mapping(4) == 1 - assert manager.global_index_local_index_mapping(5) == 1 - - @pytest.mark.asyncio async def test_async_storage_manager_error_handling(): """Test AsyncSimpleStorageManager error handling.""" @@ -310,22 +214,26 @@ async def test_async_storage_manager_error_handling(): manager._clear_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock CLEAR error")) manager.notify_data_update = AsyncMock() - # Create test metadata - sample_metas = [ - SampleMeta( - partition_id="0", - global_index=0, - fields={ - "test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)), - }, - ), - ] - batch_meta = BatchMeta(samples=sample_metas) + # Create test metadata using columnar API + batch_meta = BatchMeta( + global_indexes=[0], + partition_ids=["0"], + field_schema={ + "test_field": { + "dtype": torch.float32, + "shape": (2,), + "is_nested": False, + "is_non_tensor": False, + } + }, + production_status=np.ones(1, dtype=np.int8), + _custom_backend_meta={0: {"_su_id": "storage_0"}}, + ) # Create test data test_data = TensorDict( { - "test_field": [torch.tensor([1.0, 2.0])], + "test_field": torch.tensor([[1.0, 2.0]]), }, batch_size=1, ) @@ -340,3 +248,192 @@ async def test_async_storage_manager_error_handling(): # Note: clear_data uses return_exceptions=True, so it doesn't raise exceptions directly # Instead, we can verify that the clear operation was attempted await manager.clear_data(batch_meta) # Should not raise due to return_exceptions=True + + +@pytest.mark.asyncio +async def test_put_data_notifies_su_id(): + """put_data 调用 notify_data_update 时必须传入 custom_backend_meta 含 _su_id.""" + storage_unit_infos = { + "storage_0": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_0", + ip="127.0.0.1", + ports={"put_get_socket": 19000}, + ), + "storage_1": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_1", + ip="127.0.0.1", + ports={"put_get_socket": 19001}, + ), + } + + with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) + manager.storage_manager_id = "test_manager" + manager.storage_unit_infos = storage_unit_infos + manager.controller_info = None + manager.data_status_update_socket = None + manager.controller_handshake_socket = None + manager.zmq_context = None + + manager._put_to_single_storage_unit = AsyncMock() + notify_mock = AsyncMock() + manager.notify_data_update = notify_mock + + batch_meta = BatchMeta( + global_indexes=[0, 1, 2, 3], + partition_ids=["p0"] * 4, + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(4, dtype=np.int8), + ) + data = TensorDict({"f": torch.randn(4, 2)}, batch_size=4) + + await manager.put_data(data, batch_meta) + + # notify_data_update 必须被调用且 custom_backend_meta 不为 None + notify_mock.assert_awaited_once() + call_kwargs = notify_mock.call_args + custom_backend_meta = call_kwargs.kwargs.get("custom_backend_meta") or ( + call_kwargs.args[-1] if call_kwargs.args else None + ) + assert custom_backend_meta is not None, "custom_backend_meta 未传入 notify_data_update" + # 每个 gi 都应有 _su_id + for gi in [0, 1, 2, 3]: + assert gi in custom_backend_meta, f"gi={gi} 不在 custom_backend_meta" + assert "_su_id" in custom_backend_meta[gi], f"gi={gi} 缺少 _su_id" + assert custom_backend_meta[gi]["_su_id"] in storage_unit_infos + + +@pytest.mark.asyncio +async def test_put_data_no_batch_counter(): + """put_data 不应存在 _batch_counter 属性(已删除).""" + storage_unit_infos = { + "storage_0": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_0", + ip="127.0.0.1", + ports={"put_get_socket": 19002}, + ), + } + with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) + manager.storage_manager_id = "test_manager_2" + manager.storage_unit_infos = storage_unit_infos + manager.controller_info = None + manager.data_status_update_socket = None + manager.controller_handshake_socket = None + manager.zmq_context = None + + assert not hasattr(manager, "_batch_counter"), "_batch_counter 应已删除" + + +@pytest.mark.asyncio +async def test_get_data_routes_from_custom_backend_meta(): + """get_data 应从 metadata._custom_backend_meta 读取 _su_id 做路由,不重算.""" + storage_unit_infos = { + "storage_0": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_0", + ip="127.0.0.1", + ports={"put_get_socket": 19010}, + ), + "storage_1": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_1", + ip="127.0.0.1", + ports={"put_get_socket": 19011}, + ), + } + with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) + manager.storage_manager_id = "test_get" + manager.storage_unit_infos = storage_unit_infos + manager.controller_info = None + manager.data_status_update_socket = None + manager.controller_handshake_socket = None + manager.zmq_context = None + + # gi=0,1 -> storage_0; gi=2,3 -> storage_1(通过 _custom_backend_meta 指定) + batch_meta = BatchMeta( + global_indexes=[0, 1, 2, 3], + partition_ids=["p0"] * 4, + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(4, dtype=np.int8), + _custom_backend_meta={ + 0: {"_su_id": "storage_0"}, + 1: {"_su_id": "storage_0"}, + 2: {"_su_id": "storage_1"}, + 3: {"_su_id": "storage_1"}, + }, + ) + + # Mock _get_from_single_storage_unit to record which su_id and gi were requested + called_with: dict[str, list] = {} + + async def fake_get(gi_list, fields, target_storage_unit=None, **kwargs): + su = target_storage_unit + called_with[su] = list(gi_list) + tensors = [torch.zeros(2) for _ in gi_list] + return gi_list, fields, {"f": tensors}, b"" + + manager._get_from_single_storage_unit = fake_get + + await manager.get_data(batch_meta) + + assert "storage_0" in called_with, "storage_0 未被 get 调用" + assert "storage_1" in called_with, "storage_1 未被 get 调用" + assert set(called_with["storage_0"]) == {0, 1} + assert set(called_with["storage_1"]) == {2, 3} + + +@pytest.mark.asyncio +async def test_clear_data_routes_from_custom_backend_meta(): + """clear_data 应从 metadata._custom_backend_meta 读取 _su_id 做路由.""" + storage_unit_infos = { + "storage_0": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_0", + ip="127.0.0.1", + ports={"put_get_socket": 19020}, + ), + "storage_1": ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="storage_1", + ip="127.0.0.1", + ports={"put_get_socket": 19021}, + ), + } + with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"): + manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager) + manager.storage_manager_id = "test_clear" + manager.storage_unit_infos = storage_unit_infos + manager.controller_info = None + manager.data_status_update_socket = None + manager.controller_handshake_socket = None + manager.zmq_context = None + + batch_meta = BatchMeta( + global_indexes=[0, 1, 2, 3], + partition_ids=["p0"] * 4, + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(4, dtype=np.int8), + _custom_backend_meta={ + 0: {"_su_id": "storage_0"}, + 1: {"_su_id": "storage_0"}, + 2: {"_su_id": "storage_1"}, + 3: {"_su_id": "storage_1"}, + }, + ) + + called_with: dict[str, list] = {} + + async def fake_clear(gi_list, target_storage_unit=None, **kwargs): + called_with[target_storage_unit] = list(gi_list) + + manager._clear_single_storage_unit = fake_clear + + await manager.clear_data(batch_meta) + + assert set(called_with.get("storage_0", [])) == {0, 1} + assert set(called_with.get("storage_1", [])) == {2, 3} diff --git a/tests/test_client.py b/tests/test_client.py index 5d308d8..adf8127 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -31,10 +31,8 @@ from transfer_queue import TransferQueueClient # noqa: E402 from transfer_queue.metadata import ( # noqa: E402 BatchMeta, - FieldMeta, - SampleMeta, ) -from transfer_queue.utils.enum_utils import ProductionStatus, TransferQueueRole # noqa: E402 +from transfer_queue.utils.enum_utils import TransferQueueRole # noqa: E402 from transfer_queue.utils.zmq_utils import ( # noqa: E402 ZMQMessage, ZMQRequestType, @@ -172,24 +170,16 @@ def _mock_batch_meta(self, request_body): batch_size = request_body.get("batch_size", 1) data_fields = request_body.get("data_fields", []) - samples = [] - for i in range(batch_size): - fields = [] - for field_name in data_fields: - field_meta = FieldMeta( - name=field_name, - dtype=None, - shape=None, - production_status=ProductionStatus.NOT_PRODUCED, - ) - fields.append(field_meta) - sample = SampleMeta( - partition_id="0", - global_index=i, - fields={field.name: field for field in fields}, - ) - samples.append(sample) - metadata = BatchMeta(samples=samples) + # Build columnar field_schema + field_schema = { + fname: {"dtype": None, "shape": None, "is_nested": False, "is_non_tensor": False} for fname in data_fields + } + + metadata = BatchMeta( + global_indexes=list(range(batch_size)), + partition_ids=["0"] * batch_size, + field_schema=field_schema, + ) return {"metadata": metadata} @@ -199,39 +189,31 @@ def _mock_kv_retrieve_keys(self, request_body): create = request_body.get("create", False) partition_id = request_body.get("partition_id", "") - # Initialize key tracking if not exists if not hasattr(self, "_kv_partition_keys"): self._kv_partition_keys = {} - # Generate global indexes for the keys start_index = self._get_next_kv_index(partition_id) global_indexes = list(range(start_index, start_index + len(keys))) - # Create metadata for each key - samples = [] - for i, key in enumerate(keys): - field_meta = FieldMeta( - name="data", - dtype=torch.float32, - shape=torch.Size([1, 10]), - production_status=ProductionStatus.READY_FOR_CONSUME, - ) - sample = SampleMeta( - partition_id=partition_id, - global_index=global_indexes[i], - fields={"data": field_meta}, - ) - samples.append(sample) - - metadata = BatchMeta(samples=samples) + # Build columnar BatchMeta for KV interface + field_schema = { + "data": {"dtype": "torch.float32", "shape": [1, 10], "is_nested": False, "is_non_tensor": False} + } + import numpy as np + + production_status = np.ones(len(global_indexes), dtype=np.int8) + metadata = BatchMeta( + global_indexes=global_indexes, + partition_ids=[partition_id] * len(global_indexes), + field_schema=field_schema, + production_status=production_status, + ) - # Store keys for this partition (only when create=True) if create: if partition_id not in self._kv_partition_keys: self._kv_partition_keys[partition_id] = [] self._kv_partition_keys[partition_id].extend(keys) - # Update the next index for this partition if global_indexes: self._update_kv_index(partition_id, global_indexes[-1] + 1) @@ -808,7 +790,7 @@ async def test_async_clear_samples_with_empty_metadata(client_setup): client, _, _ = client_setup # Create empty BatchMeta - metadata = BatchMeta(samples=[]) + metadata = BatchMeta(global_indexes=[], partition_ids=[], field_schema={}) # The clear operation should complete without raising an exception # because the mock storage manager is configured to handle this diff --git a/tests/test_controller.py b/tests/test_controller.py index 3528d1a..e0220e2 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -29,7 +29,6 @@ logger = logging.getLogger(__name__) from transfer_queue.controller import TransferQueueController # noqa: E402 -from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 @pytest.fixture(scope="function") @@ -67,13 +66,9 @@ def test_controller_with_single_partition(self, ray_setup): ) assert metadata.global_indexes == list(range(gbs * num_n_samples)) - assert metadata.samples[0].partition_id == "train_0" - assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in metadata.samples]) == int( - ProductionStatus.NOT_PRODUCED - ) - assert sum([int(sample.fields.get("attention_mask").production_status) for sample in metadata.samples]) == int( - ProductionStatus.NOT_PRODUCED - ) + assert metadata.partition_ids[0] == "train_0" + # In insert mode, production_status should be all zeros (NOT_PRODUCED) + assert metadata.production_status is not None and all(metadata.production_status == 0) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id)) assert partition_index_range == list(range(gbs * num_n_samples)) @@ -158,7 +153,7 @@ def test_controller_with_single_partition(self, ray_setup): ) assert gen_meta.global_indexes == list(range(gbs * num_n_samples)) - assert gen_meta.samples[0].partition_id == "train_0" + assert gen_meta.partition_ids[0] == "train_0" assert gen_meta.field_names == ["prompt_ids"] partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples)) @@ -187,7 +182,8 @@ def test_controller_with_single_partition(self, ray_setup): ) ) assert clear_meta.global_indexes == list(range(gbs * num_n_samples)) - assert [sample.fields for sample in clear_meta.samples] == [{}] * (gbs * num_n_samples) + # In insert mode with no fields, field_schema should be empty + assert clear_meta.field_schema == {} or clear_meta.field_names == [] print("✓ Clear metadata correct") # Test clear_partition @@ -456,13 +452,9 @@ def test_controller_with_multi_partitions(self, ray_setup): part1_index_range = gbs_1 * num_n_samples_1 part2_index_range = gbs_2 * num_n_samples_2 assert val_metadata.global_indexes == list(range(part1_index_range, part2_index_range + part1_index_range)) - assert val_metadata.samples[0].partition_id == "val_0" - assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in val_metadata.samples]) == int( - ProductionStatus.NOT_PRODUCED - ) - assert sum( - [int(sample.fields.get("attention_mask").production_status) for sample in val_metadata.samples] - ) == int(ProductionStatus.NOT_PRODUCED) + assert val_metadata.partition_ids[0] == "val_0" + # In insert mode, production_status should be all zeros (NOT_PRODUCED) + assert val_metadata.production_status is not None and all(val_metadata.production_status == 0) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2)) assert partition_index_range == list(range(part1_index_range, part2_index_range + part1_index_range)) @@ -536,13 +528,9 @@ def test_controller_with_multi_partitions(self, ray_setup): ) ) assert metadata_2.global_indexes == list(range(32)) + list(range(48, 80)) - assert metadata_2.samples[0].partition_id == "train_1" - assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in metadata_2.samples]) == int( - ProductionStatus.NOT_PRODUCED - ) - assert sum( - [int(sample.fields.get("attention_mask").production_status) for sample in metadata_2.samples] - ) == int(ProductionStatus.NOT_PRODUCED) + assert metadata_2.partition_ids[0] == "train_1" + # In insert mode, production_status should be all zeros (NOT_PRODUCED) + assert metadata_2.production_status is not None and all(metadata_2.production_status == 0) partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3)) assert partition_index_range == list(range(32)) + list(range(48, 80)) print("✓ Correctly assign partition_3") @@ -884,12 +872,9 @@ def test_controller_kv_retrieve_keys_with_production_status(self, ray_setup): tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False) ) - # Verify production status is available - assert len(retrieved_metadata.samples) == len(keys) - for sample in retrieved_metadata.samples: - assert "data" in sample.fields - assert sample.fields["data"].dtype == "torch.float32" - assert sample.fields["data"].shape == (64,) + # Verify production status is available (columnar API) + assert len(retrieved_metadata.global_indexes) == len(keys) + assert "data" in retrieved_metadata.field_schema print("✓ kv_retrieve_keys works with production status") diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index ec5b29c..6320f23 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -25,33 +25,35 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 +from transfer_queue.metadata import BatchMeta # noqa: E402 from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402 -from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 def get_meta(data, global_indexes=None): if not global_indexes: - global_indexes = range(data.batch_size[0]) - samples = [] - for sample_id in range(data.batch_size[0]): - fields_dict = {} - for field_name in data.keys(): - tensor = data[field_name][sample_id] - field_meta = FieldMeta( - name=field_name, - dtype=tensor.dtype if isinstance(tensor, torch.Tensor) else None, - shape=tensor.shape if isinstance(tensor, torch.Tensor) else None, - production_status=ProductionStatus.READY_FOR_CONSUME, - ) - fields_dict[field_name] = field_meta - sample = SampleMeta( - partition_id=0, - global_index=global_indexes[sample_id], - fields=fields_dict, - ) - samples.append(sample) - metadata = BatchMeta(samples=samples) + global_indexes = list(range(data.batch_size[0])) + + # Build columnar field_schema from the data + field_schema = {} + for field_name in data.keys(): + tensor = data[field_name][0] + field_schema[field_name] = { + "dtype": tensor.dtype if isinstance(tensor, torch.Tensor) else type(tensor), + "shape": tensor.shape if isinstance(tensor, torch.Tensor) else None, + "is_nested": False, + "is_non_tensor": not isinstance(tensor, torch.Tensor), + } + + import numpy as np + + production_status = np.ones(len(global_indexes), dtype=np.int8) + + metadata = BatchMeta( + global_indexes=list(global_indexes), + partition_ids=["0"] * len(global_indexes), + field_schema=field_schema, + production_status=production_status, + ) return metadata @@ -196,14 +198,13 @@ def test_get_shape_type_custom_backend_meta_list_without_custom_backend_meta(tes def test_get_shape_type_custom_backend_meta_list_with_custom_backend_meta(test_data): """Test _get_shape_type_custom_backend_meta_list returns correct custom_backend_meta when provided.""" - # Add custom_backend_meta to metadata - custom_backend_meta = { - 8: {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}}, - 9: {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}}, - 10: {"text": {"key7": "value7"}, "label": {"key8": "value8"}, "mask": {"key9": "value9"}}, - } + # Add custom_backend_meta to metadata (columnar: list aligned with global_indexes [8, 9, 10]) metadata = test_data["metadata"] - metadata._custom_backend_meta.update(custom_backend_meta) + metadata._custom_backend_meta = [ + {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # global_index=8 + {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}}, # global_index=9 + {"text": {"key7": "value7"}, "label": {"key8": "value8"}, "mask": {"key9": "value9"}}, # global_index=10 + ] shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) @@ -224,14 +225,13 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_backend_meta(test_d def test_get_shape_type_custom_backend_meta_list_with_partial_custom_backend_meta(test_data): """Test _get_shape_type_custom_backend_meta_list handles partial custom_backend_meta correctly.""" - # Add custom_backend_meta only for some global_indexes and fields - custom_backend_meta = { - 8: {"text": {"key1": "value1"}}, # Only text field - # global_index 9 has no custom_backend_meta - 10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only - } + # Add custom_backend_meta only for some fields (columnar: list aligned with global_indexes [8, 9, 10]) metadata = test_data["metadata"] - metadata._custom_backend_meta.update(custom_backend_meta) + metadata._custom_backend_meta = [ + {"text": {"key1": "value1"}}, # global_index=8: only text field + {}, # global_index=9: no custom_backend_meta + {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # global_index=10: label and mask only + ] shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 2a129b5..7d88eba 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -13,1038 +13,199 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for TransferQueue metadata module - Learning Examples.""" +"""Unit tests for TransferQueue metadata module - Columnar BatchMeta + KVBatchMeta.""" import sys from pathlib import Path +import numpy as np import pytest import torch -from tensordict import TensorDict -from tensordict.tensorclass import NonTensorStack # Setup path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue.metadata import BatchMeta, FieldMeta, KVBatchMeta, SampleMeta # noqa: E402 -from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 +from transfer_queue.metadata import BatchMeta, KVBatchMeta # noqa: E402 +# ============================================================================== +# Columnar BatchMeta Tests +# ============================================================================== -class TestFieldMeta: - """FieldMeta learning examples.""" - def test_field_meta_is_ready(self): - """Test the is_ready property based on production status.""" - field_ready = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME - ) - assert field_ready.is_ready is True - - field_not_ready = FieldMeta( - name="test_field", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.NOT_PRODUCED - ) - assert field_not_ready.is_ready is False - - -class TestSampleMeta: - """SampleMeta learning examples.""" - - def test_sample_meta_union(self): - """Example: Union fields from two samples with matching global indexes.""" - # Create first sample - fields1 = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - sample1 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields1) +class TestBatchMetaColumnar: + """Columnar BatchMeta using field_schema + production_status (numpy array).""" - # Create second sample with additional fields - fields2 = { - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), + def _make_batch(self, batch_size=3, field_names=None): + """Helper: create a simple columnar BatchMeta.""" + if field_names is None: + field_names = ["field_a", "field_b"] + field_schema = { + fname: {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False} + for fname in field_names } - sample2 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields2) - - # Union samples - result = sample1.union(sample2) - - # Result contains all fields from both samples - assert "field1" in result.fields - assert "field2" in result.fields # From sample2 - assert "field3" in result.fields - - def test_sample_meta_union_validation_error(self): - """Example: Union validation catches mismatched global indexes.""" - sample1 = SampleMeta( - partition_id="partition_0", - global_index=0, - fields={"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,))}, - ) - - sample2 = SampleMeta( - partition_id="partition_0", - global_index=1, # Different global index - fields={"field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,))}, + production_status = np.ones(batch_size, dtype=np.int8) + return BatchMeta( + global_indexes=list(range(batch_size)), + partition_ids=["partition_0"] * batch_size, + field_schema=field_schema, + production_status=production_status, ) - with pytest.raises(ValueError) as exc_info: - sample1.union(sample2, validate=True) - assert "Global indexes" in str(exc_info.value) - - def test_sample_meta_add_fields(self): - """Example: Add new fields to a sample.""" - initial_fields = { - "field1": FieldMeta( - name="field1", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=initial_fields) - - new_fields = { - "field2": FieldMeta( - name="field2", dtype=torch.int64, shape=(3,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - sample.add_fields(new_fields) - - assert "field1" in sample.fields - assert "field2" in sample.fields - assert sample.is_ready is True - - def test_sample_meta_select_fields(self): - """Example: Select specific fields from a sample.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), - } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) - - # Select only field1 and field3 - selected_sample = sample.select_fields(["field1", "field3"]) - - assert "field1" in selected_sample.fields - assert "field3" in selected_sample.fields - assert "field2" not in selected_sample.fields - # Original sample is unchanged - assert len(sample.fields) == 3 - # Selected sample has correct metadata - assert selected_sample.fields["field1"].dtype == torch.float32 - assert selected_sample.fields["field1"].shape == (2,) - assert selected_sample.global_index == 0 - assert selected_sample.partition_id == "partition_0" - - def test_sample_meta_select_fields_with_nonexistent_fields(self): - """Example: Select fields ignores non-existent field names.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) - - # Try to select a field that doesn't exist - selected_sample = sample.select_fields(["field1", "nonexistent_field"]) - - # Only existing field is selected - assert "field1" in selected_sample.fields - assert "nonexistent_field" not in selected_sample.fields - assert "field2" not in selected_sample.fields - - def test_sample_meta_select_fields_empty_list(self): - """Example: Select with empty field list returns sample with no fields.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) - - # Select with empty list - selected_sample = sample.select_fields([]) - - assert len(selected_sample.fields) == 0 - assert selected_sample.global_index == 0 - assert selected_sample.partition_id == "partition_0" - - -class TestBatchMeta: - """BatchMeta learning examples - Core Operations.""" - - def test_batch_meta_chunk(self): - """Example: Split a batch into multiple chunks.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [SampleMeta(partition_id="partition_0", global_index=i, fields=fields) for i in range(10)] + def test_basic_init(self): + """Test basic columnar BatchMeta initialization.""" + batch = self._make_batch() + assert len(batch) == 3 + assert batch.global_indexes == [0, 1, 2] + assert batch.partition_ids == ["partition_0", "partition_0", "partition_0"] + assert "field_a" in batch.field_schema + assert "field_b" in batch.field_schema + assert batch.field_names == ["field_a", "field_b"] + + def test_production_status_vector(self): + """Test that production_status is accessible per sample.""" + batch = self._make_batch() + assert batch.production_status is not None + assert len(batch.production_status) == 3 + assert all(batch.production_status == 1) + + def test_chunk(self): + """Test splitting a batch into chunks.""" batch = BatchMeta( - samples=samples, - custom_meta={i: {"uid": i} for i in range(10)}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in range(10)}, + global_indexes=list(range(10)), + partition_ids=["partition_0"] * 10, + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(10, dtype=np.int8), + custom_meta=[{"uid": i} for i in range(10)], + _custom_backend_meta=[{"f": {"key": i}} for i in range(10)], ) - - # Chunk into 3 parts chunks = batch.chunk(3) - assert len(chunks) == 3 - assert len(chunks[0]) == 4 # First chunk gets extra element + # First chunk gets extra element (ceil division) + assert len(chunks[0]) == 4 assert len(chunks[1]) == 3 assert len(chunks[2]) == 3 - - # validate custom_meta is chunked - assert 0 in chunks[0].custom_meta - assert 1 in chunks[0].custom_meta - assert 2 in chunks[0].custom_meta - assert 3 in chunks[0].custom_meta - assert 4 not in chunks[0].custom_meta - assert 4 in chunks[1].custom_meta - - # validate _custom_backend_meta is chunked - assert 0 in chunks[0]._custom_backend_meta - assert 1 in chunks[0]._custom_backend_meta - assert 2 in chunks[0]._custom_backend_meta - assert 3 in chunks[0]._custom_backend_meta - assert 4 not in chunks[0]._custom_backend_meta - assert 4 in chunks[1]._custom_backend_meta - - def test_batch_meta_chunk_by_partition(self): - """Example: Split a batch into multiple chunks.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [SampleMeta(partition_id=f"partition_{i % 4}", global_index=i + 10, fields=fields) for i in range(10)] + # custom_meta is chunked correctly (positional) + assert chunks[0].custom_meta[0] == {"uid": 0} + assert chunks[0].custom_meta[3] == {"uid": 3} + assert len(chunks[0].custom_meta) == 4 + assert chunks[1].custom_meta[0] == {"uid": 4} + + def test_chunk_by_partition(self): + """Test splitting by partition_id.""" batch = BatchMeta( - samples=samples, - custom_meta={i + 10: {"uid": i + 10} for i in range(10)}, - _custom_backend_meta={i + 10: {"test_field": {"dtype": torch.float32}} for i in range(10)}, + global_indexes=[10, 11, 12, 13], + partition_ids=["part_A", "part_B", "part_A", "part_B"], + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, ) - - # Chunk according to partition_id chunks = batch.chunk_by_partition() - - assert len(chunks) == 4 - assert len(chunks[0]) == 3 - assert chunks[0].partition_ids == ["partition_0", "partition_0", "partition_0"] - assert chunks[0].global_indexes == [10, 14, 18] - assert len(chunks[1]) == 3 - assert chunks[1].partition_ids == ["partition_1", "partition_1", "partition_1"] - assert chunks[1].global_indexes == [11, 15, 19] - assert len(chunks[2]) == 2 - assert chunks[2].partition_ids == ["partition_2", "partition_2"] - assert chunks[2].global_indexes == [12, 16] - assert len(chunks[3]) == 2 - assert chunks[3].partition_ids == ["partition_3", "partition_3"] - assert chunks[3].global_indexes == [13, 17] - - # validate custom_meta is chunked - assert 10 in chunks[0].custom_meta - assert 14 in chunks[0].custom_meta - assert 18 in chunks[0].custom_meta - assert 11 not in chunks[0].custom_meta - assert 11 in chunks[1].custom_meta - - # validate _custom_backend_meta is chunked - assert 10 in chunks[0]._custom_backend_meta - assert 14 in chunks[0]._custom_backend_meta - assert 18 in chunks[0]._custom_backend_meta - assert 11 not in chunks[0]._custom_backend_meta - assert 11 in chunks[1]._custom_backend_meta - - def test_batch_meta_init_validation_error_different_field_names(self): - """Example: Init validation catches samples with different field names.""" - # Create first sample with field1 - fields1 = {"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,))} - sample1 = SampleMeta(partition_id="partition_0", global_index=0, fields=fields1) - - # Create second sample with field2 - fields2 = {"field2": FieldMeta(name="field2", dtype=torch.float32, shape=(2,))} - sample2 = SampleMeta(partition_id="partition_0", global_index=1, fields=fields2) - - # Attempt to create BatchMeta with samples having different field names - with pytest.raises(ValueError) as exc_info: - BatchMeta(samples=[sample1, sample2]) - assert "All samples in BatchMeta must have the same field_names." in str(exc_info.value) - - def test_batch_meta_concat(self): - """Example: Concatenate multiple batches.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - # Create two batches - batch1 = BatchMeta( - samples=[ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ], - custom_meta={i: {"uid": i} for i in [0, 1]}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in [0, 1]}, - ) - + assert len(chunks) == 2 + part_ids = [c.partition_ids[0] for c in chunks] + assert "part_A" in part_ids + assert "part_B" in part_ids + + def test_concat(self): + """Test concatenating two batches.""" + batch1 = self._make_batch(batch_size=2) batch2 = BatchMeta( - samples=[ - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), - SampleMeta(partition_id="partition_0", global_index=3, fields=fields), - ], - custom_meta={i: {"uid": i} for i in [2, 3]}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in [2, 3]}, + global_indexes=[2, 3], + partition_ids=["partition_0", "partition_0"], + field_schema=batch1.field_schema, + production_status=np.ones(2, dtype=np.int8), ) - - # Concatenate batches result = BatchMeta.concat([batch1, batch2]) - assert len(result) == 4 assert result.global_indexes == [0, 1, 2, 3] - assert result.custom_meta == {i: {"uid": i} for i in [0, 1, 2, 3]} - assert result._custom_backend_meta == {i: {"test_field": {"dtype": torch.float32}} for i in [0, 1, 2, 3]} - - def test_batch_meta_concat_with_tensor_extra_info(self): - """Example: Concat handles tensor extra_info by concatenating along dim=0.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["tensor"] = torch.randn(3, 4) - batch1.extra_info["scalar"] = torch.tensor(1.0) - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["tensor"] = torch.randn(3, 4) - batch2.extra_info["scalar"] = torch.tensor(2.0) - - result = BatchMeta.concat([batch1, batch2]) - - # Tensors are concatenated along dim=0 - assert result.extra_info["tensor"].shape == (6, 4) - # Scalars are stacked - assert result.extra_info["scalar"].shape == (2,) - - def test_batch_meta_concat_with_non_tensor_stack(self): - """Example: Concat handles NonTensorStack extra_info.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["non_tensor"] = NonTensorStack(1, 2, 3) - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["non_tensor"] = NonTensorStack(4, 5, 6) - - result = BatchMeta.concat([batch1, batch2]) - - # NonTensorStack is stacked - assert isinstance(result.extra_info["non_tensor"], NonTensorStack) - assert result.extra_info["non_tensor"].batch_size == torch.Size([2, 3]) - - def test_batch_meta_concat_with_list_extra_info(self): - """Example: Concat handles list extra_info by flattening.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["list"] = [1, 2, 3] - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["list"] = [4, 5, 6] - - result = BatchMeta.concat([batch1, batch2]) - - # Lists are flattened - assert result.extra_info["list"] == [1, 2, 3, 4, 5, 6] - - def test_batch_meta_concat_with_mixed_types(self): - """Example: Concat handles mixed extra_info types correctly.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["tensor"] = torch.randn(3, 4) - batch1.extra_info["list"] = [1, 2, 3] - batch1.extra_info["string"] = "hello" - batch1.extra_info["non_tensor"] = NonTensorStack(1, 2, 3) - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["tensor"] = torch.randn(3, 4) - batch2.extra_info["list"] = [4, 5] - batch2.extra_info["string"] = "world" - batch2.extra_info["non_tensor"] = NonTensorStack(4, 5, 6) - - result = BatchMeta.concat([batch1, batch2]) - - # Each type is handled appropriately - assert result.extra_info["tensor"].shape == (6, 4) # Concatenated - assert result.extra_info["list"] == [1, 2, 3, 4, 5] # Flattened - assert result.extra_info["string"] == "world" # Last value wins - assert isinstance(result.extra_info["non_tensor"], NonTensorStack) # Stacked - - def test_batch_meta_union(self): - """Example: Union two batches with matching global indexes.""" - fields1 = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - fields2 = { - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), - } - - batch1 = BatchMeta( - samples=[ - SampleMeta(partition_id="partition_0", global_index=8, fields=fields1), - SampleMeta(partition_id="partition_0", global_index=9, fields=fields1), - ], - _custom_backend_meta={ - i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}} for i in [8, 9] - }, - ) - batch1.extra_info["info1"] = "value1" - - batch2 = BatchMeta( - samples=[ - SampleMeta(partition_id="partition_0", global_index=8, fields=fields2), - SampleMeta(partition_id="partition_0", global_index=9, fields=fields2), - ], - _custom_backend_meta={ - i: {"field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} for i in [8, 9] - }, - ) - batch2.extra_info["info2"] = "value2" - - result = batch1.union(batch2) - - assert len(result) == 2 - # All fields are present - for sample in result.samples: - assert "field1" in sample.fields - assert "field2" in sample.fields - assert "field3" in sample.fields - # Extra info is merged - assert result.extra_info["info1"] == "value1" - assert result.extra_info["info2"] == "value2" - - # _custom_backend_meta is merged - assert result._custom_backend_meta == { - i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} - for i in [8, 9] - } - def test_batch_meta_union_validation(self): - """Example: Union validation catches mismatched conditions.""" - fields = {"test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,))} - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - - batch2 = BatchMeta( - samples=[ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), # Different size - ] - ) - - with pytest.raises(ValueError) as exc_info: - batch1.union(batch2, validate=True) - assert "Batch sizes do not match" in str(exc_info.value) - - def test_batch_meta_reorder(self): - """Example: Reorder samples in a batch.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Reorder to [2, 0, 1] - batch.reorder([2, 0, 1]) - - assert batch.global_indexes == [6, 4, 5] - # Batch indexes are updated - assert batch.samples[0].batch_index == 0 - assert batch.samples[1].batch_index == 1 - assert batch.samples[2].batch_index == 2 - - def test_batch_meta_add_fields(self): - """Example: Add fields from TensorDict to all samples.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Create TensorDict with new fields - tensor_dict = TensorDict({"new_field1": torch.randn(2, 3), "new_field2": torch.randn(2, 5)}, batch_size=[2]) - - batch.add_fields(tensor_dict) - - # Fields are added to all samples - for sample in batch.samples: - assert "new_field1" in sample.fields - assert "new_field2" in sample.fields - assert sample.is_ready is True - - def test_batch_meta_select_fields(self): - """Example: Select specific fields from all samples in a batch.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - "field3": FieldMeta(name="field3", dtype=torch.bool, shape=(4,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] + def test_custom_meta_update(self): + """Test update_custom_meta method.""" + batch = self._make_batch(batch_size=2) + batch.update_custom_meta([{"tag": "alpha"}, {"tag": "beta"}]) + assert batch.custom_meta[0]["tag"] == "alpha" + assert batch.custom_meta[1]["tag"] == "beta" + + def test_custom_backend_meta(self): + """Test _custom_backend_meta attribute.""" + batch = self._make_batch(batch_size=2) + batch._custom_backend_meta[0]["field_a"] = {"storage_key": "abc"} + assert batch._custom_backend_meta[0]["field_a"]["storage_key"] == "abc" + + def test_size_property(self): + """Test size == len property.""" + batch = self._make_batch(batch_size=5) + assert batch.size == 5 + assert len(batch) == 5 + + def test_add_fields_empty_batch_is_non_tensor_unknown(self): + """add_fields with empty field value leaves is_non_tensor as None (unknown). + + When a field has zero samples, we cannot determine the field type from data. + is_non_tensor must not default to False (which would incorrectly imply Tensor). + """ + from tensordict import TensorDict + + batch = BatchMeta.empty() + # TensorDict with an empty tensor of batch_size=0 + empty_td = TensorDict({"empty_field": torch.empty(0, 2)}, batch_size=0) + batch.add_fields(empty_td) + assert batch.field_schema["empty_field"]["is_non_tensor"] is None + + def test_to_dict_preserves_empty_per_sample_shapes(self): + """to_dict must preserve per_sample_shapes even when it is an empty list. + + An empty list [] is falsy in Python; the old `if meta.get(...)` check + would silently drop it. The fix uses `is not None`. + """ + # Use batch_size=0 so per_sample_shapes=[] is a valid (0-length) list. batch = BatchMeta( - samples=samples, - extra_info={"test_key": "test_value"}, - _custom_backend_meta={ - i: { - "field1": {"dtype": torch.float32}, - "field2": {"dtype": torch.int64}, - "field3": {"dtype": torch.bool}, + global_indexes=[], + partition_ids=[], + field_schema={ + "f": { + "dtype": torch.float32, + "shape": None, + "is_nested": True, + "is_non_tensor": False, + "per_sample_shapes": [], # valid for batch_size=0, but falsy } - for i in [0, 1] }, + production_status=np.zeros(0, dtype=np.int8), ) + d = batch.to_dict() + assert "per_sample_shapes" in d["field_schema"]["f"], "per_sample_shapes=[] was dropped by to_dict" + assert d["field_schema"]["f"]["per_sample_shapes"] == [] - # Select only field1 and field3 - selected_batch = batch.select_fields(["field1", "field3"]) - - # Check all samples have correct fields - assert len(selected_batch) == 2 - for sample in selected_batch.samples: - assert "field1" in sample.fields - assert "field3" in sample.fields - assert "field2" not in sample.fields - # Original batch is unchanged - assert len(batch.samples[0].fields) == 3 - # Extra info is preserved - assert selected_batch.extra_info["test_key"] == "test_value" - # Global indexes are preserved - assert selected_batch.global_indexes == [0, 1] - - # _custom_backend_meta is selected - assert "field1" in selected_batch._custom_backend_meta[0] - assert "field2" not in selected_batch._custom_backend_meta[0] - assert "field3" in selected_batch._custom_backend_meta[0] - assert "field1" in selected_batch._custom_backend_meta[1] - assert "field2" not in selected_batch._custom_backend_meta[1] - assert "field3" in selected_batch._custom_backend_meta[1] - - def test_batch_meta_select_fields_with_nonexistent_fields(self): - """Example: Select fields ignores non-existent field names in batch.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Try to select fields including non-existent ones - selected_batch = batch.select_fields(["field1", "nonexistent_field"]) - - # Only existing fields are selected - for sample in selected_batch.samples: - assert "field1" in sample.fields - assert "nonexistent_field" not in sample.fields - assert "field2" not in sample.fields - - def test_batch_meta_select_fields_empty_list(self): - """Example: Select with empty field list returns batch with no fields.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Select with empty list - selected_batch = batch.select_fields([]) - - assert len(selected_batch) == 2 - for sample in selected_batch.samples: - assert len(sample.fields) == 0 - # Global indexes are preserved - assert selected_batch.global_indexes == [0, 1] - - def test_batch_meta_select_fields_single_sample(self): - """Example: Select fields works correctly for batch with single sample.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - sample = SampleMeta(partition_id="partition_0", global_index=0, fields=fields) - batch = BatchMeta(samples=[sample]) - - # Select only field2 - selected_batch = batch.select_fields(["field2"]) - - assert len(selected_batch) == 1 - assert "field2" in selected_batch.samples[0].fields - assert "field1" not in selected_batch.samples[0].fields - - def test_batch_meta_select_fields_preserves_field_metadata(self): - """Example: Selected fields preserve their original metadata.""" - fields = { - "field1": FieldMeta( - name="field1", dtype=torch.float32, shape=(2, 3), production_status=ProductionStatus.READY_FOR_CONSUME - ), - "field2": FieldMeta( - name="field2", dtype=torch.int64, shape=(5,), production_status=ProductionStatus.NOT_PRODUCED - ), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Select field1 - selected_batch = batch.select_fields(["field1"]) - selected_field = selected_batch.samples[0].fields["field1"] - - assert selected_field.dtype == torch.float32 - assert selected_field.shape == (2, 3) - assert selected_field.production_status == ProductionStatus.READY_FOR_CONSUME - assert selected_field.name == "field1" - - def test_batch_meta_select_samples(self): - """Example: Select specific samples from a batch.""" - fields = { - "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)), - "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), - SampleMeta(partition_id="partition_0", global_index=7, fields=fields), - ] - batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) - - # Select samples at indices [0, 2] - selected_batch = batch.select_samples([0, 2]) # This will select the first two samples with global_index=4/5 - - # Check number of samples - assert len(selected_batch) == 2 - # Check global indexes - assert selected_batch.global_indexes == [4, 6] - # Check fields are preserved - for sample in selected_batch.samples: - assert "field1" in sample.fields - assert "field2" in sample.fields - # Original batch is unchanged - assert len(batch) == 4 - # Extra info is preserved - assert selected_batch.extra_info["test_key"] == "test_value" - - def test_batch_meta_select_samples_all_indices(self): - """Example: Select all samples using complete index list.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=4, fields=fields), - SampleMeta(partition_id="partition_0", global_index=5, fields=fields), - SampleMeta(partition_id="partition_0", global_index=6, fields=fields), - ] - batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) - - # Select all samples - selected_batch = batch.select_samples([0, 1, 2]) - - # All samples are selected - assert len(selected_batch) == 3 - assert selected_batch.global_indexes == [4, 5, 6] - # Extra info is preserved - assert selected_batch.extra_info["test_key"] == "test_value" - - def test_batch_meta_select_samples_single_sample(self): - """Example: Select a single sample from batch.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Select only the middle sample - selected_batch = batch.select_samples([1]) - - assert len(selected_batch) == 1 - assert selected_batch.global_indexes == [1] - assert selected_batch.samples[0].batch_index == 0 # New batch index - - def test_batch_meta_select_samples_empty_list(self): - """Example: Select with empty list returns empty batch.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) - - # Select with empty list - selected_batch = batch.select_samples([]) - - assert len(selected_batch) == 0 - assert selected_batch.global_indexes == [] - # Extra info is still preserved - assert selected_batch.extra_info["test_key"] == "test_value" - - def test_batch_meta_select_samples_reverse_order(self): - """Example: Select samples in reverse order.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Select samples in reverse order - selected_batch = batch.select_samples([2, 1, 0]) - - assert len(selected_batch) == 3 - assert selected_batch.global_indexes == [2, 1, 0] - # Batch indexes are re-assigned - assert selected_batch.samples[0].global_index == 2 - assert selected_batch.samples[1].global_index == 1 - assert selected_batch.samples[2].global_index == 0 - - def test_batch_meta_select_samples_with_extra_info(self): - """Example: Select samples preserves all extra info types.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Add various extra info types - batch.extra_info["tensor"] = torch.randn(3, 4) - batch.extra_info["string"] = "test_string" - batch.extra_info["number"] = 42 - batch.extra_info["list"] = [1, 2, 3] - - # Select one sample - selected_batch = batch.select_samples([0]) - - # All extra info is preserved - assert "tensor" in selected_batch.extra_info - assert selected_batch.extra_info["string"] == "test_string" - assert selected_batch.extra_info["number"] == 42 - assert selected_batch.extra_info["list"] == [1, 2, 3] - - # ===================================================== - # Custom Meta Tests - # ===================================================== - def test_batch_meta_update_custom_meta(self): - """Test update_custom_meta adds metadata for different global indices.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - "field_b": FieldMeta(name="field_b", dtype=torch.int64, shape=(3,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Initial custom_meta for sample 0 - batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}]) - - result = batch.get_all_custom_meta() - assert result[0]["sample_score"] == 0.9 - assert result[1]["sample_score"] == 0.1 - - def test_batch_meta_update_custom_meta_overwrites(self): - """Test update_custom_meta overwrites existing metadata at same key.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Initial custom_meta - batch.update_custom_meta([{"sample_score": 0.9, "quality": "high"}]) - - # Update with new value for same field - dict.update replaces - batch.update_custom_meta([{"sample_score": 0.1, "quality": "low"}]) + # Verify round-trip via from_dict + restored = BatchMeta.from_dict(d) + assert restored.field_schema["f"].get("per_sample_shapes") == [] - result = batch.get_all_custom_meta() - assert result[0]["sample_score"] == 0.1 - assert result[0]["quality"] == "low" + def test_concat_extra_info_scalar_raises_type_error(self): + """concat raises TypeError for scalar extra_info values (no merge strategy). - def test_batch_meta_update_custom_meta_with_none(self): - """Test update_custom_meta with None does nothing.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Set initial value - batch.update_custom_meta([{"sample_score": 0.9}]) - - # Update with None should not change anything - batch.update_custom_meta(None) - - result = batch.get_all_custom_meta() - assert result[0]["sample_score"] == 0.9 - - def test_batch_meta_clear_custom_meta(self): - """Test clear_custom_meta removes all custom metadata.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Set custom_meta - batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}]) - - # Clear all - batch.clear_custom_meta() - - result = batch.get_all_custom_meta() - assert result == [{}, {}] - - def test_batch_meta_get_all_custom_meta_returns_deep_copy(self): - """Test get_all_custom_meta returns a deep copy.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - custom_meta = [{"sample_score": 0.9, "nested": {"value": 1}}] - batch.update_custom_meta(custom_meta) - - # Get all custom_meta - result = batch.get_all_custom_meta() - - # Verify it's a deep copy - modifying result should not affect original - result[0]["sample_score"] = 0.1 - result[0]["nested"]["value"] = 999 - - original = batch.get_all_custom_meta() - assert original[0]["sample_score"] == 0.9 - assert original[0]["nested"]["value"] == 1 - - def test_batch_meta_get_all_custom_meta_empty(self): - """Test get_all_custom_meta with no custom_meta returns empty dict.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - result = batch.get_all_custom_meta() - assert result == [{}] - - def test_batch_meta_custom_meta_with_nested_data(self): - """Test custom_meta supports nested dictionary data.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - nested_meta = { - "model_info": {"name": "llama", "version": "7b", "config": {"hidden_size": 4096, "num_layers": 32}}, - "tags": ["training", "inference"], - } - batch.update_custom_meta([nested_meta]) - - result = batch.get_all_custom_meta() - assert result[0]["model_info"]["name"] == "llama" - assert result[0]["model_info"]["version"] == "7b" - assert result[0]["model_info"]["config"]["hidden_size"] == 4096 - assert result[0]["tags"] == ["training", "inference"] - - # ===================================================== - # Extra Info Methods Tests - # ===================================================== - - def test_batch_meta_update_extra_info(self): - """Test update_extra_info adds multiple values.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - batch = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - - # Update with multiple values - batch.update_extra_info({"key1": "value1", "key2": "value2", "key3": "value3"}) - - # Verify all exist - assert "key1" in batch.extra_info - assert "key2" in batch.extra_info - assert "key3" in batch.extra_info - assert batch.extra_info["key1"] == "value1" - assert batch.extra_info["key2"] == "value2" - - def test_batch_meta_extra_info_preserved_in_operations(self): - """Test extra_info is preserved in batch operations.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["test_key1"] = "test_value" - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["test_key2"] = "test_value_2" - - result = BatchMeta.concat([batch1, batch2]) - - # Extra info is preserved - assert "test_key1" in result.extra_info - - def test_batch_meta_extra_info_with_concat(self): - """Test extra_info handling in concat with mixed types.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch1.extra_info["string"] = "hello" - batch1.extra_info["number"] = 42 - - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) - batch2.extra_info["string"] = "world" - batch2.extra_info["number"] = 100 - - result = BatchMeta.concat([batch1, batch2]) - - # String: last value wins - assert result.extra_info["string"] == "world" - - -class TestEdgeCases: - """Edge cases and important boundaries.""" - - def test_batch_meta_chunk_with_more_chunks_than_samples(self): - """Example: Chunking when chunks > samples produces empty chunks.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # 5 chunks for 2 samples - chunks = batch.chunk(5) - - assert len(chunks) == 5 - # First 2 chunks have samples - assert len(chunks[0]) == 1 - assert len(chunks[1]) == 1 - # Last 3 chunks are empty - assert len(chunks[2]) == 0 - assert len(chunks[3]) == 0 - assert len(chunks[4]) == 0 - - def test_batch_meta_concat_with_empty_batches(self): - """Example: Concat handles empty batches gracefully.""" - fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) - } - - batch1 = BatchMeta(samples=[]) - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch3 = BatchMeta(samples=[]) - - # Empty batches are filtered out - result = BatchMeta.concat([batch1, batch2, batch3]) - assert len(result) == 1 - assert result.global_indexes == [0] - - def test_batch_meta_concat_validation_error(self): - """Example: Concat validation catches field name mismatches.""" - fields1 = {"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,))} - fields2 = {"field2": FieldMeta(name="field2", dtype=torch.float32, shape=(2,))} - - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields1)]) + Scalars like int/str do not have a defined concatenation strategy. + The error message must not say 'mixed types' when types are uniform. + """ + batch1 = BatchMeta( + global_indexes=[0], + partition_ids=["p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (1,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(1, dtype=np.int8), + extra_info={"step": 1}, + ) + batch2 = BatchMeta( + global_indexes=[1], + partition_ids=["p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (1,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(1, dtype=np.int8), + extra_info={"step": 2}, + ) + with pytest.raises(TypeError, match="no defined merge strategy"): + BatchMeta.concat([batch1, batch2]) - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields2)]) - with pytest.raises(ValueError) as exc_info: - BatchMeta.concat([batch1, batch2], validate=True) - assert "Field names do not match" in str(exc_info.value) +# ============================================================================== +# KVBatchMeta Tests (all migrated from main with no modification) +# ============================================================================== class TestKVBatchMeta: diff --git a/tests/test_ray_p2p.py b/tests/test_ray_p2p.py index 4bd54a7..e958b84 100644 --- a/tests/test_ray_p2p.py +++ b/tests/test_ray_p2p.py @@ -17,6 +17,7 @@ import time from pathlib import Path +import numpy as np import ray import torch from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -26,7 +27,7 @@ sys.path.append(str(parent_dir)) from transfer_queue.client import TransferQueueClient # noqa: E402 -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 +from transfer_queue.metadata import BatchMeta # noqa: E402 from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402 from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory # noqa: E402 from transfer_queue.utils.zmq_utils import ZMQServerInfo # noqa: E402 @@ -115,17 +116,19 @@ def generate_data( batch_size=batch_size, ) - samples = [ - SampleMeta( - global_index=i, - partition_id=partition_id, - fields={ - "input_ids": FieldMeta(name="input_ids", dtype=torch.float32, shape=(seq_len,)), - }, - ) - for i in range(batch_size) - ] - meta = BatchMeta(samples=samples) + meta = BatchMeta( + global_indexes=list(range(batch_size)), + partition_ids=[partition_id] * batch_size, + field_schema={ + "input_ids": { + "dtype": torch.float32, + "shape": (seq_len,), + "is_nested": False, + "is_non_tensor": False, + } + }, + production_status=np.zeros(batch_size, dtype=np.int8), + ) self.data = data self.meta = meta diff --git a/tests/test_serial_utils_on_cpu.py b/tests/test_serial_utils_on_cpu.py index 316bbd9..3f498c8 100644 --- a/tests/test_serial_utils_on_cpu.py +++ b/tests/test_serial_utils_on_cpu.py @@ -16,6 +16,7 @@ import sys from pathlib import Path +import numpy as np import pytest import torch from tensordict import TensorDict @@ -76,6 +77,8 @@ def test_zmq_msg_serialization(): encoded_msg = msg.serialize() decoded_msg = ZMQMessage.deserialize(encoded_msg) assert decoded_msg.request_type == msg.request_type + # TensorDict converts numpy arrays to Tensors on insertion, + # so decoding yields a Tensor (not np.ndarray). assert torch.allclose(decoded_msg.body["data"]["numpy_array"], msg.body["data"]["numpy_array"]) assert torch.allclose(decoded_msg.body["data"]["normal_tensor"], msg.body["data"]["normal_tensor"]) assert msg.body["data"]["nested_tensor"].layout == decoded_msg.body["data"]["nested_tensor"].layout @@ -883,13 +886,12 @@ def test_numpy_object_array_dicts(self): assert orig == decoded def test_numpy_numeric_arrays_zero_copy(self): - """Test that numeric numpy arrays use zero-copy path.""" + """Test that numeric numpy arrays use zero-copy path and return np.ndarray.""" import numpy as np encoder = MsgpackEncoder() decoder = MsgpackDecoder() - # These should use zero-copy (torch.from_numpy + tensor encoding) numeric_dtypes = [ np.float32, np.float64, @@ -910,14 +912,17 @@ def test_numpy_numeric_arrays_zero_copy(self): serialized = encoder.encode(arr) - # Zero-copy should produce multiple buffers (metadata + tensor buffer) + # Zero-copy must produce multiple buffers (metadata + data buffer) assert len(serialized) > 1, f"Expected zero-copy for dtype {dtype}" deserialized = decoder.decode(serialized) - # Deserialized as torch.Tensor (due to zero-copy path) - assert isinstance(deserialized, torch.Tensor) - assert torch.allclose(deserialized, torch.from_numpy(arr)) + # After the fix: deserialized must be np.ndarray, not torch.Tensor + assert isinstance(deserialized, np.ndarray), ( + f"Expected np.ndarray but got {type(deserialized)} for dtype={dtype}" + ) + assert deserialized.dtype == arr.dtype + assert np.array_equal(deserialized, arr) def test_numpy_object_array_in_zmq_message(self): """Test numpy object array inside ZMQMessage.""" @@ -978,3 +983,134 @@ def test_numpy_bytes_array(self): deserialized = decoder.decode(serialized) assert np.array_equal(deserialized, bytes_arr) + + +# ============================================================================ +# Numpy Native Serialization Tests (CUSTOM_TYPE_NUMPY) +# ============================================================================ +class TestNumpyNativeSerialization: + """Test that numpy arrays are serialized/deserialized natively. + + After the fix, numeric numpy arrays must: + 1. Round-trip as np.ndarray (not torch.Tensor). + 2. Preserve dtype and shape exactly. + 3. Use zero-copy (len(serialized) > 1). + 4. Produce correct values. + """ + + @pytest.mark.parametrize( + "dtype", + [ + # Numeric / bool / complex (original coverage) + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.bool_, + np.complex64, + np.complex128, + # Extended types now also covered via exclusion-based check + np.datetime64, # kind='M', stored as int64 + np.timedelta64, # kind='m', stored as int64 + np.dtype("S10"), # kind='S', fixed-length bytes + ], + ) + def test_numpy_roundtrip_preserves_type(self, dtype): + """All buffer-compatible ndarrays must come back as np.ndarray, not torch.Tensor.""" + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() + + dtype = np.dtype(dtype) # normalise in case a dtype instance was passed + if dtype == np.dtype("bool"): + arr = np.array([True, False, True, True], dtype=dtype) + elif dtype.kind == "c": # complex + arr = np.array([1 + 2j, 3 + 4j], dtype=dtype) + elif dtype.kind == "M": # datetime64 + arr = np.array(["2024-01", "2024-02"], dtype=dtype) + elif dtype.kind == "m": # timedelta64 + arr = np.array([1, 2], dtype=dtype) + elif dtype.kind == "S": # fixed-length bytes + arr = np.array([b"hello", b"world"], dtype=dtype) + elif np.issubdtype(dtype, np.integer): + arr = np.array([1, 2, 3, 4], dtype=dtype) + else: + arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype) + + serialized = encoder.encode(arr) + deserialized = decoder.decode(serialized) + + assert isinstance(deserialized, np.ndarray), f"Expected np.ndarray, got {type(deserialized)} for dtype={dtype}" + assert deserialized.dtype == arr.dtype + assert deserialized.shape == arr.shape + assert np.array_equal(deserialized, arr) + + def test_numpy_zero_copy_uses_multiple_buffers(self): + """Zero-copy path must produce len(serialized) > 1.""" + encoder = MsgpackEncoder() + arr = np.arange(100, dtype=np.float32) + serialized = encoder.encode(arr) + assert len(serialized) > 1, "Expected zero-copy (aux buffer) for float32 ndarray" + + def test_numpy_non_contiguous_roundtrip(self): + """Non-C-contiguous arrays must be made contiguous before serialization.""" + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() + + base = np.arange(100, dtype=np.float64).reshape(10, 10) + arr = base[::2, ::2] # non-contiguous view + assert not arr.flags["C_CONTIGUOUS"] + + serialized = encoder.encode(arr) + deserialized = decoder.decode(serialized) + + assert isinstance(deserialized, np.ndarray) + assert np.array_equal(deserialized, arr) + + def test_numpy_multidim_shape_preserved(self): + """Shape must survive a round-trip for multi-dimensional arrays.""" + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() + + arr = np.arange(60, dtype=np.int32).reshape(3, 4, 5) + serialized = encoder.encode(arr) + deserialized = decoder.decode(serialized) + + assert isinstance(deserialized, np.ndarray) + assert deserialized.shape == (3, 4, 5) + assert np.array_equal(deserialized, arr) + + def test_numpy_empty_array_roundtrip(self): + """Empty arrays must round-trip correctly.""" + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() + + arr = np.empty((0,), dtype=np.float32) + serialized = encoder.encode(arr) + deserialized = decoder.decode(serialized) + + assert isinstance(deserialized, np.ndarray) + assert deserialized.shape == (0,) + assert deserialized.dtype == np.float32 + + def test_numpy_object_array_still_uses_pickle(self): + """Object arrays (kind='O' or hasobject) must fall back to pickle.""" + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() + + # dtype=object — kind 'O', cannot be viewed as a contiguous byte buffer + arr = np.array(["a", "b", "c"], dtype=object) + serialized = encoder.encode(arr) + + # Pickle-fallback produces a single buffer (no aux tensor buffer appended) + assert len(serialized) == 1, "Object array should not use zero-copy path" + + deserialized = decoder.decode(serialized) + assert isinstance(deserialized, np.ndarray) + assert np.array_equal(deserialized, arr) diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index ed43e41..d9aebff 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -399,25 +399,71 @@ def test_storage_unit_data_direct(): storage_data = StorageUnitData(storage_size=10) - # Test put_data field_data = { "log_probs": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])], "rewards": [torch.tensor([10.0]), torch.tensor([20.0])], } + # local_keys = gi values (e.g., 0 and 1) storage_data.put_data(field_data, [0, 1]) - # Test get_data result = storage_data.get_data(["log_probs", "rewards"], [0, 1]) assert "log_probs" in result assert "rewards" in result assert len(result["log_probs"]) == 2 assert len(result["rewards"]) == 2 - # Test single index get result_single = storage_data.get_data(["log_probs"], [0]) - assert torch.allclose(result_single["log_probs"][0], torch.tensor([1.0, 2.0])) + torch.testing.assert_close(result_single["log_probs"][0], torch.tensor([1.0, 2.0])) - # Test clear + # clear: key is removed (not set to None) storage_data.clear([0]) - result_after_clear = storage_data.get_data(["log_probs"], [0]) - assert result_after_clear["log_probs"][0] is None + assert 0 not in storage_data.field_data["log_probs"] # key gone + assert 1 in storage_data.field_data["log_probs"] # other key intact + + +def test_storage_unit_data_dict_key(): + """StorageUnitData dict-key: gi 直接作为 key,clear 真正释放内存.""" + from transfer_queue.storage.simple_backend import StorageUnitData + + storage = StorageUnitData(storage_size=4) + + # put_data: 用 gi 列表 [10, 11] 作为 local_keys + storage.put_data( + {"f": [torch.tensor([1.0]), torch.tensor([2.0])]}, + local_keys=[10, 11], + ) + assert len(storage.field_data["f"]) == 2 + + # get_data: 通过 gi 读取 + result = storage.get_data(["f"], local_keys=[10, 11]) + torch.testing.assert_close(result["f"][0], torch.tensor([1.0])) + torch.testing.assert_close(result["f"][1], torch.tensor([2.0])) + + # clear: 真正删除 key,不是置 None + storage.clear(keys=[10]) + assert 10 not in storage.field_data["f"] + assert 11 in storage.field_data["f"] + + # capacity check: storage_size=4,已有 1 条,再放 4 条应失败 + with pytest.raises(ValueError, match="Storage capacity exceeded"): + storage.put_data( + {"f": [torch.tensor([i * 1.0]) for i in range(4)]}, + local_keys=[20, 21, 22, 23], + ) + + +def test_storage_unit_data_partial_consume_safety(): + """部分消费后写入复用 gi,不应覆盖未消费数据.""" + from transfer_queue.storage.simple_backend import StorageUnitData + + storage = StorageUnitData(storage_size=4) + storage.put_data({"f": [torch.tensor([0.0]), torch.tensor([1.0])]}, local_keys=[0, 1]) + + storage.clear(keys=[1]) # 只清除 gi=1 + assert 0 in storage.field_data["f"] + assert 1 not in storage.field_data["f"] + + # 复用 gi=1 写入新数据,不影响 gi=0 + storage.put_data({"f": [torch.tensor([9.0])]}, local_keys=[1]) + torch.testing.assert_close(storage.field_data["f"][0], torch.tensor([0.0])) + torch.testing.assert_close(storage.field_data["f"][1], torch.tensor([9.0])) diff --git a/tests/test_yuanrong_client_zero_copy.py b/tests/test_yuanrong_client_zero_copy.py index 3048ec5..b93fd32 100644 --- a/tests/test_yuanrong_client_zero_copy.py +++ b/tests/test_yuanrong_client_zero_copy.py @@ -21,6 +21,8 @@ import pytest import torch +pytest.importorskip("yr") + parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index f7cd239..06bc0be 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -25,6 +25,7 @@ from typing import Any, Optional from uuid import uuid4 +import numpy as np import ray import torch import zmq @@ -34,11 +35,9 @@ from transfer_queue.metadata import ( BatchMeta, - FieldMeta, - SampleMeta, ) from transfer_queue.sampler import BaseSampler, SequentialSampler -from transfer_queue.utils.enum_utils import ProductionStatus, TransferQueueRole +from transfer_queue.utils.enum_utils import TransferQueueRole from transfer_queue.utils.perf_utils import IntervalPerfMonitor from transfer_queue.utils.zmq_utils import ( ZMQMessage, @@ -227,6 +226,9 @@ class DataPartitionStatus: field_name_mapping: dict[str, int] = field(default_factory=dict) # field_name -> column_index field_dtypes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: dtype} field_shapes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: shape} + # O(F) schema cache: field_name -> {dtype, shape, is_nested, is_non_tensor} + # Updated eagerly in _update_field_metadata; used by get_field_schema() for O(1) per-field lookup. + field_schema_cache: dict[str, dict[str, Any]] = field(default_factory=dict) field_custom_backend_meta: dict[int, dict[str, Any]] = field( default_factory=dict ) # global_idx -> {field: custom_backend_meta} @@ -483,12 +485,34 @@ def _update_field_metadata( if global_idx not in self.field_dtypes: self.field_dtypes[global_idx] = {} self.field_dtypes[global_idx].update(dtype_value[i]) + # Update field_schema_cache with new dtype info + for fname, dtype in dtype_value[i].items(): + if fname not in self.field_schema_cache: + self.field_schema_cache[fname] = { + "dtype": dtype, + "shape": None, + "is_nested": False, + "is_non_tensor": False, + } + elif self.field_schema_cache[fname].get("dtype") is None: + self.field_schema_cache[fname]["dtype"] = dtype # Only create and update shape mapping if a shape value was provided if shape_value[i] is not None: if global_idx not in self.field_shapes: self.field_shapes[global_idx] = {} self.field_shapes[global_idx].update(shape_value[i]) + # Update field_schema_cache with new shape info + for fname, shape in shape_value[i].items(): + if fname not in self.field_schema_cache: + self.field_schema_cache[fname] = { + "dtype": None, + "shape": shape, + "is_nested": False, + "is_non_tensor": False, + } + elif self.field_schema_cache[fname].get("shape") is None: + self.field_schema_cache[fname]["shape"] = shape # Only create and update custom_backend_meta mapping if a custom_backend_meta value was provided if custom_backend_meta_value[i] is not None: @@ -669,13 +693,23 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: # ==================== Metadata Methods ==================== - def get_field_dtype(self, global_index: int, field_name: str) -> Optional[Any]: - """Get dtype for a specific sample and field.""" - return self.field_dtypes.get(global_index, {}).get(field_name) - - def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]: - """Get shape for a specific sample and field.""" - return self.field_shapes.get(global_index, {}).get(field_name) + def get_field_schema(self, field_names: list[str]) -> dict[str, dict[str, Any]]: + """Return field_schema for the requested fields from the O(F) cache. + + Complexity: O(F) — one dict-lookup per field, no full scan of per-sample maps. + The cache is populated eagerly in _update_field_metadata() at put time. + """ + schema = {} + for fname in field_names: + cached = self.field_schema_cache.get(fname) + if cached is not None: + schema[fname] = { + "dtype": cached.get("dtype"), + "shape": cached.get("shape"), + "is_nested": cached.get("is_nested", False), + "is_non_tensor": cached.get("is_non_tensor", False), + } + return schema def get_field_custom_backend_meta( self, global_indices: list[int], field_names: list[str] @@ -700,11 +734,23 @@ def get_field_custom_backend_meta( {0: {'field_a': {'meta1': 'xxx'}, 'field_b': {'meta1': 'xxx'}}, 1: {...}} """ return { - idx: {f: v for f, v in self.field_custom_backend_meta[idx].items() if f in field_names} + idx: { + f: v + for f, v in self.field_custom_backend_meta[idx].items() + if f.startswith("_") or f in field_names # keep special keys like _su_id + } for idx in global_indices if idx in self.field_custom_backend_meta } + def get_field_dtype(self, global_index: int, field_name: str) -> Optional[Any]: + """Get the dtype for a specific (global_index, field_name) pair.""" + return self.field_dtypes.get(global_index, {}).get(field_name, None) + + def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]: + """Get the shape for a specific (global_index, field_name) pair.""" + return self.field_shapes.get(global_index, {}).get(field_name, None) + def get_custom_meta(self, global_indices: list[int]) -> dict[int, dict]: """ Get custom_meta for multiple samples. @@ -1292,6 +1338,8 @@ def generate_batch_meta( """ Generate BatchMeta for specific samples in a partition. + O(F) optimized version that uses field_schema instead of per-sample metadata. + This function is responsible only for metadata generation and does not modify consumption state. State management is handled by the calling function. @@ -1314,55 +1362,56 @@ def generate_batch_meta( if mode not in ["fetch", "insert", "force_fetch"]: raise ValueError(f"Invalid mode: {mode}") - # Generate sample metadata - samples = [] - for global_index in batch_global_indexes: - fields = {} - for field_name in data_fields: - # Determine production status - if mode == "fetch": - production_status = ProductionStatus.READY_FOR_CONSUME - dtype = partition.get_field_dtype(global_index, field_name) - shape = partition.get_field_shape(global_index, field_name) - elif mode == "insert": - production_status = ProductionStatus.NOT_PRODUCED - dtype = None - shape = None - elif mode == "force_fetch": - field_index = partition.field_name_mapping.get(field_name) - if ( - field_index is not None - and partition.production_status is not None - and partition.production_status[global_index, field_index] == 1 - ): - production_status = ProductionStatus.READY_FOR_CONSUME - dtype = partition.get_field_dtype(global_index, field_name) - shape = partition.get_field_shape(global_index, field_name) - else: - production_status = ProductionStatus.NOT_PRODUCED - dtype = None - shape = None - - fields[field_name] = FieldMeta( - name=field_name, - dtype=dtype, - shape=shape, - production_status=production_status, - ) + batch_size = len(batch_global_indexes) - sample = SampleMeta( - partition_id=partition_id, - global_index=global_index, - fields=fields, - ) - samples.append(sample) + field_schema = partition.get_field_schema(data_fields) - custom_meta = partition.get_custom_meta(batch_global_indexes) + # In insert mode, create placeholder schema for unregistered fields so that + # metadata.field_names is complete and update_production_status() can recognize them. + if mode == "insert": + for fname in data_fields: + if fname not in field_schema: + field_schema[fname] = { + "dtype": None, + "shape": None, + "is_nested": False, + "is_non_tensor": False, + } + + if mode == "fetch": + production_status = np.ones(batch_size, dtype=np.int8) + elif mode == "insert": + production_status = np.zeros(batch_size, dtype=np.int8) + else: # force_fetch + production_status = np.zeros(batch_size, dtype=np.int8) + if partition.production_status is not None and data_fields: + field_indices = [ + partition.field_name_mapping.get(fname) + for fname in data_fields + if fname in partition.field_name_mapping + ] + if field_indices: + for i, global_idx in enumerate(batch_global_indexes): + if global_idx < partition.production_status.shape[0]: + sample_status = partition.production_status[global_idx, field_indices] + if torch.all(sample_status == 1): + production_status[i] = 1 + + custom_meta_dict = partition.get_custom_meta(batch_global_indexes) custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields) - batch_meta = BatchMeta(samples=samples) - batch_meta.update_custom_meta([custom_meta.get(idx, {}) for idx in batch_meta.global_indexes]) - batch_meta._custom_backend_meta.update(custom_backend_meta) + # Convert controller dict[int, dict] → BatchMeta list[dict] (aligned with batch_global_indexes) + custom_meta_list = [custom_meta_dict.get(gi, {}) for gi in batch_global_indexes] + custom_backend_meta_list = [custom_backend_meta.get(gi, {}) for gi in batch_global_indexes] + + batch_meta = BatchMeta( + global_indexes=batch_global_indexes, + partition_ids=[partition_id] * batch_size, + field_schema=field_schema, + production_status=production_status, + custom_meta=custom_meta_list, + _custom_backend_meta=custom_backend_meta_list, + ) return batch_meta def clear_partition(self, partition_id: str, clear_consumption: bool = True): diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 64134bf..a6582b9 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -26,9 +26,6 @@ import torch from tensordict import TensorDict from tensordict.tensorclass import NonTensorData, NonTensorStack -from torch import Tensor - -from transfer_queue.utils.enum_utils import ProductionStatus logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -40,241 +37,189 @@ logger.addHandler(handler) -# TODO: Add UT for metadata operations -@dataclass -class FieldMeta: - """Records the metadata of a single data field (name, dtype, shape, etc.).""" - - name: str - dtype: Optional[Any] # Data type (e.g., torch.float32, numpy.float32) - shape: Optional[Any] # Data shape (e.g., torch.Size([3, 224, 224]), (3, 224, 224)) - production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED - - def __str__(self) -> str: - return ( - f"FieldMeta(name='{self.name}', dtype={self.dtype}, " - f"shape={self.shape}, production_status={self.production_status})" - ) - - @property - def is_ready(self) -> bool: - """Check if this field is ready for consumption""" - return self.production_status == ProductionStatus.READY_FOR_CONSUME - - @classmethod - def from_dict(cls, data: dict) -> "FieldMeta": - """Create FieldMeta from dictionary.""" - return cls( - name=data["name"], - dtype=data["dtype"], - shape=data["shape"], - production_status=ProductionStatus(str(data["production_status"])) - if isinstance(data["production_status"], int | str) - else data["production_status"], - ) - +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- -@dataclass -class SampleMeta: - """Records the metadata of a single data sample (stored as a row in the data system).""" - partition_id: str # Partition id, used for data versioning - global_index: int # Global row index, uniquely identifies a data sample - fields: dict[str, FieldMeta] # Fields of interest for this sample +def _parse_dtype(dtype_str: str) -> Any: + """Parse a dtype string produced by to_dict() back to a dtype object. - def __post_init__(self): - """Initialize is_ready property based on field readiness""" - # Check if all fields are ready and update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - - def __str__(self) -> str: - return f"SampleMeta(partition_id={self.partition_id}, global_index={self.global_index})" - - @property - def field_names(self) -> list[str]: - """Get list of field names for this sample""" - return list(self.fields.keys()) + Supports torch.dtype strings (e.g. "torch.float32") and numpy dtype + strings (e.g. "float64", "int8"). Falls back to returning the raw + string if parsing fails so that unknown types are not silently dropped. + """ + if dtype_str is None: + return None + # torch.dtype: repr is "torch." + if dtype_str.startswith("torch."): + name = dtype_str[len("torch.") :] + dtype = getattr(torch, name, None) + if isinstance(dtype, torch.dtype): + return dtype + # numpy dtype + try: + return np.dtype(dtype_str) + except TypeError: + pass + # Fallback: return as-is (e.g. plain Python type repr like "") + return dtype_str + + +class _SampleView: + """Lazy read-only view of a single sample row in a columnar BatchMeta.""" + + __slots__ = ("_batch", "_idx") + + def __init__(self, batch: "BatchMeta", idx: int) -> None: + self._batch = batch + self._idx = idx @property - def batch_index(self) -> int: - """Get the batch index of this sample (to be set by BatchMeta)""" - return getattr(self, "_batch_index", -1) - - def get_field_by_name(self, name: str) -> Optional[FieldMeta]: - """Get FieldMeta by field name""" - return self.fields.get(name) + def fields(self) -> dict: + """Read-only access to field_schema: batch.samples[i].fields['a'] -> field meta dict.""" + return self._batch.field_schema - def has_field(self, name: str) -> bool: - """Check if this sample has a specific field""" - return name in self.fields - def is_field_ready(self, field_name: str) -> bool: - """Check if a specific field is ready for consumption""" - field = self.fields.get(field_name) - return field.is_ready if field else False - - def add_fields(self, fields: dict[str, FieldMeta]) -> "SampleMeta": - """ - Add new fields to this sample. New fields will be initialized with given dtype, shape - and production_status (if provided). If not provided, default values (None, None, READY_FOR_CONSUME) - will be used. This modifies the sample in-place to include the new fields. - """ - self.fields = _union_fields(self.fields, fields) - # Update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - return self - - def select_fields(self, field_names: list[str]) -> "SampleMeta": - """ - Select specific fields from this sample. - This will construct a new SampleMeta instance containing only the specified fields. - - Args: - field_names (list[str]): List of field names to retain. +class _SampleViewList: + """Lazy indexable list returned by BatchMeta.samples. - Returns: - SampleMeta: A new SampleMeta instance containing only the specified fields. - """ - selected_fields = {name: self.fields[name] for name in field_names if name in self.fields} - - # construct new SampleMeta instance - # TODO(tianyi): (maybe) move _custom_backend_meta and custom_meta to FieldMeta level? - selected_sample_meta = SampleMeta( - fields=selected_fields, - partition_id=self.partition_id, - global_index=self.global_index, - ) - - return selected_sample_meta - - def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta": - """ - Create a union of this sample's fields with another sample's fields. - Assume both samples have the same global index. If fields overlap, the - fields in this sample will be replaced by the other sample's fields. - - Args: - other: Another SampleMeta to union with - validate: Whether to validate union conditions - - Returns: - New SampleMeta with unioned fields (None if validation fails) - """ - if validate: - if self.global_index != other.global_index: - raise ValueError( - f"Error: Global indexes ({self.global_index} and {other.global_index}) do not match for union." - ) + Supports: indexing (samples[i]), len(), and iteration. + """ - # Merge fields - self.fields = _union_fields(self.fields, other.fields) + __slots__ = ("_batch",) - # Update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - return self + def __init__(self, batch: "BatchMeta") -> None: + self._batch = batch - @property - def is_ready(self) -> bool: - """Check if all fields in this sample are ready for consumption""" - return getattr(self, "_is_ready", False) + def __len__(self) -> int: + return len(self._batch.global_indexes) - @property - def production_status(self) -> dict[str, ProductionStatus]: - """Get production status for all fields (backward compatibility)""" - return {name: field.production_status for name, field in self.fields.items()} + def __getitem__(self, idx: int) -> _SampleView: + return _SampleView(self._batch, idx) - @classmethod - def from_dict(cls, data: dict) -> "SampleMeta": - """Create SampleMeta from dictionary.""" - fields = { - name: FieldMeta.from_dict(field_data) if isinstance(field_data, dict) else field_data - for name, field_data in data["fields"].items() - } - return cls( - partition_id=data["partition_id"], - global_index=data["global_index"], - fields=fields, - ) + def __iter__(self): + return (_SampleView(self._batch, i) for i in range(len(self))) @dataclass class BatchMeta: - """Records the metadata of a batch of data samples.""" - - samples: list[SampleMeta] + """Records the metadata of a batch of data samples with optimized field-level schema. + + This is the O(BxF) optimized version that stores field metadata at the field level + instead of per-sample, reducing storage from O(B*F) to O(F). + + Attributes: + global_indexes: List of global sample indices in this batch. + partition_ids: List of partition IDs corresponding to each sample. + field_schema: Field-level metadata {field_name: {dtype, shape, is_nested, is_non_tensor, per_sample_shapes}}. + production_status: Vectorized production status, shape (B,) where B is batch size. + extra_info: Additional batch-level information. + custom_meta: Per-sample user-defined metadata, list aligned with global_indexes. + _custom_backend_meta: Per-sample per-field storage backend metadata, list aligned with global_indexes. + """ - # external meta for non-sample level information + global_indexes: list[int] + partition_ids: list[str] + # O(F) field-level metadata: {field_name: {dtype, shape, is_nested, is_non_tensor}} + field_schema: dict[str, dict[str, Any]] = dataclasses.field(default_factory=dict) + # O(B) vectorized production status; always np.ndarray after __post_init__ (never None) + production_status: np.ndarray = dataclasses.field(default=None, repr=False) # type: ignore[assignment] extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - - # user-defined meta for each sample - custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) - - # internal meta for different storage backends in per-sample per-field level - _custom_backend_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) + # user-defined meta for each sample (sample-level), list aligned with global_indexes + custom_meta: list[dict[str, Any]] = dataclasses.field(default_factory=list) + # internal meta for different storage backends (per-sample per-field level), list aligned with global_indexes + _custom_backend_meta: list[dict[str, Any]] = dataclasses.field(default_factory=list) def __post_init__(self): """Initialize all computed properties during initialization""" - self.samples = copy.deepcopy(self.samples) + self.global_indexes = copy.deepcopy(self.global_indexes) + self.partition_ids = copy.deepcopy(self.partition_ids) + self.field_schema = copy.deepcopy(self.field_schema) self.extra_info = copy.deepcopy(self.extra_info) - # Basic properties - object.__setattr__(self, "_size", len(self.samples)) - object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) + # Validation + if len(self.global_indexes) != len(self.partition_ids): + raise ValueError( + f"Length mismatch: global_indexes has {len(self.global_indexes)}, " + f"partition_ids has {len(self.partition_ids)}" + ) + + batch_size = len(self.global_indexes) - # Pre-compute all list properties for better performance - if self.samples: - for idx, sample in enumerate(self.samples): - object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly + if self.production_status is not None: + if isinstance(self.production_status, np.ndarray): + self.production_status = self.production_status.copy() + elif isinstance(self.production_status, torch.Tensor): + self.production_status = self.production_status.numpy().copy() + elif isinstance(self.production_status, list): + self.production_status = np.array(self.production_status, dtype=np.int8) - object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) + if len(self.production_status) != batch_size: + raise ValueError(f"production_status length {len(self.production_status)} != batch_size {batch_size}") + else: + # Default: all NOT_PRODUCED (including empty batches) + self.production_status = np.zeros(batch_size, dtype=np.int8) + + for field_name, meta in self.field_schema.items(): + if meta.get("per_sample_shapes") is not None: + if len(meta["per_sample_shapes"]) != batch_size: + raise ValueError( + f"Field '{field_name}' per_sample_shapes length {len(meta['per_sample_shapes'])} " + f"!= batch_size {batch_size}" + ) - # check if all samples have the same field names - first_sample_field_names = sorted(self.samples[0].field_names) - if not all(sorted(sample.field_names) == first_sample_field_names for sample in self.samples): - raise ValueError("All samples in BatchMeta must have the same field_names.") - object.__setattr__(self, "_field_names", first_sample_field_names) + self._size = batch_size + self._field_names = sorted(self.field_schema.keys()) - object.__setattr__(self, "_partition_ids", [sample.partition_id for sample in self.samples]) + is_ready = batch_size > 0 and bool(np.all(self.production_status == 1)) + self._is_ready = is_ready - # filter custom_meta and _custom_backend_meta - self.custom_meta = copy.deepcopy( - {k: self.custom_meta[k] for k in self.global_indexes if k in self.custom_meta} - ) - self._custom_backend_meta = copy.deepcopy( - {k: self._custom_backend_meta[k] for k in self.global_indexes if k in self._custom_backend_meta} - ) + # Validate or initialize columnar custom_meta / _custom_backend_meta + if not self.custom_meta: + self.custom_meta = [{} for _ in range(batch_size)] else: - self.custom_meta = {} - self._custom_backend_meta = {} - object.__setattr__(self, "_global_indexes", []) - object.__setattr__(self, "_field_names", []) - object.__setattr__(self, "_partition_ids", []) + self.custom_meta = copy.deepcopy(self.custom_meta) + if len(self.custom_meta) != batch_size: + raise ValueError(f"custom_meta length {len(self.custom_meta)} != batch_size {batch_size}") + if not self._custom_backend_meta: + self._custom_backend_meta = [{} for _ in range(batch_size)] + else: + self._custom_backend_meta = copy.deepcopy(self._custom_backend_meta) + if len(self._custom_backend_meta) != batch_size: + raise ValueError( + f"_custom_backend_meta length {len(self._custom_backend_meta)} != batch_size {batch_size}" + ) @property def size(self) -> int: """Return the number of samples in this batch""" return getattr(self, "_size", 0) - @property - def global_indexes(self) -> list[int]: - """Get all global indexes in this batch""" - return getattr(self, "_global_indexes", []) - @property def field_names(self) -> list[str]: """Get all unique field names in this batch""" return getattr(self, "_field_names", []) + @property + def samples(self) -> _SampleViewList: + """Lazy per-sample view: supports samples[i].fields['a'], len(samples), for s in samples.""" + return _SampleViewList(self) + @property def is_ready(self) -> bool: """Check if all samples in this batch are ready for consumption""" - # TODO: get ready status from controller realtime return getattr(self, "_is_ready", False) - @property - def partition_ids(self) -> list[str]: - """Get partition ids for all samples in this batch as a list (one per sample)""" - return getattr(self, "_partition_ids", []) + # ==================== Extra Info Methods ==================== + + def get_extra_info(self, key: str, default: Any = None) -> Any: + """Get extra info by key""" + return self.extra_info.get(key, default) + + def set_extra_info(self, key: str, value: Any) -> None: + """Set extra info by key""" + self.extra_info[key] = value def get_all_extra_info(self) -> dict[str, Any]: """Get all extra_info as a dictionary (deep copy for immutability). @@ -285,49 +230,44 @@ def get_all_extra_info(self) -> dict[str, Any]: return copy.deepcopy(self.extra_info) def update_extra_info(self, info_dict: dict[str, Any]) -> None: - """ - Update extra_info with multiple key-value pairs. - - This method updates the extra_info dictionary with the provided key-value pairs. - Existing keys will be overwritten with new values. + """Update extra_info with multiple key-value pairs. Args: info_dict: Dictionary of key-value pairs to add/update in extra_info """ self.extra_info.update(info_dict) - def clear_extra_info(self) -> None: - """ - Clear all extra_info. + def remove_extra_info(self, key: str) -> Any: + """Remove extra info by key and return its value""" + return self.extra_info.pop(key, None) - This method removes all key-value pairs from the extra_info dictionary. - """ + def clear_extra_info(self) -> None: + """Clear all extra_info.""" self.extra_info.clear() + def has_extra_info(self, key: str) -> bool: + """Check if extra info contains a specific key""" + return key in self.extra_info + + # ==================== Custom Meta Methods (User Layer) ==================== + def get_all_custom_meta(self) -> list[dict[str, Any]]: - """ - Get all custom_meta as a list of dictionary. + """Get all custom_meta as a list of dictionary (one per sample, in global_indexes order). Returns: A deep copy of the custom_meta list """ - custom_meta = [self.custom_meta.get(i, {}) for i in self.global_indexes] - return copy.deepcopy(custom_meta) + return copy.deepcopy(self.custom_meta) def update_custom_meta(self, custom_meta: list[dict[str, Any]]): - """ - Update custom_meta with a list of dictionary of custom metadata. - - This method updates the custom_meta dictionary with the provided metadata. - Existing keys will be overwritten with new values. + """Update custom_meta with a list of dictionary of custom metadata. Args: - custom_meta: list of custom_meta dictionary + custom_meta: list of custom_meta dictionary (one per sample, in global_indexes order) Raises: - ValueError: If the length of custom_meta cannot match the batch size + ValueError: If the length of custom_meta does not match the batch size """ - if custom_meta is None: return @@ -336,142 +276,192 @@ def update_custom_meta(self, custom_meta: list[dict[str, Any]]): f"The length of custom_meta list {len(custom_meta)} must match the batch size: {self.size}" ) - custom_meta_dict: dict[int, dict[str, Any]] = { - self.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta)) - } - - self.custom_meta.update(custom_meta_dict) + for i, meta in enumerate(custom_meta): + self.custom_meta[i].update(meta) def clear_custom_meta(self) -> None: - """ - Clear all custom_meta. + """Clear all custom_meta.""" + self.custom_meta = [{} for _ in range(self.size)] - This method removes all entries from the custom_meta dictionary. - """ - self.custom_meta.clear() + # ==================== Core BatchMeta Operations ==================== def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": - """ - Add new fields from a TensorDict to all samples in this batch. - This modifies each sample in-place to include the new fields. + """Add new fields from a TensorDict to all samples in this batch. + This modifies the batch in-place to include the new fields. Args: tensor_dict (TensorDict): The input TensorDict containing new fields. set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True. """ - fields = _extract_field_metas(tensor_dict, set_all_ready) - - if fields: - if len(self.samples) != len(fields): - raise ValueError(f"add_fields length mismatch: samples={len(self.samples)} vs fields={len(fields)}") - for idx, sample in enumerate(self.samples): - sample.add_fields(fields=fields[idx]) - - # Update batch-level fields cache - if self.samples: - object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names)) - object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) + batch_size = tensor_dict.batch_size[0] + if batch_size != self.size: + raise ValueError(f"add_fields batch size mismatch: self.size={self.size} vs tensor_dict={batch_size}") + + for name, value in tensor_dict.items(): + # Determine if this is a nested tensor + is_nested = isinstance(value, torch.Tensor) and value.is_nested + + first_item = None + if is_nested: + unbound = value.unbind() + first_item = unbound[0] if unbound else None + else: + first_item = value[0] if len(value) > 0 else None + + # Determine if this is non-tensor data. + # When first_item is None (empty field), we cannot determine type—leave as None. + is_non_tensor = not isinstance(first_item, torch.Tensor) if first_item is not None else None + + field_meta = { + "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), + "shape": getattr(first_item, "shape", None) if not is_nested else None, + "is_nested": is_nested, + "is_non_tensor": is_non_tensor, + } + + # For nested tensors, record per-sample shapes + if is_nested: + field_meta["per_sample_shapes"] = [tuple(t.shape) for t in value.unbind()] + + self.field_schema[name] = field_meta + + if set_all_ready: + self.production_status[:] = 1 + + self._field_names = sorted(self.field_schema.keys()) + + self._is_ready = self.size > 0 and bool(np.all(self.production_status == 1)) + return self - def select_samples(self, indexes: list[int]) -> "BatchMeta": - """ - Select specific samples from this batch. + def select_samples(self, sample_indices: list[int]) -> "BatchMeta": + """Select specific samples from this batch. This will construct a new BatchMeta instance containing only the specified samples. Args: - indexes (list[int]): List of indexes (relative to this batch, not global_indexes) - to retain. + sample_indices (list[int]): List of sample indices (relative to this batch) to retain. Returns: BatchMeta: A new BatchMeta instance containing only the specified samples. """ + if any(i < 0 or i >= self.size for i in sample_indices): + raise ValueError(f"Sample indices must be in range [0, {self.size})") - if any(i < 0 or i >= len(self.samples) for i in indexes): - raise ValueError(f"Sample indices must be in range [0, {len(self.samples)})") + new_global_indexes = [self.global_indexes[i] for i in sample_indices] + new_partition_ids = [self.partition_ids[i] for i in sample_indices] - selected_samples = [self.samples[i] for i in indexes] + # Select production_status + new_production_status = self.production_status[sample_indices] - global_indexes = [self.global_indexes[i] for i in indexes] - selected_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta} - selected_custom_backend_meta = { - i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta - } + new_field_schema = {} + for fname, meta in self.field_schema.items(): + new_meta = copy.deepcopy(meta) + if meta.get("per_sample_shapes") is not None: + new_meta["per_sample_shapes"] = [meta["per_sample_shapes"][i] for i in sample_indices] + new_field_schema[fname] = new_meta + + new_custom_meta = [copy.deepcopy(self.custom_meta[i]) for i in sample_indices] - # construct new BatchMeta instance - selected_batch_meta = BatchMeta( - samples=selected_samples, + new_custom_backend_meta = [copy.deepcopy(self._custom_backend_meta[i]) for i in sample_indices] + + return BatchMeta( + global_indexes=new_global_indexes, + partition_ids=new_partition_ids, + field_schema=new_field_schema, + production_status=new_production_status, extra_info=self.extra_info, - custom_meta=selected_custom_meta, - _custom_backend_meta=selected_custom_backend_meta, + custom_meta=new_custom_meta, + _custom_backend_meta=new_custom_backend_meta, ) - return selected_batch_meta - def select_fields(self, field_names: list[str]) -> "BatchMeta": - """ - Select specific fields from all samples in this batch. + """Select specific fields from all samples in this batch. This will construct a new BatchMeta instance containing only the specified fields. Args: field_names (list[str]): List of field names to retain. Returns: - BatchMeta: A new BatchMeta instance containing only the specified fields from all samples. + BatchMeta: A new BatchMeta instance containing only the specified fields. """ - # select fields for each SampleMeta - new_samples = [sample.select_fields(field_names=field_names) for sample in self.samples] - - # select fields in _custom_backend_meta - selected_custom_backend_meta = {} - for idx in self.global_indexes: - if idx in self._custom_backend_meta: - custom_backend_meta_idx = self._custom_backend_meta[idx] - - selected_custom_backend_meta[idx] = { - field: custom_backend_meta_idx[field] for field in field_names if field in custom_backend_meta_idx - } - - # construct new BatchMeta instance - new_batch_meta = BatchMeta( - samples=new_samples, + new_field_schema = {} + for fname in field_names: + if fname in self.field_schema: + new_field_schema[fname] = copy.deepcopy(self.field_schema[fname]) + + selected_custom_backend_meta = [ + {f: v for f, v in m.items() if f.startswith("_") or f in field_names} for m in self._custom_backend_meta + ] + + return BatchMeta( + global_indexes=self.global_indexes, + partition_ids=self.partition_ids, + field_schema=new_field_schema, + production_status=self.production_status.copy(), extra_info=self.extra_info, custom_meta=self.custom_meta, _custom_backend_meta=selected_custom_backend_meta, ) - return new_batch_meta + def with_data_fields(self, field_names: list[str]) -> "BatchMeta": + """Return a new BatchMeta with the given data fields, replacing the current field_schema. - def __len__(self) -> int: - """Return the number of samples in this batch.""" - return len(self.samples) + Unlike ``select_fields``, this method allows specifying field names that are not + yet present in the current ``field_schema`` (e.g. fields added by a subsequent + ``put`` call on a subset of samples). Unknown fields are included in the new + ``field_schema`` with an empty metadata dict so that ``get_data`` can retrieve + them from the storage backend. - def __getitem__(self, item): - if isinstance(item, int | np.integer): - sample_meta = self.samples[item] if self.samples else [] - global_idx = self.global_indexes[item] + Args: + field_names (list[str]): List of field names to request. May include fields + not present in the current ``field_schema``. - if global_idx in self.custom_meta: - custom_meta = {global_idx: self.custom_meta[global_idx]} + Returns: + BatchMeta: A new BatchMeta instance whose ``field_schema`` contains exactly + the requested fields (existing metadata is preserved where available). + """ + new_field_schema = {} + for fname in field_names: + if fname in self.field_schema: + new_field_schema[fname] = copy.deepcopy(self.field_schema[fname]) else: - custom_meta = {} + # Unknown field — include with empty schema so get_data can fetch it. + new_field_schema[fname] = {} - if global_idx in self._custom_backend_meta: - custom_backend_meta = {global_idx: self._custom_backend_meta[global_idx]} - else: - custom_backend_meta = {} + selected_custom_backend_meta = [ + {f: v for f, v in m.items() if f.startswith("_") or f in field_names} for m in self._custom_backend_meta + ] - return BatchMeta( - samples=[sample_meta], - extra_info=self.extra_info, - custom_meta=custom_meta, - _custom_backend_meta=custom_backend_meta, - ) + return BatchMeta( + global_indexes=self.global_indexes, + partition_ids=self.partition_ids, + field_schema=new_field_schema, + production_status=self.production_status.copy(), + extra_info=self.extra_info, + custom_meta=self.custom_meta, + _custom_backend_meta=selected_custom_backend_meta, + ) + + def __len__(self) -> int: + """Return the number of samples in this batch.""" + return self.size + + def __getitem__(self, item) -> "BatchMeta": + if isinstance(item, int | np.integer): + if item < 0: + item += self.size + if item < 0 or item >= self.size: + raise IndexError("BatchMeta index out of range") + return self.select_samples([item]) + elif isinstance(item, slice): + start, stop, step = item.indices(self.size) + indices = list(range(start, stop, step)) + return self.select_samples(indices) else: - raise TypeError(f"Indexing with {type(item)} is not supported now!") + raise TypeError(f"Indexing with {type(item)} is not supported.") def chunk(self, chunks: int) -> list["BatchMeta"]: - """ - Split this batch into smaller chunks. + """Split this batch into smaller chunks. Args: chunks: number of chunks @@ -480,7 +470,7 @@ def chunk(self, chunks: int) -> list["BatchMeta"]: List of smaller BatchMeta chunks """ chunk_list = [] - n = len(self.samples) + n = self.size if n < chunks: logger.warning( @@ -494,47 +484,57 @@ def chunk(self, chunks: int) -> list["BatchMeta"]: start = 0 for i in range(chunks): - # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size) current_chunk_size = base_size + 1 if i < remainder else base_size end = start + current_chunk_size - chunk_samples = self.samples[start:end] - global_indexes = self.global_indexes[start:end] - chunk_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta} - chunk_custom_backend_meta = { - i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta - } - chunk = BatchMeta( - samples=chunk_samples, - extra_info=self.extra_info, - custom_meta=chunk_custom_meta, - _custom_backend_meta=chunk_custom_backend_meta, - ) + indices = list(range(start, end)) + chunk = self.select_samples(indices) chunk_list.append(chunk) start = end return chunk_list - def chunk_by_partition( - self, - ) -> list["BatchMeta"]: - """ - Split this batch into smaller chunks according to partition_ids. + def chunk_by_partition(self) -> list["BatchMeta"]: + """Split this batch into smaller chunks according to partition_ids. Return: List of smaller BatchMeta chunks, each chunk has samples with identical partition_id """ - grouped_indexes = defaultdict(list) - for partition_id, indexes in zip(self.partition_ids, range(self.size), strict=False): + for partition_id, indexes in zip(self.partition_ids, range(self.size), strict=True): grouped_indexes[partition_id].append(indexes) chunk_list = [self.select_samples(idx) for idx in grouped_indexes.values()] - return chunk_list + def union(self, other: "BatchMeta") -> "BatchMeta": + """Return the union of this BatchMeta and another BatchMeta. + Samples with global_indexes already present in this batch are ignored from the other batch. + + Args: + other: The other BatchMeta to merge with. + + Returns: + BatchMeta: A new merged BatchMeta. + """ + if not other or other.size == 0: + return self + if self.size == 0: + return other + + self_indexes = set(self.global_indexes) + unique_indices_in_other = [i for i, idx in enumerate(other.global_indexes) if idx not in self_indexes] + + if not unique_indices_in_other: + return self + + if len(unique_indices_in_other) == other.size: + return BatchMeta.concat([self, other]) + + other_unique = other.select_samples(unique_indices_in_other) + return BatchMeta.concat([self, other_unique]) + @classmethod def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": - """ - Concatenate multiple BatchMeta chunks into one large batch. + """Concatenate multiple BatchMeta chunks into one large batch. Args: data: List of BatchMeta chunks to concatenate @@ -548,14 +548,14 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": """ if not data: logger.warning("Try to concat empty BatchMeta chunks. Returning empty BatchMeta.") - return BatchMeta(samples=[], extra_info={}, custom_meta={}, _custom_backend_meta={}) + return BatchMeta.empty() # skip empty chunks - data = [chunk for chunk in data if chunk and len(chunk.samples) > 0] + data = [chunk for chunk in data if chunk and chunk.size > 0] if len(data) == 0: logger.warning("No valid BatchMeta chunks to concatenate. Returning empty BatchMeta.") - return BatchMeta(samples=[], extra_info={}, custom_meta={}, _custom_backend_meta={}) + return BatchMeta.empty() if validate: base_fields = data[0].field_names @@ -564,25 +564,62 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": if chunk.field_names != base_fields: raise ValueError("Error: Field names do not match for concatenation.") - # Combine all samples - all_samples = list(itertools.chain.from_iterable(chunk.samples for chunk in data)) + # Validate field_schema dtype and is_nested consistency across chunks + for fname in base_fields: + base_meta = data[0].field_schema.get(fname, {}) + base_dtype = base_meta.get("dtype") + base_is_nested = base_meta.get("is_nested", False) + for i, chunk in enumerate(data[1:], start=1): + chunk_meta = chunk.field_schema.get(fname, {}) + if chunk_meta.get("dtype") != base_dtype: + raise ValueError( + f"Field '{fname}' dtype mismatch in concat: " + f"chunk[0]={base_dtype}, chunk[{i}]={chunk_meta.get('dtype')}" + ) + if chunk_meta.get("is_nested", False) != base_is_nested: + raise ValueError( + f"Field '{fname}' is_nested mismatch in concat: " + f"chunk[0]={base_is_nested}, chunk[{i}]={chunk_meta.get('is_nested', False)}" + ) + + all_global_indexes = list(itertools.chain.from_iterable(chunk.global_indexes for chunk in data)) + all_partition_ids = list(itertools.chain.from_iterable(chunk.partition_ids for chunk in data)) + + all_production_status = np.concatenate([chunk.production_status for chunk in data]) + + all_field_schema: dict[str, dict[str, Any]] = {} + first_chunk = data[0] + for fname, meta in first_chunk.field_schema.items(): + all_field_schema[fname] = { + "dtype": meta.get("dtype"), + "shape": meta.get("shape"), + "is_nested": meta.get("is_nested", False), + "is_non_tensor": meta.get("is_non_tensor", False), + } + if any(chunk.field_schema.get(fname, {}).get("per_sample_shapes") for chunk in data): + all_shapes = [] + for chunk in data: + chunk_meta = chunk.field_schema.get(fname, {}) + chunk_shapes = chunk_meta.get("per_sample_shapes") + if chunk_shapes: + all_shapes.extend(chunk_shapes) + else: + all_shapes.extend([None] * chunk.size) + all_field_schema[fname]["per_sample_shapes"] = all_shapes + + all_custom_meta: list[dict[str, Any]] = [] + all_custom_backend_meta: list[dict[str, Any]] = [] + for chunk in data: + all_custom_meta.extend(chunk.custom_meta) + all_custom_backend_meta.extend(chunk._custom_backend_meta) # Merge all extra_info dictionaries from the chunks merged_extra_info = dict() - merged_custom_meta = dict() - merged_custom_backend_meta = dict() values_by_key = defaultdict(list) for chunk in data: - # For the sample-level custom_meta and field-level _custom_backend_meta, we directly update the dict. - merged_custom_meta.update(chunk.custom_meta) - merged_custom_backend_meta.update(chunk._custom_backend_meta) - for key, value in chunk.extra_info.items(): values_by_key[key].append(value) - - # For the batch-level extra_info, we concat the tensor/NonTensorStack/NonTensorData/list - # objects to prevent information losses. for key, values in values_by_key.items(): if all(isinstance(v, torch.Tensor) for v in values): try: @@ -601,161 +638,52 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": elif all(isinstance(v, list) for v in values): merged_extra_info[key] = list(itertools.chain.from_iterable(values)) else: - merged_extra_info[key] = values[-1] - - return BatchMeta( - samples=all_samples, - extra_info=merged_extra_info, - custom_meta=merged_custom_meta, - _custom_backend_meta=merged_custom_backend_meta, - ) - - def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]: - """ - Create a union of this batch's fields with another batch's fields. - Assume both batches have the same global indices and matching partition_ids for all samples. - If fields overlap, the fields in this batch will be replaced by the other batch's fields. - - Args: - other: Another BatchMeta to union with - validate: Whether to validate union conditions - - Returns: - New BatchMeta with unioned fields - - Raises: - ValueError: If validation fails (e.g., batch sizes or global indexes do not match) - """ - if validate: - if self.size != other.size: - raise ValueError("Error: Batch sizes do not match for union.") - - self_global_indexes = sorted(self.global_indexes) - other_global_indexes = sorted(other.global_indexes) - if self_global_indexes != other_global_indexes: - raise ValueError("Error: Global indexes do not match for union.") - - if self.partition_ids != other.partition_ids: - raise ValueError("Error: Partition IDs do not match for union.") - - # Create a mapping from global_index to SampleMeta in the other batch - other_sample_map = {sample.global_index: sample for sample in other.samples} - - # Merge samples - merged_samples = [] - for sample in self.samples: - if sample.global_index in other_sample_map: - other_sample = other_sample_map[sample.global_index] - merged_sample = sample.union(other_sample, validate=validate) - merged_samples.append(merged_sample) - else: - merged_samples.append(sample) - - # Merge extra info dictionaries - merged_extra_info = {**self.extra_info, **other.extra_info} - - # Merge custom_meta dictionaries - merged_custom_meta = {**self.custom_meta, **other.custom_meta} - - # Merge custom_backend_meta dictionaries - merged_custom_backend_meta = {} - for idx in self.global_indexes: - if idx in self._custom_backend_meta and idx in other._custom_backend_meta: - merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx], **other._custom_backend_meta[idx]} - elif idx in self._custom_backend_meta: - merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx]} - elif idx in other._custom_backend_meta: - merged_custom_backend_meta[idx] = {**other._custom_backend_meta[idx]} + raise TypeError( + f"BatchMeta.concat: extra_info key '{key}' has type(s) " + f"{[type(v).__name__ for v in values]} which cannot be concatenated. " + f"Only list-like iterables support concat: list, torch.Tensor, " + f"NonTensorStack/NonTensorData. Scalar types (int, str, float, dict, etc.) " + f"have no defined merge strategy—handle '{key}' manually before calling concat." + ) return BatchMeta( - samples=merged_samples, + global_indexes=all_global_indexes, + partition_ids=all_partition_ids, + field_schema=all_field_schema, + production_status=all_production_status, extra_info=merged_extra_info, - custom_meta=merged_custom_meta, - _custom_backend_meta=merged_custom_backend_meta, + custom_meta=all_custom_meta, + _custom_backend_meta=all_custom_backend_meta, ) - def reorder(self, indexes: list[int]): - """ - Reorder the SampleMeta in the BatchMeta according to the given indexes (must equal to the length of samples). - - The operation is performed in-place, modifying the current BatchMeta's SampleMeta order. - - To select a subset of samples or repeat specific samples, please use the non-inplace method select_samples(). - - Args: - indexes : list[int] - A list of integers specifying the new order of SampleMeta. Each integer - represents the current index of the SampleMeta in the BatchMeta. + def reorder(self, indices: list[int]): + """Reorder the samples in the BatchMeta according to the given indices. + The operation is performed in-place. """ + if len(indices) != self.size: + raise ValueError(f"Indices length {len(indices)} mismatch batch size {self.size}") - if len(indexes) != self.size: - raise ValueError( - f"Attempted to reorder with indexes length {len(indexes)} that does not match samples length " - f"{self.size}. Please use non-inplace method select_samples() instead if you want to " - f"select a subset of samples or repeat specific samples." - ) - - if len(set(indexes)) != self.size: - raise ValueError( - f"Indexes={indexes} contain duplicates. Please use non-inplace method " - f"select_samples() instead if you want to select a subset of samples or repeat specific samples." - ) - - if any(i < 0 or i >= len(self.samples) for i in indexes): - raise ValueError(f"Reorder indexes must be in the range [0, {self.size}).") + if len(set(indices)) != self.size: + raise ValueError("Indices contain duplicates") - # Reorder the samples - reordered_samples = [self.samples[i] for i in indexes] - object.__setattr__(self, "samples", reordered_samples) + if any(i < 0 or i >= self.size for i in indices): + raise ValueError(f"Reorder indices must be in range [0, {self.size})") - # Update necessary attributes - self._update_after_reorder() + self.global_indexes = [self.global_indexes[i] for i in indices] + self.partition_ids = [self.partition_ids[i] for i in indices] - def _update_after_reorder(self) -> None: - """Update related attributes specifically for the reorder operation""" - # Update batch_index for each sample - for idx, sample in enumerate(self.samples): - object.__setattr__(sample, "_batch_index", idx) + self.production_status = self.production_status[indices] - # Update cached index lists - object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) - object.__setattr__(self, "_partition_ids", [sample.partition_id for sample in self.samples]) + for fname, meta in self.field_schema.items(): + if meta.get("per_sample_shapes") is not None: + meta["per_sample_shapes"] = [meta["per_sample_shapes"][i] for i in indices] - # Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder - - @classmethod - def from_samples( - cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None - ) -> "BatchMeta": - """ - Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects. - - Args: - samples: A single SampleMeta or a list of SampleMeta objects - extra_info: Optional additional information to store with the batch - - Returns: - BatchMeta instance containing the provided sample(s) - - Example: - >>> sample_meta = SampleMeta(...) - >>> batch_meta = BatchMeta.from_samples(sample_meta) - - >>> sample_metas = [sample1, sample2, sample3] - >>> batch_meta = BatchMeta.from_samples(sample_metas, extra_info={"source": "training"}) - """ - if extra_info is None: - extra_info = {} - - if isinstance(samples, SampleMeta): - samples = [samples] - - return cls(samples=samples, extra_info=extra_info) + self.custom_meta = [self.custom_meta[i] for i in indices] + self._custom_backend_meta = [self._custom_backend_meta[i] for i in indices] @classmethod def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": - """ - Create an empty BatchMeta with no samples. + """Create an empty BatchMeta with no samples. Args: extra_info: Optional additional information to store with the batch @@ -768,78 +696,88 @@ def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": """ if extra_info is None: extra_info = {} - return cls(samples=[], extra_info=extra_info, custom_meta={}, _custom_backend_meta={}) + return cls( + global_indexes=[], + partition_ids=[], + field_schema={}, + production_status=None, + extra_info=extra_info, + custom_meta=[], + _custom_backend_meta=[], + ) def __str__(self): - sample_strs = ", ".join(str(sample) for sample in self.samples) return ( f"BatchMeta(size={self.size}, field_names={self.field_names}, is_ready={self.is_ready}, " - f"samples=[{sample_strs}], extra_info={self.extra_info})" + f"global_indexes={self.global_indexes}, extra_info={self.extra_info})" ) + def to_dict(self) -> dict: + """Convert BatchMeta to dict for serialization. + + dtype is explicitly serialized as a string (e.g. "torch.float32", "float64") so + that from_dict() can reconstruct it without relying on pickle to transparently + round-trip torch.dtype / numpy.dtype objects. + """ + serialized_schema = {} + for fname, meta in self.field_schema.items(): + dtype = meta.get("dtype") + serialized_schema[fname] = { + "dtype": str(dtype) if dtype is not None else None, + "shape": list(meta["shape"]) if meta.get("shape") else None, + "is_nested": meta.get("is_nested", False), + "is_non_tensor": meta.get("is_non_tensor", False), + } + if meta.get("per_sample_shapes") is not None: + serialized_schema[fname]["per_sample_shapes"] = [list(s) for s in meta["per_sample_shapes"]] + + return { + "global_indexes": self.global_indexes, + "partition_ids": self.partition_ids, + "field_schema": serialized_schema, + "production_status": self.production_status.tolist() if self.production_status is not None else [], + "extra_info": self.extra_info, + "custom_meta": self.custom_meta, + "_custom_backend_meta": self._custom_backend_meta, + } + @classmethod def from_dict(cls, data: dict) -> "BatchMeta": - """Create BatchMeta from dictionary.""" - samples = [ - SampleMeta.from_dict(sample_data) if isinstance(sample_data, dict) else sample_data - for sample_data in data["samples"] - ] + """Create BatchMeta from dictionary. + + dtype is stored as a string and decoded back to torch.dtype / numpy.dtype here. + """ + field_schema = {} + for fname, meta in data.get("field_schema", {}).items(): + dtype_str = meta.get("dtype") + dtype = _parse_dtype(dtype_str) if dtype_str is not None else None + field_schema[fname] = { + "dtype": dtype, + "shape": tuple(meta["shape"]) if meta.get("shape") else None, + "is_nested": meta.get("is_nested", False), + "is_non_tensor": meta.get("is_non_tensor", False), + } + if meta.get("per_sample_shapes") is not None: + field_schema[fname]["per_sample_shapes"] = [tuple(s) for s in meta["per_sample_shapes"]] + + ps_data = data.get("production_status") + production_status: np.ndarray = ( + np.array(ps_data, dtype=np.int8) + if ps_data is not None + else np.zeros(len(data["global_indexes"]), dtype=np.int8) + ) + return cls( - samples=samples, + global_indexes=data["global_indexes"], + partition_ids=data["partition_ids"], + field_schema=field_schema, + production_status=production_status, extra_info=data.get("extra_info", {}), - custom_meta=data.get("custom_meta", {}), - _custom_backend_meta=data.get("_custom_backend_meta", {}), + custom_meta=data.get("custom_meta", []), + _custom_backend_meta=data.get("_custom_backend_meta", []), ) -def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]: - """Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2.""" - for name in fields2.keys(): - fields1[name] = fields2[name] - return fields1 - - -def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> list[dict[str, FieldMeta]]: - """ - Extract field metas from a TensorDict. If data in tensor_dict does not have dtype or shape attribute, - the corresponding dtype or shape will be set to None. - - Args: - tensor_dict (TensorDict): The input TensorDict. - set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. - Otherwise, set to NOT_PRODUCED. Default is True. - - Returns: - all_fields (list[dict[str, FieldMeta]]): A list of dictionaries containing field metadata. - """ - batch_size = tensor_dict.batch_size[0] - - production_status = ProductionStatus.READY_FOR_CONSUME if set_all_ready else ProductionStatus.NOT_PRODUCED - - # unbind nested tensor - results: dict = {} - for field in tensor_dict.keys(): - field_data = tensor_dict[field] - if batch_size > 1 and isinstance(field_data, Tensor) and field_data.is_nested: - results[field] = field_data.unbind() - else: - results[field] = field_data - - all_fields = [] - for idx in range(batch_size): - dict_of_field_meta = {} - for field_name in results.keys(): - dict_of_field_meta[field_name] = FieldMeta( - name=field_name, - dtype=getattr(results[field_name][idx], "dtype", None), - shape=getattr(results[field_name][idx], "shape", None), - production_status=production_status, - ) - all_fields.append(dict_of_field_meta) - - return all_fields - - # ==================== KV Interface Metadata ==================== @dataclass class KVBatchMeta: @@ -889,9 +827,7 @@ def __str__(self): return f"KVBatchMeta(size={self.size}, field_names={self.fields}, extra_info={self.extra_info})" def select_keys(self, keys_to_select: list[str]) -> "KVBatchMeta": - """ - Select specific keys from this batch. - This will construct a new KVBatchMeta instance containing only the specified keys. + """Select specific keys from this batch. Args: keys_to_select (list[str]): List of keys to retain. @@ -903,7 +839,6 @@ def select_keys(self, keys_to_select: list[str]) -> "KVBatchMeta": ValueError: If duplicate keys exist in input param `keys_to_select`. RuntimeError: If `keys_to_select` contains keys that do not exist in this batch. """ - if len(set(keys_to_select)) != len(keys_to_select): raise ValueError("Contain duplicate keys.") @@ -925,14 +860,13 @@ def select_keys(self, keys_to_select: list[str]) -> "KVBatchMeta": ) def reorder(self, indexes: list[int]): - """ - Reorder the samples in this batch according to the specified indexes. + """Reorder the samples in this batch according to the specified indexes. The operation is performed in-place. Args: indexes : list[int] - A list of integers specifying the new order of SampleMeta. + A list of integers specifying the new order of samples. Raises: ValueError: If the size of input `indexes` does not match with the batch size. @@ -951,8 +885,7 @@ def reorder(self, indexes: list[int]): self.tags = [self.tags[i] for i in indexes] def chunk(self, chunks: int) -> list["KVBatchMeta"]: - """ - Split this batch into smaller chunks. + """Split this batch into smaller chunks. Args: chunks: number of chunks @@ -960,7 +893,6 @@ def chunk(self, chunks: int) -> list["KVBatchMeta"]: Return: List of smaller KVBatchMeta chunks """ - chunk_list = [] if self.size < chunks: logger.warning( @@ -974,7 +906,6 @@ def chunk(self, chunks: int) -> list["KVBatchMeta"]: start = 0 for i in range(chunks): - # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size) current_chunk_size = base_size + 1 if i < remainder else base_size end = start + current_chunk_size chunk_keys = self.keys[start:end] @@ -994,8 +925,7 @@ def chunk(self, chunks: int) -> list["KVBatchMeta"]: @classmethod def concat(cls, data: list["KVBatchMeta"]) -> "KVBatchMeta": - """ - Concatenate multiple KVBatchMeta chunks into one large batch. + """Concatenate multiple KVBatchMeta chunks into one large batch. Args: data: List of KVBatchMeta chunks to concatenate diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 2e9ce30..75a21e4 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -14,7 +14,6 @@ # limitations under the License. import asyncio -import copy import itertools import logging import os @@ -77,10 +76,8 @@ def _connect_to_controller(self) -> None: raise ValueError(f"controller_info should be ZMQServerInfo, but got {type(self.controller_info)}") try: - # create zmq context self.zmq_context = zmq.Context() - # create zmq sockets for handshake and data status update self.controller_handshake_socket = create_zmq_socket( self.zmq_context, zmq.DEALER, @@ -94,7 +91,6 @@ def _connect_to_controller(self) -> None: assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" self.data_status_update_socket.connect(self.controller_info.to_addr("data_status_update_socket")) - # do handshake with controller self._do_handshake_with_controller() except Exception as e: @@ -118,7 +114,6 @@ def _do_handshake_with_controller(self) -> None: ) poller.register(self.controller_handshake_socket, zmq.POLLIN) - # Initial handshake request send self._send_handshake_requests() start_time = time.time() @@ -128,7 +123,6 @@ def _do_handshake_with_controller(self) -> None: not is_connected # Only one controller to connect to and time.time() - start_time < TQ_STORAGE_HANDSHAKE_TIMEOUT ): - # Check for timeout and retransmission current_time = time.time() if pending_connection: if ( @@ -210,7 +204,6 @@ async def notify_data_update( shapes: Per-field shapes for each field, in {global_index: {field: shape}} format. custom_backend_meta: Per-field custom_meta for each sample, in {global_index: {field: custom_meta}} format. """ - # Create zmq poller for notifying data update information if not self.controller_info: logger.warning(f"No controller connected for storage manager {self.storage_manager_id}") @@ -256,7 +249,6 @@ async def notify_data_update( self.data_status_update_socket.send_multipart(request_msg) - # Make sure controller successfully receives data status update information. response_received: bool = False start_time = time.time() @@ -360,7 +352,6 @@ def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): super().__init__(controller_info, config) self.storage_client = StorageClientFactory.create(client_name, config) self._multi_threads_executor: Optional[ThreadPoolExecutor] = None - # Register a cleanup function: automatically invoke shutdown when the instance is garbage collected. self._executor_finalizer = weakref.finalize(self, self._shutdown_executor, self._multi_threads_executor) @staticmethod @@ -500,7 +491,6 @@ def process_field(field_idx: int): # Prioritize processing fields with larger tensor sizes to improve parallel efficiency field_sizes = [] for i in range(num_fields): - # Estimate size based on the first value _first_value = values[i * num_samples] if isinstance(_first_value, torch.Tensor): size = _first_value.nelement() * _first_value.element_size() @@ -521,6 +511,8 @@ def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): Extract the expected shape, dtype, and custom_backend_meta for each field-sample pair in metadata. The order matches the key/value order: sorted by field name, then by global index. + O(F) optimized version that uses field_schema instead of per-sample metadata. + Args: metadata (BatchMeta): Metadata containing sample and field information. Returns: @@ -530,30 +522,33 @@ def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): shapes = [] dtypes = [] custom_backend_meta_list = [] - all_custom_backend_meta = copy.deepcopy(metadata._custom_backend_meta) + num_samples = len(metadata) + for field_name in sorted(metadata.field_names): - for index in range(len(metadata)): - field = metadata.samples[index].get_field_by_name(field_name) - assert field is not None, f"Field {field_name} not found in sample {index}" - shapes.append(field.shape) - dtypes.append(field.dtype) - global_index = metadata.global_indexes[index] - custom_backend_meta_list.append(all_custom_backend_meta.get(global_index, {}).get(field_name, None)) + field_meta = metadata.field_schema.get(field_name, {}) + field_shape = field_meta.get("shape") + field_dtype = field_meta.get("dtype") + per_sample_shapes = field_meta.get("per_sample_shapes") + + for index in range(num_samples): + if per_sample_shapes is not None: + shapes.append(per_sample_shapes[index]) + else: + shapes.append(field_shape) + dtypes.append(field_dtype) + custom_backend_meta_list.append(metadata._custom_backend_meta[index].get(field_name, None)) return shapes, dtypes, custom_backend_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: """ Store tensor data in the backend storage and notify the controller. - Serializes the input tensors, stores them using the storage client, - extracts per-sample dtype and shape information, and sends a notification - to the controller that new data is available. + O(F) optimized version that extracts field-level schema instead of per-sample metadata. """ if not metadata.field_names: logger.warning("Attempted to put data, but metadata contains no fields.") return - # For each field, extract dtype and shape for each sample num_samples = len(metadata.global_indexes) if num_samples == 0: return @@ -562,37 +557,40 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: values = self._generate_values(data) loop = asyncio.get_event_loop() - # put to storage backends custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) - per_field_dtypes: dict[int, dict[str, Any]] = {} - per_field_shapes: dict[int, dict[str, Any]] = {} + # O(F): Extract field-level schema by sampling the first item + field_schema: dict[str, dict[str, Any]] = {} + for field_name, field_data in data.items(): + first_item = field_data[0] if len(field_data) > 0 else None - # Initialize the data structure for each global index - for global_idx in metadata.global_indexes: - per_field_dtypes[global_idx] = {} - per_field_shapes[global_idx] = {} + is_nested = isinstance(field_data, torch.Tensor) and field_data.is_nested - for field_name, field_data in data.items(): - for i in range(num_samples): - data_item = field_data[i] - global_idx = metadata.global_indexes[i] - per_field_dtypes[global_idx][field_name] = ( - getattr(data_item, "dtype", None) if isinstance(data_item, Tensor) else None - ) - per_field_shapes[global_idx][field_name] = ( - getattr(data_item, "shape", None) if isinstance(data_item, Tensor) else None - ) + is_non_tensor = not isinstance(first_item, Tensor) if first_item is not None else False + + field_meta = { + "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), + "shape": getattr(first_item, "shape", None) + if not is_nested and isinstance(first_item, Tensor) + else None, + "is_nested": is_nested, + "is_non_tensor": is_non_tensor, + } + + if is_nested: + field_meta["per_sample_shapes"] = [tuple(t.shape) for t in field_data.unbind()] + + field_schema[field_name] = field_meta - # Prepare per-field custom_backend_meta if available per_field_custom_backend_meta: dict[int, dict[str, Any]] = {} if custom_backend_meta: if len(custom_backend_meta) != len(keys): raise ValueError( f"Length of custom_backend_meta ({len(custom_backend_meta)}) does not match expected ({len(keys)})" ) - # custom meta is a flat list aligned with keys/values - # Use itertools.product to eliminate nested loops + # Build gi→position index for O(1) lookup + gi_to_pos = {gi: i for i, gi in enumerate(metadata.global_indexes)} + for global_idx in metadata.global_indexes: per_field_custom_backend_meta[global_idx] = {} @@ -605,13 +603,36 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: strict=True, ): per_field_custom_backend_meta[global_idx][field_name] = meta_value - metadata._custom_backend_meta.update(per_field_custom_backend_meta) + # Also update columnar _custom_backend_meta + metadata._custom_backend_meta[gi_to_pos[global_idx]][field_name] = meta_value # Get current data partition id - # Note: Currently we only support putting to & getting data from a single data partition simultaneously, - # but in the future we may support putting to & getting data from multiple data partitions concurrently. - partition_id = metadata.samples[0].partition_id - # notify controller that new data is ready + partition_id = metadata.partition_ids[0] + + # Build per-sample per-field dtypes and shapes for controller notification + per_field_dtypes: dict[int, dict[str, Any]] = {} + per_field_shapes: dict[int, dict[str, Any]] = {} + for field_name, field_data in data.items(): + first_item = field_data[0] if len(field_data) > 0 else None + is_nested = isinstance(field_data, torch.Tensor) and field_data.is_nested + field_dtype = getattr(first_item, "dtype", type(first_item) if first_item is not None else None) + field_shape = ( + getattr(first_item, "shape", None) if not is_nested and isinstance(first_item, Tensor) else None + ) + # Pre-compute unbind once to avoid O(B²) repeated calls inside the loop + unbound = field_data.unbind() if is_nested else None + + for i, global_idx in enumerate(metadata.global_indexes): + if global_idx not in per_field_dtypes: + per_field_dtypes[global_idx] = {} + per_field_shapes[global_idx] = {} + per_field_dtypes[global_idx][field_name] = field_dtype + if is_nested: + assert unbound is not None # is_nested=True implies unbind() was called + per_field_shapes[global_idx][field_name] = tuple(unbound[i].shape) + else: + per_field_shapes[global_idx][field_name] = field_shape + await self.notify_data_update( partition_id, list(data.keys()), diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index b658ba4..ce82286 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -17,6 +17,7 @@ import logging import os import warnings +from collections import defaultdict from collections.abc import Mapping from functools import wraps from operator import itemgetter @@ -27,13 +28,10 @@ import zmq from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict -from torch import Tensor from transfer_queue.metadata import BatchMeta from transfer_queue.storage.managers.base import TransferQueueStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory -from transfer_queue.storage.simple_backend import StorageMetaGroup -from transfer_queue.utils.common import get_env_bool from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket logger = logging.getLogger(__name__) @@ -48,8 +46,6 @@ TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT", 200)) # seconds TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT", 200)) # seconds -TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) - @TransferQueueStorageManagerFactory.register("SimpleStorage") class AsyncSimpleStorageManager(TransferQueueStorageManager): @@ -78,7 +74,6 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): raise ValueError("AsyncSimpleStorageManager requires non-empty 'zmq_info' in config.") self.storage_unit_infos = self._register_servers(server_infos) - self._build_storage_mapping_functions() def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerInfo]"): """Register and validate server information. @@ -107,16 +102,6 @@ def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerIn return server_infos_transform - def _build_storage_mapping_functions(self): - """Build mapping functions for global index to storage unit and local index. - - Creates round-robin mapping functions to distribute data across storage units. - """ - self.global_index_storage_unit_mapping = lambda x: list(self.storage_unit_infos.keys())[ - x % len(self.storage_unit_infos) - ] - self.global_index_local_index_mapping = lambda x: x // len(self.storage_unit_infos) - # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. @staticmethod def dynamic_storage_manager_socket(socket_name: str): @@ -190,6 +175,9 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: """ Send data to remote StorageUnit based on metadata. + Optimized version using TensorDict slicing and unified async processing. + Complexity: O(F) for schema extraction + O(S) for data distribution. + Args: data: TensorDict containing the data to store. metadata: BatchMeta containing storage location information. @@ -197,58 +185,131 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: logger.debug(f"[{self.storage_manager_id}]: receive put_data request, putting {metadata.size} samples.") - # group samples by storage unit - storage_meta_groups = build_storage_meta_groups( - metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping - ) - - # unbind nested tensor - results: dict = {} - for field in data.keys(): - field_data = data[field] - if data.batch_size[0] > 1 and isinstance(field_data, Tensor) and field_data.is_nested: - results[field] = field_data.unbind() - else: - results[field] = field_data - - # send data to each storage unit - tasks = [ - self._put_to_single_storage_unit( - meta_group.get_local_indexes(), - _filter_storage_data(meta_group, results), - target_storage_unit=storage_id, + storage_unit_keys = list(self.storage_unit_infos.keys()) + num_units = len(storage_unit_keys) + batch_size = metadata.size + + if batch_size == 0: + return + + chunk_size = (batch_size + num_units - 1) // num_units + + field_schema = self._extract_field_schema(data) + + su_to_gis: dict[str, list[int]] = {} + tasks = [] + for unit_idx, storage_id in enumerate(storage_unit_keys): + start = unit_idx * chunk_size + end = min((unit_idx + 1) * chunk_size, batch_size) + if start >= batch_size or start >= end: + continue + gi_slice = metadata.global_indexes[start:end] + su_to_gis[storage_id] = list(gi_slice) + tasks.append( + self._prepare_and_send_to_unit( + unit_idx=unit_idx, + storage_id=storage_id, + chunk_size=chunk_size, + batch_size=batch_size, + start_offset=0, # fixed; no cross-batch rotation + num_units=num_units, + data=data, + metadata=metadata, + ) ) - for storage_id, meta_group in storage_meta_groups.items() - ] + await asyncio.gather(*tasks) - # Gather per-field dtype and shape information for each field - # global_indexes, local_indexes, and field_data correspond one-to-one - per_field_dtypes: dict[int, dict[str, Any]] = {} - per_field_shapes: dict[int, dict[str, Any]] = {} - - # Initialize the data structure for each global index - for global_idx in metadata.global_indexes: - per_field_dtypes[global_idx] = {} - per_field_shapes[global_idx] = {} - - # For each field, extract dtype and shape for each sample - for field in results.keys(): - for i, data_item in enumerate(results[field]): - global_idx = metadata.global_indexes[i] - per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None - per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None - - # Get current data partition id - # Note: Currently we only support putting to & getting data from a single data partition simultaneously, - # but in the future we may support putting to & getting data from multiple data partitions concurrently. - partition_id = metadata.samples[0].partition_id - - # notify controller that new data is ready + partition_id = metadata.partition_ids[0] + dtypes_for_notify = { + gi: {fname: fmeta.get("dtype") for fname, fmeta in field_schema.items()} for gi in metadata.global_indexes + } + shapes_for_notify = { + gi: {fname: fmeta.get("shape") for fname, fmeta in field_schema.items()} for gi in metadata.global_indexes + } + backend_meta = {gi: {"_su_id": storage_id} for storage_id, gi_list in su_to_gis.items() for gi in gi_list} await self.notify_data_update( - partition_id, list(results.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes + partition_id, + list(data.keys()), + metadata.global_indexes, + dtypes_for_notify, + shapes_for_notify, + custom_backend_meta=backend_meta, ) + async def _prepare_and_send_to_unit( + self, + unit_idx: int, + storage_id: str, + chunk_size: int, + batch_size: int, + start_offset: int, + num_units: int, + data: TensorDict, + metadata: BatchMeta, + ) -> None: + """Prepare data slice and send to a single storage unit. + + All operations use O(1) slicing. Returns early if this unit has no data assigned. + """ + rotated_idx = (unit_idx - start_offset) % num_units + start = rotated_idx * chunk_size + end = min((rotated_idx + 1) * chunk_size, batch_size) + + if start >= batch_size or start >= end: + return + + # global_index is used directly as dict key in storage, no local_index conversion needed + global_indexes_slice = metadata.global_indexes[start:end] + local_indexes = list(global_indexes_slice) + + storage_data = {} + for fname in data.keys(): + field_data = data[fname] + if isinstance(field_data, torch.Tensor) and field_data.is_nested: + # CPU NestedTensor does not support slicing; unbind first then index + unbound = field_data.unbind() + storage_data[fname] = unbound[start:end] + else: + storage_data[fname] = field_data[start:end] + + await self._put_to_single_storage_unit(local_indexes, storage_data, target_storage_unit=storage_id) + + def _extract_field_schema(self, data: TensorDict) -> dict[str, dict[str, Any]]: + """Extract field-level schema from TensorDict. O(F) complexity.""" + field_schema: dict[str, dict[str, Any]] = {} + + for field_name in data.keys(): + field_data = data[field_name] + + # NestedTensor does not support len()/indexing; check is_nested then unbind + is_tensor = isinstance(field_data, torch.Tensor) + is_nested = is_tensor and field_data.is_nested + + if is_nested: + unbound = field_data.unbind() + first_item = unbound[0] if unbound else None + elif is_tensor: + first_item = field_data[0] if field_data.shape[0] > 0 else None + else: + first_item = field_data[0] if len(field_data) > 0 else None + + is_non_tensor = not isinstance(first_item, torch.Tensor) if first_item is not None else False + + field_meta = { + "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), + "shape": getattr(first_item, "shape", None) if is_tensor and not is_nested else None, + "is_nested": is_nested, + "is_non_tensor": is_non_tensor, + } + + if is_nested: + field_meta["per_sample_shapes"] = [tuple(t.shape) for t in unbound] + + field_schema[field_name] = field_meta + + return field_schema + @dynamic_storage_manager_socket(socket_name="put_get_socket") async def _put_to_single_storage_unit( self, @@ -286,6 +347,9 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: """ Retrieve data from remote StorageUnit based on metadata. + Routes to each SU using the _su_id recorded in metadata._custom_backend_meta + at put time. No re-computation of block allocation. + Args: metadata: BatchMeta that contains metadata for data retrieval. @@ -295,20 +359,25 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: logger.debug(f"[{self.storage_manager_id}]: receive get_data request, getting {metadata.size} samples.") - # group samples by storage unit - storage_meta_groups = build_storage_meta_groups( - metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping - ) + if metadata.size == 0: + return TensorDict({}, batch_size=0) + + groups: dict[str, list[int]] = defaultdict(list) + for i, gi in enumerate(metadata.global_indexes): + backend_meta = metadata._custom_backend_meta[i] + if not backend_meta or "_su_id" not in backend_meta: + raise RuntimeError( + f"get_data: missing _su_id for global_index {gi} in _custom_backend_meta. " + f"Make sure put_data was called before get_data." + ) + groups[backend_meta["_su_id"]].append(gi) - # retrieve data tasks = [ - self._get_from_single_storage_unit(meta_group, target_storage_unit=storage_id) - for storage_id, meta_group in storage_meta_groups.items() + self._get_from_single_storage_unit(gi_list, metadata.field_names, target_storage_unit=su_id) + for su_id, gi_list in groups.items() ] - results = await asyncio.gather(*tasks) - # post-process data segments to generate a batch of data merged_data: dict[int, dict[str, torch.Tensor]] = {} for global_indexes, fields, data_from_single_storage_unit, messages in results: field_getter = itemgetter(*fields) @@ -329,7 +398,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: ordered_data[field] = [merged_data[global_idx][field] for global_idx in metadata.global_indexes] # In the final packing stage we intentionally perform a memory copy through torch.stack and as_nested_tensor. - # This detaches the received tensors from the original zero‑copy buffers, + # This detaches the received tensors from the original zero-copy buffers, # gives them their own lifetime, and ensures the resulting tensors are writable. tensor_data = { field: ( @@ -350,17 +419,18 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: @dynamic_storage_manager_socket(socket_name="put_get_socket") async def _get_from_single_storage_unit( - self, storage_meta_group: StorageMetaGroup, target_storage_unit: str, socket: zmq.Socket = None + self, + gi_list: list[int], + fields: list[str], + target_storage_unit: str, + socket: zmq.Socket = None, ): - global_indexes = storage_meta_group.get_global_indexes() - local_indexes = storage_meta_group.get_local_indexes() - fields = storage_meta_group.get_field_names() - + """Get data from a single SU by gi keys.""" request_msg = ZMQMessage.create( request_type=ZMQRequestType.GET_DATA, # type: ignore[arg-type] sender_id=self.storage_manager_id, receiver_id=target_storage_unit, - body={"local_indexes": local_indexes, "fields": fields}, + body={"local_indexes": gi_list, "fields": fields}, ) try: await socket.send_multipart(request_msg.serialize()) @@ -372,7 +442,7 @@ async def _get_from_single_storage_unit( # We need to return messages to get_data() since the zero-copy deserialization directly points to the # memory of messages object. storage_unit_data = response_msg.body["data"] - return global_indexes, fields, storage_unit_data, messages + return gi_list, fields, storage_unit_data, messages else: raise RuntimeError( f"Failed to get data from storage unit {target_storage_unit}: " @@ -384,21 +454,29 @@ async def _get_from_single_storage_unit( async def clear_data(self, metadata: BatchMeta) -> None: """Clear data in remote StorageUnit. + Routes to each SU using the _su_id recorded in metadata._custom_backend_meta. + Args: metadata: BatchMeta that contains metadata for data clearing. """ logger.debug(f"[{self.storage_manager_id}]: receive clear_data request, clearing {metadata.size} samples.") - # group samples by storage unit - storage_meta_groups = build_storage_meta_groups( - metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping - ) + if metadata.size == 0: + return + + groups: dict[str, list[int]] = defaultdict(list) + for i, gi in enumerate(metadata.global_indexes): + backend_meta = metadata._custom_backend_meta[i] + if not backend_meta or "_su_id" not in backend_meta: + raise RuntimeError( + f"clear_data: missing _su_id for global_index {gi} in _custom_backend_meta. " + f"Make sure put_data was called before clear_data." + ) + groups[backend_meta["_su_id"]].append(gi) - # clear data tasks = [ - self._clear_single_storage_unit(meta_group.get_local_indexes(), target_storage_unit=storage_id) - for storage_id, meta_group in storage_meta_groups.items() + self._clear_single_storage_unit(gi_list, target_storage_unit=su_id) for su_id, gi_list in groups.items() ] results = await asyncio.gather(*tasks, return_exceptions=True) @@ -442,115 +520,3 @@ def get_zmq_server_info(self) -> dict[str, ZMQServerInfo]: def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" super().close() - - -def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: dict) -> dict[str, Any]: - """Filter batch-aligned data from a dict using batch indexes from a StorageMetaGroup. - This helper extracts a subset of items from each field in ``data`` according to the - batch indexes stored in ``storage_meta_group``. The same indexes are applied to every - field in the input dict so that the returned samples remain aligned across - fields. - - Args: - storage_meta_group: A :class:`StorageMetaGroup` instance that provides - a sequence of batch indexes via :meth:`get_batch_indexes`. Each index - refers to a position along the batch dimension of the tensors stored - in ``data``. - data: A dict containing batched data fields. All fields are expected to - be indexable by the batch indexes returned by - ``storage_meta_group.get_batch_indexes()``. - Returns: - dict[str, Any]: A dictionary mapping each field name in ``data`` to a list - of items selected at the requested batch indexes. The order of items in - each list matches the order of ``storage_meta_group.get_batch_indexes()``. - """ - - # We use dict here instead of TensorDict to avoid unnecessary TensorDict overhead - results: dict[str, Any] = {} - batch_indexes = storage_meta_group.get_batch_indexes() - - if not batch_indexes: - return results - - for fname in data.keys(): - field_data = data[fname] - result = itemgetter(*batch_indexes)(field_data) - - if not isinstance(result, tuple): - result = (result,) - results[fname] = list(result) - - if not TQ_ZERO_COPY_SERIALIZATION: - # Explicitly copy tensor slices to prevent pickling the whole tensor for every storage unit. - # The tensors may still be contiguous, so we cannot use .contiguous() to trigger copy from parent tensors. - results[fname] = [item.clone() if isinstance(item, torch.Tensor) else item for item in results[fname]] - return results - - -def build_storage_meta_groups( - batch_meta: BatchMeta, - global_index_storage_unit_mapping: Callable, - global_index_local_index_mapping: Callable, -) -> dict[str, StorageMetaGroup]: - """Build storage meta groups from batch metadata for distributed storage. - - This function is the starting point of the data distribution workflow. It analyzes - BatchMeta containing SampleMeta objects (originating from client requests) and - groups them by target storage unit based on their global_index. - - Key Data Flow: - 1. BatchMeta contains SampleMeta objects with batch_index (original TensorDict position) - 2. Each SampleMeta is assigned to a storage unit using global_index mapping - 3. Local storage positions are calculated for each sample - 4. Results in StorageMetaGroup objects ready for transfer operations - - Args: - batch_meta: BatchMeta containing SampleMeta objects from client request. - Each SampleMeta has: - - batch_index: Position in original TensorDict (0-based) - - global_index: Global unique identifier across all storage - global_index_storage_unit_mapping: Function to map global_index to storage_unit_id. - Example: lambda x: storage_unit_ids[x % num_storage_units] (round-robin distribution) - global_index_local_index_mapping: Function to map global_index to local_index. - Example: lambda x: x // num_storage_units (local position within storage unit) - - Returns: - Dictionary mapping storage_unit_id to StorageMetaGroup, where each group contains: - - storage_id: Target storage unit identifier - - sample_metas: List of SampleMeta objects assigned to this unit - - local_indexes: List of storage positions for each sample - - Example: - >>> # Input: BatchMeta with samples at global_indexes [10, 11, 12] - >>> # 3 storage units available: storage_0, storage_1, storage_2 - >>> batch_meta = BatchMeta(samples=[ - ... SampleMeta(batch_index=0, global_index=10), # Original position 0 - ... SampleMeta(batch_index=1, global_index=11), # Original position 1 - ... SampleMeta(batch_index=2, global_index=12) # Original position 2 - ... ]) - >>> groups = build_storage_meta_groups( - ... batch_meta, - ... lambda x: f"storage_{x % 3}", # 10->storage_1, 11->storage_2, 12->storage_0 - ... lambda x: x // 3 # 10->3, 11->3, 12->4 - ... ) - >>> groups["storage_1"].sample_metas[0].batch_index # 0 - original TensorDict position - >>> groups["storage_1"].sample_metas[0].local_index # 3 - storage position - - Note: - This function preserves the crucial batch_index information that links each - SampleMeta back to its original position in the client's TensorDict. - This batch_index is later used by _add_field_data() to extract - the correct data items for storage. - """ - storage_meta_groups: dict[str, StorageMetaGroup] = {} - - for sample in batch_meta.samples: - storage_id = global_index_storage_unit_mapping(sample.global_index) - local_index = global_index_local_index_mapping(sample.global_index) - if storage_id not in storage_meta_groups: - storage_meta_groups[storage_id] = StorageMetaGroup(storage_id=storage_id) - - # Use add_sample_meta to store SampleMeta references directly - storage_meta_groups[storage_id].add_sample_meta(sample, local_index) - - return storage_meta_groups diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index 37dcca4..be72edb 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -17,7 +17,6 @@ import logging import os from dataclasses import dataclass -from operator import itemgetter from threading import Thread from typing import Any from uuid import uuid4 @@ -26,7 +25,6 @@ import zmq from ray.util import get_node_ip_address -from transfer_queue.metadata import SampleMeta from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads from transfer_queue.utils.enum_utils import TransferQueueRole from transfer_queue.utils.perf_utils import IntervalPerfMonitor @@ -48,100 +46,76 @@ class StorageUnitData: """Storage unit for managing 2D data structure (samples × fields). - This class provides efficient storage and retrieval of data in a 2D matrix format - where rows represent samples (indexed by local_index) and columns represent fields. - Each field contains a list of data items indexed by their local position. + Uses dict-based storage keyed by global_index (gi) instead of pre-allocated list. + This allows O(1) insert/delete without index translation and avoids capacity bloat. Data Structure Example: - ┌─────────────┬─────────────┬─────────────┬─────────┐ - │ local_index │ field_name1 │ field_name2 │ ... │ - ├─────────────┼─────────────┼─────────────┼─────────┤ - │ 0 │ item1 │ item2 │ ... │ - │ 1 │ item3 │ item4 │ ... │ - │ 2 │ item5 │ item6 │ ... │ - └─────────────┴─────────────┴─────────────┴─────────┘ + field_data = { + "field_name1": {gi0: item1, gi3: item2, ...}, + "field_name2": {gi0: item3, gi3: item4, ...}, + } """ def __init__(self, storage_size: int): - # Dict containing field names and corresponding data in the field - # Format: {"field_name": [data_at_index_0, data_at_index_1, ...]} - self.field_data: dict[str, list] = {} - - # Maximum number of elements stored in storage unit + # field_name -> {gi: data} nested dict + self.field_data: dict[str, dict] = {} + # Capacity upper bound (not pre-allocated list length) self.storage_size = storage_size - def get_data(self, fields: list[str], local_indexes: list[int]) -> dict[str, list]: - """ - Get data from storage unit according to given fields and local_indexes. + def get_data(self, fields: list[str], local_keys: list) -> dict[str, list]: + """Get data by gi keys. Args: fields: Field names used for getting data. - local_indexes: Local indexes used for getting data. + local_keys: Global indexes used as dict keys. Returns: dict with field names as keys, corresponding data list as values. """ result: dict[str, list] = {} - for field in fields: - # Validate field name if field not in self.field_data: raise ValueError( - f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}" + f"StorageUnitData get_data: field '{field}' not found. Available: {list(self.field_data.keys())}" ) - - if len(local_indexes) == 1: - gathered_item = self.field_data[field][local_indexes[0]] - result[field] = [gathered_item] - - else: - gathered_items = list(itemgetter(*local_indexes)(self.field_data[field])) - - result[field] = gathered_items - + try: + result[field] = [self.field_data[field][k] for k in local_keys] + except KeyError as e: + raise KeyError(f"StorageUnitData get_data: key {e} not found in field '{field}'") from e return result - def put_data(self, field_data: dict[str, Any], local_indexes: list[int]) -> None: - """ - Put or update data into storage unit according to given field_data and local_indexes. + def put_data(self, field_data: dict[str, Any], local_keys: list) -> None: + """Put data into storage. local_keys are global_indexes used as dict keys. Args: - field_data: Dict with field names as keys, corresponding data in the field as values. - local_indexes: Local indexes used for putting data. + field_data: Dict with field names as keys, data list as values. + local_keys: Global indexes to use as dict keys. """ - + # Capacity is enforced per unique sample key, not counted per-field + existing_keys: set = set() + for fd in self.field_data.values(): + existing_keys.update(fd.keys()) + new_global_keys = [k for k in local_keys if k not in existing_keys] + if len(existing_keys) + len(new_global_keys) > self.storage_size: + raise ValueError( + f"Storage capacity exceeded: {len(existing_keys)} existing + " + f"{len(new_global_keys)} new > {self.storage_size}" + ) for f, values in field_data.items(): if f not in self.field_data: - self.field_data[f] = [None] * self.storage_size - - for i, idx in enumerate(local_indexes): - if idx < 0 or idx >= self.storage_size: - raise ValueError( - f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond " - f"storage_size: {self.storage_size}" - ) + self.field_data[f] = {} + for key, val in zip(local_keys, values, strict=False): + self.field_data[f][key] = val - self.field_data[f][idx] = values[i] - - def clear(self, local_indexes: list[int]) -> None: - """ - Clear data at specified local_indexes by setting all related fields to None. + def clear(self, keys: list[int]) -> None: + """Remove data at given global index keys, immediately freeing memory. Args: - local_indexes: local_indexes to clear. + keys: Global indexes to remove. """ - # Validate local_indexes - for idx in local_indexes: - if idx < 0 or idx >= self.storage_size: - raise ValueError( - f"StorageUnitData clear operation receive invalid local_index: {idx} beyond " - f"storage_size: {self.storage_size}" - ) - - # Clear data at specified local_indexes for f in self.field_data: - for idx in local_indexes: - self.field_data[f][idx] = None + for key in keys: + self.field_data[f].pop(key, None) @ray.remote(num_cpus=1) @@ -241,7 +215,7 @@ def _process_put_get(self) -> None: response_msg = self._handle_clear(request_msg) else: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, + request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Storage unit id #{self.storage_unit_id} " @@ -250,7 +224,7 @@ def _process_put_get(self) -> None: ) except Exception as e: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_ERROR, + request_type=ZMQRequestType.PUT_GET_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Storage unit id #{self.storage_unit_id} occur error in processing " @@ -280,13 +254,15 @@ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: # After put operation finish, send a message to the client response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.storage_unit_id, body={} + request_type=ZMQRequestType.PUT_DATA_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={}, ) return response_msg except Exception as e: return ZMQMessage.create( - request_type=ZMQRequestType.PUT_ERROR, + request_type=ZMQRequestType.PUT_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Failed to put data into storage unit id " @@ -314,7 +290,7 @@ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: result_data = self.storage_data.get_data(fields, local_indexes) response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA_RESPONSE, + request_type=ZMQRequestType.GET_DATA_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "data": result_data, @@ -322,7 +298,7 @@ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: ) except Exception as e: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_ERROR, + request_type=ZMQRequestType.GET_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Failed to get data from storage unit id #{self.storage_unit_id}, " @@ -350,13 +326,13 @@ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: self.storage_data.clear(local_indexes) response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, + request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={"message": f"Clear data in storage unit id #{self.storage_unit_id} successfully."}, ) except Exception as e: response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_ERROR, + request_type=ZMQRequestType.CLEAR_DATA_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ "message": f"Failed to clear data in storage unit id #{self.storage_unit_id}, " @@ -376,48 +352,52 @@ def get_zmq_server_info(self) -> ZMQServerInfo: @dataclass class StorageMetaGroup: - """ - Represents a group of samples stored in the same storage unit. - Used to organize samples by their storage_id for efficient client operations. - """ + """Group of metadata for a specific storage unit.""" storage_id: str - sample_metas: list[SampleMeta] = dataclasses.field(default_factory=list) - local_indexes: list[int] = dataclasses.field(default_factory=list) + global_indexes: list[int] = dataclasses.field(default_factory=list) + partition_ids: list[str] = dataclasses.field(default_factory=list) + batch_indexes: list[int] = dataclasses.field(default_factory=list) # Original TensorDict positions + field_names: list[str] = dataclasses.field(default_factory=list) # Field names from BatchMeta - def add_sample_meta(self, sample_meta: SampleMeta, local_index: int) -> None: - """Add a SampleMeta object to this storage group""" - self.sample_metas.append(sample_meta) - self.local_indexes.append(local_index) + def add_meta(self, global_index: int, partition_id: str, batch_index: int | None = None): + """Add metadata to the group. - def get_batch_indexes(self) -> list[int]: - """Get all internal indexes from stored SampleMeta objects""" - return [meta.batch_index for meta in self.sample_metas] + Args: + global_index: Global unique index across all storage + partition_id: Partition identifier + batch_index: Original position in input TensorDict (optional) + """ + self.global_indexes.append(global_index) + self.partition_ids.append(partition_id) + if batch_index is not None: + self.batch_indexes.append(batch_index) def get_global_indexes(self) -> list[int]: - """Get all global indexes from stored SampleMeta objects""" - return [meta.global_index for meta in self.sample_metas] + """Get all global indexes from stored samples""" + return self.global_indexes + + def get_storage_keys(self) -> list[int]: + """Return global indexes used as storage dict keys.""" + return self.global_indexes - def get_local_indexes(self) -> list[int]: - """Get all local indexes from stored SampleMeta objects""" - return self.local_indexes + def get_batch_indexes(self) -> list[int]: + """Get original TensorDict position indexes for _filter_storage_data.""" + return self.batch_indexes def get_field_names(self) -> list[str]: - """Get all unique field names from stored SampleMeta objects""" - all_fields: set[str] = set() - for meta in self.sample_metas: - all_fields.update(meta.fields.keys()) - return list(all_fields) + """Get all field names for this storage group.""" + return self.field_names @property def size(self) -> int: """Number of samples in this storage meta group""" - return len(self.sample_metas) + return len(self.global_indexes) @property def is_empty(self) -> bool: """Check if this storage meta group is empty""" - return len(self.sample_metas) == 0 + return len(self.global_indexes) == 0 def __len__(self) -> int: """Number of samples in this storage meta group""" diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index aa74009..04139e2 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -35,6 +35,8 @@ CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_TENSOR = 3 # For tensor with buffer reference CUSTOM_TYPE_NESTED_TENSOR = 4 # For nested tensor (strided or jagged) +CUSTOM_TYPE_BATCHMETA = 5 # For BatchMeta serialization +CUSTOM_TYPE_NUMPY = 6 # For numpy ndarray with buffer reference bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame @@ -69,6 +71,9 @@ def aux_buffers(self) -> list[bytestr]: def encode(self, obj: Any) -> Sequence[bytestr]: """Encode a given object to a byte array.""" + # Pre-process to convert BatchMeta to Ext; msgspec auto-serializes dataclasses and won't call enc_hook for them. + obj = self._preprocess_for_batchmeta(obj) + bufs: list[bytestr] = [b""] token = _encoder_aux_buffers.set(bufs) try: @@ -81,6 +86,24 @@ def encode(self, obj: Any) -> Sequence[bytestr]: finally: _encoder_aux_buffers.reset(token) + def _preprocess_for_batchmeta(self, obj: Any) -> Any: + """Recursively preprocess object to convert BatchMeta to Ext. + + This is necessary because msgspec auto-serializes dataclasses and + won't call enc_hook for them. + """ + from transfer_queue.metadata import BatchMeta + + if isinstance(obj, BatchMeta): + return self._encode_batchmeta(obj) + elif isinstance(obj, dict): + return {k: self._preprocess_for_batchmeta(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._preprocess_for_batchmeta(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(self._preprocess_for_batchmeta(item) for item in obj) + return obj + def enc_hook(self, obj: Any) -> Any: """Custom encoding hook for types msgspec doesn't natively support. @@ -88,6 +111,9 @@ def enc_hook(self, obj: Any) -> Any: - torch.Tensor: Extract buffer, store metadata - TensorDict: Convert to dict structure for recursive processing - numpy.ndarray: Convert to tensor for unified handling + + Note: BatchMeta is handled by _preprocess_for_batchmeta() before encode() is called, + so it will never reach this hook. """ if isinstance(obj, torch.Tensor): return self._encode_tensor(obj) @@ -96,17 +122,15 @@ def enc_hook(self, obj: Any) -> Any: if isinstance(obj, TensorDictBase): return self._encode_tensordict(obj) - # Handle numpy arrays by converting to tensor - # Only numeric dtypes are supported by torch.from_numpy: - # f=float, i=signed int, u=unsigned int, b=bool, c=complex + # Numpy arrays: serialize natively unless the dtype contains Python objects. if isinstance(obj, np.ndarray): - if obj.dtype.kind in ("f", "i", "u", "b", "c"): + if obj.dtype.kind != "O" and not obj.dtype.hasobject: try: - return self._encode_tensor(torch.from_numpy(obj)) - except (TypeError, RuntimeError): - # Fallback to pickle for unsupported dtypes (e.g., float16 on some platforms) + return self._encode_numpy(obj) + except (TypeError, RuntimeError, ValueError): + # Fallback to pickle for platforms that don't support the view pass - # For object arrays, strings, or other unsupported types, use pickle + # Only true object arrays (or structured dtypes with object fields) reach here return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) if isinstance(obj, FunctionType): @@ -116,6 +140,14 @@ def enc_hook(self, obj: Any) -> Any: # Fallback to pickle for unknown types return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) + def _encode_batchmeta(self, obj: Any) -> msgpack.Ext: + """Encode BatchMeta for serialization. + + BatchMeta is small, so we serialize it via pickle (which handles torch.dtype natively). + """ + meta_dict = obj.to_dict() + return msgpack.Ext(CUSTOM_TYPE_BATCHMETA, pickle.dumps(meta_dict, protocol=pickle.HIGHEST_PROTOCOL)) + def _encode_tensordict(self, obj: Any) -> dict: """Convert TensorDict to a dict structure for recursive msgpack processing. @@ -134,16 +166,7 @@ def _encode_tensordict(self, obj: Any) -> dict: } def _encode_tensor(self, obj: torch.Tensor) -> msgpack.Ext: - """Encode tensor with zero-copy buffer extraction. - - Features: - - Auto GPU->CPU conversion - - Auto contiguous conversion - - Direct memoryview extraction via uint8 view (for BFloat16 support) - - Nested tensors: unbind and serialize each sub-tensor with zero-copy - - Returns Ext type so decoding goes through ext_hook (which has buffer access). - """ + """Encode tensor with zero-copy buffer extraction (handles GPU, non-contiguous, nested).""" assert len(self.aux_buffers) > 0 # Handle nested tensors (strided or jagged) via unbind @@ -218,6 +241,20 @@ def _encode_regular_tensor(self, obj: torch.Tensor) -> msgpack.Ext: meta = (dtype, tuple(obj.shape), idx) return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(meta, protocol=pickle.HIGHEST_PROTOCOL)) + def _encode_numpy(self, obj: np.ndarray) -> msgpack.Ext: + """Encode numpy array with zero-copy buffer extraction.""" + # Ensure C-contiguous layout; no-op when already contiguous + if not obj.flags["C_CONTIGUOUS"]: + obj = np.ascontiguousarray(obj) + + # Byte-level view as uint8 then ravel → 1-D C-contiguous raw-bytes array + buf = memoryview(obj.view(np.uint8).ravel()) + idx = len(self.aux_buffers) + self.aux_buffers.append(buf) + + meta = (str(obj.dtype), tuple(obj.shape), idx) + return msgpack.Ext(CUSTOM_TYPE_NUMPY, pickle.dumps(meta, protocol=pickle.HIGHEST_PROTOCOL)) + class MsgpackDecoder: """Decoder with custom torch tensor and numpy array serialization. @@ -307,6 +344,19 @@ def _decode_nested_tensor(self, nested_meta: dict) -> torch.Tensor: else: # strided return torch.nested.as_nested_tensor(sub_tensors, layout=torch.strided) + def _decode_numpy(self, meta: tuple) -> np.ndarray: + """Decode numpy array from (dtype_str, shape, buffer_idx) tuple.""" + dtype_str, shape, idx = meta + buffer = self.aux_buffers[idx] + np_dtype = np.dtype(dtype_str) + + if not buffer: # empty array + return np.empty(shape, dtype=np_dtype) + + # Reconstruct from raw bytes: uint8 view → reinterpret as original dtype + arr = np.frombuffer(buffer, dtype=np.uint8) + return arr.view(np_dtype).reshape(shape) + def ext_hook(self, code: int, data: memoryview) -> Any: """Custom decoding hook for types msgspec doesn't natively support. @@ -314,6 +364,7 @@ def ext_hook(self, code: int, data: memoryview) -> Any: - torch.Tensor: Extract buffer, store metadata - TensorDict: Convert to dict structure for recursive processing - numpy.ndarray: Convert to tensor for unified handling + - BatchMeta: Reconstruct from pickle """ if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) @@ -325,6 +376,14 @@ def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_NESTED_TENSOR: nested_meta = pickle.loads(data) return self._decode_nested_tensor(nested_meta) + if code == CUSTOM_TYPE_BATCHMETA: + from transfer_queue.metadata import BatchMeta + + meta_dict = pickle.loads(data) + return BatchMeta.from_dict(meta_dict) + if code == CUSTOM_TYPE_NUMPY: + meta = pickle.loads(data) + return self._decode_numpy(meta) raise NotImplementedError(f"Extension type code {code} is not supported") diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index eaaf65e..625d85d 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -26,9 +26,6 @@ import ray import zmq -from transfer_queue.utils.common import ( - get_env_bool, -) from transfer_queue.utils.enum_utils import ExplicitEnum, TransferQueueRole from transfer_queue.utils.serial_utils import _decoder, _encoder @@ -42,9 +39,10 @@ logger.addHandler(handler) -bytestr: TypeAlias = bytes | bytearray | memoryview +# 0xC1 is permanently reserved (invalid) in msgpack spec — safe to use as pickle fallback sentinel. +_PICKLE_FALLBACK_SENTINEL = b"\xc1\xfe\xed" -TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) +bytestr: TypeAlias = bytes | bytearray | memoryview class ZMQRequestType(ExplicitEnum): @@ -168,43 +166,45 @@ def create( ) def serialize(self) -> list: - """ - Serialize message using unified MsgpackEncoder or pickle. - Returns: list[bytestr] - [msgpack_header, *tensor_buffers] or [bytes] - """ - if TQ_ZERO_COPY_SERIALIZATION: - msg_dict = { - "request_type": self.request_type.value, # Enum -> str for msgpack - "sender_id": self.sender_id, - "receiver_id": self.receiver_id, - "request_id": self.request_id, - "timestamp": self.timestamp, - "body": self.body, - } + """Serialize using zero-copy msgpack; falls back to pickle for unsupported types.""" + msg_dict = { + "request_type": self.request_type.value, # Enum -> str for msgpack + "sender_id": self.sender_id, + "receiver_id": self.receiver_id, + "request_id": self.request_id, + "timestamp": self.timestamp, + "body": self.body, + } + try: return list(_encoder.encode(msg_dict)) - else: - return [pickle.dumps(self)] + except (TypeError, ValueError) as e: + # Pickle fallback is a normal degradation path (e.g. body contains torch.dtype objects). + # Log at INFO so operators are aware but not alarmed; use WARNING only for unexpected errors. + logger.info( + "ZMQMessage.serialize: msgpack encoding unsupported type (%s), using pickle fallback.", + type(e).__name__, + ) + return [_PICKLE_FALLBACK_SENTINEL, pickle.dumps(self)] @classmethod def deserialize(cls, frames: list) -> "ZMQMessage": - """ - Deserialize message using unified MsgpackDecoder or pickle. - """ + """Deserialize: choose decoding path based on the first frame marker (zero-copy or pickle fallback).""" if not frames: raise ValueError("Empty frames received") - if TQ_ZERO_COPY_SERIALIZATION: - msg_dict = _decoder.decode(frames) - return cls( - request_type=ZMQRequestType(msg_dict["request_type"]), - sender_id=msg_dict["sender_id"], - receiver_id=msg_dict["receiver_id"], - body=msg_dict["body"], - request_id=msg_dict["request_id"], - timestamp=msg_dict["timestamp"], - ) - else: - return pickle.loads(frames[0]) + # pickle fallback path: serialize() sets frame[0] to _PICKLE_FALLBACK_SENTINEL on failure. + if len(frames) >= 2 and frames[0] == _PICKLE_FALLBACK_SENTINEL: + return pickle.loads(frames[1]) + + msg_dict = _decoder.decode(frames) + return cls( + request_type=ZMQRequestType(msg_dict["request_type"]), + sender_id=msg_dict["sender_id"], + receiver_id=msg_dict["receiver_id"], + body=msg_dict["body"], + request_id=msg_dict["request_id"], + timestamp=msg_dict["timestamp"], + ) def get_free_port() -> str: diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py index 2c81941..b8eb5d5 100644 --- a/tutorial/03_metadata_concepts.py +++ b/tutorial/03_metadata_concepts.py @@ -36,6 +36,7 @@ ) +import numpy as np # noqa: E402 import ray # noqa: E402 import torch # noqa: E402 from tensordict import TensorDict # noqa: E402 @@ -45,129 +46,161 @@ sys.path.append(str(parent_dir)) import transfer_queue as tq # noqa: E402 -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 -from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 +from transfer_queue.metadata import BatchMeta # noqa: E402 # Configure Ray os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" -def demonstrate_field_meta(): +def demonstrate_batch_meta_schema(): """ - Demonstrate FieldMeta - specific data fields of each training sample. + Demonstrate BatchMeta field_schema - field-level metadata for the batch. + After the columnar refactoring, field metadata is stored once at the batch level + (not per-sample). This is the O(F) optimized representation. """ print("=" * 80) - print("FieldMeta - Specific data fields of each training sample") + print("BatchMeta field_schema - Field-Level Metadata (O(F) columnar storage)") print("=" * 80) - print("FieldMeta represents a single field in ONE sample:") - print("- name: Field identifier ('Prompt', 'Response', etc.)") + print("field_schema stores metadata for each field ONCE per batch (not per sample):") print("- dtype: Data type (torch.float32, torch.int64, etc.)") print("- shape: Shape of ONE sample's data (NO batch dimension)") - print("- production_status: Whether data is ready (has been produced and written to the TQ backend)") - - # Example 1: Create a field for input_ids - print("[Example 1] Manually creating FieldMeta for input_ids...") - input_ids_field = FieldMeta( - name="input_ids", - dtype=torch.int64, - shape=(512,), # Sequence length for ONE sample - production_status=ProductionStatus.READY_FOR_CONSUME, + print("- is_nested: Whether the field uses nested/ragged tensors") + print("- is_non_tensor: Whether the field is non-tensor data") + + # Example 1: Create a field schema entry for input_ids (analogous to old FieldMeta) + print("[Example 1] Creating field schema entry for input_ids...") + batch = BatchMeta( + global_indexes=[0, 1, 2], + partition_ids=["train_0"] * 3, + field_schema={ + "input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + }, ) - print(f"✓ Created: {input_ids_field}") - print(f" Is ready: {input_ids_field.is_ready}") + print("✓ Created: BatchMeta with field 'input_ids'") + print(f" input_ids schema: {batch.field_schema['input_ids']}") + print(f" Is ready: {batch.is_ready}") print(" Note: Shape (512,) means ONE sample has 512 tokens (no batch dimension)") - # Example 2: Create a field for attention_mask - print("[Example 2] Creating FieldMeta for attention_mask...") - attention_mask_field = FieldMeta( - name="attention_mask", - dtype=torch.int64, - shape=(512,), # Sequence length for ONE sample - production_status=ProductionStatus.NOT_PRODUCED, + # Example 2: Create a field schema entry for attention_mask + print("[Example 2] Creating field schema entry for attention_mask...") + batch2 = BatchMeta( + global_indexes=[0, 1, 2], + partition_ids=["train_0"] * 3, + field_schema={ + "attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + }, ) - print(f"✓ Created: {attention_mask_field}") - print(f" Is ready: {attention_mask_field.is_ready}") + print("✓ Created: BatchMeta with field 'attention_mask'") + print(f" attention_mask schema: {batch2.field_schema['attention_mask']}") + print(f" Is ready: {batch2.is_ready}") - # Example 3: Check field readiness + # Example 3: Check field readiness via is_ready and production_status print("[Example 3] Checking field readiness...") - print(f" input_ids ready: {input_ids_field.is_ready}") - print(f" attention_mask ready: {attention_mask_field.is_ready}") - - -def demonstrate_sample_meta(): + ready_batch = BatchMeta( + global_indexes=[0, 1, 2], + partition_ids=["train_0"] * 3, + field_schema={ + "input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + }, + production_status=np.array([1, 1, 1], dtype="int8"), # 1 = READY_FOR_CONSUME + ) + print(f" input_ids field exists: {'input_ids' in ready_batch.field_schema}") + print(f" attention_mask field exists: {'attention_mask' in ready_batch.field_schema}") + print(f" not-ready batch is_ready: {batch.is_ready}") + print(f" ready batch is_ready: {ready_batch.is_ready}") + + # Example 4: Access per-sample view and individual field schema by key + print("[Example 4] Accessing sample view and individual field by key...") + view = ready_batch.samples[0] + print(f" batch.samples[0].fields -> {view.fields}") + print(f" batch.samples[0].fields['input_ids'] -> {view.fields['input_ids']}") + print(f" batch.samples[0].fields['input_ids']['dtype'] -> {view.fields['input_ids']['dtype']}") + print(" Note: view.fields returns the shared field_schema dict (same for all samples)") + print(f" For partition_id: use batch.partition_ids[0] = '{ready_batch.partition_ids[0]}'") + print(f" For global_index: use batch.global_indexes[0] = {ready_batch.global_indexes[0]}") + + +def demonstrate_batch_meta_construction(): """ - Demonstrate SampleMeta - describes a single data sample. + Demonstrate how to construct BatchMeta directly and operate on it + (analogous to old SampleMeta operations: add_fields, select_fields, union). """ print("=" * 80) - print("SampleMeta - Describing a Single Data Sample") + print("BatchMeta Construction & Operations") print("=" * 80) - print("SampleMeta represents ONE data sample:") - print("- partition_id: Which partition the sample belongs to") - print("- global_index: Unique identifier across ALL partitions") - print("- fields: Dict of FieldMeta objects (describing each field of THIS sample)") - - # Example 1: Manually create a sample - print("[Example 1] Creating a SampleMeta...") - fields = { - "input_ids": FieldMeta("input_ids", torch.int64, (512,)), - "attention_mask": FieldMeta("attention_mask", torch.int64, (512,)), - } - sample = SampleMeta(partition_id="train_0", global_index=0, fields=fields) - print(f"✓ Created: {sample}") - print(f" Partition: {sample.partition_id}") - print(f" Global index: {sample.global_index}") - print(f" Fields: {sample.field_names}") - print(f" Is ready: {sample.is_ready}") - - # Example 2: Manually add fields to a sample - print("[Example 2] Adding fields to a sample...") - new_fields = { - "responses": FieldMeta("responses", torch.int64, (128,)), - "log_probs": FieldMeta("log_probs", torch.float32, (128,)), - } - sample.add_fields(new_fields) - print(f"✓ Added fields: {list(new_fields.keys())}") - print(f" Now has fields: {sample.field_names}") - print(f" Is ready: {sample.is_ready}") - - # Example 3: Select specific fields + print("BatchMeta now uses a columnar layout:") + print("- global_indexes: list[int] - unique IDs across ALL partitions") + print("- partition_ids: list[str] - which partition each sample belongs to") + print("- field_schema: dict[str, dict] - field metadata (stored ONCE, not per-sample)") + + # Example 1: Manually create a BatchMeta (analogous to old SampleMeta construction) + print("[Example 1] Creating a BatchMeta with input_ids and attention_mask...") + batch = BatchMeta( + global_indexes=[0, 1, 2, 3, 4], + partition_ids=["train_0"] * 5, + field_schema={ + "input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + }, + ) + print(f"✓ Created: {len(batch)} samples") + print(f" Partition IDs: {batch.partition_ids}") + print(f" Global indexes: {batch.global_indexes}") + print(f" Fields: {batch.field_names}") + print(f" Is ready: {batch.is_ready}") + + # Example 2: add_fields - add new fields from real tensor data (analogous to sample.add_fields) + print("[Example 2] Adding new fields via add_fields(TensorDict)...") + new_data = TensorDict( + { + "responses": torch.randint(0, 1000, (5, 128)), + "log_probs": torch.randn(5, 128), + }, + batch_size=5, + ) + batch.add_fields(new_data) # infers dtype/shape from actual tensors, sets all ready + print("✓ Added fields: ['responses', 'log_probs']") + print(f" Now has fields: {batch.field_names}") + print(f" Is ready: {batch.is_ready} (add_fields sets all to READY_FOR_CONSUME by default)") + + # Example 3: select_fields - select specific fields (analogous to sample.select_fields) print("[Example 3] Selecting specific fields...") - selected_sample = sample.select_fields(["input_ids", "responses"]) - print(f"✓ Selected fields: {selected_sample.field_names}") - print(f" Original fields: {sample.field_names}") - - # Example 4: Union two samples - print("[Example 4] Unioning two samples...") - print(" IMPORTANT: Union requires samples to have IDENTICAL partition_id and global_index!") - sample1 = SampleMeta( - partition_id="train_0", - global_index=5, - fields={ - "input_ids": FieldMeta("input_ids", torch.int64, (512,)), - "attention_mask": FieldMeta("attention_mask", torch.int64, (512,)), + selected = batch.select_fields(["input_ids", "responses"]) + print(f"✓ Selected fields: {selected.field_names}") + print(f" Original fields: {batch.field_names}") + + # Example 4: union - merge two batches with different global_indexes (new columnar semantics) + print("[Example 4] Unioning two batches with different global_indexes...") + print(" IMPORTANT: new union semantics = concat unique samples (by global_index)") + batch_a = BatchMeta( + global_indexes=[0, 1, 2], + partition_ids=["train_0"] * 3, + field_schema={ + "input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, }, ) - sample2 = SampleMeta( - partition_id="train_0", - global_index=5, # Same global index! - fields={ - "responses": FieldMeta("responses", torch.int64, (128,)), - "log_probs": FieldMeta("log_probs", torch.float32, (128,)), + batch_b = BatchMeta( + global_indexes=[2, 3, 4], # global_index=2 overlaps with batch_a + partition_ids=["train_0"] * 3, + field_schema={ + "input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, }, ) - print(f" Sample1: partition={sample1.partition_id}, global_index={sample1.global_index}") - print(f" Sample2: partition={sample2.partition_id}, global_index={sample2.global_index}") + print(f" BatchA indexes: {batch_a.global_indexes}") + print(f" BatchB indexes: {batch_b.global_indexes}") + unioned = batch_a.union(batch_b) + print(f"✓ Union result indexes: {unioned.global_indexes} (global_index=2 deduplicated)") - try: - unioned = sample1.union(sample2) - print("✓ Union successful!") - print(f" Unioned fields: {unioned.field_names}") - print(f" Global index preserved: {unioned.global_index}") - except ValueError as e: - print(f"✗ Union failed: {e}") + # Example 5: Empty BatchMeta + print("[Example 5] Creating an empty BatchMeta (for initializing before data arrives)...") + empty = BatchMeta.empty() + print(f"✓ Empty BatchMeta: size={empty.size}, is_ready={empty.is_ready}") def demonstrate_batch_meta(): @@ -175,35 +208,45 @@ def demonstrate_batch_meta(): Demonstrate BatchMeta - describes a batch of samples with operations. """ print("=" * 80) - print("BatchMeta - Describing a Batch of Samples") + print("BatchMeta - Operations on Batch") print("=" * 80) - print("BatchMeta represents a collection of samples:") - print("- samples: List of SampleMeta objects") - print("- extra_info: Additional batch-level information") - print("- Provides operations: chunk, concat, union, select, reorder") + print("BatchMeta represents a collection of data samples with operations:") + print("- global_indexes: list of global sample indices") + print("- partition_ids: list of partition IDs per sample") + print("- field_schema: field-level metadata (stored once at batch level)") + print("- Operations: chunk, concat, union, select_samples, select_fields, reorder") + + # Helper to create a BatchMeta + def make_batch(global_indexes, fields=None): + if fields is None: + fields = ["input_ids", "attention_mask", "responses"] + schema = { + "input_ids": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "attention_mask": {"dtype": torch.int64, "shape": (512,), "is_nested": False, "is_non_tensor": False}, + "responses": {"dtype": torch.int64, "shape": (128,), "is_nested": False, "is_non_tensor": False}, + } + return BatchMeta( + global_indexes=global_indexes, + partition_ids=["train_0"] * len(global_indexes), + field_schema={k: v for k, v in schema.items() if k in fields}, + ) - # Example 1: Manually create a batch + # Example 1: Create a batch print("[Example 1] Creating a BatchMeta...") - fields = { - "input_ids": FieldMeta("input_ids", torch.int64, (512,)), - "attention_mask": FieldMeta("attention_mask", torch.int64, (512,)), - "responses": FieldMeta("responses", torch.int64, (128,)), - } - samples = [SampleMeta(partition_id="train_0", global_index=i, fields=fields) for i in range(5)] - batch = BatchMeta(samples=samples) + batch = make_batch(list(range(5))) print(f"✓ Created batch with {len(batch)} samples") print(f" Global indexes: {batch.global_indexes}") print(f" Field names: {batch.field_names}") print(f" Size: {batch.size}") - # Example 2: Add extra_info + # Example 2: Add extra_info (batch-level) print("[Example 2] Adding batch-level information through extra_info...") - print("Note: The extra info will not be stored into TransferQueueController.") batch.extra_info["epoch"] = 1 batch.extra_info["batch_idx"] = 0 print(f"✓ Extra info: {batch.get_all_extra_info()}") + # Example 3: update_custom_meta (list aligned with global_indexes) print("[Example 3] Adding sample-level information through custom_meta...") batch.update_custom_meta( [ @@ -216,82 +259,55 @@ def demonstrate_batch_meta(): ) print(f"✓ Custom meta: {batch.get_all_custom_meta()}") - # Example 4: Chunk a batch + # Example 4: Chunk print("[Example 4] Chunking a batch into parts...") chunks = batch.chunk(3) print(f"✓ Split into {len(chunks)} chunks:") for i, chunk in enumerate(chunks): print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}") - # Example 5: Select specific fields + # Example 5: select_fields print("[Example 5] Selecting specific fields...") selected_batch = batch.select_fields(["input_ids", "responses"]) print(f"✓ Selected fields: {selected_batch.field_names}") print(f" Original fields: {batch.field_names}") - # Example 6: Select specific samples + # Example 6: select_samples print("[Example 6] Selecting specific samples...") selected_samples = batch.select_samples([0, 2, 4]) - print(f"✓ Selected samples at indexes: {selected_samples.global_indexes}") + print(f"✓ Selected samples at batch indices [0,2,4]: global_indexes={selected_samples.global_indexes}") - # Example 7: Reorder samples + # Example 7: reorder print("[Example 7] Reordering samples...") print(f" Original order: {batch.global_indexes}") batch.reorder([4, 3, 2, 1, 0]) print(f" After reorder: {batch.global_indexes}") - # Example 8: Concat batches + # Example 8: concat print("[Example 8] Concatenating batches...") - batch1 = BatchMeta(samples=[SampleMeta(partition_id="train_0", global_index=i, fields=fields) for i in range(3)]) - batch2 = BatchMeta(samples=[SampleMeta(partition_id="train_0", global_index=i, fields=fields) for i in range(3, 6)]) + batch1 = make_batch(list(range(3))) + batch2 = make_batch(list(range(3, 6))) concatenated = BatchMeta.concat([batch1, batch2]) print(f"✓ Concatenated {len(batch1)} + {len(batch2)} = {len(concatenated)} samples") print(f" Global indexes: {concatenated.global_indexes}") - print(" Note: concat combines multiple batches into one (same structure)") - - # Example 9: Union batches - print("[Example 9] Unioning batches (different fields, same samples)...") - batch_with_input = BatchMeta( - samples=[ - SampleMeta( - partition_id="train_0", - global_index=i, - fields={ - "input_ids": FieldMeta("input_ids", torch.int64, (512,)), - "attention_mask": FieldMeta("attention_mask", torch.int64, (512,)), - }, - ) - for i in range(3) - ] - ) - batch_with_output = BatchMeta( - samples=[ - SampleMeta( - partition_id="train_0", - global_index=i, - fields={ - "responses": FieldMeta("responses", torch.int64, (128,)), - "log_probs": FieldMeta("log_probs", torch.float32, (128,)), - }, - ) - for i in range(3) - ] - ) - print(f" Batch1 has fields: {batch_with_input.field_names}") - print(f" Batch2 has fields: {batch_with_output.field_names}") - print(f" Both have same samples (global_indexes: {batch_with_input.global_indexes})") - - unioned_batch = batch_with_input.union(batch_with_output) - print("✓ Union successful!") - print(f" Unioned fields: {unioned_batch.field_names}") - print(" Note: union merges fields from two batches with SAME samples (same global_indexes)") + print(" Note: concat combines multiple batches with SAME field structure into one larger batch") + + # Example 9: union (new semantics: concat unique samples, dedup by global_index) + print("[Example 9] Unioning batches with overlapping global_indexes...") + batch_a = make_batch(list(range(3)), fields=["input_ids", "attention_mask"]) # indexes [0,1,2] + batch_b = make_batch(list(range(2, 5)), fields=["input_ids", "attention_mask"]) # indexes [2,3,4] — 2 overlaps! + print(f" BatchA fields: {batch_a.field_names}, indexes: {batch_a.global_indexes}") + print(f" BatchB fields: {batch_b.field_names}, indexes: {batch_b.global_indexes}") + unioned = batch_a.union(batch_b) + print(f"✓ Unioned: {unioned.global_indexes} (global_index=2 deduplicated, result: [0,1,2,3,4])") + print(" Note: union keeps self's copy when global_index overlaps") print("=" * 80) print("concat vs union:") - print(" - concat: Combines multiple batches with SAME structure into one larger batch") - print(" Example: batch1[0,1,2] + batch2[3,4,5] = batch[0,1,2,3,4,5]") - print(" - union: Merges fields from two batches with IDENTICAL samples") - print(" Example: batch1[0,1] with fields A + batch2[0,1] with fields B = batch[0,1] with fields A+B") + print(" - concat: Combines multiple batches with SAME field structure into one larger batch") + print(" Example: batch[0,1,2] concat batch[3,4,5] = batch[0,1,2,3,4,5]") + print(" - union: Merges two batches, deduplicating by global_index (keeps self's copy)") + print(" Example: batch[0,1,2] union batch[2,3,4] = batch[0,1,2,3,4]") print("=" * 80) @@ -351,8 +367,8 @@ def demonstrate_real_workflow(): print(f" Number of samples: {len(batch_meta)}") print(f" Global indexes: {batch_meta.global_indexes}") print(f" Field names: {batch_meta.field_names}") - print(f" Partition ID: {batch_meta.samples[0].partition_id}") - print(f" Sample structure: {batch_meta.samples[0]}") + print(f" Partition IDs: {batch_meta.partition_ids}") + print(f" Sample view: fields={batch_meta.samples[0].fields}") print(f" Custom Meta: {batch_meta.get_all_custom_meta()}") print("[Step 4] Retrieve samples with specific fields..") @@ -397,24 +413,24 @@ def main(): This script introduces the metadata system in TransferQueue, which tracks the structure and state of data: - 1. FieldMeta - Describes a single field (name, dtype, shape, production status) - 2. SampleMeta - Describes a single data sample (partition_id, global_index, fields) - 3. BatchMeta - Describes a batch of samples (collection of SampleMeta with operations) + 1. BatchMeta - The central metadata object for a collection of data samples. + Uses a columnar layout: field metadata is stored ONCE at the batch level (O(F)), + not per-sample (was O(B×F) in the old design). Key Concepts: - - Metadata tracks data structure without storing actual data - - User can set their own custom metadata into BatchMeta, and use TQ controller to store them. - - BatchMeta provides operations: chunk, concat, union, select, reorder... - - Metadata is lightweight and can be passed around efficiently - - Union requires samples to have identical partition_id and global_index - """ + - BatchMeta stores global_indexes, partition_ids, and field_schema directly + - field_schema: dict[field_name, {dtype, shape, is_nested, is_non_tensor}] + - custom_meta: list[dict] aligned with global_indexes (one dict per sample) + - Metadata operations: chunk, concat, union, select_fields, select_samples, reorder + - batch.samples[i] returns a lazy view with .fields -> field_schema (read-only) + """ ) ) print("=" * 80) try: - demonstrate_field_meta() - demonstrate_sample_meta() + demonstrate_batch_meta_schema() + demonstrate_batch_meta_construction() demonstrate_batch_meta() demonstrate_real_workflow() @@ -422,12 +438,11 @@ def main(): print("Tutorial Complete!") print("=" * 80) print("Key Takeaways:") - print("1. FieldMeta describes individual data fields (NO batch dimension in shape)") - print("2. SampleMeta describes a single data sample") - print("3. BatchMeta manages collections of samples with operations") - print("4. Metadata operations: chunk, concat, union, select, reorder... You can retrieve subsets easily!") - print("5. extra_info is in batch-level, and custom_meta is in sample-level.") - print("6. You can put custom_meta into TQ controller, so you can retrieve them from anywhere!") + print("1. BatchMeta uses columnar storage: field metadata stored once, not per-sample") + print("2. Construct BatchMeta with: BatchMeta(global_indexes=[...], partition_ids=[...], field_schema={...})") + print("3. BatchMeta operations: chunk, concat, union, select_fields, select_samples, reorder") + print("4. extra_info is batch-level; custom_meta is sample-level (list[dict])") + print("5. Store custom_meta via TQ controller: tq_client.set_custom_meta(batch_meta)") # Cleanup ray.shutdown()