From 0e87f97549a09fba2a96ff42ba6088d575e21be7 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 6 Feb 2026 23:56:40 +0800 Subject: [PATCH 01/34] add kv interface Signed-off-by: 0oshowero0 --- README.md | 6 +- pyproject.toml | 3 +- requirements.txt | 3 +- tests/test_async_simple_storage_manager.py | 15 +- tests/test_client.py | 270 +++- tests/test_controller.py | 202 +++ tests/test_controller_data_partitions.py | 59 + tests/test_kv_interface.py | 539 ++++++++ tests/test_kv_storage_manager.py | 116 +- tests/test_metadata.py | 114 +- transfer_queue/__init__.py | 58 +- transfer_queue/client.py | 579 ++++++--- transfer_queue/config.yaml | 31 + transfer_queue/controller.py | 167 ++- .../dataloader/streaming_dataset.py | 46 +- transfer_queue/interface.py | 1136 +++++++++++++++++ transfer_queue/metadata.py | 51 +- transfer_queue/storage/managers/base.py | 41 +- transfer_queue/storage/managers/factory.py | 36 +- .../storage/managers/mooncake_manager.py | 7 +- .../managers/simple_backend_manager.py | 23 +- .../storage/managers/yuanrong_manager.py | 7 +- transfer_queue/utils/common.py | 68 +- transfer_queue/utils/zmq_utils.py | 39 + tutorial/01_core_components.py | 92 +- tutorial/02_kv_interface.py | 205 +++ ...ta_concepts.py => 03_metadata_concepts.py} | 68 +- ...ller.py => 04_understanding_controller.py} | 101 +- ...custom_sampler.py => 05_custom_sampler.py} | 77 +- ...taloader.py => 06_streaming_dataloader.py} | 80 +- 30 files changed, 3455 insertions(+), 784 deletions(-) create mode 100644 tests/test_kv_interface.py create mode 100644 transfer_queue/config.yaml create mode 100644 transfer_queue/interface.py create mode 100644 tutorial/02_kv_interface.py rename tutorial/{02_metadata_concepts.py => 03_metadata_concepts.py} (89%) rename tutorial/{03_understanding_controller.py => 04_understanding_controller.py} (76%) rename tutorial/{04_custom_sampler.py => 05_custom_sampler.py} (83%) rename tutorial/{05_streaming_dataloader.py => 06_streaming_dataloader.py} (84%) diff --git a/README.md b/README.md index b69176d..2c7d4f0 100644 --- a/README.md +++ b/README.md @@ -73,10 +73,10 @@ This class encapsulates the core interaction logic within the TransferQueue syst Currently, we support the following storage backends: -- SimpleStorageUnit: A basic CPU memory storage with minimal data format constraints and easy usability. +- SimpleStorage: A basic CPU memory storage with minimal data format constraints and easy usability. - [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem) (beta, [#PR107](https://github.com/TransferQueue/TransferQueue/pull/107), [#PR96](https://github.com/TransferQueue/TransferQueue/pull/96)): An Ascend native data system that provides hierarchical storage interfaces including HBM/DRAM/SSD. -- [Mooncake Store](https://github.com/kvcache-ai/Mooncake) (alpha, [#PR162](https://github.com/TransferQueue/TransferQueue/pull/162)): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM. -- [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html) (alpha, [#PR167](https://github.com/TransferQueue/TransferQueue/pull/167)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors. +- [MooncakeStore](https://github.com/kvcache-ai/Mooncake) (alpha, [#PR162](https://github.com/TransferQueue/TransferQueue/pull/162)): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM. +- [RayRDT](https://docs.ray.io/en/master/ray-core/direct-transport.html) (alpha, [#PR167](https://github.com/TransferQueue/TransferQueue/pull/167)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors. Among them, `SimpleStorageUnit` serves as our default storage backend, coordinated by the `AsyncSimpleStorageManager` class. Each storage unit can be deployed on a separate node, allowing for distributed data management. diff --git a/pyproject.toml b/pyproject.toml index f853970..35d6524 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,5 +126,6 @@ yuanrong = [ # This is the rough equivalent of package_data={'': ['version/*']} [tool.setuptools.package-data] transfer_queue = [ - "version/*", + "version/*", + "*.yaml" ] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1da090c..8cfccac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ pyzmq hydra-core numpy<2.0.0 msgspec -psutil \ No newline at end of file +psutil +omegaconf \ No newline at end of file diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index 3949a0d..5187e16 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -62,8 +62,7 @@ async def mock_async_storage_manager(): ) config = { - "storage_unit_infos": storage_unit_infos, - "controller_info": controller_info, + "zmq_info": storage_unit_infos, } # Mock the handshake process entirely to avoid ZMQ complexity @@ -199,8 +198,7 @@ async def test_async_storage_manager_mapping_functions(): ) config = { - "storage_unit_infos": storage_unit_infos, - "controller_info": controller_info, + "zmq_info": storage_unit_infos, } # Mock ZMQ operations @@ -230,7 +228,7 @@ async def test_async_storage_manager_mapping_functions(): mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize()) # Create manager - manager = AsyncSimpleStorageManager(config) + manager = AsyncSimpleStorageManager(controller_info, config) # Test round-robin mapping for 3 storage units # global_index -> storage_unit mapping: 0->storage_0, 1->storage_1, 2->storage_2, @@ -266,7 +264,7 @@ async def test_async_storage_manager_error_handling(): } # Mock controller info - controller_infos = ZMQServerInfo( + controller_info = ZMQServerInfo( role=TransferQueueRole.CONTROLLER, id="controller_0", ip="127.0.0.1", @@ -274,8 +272,7 @@ async def test_async_storage_manager_error_handling(): ) config = { - "storage_unit_infos": storage_unit_infos, - "controller_info": controller_infos, + "zmq_info": storage_unit_infos, } # Mock ZMQ operations @@ -305,7 +302,7 @@ async def test_async_storage_manager_error_handling(): mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize()) # Create manager - manager = AsyncSimpleStorageManager(config) + manager = AsyncSimpleStorageManager(controller_info, config) # Mock operations that raise exceptions manager._put_to_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock PUT error")) diff --git a/tests/test_client.py b/tests/test_client.py index 38f140b..5d3360f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -34,7 +34,7 @@ FieldMeta, SampleMeta, ) -from transfer_queue.utils.enum_utils import TransferQueueRole # noqa: E402 +from transfer_queue.utils.enum_utils import ProductionStatus, TransferQueueRole # noqa: E402 from transfer_queue.utils.zmq_utils import ( # noqa: E402 ZMQMessage, ZMQRequestType, @@ -144,6 +144,12 @@ def _handle_requests(self): "message": "Consumption reset successfully", } response_type = ZMQRequestType.RESET_CONSUMPTION_RESPONSE + elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS: + response_body = self._mock_kv_retrieve_keys(request_msg.body) + response_type = ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE + elif request_msg.request_type == ZMQRequestType.KV_LIST: + response_body = self._mock_kv_list(request_msg.body) + response_type = ZMQRequestType.KV_LIST_RESPONSE else: response_body = {"error": f"Unknown request type: {request_msg.request_type}"} response_type = ZMQRequestType.CLEAR_META_RESPONSE @@ -174,7 +180,7 @@ def _mock_batch_meta(self, request_body): name=field_name, dtype=None, shape=None, - production_status=0, + production_status=ProductionStatus.NOT_PRODUCED, ) fields.append(field_meta) sample = SampleMeta( @@ -187,6 +193,80 @@ def _mock_batch_meta(self, request_body): return {"metadata": metadata} + def _mock_kv_retrieve_keys(self, request_body): + """Mock KV retrieve keys response.""" + keys = request_body.get("keys", []) + create = request_body.get("create", False) + partition_id = request_body.get("partition_id", "") + + # Initialize key tracking if not exists + if not hasattr(self, "_kv_partition_keys"): + self._kv_partition_keys = {} + + # Generate global indexes for the keys + start_index = self._get_next_kv_index(partition_id) + global_indexes = list(range(start_index, start_index + len(keys))) + + # Create metadata for each key + samples = [] + for i, key in enumerate(keys): + field_meta = FieldMeta( + name="data", + dtype=torch.float32, + shape=torch.Size([1, 10]), + production_status=ProductionStatus.READY_FOR_CONSUME, + ) + sample = SampleMeta( + partition_id=partition_id, + global_index=global_indexes[i], + fields={"data": field_meta}, + ) + samples.append(sample) + + metadata = BatchMeta(samples=samples) + + # Store keys for this partition (only when create=True) + if create: + if partition_id not in self._kv_partition_keys: + self._kv_partition_keys[partition_id] = [] + self._kv_partition_keys[partition_id].extend(keys) + + # Update the next index for this partition + if global_indexes: + self._update_kv_index(partition_id, global_indexes[-1] + 1) + + return {"metadata": metadata} + + def _mock_kv_list(self, request_body): + """Mock KV list response.""" + partition_id = request_body.get("partition_id", "") + + # Initialize key tracking if not exists + if not hasattr(self, "_kv_partition_keys"): + self._kv_partition_keys = {} + + # Return cached keys for this partition + keys = self._kv_partition_keys.get(partition_id, []) + return {"keys": keys, "custom_meta": [{} for _ in range(len(keys))]} + + def _get_next_kv_index(self, partition_id): + """Get next available index for KV keys in partition.""" + if not hasattr(self, "_kv_index_map"): + self._kv_index_map = {} + if partition_id not in self._kv_index_map: + self._kv_index_map[partition_id] = 0 + # Also initialize key tracking + if not hasattr(self, "_kv_partition_keys"): + self._kv_partition_keys = {} + self._kv_partition_keys[partition_id] = [] + return self._kv_index_map[partition_id] + + def _update_kv_index(self, partition_id, next_index): + """Update next available index for KV keys.""" + if not hasattr(self, "_kv_index_map"): + self._kv_index_map = {} + self._kv_index_map[partition_id] = next_index + def stop(self): self.running = False time.sleep(0.2) # Give thread time to stop @@ -318,9 +398,9 @@ def client_setup(mock_controller, mock_storage): ): config = { "controller_info": mock_controller.zmq_server_info, - "storage_unit_infos": {mock_storage.storage_id: mock_storage.zmq_server_info}, + "zmq_info": {mock_storage.storage_id: mock_storage.zmq_server_info}, } - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + client.initialize_storage_manager(manager_type="SimpleStorage", config=config) # Mock all storage manager methods to avoid real ZMQ operations async def mock_put_data(data, metadata): @@ -411,9 +491,9 @@ def test_single_controller_multiple_storages(): ): config = { "controller_info": controller.zmq_server_info, - "storage_unit_infos": {s.storage_id: s.zmq_server_info for s in storages}, + "zmq_info": {s.storage_id: s.zmq_server_info for s in storages}, } - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + client.initialize_storage_manager(manager_type="SimpleStorage", config=config) # Mock all storage manager methods to avoid real ZMQ operations async def mock_put_data(data, metadata): @@ -850,10 +930,10 @@ def test_set_custom_meta_sync(self, client_setup): metadata = client.get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0") # Set custom_meta on the metadata metadata.update_custom_meta( - { - 0: {"input_ids": {"token_count": 100}}, - 1: {"input_ids": {"token_count": 120}}, - } + [ + {"input_ids": {"token_count": 100}}, + {"input_ids": {"token_count": 120}}, + ] ) # Call set_custom_meta with metadata (BatchMeta) @@ -869,12 +949,174 @@ async def test_set_custom_meta_async(self, client_setup): metadata = await client.async_get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0") # Set custom_meta on the metadata metadata.update_custom_meta( - { - 0: {"input_ids": {"token_count": 100}}, - 1: {"input_ids": {"token_count": 120}}, - } + [ + {"input_ids": {"token_count": 100}}, + {"input_ids": {"token_count": 120}}, + ] ) # Call async_set_custom_meta with metadata (BatchMeta) await client.async_set_custom_meta(metadata) print("✓ async_set_custom_meta async method works") + + +# ===================================================== +# KV Interface Tests +# ===================================================== + + +class TestClientKVInterface: + """Tests for client KV interface methods.""" + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_single(self, client_setup): + """Test async_kv_retrieve_keys with single key.""" + client, _, _ = client_setup + + # Test async_kv_retrieve_keys with single key + metadata = await client.async_kv_retrieve_keys( + keys="test_key_1", + partition_id="test_partition", + create=True, + ) + + # Verify metadata structure + assert metadata is not None + assert hasattr(metadata, "global_indexes") + assert hasattr(metadata, "size") + assert metadata.size == 1 + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_multiple(self, client_setup): + """Test async_kv_retrieve_keys with multiple keys.""" + client, _, _ = client_setup + + # Test async_kv_retrieve_keys with multiple keys + keys = ["key_a", "key_b", "key_c"] + metadata = await client.async_kv_retrieve_keys( + keys=keys, + partition_id="test_partition", + create=True, + ) + + # Verify metadata structure + assert metadata is not None + assert hasattr(metadata, "global_indexes") + assert hasattr(metadata, "size") + assert metadata.size == 3 + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_create_false(self, client_setup): + """Test async_kv_retrieve_keys with create=False (retrieve existing keys).""" + client, _, _ = client_setup + + # create some keys + await client.async_kv_retrieve_keys( + keys="existing_key", + partition_id="existing_partition", + create=True, + ) + + # Then retrieve them with create=False + metadata = await client.async_kv_retrieve_keys( + keys="existing_key", + partition_id="existing_partition", + create=False, + ) + + # Verify metadata structure + assert metadata is not None + assert metadata.size == 1 + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_invalid_keys_type(self, client_setup): + """Test async_kv_retrieve_keys raises error with invalid keys type.""" + client, _, _ = client_setup + + # Test with invalid keys type (not string or list) + with pytest.raises(TypeError): + await client.async_kv_retrieve_keys( + keys=123, # Invalid type + partition_id="test_partition", + create=True, + ) + + @pytest.mark.asyncio + async def test_async_kv_list_empty_partition(self, client_setup): + """Test async_kv_list with empty partition.""" + client, _, _ = client_setup + + # Test async_kv_list with empty partition + keys, custom_meta = await client.async_kv_list(partition_id="empty_partition") + + # Should return empty list for partition with no keys + assert keys == [] + assert custom_meta == [] + + @pytest.mark.asyncio + async def test_async_kv_list_with_keys(self, client_setup): + """Test async_kv_list returns keys after they are registered.""" + client, mock_controller, _ = client_setup + + # First register some keys + await client.async_kv_retrieve_keys( + keys=["key_1", "key_2"], + partition_id="kv_partition", + create=True, + ) + + # Then list them + keys, custom_meta = await client.async_kv_list(partition_id="kv_partition") + + # Verify keys are returned + assert keys is not None + assert len(keys) >= 2 + assert "key_1" in keys + assert "key_2" in keys + assert custom_meta == [{}, {}] + + @pytest.mark.asyncio + async def test_async_kv_list_multiple_partitions(self, client_setup): + """Test async_kv_list with multiple partitions.""" + client, _, _ = client_setup + + # Create keys in different partitions + await client.async_kv_retrieve_keys( + keys="partition_a_key", + partition_id="partition_a", + create=True, + ) + await client.async_kv_retrieve_keys( + keys="partition_b_key", + partition_id="partition_b", + create=True, + ) + + # List keys for each partition + keys_a, custom_meta_a = await client.async_kv_list(partition_id="partition_a") + keys_b, custom_meta_b = await client.async_kv_list(partition_id="partition_b") + + # Verify keys are isolated per partition + assert "partition_a_key" in keys_a + assert "partition_b_key" not in keys_a + assert "partition_b_key" in keys_b + assert "partition_a_key" not in keys_b + assert custom_meta_a == [{}] + assert custom_meta_b == [{}] + + def test_kv_retrieve_keys_type_validation(self, client_setup): + """Test synchronous kv_retrieve_keys type validation.""" + import asyncio + + client, _, _ = client_setup + + # Test with non-string element in list + async def test_invalid_list(): + with pytest.raises(TypeError): + await client.async_kv_retrieve_keys( + keys=["valid_key", 123], # Invalid: 123 is not a string + partition_id="test_partition", + create=True, + ) + + asyncio.run(test_invalid_list()) diff --git a/tests/test_controller.py b/tests/test_controller.py index 09f5778..f3a5c26 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -760,3 +760,205 @@ def test_controller_with_custom_meta(self, ray_setup): # Clean up ray.get(tq_controller.clear_partition.remote(partition_id)) + + +class TestTransferQueueControllerKvInterface: + """End-to-end tests for TransferQueueController KV interface functionality. + + Tests for kv_retrieve_keys method that supports key-value interface operations + across the controller and partition layers. + """ + + def test_controller_kv_retrieve_keys_create_mode(self, ray_setup): + """Test kv_retrieve_keys with create=True creates new keys in partition.""" + tq_controller = TransferQueueController.remote() + partition_id = "kv_test_partition" + + # Retrieve keys with create=True - should create partition and keys + keys = ["key_a", "key_b", "key_c"] + metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + + # Verify partition was created + partitions = ray.get(tq_controller.list_partitions.remote()) + assert partition_id in partitions + + # Verify metadata contains correct number of global_indexes + assert len(metadata.global_indexes) == len(keys) + + # Verify partition has keys_mapping + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + assert "key_a" in partition.keys_mapping + assert "key_b" in partition.keys_mapping + assert "key_c" in partition.keys_mapping + assert metadata.global_indexes[0] == partition.keys_mapping["key_a"] + assert metadata.global_indexes[1] == partition.keys_mapping["key_b"] + assert metadata.global_indexes[2] == partition.keys_mapping["key_c"] + assert partition.revert_keys_mapping[metadata.global_indexes[0]] == "key_a" + assert partition.revert_keys_mapping[metadata.global_indexes[1]] == "key_b" + assert partition.revert_keys_mapping[metadata.global_indexes[2]] == "key_c" + + print("✓ kv_retrieve_keys with create=True creates keys correctly") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_existing_keys(self, ray_setup): + """Test kv_retrieve_keys retrieves existing keys correctly.""" + tq_controller = TransferQueueController.remote() + partition_id = "kv_existing_test" + + # First, create some keys + keys = ["existing_key_1", "existing_key_2"] + ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + + # Retrieve the same keys again (should return existing) + retrieved_metadata = ray.get( + tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False) + ) + + # Verify the same global_indexes are returned + assert len(retrieved_metadata.global_indexes) == len(keys) + + print("✓ kv_retrieve_keys retrieves existing keys correctly") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_non_existent_without_create(self, ray_setup): + """Test kv_retrieve_keys raises error for non-existent keys without create.""" + tq_controller = TransferQueueController.remote() + partition_id = "kv_nonexistent_test" + + # Create partition first + ray.get(tq_controller.kv_retrieve_keys.remote(keys=["initial_key"], partition_id=partition_id, create=True)) + + # Try to retrieve non-existent key without create + batch_meta = ray.get( + tq_controller.kv_retrieve_keys.remote(keys=["nonexistent_key"], partition_id=partition_id, create=False) + ) + assert batch_meta.size == 0 + + print("✓ kv_retrieve_keys return an empty BatchMeta for non-existent keys without create") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_empty_partition_without_create(self, ray_setup): + """Test kv_retrieve_keys raises error for non-existent partition without create.""" + tq_controller = TransferQueueController.remote() + partition_id = "nonexistent_partition" + + batch_meta = ray.get( + tq_controller.kv_retrieve_keys.remote(keys=["key_1"], partition_id=partition_id, create=False) + ) + assert batch_meta.size == 0 + + print("✓ kv_retrieve_keys return an empty BatchMeta for non-existent partition_id without create") + + def test_controller_kv_retrieve_keys_with_production_status(self, ray_setup): + """Test kv_retrieve_keys works with production status update.""" + tq_controller = TransferQueueController.remote() + partition_id = "kv_production_test" + + # Create keys + keys = ["sample_1", "sample_2", "sample_3"] + metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + global_indexes = metadata.global_indexes + + # Update production status + dtypes = {idx: {"data": "torch.float32"} for idx in global_indexes} + shapes = {idx: {"data": (64,)} for idx in global_indexes} + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=global_indexes, + field_names=["data"], + dtypes=dtypes, + shapes=shapes, + ) + ) + assert success + + # Retrieve keys again (should include production info) + retrieved_metadata = ray.get( + tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False) + ) + + # Verify production status is available + assert len(retrieved_metadata.samples) == len(keys) + for sample in retrieved_metadata.samples: + assert "data" in sample.fields + assert sample.fields["data"].dtype == "torch.float32" + assert sample.fields["data"].shape == (64,) + + print("✓ kv_retrieve_keys works with production status") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_with_custom_meta(self, ray_setup): + """Test kv_retrieve_keys preserves custom_meta through retrieve.""" + tq_controller = TransferQueueController.remote() + partition_id = "kv_custom_meta_test" + + # Create keys + keys = ["key_1", "key_2"] + metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + + # Set custom_meta + custom_meta = { + partition_id: { + metadata.global_indexes[0]: {"score": 0.9, "tag": "A"}, + metadata.global_indexes[1]: {"score": 0.8, "tag": "B"}, + } + } + ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=custom_meta)) + + # Retrieve keys and verify custom_meta + retrieved_metadata = ray.get( + tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False) + ) + + # Verify custom_meta is preserved + all_custom_meta = retrieved_metadata.get_all_custom_meta() + assert len(all_custom_meta) == 2 + assert all_custom_meta[metadata.global_indexes[0]]["score"] == 0.9 + assert all_custom_meta[metadata.global_indexes[1]]["tag"] == "B" + + print("✓ kv_retrieve_keys preserves custom_meta") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_interface_multiple_partitions(self, ray_setup): + """Test KV interface works correctly across multiple partitions.""" + tq_controller = TransferQueueController.remote() + + # Create keys in partition 1 + partition_1 = "partition_kv_1" + keys_1 = ["p1_key_a", "p1_key_b"] + ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys_1, partition_id=partition_1, create=True)) + + # Create keys in partition 2 + partition_2 = "partition_kv_2" + keys_2 = ["p2_key_x", "p2_key_y", "p2_key_z"] + ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys_2, partition_id=partition_2, create=True)) + + # Verify partitions are isolated + partition_1_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_1)) + partition_2_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_2)) + + assert "p1_key_a" in partition_1_snapshot.keys_mapping + assert "p1_key_b" in partition_1_snapshot.keys_mapping + assert "p2_key_x" in partition_2_snapshot.keys_mapping + assert "p2_key_z" in partition_2_snapshot.keys_mapping + + # Verify cross-partition access is isolated + assert "p2_key_x" not in partition_1_snapshot.keys_mapping + assert "p1_key_a" not in partition_2_snapshot.keys_mapping + + print("✓ KV interface maintains partition isolation") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_1)) + ray.get(tq_controller.clear_partition.remote(partition_2)) diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 31478ba..6eba7a0 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -941,3 +941,62 @@ def test_custom_meta_cleared_with_data(self): result = partition.get_custom_meta([0, 1]) assert 0 not in result assert 1 in result # Sample 1 should still have custom_meta + + +class TestDataPartitionStatusKvInterface: + """Unit tests for DataPartitionStatus KV interface functionality. + + Tests for the keys_mapping and kv_retrieve_keys methods that support + key-value interface operations within a partition. + """ + + def test_kv_retrieve_keys_with_existing_keys(self): + """Test kv_retrieve_keys returns correct global_indexes for existing keys.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + # Simulate keys being registered (as would happen during kv_put) + partition.keys_mapping = {"key_a": 0, "key_b": 1, "key_c": 2} + + # Retrieve keys + global_indexes = partition.kv_retrieve_keys(["key_a", "key_b", "key_c"]) + + assert global_indexes == [0, 1, 2] + + def test_kv_retrieve_keys_with_nonexistent_keys(self): + """Test kv_retrieve_keys returns None for keys that don't exist.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + # Simulate some keys being registered + partition.keys_mapping = {"existing_key": 5} + + # Retrieve mixed existing and non-existing keys + global_indexes = partition.kv_retrieve_keys(["existing_key", "nonexistent_key"]) + + assert global_indexes == [5, None] + + def test_kv_retrieve_keys_empty_list(self): + """Test kv_retrieve_keys handles empty key list.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + global_indexes = partition.kv_retrieve_keys([]) + + assert global_indexes == [] + + def test_kv_retrieve_keys_partial_match(self): + """Test kv_retrieve_keys with partial key matches.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + partition.keys_mapping = {"key_1": 10, "key_2": 20, "key_3": 30} + + # Request only some of the keys + global_indexes = partition.kv_retrieve_keys(["key_1", "key_3"]) + + assert global_indexes == [10, 30] diff --git a/tests/test_kv_interface.py b/tests/test_kv_interface.py new file mode 100644 index 0000000..24d74af --- /dev/null +++ b/tests/test_kv_interface.py @@ -0,0 +1,539 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for kv interface in transfer_queue.interface.""" + +import asyncio +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import torch +from tensordict import TensorDict + +# Setup path +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + +from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 +from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 + + +def create_batch_meta(global_indexes, fields_data): + """Helper to create BatchMeta for testing.""" + samples = [] + for sample_id, global_idx in enumerate(global_indexes): + fields_dict = {} + for field_name, tensor in fields_data.items(): + field_meta = FieldMeta( + name=field_name, + dtype=tensor.dtype, + shape=tensor.shape, + production_status=ProductionStatus.READY_FOR_CONSUME, + ) + fields_dict[field_name] = field_meta + sample = SampleMeta( + partition_id="test_partition", + global_index=global_idx, + fields=fields_dict, + ) + samples.append(sample) + return BatchMeta(samples=samples) + + +class TestKVPut: + """Tests for kv_put function.""" + + def test_kv_put_with_fields(self): + """Test kv_put with fields parameter.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + + tensor_data = TensorDict( + {"text": torch.tensor([[1, 2, 3]]), "label": torch.tensor([0])}, + batch_size=[1], + ) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_put + + kv_put(key="test_key", partition_id="partition_1", fields=tensor_data, tag={"type": "test"}) + + # Verify kv_retrieve_keys was called + mock_client.kv_retrieve_keys.assert_called_once_with(keys=["test_key"], partition_id="partition_1", create=True) + + # Verify update_custom_meta was called + mock_batch_meta.update_custom_meta.assert_called_once() + + # Verify put was called + mock_client.put.assert_called_once() + + def test_kv_put_with_dict_fields(self): + """Test kv_put converts dict to TensorDict correctly.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_put + + # Test with dict containing tensor + kv_put( + key="test_key", + partition_id="partition_1", + fields={"text": torch.tensor([1, 2, 3])}, + tag=None, + ) + + # Verify put was called + mock_client.put.assert_called_once() + call_args = mock_client.put.call_args + fields_arg = call_args[0][0] + assert "text" in fields_arg + # The dict should be converted to TensorDict + assert isinstance(fields_arg, TensorDict) + + def test_kv_put_with_tag_only(self): + """Test kv_put with only tag parameter (no fields).""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_put + + kv_put(key="test_key", partition_id="partition_1", fields=None, tag={"score": 0.9}) + + # Verify put was NOT called (only set_custom_meta) + mock_client.put.assert_not_called() + mock_client.set_custom_meta.assert_called_once_with(mock_batch_meta) + + def test_kv_put_raises_error_without_fields_and_tag(self): + """Test kv_put raises ValueError when neither fields nor tag provided.""" + mock_client = MagicMock() + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_put + + with pytest.raises(ValueError, match="Please provide at least one parameter"): + kv_put(key="test_key", partition_id="partition_1", fields=None, tag=None) + + +class TestKVBatchPut: + """Tests for kv_batch_put function.""" + + def test_kv_batch_put_success(self): + """Test kv_batch_put with valid inputs.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + + batch_data = TensorDict( + { + "text": torch.tensor([[1, 2], [3, 4], [5, 6]]), + "label": torch.tensor([0, 1, 2]), + }, + batch_size=[3], + ) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_batch_put + + keys = ["key1", "key2", "key3"] + tags = [{"tag": "v1"}, {"tag": "v2"}, {"tag": "v3"}] + + kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags) + + mock_client.kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=True) + mock_batch_meta.update_custom_meta.assert_called_once_with(tags) + mock_client.put.assert_called_once() + + def test_kv_batch_put_tags_length_mismatch(self): + """Test kv_batch_put raises error when tags length doesn't match keys.""" + mock_client = MagicMock() + + batch_data = TensorDict( + { + "text": torch.tensor([[1, 2], [3, 4], [5, 6]]), + "label": torch.tensor([0, 1, 2]), + }, + batch_size=[3], + ) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_batch_put + + keys = ["key1", "key2", "key3"] + tags = [{"tag": "v1"}, {"tag": "v2"}] # Only 2 tags for 3 keys + + with pytest.raises(ValueError, match="does not match length of tags"): + kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags) + + +class TestKVGet: + """Tests for kv_get function.""" + + def test_kv_get_single_key(self): + """Test kv_get with single key.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + mock_client.get_data.return_value = TensorDict({"data": torch.tensor([1, 2])}) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_get + + kv_get(keys="test_key", partition_id="partition_1") + + # keys is passed directly (not wrapped in list) for single key + mock_client.kv_retrieve_keys.assert_called_once_with(keys="test_key", partition_id="partition_1", create=False) + mock_client.get_data.assert_called_once_with(mock_batch_meta) + + def test_kv_get_multiple_keys(self): + """Test kv_get with multiple keys.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + mock_client.get_data.return_value = TensorDict({"data": torch.tensor([1, 2])}) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_get + + keys = ["key1", "key2", "key3"] + kv_get(keys=keys, partition_id="partition_1") + + mock_client.kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=False) + + def test_kv_get_with_fields(self): + """Test kv_get with specific fields.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_batch_meta.select_fields = MagicMock(return_value=mock_batch_meta) + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + mock_client.get_data.return_value = TensorDict({"text": torch.tensor([1, 2])}) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_get + + kv_get(keys="test_key", partition_id="partition_1", fields="text") + + mock_batch_meta.select_fields.assert_called_once_with(["text"]) + + +class TestKVList: + """Tests for kv_list function.""" + + def test_kv_list_with_keys(self): + """Test kv_list returns keys and custom_meta.""" + mock_client = MagicMock() + mock_client.kv_list.return_value = ["key1", "key2", "key3"] + mock_batch_meta = MagicMock() + mock_batch_meta.global_indexes = [0, 1, 2] + mock_batch_meta.get_all_custom_meta = MagicMock(return_value={0: {}, 1: {}, 2: {}}) + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_list + + keys, custom_meta = kv_list(partition_id="partition_1") + + assert keys == ["key1", "key2", "key3"] + assert len(custom_meta) == 3 + + def test_kv_list_empty_partition(self): + """Test kv_list returns None when partition is empty.""" + mock_client = MagicMock() + mock_client.kv_list.return_value = [], [] + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_list + + keys, custom_meta = kv_list(partition_id="empty_partition") + + assert keys == [] + assert custom_meta == [] + + +class TestKVClear: + """Tests for kv_clear function.""" + + def test_kv_clear_single_key(self): + """Test kv_clear with single key.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_batch_meta.size = 1 + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_clear + + kv_clear(keys="test_key", partition_id="partition_1") + + mock_client.kv_retrieve_keys.assert_called_once_with( + keys=["test_key"], partition_id="partition_1", create=False + ) + mock_client.clear_samples.assert_called_once_with(mock_batch_meta) + + def test_kv_clear_multiple_keys(self): + """Test kv_clear with multiple keys.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_batch_meta.size = 3 + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_clear + + kv_clear(keys=["key1", "key2", "key3"], partition_id="partition_1") + + mock_client.kv_retrieve_keys.assert_called_once_with( + keys=["key1", "key2", "key3"], partition_id="partition_1", create=False + ) + mock_client.clear_samples.assert_called_once() + + +class TestAsyncKVPut: + """Tests for async_kv_put function.""" + + def test_async_kv_put_with_fields(self): + """Test async_kv_put with fields parameter.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) + mock_client.async_put = AsyncMock() + mock_client.async_set_custom_meta = AsyncMock() + + tensor_data = TensorDict( + {"text": torch.tensor([[1, 2, 3]]), "label": torch.tensor([0])}, + batch_size=[1], + ) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import async_kv_put + + asyncio.run( + async_kv_put(key="test_key", partition_id="partition_1", fields=tensor_data, tag={"type": "test"}) + ) + + mock_client.async_kv_retrieve_keys.assert_called_once_with( + keys=["test_key"], partition_id="partition_1", create=True + ) + mock_batch_meta.update_custom_meta.assert_called_once() + mock_client.async_put.assert_called_once() + + def test_async_kv_put_with_tag_only(self): + """Test async_kv_put with only tag (no fields).""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) + mock_client.async_put = AsyncMock() + mock_client.async_set_custom_meta = AsyncMock() + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import async_kv_put + + asyncio.run(async_kv_put(key="test_key", partition_id="partition_1", fields=None, tag={"score": 0.9})) + + mock_client.async_put.assert_not_called() + mock_client.async_set_custom_meta.assert_called_once_with(mock_batch_meta) + + +class TestAsyncKVBatchPut: + """Tests for async_kv_batch_put function.""" + + def test_async_kv_batch_put_success(self): + """Test async_kv_batch_put with valid inputs.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) + mock_client.async_put = AsyncMock() + mock_client.async_set_custom_meta = AsyncMock() + + batch_data = TensorDict( + { + "text": torch.tensor([[1, 2], [3, 4], [5, 6]]), + "label": torch.tensor([0, 1, 2]), + }, + batch_size=[3], + ) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import async_kv_batch_put + + keys = ["key1", "key2", "key3"] + tags = [{"tag": "v1"}, {"tag": "v2"}, {"tag": "v3"}] + + asyncio.run(async_kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags)) + + mock_client.async_kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=True) + mock_batch_meta.update_custom_meta.assert_called_once_with(tags) + mock_client.async_put.assert_called_once() + + +class TestAsyncKVGet: + """Tests for async_kv_get function.""" + + def test_async_kv_get_single_key(self): + """Test async_kv_get with single key.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) + mock_client.async_get_data = AsyncMock(return_value=TensorDict({"data": torch.tensor([1, 2])})) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import async_kv_get + + asyncio.run(async_kv_get(keys="test_key", partition_id="partition_1")) + + # keys is passed directly (not wrapped in list) for single key + mock_client.async_kv_retrieve_keys.assert_called_once_with( + keys="test_key", partition_id="partition_1", create=False + ) + mock_client.async_get_data.assert_called_once_with(mock_batch_meta) + + +class TestAsyncKVList: + """Tests for async_kv_list function.""" + + def test_async_kv_list_with_keys(self): + """Test async_kv_list returns keys and custom_meta.""" + mock_client = MagicMock() + mock_client.async_kv_list = AsyncMock(return_value=["key1", "key2", "key3"]) + mock_batch_meta = MagicMock() + mock_batch_meta.global_indexes = [0, 1, 2] + mock_batch_meta.get_all_custom_meta = MagicMock(return_value={0: {}, 1: {}, 2: {}}) + mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import async_kv_list + + keys, custom_meta = asyncio.run(async_kv_list(partition_id="partition_1")) + + assert keys == ["key1", "key2", "key3"] + assert len(custom_meta) == 3 + + def test_async_kv_list_empty_partition(self): + """Test async_kv_list returns None when partition is empty.""" + mock_client = MagicMock() + mock_client.async_kv_list = AsyncMock(return_value=[]) + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import async_kv_list + + keys, custom_meta = asyncio.run(async_kv_list(partition_id="empty_partition")) + + assert keys is None + assert custom_meta is None + + +class TestAsyncKVClear: + """Tests for async_kv_clear function.""" + + def test_async_kv_clear_single_key(self): + """Test async_kv_clear with single key.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_batch_meta.size = 1 + mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) + mock_client.async_clear_samples = AsyncMock() + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import async_kv_clear + + asyncio.run(async_kv_clear(keys="test_key", partition_id="partition_1")) + + mock_client.async_kv_retrieve_keys.assert_called_once_with( + keys=["test_key"], partition_id="partition_1", create=False + ) + mock_client.async_clear_samples.assert_called_once_with(mock_batch_meta) + + def test_async_kv_clear_multiple_keys(self): + """Test async_kv_clear with multiple keys.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_batch_meta.size = 3 + mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) + mock_client.async_clear_samples = AsyncMock() + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import async_kv_clear + + asyncio.run(async_kv_clear(keys=["key1", "key2", "key3"], partition_id="partition_1")) + + mock_client.async_kv_retrieve_keys.assert_called_once() + mock_client.async_clear_samples.assert_called_once() + + +class TestKVInterfaceDictConversion: + """Tests for dict to TensorDict conversion in kv_put.""" + + def test_kv_put_with_nontensor_value(self): + """Test kv_put converts non-tensor values using NonTensorStack.""" + mock_client = MagicMock() + mock_batch_meta = MagicMock() + mock_client.kv_retrieve_keys.return_value = mock_batch_meta + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_put + + # Test with non-tensor value (like a string or list) + kv_put( + key="test_key", + partition_id="partition_1", + fields={"meta": {"key": "value"}}, + tag=None, + ) + + # Verify put was called + mock_client.put.assert_called_once() + call_args = mock_client.put.call_args + fields_arg = call_args[0][0] + # The dict should be converted to TensorDict + assert isinstance(fields_arg, TensorDict) + assert "meta" in fields_arg + + def test_kv_put_rejects_nested_tensor(self): + """Test kv_put raises ValueError for nested tensors (requires batch_put).""" + mock_client = MagicMock() + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_put + + nested_tensor = torch.nested.nested_tensor([[1, 2], [3, 4]]) + + with pytest.raises(ValueError, match="Please use.*kv_batch_put"): + kv_put( + key="test_key", + partition_id="partition_1", + fields={"nested": nested_tensor}, + tag=None, + ) + + def test_kv_put_invalid_fields_type(self): + """Test kv_put raises ValueError for invalid fields type.""" + mock_client = MagicMock() + + with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): + from transfer_queue.interface import kv_put + + with pytest.raises(ValueError, match="field can only be dict or TensorDict"): + kv_put( + key="test_key", + partition_id="partition_1", + fields="invalid_string", + tag=None, + ) diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index 41296dd..ec5b29c 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -27,6 +27,7 @@ from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402 +from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 def get_meta(data, global_indexes=None): @@ -41,7 +42,7 @@ def get_meta(data, global_indexes=None): name=field_name, dtype=tensor.dtype if isinstance(tensor, torch.Tensor) else None, shape=tensor.shape if isinstance(tensor, torch.Tensor) else None, - production_status=1, + production_status=ProductionStatus.READY_FOR_CONSUME, ) fields_dict[field_name] = field_meta sample = SampleMeta( @@ -117,7 +118,7 @@ def test_merge_tensors_to_tensordict(mock_create, test_data): mock_client = MagicMock() mock_create.return_value = mock_client - manager = KVStorageManager(test_data["cfg"]) + manager = KVStorageManager(controller_info=MagicMock(), config=test_data["cfg"]) assert manager.storage_client is mock_client assert manager._multi_threads_executor is None @@ -163,8 +164,8 @@ def test_merge_tensors_to_tensordict(mock_create, test_data): assert complex_tensordict[key] == complex_data[key] -def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data): - """Test _get_shape_type_custom_backend_meta_list returns correct shapes and dtypes without custom_meta.""" +def test_get_shape_type_custom_backend_meta_list_without_custom_backend_meta(test_data): + """Test _get_shape_type_custom_backend_meta_list returns correct shapes and dtypes without custom_backend_meta.""" shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list( test_data["metadata"] ) @@ -185,7 +186,7 @@ def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data): torch.Size([2]), # text[2] ] expected_dtypes = [torch.int64] * (len(test_data["field_names"]) * len(test_data["global_indexes"])) - # No custom_meta provided, so all should be None + # No custom_backend_meta provided, so all should be None expected_custom_backend_meta = [None] * (len(test_data["field_names"]) * len(test_data["global_indexes"])) assert shapes == expected_shapes @@ -193,9 +194,9 @@ def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data): assert custom_backend_meta_list == expected_custom_backend_meta -def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data): - """Test _get_shape_type_custom_meta_list returns correct custom_meta when provided.""" - # Add custom_meta to metadata +def test_get_shape_type_custom_backend_meta_list_with_custom_backend_meta(test_data): + """Test _get_shape_type_custom_backend_meta_list returns correct custom_backend_meta when provided.""" + # Add custom_backend_meta to metadata custom_backend_meta = { 8: {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}}, 9: {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}}, @@ -206,8 +207,8 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data): shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) - # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index - expected_custom_meta = [ + # Check custom_backend_meta - order is label, mask, text (sorted alphabetically) by global_index + expected_custom_backend_meta = [ {"key2": "value2"}, # label, global_index=8 {"key5": "value5"}, # label, global_index=9 {"key8": "value8"}, # label, global_index=10 @@ -218,15 +219,15 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data): {"key4": "value4"}, # text, global_index=9 {"key7": "value7"}, # text, global_index=10 ] - assert custom_backend_meta_list == expected_custom_meta + assert custom_backend_meta_list == expected_custom_backend_meta -def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_data): - """Test _get_shape_type_custom_backend_meta_list handles partial custom_meta correctly.""" - # Add custom_meta only for some global_indexes and fields +def test_get_shape_type_custom_backend_meta_list_with_partial_custom_backend_meta(test_data): + """Test _get_shape_type_custom_backend_meta_list handles partial custom_backend_meta correctly.""" + # Add custom_backend_meta only for some global_indexes and fields custom_backend_meta = { 8: {"text": {"key1": "value1"}}, # Only text field - # global_index 9 has no custom_meta + # global_index 9 has no custom_backend_meta 10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only } metadata = test_data["metadata"] @@ -234,17 +235,17 @@ def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_d shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) - # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index + # Check custom_backend_meta - order is label, mask, text (sorted alphabetically) by global_index expected_custom_backend_meta = [ - None, # label, global_index=8 (not in custom_meta) - None, # label, global_index=9 (not in custom_meta) + None, # label, global_index=8 (not in custom_backend_meta) + None, # label, global_index=9 (not in custom_backend_meta) {"key2": "value2"}, # label, global_index=10 - None, # mask, global_index=8 (not in custom_meta) - None, # mask, global_index=9 (not in custom_meta) + None, # mask, global_index=8 (not in custom_backend_meta) + None, # mask, global_index=9 (not in custom_backend_meta) {"key3": "value3"}, # mask, global_index=10 {"key1": "value1"}, # text, global_index=8 - None, # text, global_index=9 (not in custom_meta) - None, # text, global_index=10 (not in custom_meta for text) + None, # text, global_index=9 (not in custom_backend_meta) + None, # text, global_index=10 (not in custom_backend_meta for text) ] assert custom_backend_meta_list == expected_custom_backend_meta @@ -279,13 +280,13 @@ def test_data_for_put_data(): @patch.object(KVStorageManager, "_connect_to_controller", lambda self: None) @patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock) -def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_for_put_data): - """Test that put_data correctly processes custom_meta returned by storage client.""" +def test_put_data_with_custom_backend_meta_from_storage_client(mock_notify, test_data_for_put_data): + """Test that put_data correctly processes custom_backend_meta returned by storage client.""" # Create a mock storage client mock_storage_client = MagicMock() - # Simulate storage client returning custom_meta (one per key) + # Simulate storage client returning custom_backend_meta (one per key) # Keys order: label[0,1,2], text[0,1,2] (sorted by field name) - mock_custom_meta = [ + mock_custom_backend_meta = [ {"storage_key": "0@label"}, {"storage_key": "1@label"}, {"storage_key": "2@label"}, @@ -293,12 +294,12 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo {"storage_key": "1@text"}, {"storage_key": "2@text"}, ] - mock_storage_client.put.return_value = mock_custom_meta + mock_storage_client.put.return_value = mock_custom_backend_meta # Create manager with mocked dependencies - config = {"controller_info": MagicMock(), "client_name": "MockClient"} + config = {"client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): - manager = KVStorageManager(config) + manager = KVStorageManager(controller_info=MagicMock(), config=config) # Run put_data asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"])) @@ -314,64 +315,65 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo assert keys == expected_keys assert len(values) == 6 - # Verify notify_data_update was called with correct custom_meta structure + # Verify notify_data_update was called with correct custom_backend_meta structure mock_notify.assert_called_once() notify_call_args = mock_notify.call_args - per_field_custom_meta = notify_call_args[0][5] # 6th positional argument + per_field_custom_backend_meta = notify_call_args[0][5] # 6th positional argument - # Verify custom_meta is structured correctly: {global_index: {field: meta}} - assert 0 in per_field_custom_meta - assert 1 in per_field_custom_meta - assert 2 in per_field_custom_meta + # Verify custom_backend_meta is structured correctly: {global_index: {field: meta}} + assert 0 in per_field_custom_backend_meta + assert 1 in per_field_custom_backend_meta + assert 2 in per_field_custom_backend_meta - assert per_field_custom_meta[0]["label"] == {"storage_key": "0@label"} - assert per_field_custom_meta[0]["text"] == {"storage_key": "0@text"} - assert per_field_custom_meta[1]["label"] == {"storage_key": "1@label"} - assert per_field_custom_meta[1]["text"] == {"storage_key": "1@text"} - assert per_field_custom_meta[2]["label"] == {"storage_key": "2@label"} - assert per_field_custom_meta[2]["text"] == {"storage_key": "2@text"} + assert per_field_custom_backend_meta[0]["label"] == {"storage_key": "0@label"} + assert per_field_custom_backend_meta[0]["text"] == {"storage_key": "0@text"} + assert per_field_custom_backend_meta[1]["label"] == {"storage_key": "1@label"} + assert per_field_custom_backend_meta[1]["text"] == {"storage_key": "1@text"} + assert per_field_custom_backend_meta[2]["label"] == {"storage_key": "2@label"} + assert per_field_custom_backend_meta[2]["text"] == {"storage_key": "2@text"} - # Verify metadata was updated with custom_meta - all_custom_meta = test_data_for_put_data["metadata"].get_all_custom_meta() - assert all_custom_meta[0]["label"] == {"storage_key": "0@label"} - assert all_custom_meta[2]["text"] == {"storage_key": "2@text"} + # Verify metadata was updated with custom_backend_meta + all_custom_backend_meta = test_data_for_put_data["metadata"]._custom_backend_meta + assert len(all_custom_backend_meta) == 3 + assert all_custom_backend_meta[0]["label"] == {"storage_key": "0@label"} + assert all_custom_backend_meta[2]["text"] == {"storage_key": "2@text"} @patch.object(KVStorageManager, "_connect_to_controller", lambda self: None) @patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock) -def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data): - """Test that put_data works correctly when storage client returns no custom_meta.""" - # Create a mock storage client that returns None for custom_meta +def test_put_data_without_custom_backend_meta(mock_notify, test_data_for_put_data): + """Test that put_data works correctly when storage client returns no custom_backend_meta.""" + # Create a mock storage client that returns None for custom_backend_meta mock_storage_client = MagicMock() mock_storage_client.put.return_value = None # Create manager with mocked dependencies config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): - manager = KVStorageManager(config) + manager = KVStorageManager(controller_info=MagicMock(), config=config) # Run put_data asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"])) - # Verify notify_data_update was called with empty dict for custom_meta + # Verify notify_data_update was called with empty dict for custom_backend_meta mock_notify.assert_called_once() notify_call_args = mock_notify.call_args - per_field_custom_meta = notify_call_args[0][5] # 6th positional argument - assert per_field_custom_meta == {} + per_field_custom_backend_meta = notify_call_args[0][5] # 6th positional argument + assert per_field_custom_backend_meta == {} @patch.object(KVStorageManager, "_connect_to_controller", lambda self: None) -def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_data): - """Test that put_data raises ValueError when custom_meta length doesn't match keys.""" - # Create a mock storage client that returns mismatched custom_meta length +def test_put_data_custom_backend_meta_length_mismatch_raises_error(test_data_for_put_data): + """Test that put_data raises ValueError when custom_backend_meta length doesn't match keys.""" + # Create a mock storage client that returns mismatched custom_backend_meta length mock_storage_client = MagicMock() - # Return only 3 custom_meta entries when 6 are expected + # Return only 3 custom_backend_meta entries when 6 are expected mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}] # Create manager with mocked dependencies config = {"controller_info": MagicMock(), "client_name": "MockClient"} with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client): - manager = KVStorageManager(config) + manager = KVStorageManager(controller_info=MagicMock(), config=config) # Run put_data and expect ValueError with pytest.raises(ValueError) as exc_info: diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 2bbf40c..6db10a6 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -788,64 +788,6 @@ def test_batch_meta_select_samples_with_extra_info(self): # ===================================================== # Custom Meta Tests # ===================================================== - - def test_batch_meta_set_custom_meta_basic(self): - """Test set_custom_meta sets metadata for a sample by global_index.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - "field_b": FieldMeta(name="field_b", dtype=torch.int64, shape=(3,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Set custom_meta for sample 0 - batch.set_custom_meta(0, {"sample_score": 0.9, "quality": "high"}) - - result = batch.get_all_custom_meta() - assert 0 in result - assert result[0]["sample_score"] == 0.9 - assert result[0]["quality"] == "high" - # Sample 1 should not have custom_meta - assert 1 not in result - - def test_batch_meta_set_custom_meta_overwrites(self): - """Test set_custom_meta overwrites existing metadata.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Set initial custom_meta - batch.set_custom_meta(0, {"sample_score": 0.9, "quality": "high"}) - - # Overwrite with new custom_meta - batch.set_custom_meta(0, {"sample_score": 0.1, "quality": "low"}) - - result = batch.get_all_custom_meta() - assert result[0]["sample_score"] == 0.1 - assert result[0]["quality"] == "low" - - def test_batch_meta_set_custom_meta_invalid_global_index(self): - """Test set_custom_meta raises error for invalid global_index.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Try to set with non-existent global index - with pytest.raises(ValueError) as exc_info: - batch.set_custom_meta(999, {"sample_score": 0.9}) - assert "not found in global_indexes" in str(exc_info.value) - def test_batch_meta_update_custom_meta(self): """Test update_custom_meta adds metadata for different global indices.""" fields = { @@ -859,10 +801,7 @@ def test_batch_meta_update_custom_meta(self): batch = BatchMeta(samples=samples) # Initial custom_meta for sample 0 - batch.update_custom_meta({0: {"sample_score": 0.9}}) - - # Update with metadata for sample 1 - batch.update_custom_meta({1: {"sample_score": 0.1}}) + batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}]) result = batch.get_all_custom_meta() assert result[0]["sample_score"] == 0.9 @@ -879,10 +818,10 @@ def test_batch_meta_update_custom_meta_overwrites(self): batch = BatchMeta(samples=samples) # Initial custom_meta - batch.update_custom_meta({0: {"sample_score": 0.9, "quality": "high"}}) + batch.update_custom_meta([{"sample_score": 0.9, "quality": "high"}]) # Update with new value for same field - dict.update replaces - batch.update_custom_meta({0: {"sample_score": 0.1, "quality": "low"}}) + batch.update_custom_meta([{"sample_score": 0.1, "quality": "low"}]) result = batch.get_all_custom_meta() assert result[0]["sample_score"] == 0.1 @@ -899,7 +838,7 @@ def test_batch_meta_update_custom_meta_with_none(self): batch = BatchMeta(samples=samples) # Set initial value - batch.update_custom_meta({0: {"sample_score": 0.9}}) + batch.update_custom_meta([{"sample_score": 0.9}]) # Update with None should not change anything batch.update_custom_meta(None) @@ -907,40 +846,6 @@ def test_batch_meta_update_custom_meta_with_none(self): result = batch.get_all_custom_meta() assert result[0]["sample_score"] == 0.9 - def test_batch_meta_update_custom_meta_with_empty_dict(self): - """Test update_custom_meta with empty dict does nothing.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Set initial value - batch.update_custom_meta({0: {"sample_score": 0.9}}) - - # Update with empty dict should not change anything - batch.update_custom_meta({}) - - result = batch.get_all_custom_meta() - assert result[0]["sample_score"] == 0.9 - - def test_batch_meta_update_custom_meta_invalid_global_index(self): - """Test update_custom_meta raises error for invalid global_index.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Try to update with non-existent global index - with pytest.raises(ValueError) as exc_info: - batch.update_custom_meta({999: {"sample_score": 0.9}}) - assert "non-exist global_indexes" in str(exc_info.value) - def test_batch_meta_clear_custom_meta(self): """Test clear_custom_meta removes all custom metadata.""" fields = { @@ -953,14 +858,13 @@ def test_batch_meta_clear_custom_meta(self): batch = BatchMeta(samples=samples) # Set custom_meta - batch.set_custom_meta(0, {"sample_score": 0.9}) - batch.set_custom_meta(1, {"sample_score": 0.1}) + batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}]) # Clear all batch.clear_custom_meta() result = batch.get_all_custom_meta() - assert result == {} + assert result == [{}, {}] def test_batch_meta_get_all_custom_meta_returns_deep_copy(self): """Test get_all_custom_meta returns a deep copy.""" @@ -972,7 +876,7 @@ def test_batch_meta_get_all_custom_meta_returns_deep_copy(self): ] batch = BatchMeta(samples=samples) - custom_meta = {0: {"sample_score": 0.9, "nested": {"value": 1}}} + custom_meta = [{"sample_score": 0.9, "nested": {"value": 1}}] batch.update_custom_meta(custom_meta) # Get all custom_meta @@ -997,7 +901,7 @@ def test_batch_meta_get_all_custom_meta_empty(self): batch = BatchMeta(samples=samples) result = batch.get_all_custom_meta() - assert result == {} + assert result == [{}] def test_batch_meta_custom_meta_with_nested_data(self): """Test custom_meta supports nested dictionary data.""" @@ -1013,7 +917,7 @@ def test_batch_meta_custom_meta_with_nested_data(self): "model_info": {"name": "llama", "version": "7b", "config": {"hidden_size": 4096, "num_layers": 32}}, "tags": ["training", "inference"], } - batch.set_custom_meta(0, nested_meta) + batch.update_custom_meta([nested_meta]) result = batch.get_all_custom_meta() assert result[0]["model_info"]["name"] == "llama" diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index e925902..f0421c0 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -15,12 +15,35 @@ import os -from .client import ( - TransferQueueClient, - process_zmq_server_info, -) +from .client import TransferQueueClient from .controller import TransferQueueController from .dataloader import StreamingDataLoader, StreamingDataset +from .interface import ( + async_clear_partition, + async_clear_samples, + async_get_data, + async_get_meta, + async_kv_batch_put, + async_kv_clear, + async_kv_get, + async_kv_list, + async_kv_put, + async_put, + async_set_custom_meta, + clear_partition, + clear_samples, + close, + get_data, + get_meta, + init, + kv_batch_put, + kv_clear, + kv_get, + kv_list, + kv_put, + put, + set_custom_meta, +) from .metadata import BatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler @@ -28,9 +51,34 @@ from .sampler.sequential_sampler import SequentialSampler from .storage import SimpleStorageUnit from .utils.common import get_placement_group -from .utils.zmq_utils import ZMQServerInfo +from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info __all__ = [ + "init", + "get_meta", + "get_data", + "put", + "set_custom_meta", + "clear_samples", + "clear_partition", + "async_get_meta", + "async_get_data", + "async_put", + "async_set_custom_meta", + "async_clear_samples", + "async_clear_partition", + "close", + "kv_put", + "kv_batch_put", + "kv_get", + "kv_list", + "kv_clear", + "async_kv_put", + "async_kv_batch_put", + "async_kv_get", + "async_kv_list", + "async_kv_clear", +] + [ "TransferQueueClient", "StreamingDataset", "StreamingDataLoader", diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 24b4e55..f0c899b 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -18,23 +18,19 @@ import os import threading from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional from uuid import uuid4 -import ray import torch import zmq import zmq.asyncio from tensordict import TensorDict from torch import Tensor -from transfer_queue.controller import TransferQueueController from transfer_queue.metadata import ( BatchMeta, ) from transfer_queue.storage import ( - SimpleStorageUnit, - TransferQueueStorageManager, TransferQueueStorageManagerFactory, ) from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads @@ -95,11 +91,12 @@ def initialize_storage_manager( AsyncSimpleStorageManager, KVStorageManager (under development), etc. config: Configuration dictionary for the storage manager. For AsyncSimpleStorageManager, must contain the following required keys: - - controller_info: ZMQ server information about the controller - - storage_unit_infos: ZMQ server information about the storage units + - zmq_info: ZMQ server information about the storage units """ - self.storage_manager = TransferQueueStorageManagerFactory.create(manager_type, config) + self.storage_manager = TransferQueueStorageManagerFactory.create( + manager_type, controller_info=self._controller, config=config + ) # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. @staticmethod @@ -156,6 +153,7 @@ async def wrapper(self, *args, **kwargs): return decorator + # ==================== Basic API ==================== @dynamic_socket(socket_name="request_handle_socket") async def async_get_meta( self, @@ -218,7 +216,7 @@ async def async_get_meta( """ assert socket is not None request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_META, + request_type=ZMQRequestType.GET_META, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -294,10 +292,13 @@ async def async_set_custom_meta( partition_custom_meta: dict[str, dict[int, dict]] = {pid: {} for pid in set(metadata.partition_ids)} for meta in metadata_chunks: - partition_custom_meta[meta.partition_ids[0]].update(meta.get_all_custom_meta()) + custom_meta = meta.get_all_custom_meta() + partition_custom_meta[meta.partition_ids[0]].update( + {meta.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta))} + ) request_msg = ZMQMessage.create( - request_type=ZMQRequestType.SET_CUSTOM_META, + request_type=ZMQRequestType.SET_CUSTOM_META, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -535,7 +536,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): """ request_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_META, + request_type=ZMQRequestType.CLEAR_META, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={"global_indexes": metadata.global_indexes, "partition_ids": metadata.partition_ids}, @@ -563,7 +564,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta RuntimeError: If controller returns error response """ request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_PARTITION_META, + request_type=ZMQRequestType.GET_PARTITION_META, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={"partition_id": partition_id}, @@ -608,6 +609,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None): if response_msg.request_type != ZMQRequestType.CLEAR_PARTITION_RESPONSE: raise RuntimeError(f"Failed to clear partition {partition_id} in controller.") + # ==================== Status Query API ==================== @dynamic_socket(socket_name="request_handle_socket") async def async_get_consumption_status( self, @@ -641,7 +643,7 @@ async def async_get_consumption_status( assert socket is not None request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_CONSUMPTION, + request_type=ZMQRequestType.GET_CONSUMPTION, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -678,7 +680,7 @@ async def async_get_production_status( partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, ) -> tuple[Optional[Tensor], Optional[Tensor]]: - """Get production status for current partition for specific fields. + """Get production status for specific data fields and partition. Args: data_fields: Data fields to check production status for @@ -703,7 +705,7 @@ async def async_get_production_status( """ assert socket is not None request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_PRODUCTION, + request_type=ZMQRequestType.GET_PRODUCTION, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -768,6 +770,41 @@ async def async_check_consumption_status( return False return torch.all(consumption_status == 1).item() + async def async_check_production_status( + self, + data_fields: list[str], + partition_id: str, + ) -> bool: + """Check if the all specific fields of samples for current partition are ready + (produced) for consumption. + + Args: + data_fields: Data fields to check production status for + partition_id: Partition id to check production status for + + Returns: + bool: True if all samples have been produced and ready, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Check if all samples are ready for consumption + >>> is_ready = asyncio.run(client.async_check_production_status( + ... data_fields=["input_ids", "attention_mask"], + ... partition_id="train_0" + ... )) + >>> print(f"All samples ready: {is_ready}") + """ + _, production_status = await self.async_get_production_status( + data_fields=data_fields, + partition_id=partition_id, + ) + + if production_status is None: + return False + return torch.all(production_status == 1).item() + @dynamic_socket(socket_name="request_handle_socket") async def async_reset_consumption( self, @@ -775,17 +812,22 @@ async def async_reset_consumption( task_name: Optional[str] = None, socket: Optional[zmq.asyncio.Socket] = None, ) -> bool: - """Reset consumption status for a partition, allowing data to be re-consumed. - This is useful for debugging scenarios where the same rollout data needs to be - trained multiple times without regenerating the data. + """Aynchronously reset consumption status for a partition. + + This allows the same data to be re-consumed, useful for debugging scenarios + where the same rollout data needs to be trained multiple times. + Args: partition_id: Partition id to reset consumption status for task_name: Name of the task to reset. If None, resets all tasks. socket: ZMQ async socket for message transmission (injected by decorator) + Returns: bool: True if reset was successful, False otherwise + Raises: RuntimeError: If communication fails or controller returns error response + Example: >>> # Reset consumption for train task to re-train on same data >>> success = asyncio.run(client.async_reset_consumption( @@ -799,7 +841,7 @@ async def async_reset_consumption( if task_name is not None: body["task_name"] = task_name request_msg = ZMQMessage.create( - request_type=ZMQRequestType.RESET_CONSUMPTION, + request_type=ZMQRequestType.RESET_CONSUMPTION, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body=body, @@ -825,41 +867,6 @@ async def async_reset_consumption( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e - async def async_check_production_status( - self, - data_fields: list[str], - partition_id: str, - ) -> bool: - """Check if the all specific fields of samples for current partition are ready - (produced) for consumption. - - Args: - data_fields: Data fields to check production status for - partition_id: Partition id to check production status for - - Returns: - bool: True if all samples have been produced and ready, False otherwise - - Raises: - RuntimeError: If communication fails or controller returns error response - - Example: - >>> # Check if all samples are ready for consumption - >>> is_ready = asyncio.run(client.async_check_production_status( - ... data_fields=["input_ids", "attention_mask"], - ... partition_id="train_0" - ... )) - >>> print(f"All samples ready: {is_ready}") - """ - _, production_status = await self.async_get_production_status( - data_fields=data_fields, - partition_id=partition_id, - ) - - if production_status is None: - return False - return torch.all(production_status == 1).item() - @dynamic_socket(socket_name="request_handle_socket") async def async_get_partition_list( self, @@ -872,9 +879,13 @@ async def async_get_partition_list( Returns: list[str]: List of partition ids managed by the controller + + Example: + >>> partition_ids = asyncio.run(client.get_partition_list()) + >>> print(f"Available partitions: {partition_ids}") """ request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_LIST_PARTITIONS, + request_type=ZMQRequestType.GET_LIST_PARTITIONS, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={}, @@ -901,6 +912,123 @@ async def async_get_partition_list( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e + # ==================== KV Interface API ==================== + @dynamic_socket(socket_name="request_handle_socket") + async def async_kv_retrieve_keys( + self, + keys: list[str] | str, + partition_id: str, + create: bool = False, + socket: Optional[zmq.asyncio.Socket] = None, + ) -> BatchMeta: + """Asynchronously retrieve BatchMeta from the controller using user-specified keys. + + Args: + keys: List of keys to retrieve from the controller + partition_id: The ID of the logical partition to search for keys. + create: Whether to register new keys if not found. + socket: ZMQ socket (injected by decorator) + + Returns: + metadata: BatchMeta of the corresponding keys + + Raises: + TypeError: If `keys` is not a list of string or a string + """ + + if isinstance(keys, str): + keys = [keys] + elif isinstance(keys, list): + if len(keys) < 1: + raise ValueError("Received an empty list as keys.") + # validate all the elements are str + if not all(isinstance(k, str) for k in keys): + raise TypeError("Not all element in `keys` are strings.") + else: + raise TypeError("Only string or list of strings are allowed as `keys`.") + + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.KV_RETRIEVE_KEYS, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={ + "keys": keys, + "partition_id": partition_id, + "create": create, + }, + ) + + try: + assert socket is not None + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart() + response_msg = ZMQMessage.deserialize(response_serialized) + logger.debug( + f"[{self.client_id}]: Client get kv_retrieve_keys response: {response_msg} " + f"from controller {self._controller.id}" + ) + + if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE: + metadata = response_msg.body.get("metadata", BatchMeta.empty()) + return metadata + else: + raise RuntimeError( + f"[{self.client_id}]: Failed to retrieve keys from controller {self._controller.id}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e + + @dynamic_socket(socket_name="request_handle_socket") + async def async_kv_list( + self, + partition_id: str, + socket: Optional[zmq.asyncio.Socket] = None, + ) -> tuple[list[str], list[dict]]: + """Asynchronously retrieve keys and custom_meta from the controller for partition. + + Args: + partition_id: Partition to retrieve from the controller + socket: ZMQ socket (injected by decorator) + + Returns: + keys: list of keys in the partition + custom_meta: list of dict for custom_meta + """ + + if partition_id is None: + return [], [] + + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.KV_LIST, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={ + "partition_id": partition_id, + }, + ) + + try: + assert socket is not None + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart() + response_msg = ZMQMessage.deserialize(response_serialized) + logger.debug( + f"[{self.client_id}]: Client get kv_list response: {response_msg} from controller {self._controller.id}" + ) + + if response_msg.request_type == ZMQRequestType.KV_LIST_RESPONSE: + keys = response_msg.body.get("keys", []) + custom_meta = response_msg.body.get("custom_meta", []) + return keys, custom_meta + else: + raise RuntimeError( + f"[{self.client_id}]: Failed to list keys from controller {self._controller.id}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e + def close(self) -> None: """Close the client and cleanup resources including storage manager.""" try: @@ -975,72 +1103,16 @@ def wrapper(*args, **kwargs): self._get_partition_list = _make_sync(self.async_get_partition_list) self._set_custom_meta = _make_sync(self.async_set_custom_meta) self._reset_consumption = _make_sync(self.async_reset_consumption) + self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys) + self._kv_list = _make_sync(self.async_kv_list) - def put( - self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None - ) -> BatchMeta: - """Synchronously write data to storage units based on metadata. - - If metadata is not provided, it will be created automatically using insert mode - with the provided data fields and partition_id. - - During put, the custom_meta in metadata will update the corresponding custom_meta in - TransferQueue Controller. - - Note: - When using multiple workers for distributed execution, there may be data - ordering inconsistencies between workers during put operations. - - Args: - data: Data to write as TensorDict - metadata: Records the metadata of a batch of data samples, containing index and - storage unit information. If None, metadata will be auto-generated. - partition_id: Target data partition id (required if metadata is not provided) - - Returns: - BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved - metadata; will be updated in a future version to reflect the post-put state) - - Raises: - ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided - RuntimeError: If storage operation fails - - Example: - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_partition_id = "train_0" - >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = client.get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=batch_size, - ... partition_id=current_partition_id, - ... mode="fetch", - ... task_name="generate_sequences", - ... ) - >>> batch = client.get_data(batch_meta) - >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> client.put(data=output, metadata=batch_meta) - >>> - >>> # Example 2: Initial data insertion without pre-existing metadata - >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! - >>> # Please make sure the corresponding partition_id is empty before calling the async_put() - >>> # without metadata. - >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with - >>> # interleave the initial data if n_sample > 1 before calling the async_put(). - >>> original_prompts = torch.randn(batch_size, seq_len) - >>> n_samples = 4 - >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) - >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) - >>> # This will create metadata in "insert" mode internally. - >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) - """ - return self._put(data=data, metadata=metadata, partition_id=partition_id) - + # ==================== Basic API ==================== def get_meta( self, data_fields: list[str], batch_size: int, partition_id: str, + mode: str = "fetch", task_name: Optional[str] = None, sampling_config: Optional[dict[str, Any]] = None, ) -> BatchMeta: @@ -1098,10 +1170,95 @@ def get_meta( data_fields=data_fields, batch_size=batch_size, partition_id=partition_id, + mode=mode, task_name=task_name, sampling_config=sampling_config, ) + def set_custom_meta(self, metadata: BatchMeta) -> None: + """Synchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() or + BatchMeta.set_custom_meta() before calling this method. + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Create batch with custom metadata + >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> client.set_custom_meta(batch_meta) + """ + + return self._set_custom_meta(metadata=metadata) + + def put( + self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None + ) -> BatchMeta: + """Synchronously write data to storage units based on metadata. + + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. + + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. + + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. + + Args: + data: Data to write as TensorDict + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. + partition_id: Target data partition id (required if metadata is not provided) + + Returns: + BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved + metadata; will be updated in a future version to reflect the post-put state) + + Raises: + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails + + Example: + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = client.get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = client.get_data(batch_meta) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> client.put(data=output, metadata=batch_meta) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) + """ + return self._put(data=data, metadata=metadata, partition_id=partition_id) + def get_data(self, metadata: BatchMeta) -> TensorDict: """Synchronously fetch data from storage units and organize into TensorDict. @@ -1148,65 +1305,85 @@ def clear_samples(self, metadata: BatchMeta): """ return self._clear_samples(metadata=metadata) - def check_consumption_status(self, task_name: str, partition_id: str) -> bool: - """Synchronously check if all samples for a partition have been consumed by a specific task. + # ==================== Status Query API ==================== + def get_consumption_status( + self, + task_name: str, + partition_id: str, + ) -> tuple[Optional[Tensor], Optional[Tensor]]: + """Synchronously get consumption status for a specific task and partition. Args: task_name: Name of the task to check consumption for partition_id: Partition id to check consumption status for Returns: - bool: True if all samples have been consumed by the task, False otherwise + Tuple of: + - Partition global index tensor + - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. Raises: RuntimeError: If communication fails or controller returns error response Example: - >>> # Check if all samples have been consumed - >>> is_consumed = client.check_consumption_status( + >>> global_index, consumption_status = client.get_consumption_status( ... task_name="generate_sequences", ... partition_id="train_0" ... ) - >>> print(f"All samples consumed: {is_consumed}") + >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}") """ - return self._check_consumption_status(task_name=task_name, partition_id=partition_id) + return self._get_consumption_status(task_name, partition_id) - def get_consumption_status( + def get_production_status( self, - task_name: str, + data_fields: list[str], partition_id: str, ) -> tuple[Optional[Tensor], Optional[Tensor]]: - """Synchronously get consumption status for a specific task and partition. + """Synchronously get production status for specific data fields and partition. Args: - task_name: Name of the task to check consumption for - partition_id: Partition id to check consumption status for + data_fields: Data fields to check production status for + partition_id: Partition id to check production status for Returns: Tuple of: - Partition global index tensor - - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. + - Production status tensor for the specified fields. 1 for ready, 0 for not ready. + + Raises: + RuntimeError: If communication fails or controller returns error response Example: - >>> global_index, consumption_status = client.get_consumption_status( - ... task_name="generate_sequences", + >>> global_index, production_status = client.get_production_status( + ... data_fields=["input_ids", "attention_mask"], ... partition_id="train_0" ... ) - >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}") + >>> print(f"Global index: {global_index}, Production status: {production_status}") """ - return self._get_consumption_status(task_name, partition_id) + return self._get_production_status(data_fields=data_fields, partition_id=partition_id) + + def check_consumption_status(self, task_name: str, partition_id: str) -> bool: + """Synchronously check if all samples for a partition have been consumed by a specific task. - def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool: - """Synchronously reset consumption status for a partition. - This allows the same data to be re-consumed, useful for debugging scenarios - where the same rollout data needs to be trained multiple times. Args: - partition_id: Partition id to reset consumption status for - task_name: Name of the task to reset. If None, resets all tasks. + task_name: Name of the task to check consumption for + partition_id: Partition id to check consumption status for + Returns: - bool: True if reset was successful, False otherwise + bool: True if all samples have been consumed by the task, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Check if all samples have been consumed + >>> is_consumed = client.check_consumption_status( + ... task_name="generate_sequences", + ... partition_id="train_0" + ... ) + >>> print(f"All samples consumed: {is_consumed}") """ - return self._reset_consumption(partition_id, task_name) + return self._check_consumption_status(task_name=task_name, partition_id=partition_id) def check_production_status(self, data_fields: list[str], partition_id: str) -> bool: """Synchronously check if all samples for a partition are ready (produced) for consumption. @@ -1215,6 +1392,9 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> data_fields: Data fields to check production status for partition_id: Partition id to check production status for + Returns: + bool: True if all samples have been produced and ready, False otherwise + Raises: RuntimeError: If communication fails or controller returns error response @@ -1228,30 +1408,31 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> """ return self._check_production_status(data_fields=data_fields, partition_id=partition_id) - def get_production_status( - self, - data_fields: list[str], - partition_id: str, - ) -> tuple[Optional[Tensor], Optional[Tensor]]: - """Synchronously get production status for a specific data fields and partition. + def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool: + """Synchronously reset consumption status for a partition. + + This allows the same data to be re-consumed, useful for debugging scenarios + where the same rollout data needs to be trained multiple times. Args: - data_fields: Data fields to check production status for - partition_id: Partition id to check production status for + partition_id: Partition id to reset consumption status for + task_name: Name of the task to reset. If None, resets all tasks. Returns: - Tuple of: - - Partition global index tensor - - Production status tensor for the specified fields. 1 for ready, 0 for not ready. + bool: True if reset was successful, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response Example: - >>> global_index, production_status = client.get_production_status( - ... data_fields=["input_ids", "attention_mask"], - ... partition_id="train_0" + >>> # Reset consumption for train task to re-train on same data + >>> success = client.reset_consumption( + ... partition_id="train_0", + ... task_name="train" ... ) - >>> print(f"Global index: {global_index}, Production status: {production_status}") + >>> print(f"Reset successful: {success}") """ - return self._get_production_status(data_fields=data_fields, partition_id=partition_id) + return self._reset_consumption(partition_id, task_name) def get_partition_list( self, @@ -1260,32 +1441,51 @@ def get_partition_list( Returns: list[str]: List of partition ids managed by the controller + + Example: + >>> partition_ids = client.get_partition_list() + >>> print(f"Available partitions: {partition_ids}") """ return self._get_partition_list() - def set_custom_meta(self, metadata: BatchMeta) -> None: - """Synchronously send custom metadata to the controller. - - This method sends per-sample custom metadata (custom_meta) to the controller. - The custom_meta is stored in the controller and can be retrieved along with - the BatchMeta in subsequent get_meta calls. + # ==================== KV Interface API ==================== + def kv_retrieve_keys( + self, + keys: list[str] | str, + partition_id: str, + create: bool = False, + ) -> BatchMeta: + """Synchronously retrieve BatchMeta from the controller using user-specified keys. Args: - metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() or - BatchMeta.set_custom_meta() before calling this method. + keys: List of keys to retrieve from the controller + partition_id: The ID of the logical partition to search for keys. + create: Whether to register new keys if not found. + + Returns: + metadata: BatchMeta of the corresponding keys Raises: - RuntimeError: If communication fails or controller returns error response + TypeError: If `keys` is not a list of string or a string + """ - Example: - >>> # Create batch with custom metadata - >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) - >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) - >>> client.set_custom_meta(batch_meta) + return self._kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create) + + def kv_list( + self, + partition_id: str, + ) -> tuple[list[Optional[str]], list[Optional[dict]]]: + """Synchronously retrieve keys and custom_meta from the controller for partition. + + Args: + partition_id: Partition to retrieve from the controller + + Returns: + keys: list of keys in the partition + custom_meta: list of dict for custom_meta """ - return self._set_custom_meta(metadata=metadata) + return self._kv_list(partition_id=partition_id) def close(self) -> None: """Close the client and cleanup resources including event loop and thread.""" @@ -1304,36 +1504,3 @@ def close(self) -> None: logger.warning(f"[{self.client_id}]: Error closing event loop: {e}") super().close() - - -def process_zmq_server_info( - handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]] - | Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"], -): # noqa: UP007 - """Extract ZMQ server information from handler objects. - - Args: - handlers: Dictionary of handler objects (controllers, storage managers, or storage units), - or a single handler object - - Returns: - If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information - If handlers is a single object: ZMQ server information for that object - - Examples: - >>> # Single handler - >>> controller = TransferQueueController.remote(...) - >>> info = process_zmq_server_info(controller) - >>> - >>> # Multiple handlers - >>> handlers = {"storage_0": storage_0, "storage_1": storage_1} - >>> info_dict = process_zmq_server_info(handlers)""" - # Handle single handler object case - if not isinstance(handlers, dict): - return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] - else: - # Handle dictionary case - server_info = {} - for name, handler in handlers.items(): - server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] - return server_info diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml new file mode 100644 index 0000000..9503afa --- /dev/null +++ b/transfer_queue/config.yaml @@ -0,0 +1,31 @@ +# This is the default configuration of TransferQueue. Users may modify the default value +# and use transfer_queue.init(conf) to overwrite the config entries. + +controller: + # User-defined sampler. User can pass sampler instance to overwrite this string config. + sampler: SequentialSampler + # Whether return an empty BatchMeta to prevent request blocking when no enough data is available + polling_mode: False + # ZMQ Server IP & Ports (automatically generated during init) + zmq_info: null + + +backend: + # Pluggable storage/transport backend of TransferQueue. Choose from: + # SimpleStorage, Yuanrong, MooncakeStore, ... + storage_backend: SimpleStorage + + # For SimpleStorage: + SimpleStorage: + # Total number of samples + total_storage_size: 100000 + # Number of distributed storage units for SimpleStorage backend + num_data_storage_units: 2 + # ZMQ Server IP & Ports (automatically generated during init) + zmq_info: null + + # For Yuanrong: + # TODO + + # For MooncakeStore: + # TODO \ No newline at end of file diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index c91eaf7..6175863 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -28,6 +28,7 @@ import ray import torch import zmq +from omegaconf import DictConfig from ray.util import get_node_ip_address from torch import Tensor @@ -213,7 +214,7 @@ class DataPartitionStatus: # Each tensor tracks which samples have been consumed by that task consumption_status: dict[str, Tensor] = field(default_factory=dict) - # Sample metadata + # Global indexes global_indexes: set[int] = field( default_factory=set ) # set of global indexes that have been added to this partition @@ -232,6 +233,10 @@ class DataPartitionStatus: # User-defined metadata that may not apply to field level custom_meta: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {} + # User-defined Keys + keys_mapping: dict[str, int] = field(default_factory=dict) # key -> global_idx + revert_keys_mapping: dict[int, str] = field(default_factory=dict) # global_idx -> key + # Threading lock for concurrency control; only for preventing mask operation error when expanding production_status. # No need to strictly lock for every read/write operation since freshness is not critical. data_status_lock: Lock = field(default_factory=Lock) @@ -824,6 +829,8 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr self.field_shapes.pop(idx, None) self.field_custom_backend_meta.pop(idx, None) self.custom_meta.pop(idx, None) + self.keys_mapping.pop(self.revert_keys_mapping[idx], None) + self.revert_keys_mapping.pop(idx, None) except Exception as e: logger.error( @@ -831,6 +838,11 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr f"Attempted to clear global_indexes: {indexes_to_release}" ) + def kv_retrieve_keys(self, keys: list[str]) -> list[int | None]: + """Translate the user-specified keys to global_indexes""" + global_indexes = [self.keys_mapping.get(k, None) for k in keys] + return global_indexes + @ray.remote(num_cpus=1) class TransferQueueController: @@ -879,6 +891,7 @@ def __init__( self.controller_id = f"TQ_CONTROLLER_{uuid4().hex[:8]}" self.polling_mode = polling_mode + self.tq_config = None # global config for TransferQueue system # Initialize ZMQ sockets for communication self._init_zmq_socket() @@ -979,6 +992,8 @@ def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int] Returns: List of indexes allocated to the partition """ + # Note: This includes the pre-allocated global_indexes for the partition. + # i.e., partition.global_indexes + partition.pre_allocated_global_indexes return self.index_manager.get_indexes_for_partition(partition) # ==================== Data Production API ==================== @@ -1158,9 +1173,10 @@ def get_metadata( ) batch_global_indexes.extend(new_global_indexes) + # register global_indexes in partition + partition.global_indexes.update(batch_global_indexes) + else: - # TODO: separate this "clear" related logic into a separated mode - # clear metadata call passes empty data_fields batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) @@ -1312,7 +1328,7 @@ def generate_batch_meta( and partition.production_status is not None and partition.production_status[global_index, field_index] == 1 ): - production_status = ProductionStatus.NOT_PRODUCED + production_status = ProductionStatus.READY_FOR_CONSUME dtype = partition.get_field_dtype(global_index, field_name) shape = partition.get_field_shape(global_index, field_name) else: @@ -1338,7 +1354,7 @@ def generate_batch_meta( custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields) batch_meta = BatchMeta(samples=samples) - batch_meta.update_custom_meta(custom_meta) + batch_meta.update_custom_meta([custom_meta.get(idx, {}) for idx in batch_meta.global_indexes]) batch_meta._custom_backend_meta.update(custom_backend_meta) return batch_meta @@ -1355,7 +1371,8 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True): partition = self._get_partition(partition_id) if not partition: - raise ValueError(f"Partition {partition_id} not found") + logger.warning(f"Try to clear an non-existent partition {partition_id}!") + return global_indexes_range = list(self.index_manager.get_indexes_for_partition(partition_id)) partition.clear_data(global_indexes_range, clear_consumption) @@ -1372,13 +1389,13 @@ def reset_consumption(self, partition_id: str, task_name: Optional[str] = None): Args: partition_id: ID of the partition to reset consumption for task_name: Name of the task to reset. If None, resets all tasks. - Raises: - ValueError: If partition not found + """ logger.debug(f"[{self.controller_id}]: Resetting consumption for partition {partition_id}, task={task_name}") partition = self._get_partition(partition_id) if not partition: - raise ValueError(f"Partition {partition_id} not found") + logger.warning(f"Try to reset consumption of an non-existent partition {partition_id}!") + return partition.reset_consumption(task_name) def clear_meta( @@ -1430,6 +1447,71 @@ def clear_meta( # Release the specific indexes from index manager self.index_manager.release_indexes(partition_id, global_indexes_to_clear) + def kv_retrieve_keys( + self, + keys: list[str], + partition_id: str, + create: bool = False, + ) -> BatchMeta: + """ + Retrieve BatchMeta from the controller using a list of keys. + + Args: + keys: List of keys to retrieve from the controller + partition_id: Partition id to retrieve from the controller + create: Whether to register new keys if not found. + + Returns: + metadata: BatchMeta of the requested keys + """ + + logger.debug(f"[{self.controller_id}]: Retrieve keys {keys} in partition {partition_id}") + + partition = self._get_partition(partition_id) + + if partition is None: + if not create: + logger.warning(f"Partition {partition_id} were not found in controller!") + return BatchMeta.empty() + else: + self.create_partition(partition_id) + partition = self._get_partition(partition_id) + + assert partition is not None + global_indexes = partition.kv_retrieve_keys(keys) + + none_indexes = [idx for idx, value in enumerate(global_indexes) if value is None] + if len(none_indexes) > 0: + if not create: + logger.warning(f"Keys {[keys[i] for i in none_indexes]} were not found in partition {partition_id}!") + return BatchMeta.empty() + else: + # create non-exist keys + batch_global_indexes = partition.activate_pre_allocated_indexes(len(none_indexes)) + + if len(batch_global_indexes) < len(none_indexes): + new_global_indexes = self.index_manager.allocate_indexes( + partition_id, count=(len(none_indexes) - len(batch_global_indexes)) + ) + batch_global_indexes.extend(new_global_indexes) + + # register global_indexes in partition + partition.global_indexes.update(batch_global_indexes) + + # register key-global_indexes mapping in partition + for i in range(len(none_indexes)): + global_indexes[none_indexes[i]] = batch_global_indexes[i] + partition.keys_mapping[keys[none_indexes[i]]] = batch_global_indexes[i] + partition.revert_keys_mapping[batch_global_indexes[i]] = keys[none_indexes[i]] + + verified_global_indexes = [idx for idx in global_indexes if idx is not None] + assert len(verified_global_indexes) == len(keys) + + data_fields = list(partition.field_name_mapping.keys()) + metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch") + + return metadata + def _init_zmq_socket(self): """Initialize ZMQ sockets for communication.""" self.zmq_context = zmq.Context() @@ -1725,6 +1807,41 @@ def _process_request(self): body={"partition_ids": partition_ids}, ) + elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS: + with perf_monitor.measure(op_type="KV_RETRIEVE_KEYS"): + keys = params["keys"] + partition_id = params["partition_id"] + create = params["create"] + + metadata = self.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"metadata": metadata}, + ) + + elif request_msg.request_type == ZMQRequestType.KV_LIST: + with perf_monitor.measure(op_type="KV_LIST"): + partition_id = params["partition_id"] + partition = self._get_partition(partition_id) + if not partition: + keys = [] + custom_meta = [] + message = f"Partition {partition_id} not found for kv_list." + logger.debug(f"[{self.controller_id}]: {message}") + else: + keys = list(partition.keys_mapping.keys()) + custom_meta = [partition.custom_meta.get(partition.keys_mapping[k], {}) for k in keys] + message = "Success" + + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.KV_LIST_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"keys": keys, "custom_meta": custom_meta, message: message}, + ) + self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) def _update_data_status(self): @@ -1772,3 +1889,35 @@ def _update_data_status(self): def get_zmq_server_info(self) -> ZMQServerInfo: """Get ZMQ server connection information.""" return self.zmq_server_info + + def store_config(self, conf: DictConfig) -> None: + """Store the global config of TransferQueue.""" + self.tq_config = conf + + def get_config(self) -> DictConfig: + """Retrieve the global config of TransferQueue.""" + return self.tq_config + + def register_sampler( + self, + sampler: BaseSampler | type[BaseSampler] = SequentialSampler, + ) -> None: + """ + Register a sampler instance or subclass after the controller is initialized. + + Args: + sampler: Sampler instance or sampler class to use for data sampling. + - If a BaseSampler instance is provided, it will be used directly + - If a BaseSampler subclass is provided, it will be instantiated + - Defaults to SequentialSampler for simple sequential sampling + - Example: sampler=GRPOGroupNSampler() (instance) + - Example: sampler=SequentialSampler (class) + """ + if isinstance(sampler, BaseSampler): + self.sampler = sampler + elif isinstance(sampler, type) and issubclass(sampler, BaseSampler): + self.sampler = sampler() + else: + raise TypeError( + f"sampler {getattr(sampler, '__name__', repr(sampler))} must be an instance or subclass of BaseSampler" + ) diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index eb2afe4..1de90c5 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -17,8 +17,10 @@ import os import time import uuid -from typing import Any, Callable, Iterator +import warnings +from typing import Callable, Iterator +from omegaconf import DictConfig from tensordict import TensorDict from torch.utils.data import IterableDataset @@ -68,7 +70,7 @@ class StreamingDataset(IterableDataset): def __init__( self, - config: dict[str, Any], + config: DictConfig, batch_size: int, micro_batch_size: int, data_fields: list[str], @@ -82,8 +84,8 @@ def __init__( Args: config: Configuration dictionary containing: - - controller_info: ZMQServerInfo for the TransferQueueController - - storage_backend: Storage backend type (e.g., "AsyncSimpleStorageManager") + - controller.controller_info: ZMQServerInfo for the TransferQueueController + - backend.storage_backend: Storage backend type (e.g., "SimpleStorage") - Other backend-specific configuration batch_size: Batch size for data loading per iter. micro_batch_size: Number of samples per micro-batch. This is the batch size @@ -156,16 +158,44 @@ def _create_client(self): ValueError: If controller_info or storage_backend is missing or invalid. """ client_id = uuid.uuid4().hex[:8] - controller_info = self.config.get("controller_info", None) + + # TODO: DEPRECATE in future + controller_config = self.config.get("controller", None) + if controller_config: + controller_info = controller_config.get("zmq_info", None) + else: + controller_info = self.config.get("controller_info", None) + if controller_info: + warnings.warn( + "Config entry `controller_info` will be deprecated in 0.1.7, please " + "use `controller.zmq_info` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + if not controller_info or not isinstance(controller_info, ZMQServerInfo): - raise ValueError("Invalid or missing controller_info in config") + raise ValueError("Invalid or missing controller.zmq_info in config") + + backend_config = self.config.get("backend", None) + if not backend_config: + storage_backend = self.config.get("storage_backend", None) + backend_config = self.config + if storage_backend: + warnings.warn( + "Config entry `storage_backend` will be deprecated in 0.1.7, please " + "use `backend.storage_backend` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + else: + storage_backend = backend_config.get("storage_backend", None) + backend_config = self.config.backend[storage_backend] - storage_backend = self.config.get("storage_backend", None) if not storage_backend: raise ValueError("Missing storage_backend in config") self._tq_client = TransferQueueClient(client_id, controller_info) - self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=self.config) + self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config) def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: """Iterate over the dataset, yielding batches of data. diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py new file mode 100644 index 0000000..2585c35 --- /dev/null +++ b/transfer_queue/interface.py @@ -0,0 +1,1136 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.resources as pkg_resources +import logging +import math +import os +import time +from typing import Any, Optional + +import ray +import torch +from omegaconf import DictConfig, OmegaConf +from tensordict import TensorDict +from tensordict.tensorclass import NonTensorStack + +from transfer_queue.client import TransferQueueClient +from transfer_queue.controller import TransferQueueController +from transfer_queue.metadata import BatchMeta +from transfer_queue.sampler import * # noqa: F401 +from transfer_queue.sampler import BaseSampler +from transfer_queue.storage.simple_backend import SimpleStorageUnit +from transfer_queue.utils.common import get_placement_group +from transfer_queue.utils.zmq_utils import process_zmq_server_info + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + +_TRANSFER_QUEUE_CLIENT: Any = None +_TRANSFER_QUEUE_STORAGE: Any = None + + +def _maybe_create_transferqueue_client( + conf: Optional[DictConfig] = None, +) -> TransferQueueClient: + global _TRANSFER_QUEUE_CLIENT + if _TRANSFER_QUEUE_CLIENT is None: + if conf is None: + raise ValueError("Missing config for initializing TransferQueueClient!") + pid = os.getpid() + _TRANSFER_QUEUE_CLIENT = TransferQueueClient( + client_id=f"TransferQueueClient_{pid}", controller_info=conf.controller.zmq_info + ) + + backend_name = conf.backend.storage_backend + + _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name]) + + return _TRANSFER_QUEUE_CLIENT + + +def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: + global _TRANSFER_QUEUE_STORAGE + + if _TRANSFER_QUEUE_STORAGE is None: + _TRANSFER_QUEUE_STORAGE = {} + if conf.backend.storage_backend == "SimpleStorage": + # initialize SimpleStorageUnit + num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units + total_storage_size = conf.backend.SimpleStorage.total_storage_size + storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) + + for storage_unit_rank in range(num_data_storage_units): + storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] + placement_group=storage_placement_group, + placement_group_bundle_index=storage_unit_rank, + name=f"TransferQueueStorageUnit#{storage_unit_rank}", + lifetime="detached", + ).remote(storage_unit_size=math.ceil(total_storage_size / num_data_storage_units)) + _TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node + logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") + + storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_STORAGE) + backend_name = conf.backend.storage_backend + conf.backend[backend_name].zmq_info = storage_zmq_info + + return conf + + +def _init_from_existing() -> None: + """Initialize the TransferQueueClient from existing controller.""" + + controller = ray.get_actor("TransferQueueController") + logger.info("Found existing TransferQueueController instance. Connecting...") + + conf = None + while conf is None: + remote_conf = ray.get(controller.get_config.remote()) + if remote_conf is not None: + _maybe_create_transferqueue_client(remote_conf) + logger.info("TransferQueueClient initialized.") + return + + logger.debug("Waiting for controller to initialize... Retrying in 1s") + time.sleep(1) + + +def init(conf: Optional[DictConfig] = None) -> None: + """Initialize the TransferQueue system. + + This function sets up the TransferQueue controller, distributed storage, and client. + It should be called once at the beginning of the program before any data operations. + + If a controller already exists (e.g., initialized by another process), this function + will retrieve the config from existing controller and initialize the TransferQueueClient. + In this case, the `conf` parameter will be ignored. + + Args: + conf: Optional configuration dictionary. If provided, it will be merged with + the default config from 'config.yaml'. This is only used for first-time + initializing. When connecting to an existing controller, this parameter + is ignored. + + Raises: + ValueError: If config is not valid or required configuration keys are missing. + + Example: + >>> # In process 0, node A + >>> import transfer_queue as tq + >>> tq.init() # Initialize the TransferQueue + >>> tq.put(...) # then you can use tq for data operations + >>> + >>> # In process 1, node B (with Ray connected to node A) + >>> import transfer_queue as tq + >>> tq.init() # This will only initialize a TransferQueueClient and link with existing TQ + >>> metadata = tq.get_meta(...) + >>> data = tq.get_data(metadata) + """ + try: + _init_from_existing() + except ValueError: + logger.info("No TransferQueueController found. Starting first-time initialization...") + else: + return + + # First-time initialize TransferQueue + + # create config + final_conf = OmegaConf.create({}, flags={"allow_objects": True}) + with pkg_resources.path("transfer_queue", "config.yaml") as p: + default_conf = OmegaConf.load(p) + final_conf = OmegaConf.merge(final_conf, default_conf) + if conf: + final_conf = OmegaConf.merge(final_conf, conf) + + # create controller + try: + sampler = final_conf.controller.sampler + if isinstance(sampler, BaseSampler): + # user pass a pre-initialized sampler instance + sampler = sampler + elif isinstance(sampler, type) and issubclass(sampler, BaseSampler): + # user pass a sampler class + sampler = sampler() + elif isinstance(sampler, str): + # user pass a sampler name str + # try to convert as sampler class + sampler = globals()[final_conf.controller.sampler] + except KeyError: + raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None + + try: + # Ray will make sure actor with same name can only be created once + controller = TransferQueueController.options(name="TransferQueueController", lifetime="detached").remote( # type: ignore[attr-defined] + sampler=sampler, polling_mode=final_conf.controller.polling_mode + ) + logger.info("TransferQueueController has been created.") + except ValueError: + logger.info("Some other rank has initialized TransferQueueController. Try to connect to existing controller.") + _init_from_existing() + return + + controller_zmq_info = process_zmq_server_info(controller) + final_conf.controller.zmq_info = controller_zmq_info + + # create distributed storage backends + final_conf = _maybe_create_transferqueue_storage(final_conf) + + # store the config into controller + ray.get(controller.store_config.remote(final_conf)) + logger.info(f"TransferQueue config: {final_conf}") + + # create client + _maybe_create_transferqueue_client(final_conf) + + +# ==================== Basic API ==================== +def get_meta( + data_fields: list[str], + batch_size: int, + partition_id: str, + mode: str = "fetch", + task_name: Optional[str] = None, + sampling_config: Optional[dict[str, Any]] = None, +) -> BatchMeta: + """Synchronously fetch data metadata from the controller via ZMQ. + + Args: + data_fields: List of data field names to retrieve metadata for + batch_size: Number of samples to request in the batch + partition_id: Current data partition id + mode: Data fetch mode. Options: + - 'fetch': Get ready data only + - 'force_fetch': Get data regardless of readiness (may return unready samples) + - 'insert': Internal usage - should not be used by users + task_name: Optional task name associated with the request + sampling_config: Optional sampling configuration for custom samplers. + + Returns: + BatchMeta: Metadata object containing data structure, sample information, and readiness status + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> # Example 1: Basic fetch metadata + >>> batch_meta = tq.get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences" + ... ) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) + >>> batch_meta = tq.get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=8, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... sampling_config={"n_samples_per_prompt": 4} + ... ) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, + >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) + >>> batch_meta = tq.get_meta( + ... partition_id="train_0", # optional + ... mode="force_fetch", + ... ) + >>> print(batch_meta.is_ready) # May be False if some samples not ready + """ + + tq_client = _maybe_create_transferqueue_client() + return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) + + +def set_custom_meta(metadata: BatchMeta) -> None: + """Synchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() or + BatchMeta.set_custom_meta() before calling this method. + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> # Create batch with custom metadata + >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> tq.set_custom_meta(batch_meta) + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.set_custom_meta(metadata) + + +def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta: + """Synchronously write data to storage units based on metadata. + + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. + + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. + + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. + + Args: + data: Data to write as TensorDict + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. + partition_id: Target data partition id (required if metadata is not provided) + + Returns: + BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved + metadata; will be updated in a future version to reflect the post-put state) + + Raises: + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = tq.get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = tq.get_data(batch_meta) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> tq.put(data=output, metadata=batch_meta) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id) + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.put(data, metadata, partition_id) + + +def get_data(metadata: BatchMeta) -> TensorDict: + """Synchronously fetch data from storage units and organize into TensorDict. + + Args: + metadata: Batch metadata containing data location information and global indexes + + Returns: + TensorDict containing: + - Requested data fields (e.g., "prompts", "attention_mask") + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> batch_meta = tq.get_data( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = tq.get_data(batch_meta) + >>> print(batch) + >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.get_data(metadata) + + +def clear_partition(partition_id: str): + """Synchronously clear the whole partition from all storage units and the controller. + + Args: + partition_id: The partition id to clear data for + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.clear_partition(partition_id) + + +def clear_samples(metadata: BatchMeta): + """Synchronously clear specific samples from all storage units and the controller. + + Args: + metadata: The BatchMeta of the corresponding data to be cleared + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.clear_samples(metadata) + + +async def async_get_meta( + data_fields: list[str], + batch_size: int, + partition_id: str, + mode: str = "fetch", + task_name: Optional[str] = None, + sampling_config: Optional[dict[str, Any]] = None, +) -> BatchMeta: + """Asynchronously fetch data metadata from the controller via ZMQ. + + Args: + data_fields: List of data field names to retrieve metadata for + batch_size: Number of samples to request in the batch + partition_id: Current data partition id + mode: Data fetch mode. Options: + - 'fetch': Get ready data only + - 'force_fetch': Get data regardless of readiness (may return unready samples) + - 'insert': Internal usage - should not be used by users + task_name: Optional task name associated with the request + sampling_config: Optional sampling configuration for custom samplers. + + Returns: + BatchMeta: Metadata object containing data structure, sample information, and readiness status + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> # Example 1: Basic fetch metadata + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences" + ... )) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=8, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, + >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... partition_id="train_0", # optional + ... mode="force_fetch", + ... )) + >>> print(batch_meta.is_ready) # May be False if some samples not ready + """ + + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) + + +async def async_set_custom_meta( + metadata: BatchMeta, +) -> None: + """ + Asynchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() or + BatchMeta.set_custom_meta() before calling this method. + socket: ZMQ async socket for message transmission (injected by decorator) + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> # Create batch with custom metadata + >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> asyncio.run(tq.async_set_custom_meta(batch_meta)) + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_set_custom_meta(metadata) + + +async def async_put( + data: TensorDict, + metadata: Optional[BatchMeta] = None, + partition_id: Optional[str] = None, +) -> BatchMeta: + """Asynchronously write data to storage units based on metadata. + + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. + + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. + + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. + + Args: + data: Data to write as TensorDict + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. + partition_id: Target data partition id (required if metadata is not provided) + + Returns: + BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved + metadata; will be updated in a future version to reflect the post-put state) + + Raises: + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> batch = asyncio.run(tq.async_get_data(batch_meta)) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta)) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_put(data, metadata, partition_id) + + +async def async_get_data(metadata: BatchMeta) -> TensorDict: + """Asynchronously fetch data from storage units and organize into TensorDict. + + Args: + metadata: Batch metadata containing data location information and global indexes + + Returns: + TensorDict containing: + - Requested data fields (e.g., "prompts", "attention_mask") + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> batch = asyncio.run(tq.async_get_data(batch_meta)) + >>> print(batch) + >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_get_data(metadata) + + +# ==================== Data Operations API ==================== + + +async def async_clear_samples(metadata: BatchMeta): + """Asynchronously clear specific samples from all storage units and the controller. + + Args: + metadata: The BatchMeta of the corresponding data to be cleared + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_clear_samples(metadata) + + +async def async_clear_partition(partition_id: str): + """Asynchronously clear the whole partition from all storage units and the controller. + + Args: + partition_id: The partition id to clear data for + + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_clear_partition(partition_id) + + +def close(): + """Close the TransferQueue system. + + This function cleans up the TransferQueue system, including: + - Closing the client and its associated resources + - Cleaning up distributed storage (only for the process that initialized it) + - Killing the controller actor + + Note: + This function should be called when the TransferQueue system is no longer needed. + """ + global _TRANSFER_QUEUE_CLIENT + global _TRANSFER_QUEUE_STORAGE + if _TRANSFER_QUEUE_CLIENT: + _TRANSFER_QUEUE_CLIENT.close() + _TRANSFER_QUEUE_CLIENT = None + + try: + if _TRANSFER_QUEUE_STORAGE: + # only the process that do first-time init can clean the distributed storage + for storage in _TRANSFER_QUEUE_STORAGE.values(): + ray.kill(storage) + _TRANSFER_QUEUE_STORAGE = None + except Exception: + pass + + try: + controller = ray.get_actor("TransferQueueController") + ray.kill(controller) + except Exception: + pass + + +# ==================== KV Interface API ==================== +def kv_put( + key: str, partition_id: str, fields: Optional[TensorDict | dict[str, Any]], tag: Optional[dict[str, Any]] +) -> None: + """Put a single key-value pair to TransferQueue. + + This is a convenience method for putting data using a user-specified key + instead of BatchMeta. Internally, the key is translated to a BatchMeta + and the data is stored using the regular put mechanism. + + Args: + key: User-specified key for the data sample (in row) + partition_id: Logical partition to store the data in + fields: Data fields to store. Can be a TensorDict or a dict of tensors. + Each key in `fields` will be treated as a column for the data sample. + If dict is provided, tensors will be unsqueezed to add batch dimension. + tag: Optional metadata tag to associate with the key + + Raises: + ValueError: If neither fields nor tag is provided + ValueError: If nested tensors are provided (use kv_batch_put instead) + RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` + + Example: + >>> import transfer_queue as tq + >>> import torch + >>> tq.init() + >>> # Put with both fields and tag + >>> tq.kv_put( + ... key="sample_1", + ... partition_id="train", + ... fields={"input_ids": torch.tensor([1, 2, 3])}, + ... tag={"score": 0.95} + ... ) + """ + if fields is None and tag is None: + raise ValueError("Please provide at least one parameter of fields or tag.") + + tq_client = _maybe_create_transferqueue_client() + + # 1. translate user-specified key to BatchMeta + batch_meta = tq_client.kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True) + + if batch_meta.size != 1: + raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") + + # 2. register the user-specified tag to BatchMeta + if tag: + batch_meta.update_custom_meta([tag]) + + # 3. put data + if fields is not None: + if isinstance(fields, dict): + # TODO: consider whether to support this... + batch = {} + for key, value in fields.items(): + if isinstance(value, torch.Tensor): + if value.is_nested: + raise ValueError("Please use (async)kv_batch_put for batch operation") + batch[key] = value.unsqueeze(0) + else: + batch[key] = NonTensorStack(value) + fields = TensorDict(batch, batch_size=[1]) + elif not isinstance(fields, TensorDict): + raise ValueError("field can only be dict or TensorDict") + + # custom_meta (tag) will be put to controller through the internal put process + tq_client.put(fields, batch_meta) + else: + # directly update custom_meta (tag) to controller + tq_client.set_custom_meta(batch_meta) + + +def kv_batch_put(keys: list[str], partition_id: str, fields: TensorDict, tags: list[dict[str, Any]]) -> None: + """Put multiple key-value pairs to TransferQueue in batch. + + This method stores multiple key-value pairs in a single operation, which is more + efficient than calling kv_put multiple times. + + Args: + keys: List of user-specified keys for the data + partition_id: Logical partition to store the data in + fields: TensorDict containing data for all keys. Must have batch_size == len(keys) + tags: List of metadata tags, one for each key + + Raises: + ValueError: If neither `fields` nor `tags` is provided + ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict + RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` + + Example: + >>> import transfer_queue as tq + >>> from tensordict import TensorDict + >>> tq.init() + >>> keys = ["sample_1", "sample_2", "sample_3"] + >>> fields = TensorDict({ + ... "input_ids": torch.randn(3, 10), + ... "attention_mask": torch.ones(3, 10), + ... }, batch_size=3) + >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}] + >>> tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) + """ + + if fields is None and tags is None: + raise ValueError("Please provide at least one parameter of fields or tag.") + + if fields.batch_size[0] != len(keys): + raise ValueError( + f"`keys` with length {len(keys)} does not match the `fields` TensorDict with " + f"batch_size {fields.batch_size[0]}" + ) + + tq_client = _maybe_create_transferqueue_client() + + # 1. translate user-specified key to BatchMeta + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True) + + if batch_meta.size != len(keys): + raise RuntimeError( + f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!" + ) + + # 2. register the user-specified tags to BatchMeta + if tags: + if len(tags) != len(keys): + raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}") + batch_meta.update_custom_meta(tags) + + # 3. put data + if fields is not None: + tq_client.put(fields, batch_meta) + else: + # directly update custom_meta (tags) to controller + tq_client.set_custom_meta(batch_meta) + + +def kv_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict: + """Get data from TransferQueue using user-specified keys. + + This is a convenience method for retrieving data using keys instead of indexes. + + Args: + keys: Single key or list of keys to retrieve + partition_id: Partition containing the keys + fields: Optional field(s) to retrieve. If None, retrieves all fields + + Returns: + TensorDict with the requested data + + Raises: + RuntimeError: If keys or partition are not found + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # Get single key with all fields + >>> data = tq.kv_get(key="sample_1", partition_id="train") + >>> # Get multiple keys with specific fields + >>> data = tq.kv_get( + ... keys=["sample_1", "sample_2"], + ... partition_id="train", + ... fields="input_ids" + ... ) + """ + tq_client = _maybe_create_transferqueue_client() + + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + + if batch_meta.size == 0: + raise RuntimeError("keys or partition were not found!") + + if fields is not None: + if isinstance(fields, str): + fields = [fields] + batch_meta = batch_meta.select_fields(fields) + + data = tq_client.get_data(batch_meta) + + return data + + +def kv_list(partition_id: str) -> tuple[list[Optional[str]], list[Optional[dict[str, Any]]]]: + """List all keys and their metadata in a partition. + + Args: + partition_id: Partition to list keys from + + Returns: + Tuple of: + - List of keys in the partition + - List of custom metadata (tags) associated with each key + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> keys, tags = tq.kv_list(partition_id="train") + >>> print(f"Keys: {keys}") + >>> print(f"Tags: {tags}") + """ + tq_client = _maybe_create_transferqueue_client() + + keys, custom_meta = tq_client.kv_list(partition_id) + + return keys, custom_meta + + +def kv_clear(keys: list[str] | str, partition_id: str) -> None: + """Clear key-value pairs from TransferQueue. + + This removes the specified keys and their associated data from both + the controller and storage units. + + Args: + keys: Single key or list of keys to clear + partition_id: Partition containing the keys + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # Clear single key + >>> tq.kv_clear(key="sample_1", partition_id="train") + >>> # Clear multiple keys + >>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train") + """ + + if isinstance(keys, str): + keys = [keys] + + tq_client = _maybe_create_transferqueue_client() + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + + if batch_meta.size > 0: + tq_client.clear_samples(batch_meta) + + +# ==================== KV Interface API ==================== +async def async_kv_put( + key: str, partition_id: str, fields: Optional[TensorDict | dict[str, Any]], tag: Optional[dict[str, Any]] +) -> None: + """Asynchronously put a single key-value pair to TransferQueue. + + This is a convenience method for putting data using a user-specified key + instead of BatchMeta. Internally, the key is translated to a BatchMeta + and the data is stored using the regular put mechanism. + + Args: + key: User-specified key for the data sample (in row) + partition_id: Logical partition to store the data in + fields: Data fields to store. Can be a TensorDict or a dict of tensors. + Each key in `fields` will be treated as a column for the data sample. + If dict is provided, tensors will be unsqueezed to add batch dimension. + tag: Optional metadata tag to associate with the key + + Raises: + ValueError: If neither fields nor tag is provided + ValueError: If nested tensors are provided (use kv_batch_put instead) + RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` + + Example: + >>> import transfer_queue as tq + >>> import torch + >>> tq.init() + >>> # Put with both fields and tag + >>> await tq.async_kv_put( + ... key="sample_1", + ... partition_id="train", + ... fields={"input_ids": torch.tensor([1, 2, 3])}, + ... tag={"score": 0.95} + ... )) + """ + + if fields is None and tag is None: + raise ValueError("Please provide at least one parameter of fields or tag.") + + tq_client = _maybe_create_transferqueue_client() + + # 1. translate user-specified key to BatchMeta + batch_meta = await tq_client.async_kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True) + + if batch_meta.size != 1: + raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") + + # 2. register the user-specified tag to BatchMeta + if tag: + batch_meta.update_custom_meta([tag]) + + # 3. put data + if fields is not None: + if isinstance(fields, dict): + # TODO: consider whether to support this... + batch = {} + for key, value in fields.items(): + if isinstance(value, torch.Tensor): + if value.is_nested: + raise ValueError("Please use (async)kv_batch_put for batch operation") + batch[key] = value.unsqueeze(0) + else: + batch[key] = NonTensorStack(value) + fields = TensorDict(batch, batch_size=[1]) + elif not isinstance(fields, TensorDict): + raise ValueError("field can only be dict or TensorDict") + + # custom_meta (tag) will be put to controller through the put process + await tq_client.async_put(fields, batch_meta) + else: + # directly update custom_meta (tag) to controller + await tq_client.async_set_custom_meta(batch_meta) + + +async def async_kv_batch_put( + keys: list[str], partition_id: str, fields: TensorDict, tags: list[dict[str, Any]] +) -> None: + """Asynchronously put multiple key-value pairs to TransferQueue in batch. + + This method stores multiple key-value pairs in a single operation, which is more + efficient than calling kv_put multiple times. + + Args: + keys: List of user-specified keys for the data + partition_id: Logical partition to store the data in + fields: TensorDict containing data for all keys. Must have batch_size == len(keys) + tags: List of metadata tags, one for each key + + Raises: + ValueError: If neither `fields` nor `tags` is provided + ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict + RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> keys = ["sample_1", "sample_2", "sample_3"] + >>> fields = TensorDict({ + ... "input_ids": torch.randn(3, 10), + ... "attention_mask": torch.ones(3, 10), + ... }, batch_size=3) + >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}] + >>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) + """ + + if fields is None and tags is None: + raise ValueError("Please provide at least one parameter of fields or tag.") + + if fields.batch_size[0] != len(keys): + raise ValueError( + f"`keys` with length {len(keys)} does not match the `fields` TensorDict with " + f"batch_size {fields.batch_size[0]}" + ) + + tq_client = _maybe_create_transferqueue_client() + + # 1. translate user-specified key to BatchMeta + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True) + + if batch_meta.size != len(keys): + raise RuntimeError( + f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!" + ) + + # 2. register the user-specified tags to BatchMeta + if tags: + if len(tags) != len(keys): + raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}") + batch_meta.update_custom_meta(tags) + + # 3. put data + if fields is not None: + await tq_client.async_put(fields, batch_meta) + else: + # directly update custom_meta (tags) to controller + await tq_client.async_set_custom_meta(batch_meta) + + +async def async_kv_get( + keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None +) -> TensorDict: + """Asynchronously get data from TransferQueue using user-specified keys. + + This is a convenience method for retrieving data using keys instead of indexes. + + Args: + keys: Single key or list of keys to retrieve + partition_id: Partition containing the keys + fields: Optional field(s) to retrieve. If None, retrieves all fields + + Returns: + TensorDict with the requested data + + Raises: + RuntimeError: If keys or partition are not found + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # Get single key with all fields + >>> data = await tq.async_kv_get(key="sample_1", partition_id="train") + >>> # Get multiple keys with specific fields + >>> data = await tq.async_kv_get( + ... keys=["sample_1", "sample_2"], + ... partition_id="train", + ... fields="input_ids" + ... ) + """ + tq_client = _maybe_create_transferqueue_client() + + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + + if batch_meta.size == 0: + raise RuntimeError("keys or partition were not found!") + + if fields is not None: + if isinstance(fields, str): + fields = [fields] + batch_meta = batch_meta.select_fields(fields) + + data = await tq_client.async_get_data(batch_meta) + + return data + + +async def async_kv_list(partition_id: str) -> tuple[list[str], list[dict[str, Any]]]: + """Asynchronously list all keys and their metadata in a partition. + + Args: + partition_id: Partition to list keys from + + Returns: + Tuple of: + - List of keys in the partition + - List of custom metadata (tags) associated with each key + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> keys, tags = await tq.async_kv_list(partition_id="train") + >>> print(f"Keys: {keys}") + >>> print(f"Tags: {tags}") + """ + tq_client = _maybe_create_transferqueue_client() + + keys, custom_meta = await tq_client.async_kv_list(partition_id) + + return keys, custom_meta + + +async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: + """Asynchronously clear key-value pairs from TransferQueue. + + This removes the specified keys and their associated data from both + the controller and storage units. + + Args: + keys: Single key or list of keys to clear + partition_id: Partition containing the keys + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # Clear single key + >>> await tq.async_kv_clear(key="sample_1", partition_id="train") + >>> # Clear multiple keys + >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train") + """ + + if isinstance(keys, str): + keys = [keys] + + tq_client = _maybe_create_transferqueue_client() + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + + if batch_meta.size > 0: + await tq_client.async_clear_samples(batch_meta) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 5e25404..2abde1b 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -303,60 +303,43 @@ def clear_extra_info(self) -> None: """ self.extra_info.clear() - def set_custom_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None: + def get_all_custom_meta(self) -> list[dict[str, Any]]: """ - Set custom_meta for a specific sample by global_index. - - Custom metadata is user-defined per-sample metadata that can be stored - and retrieved along with the BatchMeta. - - Args: - global_index: The global_index of the sample to set custom meta for - meta_dict: Dictionary containing custom metadata for the sample - - Raises: - ValueError: If the key is not in global_indexes - """ - - if global_index not in self.global_indexes: - raise ValueError(f"key {global_index} not found in global_indexes {self.global_indexes}.") - - self.custom_meta[global_index] = copy.deepcopy(meta_dict) - - def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: - """ - Get all custom_meta as a dictionary. + Get all custom_meta as a list of dictionary. Returns: - A deep copy of the custom_meta dictionary + A deep copy of the custom_meta list """ - return copy.deepcopy(self.custom_meta) + custom_meta = [self.custom_meta.get(i, {}) for i in self.global_indexes] + return copy.deepcopy(custom_meta) - def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): + def update_custom_meta(self, custom_meta: list[dict[str, Any]]): """ - Update custom_meta with a dictionary of new metadata. + Update custom_meta with a list of dictionary of custom metadata. This method updates the custom_meta dictionary with the provided metadata. Existing keys will be overwritten with new values. Args: - new_meta: Dictionary of new metadata + custom_meta: list of custom_meta dictionary Raises: - ValueError: If any key in new_meta is not in global_indexes + ValueError: If the length of custom_meta cannot match the batch size """ - if new_meta is None: + if custom_meta is None: return - non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes) - if non_exist_global_indexes: + if len(custom_meta) != self.size: raise ValueError( - f"Trying to update custom_meta with non-exist global_indexes! {non_exist_global_indexes} " - f"do not exist in this batch." + f"The length of custom_meta list {len(custom_meta)} must match the batch size: {self.size}" ) - self.custom_meta.update(new_meta) + custom_meta_dict: dict[int, dict[str, Any]] = { + self.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta)) + } + + self.custom_meta.update(custom_meta_dict) def clear_custom_meta(self) -> None: """ diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index c07ec1f..258dd91 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -28,6 +28,7 @@ import ray import torch import zmq +from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict from torch import Tensor @@ -59,12 +60,10 @@ class TransferQueueStorageManager(ABC): """Base class for storage layer. It defines the interface for data operations and generally provides handshake & notification capabilities.""" - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}" self.config = config - controller_info = config.get("controller_info") - assert controller_info is not None, "controller_info is required" - self.controller_info: ZMQServerInfo = controller_info + self.controller_info = controller_info self.data_status_update_socket: Optional[zmq.Socket[bytes]] = None self.controller_handshake_socket: Optional[zmq.Socket[bytes]] = None @@ -178,7 +177,7 @@ def _send_handshake_requests(self) -> None: """Send handshake request to controller.""" assert self.controller_handshake_socket is not None, "controller_handshake_socket is not properly initialized" request_msg = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE, + request_type=ZMQRequestType.HANDSHAKE, # type: ignore[arg-type] sender_id=self.storage_manager_id, body={ "storage_manager_id": self.storage_manager_id, @@ -226,7 +225,7 @@ async def notify_data_update( poller.register(self.data_status_update_socket, zmq.POLLIN) request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type] sender_id=self.storage_manager_id, body={ "partition_id": partition_id, @@ -246,7 +245,7 @@ async def notify_data_update( ) except Exception as e: request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type] sender_id=self.storage_manager_id, body={ "message": f"Failed to notify data status update information from " @@ -351,14 +350,14 @@ class KVStorageManager(TransferQueueStorageManager): It maps structured metadata (BatchMeta) to flat lists of keys and values for efficient KV operations. """ - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): """ Initialize the KVStorageManager with configuration. """ client_name = config.get("client_name", None) if client_name is None: raise ValueError("Missing client_name in config") - super().__init__(config) + super().__init__(controller_info, config) self.storage_client = StorageClientFactory.create(client_name, config) self._multi_threads_executor: Optional[ThreadPoolExecutor] = None # Register a cleanup function: automatically invoke shutdown when the instance is garbage collected. @@ -557,7 +556,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: loop = asyncio.get_event_loop() # put to storage backends - custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) + custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) per_field_dtypes: dict[int, dict[str, Any]] = {} per_field_shapes: dict[int, dict[str, Any]] = {} @@ -578,24 +577,26 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: getattr(data_item, "shape", None) if isinstance(data_item, Tensor) else None ) - # Prepare per-field custom_meta if available - per_field_custom_meta: dict[int, dict[str, Any]] = {} - if custom_meta: - if len(custom_meta) != len(keys): - raise ValueError(f"Length of custom_meta ({len(custom_meta)}) does not match expected ({len(keys)})") + # Prepare per-field custom_backend_meta if available + per_field_custom_backend_meta: dict[int, dict[str, Any]] = {} + if custom_backend_meta: + if len(custom_backend_meta) != len(keys): + raise ValueError( + f"Length of custom_backend_meta ({len(custom_backend_meta)}) does not match expected ({len(keys)})" + ) # custom meta is a flat list aligned with keys/values # Use itertools.product to eliminate nested loops for global_idx in metadata.global_indexes: - per_field_custom_meta[global_idx] = {} + per_field_custom_backend_meta[global_idx] = {} # TODO(tianyi): the order of custom meta is coupled with keys/values for (field_name, global_idx), meta_value in zip( itertools.product(sorted(metadata.field_names), metadata.global_indexes), - custom_meta, + custom_backend_meta, strict=True, ): - per_field_custom_meta[global_idx][field_name] = meta_value - metadata.update_custom_meta(per_field_custom_meta) + per_field_custom_backend_meta[global_idx][field_name] = meta_value + metadata._custom_backend_meta.update(per_field_custom_backend_meta) # Get current data partition id # Note: Currently we only support putting to & getting data from a single data partition simultaneously, @@ -608,7 +609,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: metadata.global_indexes, per_field_dtypes, per_field_shapes, - per_field_custom_meta, + per_field_custom_backend_meta, ) async def get_data(self, metadata: BatchMeta) -> TensorDict: diff --git a/transfer_queue/storage/managers/factory.py b/transfer_queue/storage/managers/factory.py index e595ccd..04d2cd0 100644 --- a/transfer_queue/storage/managers/factory.py +++ b/transfer_queue/storage/managers/factory.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Any from transfer_queue.storage.managers.base import TransferQueueStorageManager +from transfer_queue.utils.zmq_utils import ZMQServerInfo class TransferQueueStorageManagerFactory: @@ -39,10 +41,34 @@ def decorator(manager_cls: type[TransferQueueStorageManager]): return decorator @classmethod - def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager: + def create( + cls, manager_type: str, controller_info: ZMQServerInfo, config: dict[str, Any] + ) -> TransferQueueStorageManager: """Create and return a TransferQueueStorageManager instance.""" if manager_type not in cls._registry: - raise ValueError( - f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" - ) - return cls._registry[manager_type](config) + if manager_type == "AsyncSimpleStorageManager": + warnings.warn( + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use SimpleStorage instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "SimpleStorage" + elif manager_type == "MooncakeStorageManager": + warnings.warn( + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use MooncakeStore instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "MooncakeStore" + elif manager_type == "YuanrongStorageManager": + warnings.warn( + f"The manager_type {manager_type} will be deprecated in 0.1.7, please use Yuanrong instead.", + category=DeprecationWarning, + stacklevel=2, + ) + manager_type = "Yuanrong" + else: + raise ValueError( + f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" + ) + return cls._registry[manager_type](controller_info, config) diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index ca55566..9f6f93a 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -19,16 +19,17 @@ from transfer_queue.storage.managers.base import KVStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.utils.zmq_utils import ZMQServerInfo logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) -@TransferQueueStorageManagerFactory.register("MooncakeStorageManager") +@TransferQueueStorageManagerFactory.register("MooncakeStore") class MooncakeStorageManager(KVStorageManager): """Storage manager for MooncakeStorage backend.""" - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): # Required: Address of the HTTP metadata server (e.g., "localhost:8080") metadata_server = config.get("metadata_server", None) # Required: Address of the master server RPC endpoint (e.g., "localhost:8081") @@ -45,4 +46,4 @@ def __init__(self, config: dict[str, Any]): config["client_name"] = "MooncakeStorageClient" elif client_name != "MooncakeStorageClient": raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'MooncakeStorageClient'") - super().__init__(config) + super().__init__(controller_info, config) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 5c6e68f..f142069 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -16,6 +16,7 @@ import asyncio import logging import os +import warnings from collections.abc import Mapping from functools import wraps from operator import itemgetter @@ -24,6 +25,7 @@ import torch import zmq +from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict from transfer_queue.metadata import BatchMeta @@ -48,7 +50,7 @@ TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) -@TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager") +@TransferQueueStorageManagerFactory.register("SimpleStorage") class AsyncSimpleStorageManager(TransferQueueStorageManager): """Asynchronous storage manager that handles multiple storage units. @@ -56,14 +58,23 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager): instances using ZMQ communication and dynamic socket management. """ - def __init__(self, config: dict[str, Any]): - super().__init__(config) + def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): + super().__init__(controller_info, config) self.config = config - server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("storage_unit_infos", None) + server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("zmq_info", None) if server_infos is None: - raise ValueError("AsyncSimpleStorageManager requires non-empty 'storage_unit_infos' in config.") + server_infos = config.get("storage_unit_infos", None) + if server_infos is not None: + warnings.warn( + "The config entry `storage_unit_infos` will be deprecated in 0.1.7, please use `zmq_info` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + if server_infos is None: + raise ValueError("AsyncSimpleStorageManager requires non-empty 'zmq_info' in config.") self.storage_unit_infos = self._register_servers(server_infos) self._build_storage_mapping_functions() @@ -277,7 +288,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping ) - # retrive data + # retrieve data tasks = [ self._get_from_single_storage_unit(meta_group, target_storage_unit=storage_id) for storage_id, meta_group in storage_meta_groups.items() diff --git a/transfer_queue/storage/managers/yuanrong_manager.py b/transfer_queue/storage/managers/yuanrong_manager.py index bfb79e6..54ac094 100644 --- a/transfer_queue/storage/managers/yuanrong_manager.py +++ b/transfer_queue/storage/managers/yuanrong_manager.py @@ -19,6 +19,7 @@ from transfer_queue.storage.managers.base import KVStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory +from transfer_queue.utils.zmq_utils import ZMQServerInfo logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -30,11 +31,11 @@ logger.addHandler(handler) -@TransferQueueStorageManagerFactory.register("YuanrongStorageManager") +@TransferQueueStorageManagerFactory.register("Yuanrong") class YuanrongStorageManager(KVStorageManager): """Storage manager for Yuanrong backend.""" - def __init__(self, config: dict[str, Any]): + def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): host = config.get("host", None) port = config.get("port", None) client_name = config.get("client_name", None) @@ -48,4 +49,4 @@ def __init__(self, config: dict[str, Any]): config["client_name"] = "YuanrongStorageClient" elif client_name != "YuanrongStorageClient": raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'YuanrongStorageClient'") - super().__init__(config) + super().__init__(controller_info, config) diff --git a/transfer_queue/utils/common.py b/transfer_queue/utils/common.py index e25f6b0..5192e9f 100644 --- a/transfer_queue/utils/common.py +++ b/transfer_queue/utils/common.py @@ -16,11 +16,13 @@ import logging import os from contextlib import contextmanager -from typing import Optional +from typing import Any, Optional +import numpy as np import psutil import ray import torch +from tensordict import NonTensorStack, TensorDict logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -98,3 +100,67 @@ def get_env_bool(env_key: str, default: bool = False) -> bool: true_values = {"true", "1", "yes", "y", "on"} return env_value_lower in true_values + + +def dict_to_tensordict(data: dict[str, Any]) -> TensorDict: + """ + Create a TensorDict from a dict of tensors and non_tensors. + """ + + batch = {} + + final_batch_size = None + tensor_batch_size = None + deterministic_tensor_batch_size = None + deterministic_non_tensor_batch_size = None + + for key, val in data.items(): + if isinstance(val, torch.Tensor): + if val.is_nested and val.layout == torch.strided: + # must use unbind for strided nested tensor + deterministic_tensor_batch_size = len(val.unbind()) + else: + tensor_batch_size = val.shape[0] + batch[key] = val + elif isinstance(val, np.ndarray): + batch[key] = val + tensor_batch_size = val.shape[0] + elif isinstance(val, str): + batch[key] = val + deterministic_non_tensor_batch_size = 1 + elif isinstance(val, list): + batch[key] = NonTensorStack(*val) + deterministic_non_tensor_batch_size = len(val) + else: + batch[key] = NonTensorStack(val) + deterministic_non_tensor_batch_size = 1 + + if deterministic_tensor_batch_size: + if deterministic_non_tensor_batch_size: + assert deterministic_non_tensor_batch_size == deterministic_tensor_batch_size + if final_batch_size: + assert final_batch_size == deterministic_tensor_batch_size + else: + final_batch_size = deterministic_tensor_batch_size + + if deterministic_non_tensor_batch_size: + if deterministic_tensor_batch_size: + assert deterministic_non_tensor_batch_size == tensor_batch_size + if final_batch_size: + assert final_batch_size == deterministic_non_tensor_batch_size + else: + final_batch_size = deterministic_non_tensor_batch_size + + if not final_batch_size: + raise RuntimeError("Cannot correctly determine batch_size for input.") + + if tensor_batch_size: + if tensor_batch_size != final_batch_size: + assert final_batch_size == 1 + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.unsqueeze(0) + elif isinstance(v, np.ndarray): + batch[k] = np.expand_dims(v, 0) + + return TensorDict(batch, batch_size=[final_batch_size]) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 1f6ed92..eaaf65e 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -23,6 +23,7 @@ from uuid import uuid4 import psutil +import ray import zmq from transfer_queue.utils.common import ( @@ -100,6 +101,12 @@ class ZMQRequestType(ExplicitEnum): NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK" NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR" + # KV_INTERFACE + KV_RETRIEVE_KEYS = "KV_RETRIEVE_KEYS" + KV_RETRIEVE_KEYS_RESPONSE = "KV_RETRIEVE_KEYS_RESPONSE" + KV_LIST = "KV_LIST" + KV_LIST_RESPONSE = "KV_LIST_RESPONSE" + class ZMQServerInfo: """ @@ -239,3 +246,35 @@ def create_zmq_socket( if identity is not None: socket.setsockopt(zmq.IDENTITY, identity) return socket + + +def process_zmq_server_info( + handlers: dict[Any, Any] | Any, +): # noqa: UP007 + """Extract ZMQ server information from handler objects. + + Args: + handlers: Dictionary of handler objects (controllers, storage managers, or storage units), + or a single handler object + + Returns: + If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information + If handlers is a single object: ZMQ server information for that object + + Examples: + >>> # Single handler + >>> controller = TransferQueueController.remote(...) + >>> info = process_zmq_server_info(controller) + >>> + >>> # Multiple handlers + >>> handlers = {"storage_0": storage_0, "storage_1": storage_1} + >>> info_dict = process_zmq_server_info(handlers)""" + # Handle single handler object case + if not isinstance(handlers, dict): + return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] + else: + # Handle dictionary case + server_info = {} + for name, handler in handlers.items(): + server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined] + return server_info diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py index 25b530c..59159a6 100644 --- a/tutorial/01_core_components.py +++ b/tutorial/01_core_components.py @@ -37,87 +37,23 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 # Configure Ray os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" - -def demonstrate_basic_setup(): - """ - Demonstrate the basic setup of TransferQueue with three core components. - """ - - # Initialize Ray - if not ray.is_initialized(): - ray.init() - - # Configuration - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - print("[Step 1] Creating Storage Backend (using default SimpleStorageUnit)...") - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - print(f" ✓ Created SimpleStorageUnit #{i}") - - print("[Step 2] Creating TransferQueueController...") - controller = TransferQueueController.remote() - print(" ✓ Controller created - manages data state") - - # Get server information - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - # Create Client (User-facing API) - print("[Step 3] Creating TransferQueueClient...") - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, - ) - print(" ✓ Client created - this is what users interact with!") - - # Initialize storage manager - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - print( - " ✓ Storage manager initialized. It is a class variable inside the client, acting as an adapter to " - "suit for various storage backends." - ) - - print("[Architecture Summary]") - print( - " - TransferQueueController: Tracking the production/consumption status as metadata (can define your own " - "data consumption logics)." - ) - print(" - SimpleStorageUnit: Distributed data storage that holds actual data (easily swap out by other backends).") - print(" - TransferQueueClient: User interface that allows you to put/get/clear data or metadata)") - - return controller, storage_units, client +if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") -def demonstrate_data_workflow(client): +def demonstrate_data_workflow(): """ Demonstrate basic data workflow: put → get → clear. """ @@ -148,12 +84,12 @@ def demonstrate_data_workflow(client): print(f" Created {data_batch.batch_size[0]} samples") partition_id = "tutorial_partition_0" - client.put(data=data_batch, partition_id=partition_id) + tq.put(data=data_batch, partition_id=partition_id) print(f" ✓ Data written to partition: {partition_id}") # Step 2: Get metadata print("[Step 2] Requesting data metadata...") - batch_meta = client.get_meta( + batch_meta = tq.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=data_batch.batch_size[0], partition_id=partition_id, @@ -164,7 +100,7 @@ def demonstrate_data_workflow(client): # Step 3: Get actual data print("[Step 3] Retrieving actual data...") - retrieved_data = client.get_data(batch_meta) + retrieved_data = tq.get_data(batch_meta) print(" ✓ Data retrieved successfully") print(f" Keys: {list(retrieved_data.keys())}") @@ -176,7 +112,7 @@ def demonstrate_data_workflow(client): # Step 5: Clear print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)") - client.clear_partition(partition_id=partition_id) + tq.clear_partition(partition_id=partition_id) print(" ✓ Partition cleared") @@ -189,16 +125,16 @@ def demonstrate_storage_backend_options(): print("=" * 80) print("TransferQueue supports multiple storage backends:") - print("1. SimpleStorageUnit (default)") + print("1. SimpleStorage (default)") print(" - In-memory storage, fast and simple") print(" - Leveraging ZMQ for communication, with zero-copy serialization and transfer") print(" - No extra dependencies, good for development and testing") - print("2. YuanrongStorage") + print("2. Yuanrong") print(" - Ascend native distributed storage solution") print(" - Hierarchical storage interfaces including HBM/DRAM/SSD") - print("3. MoonCakeStore (on the way)") + print("3. MooncakeStore (on the way)") print(" - Support multiple transmission protocols") print(" - RDMA between DRAM and HBM") @@ -234,10 +170,10 @@ def main(): try: print("Setting up TransferQueue...") - controller, storage_units, client = demonstrate_basic_setup() + tq.init() print("Demonstrating the user workflow...") - demonstrate_data_workflow(client) + demonstrate_data_workflow() demonstrate_storage_backend_options() @@ -253,7 +189,7 @@ def main(): print("3. You can swap out different storage backends easily") # Cleanup - client.close() + tq.close() ray.shutdown() print("\n✓ Cleanup complete") diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py new file mode 100644 index 0000000..59159a6 --- /dev/null +++ b/tutorial/02_kv_interface.py @@ -0,0 +1,205 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import textwrap +import warnings +from pathlib import Path + +warnings.filterwarnings( + action="ignore", + message=r"The PyTorch API of nested tensors is in prototype stage*", + category=UserWarning, + module=r"torch\.nested", +) + +warnings.filterwarnings( + action="ignore", + message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible " + r"devices env var if num_gpus=0 or num_gpus=None.*", + category=FutureWarning, + module=r"ray\._private\.worker", +) + + +import ray # noqa: E402 +import torch # noqa: E402 +from tensordict import TensorDict # noqa: E402 + +# Add the parent directory to the path +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + +import transfer_queue as tq # noqa: E402 + +# Configure Ray +os.environ["RAY_DEDUP_LOGS"] = "0" +os.environ["RAY_DEBUG"] = "1" + +if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + +def demonstrate_data_workflow(): + """ + Demonstrate basic data workflow: put → get → clear. + """ + print("=" * 80) + print("Data Workflow Demo: put → get → clear") + print("=" * 80) + + # Step 1: Put data + print("[Step 1] Putting data into TransferQueue...") + + input_ids = torch.tensor( + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + ] + ) + attention_mask = torch.ones_like(input_ids) + + data_batch = TensorDict( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + }, + batch_size=input_ids.size(0), + ) + + print(f" Created {data_batch.batch_size[0]} samples") + partition_id = "tutorial_partition_0" + tq.put(data=data_batch, partition_id=partition_id) + print(f" ✓ Data written to partition: {partition_id}") + + # Step 2: Get metadata + print("[Step 2] Requesting data metadata...") + batch_meta = tq.get_meta( + data_fields=["input_ids", "attention_mask"], + batch_size=data_batch.batch_size[0], + partition_id=partition_id, + task_name="tutorial_task", + ) + print(f" ✓ Got metadata: {len(batch_meta)} samples") + print(f" Global indexes: {batch_meta.global_indexes}") + + # Step 3: Get actual data + print("[Step 3] Retrieving actual data...") + retrieved_data = tq.get_data(batch_meta) + print(" ✓ Data retrieved successfully") + print(f" Keys: {list(retrieved_data.keys())}") + + # Step 4: Verify + print("[Step 4] Verifying data integrity...") + assert torch.equal(retrieved_data["input_ids"], input_ids) + assert torch.equal(retrieved_data["attention_mask"], attention_mask) + print(" ✓ Data matches original!") + + # Step 5: Clear + print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)") + tq.clear_partition(partition_id=partition_id) + print(" ✓ Partition cleared") + + +def demonstrate_storage_backend_options(): + """ + Show different storage backend options. + """ + print("=" * 80) + print("Storage Backend Options") + print("=" * 80) + + print("TransferQueue supports multiple storage backends:") + print("1. SimpleStorage (default)") + print(" - In-memory storage, fast and simple") + print(" - Leveraging ZMQ for communication, with zero-copy serialization and transfer") + print(" - No extra dependencies, good for development and testing") + + print("2. Yuanrong") + print(" - Ascend native distributed storage solution") + print(" - Hierarchical storage interfaces including HBM/DRAM/SSD") + + print("3. MooncakeStore (on the way)") + print(" - Support multiple transmission protocols") + print(" - RDMA between DRAM and HBM") + + print("4. Ray RDT (on the way)") + print(" - Leverage Ray's distributed object store to store data") + + print("5. Custom Storage Backends") + print(" - Implement your own storage manager by inheriting from `TransferQueueStorageManager` base class") + print(" - For KV based storage, you only need to provide a storage client and integrate with `KVStorageManager`") + + +def main(): + print("=" * 80) + print( + textwrap.dedent( + """ + TransferQueue Tutorial 1: Core Components Introduction + + This script introduces the three core components of TransferQueue: + 1. TransferQueueController - Manages all the metadata and tracks the production and consumption states + 2. StorageBackend - Pluggable distributed storage backend that holds the actual data + 3. TransferQueueClient - Client interface for reading/writing data (user-facing API) + + Key Concepts: + - Data is organized into logical partitions (e.g., "train", "val") + - Each sample has multiple fields, with a global index for identification + - Controller maintains production/consumption state tracking + - Client is the main interface users interact with + """ + ) + ) + print("=" * 80) + + try: + print("Setting up TransferQueue...") + tq.init() + + print("Demonstrating the user workflow...") + demonstrate_data_workflow() + + demonstrate_storage_backend_options() + + print("=" * 80) + print("Tutorial Complete!") + print("=" * 80) + print("Key Takeaways:") + print("1. TransferQueue has 3 core components:") + print(" - Controller: Manages data production/consumption state") + print(" - StorageBackend: Persists actual data") + print(" - Client: User-facing API (what you use)") + print("2. Client is the main interface users interact with") + print("3. You can swap out different storage backends easily") + + # Cleanup + tq.close() + ray.shutdown() + print("\n✓ Cleanup complete") + + except Exception as e: + print(f"Error during tutorial: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tutorial/02_metadata_concepts.py b/tutorial/03_metadata_concepts.py similarity index 89% rename from tutorial/02_metadata_concepts.py rename to tutorial/03_metadata_concepts.py index 93e59ac..3a4a57d 100644 --- a/tutorial/02_metadata_concepts.py +++ b/tutorial/03_metadata_concepts.py @@ -38,19 +38,13 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 @@ -211,10 +205,12 @@ def demonstrate_batch_meta(): print(f"✓ Extra info: {batch.get_all_extra_info()}") print("[Example 3] Adding sample-level information through custom_meta...") - batch.set_custom_meta( - global_index=0, meta_dict={"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"} + batch.update_custom_meta( + [ + {"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"}, + {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"}, + ] ) - batch.update_custom_meta({1: {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"}}) print(f"✓ Custom meta: {batch.get_all_custom_meta()}") # Example 4: Chunk a batch @@ -308,32 +304,8 @@ def demonstrate_real_workflow(): if not ray.is_initialized(): ray.init() - # Setup TransferQueue - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - - controller = TransferQueueController.remote() - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, - ) - - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + # Initialize TransferQueue + tq.init() print("[Step 1] Putting data into TransferQueue...") input_ids = torch.randint(0, 1000, (8, 512)) @@ -348,23 +320,23 @@ def demonstrate_real_workflow(): ) partition_id = "demo_partition" - batch_meta = client.put(data=data_batch, partition_id=partition_id) + batch_meta = tq.put(data=data_batch, partition_id=partition_id) print(f"✓ Put {data_batch.batch_size[0]} samples into partition '{partition_id}', got BatchMeta back {batch_meta}.") print("[Step 2] [Optional] Setting sample-level custom_meta...") - custom_meta = { - global_index: {"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0} - for global_index in batch_meta.global_indexes - } + custom_meta = [ + {"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0} + for _ in range(batch_meta.size) + ] batch_meta.update_custom_meta(custom_meta) print(f"✓ Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}") - client.set_custom_meta(batch_meta) + tq.set_custom_meta(batch_meta) print("✓ Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.") print("[Step 3] Try to get metadata from TransferQueue from other places...") - batch_meta = client.get_meta( + batch_meta = tq.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=8, partition_id=partition_id, @@ -383,7 +355,7 @@ def demonstrate_real_workflow(): print("✓ Selected 'input_ids' field only:") print(f" New field names: {selected_meta.field_names}") print(f" Samples still have same global indexes: {selected_meta.global_indexes}") - retrieved_data = client.get_data(selected_meta) + retrieved_data = tq.get_data(selected_meta) print(f" Retrieved data keys: {list(retrieved_data.keys())}") print("[Step 5] Select specific samples from the retrieved BatchMeta...") @@ -391,7 +363,7 @@ def demonstrate_real_workflow(): print("✓ Selected samples at indices [0, 2, 4, 6]:") print(f" New global indexes: {partial_meta.global_indexes}") print(f" Number of samples: {len(partial_meta)}") - retrieved_data = client.get_data(partial_meta) + retrieved_data = tq.get_data(partial_meta) print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}") print("[Step 6] Demonstrate chunk operation...") @@ -399,12 +371,12 @@ def demonstrate_real_workflow(): print(f"✓ Chunked into {len(chunks)} parts:") for i, chunk in enumerate(chunks): print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}") - chunk_data = client.get_data(chunk) + chunk_data = tq.get_data(chunk) print(f" Chunk {i}: Retrieved chunk data: {chunk_data}") # Cleanup - client.clear_partition(partition_id=partition_id) - client.close() + tq.clear_partition(partition_id=partition_id) + tq.close() ray.shutdown() print("✓ Partition cleared and resources cleaned up") diff --git a/tutorial/03_understanding_controller.py b/tutorial/04_understanding_controller.py similarity index 76% rename from tutorial/03_understanding_controller.py rename to tutorial/04_understanding_controller.py index 8b7e6d0..4ca426d 100644 --- a/tutorial/03_understanding_controller.py +++ b/tutorial/04_understanding_controller.py @@ -36,59 +36,19 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 # Configure Ray os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" -def setup_transfer_queue(): - """Setup TransferQueue components.""" - if not ray.is_initialized(): - ray.init() - - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - - controller = TransferQueueController.remote() - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, - ) - - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - - return controller, storage_units, client - - def demonstrate_partition_isolation(): """Feature 1: Different partitions are isolated - data doesn't interfere.""" print("=" * 80) @@ -97,7 +57,10 @@ def demonstrate_partition_isolation(): print("\nDifferent partitions are completely isolated - data doesn't interfere between partitions") - controller, storage_units, client = setup_transfer_queue() + if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + tq.init() # Partition 1: Training data print("\n[Partition 1] Putting training data...") @@ -108,7 +71,7 @@ def demonstrate_partition_isolation(): }, batch_size=2, ) - client.put(data=train_data, partition_id="train") + tq.put(data=train_data, partition_id="train") print(" ✓ Training data added to 'train' partition") # Partition 2: Validation data @@ -120,25 +83,23 @@ def demonstrate_partition_isolation(): }, batch_size=2, ) - client.put(data=val_data, partition_id="val") + tq.put(data=val_data, partition_id="val") print(" ✓ Validation data added to 'val' partition") # Get from train partition print("\n[Retrieving from 'train' partition]") - train_meta = client.get_meta( + train_meta = tq.get_meta( data_fields=["input_ids", "labels"], batch_size=2, partition_id="train", task_name="train_task" ) - retrieved_train_data = client.get_data(train_meta) + retrieved_train_data = tq.get_data(train_meta) print(f" ✓ Got BatchMeta={train_meta} from partition 'train'") print(f" ✓ Retrieved Data: input_ids={retrieved_train_data['input_ids']}, labels={retrieved_train_data['labels']}") # Get from val partition print("\n[Retrieving from 'val' partition]") - val_meta = client.get_meta( - data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task" - ) - retrieved_val_data = client.get_data(val_meta) + val_meta = tq.get_meta(data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task") + retrieved_val_data = tq.get_data(val_meta) print(f" ✓ Got BatchMeta={val_meta} from partition 'val'") print(f" ✓ Retrieved Data: input_ids={retrieved_val_data['input_ids']}, labels={retrieved_val_data['labels']}") @@ -146,9 +107,9 @@ def demonstrate_partition_isolation(): print(" ✓ Data isolation: 'train' and 'val' partitions are completely independent") # Cleanup - client.clear_partition(partition_id="train") - client.clear_partition(partition_id="val") - client.close() + tq.clear_partition(partition_id="train") + tq.clear_partition(partition_id="val") + tq.close() ray.shutdown() @@ -160,7 +121,10 @@ def demonstrate_dynamic_expansion(): print("\nPartitions dynamically expand to accommodate new data (rows and columns)") - controller, storage_units, client = setup_transfer_queue() + if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + tq.init() # Add first batch with 2 samples, 2 fields print("\n[Step 1] Adding initial data (2 samples, 2 fields)...") @@ -171,7 +135,7 @@ def demonstrate_dynamic_expansion(): }, batch_size=2, ) - meta1 = client.put(data=data1, partition_id="dynamic") + meta1 = tq.put(data=data1, partition_id="dynamic") print(" ✓ Added 2 samples") print(f" ✓ Got BatchMeta: {meta1} samples") @@ -184,9 +148,9 @@ def demonstrate_dynamic_expansion(): }, batch_size=3, ) - meta2 = client.put(data=data2, partition_id="dynamic") + meta2 = tq.put(data=data2, partition_id="dynamic") - all_meta = client.get_meta( + all_meta = tq.get_meta( data_fields=["field1", "field2"], batch_size=5, partition_id="dynamic", task_name="dynamic_task" ) print(" ✓ Added 3 more samples (total: 5)") @@ -201,7 +165,7 @@ def demonstrate_dynamic_expansion(): }, batch_size=2, ) - meta3 = client.put(data=data3, metadata=meta1) + meta3 = tq.put(data=data3, metadata=meta1) print(" ✓ Added 2 samples with new field 'field3'") print(f" ✓ Got BatchMeta: {meta3} for newly put data with new field") @@ -210,8 +174,8 @@ def demonstrate_dynamic_expansion(): print(" ✓ Columns auto-expand: Can add new fields anytime") # Cleanup - client.clear_partition(partition_id="dynamic") - client.close() + tq.clear_partition(partition_id="dynamic") + tq.close() ray.shutdown() @@ -221,7 +185,10 @@ def demonstrate_default_consumption_sample_strategy(): print("Feature 3: Default Sampling Strategy for Controller - No Duplicate, Sequential Samples") print("=" * 80) - controller, storage_units, client = setup_transfer_queue() + if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + tq.init() # Add 6 samples print("\n[Setup] Adding 6 samples...") @@ -231,22 +198,22 @@ def demonstrate_default_consumption_sample_strategy(): }, batch_size=6, ) - client.put(data=all_data, partition_id="sampling") + tq.put(data=all_data, partition_id="sampling") print(" ✓ 6 samples added") # First get - should get samples 0,1,2 print("\n[Task A, Get 1] Requesting 3 samples...") - meta1 = client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") + meta1 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") print(f" ✓ Got samples: {meta1.global_indexes}") # Second get - should get samples 3,4,5 (no duplicates!) print("\n[Task A, Get 2] Requesting 3 more samples...") - meta2 = client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") + meta2 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") print(f" ✓ Got samples: {meta2.global_indexes}") # Third get - should get samples 0,1 print("\n[Task B, Get 1] Requesting 2 samples...") - meta3 = client.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B") + meta3 = tq.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -257,8 +224,8 @@ def demonstrate_default_consumption_sample_strategy(): print(" ✓ Third get (Task B): samples 0,1") # Cleanup - client.clear_partition(partition_id="sampling") - client.close() + tq.clear_partition(partition_id="sampling") + tq.close() ray.shutdown() diff --git a/tutorial/04_custom_sampler.py b/tutorial/05_custom_sampler.py similarity index 83% rename from tutorial/04_custom_sampler.py rename to tutorial/05_custom_sampler.py index 7bf13cd..c35b4e0 100644 --- a/tutorial/04_custom_sampler.py +++ b/tutorial/05_custom_sampler.py @@ -49,12 +49,7 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +import transfer_queue as tq # noqa: E402 from transfer_queue.sampler import BaseSampler # noqa: E402 @@ -171,36 +166,14 @@ def sample( def setup_transfer_queue_with_sampler(sampler): """Setup TransferQueue with custom sampler.""" if not ray.is_initialized(): - ray.init() + ray.init(namespace="TransferQueueTutorial") config = OmegaConf.create( - { - "global_batch_size": 8, - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(2): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) - - controller = TransferQueueController.remote(sampler=sampler) - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - - client = TransferQueueClient( - client_id="TutorialClient", - controller_info=controller_info, + {"controller": {"sampler": sampler}, "backend": {"SimpleStorage": {"num_data_storage_units": 2}}}, + flags={"allow_objects": True}, ) - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - - return controller, storage_units, client + tq.init(config) def demonstrate_random_sampler_with_replacement(): @@ -211,7 +184,7 @@ def demonstrate_random_sampler_with_replacement(): print("\nSetup TransferQueue with RandomSamplerWithReplacement...") sampler = RandomSamplerWithReplacement() - controller, storage_units, client = setup_transfer_queue_with_sampler(sampler) + setup_transfer_queue_with_sampler(sampler) # Add 5 samples print("\n[Step 1] Adding 5 samples...") @@ -221,22 +194,22 @@ def demonstrate_random_sampler_with_replacement(): }, batch_size=5, ) - client.put(data=data, partition_id="test") + tq.put(data=data, partition_id="test") print(" ✓ 5 samples added") # Get batch 1 (should get 2 random samples) print("\n[Step 2] Get batch 1 (2 samples)...") - meta1 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta1 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta1.global_indexes}") # Get batch 2 (should get 1 random sample with replacement - may have duplicate with previous batch!) print("\n[Step 3] Get batch 2 (1 sample)...") - meta2 = client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") + meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta2.global_indexes}") # Get batch 3 (should get 2 random samples with replacement - may have duplicate with previous batches!) print("\n[Step 4] Get batch 3 (2 samples)...") - meta3 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -246,8 +219,8 @@ def demonstrate_random_sampler_with_replacement(): print(f" ✓ All sampled: {all_sampled}") # Cleanup - client.clear_partition(partition_id="test") - client.close() + tq.clear_partition(partition_id="test") + tq.close() ray.shutdown() @@ -259,7 +232,7 @@ def demonstrate_random_sampler_without_replacement(): print("\nSetup TransferQueue with RandomSamplerWithoutReplacement...") sampler = RandomSamplerWithoutReplacement() - controller, storage_units, client = setup_transfer_queue_with_sampler(sampler) + setup_transfer_queue_with_sampler(sampler) # Add 6 samples print("\n[Step 1] Adding 6 samples...") @@ -269,22 +242,22 @@ def demonstrate_random_sampler_without_replacement(): }, batch_size=6, ) - client.put(data=data, partition_id="test") + tq.put(data=data, partition_id="test") print(" ✓ 6 samples added") # Get batch 1 (should get 3 random samples without replacement) print("\n[Step 2] Get batch 1 (3 samples)...") - meta1 = client.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task") + meta1 = tq.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta1.global_indexes}") # Get batch 2 (should randomly get 1 sample that are different from previous batch) print("\n[Step 3] Get batch 2 (1 samples)...") - meta2 = client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") + meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta2.global_indexes}") # Get batch 3 (should randomly get 2 samples that are different from previous batch) print("\n[Step 4] Get batch 3 (2 samples)...") - meta3 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -294,8 +267,8 @@ def demonstrate_random_sampler_without_replacement(): print(f" ✓ Batch 3: {meta3.global_indexes} (none left)") # Cleanup - client.clear_partition(partition_id="test") - client.close() + tq.clear_partition(partition_id="test") + tq.close() ray.shutdown() @@ -307,7 +280,7 @@ def demonstrate_priority_sampler(): print("\nSetup TransferQueue with PrioritySampler...") sampler = PrioritySampler() - controller, storage_units, client = setup_transfer_queue_with_sampler(sampler) + setup_transfer_queue_with_sampler(sampler) # Add 8 samples print("\n[Step 1] Adding 8 samples...") @@ -317,7 +290,7 @@ def demonstrate_priority_sampler(): }, batch_size=8, ) - client.put(data=data, partition_id="test") + tq.put(data=data, partition_id="test") print(" ✓ 8 samples added") time.sleep(1) @@ -330,7 +303,7 @@ def demonstrate_priority_sampler(): print(f"Priority scores: {priority_scores}") # Get batch using priority sampling - meta1 = client.get_meta( + meta1 = tq.get_meta( data_fields=["input"], batch_size=1, partition_id="test", @@ -342,7 +315,7 @@ def demonstrate_priority_sampler(): # Get another batch print("\n[Step 3] Get another batch (2 samples)...") - meta2 = client.get_meta( + meta2 = tq.get_meta( data_fields=["input"], batch_size=2, partition_id="test", @@ -358,8 +331,8 @@ def demonstrate_priority_sampler(): print(f" ✓ Batch 2 high-priority indices: {[i for i in meta2.global_indexes if priority_scores[i] >= 0.1]}") # Cleanup - client.clear_partition(partition_id="test") - client.close() + tq.clear_partition(partition_id="test") + tq.close() ray.shutdown() diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py similarity index 84% rename from tutorial/05_streaming_dataloader.py rename to tutorial/06_streaming_dataloader.py index 916be00..4d92aa2 100644 --- a/tutorial/05_streaming_dataloader.py +++ b/tutorial/06_streaming_dataloader.py @@ -57,39 +57,25 @@ import ray # noqa: E402 import torch # noqa: E402 -from omegaconf import DictConfig, OmegaConf # noqa: E402 +from omegaconf import OmegaConf # noqa: E402 from tensordict import TensorDict # noqa: E402 # Add the parent directory to the path for imports parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) - +import transfer_queue as tq # noqa: E402 from transfer_queue import ( # noqa: E402 RankAwareSampler, - SimpleStorageUnit, StreamingDataLoader, StreamingDataset, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, ) def setup_transfer_queue(): """Setup TransferQueue components.""" if not ray.is_initialized(): - ray.init() - - config = OmegaConf.create( - { - "num_data_storage_units": 2, - } - ) - - storage_units = {} - for i in range(config["num_data_storage_units"]): - storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100) + ray.init(namespace="TransferQueueTutorial") print("[Setup]: Setup TransferQueue components") print( @@ -101,26 +87,23 @@ def setup_transfer_queue(): "TransferQueueController. In polling_mode, the controller will return empty BatchMeta when " "available data cannot meet the consumption requirements. User side need to retry later." ) - controller = TransferQueueController.remote( - sampler=RankAwareSampler, # RankAwareSampler enables consistent sampling for each DP rank - polling_mode=True, # Enable polling mode for streaming data retrieval - ) - - controller_info = process_zmq_server_info(controller) - storage_unit_infos = process_zmq_server_info(storage_units) - # Build the complete configuration - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) - tq_config.controller_info = controller_info - tq_config.storage_unit_infos = storage_unit_infos - config.storage_backend = "AsyncSimpleStorageManager" - config = OmegaConf.merge(tq_config, config) + config = OmegaConf.create( + { + "controller": { + "sampler": RankAwareSampler, # RankAwareSampler enables consistent sampling for each DP rank + "polling_mode": True, # Enable polling mode for streaming data retrieval + }, + "backend": {"SimpleStorage": {"num_data_storage_units": 2}}, + }, + flags={"allow_objects": True}, + ) - return controller, storage_units, config + tq.init(config) @ray.remote(num_cpus=0.1) -def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): +def generate_worker(rank_id: int, num_samples: int = 20): """ Generate actor that produces training samples. @@ -129,7 +112,6 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): Args: rank_id: Unique identifier for this generator (used for sample indexing) - config: TransferQueue configuration num_samples: Number of samples to generate Note: @@ -137,13 +119,9 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): This ensures global uniqueness across all generator actors. """ # Create a client for interacting with TransferQueue - client = TransferQueueClient( - client_id=f"gen_worker_{rank_id}", - controller_info=config.controller_info, - ) - # Initialize the storage manager for this client - client.initialize_storage_manager(manager_type=config.storage_backend, config=config) + # Need to call tq.init() in each process + tq.init() # Generate and put samples into the queue for i in range(num_samples): @@ -159,7 +137,7 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue") # Put data into the specified partition - client.put(data, partition_id="train") + tq.put(data, partition_id="train") print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue") @@ -168,7 +146,6 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20): def update_worker( rank_id: int, dp_rank: int, - config: DictConfig, max_steps: int = 5, ): """ @@ -182,7 +159,6 @@ def update_worker( rank_id: Global rank identifier for logging and display purposes dp_rank: Data parallel rank ID that this worker belongs to The same Ranks receive the same data samples - config: TransferQueue configuration max_steps: Maximum number of batches to consume Returns: @@ -200,8 +176,15 @@ def update_worker( - batch_meta: Metadata for TransferQueue coordination (contains global_indexes) """ + # Need to call tq.init() in each process + tq.init() + # Step 1: Create StreamingDataset # This dataset integrates with TransferQueue and handles batch retrieval + + controller = ray.get_actor("TransferQueueController") + config = ray.get(controller.get_config.remote()) + dataset = StreamingDataset( config=config, batch_size=2, @@ -253,7 +236,7 @@ def update_worker( } -def start_all_generate_actors(config): +def start_all_generate_actors(): """ Launch generate_actors for producing training samples. """ @@ -261,12 +244,12 @@ def start_all_generate_actors(config): handlers = [] for i in range(num_workers): - handlers.append(generate_worker.remote(rank_id=i, config=config, num_samples=20)) + handlers.append(generate_worker.remote(rank_id=i, num_samples=20)) return handlers -def start_all_update_actors(config): +def start_all_update_actors(): """ Launch update_actors for consuming training samples. """ @@ -285,7 +268,6 @@ def start_all_update_actors(config): update_worker.remote( rank_id=rank_ids[i], dp_rank=dp_rank[i], - config=config, ) ) @@ -331,15 +313,15 @@ def main(): "global_batch_size to make sure consumers can accurately determine consumption status even before " "producers have generated the samples." ) - controller, storage_units, config = setup_transfer_queue() + setup_transfer_queue() # Step 2: Launch data generation actors print("\n[Phase 2] Starting data generation...") - generate_worker_handlers = start_all_generate_actors(config) + generate_worker_handlers = start_all_generate_actors() # Step 3: Launch data consumption actors print("\n[Phase 3] Starting data consumption...") - update_worker_handlers = start_all_update_actors(config) + update_worker_handlers = start_all_update_actors() # Wait for completion print("\n[Phase 4] Waiting for actors to complete...") From 4836b484ab36d412447caf2056f6611c6dc4657e Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 00:06:07 +0800 Subject: [PATCH 02/34] todo: fix ut Signed-off-by: 0oshowero0 --- tests/test_kv_interface.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_kv_interface.py b/tests/test_kv_interface.py index 24d74af..6b9eee1 100644 --- a/tests/test_kv_interface.py +++ b/tests/test_kv_interface.py @@ -257,15 +257,15 @@ def test_kv_list_with_keys(self): def test_kv_list_empty_partition(self): """Test kv_list returns None when partition is empty.""" mock_client = MagicMock() - mock_client.kv_list.return_value = [], [] + mock_client.kv_list.return_value = [] with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): from transfer_queue.interface import kv_list keys, custom_meta = kv_list(partition_id="empty_partition") - assert keys == [] - assert custom_meta == [] + assert keys is None + assert custom_meta is None class TestKVClear: From c1cb09f82bd8c08ff9c58dd1b2ff00caf96d2c1a Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 15:36:19 +0800 Subject: [PATCH 03/34] try add tutorial Signed-off-by: 0oshowero0 --- tutorial/02_kv_interface.py | 179 +++++++++++++----------- tutorial/03_metadata_concepts.py | 2 +- tutorial/04_understanding_controller.py | 2 +- tutorial/05_custom_sampler.py | 2 +- tutorial/06_streaming_dataloader.py | 2 +- 5 files changed, 98 insertions(+), 89 deletions(-) diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py index 59159a6..b54b03a 100644 --- a/tutorial/02_kv_interface.py +++ b/tutorial/02_kv_interface.py @@ -53,97 +53,108 @@ ray.init(namespace="TransferQueueTutorial") -def demonstrate_data_workflow(): +def demonstrate_kv_api(): """ - Demonstrate basic data workflow: put → get → clear. + Demonstrate xxxxx """ print("=" * 80) - print("Data Workflow Demo: put → get → clear") + print("Data xxxx") print("=" * 80) - # Step 1: Put data - print("[Step 1] Putting data into TransferQueue...") + # Step 1: Put single data sample + print("[Step 1] Putting single data into TransferQueue...") - input_ids = torch.tensor( + input_ids = torch.tensor([[1, 2, 3]]) + attention_mask = torch.ones(input_ids.size()) + + single_sample = TensorDict( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + }, + batch_size=input_ids.size(0), + ) + + print(f" Created single sample with multiple fields:{single_sample.keys()}.") + print(" Leveraging TransferQueue, we can provide fine-grained access into each single field of a sample (key).") + + partition_id = "Train" + key = [f"{uid}_{session_id}" for uid in [0] for session_id in [0]] + tag = [{"global_steps": 0, "status": "running", "model_version": 0}] + tq.kv_put(key, partition_id=partition_id, fields=single_sample, tag=tag) + print(f" ✓ Data put to partition: {partition_id}") + + # Step 2: Put data batch + print("[Step 2] Putting multiple data into TransferQueue...") + + batch_input_ids = torch.tensor( [ - [1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], + [13, 14, 15], ] ) - attention_mask = torch.ones_like(input_ids) + batch_attention_mask = torch.ones_like(batch_input_ids) data_batch = TensorDict( { - "input_ids": input_ids, - "attention_mask": attention_mask, + "input_ids": batch_input_ids, + "attention_mask": batch_attention_mask, }, - batch_size=input_ids.size(0), + batch_size=batch_input_ids.size(0), ) - print(f" Created {data_batch.batch_size[0]} samples") - partition_id = "tutorial_partition_0" - tq.put(data=data_batch, partition_id=partition_id) - print(f" ✓ Data written to partition: {partition_id}") - - # Step 2: Get metadata - print("[Step 2] Requesting data metadata...") - batch_meta = tq.get_meta( - data_fields=["input_ids", "attention_mask"], - batch_size=data_batch.batch_size[0], - partition_id=partition_id, - task_name="tutorial_task", - ) - print(f" ✓ Got metadata: {len(batch_meta)} samples") - print(f" Global indexes: {batch_meta.global_indexes}") - - # Step 3: Get actual data - print("[Step 3] Retrieving actual data...") - retrieved_data = tq.get_data(batch_meta) - print(" ✓ Data retrieved successfully") - print(f" Keys: {list(retrieved_data.keys())}") - - # Step 4: Verify - print("[Step 4] Verifying data integrity...") - assert torch.equal(retrieved_data["input_ids"], input_ids) - assert torch.equal(retrieved_data["attention_mask"], attention_mask) - print(" ✓ Data matches original!") + partition_id = "Train" + keys = [f"{uid}_{session_id}" for uid in [1, 1, 1, 2] for session_id in [0, 1, 2, 0]] + tags = [{"global_steps": 1, "status": "running", "model_version": 1} for _ in range(len(keys))] - # Step 5: Clear - print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)") - tq.clear_partition(partition_id=partition_id) - print(" ✓ Partition cleared") + print(f" Created {data_batch.batch_size[0]} samples, assigning keys: {keys}, tags: {tags}") + tq.kv_batch_put(keys, partition_id=partition_id, fields=data_batch, tags=tags) + print(f" ✓ Data put to partition: {partition_id}") + # Step 3: Append new data fields to existing samples + print("[Step 3] Putting multiple data into TransferQueue...") + batch_response = torch.tensor( + [ + [4, 5, 6], + [7, 8, 9], + ] + ) + data_batch = TensorDict( + { + "response": batch_response, + }, + batch_size=batch_response.size(0), + ) -def demonstrate_storage_backend_options(): - """ - Show different storage backend options. - """ - print("=" * 80) - print("Storage Backend Options") - print("=" * 80) + keys = [f"{uid}_{session_id}" for uid in [1, 2] for session_id in [1, 0]] + tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(keys))] - print("TransferQueue supports multiple storage backends:") - print("1. SimpleStorage (default)") - print(" - In-memory storage, fast and simple") - print(" - Leveraging ZMQ for communication, with zero-copy serialization and transfer") - print(" - No extra dependencies, good for development and testing") + tq.kv_batch_put(keys, partition_id=partition_id, fields=data_batch, tags=tags) - print("2. Yuanrong") - print(" - Ascend native distributed storage solution") - print(" - Hierarchical storage interfaces including HBM/DRAM/SSD") + # Step 4: Query all keys and tags + print("[Step 4] Query all the keys and tags from TransferQueue...") + all_keys, all_tags = tq.kv_list(partition_id=partition_id) + print(f" ✓ Got keys: {keys}") + print(f" Got tags: {tags}") - print("3. MooncakeStore (on the way)") - print(" - Support multiple transmission protocols") - print(" - RDMA between DRAM and HBM") + # Step 5: Get specific fields of values + print("[Step 5] ...") + retrieved_input_ids_data = tq.kv_get(all_keys, fields="input_ids") + print(" ✓ Data retrieved successfully") + print(f" {retrieved_input_ids_data}") - print("4. Ray RDT (on the way)") - print(" - Leverage Ray's distributed object store to store data") + # Step 6: Get all fields of values + print("[Step 5] ...") + retrieved_all_data = tq.kv_get(all_keys) + print(" ✓ Data retrieved successfully") + print(f" {retrieved_all_data}") - print("5. Custom Storage Backends") - print(" - Implement your own storage manager by inheriting from `TransferQueueStorageManager` base class") - print(" - For KV based storage, you only need to provide a storage client and integrate with `KVStorageManager`") + # Step 5: Clear + print("[Step 5] Clearing partition...") + tq.kv_clear(all_keys, partition_id=partition_id) + print(" ✓ Keys are cleared") def main(): @@ -151,18 +162,23 @@ def main(): print( textwrap.dedent( """ - TransferQueue Tutorial 1: Core Components Introduction + TransferQueue Tutorial 2: Key-Value Semantic API - This script introduces the three core components of TransferQueue: - 1. TransferQueueController - Manages all the metadata and tracks the production and consumption states - 2. StorageBackend - Pluggable distributed storage backend that holds the actual data - 3. TransferQueueClient - Client interface for reading/writing data (user-facing API) + This script demonstrate the key-value semantic API of TransferQueue: + 1. kv_put & kv_batch_put - Put key-value pairs and custom tags into TransferQueue + 2. kv_get - Get values from TransferQueue according to user-specified keys + 3. kv_list - Get all the keys and tags from TransferQueue + 4. kv_clear - Delete the value and tags of given keys - Key Concepts: - - Data is organized into logical partitions (e.g., "train", "val") - - Each sample has multiple fields, with a global index for identification - - Controller maintains production/consumption state tracking - - Client is the main interface users interact with + Supported Features: + 1. Fine-grained access - user can put/get partial fields inside a data sample (key) + 2. Partition management - each logical partition manages their own key-value mapping + + Unsupported Features: + 1. Production & consumption management (user have to manually management through tags) + 2. User-defined sampler in TransferQueue controller (user need to do sampling by themselves through tags) + 3. Fully streamed data pipeline (TQ controller cannot determine which sample to dispatch to the consumers) + """ ) ) @@ -172,21 +188,14 @@ def main(): print("Setting up TransferQueue...") tq.init() - print("Demonstrating the user workflow...") - demonstrate_data_workflow() - - demonstrate_storage_backend_options() + print("Demonstrating the key-value semantic API...") + demonstrate_kv_api() print("=" * 80) print("Tutorial Complete!") print("=" * 80) print("Key Takeaways:") - print("1. TransferQueue has 3 core components:") - print(" - Controller: Manages data production/consumption state") - print(" - StorageBackend: Persists actual data") - print(" - Client: User-facing API (what you use)") - print("2. Client is the main interface users interact with") - print("3. You can swap out different storage backends easily") + print("1. ") # Cleanup tq.close() diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py index 3a4a57d..bbda5b9 100644 --- a/tutorial/03_metadata_concepts.py +++ b/tutorial/03_metadata_concepts.py @@ -387,7 +387,7 @@ def main(): print( textwrap.dedent( """ - TransferQueue Tutorial 2: Metadata System + TransferQueue Tutorial 3: Metadata System This script introduces the metadata system in TransferQueue, which tracks the structure and state of data: diff --git a/tutorial/04_understanding_controller.py b/tutorial/04_understanding_controller.py index 4ca426d..fbeae21 100644 --- a/tutorial/04_understanding_controller.py +++ b/tutorial/04_understanding_controller.py @@ -235,7 +235,7 @@ def main(): print( textwrap.dedent( """ - TransferQueue Tutorial 3: Understanding TransferQueueController + TransferQueue Tutorial 4: Understanding TransferQueueController This script demonstrates TransferQueueController's key features: diff --git a/tutorial/05_custom_sampler.py b/tutorial/05_custom_sampler.py index c35b4e0..bdca835 100644 --- a/tutorial/05_custom_sampler.py +++ b/tutorial/05_custom_sampler.py @@ -341,7 +341,7 @@ def main(): print( textwrap.dedent( """ - TransferQueue Tutorial 4: Custom Sampler Development + TransferQueue Tutorial 5: Custom Sampler Development This script demonstrates how to develop custom samplers for TransferQueue. Samplers control HOW data is consumed from the queue. diff --git a/tutorial/06_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py index 4d92aa2..d95a3c8 100644 --- a/tutorial/06_streaming_dataloader.py +++ b/tutorial/06_streaming_dataloader.py @@ -289,7 +289,7 @@ def main(): print( textwrap.dedent( """ - TransferQueue Tutorial 5: StreamingDataLoader for Distributed Training + TransferQueue Tutorial 6: StreamingDataLoader for Distributed Training This tutorial demonstrates the StreamingDataLoader interface for distributed training scenarios. It showcases how to use StreamingDataset and StreamingDataLoader From c53e5fa9161e78277ae55157b169fb34abd9e3c5 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 16:42:39 +0800 Subject: [PATCH 04/34] provide basic runnable tutorial Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 10 ++- tutorial/02_kv_interface.py | 164 ++++++++++++++++++++--------------- 2 files changed, 105 insertions(+), 69 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 6175863..6e46280 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1504,6 +1504,9 @@ def kv_retrieve_keys( partition.keys_mapping[keys[none_indexes[i]]] = batch_global_indexes[i] partition.revert_keys_mapping[batch_global_indexes[i]] = keys[none_indexes[i]] + with partition.data_status_lock: + partition.ensure_samples_capacity(max(batch_global_indexes) + 1) + verified_global_indexes = [idx for idx in global_indexes if idx is not None] assert len(verified_global_indexes) == len(keys) @@ -1809,10 +1812,14 @@ def _process_request(self): elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS: with perf_monitor.measure(op_type="KV_RETRIEVE_KEYS"): + params = request_msg.body keys = params["keys"] partition_id = params["partition_id"] create = params["create"] + print("+++++++++++TQDEBUG+++++++++++++++++") + print(params) + metadata = self.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create) response_msg = ZMQMessage.create( request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE, @@ -1823,6 +1830,7 @@ def _process_request(self): elif request_msg.request_type == ZMQRequestType.KV_LIST: with perf_monitor.measure(op_type="KV_LIST"): + params = request_msg.body partition_id = params["partition_id"] partition = self._get_partition(partition_id) if not partition: @@ -1839,7 +1847,7 @@ def _process_request(self): request_type=ZMQRequestType.KV_LIST_RESPONSE, sender_id=self.controller_id, receiver_id=request_msg.sender_id, - body={"keys": keys, "custom_meta": custom_meta, message: message}, + body={"keys": keys, "custom_meta": custom_meta, "message": message}, ) self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py index b54b03a..921a02b 100644 --- a/tutorial/02_kv_interface.py +++ b/tutorial/02_kv_interface.py @@ -55,14 +55,15 @@ def demonstrate_kv_api(): """ - Demonstrate xxxxx + Demonstrate the Key-Value (KV) semantic API: + kv_put & kv_batch_put -> kv_get -> kv_list -> kv_clear """ print("=" * 80) - print("Data xxxx") + print("Key-Value Semantic API Demo: kv_put → kv_get → kv_list → kv_clear") print("=" * 80) - # Step 1: Put single data sample - print("[Step 1] Putting single data into TransferQueue...") + # Step 1: Put a single key-value pair with kv_put + print("[Step 1] Putting a single sample with kv_put...") input_ids = torch.tensor([[1, 2, 3]]) attention_mask = torch.ones(input_ids.size()) @@ -75,17 +76,18 @@ def demonstrate_kv_api(): batch_size=input_ids.size(0), ) - print(f" Created single sample with multiple fields:{single_sample.keys()}.") - print(" Leveraging TransferQueue, we can provide fine-grained access into each single field of a sample (key).") - partition_id = "Train" - key = [f"{uid}_{session_id}" for uid in [0] for session_id in [0]] - tag = [{"global_steps": 0, "status": "running", "model_version": 0}] - tq.kv_put(key, partition_id=partition_id, fields=single_sample, tag=tag) - print(f" ✓ Data put to partition: {partition_id}") + key = "0_0" # User-defined key: "{uid}_{session_id}" + tag = {"global_steps": 0, "status": "running", "model_version": 0} + + print(f" Created single sample with key: {key}, fields: {list(single_sample.keys())}, and tag: {tag}") + print(" Note: kv_put accepts a user-defined string key instead of auto-generated index") + + tq.kv_put(key=key, partition_id=partition_id, fields=single_sample, tag=tag) + print(f" ✓ kv_put: key='{key}', tag={tag}") - # Step 2: Put data batch - print("[Step 2] Putting multiple data into TransferQueue...") + # Step 2: Put multiple key-value pairs with kv_batch_put + print("\n[Step 2] Putting batch data with kv_batch_put...") batch_input_ids = torch.tensor( [ @@ -105,56 +107,66 @@ def demonstrate_kv_api(): batch_size=batch_input_ids.size(0), ) - partition_id = "Train" - keys = [f"{uid}_{session_id}" for uid in [1, 1, 1, 2] for session_id in [0, 1, 2, 0]] + keys = ["1_0", "1_1", "1_2", "2_0"] # 4 keys for 4 samples tags = [{"global_steps": 1, "status": "running", "model_version": 1} for _ in range(len(keys))] - print(f" Created {data_batch.batch_size[0]} samples, assigning keys: {keys}, tags: {tags}") - tq.kv_batch_put(keys, partition_id=partition_id, fields=data_batch, tags=tags) - print(f" ✓ Data put to partition: {partition_id}") + print(f" Created batch with {data_batch.batch_size[0]} samples") + print(f" Batch keys: {keys}") + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data_batch, tags=tags) + print(f" ✓ kv_batch_put: {len(keys)} samples written to partition '{partition_id}'") + + # Step 3: Append additional fields to existing samples + print("\n[Step 3] Appending new fields to existing samples...") - # Step 3: Append new data fields to existing samples - print("[Step 3] Putting multiple data into TransferQueue...") batch_response = torch.tensor( [ [4, 5, 6], [7, 8, 9], ] ) - data_batch = TensorDict( + response_batch = TensorDict( { "response": batch_response, }, batch_size=batch_response.size(0), ) - keys = [f"{uid}_{session_id}" for uid in [1, 2] for session_id in [1, 0]] - tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(keys))] + append_keys = ["1_1", "2_0"] # Appending to existing samples + append_tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(append_keys))] + print(f" Adding 'response' field to keys: {append_keys}") + tq.kv_batch_put(keys=append_keys, partition_id=partition_id, fields=response_batch, tags=append_tags) + print(" ✓ Field appended successfully (sample now has input_ids, attention_mask, and response)") - tq.kv_batch_put(keys, partition_id=partition_id, fields=data_batch, tags=tags) - - # Step 4: Query all keys and tags - print("[Step 4] Query all the keys and tags from TransferQueue...") + # Step 4: List all keys and tags in a partition + print("\n[Step 4] Listing all keys and tags in partition...") all_keys, all_tags = tq.kv_list(partition_id=partition_id) - print(f" ✓ Got keys: {keys}") - print(f" Got tags: {tags}") - - # Step 5: Get specific fields of values - print("[Step 5] ...") - retrieved_input_ids_data = tq.kv_get(all_keys, fields="input_ids") - print(" ✓ Data retrieved successfully") - print(f" {retrieved_input_ids_data}") - - # Step 6: Get all fields of values - print("[Step 5] ...") - retrieved_all_data = tq.kv_get(all_keys) - print(" ✓ Data retrieved successfully") - print(f" {retrieved_all_data}") - - # Step 5: Clear - print("[Step 5] Clearing partition...") - tq.kv_clear(all_keys, partition_id=partition_id) - print(" ✓ Keys are cleared") + print(f" Found {len(all_keys)} keys in partition '{partition_id}':") + for k, t in zip(all_keys, all_tags, strict=False): + print(f" - key='{k}', tag={t}") + + # Step 5: Retrieve specific fields using kv_get + print("\n[Step 5] Retrieving specific fields with kv_get...") + retrieved_input_ids = tq.kv_get(keys=all_keys, partition_id=partition_id, fields="input_ids") + print(f" Retrieved 'input_ids' field for all {len(all_keys)} samples:") + print(f" Shape: {retrieved_input_ids.batch_size}") + print(f" Values: {retrieved_input_ids['input_ids']}") + + # TODO: this will fail because only single sample has an extra fields... + # need to add additional check during kv_get to make sure other samples are correctly tackled + # # Step 6: Retrieve all fields using kv_get + # print("\n[Step 6] Retrieving all fields with kv_get...") + # retrieved_all = tq.kv_get(keys=all_keys, partition_id=partition_id) + # print(f" Retrieved all fields for {len(all_keys)} samples:") + # print(f" Fields: {list(retrieved_all.keys())}") + + # Step 7: Clear specific keys + print("\n[Step 7] Clearing keys from partition...") + keys_to_clear = all_keys[:2] # Clear first 2 keys + tq.kv_clear(keys=keys_to_clear, partition_id=partition_id) + print(f" ✓ Cleared keys: {keys_to_clear}") + + remaining_keys, _ = tq.kv_list(partition_id=partition_id) + print(f" Remaining keys in partition: {remaining_keys}") def main(): @@ -162,23 +174,35 @@ def main(): print( textwrap.dedent( """ - TransferQueue Tutorial 2: Key-Value Semantic API - - This script demonstrate the key-value semantic API of TransferQueue: - 1. kv_put & kv_batch_put - Put key-value pairs and custom tags into TransferQueue - 2. kv_get - Get values from TransferQueue according to user-specified keys - 3. kv_list - Get all the keys and tags from TransferQueue - 4. kv_clear - Delete the value and tags of given keys - - Supported Features: - 1. Fine-grained access - user can put/get partial fields inside a data sample (key) - 2. Partition management - each logical partition manages their own key-value mapping - - Unsupported Features: - 1. Production & consumption management (user have to manually management through tags) - 2. User-defined sampler in TransferQueue controller (user need to do sampling by themselves through tags) - 3. Fully streamed data pipeline (TQ controller cannot determine which sample to dispatch to the consumers) - + TransferQueue Tutorial 2: Key-Value (KV) Semantic API + + This tutorial demonstrates the KV semantic API, which provides a simpler + interface for data storage and retrieval using user-defined string keys + instead of auto-generated numeric indexes. + + Key Methods: + 1. kv_put - Put a single key-value pair with optional metadata tag + 2. kv_batch_put - Put multiple key-value pairs efficiently in batch + 3. kv_get - Retrieve data by key(s), optionally specifying fields + 4. kv_list - List all keys and their metadata tags in a partition + 5. kv_clear - Remove key-value pairs from storage + + Key Features: + ✓ User-defined keys - Use meaningful string keys instead of numeric indexes + ✓ Fine-grained access - Get/put individual fields within a sample + ✓ Partition management - Each partition maintains its own key-value mapping + ✓ Metadata tags - Attach custom metadata (status, scores, etc.) to samples + + Use Cases: + - Storing per-model-checkpoint states + - Managing evaluation results by sample ID + - Caching intermediate computation results + - Fine-grained data access without full BatchMeta management + + Limitations (vs Full API): + - No built-in production/consumption tracking (manage via tags) + - No Sampler-based sampling (implement sampling logic externally) + - Controller doesn't control streaming (manual key management required) """ ) ) @@ -188,19 +212,23 @@ def main(): print("Setting up TransferQueue...") tq.init() - print("Demonstrating the key-value semantic API...") + print("\nDemonstrating the KV semantic API...") demonstrate_kv_api() - print("=" * 80) + print("\n" + "=" * 80) print("Tutorial Complete!") print("=" * 80) - print("Key Takeaways:") - print("1. ") + print("\nKey Takeaways:") + print(" 1. KV API simplifies data access with user-defined string keys") + print(" 2. kv_batch_put is more efficient for bulk operations") + print(" 3. Use 'fields' parameter to get/put specific fields only") + print(" 4. Tags enable custom metadata for production status, scores, etc.") + print(" 5. Use kv_list to inspect partition contents") # Cleanup tq.close() ray.shutdown() - print("\n✓ Cleanup complete") + print("\nCleanup complete") except Exception as e: print(f"Error during tutorial: {e}") From 85567f3624c65c89eedbfb3ed364600c0931aa20 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 16:49:33 +0800 Subject: [PATCH 05/34] fix other tutorials Signed-off-by: 0oshowero0 --- tutorial/03_metadata_concepts.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py index bbda5b9..126cc5b 100644 --- a/tutorial/03_metadata_concepts.py +++ b/tutorial/03_metadata_concepts.py @@ -209,6 +209,9 @@ def demonstrate_batch_meta(): [ {"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"}, {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"}, + {"uid": "prompt@2", "session_id": "session@0", "model_version": "epoch@0"}, + {"uid": "prompt@3", "session_id": "session@0", "model_version": "epoch@0"}, + {"uid": "prompt@4", "session_id": "session@0", "model_version": "epoch@0"}, ] ) print(f"✓ Custom meta: {batch.get_all_custom_meta()}") From 063b6b1e36a0e0325582a820b8b3d7c7d15e3c60 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 17:39:55 +0800 Subject: [PATCH 06/34] update tutorial Signed-off-by: 0oshowero0 --- tutorial/02_kv_interface.py | 110 ++++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 48 deletions(-) diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py index 921a02b..fc2e96b 100644 --- a/tutorial/02_kv_interface.py +++ b/tutorial/02_kv_interface.py @@ -56,15 +56,16 @@ def demonstrate_kv_api(): """ Demonstrate the Key-Value (KV) semantic API: - kv_put & kv_batch_put -> kv_get -> kv_list -> kv_clear + kv_put & kv_batch_put -> kv_list -> kv_get -> kv_clear """ print("=" * 80) - print("Key-Value Semantic API Demo: kv_put → kv_get → kv_list → kv_clear") + print("Key-Value Semantic API Demo: kv_put/kv_batch_put → kv_list → kv_get → kv_clear") print("=" * 80) # Step 1: Put a single key-value pair with kv_put print("[Step 1] Putting a single sample with kv_put...") + # Define the data content (The "Value") input_ids = torch.tensor([[1, 2, 3]]) attention_mask = torch.ones(input_ids.size()) @@ -77,14 +78,16 @@ def demonstrate_kv_api(): ) partition_id = "Train" + # Use a meaningful string key instead of an auto-increment integer key = "0_0" # User-defined key: "{uid}_{session_id}" tag = {"global_steps": 0, "status": "running", "model_version": 0} - print(f" Created single sample with key: {key}, fields: {list(single_sample.keys())}, and tag: {tag}") - print(" Note: kv_put accepts a user-defined string key instead of auto-generated index") + print(f" Inserting Key: {key}") + print(f" Fields (Columns): {list(single_sample.keys())}") + print(f" Tag (Metadata): {tag}") tq.kv_put(key=key, partition_id=partition_id, fields=single_sample, tag=tag) - print(f" ✓ kv_put: key='{key}', tag={tag}") + print(" ✓ kv_put success.") # Step 2: Put multiple key-value pairs with kv_batch_put print("\n[Step 2] Putting batch data with kv_batch_put...") @@ -110,13 +113,14 @@ def demonstrate_kv_api(): keys = ["1_0", "1_1", "1_2", "2_0"] # 4 keys for 4 samples tags = [{"global_steps": 1, "status": "running", "model_version": 1} for _ in range(len(keys))] - print(f" Created batch with {data_batch.batch_size[0]} samples") - print(f" Batch keys: {keys}") + print(f" Inserting batch of {len(keys)} samples.") + print(f" Fields (Columns): {list(data_batch.keys())}") + print(f" Tag (Metadata): {tags}") tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data_batch, tags=tags) - print(f" ✓ kv_batch_put: {len(keys)} samples written to partition '{partition_id}'") + print(" ✓ kv_batch_put success.") # Step 3: Append additional fields to existing samples - print("\n[Step 3] Appending new fields to existing samples...") + print("\n[Step 3] Appending new fields (Columns) to existing samples...") batch_response = torch.tensor( [ @@ -131,25 +135,38 @@ def demonstrate_kv_api(): batch_size=batch_response.size(0), ) + # We only update subset of keys append_keys = ["1_1", "2_0"] # Appending to existing samples append_tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(append_keys))] - print(f" Adding 'response' field to keys: {append_keys}") + print(f" Target Keys: {append_keys}") + print(" New Field to Add is: 'response'") + print(f" The updated tags are: {append_tags}") tq.kv_batch_put(keys=append_keys, partition_id=partition_id, fields=response_batch, tags=append_tags) - print(" ✓ Field appended successfully (sample now has input_ids, attention_mask, and response)") + print(" ✓ Update success: Samples '1_1' and '2_0' now contain {input_ids, attention_mask, response}.") + + # Step 4: Only update tags through kv_put + print("\n[Step 4] Update existing tags without providing value...") + key = "0_0" + tag = {"global_steps": 0, "status": "finish", "model_version": 0} + print(f" Target Key: {key}") + print(f" The updated tag is: {tag}") + tq.kv_put(key=key, partition_id=partition_id, fields=None, tag=tag) + print(f" ✓ Update success: Samples '0_0' now has tag as {tag}.") + + # Step 5: List all keys and tags in a partition + print("\n[Step 5] Listing all keys and tags in partition...") - # Step 4: List all keys and tags in a partition - print("\n[Step 4] Listing all keys and tags in partition...") all_keys, all_tags = tq.kv_list(partition_id=partition_id) print(f" Found {len(all_keys)} keys in partition '{partition_id}':") for k, t in zip(all_keys, all_tags, strict=False): - print(f" - key='{k}', tag={t}") + print(f" - key='{k}' | tag={t}") + + # Step 6: Retrieve specific fields using kv_get + print("\n[Step 6] Retrieving specific fields (Column) with kv_get...") + print(" Fetching only 'input_ids' to save bandwidth (ignoring 'attention_mask' and 'response').") - # Step 5: Retrieve specific fields using kv_get - print("\n[Step 5] Retrieving specific fields with kv_get...") retrieved_input_ids = tq.kv_get(keys=all_keys, partition_id=partition_id, fields="input_ids") - print(f" Retrieved 'input_ids' field for all {len(all_keys)} samples:") - print(f" Shape: {retrieved_input_ids.batch_size}") - print(f" Values: {retrieved_input_ids['input_ids']}") + print(f" ✓ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.") # TODO: this will fail because only single sample has an extra fields... # need to add additional check during kv_get to make sure other samples are correctly tackled @@ -159,9 +176,9 @@ def demonstrate_kv_api(): # print(f" Retrieved all fields for {len(all_keys)} samples:") # print(f" Fields: {list(retrieved_all.keys())}") - # Step 7: Clear specific keys - print("\n[Step 7] Clearing keys from partition...") - keys_to_clear = all_keys[:2] # Clear first 2 keys + # Step 8: Clear specific keys + print("\n[Step 8] Clearing keys from partition...") + keys_to_clear = all_keys[:2] # Delete the first 2 keys tq.kv_clear(keys=keys_to_clear, partition_id=partition_id) print(f" ✓ Cleared keys: {keys_to_clear}") @@ -176,33 +193,31 @@ def main(): """ TransferQueue Tutorial 2: Key-Value (KV) Semantic API - This tutorial demonstrates the KV semantic API, which provides a simpler - interface for data storage and retrieval using user-defined string keys - instead of auto-generated numeric indexes. + This tutorial demonstrates the KV semantic API, which provides a simple + interface for data storage and retrieval using user-defined string keys. Key Methods: - 1. kv_put - Put a single key-value pair with optional metadata tag - 2. kv_batch_put - Put multiple key-value pairs efficiently in batch - 3. kv_get - Retrieve data by key(s), optionally specifying fields - 4. kv_list - List all keys and their metadata tags in a partition - 5. kv_clear - Remove key-value pairs from storage + 1. (async_)kv_put - Insert/Update a multi-column sample by key, with optional metadata tag + 2. (async_)kv_batch_put - Put multiple key-value pairs efficiently in batch + 3. (async_)kv_get - Retrieve samples (by keys), supporting column selection (by fields) + 4. (async_)kv_list - List keys and tags (metadata) in a partition + 5. (async_)kv_clear - Remove key-value pairs from storage Key Features: - ✓ User-defined keys - Use meaningful string keys instead of numeric indexes - ✓ Fine-grained access - Get/put individual fields within a sample - ✓ Partition management - Each partition maintains its own key-value mapping - ✓ Metadata tags - Attach custom metadata (status, scores, etc.) to samples + ✓ Redis-style Semantics - Familiar KV interface (Put/Get/List) for zero learning curve + ✓ Fine-grained Access - Update or retrieve specific fields (columns) within a key (row) without full op. + ✓ Partition Isolation - Logical separation of storage namespaces + ✓ Metadata Tags - Lightweight metadata for status tracking + ✓ Pluggable Backends - Supports multiple backends Use Cases: - - Storing per-model-checkpoint states - - Managing evaluation results by sample ID - - Caching intermediate computation results - - Fine-grained data access without full BatchMeta management - - Limitations (vs Full API): - - No built-in production/consumption tracking (manage via tags) - - No Sampler-based sampling (implement sampling logic externally) - - Controller doesn't control streaming (manual key management required) + - Focusing on fine-grained data access where extreme streaming performance is non-essential + - Integration with external ReplayBuffer/single-controller that manage sample dispatching + + Limitations (vs low-level native APIs): + - No built-in production/consumption tracking: Users have to manually check status through tags. + - No built-in Sampler support: Must implement data dispatch by ReplayBuffer or single-controller externally. + - No fully streaming: Consumers must wait for single-controller to dispatch `keys`. """ ) ) @@ -219,11 +234,10 @@ def main(): print("Tutorial Complete!") print("=" * 80) print("\nKey Takeaways:") - print(" 1. KV API simplifies data access with user-defined string keys") - print(" 2. kv_batch_put is more efficient for bulk operations") - print(" 3. Use 'fields' parameter to get/put specific fields only") - print(" 4. Tags enable custom metadata for production status, scores, etc.") - print(" 5. Use kv_list to inspect partition contents") + print(" 1. KV API simplifies data access with Redis-style semantics") + print(" 2. Use 'fields' parameter to get/put specific fields only") + print(" 3. Tags enable custom metadata for production status, scores, etc.") + print(" 4. Use kv_list to inspect partition contents") # Cleanup tq.close() From 1a86e20b050385da769b3abae295faab05b7c227 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 18:23:59 +0800 Subject: [PATCH 07/34] fix Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 6e46280..1f430aa 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -829,8 +829,10 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr self.field_shapes.pop(idx, None) self.field_custom_backend_meta.pop(idx, None) self.custom_meta.pop(idx, None) - self.keys_mapping.pop(self.revert_keys_mapping[idx], None) - self.revert_keys_mapping.pop(idx, None) + + if idx in self.revert_keys_mapping: + self.keys_mapping.pop(self.revert_keys_mapping[idx], None) + self.revert_keys_mapping.pop(idx, None) except Exception as e: logger.error( @@ -1817,9 +1819,6 @@ def _process_request(self): partition_id = params["partition_id"] create = params["create"] - print("+++++++++++TQDEBUG+++++++++++++++++") - print(params) - metadata = self.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create) response_msg = ZMQMessage.create( request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE, From e1f6ecf7ec930c82291c7af4a0a5a8f8ef4a459b Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 19:06:26 +0800 Subject: [PATCH 08/34] fix cornercase Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 10 +++++++++- transfer_queue/interface.py | 10 ++++++++-- tutorial/02_kv_interface.py | 27 ++++++++++++++------------- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 1f430aa..1f01482 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1512,7 +1512,15 @@ def kv_retrieve_keys( verified_global_indexes = [idx for idx in global_indexes if idx is not None] assert len(verified_global_indexes) == len(keys) - data_fields = list(partition.field_name_mapping.keys()) + # must fetch fields that the requested samples all have + col_mask = partition.production_status[verified_global_indexes, :].sum(dim=0).reshape(-1) == len( + verified_global_indexes + ) + data_fields = [] + for fname, col_idx in partition.field_name_mapping.items(): + if col_mask[col_idx]: + data_fields.append(fname) + metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch") return metadata diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 2585c35..ed325b8 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -813,6 +813,7 @@ def kv_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] Raises: RuntimeError: If keys or partition are not found + RuntimeError: If empty fields exist in any key (sample) Example: >>> import transfer_queue as tq @@ -838,8 +839,10 @@ def kv_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] fields = [fields] batch_meta = batch_meta.select_fields(fields) - data = tq_client.get_data(batch_meta) + if not batch_meta.is_ready: + raise RuntimeError("Some fields are not ready in all the requested keys!") + data = tq_client.get_data(batch_meta) return data @@ -1052,6 +1055,7 @@ async def async_kv_get( Raises: RuntimeError: If keys or partition are not found + RuntimeError: If empty fields exist in any key (sample) Example: >>> import transfer_queue as tq @@ -1077,8 +1081,10 @@ async def async_kv_get( fields = [fields] batch_meta = batch_meta.select_fields(fields) - data = await tq_client.async_get_data(batch_meta) + if not batch_meta.is_ready: + raise RuntimeError("Some fields are not ready in all the requested keys!") + data = await tq_client.async_get_data(batch_meta) return data diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py index fc2e96b..2b309ca 100644 --- a/tutorial/02_kv_interface.py +++ b/tutorial/02_kv_interface.py @@ -146,12 +146,12 @@ def demonstrate_kv_api(): # Step 4: Only update tags through kv_put print("\n[Step 4] Update existing tags without providing value...") - key = "0_0" - tag = {"global_steps": 0, "status": "finish", "model_version": 0} - print(f" Target Key: {key}") - print(f" The updated tag is: {tag}") - tq.kv_put(key=key, partition_id=partition_id, fields=None, tag=tag) - print(f" ✓ Update success: Samples '0_0' now has tag as {tag}.") + key_for_update_tags = "0_0" + tag_update = {"global_steps": 0, "status": "finish", "model_version": 0} + print(f" Target Key: {key_for_update_tags}") + print(f" The updated tag is: {tag_update}") + tq.kv_put(key=key_for_update_tags, partition_id=partition_id, fields=None, tag=tag_update) + print(f" ✓ Update success: Samples '0_0' now has tag as {tag_update}.") # Step 5: List all keys and tags in a partition print("\n[Step 5] Listing all keys and tags in partition...") @@ -168,13 +168,14 @@ def demonstrate_kv_api(): retrieved_input_ids = tq.kv_get(keys=all_keys, partition_id=partition_id, fields="input_ids") print(f" ✓ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.") - # TODO: this will fail because only single sample has an extra fields... - # need to add additional check during kv_get to make sure other samples are correctly tackled - # # Step 6: Retrieve all fields using kv_get - # print("\n[Step 6] Retrieving all fields with kv_get...") - # retrieved_all = tq.kv_get(keys=all_keys, partition_id=partition_id) - # print(f" Retrieved all fields for {len(all_keys)} samples:") - # print(f" Fields: {list(retrieved_all.keys())}") + # # Step 7: Retrieve all fields using kv_get + print("\n[Step 7] Retrieving all fields with kv_get...") + retrieved_all = tq.kv_get(keys=all_keys, partition_id=partition_id) + print(f" Retrieved all fields for {all_keys}:") + print(f" Fields: {list(retrieved_all.keys())}") + print( + f" Note: We cannot retrieve fields {list(response_batch.keys())}, since they only available in {append_keys}" + ) # Step 8: Clear specific keys print("\n[Step 8] Clearing keys from partition...") From f1c31ee2d5a6e725835760cc05522454a6d2e0ab Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 20:55:36 +0800 Subject: [PATCH 09/34] minor fix Signed-off-by: 0oshowero0 --- tutorial/03_metadata_concepts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py index 126cc5b..91752d8 100644 --- a/tutorial/03_metadata_concepts.py +++ b/tutorial/03_metadata_concepts.py @@ -305,7 +305,7 @@ def demonstrate_real_workflow(): # Initialize Ray if not ray.is_initialized(): - ray.init() + ray.init(namespace="TransferQueueTutorial") # Initialize TransferQueue tq.init() From 3ecf4d394e5a6cfd8193685f4dcd85d819e147c5 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 21:20:16 +0800 Subject: [PATCH 10/34] add ut Signed-off-by: 0oshowero0 --- tests/e2e/test_kv_interface_e2e.py | 580 +++++++++++++++++++++++++++++ tests/test_kv_interface.py | 539 --------------------------- 2 files changed, 580 insertions(+), 539 deletions(-) create mode 100644 tests/e2e/test_kv_interface_e2e.py delete mode 100644 tests/test_kv_interface.py diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py new file mode 100644 index 0000000..458befa --- /dev/null +++ b/tests/e2e/test_kv_interface_e2e.py @@ -0,0 +1,580 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for KV interface in transfer_queue.interface. + +This test module validates the KV interface functionality by: +1. Using external interfaces (kv_put, kv_batch_put, kv_get, kv_list, kv_clear) for read/write +2. Verifying correctness by calling TransferQueueController's internal methods directly +""" + +import os +import sys +from pathlib import Path + +import pytest +import ray +import torch +from tensordict import TensorDict + +# Add parent directory to path +parent_dir = Path(__file__).resolve().parent.parent.parent +sys.path.append(str(parent_dir)) + +import transfer_queue as tq # noqa: E402 + +# Configure Ray for tests +os.environ["RAY_DEDUP_LOGS"] = "0" + + +@pytest.fixture(scope="module") +def ray_init(): + """Initialize Ray for the test module.""" + if not ray.is_initialized(): + ray.init(namespace="TestKVInterfaceE2E") + yield + if ray.is_initialized(): + ray.shutdown() + + +@pytest.fixture(scope="module") +def tq_system(ray_init): + """Initialize TransferQueue system for the test module.""" + tq.init() + yield + tq.close() + + +@pytest.fixture +def controller(tq_system): + """Get the TransferQueueController actor for direct verification.""" + controller = ray.get_actor("TransferQueueController") + yield controller + + +@pytest.fixture(autouse=True) +def cleanup_partition(controller): + """Cleanup partition after each test.""" + yield + try: + ray.get(controller.clear_partition.remote("test_partition")) + except Exception: + pass + + +def get_controller_partition(controller, partition_id: str): + """Get partition snapshot from controller for verification.""" + return ray.get(controller.get_partition_snapshot.remote(partition_id)) + + +def assert_tensor_equal(tensor_a, tensor_b, msg=""): + """Assert two tensors are equal.""" + assert torch.equal(tensor_a, tensor_b), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}" + + +def assert_tensor_close(tensor_a, tensor_b, rtol=1e-5, atol=1e-8, msg=""): + """Assert two tensors are close.""" + assert torch.allclose(tensor_a, tensor_b, rtol=rtol, atol=atol), f"{msg} Tensors are not close" + + +class TestKVPutE2E: + """End-to-end tests for kv_put functionality.""" + + def test_kv_put_single_sample_with_fields_and_tag(self, controller): + """Test putting a single sample with fields and tag.""" + partition_id = "test_partition" + key = "sample_0" + # Use 1D tensors - kv_put with dict will auto-unsqueeze to add batch dimension + input_ids = torch.tensor([1, 2, 3]) + attention_mask = torch.ones(3) + tag = {"global_steps": 0, "status": "running"} + + # Put data using interface + tq.kv_put( + key=key, + partition_id=partition_id, + fields={"input_ids": input_ids, "attention_mask": attention_mask}, + tag=tag, + ) + + # Verify via controller internal state + partition = get_controller_partition(controller, partition_id) + assert partition is not None, "Partition should exist" + + # Check key->global_index mapping + assert key in partition.keys_mapping, f"Key {key} should be in keys_mapping" + global_idx = partition.keys_mapping[key] + assert global_idx in partition.global_indexes, f"Global index {global_idx} should be in global_indexes" + + # Check custom_meta (tag) + assert global_idx in partition.custom_meta, f"Custom meta should exist for global index {global_idx}" + assert partition.custom_meta[global_idx]["global_steps"] == 0 + assert partition.custom_meta[global_idx]["status"] == "running" + + # Check production status - fields should be marked as produced + assert "input_ids" in partition.field_name_mapping, "input_ids field should be registered" + assert "attention_mask" in partition.field_name_mapping, "attention_mask field should be registered" + input_ids_col_idx = partition.field_name_mapping["input_ids"] + assert partition.production_status[global_idx, input_ids_col_idx] == 1, "input_ids should be marked as produced" + + # Retrieve and verify data via kv_get - tensors will have batch dimension + retrieved = tq.kv_get(keys=key, partition_id=partition_id) + assert "input_ids" in retrieved.keys() + assert "attention_mask" in retrieved.keys() + # After unsqueeze, tensors become 2D [batch_size=1, original_size] + expected_input_ids = input_ids.unsqueeze(0) + expected_attention_mask = attention_mask.unsqueeze(0) + assert_tensor_equal(retrieved["input_ids"], expected_input_ids) + assert_tensor_equal(retrieved["attention_mask"], expected_attention_mask) + + def test_kv_put_update_tag_only(self, controller): + """Test updating only tag without providing fields.""" + partition_id = "test_partition" + key = "sample_1" + + # First put with fields - use TensorDict to avoid unsqueeze + single_data = TensorDict({"value": torch.tensor([[10]])}, batch_size=1) + tq.kv_put(key=key, partition_id=partition_id, fields=single_data, tag={"version": 1}) + + # Update only tag + new_tag = {"version": 2, "status": "updated"} + tq.kv_put(key=key, partition_id=partition_id, fields=None, tag=new_tag) + + # Verify via controller + partition = get_controller_partition(controller, partition_id) + global_idx = partition.keys_mapping[key] + assert partition.custom_meta[global_idx]["version"] == 2 + assert partition.custom_meta[global_idx]["status"] == "updated" + + # Data should still be accessible + retrieved = tq.kv_get(keys=key, partition_id=partition_id) + assert_tensor_equal(retrieved["value"], torch.tensor([[10]])) + + def test_kv_put_with_dict_fields(self, controller): + """Test kv_put with dict fields (auto-converted to TensorDict).""" + partition_id = "test_partition" + key = "sample_2" + + # Put with dict fields - will be auto-unsqueezed + tq.kv_put( + key=key, partition_id=partition_id, fields={"data": torch.tensor([1, 2, 3, 4])}, tag={"type": "dict_test"} + ) + + # Verify - retrieved data will have batch dimension + retrieved = tq.kv_get(keys=key, partition_id=partition_id) + expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed + assert_tensor_equal(retrieved["data"], expected) + + +class TestKVBatchPutE2E: + """End-to-end tests for kv_batch_put functionality.""" + + def test_kv_batch_put_multiple_samples(self, controller): + """Test batch putting multiple samples.""" + partition_id = "test_partition" + keys = ["batch_0", "batch_1", "batch_2", "batch_3"] + batch_input_ids = torch.tensor( + [ + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + [13, 14, 15], + ] + ) + batch_attention_mask = torch.ones_like(batch_input_ids) + + fields = TensorDict( + { + "input_ids": batch_input_ids, + "attention_mask": batch_attention_mask, + }, + batch_size=4, + ) + + tags = [{"idx": i, "batch": True} for i in range(4)] + + # Batch put using interface + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags) + + # Verify via controller + partition = get_controller_partition(controller, partition_id) + assert partition is not None + + # All keys should be registered + for key in keys: + assert key in partition.keys_mapping, f"Key {key} should be in keys_mapping" + + # Verify tags + for i, key in enumerate(keys): + global_idx = partition.keys_mapping[key] + assert partition.custom_meta[global_idx]["idx"] == i + assert partition.custom_meta[global_idx]["batch"] is True + + # Verify all data via kv_get + retrieved = tq.kv_get(keys=keys, partition_id=partition_id) + assert_tensor_equal(retrieved["input_ids"], batch_input_ids) + assert_tensor_equal(retrieved["attention_mask"], batch_attention_mask) + + def test_kv_batch_put_partial_update(self, controller): + """Test adding new fields to existing samples.""" + partition_id = "test_partition" + keys = ["partial_0", "partial_1"] + + # First put initial data + initial_data = TensorDict( + { + "input_ids": torch.tensor([[1, 2], [3, 4]]), + }, + batch_size=2, + ) + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=initial_data, tags=[{"v": 1}, {"v": 1}]) + + # Add new fields to subset of keys + new_fields = TensorDict( + { + "response": torch.tensor([[5, 6]]), # Only for 1 sample + }, + batch_size=1, + ) + tq.kv_batch_put(keys=[keys[1]], partition_id=partition_id, fields=new_fields, tags=[{"v": 2}]) + + # Verify via controller - only keys[1] should have response field + partition = get_controller_partition(controller, partition_id) + global_idx_1 = partition.keys_mapping[keys[1]] + + # Check that fields were added + assert "response" in partition.field_name_mapping + response_col_idx = partition.field_name_mapping["response"] + + # keys[0] should NOT have response marked as produced + global_idx_0 = partition.keys_mapping[keys[0]] + assert partition.production_status[global_idx_0, response_col_idx] == 0, "Keys[0] should not have response" + + # keys[1] should have response marked as produced + assert partition.production_status[global_idx_1, response_col_idx] == 1, "Keys[1] should have response" + + +class TestKVGetE2E: + """End-to-end tests for kv_get functionality.""" + + def test_kv_get_single_key(self, controller): + """Test getting data for a single key.""" + partition_id = "test_partition" + key = "get_single" + # Use TensorDict to avoid auto-unsqueeze issue with dict input + expected_data = torch.tensor([[100, 200, 300]]) + fields = TensorDict({"data": expected_data}, batch_size=1) + + tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) + + retrieved = tq.kv_get(keys=key, partition_id=partition_id) + assert_tensor_equal(retrieved["data"], expected_data) + + def test_kv_get_multiple_keys(self, controller): + """Test getting data for multiple keys.""" + partition_id = "test_partition" + keys = ["get_multi_0", "get_multi_1", "get_multi_2"] + expected_data = torch.tensor([[1, 2], [3, 4], [5, 6]]) + + fields = TensorDict({"data": expected_data}, batch_size=3) + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) + + retrieved = tq.kv_get(keys=keys, partition_id=partition_id) + assert_tensor_equal(retrieved["data"], expected_data) + + def test_kv_get_specific_fields(self, controller): + """Test getting only specific fields.""" + partition_id = "test_partition" + key = "get_fields" + # Use TensorDict to avoid auto-unsqueeze issue + input_ids = torch.tensor([[1, 2, 3]]) + attention_mask = torch.ones(1, 3) + response = torch.tensor([[10, 20]]) + + fields = TensorDict( + {"input_ids": input_ids, "attention_mask": attention_mask, "response": response}, batch_size=1 + ) + + # Put all fields + tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) + + # Get only input_ids + retrieved = tq.kv_get(keys=key, partition_id=partition_id, fields="input_ids") + assert "input_ids" in retrieved.keys() + assert "attention_mask" not in retrieved.keys() + assert "response" not in retrieved.keys() + assert_tensor_equal(retrieved["input_ids"], input_ids) + + # Get multiple specific fields + retrieved = tq.kv_get(keys=key, partition_id=partition_id, fields=["input_ids", "response"]) + assert "input_ids" in retrieved.keys() + assert "response" in retrieved.keys() + assert "attention_mask" not in retrieved.keys() + + def test_kv_get_nonexistent_key(self, controller): + """Test that getting data for non-existent key returns empty result.""" + partition_id = "test_partition" + + # Try to get data for a key that doesn't exist - should return empty or raise error + try: + retrieved = tq.kv_get(keys="nonexistent_key", partition_id=partition_id) + # If it returns, it should be empty + assert retrieved.batch_size[0] == 0 + except RuntimeError as e: + # Or it might raise an error about keys not found + assert "not found" in str(e).lower() or "empty" in str(e).lower() + + +class TestKVListE2E: + """End-to-end tests for kv_list functionality.""" + + def test_kv_list_all_keys(self, controller): + """Test listing all keys in a partition.""" + partition_id = "test_partition" + keys = ["list_0", "list_1", "list_2"] + + for i, key in enumerate(keys): + tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag={"id": i}) + + # List all keys + listed_keys, tags = tq.kv_list(partition_id=partition_id) + + assert len(listed_keys) == 3 + for key in keys: + assert key in listed_keys + + # Verify tags match + for i, (key, tag) in enumerate(zip(listed_keys, tags, strict=False)): + assert tag["id"] == i + + def test_kv_list_empty_partition(self): + """Test listing empty partition.""" + partition_id = "test_partition_empty" + + keys, tags = tq.kv_list(partition_id=partition_id) + + assert len(keys) == 0 + assert len(tags) == 0 + + +class TestKVClearE2E: + """End-to-end tests for kv_clear functionality.""" + + def test_kv_clear_single_key(self, controller): + """Test clearing a single key.""" + partition_id = "test_partition" + key = "clear_single" + other_key = "clear_other" + + tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[1]])}, tag={"id": "single"}) + tq.kv_put(key=other_key, partition_id=partition_id, fields={"data": torch.tensor([[2]])}, tag={"id": "other"}) + + # Clear single key + tq.kv_clear(keys=key, partition_id=partition_id) + + # Verify via kv_list + listed_keys, _ = tq.kv_list(partition_id=partition_id) + assert key not in listed_keys + assert other_key in listed_keys + + # Verify via controller - key should be removed + partition = get_controller_partition(controller, partition_id) + assert key not in partition.keys_mapping + + def test_kv_clear_multiple_keys(self, controller): + """Test clearing multiple keys.""" + partition_id = "test_partition" + keys = ["clear_multi_0", "clear_multi_1", "clear_multi_2", "clear_multi_3"] + + for i, key in enumerate(keys): + tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag=None) + + # Clear first 2 keys + tq.kv_clear(keys=keys[:2], partition_id=partition_id) + + # Verify + listed_keys, _ = tq.kv_list(partition_id=partition_id) + assert len(listed_keys) == 2 + assert keys[0] not in listed_keys + assert keys[1] not in listed_keys + assert keys[2] in listed_keys + assert keys[3] in listed_keys + + +class TestKVTagsE2E: + """End-to-end tests for tag functionality.""" + + def test_tag_preservation_across_operations(self, controller): + """Test that tags are preserved and updated correctly.""" + partition_id = "test_partition" + key = "tag_test" + + # Put with initial tag + tq.kv_put( + key=key, + partition_id=partition_id, + fields={"data": torch.tensor([[1]])}, + tag={"version": 1, "status": "init"}, + ) + + # Update with new tag (keeping version incrementing) + tq.kv_put(key=key, partition_id=partition_id, fields=None, tag={"version": 2, "status": "updated"}) + + # Verify tag is updated + partition = get_controller_partition(controller, partition_id) + global_idx = partition.keys_mapping[key] + assert partition.custom_meta[global_idx]["version"] == 2 + assert partition.custom_meta[global_idx]["status"] == "updated" + + def test_tag_retrieval_via_kv_list(self): + """Test retrieving tags via kv_list.""" + partition_id = "test_partition" + keys = ["tag_list_0", "tag_list_1", "tag_list_2"] + + expected_tags = [ + {"score": 0.9, "label": "A"}, + {"score": 0.85, "label": "B"}, + {"score": 0.95, "label": "C"}, + ] + + for key, tag in zip(keys, expected_tags, strict=False): + tq.kv_put(key=key, partition_id=partition_id, fields={"x": torch.tensor([[1]])}, tag=tag) + + # List and verify tags + listed_keys, tags = tq.kv_list(partition_id=partition_id) + + for key, expected_tag in zip(keys, expected_tags, strict=False): + assert key in listed_keys + idx = listed_keys.index(key) + assert tags[idx] == expected_tag + + +class TestKVE2ECornerCases: + """End-to-end tests for corner cases.""" + + def test_key_to_global_index_mapping_consistency(self, controller): + """Test that key->global_index mapping is consistent across operations.""" + partition_id = "test_partition" + keys = ["map_0", "map_1", "map_2", "map_3"] + + # Put all keys + tq.kv_batch_put( + keys=keys, + partition_id=partition_id, + fields=TensorDict({"data": torch.randn(4, 5)}, batch_size=4), + tags=[{"i": i} for i in range(4)], + ) + + # Verify mapping consistency via controller + partition = get_controller_partition(controller, partition_id) + + for key in keys: + assert key in partition.keys_mapping + global_idx = partition.keys_mapping[key] + assert global_idx in partition.revert_keys_mapping + assert partition.revert_keys_mapping[global_idx] == key + + def test_field_expansion_across_samples(self, controller): + """Test that new fields can be added across samples.""" + partition_id = "test_partition" + keys = ["expand_0", "expand_1"] + + # Put initial fields + tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_a": torch.tensor([[1]])}, tag=None) + + # Add new field to first key + tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_b": torch.tensor([[2]])}, tag=None) + + # Add different field to second key + tq.kv_put(key=keys[1], partition_id=partition_id, fields={"field_c": torch.tensor([[3]])}, tag=None) + + # Verify field expansion in controller + partition = get_controller_partition(controller, partition_id) + + # All fields should be registered + assert "field_a" in partition.field_name_mapping + assert "field_b" in partition.field_name_mapping + assert "field_c" in partition.field_name_mapping + + def test_empty_tag_list(self): + """Test operations with empty tags.""" + partition_id = "test_partition" + key = "empty_tag" + + # Use 1D tensor - will be auto-unsqueezed to 2D + tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([1])}, tag={}) + + # Should work and data should be retrievable - will be 2D after unsqueeze + retrieved = tq.kv_get(keys=key, partition_id=partition_id) + assert_tensor_equal(retrieved["data"], torch.tensor([[1]])) + + def test_large_batch_put_and_get(self): + """Test putting and getting a large batch of samples.""" + partition_id = "test_partition" + num_samples = 100 + keys = [f"large_{i}" for i in range(num_samples)] + + # Create batch data + data = TensorDict( + { + "input_ids": torch.randn(num_samples, 10), + "attention_mask": torch.ones(num_samples, 10), + }, + batch_size=num_samples, + ) + + tags = [{"idx": i} for i in range(num_samples)] + + # Batch put + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data, tags=tags) + + # Batch get all + retrieved = tq.kv_get(keys=keys, partition_id=partition_id) + + assert retrieved["input_ids"].shape == (num_samples, 10) + assert retrieved["attention_mask"].shape == (num_samples, 10) + + # Verify specific samples + assert_tensor_equal(retrieved["input_ids"][0], data["input_ids"][0]) + assert_tensor_equal(retrieved["input_ids"][99], data["input_ids"][99]) + + def test_controller_partition_synchronization(self, controller): + """Test that controller partition state is synchronized with operations.""" + partition_id = "test_partition" + key = "sync_test" + + # Put data + tq.kv_put(key=key, partition_id=partition_id, fields={"x": torch.tensor([[42]])}, tag={"sync": True}) + + # Get snapshot before clear + partition_before = get_controller_partition(controller, partition_id) + global_idx = partition_before.keys_mapping[key] + assert partition_before.production_status[global_idx, partition_before.field_name_mapping["x"]] == 1 + + # Clear + tq.kv_clear(keys=key, partition_id=partition_id) + + # Get snapshot after clear + partition_after = get_controller_partition(controller, partition_id) + assert key not in partition_after.keys_mapping + + +def run_tests(): + """Run all e2e tests manually for debugging.""" + pytest.main([__file__, "-v", "-s"]) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_kv_interface.py b/tests/test_kv_interface.py deleted file mode 100644 index 6b9eee1..0000000 --- a/tests/test_kv_interface.py +++ /dev/null @@ -1,539 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2025 The TransferQueue Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for kv interface in transfer_queue.interface.""" - -import asyncio -import sys -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -import torch -from tensordict import TensorDict - -# Setup path -parent_dir = Path(__file__).resolve().parent.parent -sys.path.append(str(parent_dir)) - -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 -from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 - - -def create_batch_meta(global_indexes, fields_data): - """Helper to create BatchMeta for testing.""" - samples = [] - for sample_id, global_idx in enumerate(global_indexes): - fields_dict = {} - for field_name, tensor in fields_data.items(): - field_meta = FieldMeta( - name=field_name, - dtype=tensor.dtype, - shape=tensor.shape, - production_status=ProductionStatus.READY_FOR_CONSUME, - ) - fields_dict[field_name] = field_meta - sample = SampleMeta( - partition_id="test_partition", - global_index=global_idx, - fields=fields_dict, - ) - samples.append(sample) - return BatchMeta(samples=samples) - - -class TestKVPut: - """Tests for kv_put function.""" - - def test_kv_put_with_fields(self): - """Test kv_put with fields parameter.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - - tensor_data = TensorDict( - {"text": torch.tensor([[1, 2, 3]]), "label": torch.tensor([0])}, - batch_size=[1], - ) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_put - - kv_put(key="test_key", partition_id="partition_1", fields=tensor_data, tag={"type": "test"}) - - # Verify kv_retrieve_keys was called - mock_client.kv_retrieve_keys.assert_called_once_with(keys=["test_key"], partition_id="partition_1", create=True) - - # Verify update_custom_meta was called - mock_batch_meta.update_custom_meta.assert_called_once() - - # Verify put was called - mock_client.put.assert_called_once() - - def test_kv_put_with_dict_fields(self): - """Test kv_put converts dict to TensorDict correctly.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_put - - # Test with dict containing tensor - kv_put( - key="test_key", - partition_id="partition_1", - fields={"text": torch.tensor([1, 2, 3])}, - tag=None, - ) - - # Verify put was called - mock_client.put.assert_called_once() - call_args = mock_client.put.call_args - fields_arg = call_args[0][0] - assert "text" in fields_arg - # The dict should be converted to TensorDict - assert isinstance(fields_arg, TensorDict) - - def test_kv_put_with_tag_only(self): - """Test kv_put with only tag parameter (no fields).""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_put - - kv_put(key="test_key", partition_id="partition_1", fields=None, tag={"score": 0.9}) - - # Verify put was NOT called (only set_custom_meta) - mock_client.put.assert_not_called() - mock_client.set_custom_meta.assert_called_once_with(mock_batch_meta) - - def test_kv_put_raises_error_without_fields_and_tag(self): - """Test kv_put raises ValueError when neither fields nor tag provided.""" - mock_client = MagicMock() - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_put - - with pytest.raises(ValueError, match="Please provide at least one parameter"): - kv_put(key="test_key", partition_id="partition_1", fields=None, tag=None) - - -class TestKVBatchPut: - """Tests for kv_batch_put function.""" - - def test_kv_batch_put_success(self): - """Test kv_batch_put with valid inputs.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - - batch_data = TensorDict( - { - "text": torch.tensor([[1, 2], [3, 4], [5, 6]]), - "label": torch.tensor([0, 1, 2]), - }, - batch_size=[3], - ) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_batch_put - - keys = ["key1", "key2", "key3"] - tags = [{"tag": "v1"}, {"tag": "v2"}, {"tag": "v3"}] - - kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags) - - mock_client.kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=True) - mock_batch_meta.update_custom_meta.assert_called_once_with(tags) - mock_client.put.assert_called_once() - - def test_kv_batch_put_tags_length_mismatch(self): - """Test kv_batch_put raises error when tags length doesn't match keys.""" - mock_client = MagicMock() - - batch_data = TensorDict( - { - "text": torch.tensor([[1, 2], [3, 4], [5, 6]]), - "label": torch.tensor([0, 1, 2]), - }, - batch_size=[3], - ) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_batch_put - - keys = ["key1", "key2", "key3"] - tags = [{"tag": "v1"}, {"tag": "v2"}] # Only 2 tags for 3 keys - - with pytest.raises(ValueError, match="does not match length of tags"): - kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags) - - -class TestKVGet: - """Tests for kv_get function.""" - - def test_kv_get_single_key(self): - """Test kv_get with single key.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - mock_client.get_data.return_value = TensorDict({"data": torch.tensor([1, 2])}) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_get - - kv_get(keys="test_key", partition_id="partition_1") - - # keys is passed directly (not wrapped in list) for single key - mock_client.kv_retrieve_keys.assert_called_once_with(keys="test_key", partition_id="partition_1", create=False) - mock_client.get_data.assert_called_once_with(mock_batch_meta) - - def test_kv_get_multiple_keys(self): - """Test kv_get with multiple keys.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - mock_client.get_data.return_value = TensorDict({"data": torch.tensor([1, 2])}) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_get - - keys = ["key1", "key2", "key3"] - kv_get(keys=keys, partition_id="partition_1") - - mock_client.kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=False) - - def test_kv_get_with_fields(self): - """Test kv_get with specific fields.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_batch_meta.select_fields = MagicMock(return_value=mock_batch_meta) - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - mock_client.get_data.return_value = TensorDict({"text": torch.tensor([1, 2])}) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_get - - kv_get(keys="test_key", partition_id="partition_1", fields="text") - - mock_batch_meta.select_fields.assert_called_once_with(["text"]) - - -class TestKVList: - """Tests for kv_list function.""" - - def test_kv_list_with_keys(self): - """Test kv_list returns keys and custom_meta.""" - mock_client = MagicMock() - mock_client.kv_list.return_value = ["key1", "key2", "key3"] - mock_batch_meta = MagicMock() - mock_batch_meta.global_indexes = [0, 1, 2] - mock_batch_meta.get_all_custom_meta = MagicMock(return_value={0: {}, 1: {}, 2: {}}) - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_list - - keys, custom_meta = kv_list(partition_id="partition_1") - - assert keys == ["key1", "key2", "key3"] - assert len(custom_meta) == 3 - - def test_kv_list_empty_partition(self): - """Test kv_list returns None when partition is empty.""" - mock_client = MagicMock() - mock_client.kv_list.return_value = [] - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_list - - keys, custom_meta = kv_list(partition_id="empty_partition") - - assert keys is None - assert custom_meta is None - - -class TestKVClear: - """Tests for kv_clear function.""" - - def test_kv_clear_single_key(self): - """Test kv_clear with single key.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_batch_meta.size = 1 - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_clear - - kv_clear(keys="test_key", partition_id="partition_1") - - mock_client.kv_retrieve_keys.assert_called_once_with( - keys=["test_key"], partition_id="partition_1", create=False - ) - mock_client.clear_samples.assert_called_once_with(mock_batch_meta) - - def test_kv_clear_multiple_keys(self): - """Test kv_clear with multiple keys.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_batch_meta.size = 3 - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_clear - - kv_clear(keys=["key1", "key2", "key3"], partition_id="partition_1") - - mock_client.kv_retrieve_keys.assert_called_once_with( - keys=["key1", "key2", "key3"], partition_id="partition_1", create=False - ) - mock_client.clear_samples.assert_called_once() - - -class TestAsyncKVPut: - """Tests for async_kv_put function.""" - - def test_async_kv_put_with_fields(self): - """Test async_kv_put with fields parameter.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) - mock_client.async_put = AsyncMock() - mock_client.async_set_custom_meta = AsyncMock() - - tensor_data = TensorDict( - {"text": torch.tensor([[1, 2, 3]]), "label": torch.tensor([0])}, - batch_size=[1], - ) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import async_kv_put - - asyncio.run( - async_kv_put(key="test_key", partition_id="partition_1", fields=tensor_data, tag={"type": "test"}) - ) - - mock_client.async_kv_retrieve_keys.assert_called_once_with( - keys=["test_key"], partition_id="partition_1", create=True - ) - mock_batch_meta.update_custom_meta.assert_called_once() - mock_client.async_put.assert_called_once() - - def test_async_kv_put_with_tag_only(self): - """Test async_kv_put with only tag (no fields).""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) - mock_client.async_put = AsyncMock() - mock_client.async_set_custom_meta = AsyncMock() - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import async_kv_put - - asyncio.run(async_kv_put(key="test_key", partition_id="partition_1", fields=None, tag={"score": 0.9})) - - mock_client.async_put.assert_not_called() - mock_client.async_set_custom_meta.assert_called_once_with(mock_batch_meta) - - -class TestAsyncKVBatchPut: - """Tests for async_kv_batch_put function.""" - - def test_async_kv_batch_put_success(self): - """Test async_kv_batch_put with valid inputs.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) - mock_client.async_put = AsyncMock() - mock_client.async_set_custom_meta = AsyncMock() - - batch_data = TensorDict( - { - "text": torch.tensor([[1, 2], [3, 4], [5, 6]]), - "label": torch.tensor([0, 1, 2]), - }, - batch_size=[3], - ) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import async_kv_batch_put - - keys = ["key1", "key2", "key3"] - tags = [{"tag": "v1"}, {"tag": "v2"}, {"tag": "v3"}] - - asyncio.run(async_kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags)) - - mock_client.async_kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=True) - mock_batch_meta.update_custom_meta.assert_called_once_with(tags) - mock_client.async_put.assert_called_once() - - -class TestAsyncKVGet: - """Tests for async_kv_get function.""" - - def test_async_kv_get_single_key(self): - """Test async_kv_get with single key.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) - mock_client.async_get_data = AsyncMock(return_value=TensorDict({"data": torch.tensor([1, 2])})) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import async_kv_get - - asyncio.run(async_kv_get(keys="test_key", partition_id="partition_1")) - - # keys is passed directly (not wrapped in list) for single key - mock_client.async_kv_retrieve_keys.assert_called_once_with( - keys="test_key", partition_id="partition_1", create=False - ) - mock_client.async_get_data.assert_called_once_with(mock_batch_meta) - - -class TestAsyncKVList: - """Tests for async_kv_list function.""" - - def test_async_kv_list_with_keys(self): - """Test async_kv_list returns keys and custom_meta.""" - mock_client = MagicMock() - mock_client.async_kv_list = AsyncMock(return_value=["key1", "key2", "key3"]) - mock_batch_meta = MagicMock() - mock_batch_meta.global_indexes = [0, 1, 2] - mock_batch_meta.get_all_custom_meta = MagicMock(return_value={0: {}, 1: {}, 2: {}}) - mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import async_kv_list - - keys, custom_meta = asyncio.run(async_kv_list(partition_id="partition_1")) - - assert keys == ["key1", "key2", "key3"] - assert len(custom_meta) == 3 - - def test_async_kv_list_empty_partition(self): - """Test async_kv_list returns None when partition is empty.""" - mock_client = MagicMock() - mock_client.async_kv_list = AsyncMock(return_value=[]) - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import async_kv_list - - keys, custom_meta = asyncio.run(async_kv_list(partition_id="empty_partition")) - - assert keys is None - assert custom_meta is None - - -class TestAsyncKVClear: - """Tests for async_kv_clear function.""" - - def test_async_kv_clear_single_key(self): - """Test async_kv_clear with single key.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_batch_meta.size = 1 - mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) - mock_client.async_clear_samples = AsyncMock() - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import async_kv_clear - - asyncio.run(async_kv_clear(keys="test_key", partition_id="partition_1")) - - mock_client.async_kv_retrieve_keys.assert_called_once_with( - keys=["test_key"], partition_id="partition_1", create=False - ) - mock_client.async_clear_samples.assert_called_once_with(mock_batch_meta) - - def test_async_kv_clear_multiple_keys(self): - """Test async_kv_clear with multiple keys.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_batch_meta.size = 3 - mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta) - mock_client.async_clear_samples = AsyncMock() - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import async_kv_clear - - asyncio.run(async_kv_clear(keys=["key1", "key2", "key3"], partition_id="partition_1")) - - mock_client.async_kv_retrieve_keys.assert_called_once() - mock_client.async_clear_samples.assert_called_once() - - -class TestKVInterfaceDictConversion: - """Tests for dict to TensorDict conversion in kv_put.""" - - def test_kv_put_with_nontensor_value(self): - """Test kv_put converts non-tensor values using NonTensorStack.""" - mock_client = MagicMock() - mock_batch_meta = MagicMock() - mock_client.kv_retrieve_keys.return_value = mock_batch_meta - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_put - - # Test with non-tensor value (like a string or list) - kv_put( - key="test_key", - partition_id="partition_1", - fields={"meta": {"key": "value"}}, - tag=None, - ) - - # Verify put was called - mock_client.put.assert_called_once() - call_args = mock_client.put.call_args - fields_arg = call_args[0][0] - # The dict should be converted to TensorDict - assert isinstance(fields_arg, TensorDict) - assert "meta" in fields_arg - - def test_kv_put_rejects_nested_tensor(self): - """Test kv_put raises ValueError for nested tensors (requires batch_put).""" - mock_client = MagicMock() - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_put - - nested_tensor = torch.nested.nested_tensor([[1, 2], [3, 4]]) - - with pytest.raises(ValueError, match="Please use.*kv_batch_put"): - kv_put( - key="test_key", - partition_id="partition_1", - fields={"nested": nested_tensor}, - tag=None, - ) - - def test_kv_put_invalid_fields_type(self): - """Test kv_put raises ValueError for invalid fields type.""" - mock_client = MagicMock() - - with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client): - from transfer_queue.interface import kv_put - - with pytest.raises(ValueError, match="field can only be dict or TensorDict"): - kv_put( - key="test_key", - partition_id="partition_1", - fields="invalid_string", - tag=None, - ) From 4d479ab896128dc7c01ec5cf4f140bc7652b6765 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 21:38:55 +0800 Subject: [PATCH 11/34] try fix tutorial ci Signed-off-by: 0oshowero0 --- .github/workflows/tutorial-check.yml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml index 1202e21..ed8a2ba 100644 --- a/.github/workflows/tutorial-check.yml +++ b/.github/workflows/tutorial-check.yml @@ -36,4 +36,11 @@ jobs: - name: Run tutorials run: | export TQ_NUM_THREADS=2 - for file in tutorial/*.py; do python3 "$file"; done \ No newline at end of file + export RAY_DEDUP_LOGS=0 + ray stop --force || true + for file in tutorial/*.py; do + echo "================ Running $file ================" + python3 "$file" + ray stop --force + sleep 5 + done \ No newline at end of file From 5556278548237f84207741475f696ef799dc8de1 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 21:54:35 +0800 Subject: [PATCH 12/34] try only run 06 Signed-off-by: 0oshowero0 --- .github/workflows/tutorial-check.yml | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml index ed8a2ba..0c72b76 100644 --- a/.github/workflows/tutorial-check.yml +++ b/.github/workflows/tutorial-check.yml @@ -37,10 +37,4 @@ jobs: run: | export TQ_NUM_THREADS=2 export RAY_DEDUP_LOGS=0 - ray stop --force || true - for file in tutorial/*.py; do - echo "================ Running $file ================" - python3 "$file" - ray stop --force - sleep 5 - done \ No newline at end of file + python tutorial/06_streaming_dataloader.py \ No newline at end of file From 4f6ce684df0d2e83cefc917b257ddea5fc9e637a Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 21:59:46 +0800 Subject: [PATCH 13/34] try add sleep in tutorial Signed-off-by: 0oshowero0 --- tutorial/06_streaming_dataloader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tutorial/06_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py index d95a3c8..c0131ef 100644 --- a/tutorial/06_streaming_dataloader.py +++ b/tutorial/06_streaming_dataloader.py @@ -319,6 +319,8 @@ def main(): print("\n[Phase 2] Starting data generation...") generate_worker_handlers = start_all_generate_actors() + time.sleep(10) + # Step 3: Launch data consumption actors print("\n[Phase 3] Starting data consumption...") update_worker_handlers = start_all_update_actors() From e7ea8e29680b72a2650478c049b8dfceb3bb1570 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 22:20:53 +0800 Subject: [PATCH 14/34] fix race condition Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 2 ++ tutorial/06_streaming_dataloader.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 1f01482..ea803f5 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -547,6 +547,8 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te ) if mask: + with self.data_status_lock: + self.ensure_samples_capacity(max(partition_global_index) + 1) consumption_status = consumption_status[partition_global_index] return partition_global_index, consumption_status diff --git a/tutorial/06_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py index c0131ef..d95a3c8 100644 --- a/tutorial/06_streaming_dataloader.py +++ b/tutorial/06_streaming_dataloader.py @@ -319,8 +319,6 @@ def main(): print("\n[Phase 2] Starting data generation...") generate_worker_handlers = start_all_generate_actors() - time.sleep(10) - # Step 3: Launch data consumption actors print("\n[Phase 3] Starting data consumption...") update_worker_handlers = start_all_update_actors() From 29f5a1ed5e340742314c72d57cf4f4f44b2ce86e Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 22:24:07 +0800 Subject: [PATCH 15/34] fix race condition Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index ea803f5..538ba21 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -540,8 +540,6 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8) # Get consumption status for requested task - consumption_status = self.consumption_status[task_name] - partition_global_index = torch.tensor( sorted(self.global_indexes | self.pre_allocated_global_indexes), dtype=torch.long ) @@ -549,8 +547,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te if mask: with self.data_status_lock: self.ensure_samples_capacity(max(partition_global_index) + 1) - consumption_status = consumption_status[partition_global_index] - + consumption_status = self.consumption_status[task_name][partition_global_index] + else: + consumption_status = self.consumption_status[task_name] return partition_global_index, consumption_status def reset_consumption(self, task_name: Optional[str] = None): From e1615715574b22b9bb81570f0005c223f12fb403 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 7 Feb 2026 22:26:35 +0800 Subject: [PATCH 16/34] recover ci Signed-off-by: 0oshowero0 --- .github/workflows/tutorial-check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml index 0c72b76..c24cab8 100644 --- a/.github/workflows/tutorial-check.yml +++ b/.github/workflows/tutorial-check.yml @@ -37,4 +37,4 @@ jobs: run: | export TQ_NUM_THREADS=2 export RAY_DEDUP_LOGS=0 - python tutorial/06_streaming_dataloader.py \ No newline at end of file + for file in tutorial/*.py; do python3 "$file"; done \ No newline at end of file From 928edda0f8408e17ec25c9b64af9b91e8359b1e8 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sun, 8 Feb 2026 12:36:11 +0800 Subject: [PATCH 17/34] fix batchmeta deserial Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index f0c899b..2af07a7 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -970,6 +970,7 @@ async def async_kv_retrieve_keys( if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE: metadata = response_msg.body.get("metadata", BatchMeta.empty()) + metadata = BatchMeta.from_dict(metadata) if isinstance(metadata, dict) else metadata return metadata else: raise RuntimeError( From ac9779505aa9345b47d1eafb697a7db88f8d358e Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sun, 8 Feb 2026 13:17:22 +0800 Subject: [PATCH 18/34] update kv interface test Signed-off-by: 0oshowero0 --- tests/e2e/test_kv_interface_e2e.py | 275 ++++++++++++----------------- transfer_queue/__init__.py | 8 +- transfer_queue/interface.py | 12 +- tutorial/02_kv_interface.py | 18 +- 4 files changed, 131 insertions(+), 182 deletions(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 458befa..3918bbe 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -16,7 +16,7 @@ """End-to-end tests for KV interface in transfer_queue.interface. This test module validates the KV interface functionality by: -1. Using external interfaces (kv_put, kv_batch_put, kv_get, kv_list, kv_clear) for read/write +1. Using external interfaces (kv_put, kv_batch_put, kv_batch_get, kv_list, kv_clear) for read/write 2. Verifying correctness by calling TransferQueueController's internal methods directly """ @@ -92,10 +92,44 @@ def assert_tensor_close(tensor_a, tensor_b, rtol=1e-5, atol=1e-8, msg=""): class TestKVPutE2E: """End-to-end tests for kv_put functionality.""" + def test_kv_put_with_dict_fields(self, controller): + """Test kv_put with dict fields (auto-converted to TensorDict).""" + partition_id = "test_partition" + key = "sample_0" + + # Put with dict fields - will be auto-unsqueezed + tq.kv_put( + key=key, partition_id=partition_id, fields={"data": torch.tensor([1, 2, 3, 4])}, tag={"type": "dict_test"} + ) + + # Verify - retrieved data will have batch dimension + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed + assert_tensor_equal(retrieved["data"], expected) + + def test_kv_put_with_tensordict_fields(self, controller): + """Test kv_put with tensordict fields.""" + partition_id = "test_partition" + key = "sample_1" + + tensordict_data = TensorDict( + { + "input_ids": torch.tensor([[1, 2, 3, 4]]), + }, + batch_size=1, + ) + # Put with dict fields - will be auto-unsqueezed + tq.kv_put(key=key, partition_id=partition_id, fields=tensordict_data, tag={"type": "tensordict_test"}) + + # Verify - retrieved data will have batch dimension + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed + assert_tensor_equal(retrieved["input_ids"], expected) + def test_kv_put_single_sample_with_fields_and_tag(self, controller): """Test putting a single sample with fields and tag.""" partition_id = "test_partition" - key = "sample_0" + key = "sample_2" # Use 1D tensors - kv_put with dict will auto-unsqueeze to add batch dimension input_ids = torch.tensor([1, 2, 3]) attention_mask = torch.ones(3) @@ -129,8 +163,8 @@ def test_kv_put_single_sample_with_fields_and_tag(self, controller): input_ids_col_idx = partition.field_name_mapping["input_ids"] assert partition.production_status[global_idx, input_ids_col_idx] == 1, "input_ids should be marked as produced" - # Retrieve and verify data via kv_get - tensors will have batch dimension - retrieved = tq.kv_get(keys=key, partition_id=partition_id) + # Retrieve and verify data via kv_batch_get - tensors will have batch dimension + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) assert "input_ids" in retrieved.keys() assert "attention_mask" in retrieved.keys() # After unsqueeze, tensors become 2D [batch_size=1, original_size] @@ -142,9 +176,9 @@ def test_kv_put_single_sample_with_fields_and_tag(self, controller): def test_kv_put_update_tag_only(self, controller): """Test updating only tag without providing fields.""" partition_id = "test_partition" - key = "sample_1" + key = "sample_3" - # First put with fields - use TensorDict to avoid unsqueeze + # First put with fields - use TensorDict as another example single_data = TensorDict({"value": torch.tensor([[10]])}, batch_size=1) tq.kv_put(key=key, partition_id=partition_id, fields=single_data, tag={"version": 1}) @@ -159,23 +193,42 @@ def test_kv_put_update_tag_only(self, controller): assert partition.custom_meta[global_idx]["status"] == "updated" # Data should still be accessible - retrieved = tq.kv_get(keys=key, partition_id=partition_id) + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) assert_tensor_equal(retrieved["value"], torch.tensor([[10]])) - def test_kv_put_with_dict_fields(self, controller): - """Test kv_put with dict fields (auto-converted to TensorDict).""" + def test_kv_put_partial_update(self, controller): + """Test adding new fields to existing sample.""" partition_id = "test_partition" - key = "sample_2" + key = "sample_4" - # Put with dict fields - will be auto-unsqueezed - tq.kv_put( - key=key, partition_id=partition_id, fields={"data": torch.tensor([1, 2, 3, 4])}, tag={"type": "dict_test"} + # First put initial data + initial_data = TensorDict( + { + "input_ids": torch.tensor([[1, 2, 3, 4]]), + }, + batch_size=1, ) + tq.kv_put(key=key, partition_id=partition_id, fields=initial_data, tag={"v": 1}) - # Verify - retrieved data will have batch dimension - retrieved = tq.kv_get(keys=key, partition_id=partition_id) - expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed - assert_tensor_equal(retrieved["data"], expected) + # Add new fields to subset of keys + new_fields = TensorDict( + { + "response": torch.tensor([[5, 6]]), + }, + batch_size=1, + ) + tq.kv_put(key=key, partition_id=partition_id, fields=new_fields, tag={"v": 2}) + + # Verify via controller - only keys[1] should have response field + partition = get_controller_partition(controller, partition_id) + global_idx = partition.keys_mapping[key] + + # Check that fields were added + assert "response" in partition.field_name_mapping + response_col_idx = partition.field_name_mapping["response"] + + # key should have response marked as produced + assert partition.production_status[global_idx, response_col_idx] == 1, "Key should have response" class TestKVBatchPutE2E: @@ -222,8 +275,8 @@ def test_kv_batch_put_multiple_samples(self, controller): assert partition.custom_meta[global_idx]["idx"] == i assert partition.custom_meta[global_idx]["batch"] is True - # Verify all data via kv_get - retrieved = tq.kv_get(keys=keys, partition_id=partition_id) + # Verify all data via kv_batch_get + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) assert_tensor_equal(retrieved["input_ids"], batch_input_ids) assert_tensor_equal(retrieved["attention_mask"], batch_attention_mask) @@ -267,9 +320,9 @@ def test_kv_batch_put_partial_update(self, controller): class TestKVGetE2E: - """End-to-end tests for kv_get functionality.""" + """End-to-end tests for kv_batch_get functionality.""" - def test_kv_get_single_key(self, controller): + def test_kv_batch_get_single_key(self, controller): """Test getting data for a single key.""" partition_id = "test_partition" key = "get_single" @@ -279,10 +332,10 @@ def test_kv_get_single_key(self, controller): tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) - retrieved = tq.kv_get(keys=key, partition_id=partition_id) + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) - def test_kv_get_multiple_keys(self, controller): + def test_kv_batch_get_multiple_keys(self, controller): """Test getting data for multiple keys.""" partition_id = "test_partition" keys = ["get_multi_0", "get_multi_1", "get_multi_2"] @@ -291,11 +344,25 @@ def test_kv_get_multiple_keys(self, controller): fields = TensorDict({"data": expected_data}, batch_size=3) tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) - retrieved = tq.kv_get(keys=keys, partition_id=partition_id) + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) + assert_tensor_equal(retrieved["data"], expected_data) + + def test_kv_batch_get_partial_keys(self, controller): + """Test getting data for partial keys.""" + partition_id = "test_partition" + keys = ["get_multi_3", "get_multi_4", "get_multi_5"] + partial_keys = ["get_multi_3", "get_multi_5"] + input_data = torch.tensor([[1, 2], [3, 4], [5, 6]]) + expected_data = torch.tensor([[1, 2], [5, 6]]) + + fields = TensorDict({"data": input_data}, batch_size=3) + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) + + retrieved = tq.kv_batch_get(keys=partial_keys, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) - def test_kv_get_specific_fields(self, controller): - """Test getting only specific fields.""" + def test_kv_batch_get_partial_fields(self, controller): + """Test getting only partial fields.""" partition_id = "test_partition" key = "get_fields" # Use TensorDict to avoid auto-unsqueeze issue @@ -311,25 +378,27 @@ def test_kv_get_specific_fields(self, controller): tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) # Get only input_ids - retrieved = tq.kv_get(keys=key, partition_id=partition_id, fields="input_ids") + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, fields="input_ids") assert "input_ids" in retrieved.keys() assert "attention_mask" not in retrieved.keys() assert "response" not in retrieved.keys() assert_tensor_equal(retrieved["input_ids"], input_ids) # Get multiple specific fields - retrieved = tq.kv_get(keys=key, partition_id=partition_id, fields=["input_ids", "response"]) + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, fields=["input_ids", "response"]) assert "input_ids" in retrieved.keys() assert "response" in retrieved.keys() assert "attention_mask" not in retrieved.keys() + assert_tensor_equal(retrieved["input_ids"], input_ids) + assert_tensor_equal(retrieved["response"], response) - def test_kv_get_nonexistent_key(self, controller): + def test_kv_batch_get_nonexistent_key(self, controller): """Test that getting data for non-existent key returns empty result.""" partition_id = "test_partition" # Try to get data for a key that doesn't exist - should return empty or raise error try: - retrieved = tq.kv_get(keys="nonexistent_key", partition_id=partition_id) + retrieved = tq.kv_batch_get(keys="nonexistent_key", partition_id=partition_id) # If it returns, it should be empty assert retrieved.batch_size[0] == 0 except RuntimeError as e: @@ -392,6 +461,7 @@ def test_kv_clear_single_key(self, controller): # Verify via controller - key should be removed partition = get_controller_partition(controller, partition_id) assert key not in partition.keys_mapping + assert other_key in partition.keys_mapping def test_kv_clear_multiple_keys(self, controller): """Test clearing multiple keys.""" @@ -413,79 +483,9 @@ def test_kv_clear_multiple_keys(self, controller): assert keys[3] in listed_keys -class TestKVTagsE2E: - """End-to-end tests for tag functionality.""" - - def test_tag_preservation_across_operations(self, controller): - """Test that tags are preserved and updated correctly.""" - partition_id = "test_partition" - key = "tag_test" - - # Put with initial tag - tq.kv_put( - key=key, - partition_id=partition_id, - fields={"data": torch.tensor([[1]])}, - tag={"version": 1, "status": "init"}, - ) - - # Update with new tag (keeping version incrementing) - tq.kv_put(key=key, partition_id=partition_id, fields=None, tag={"version": 2, "status": "updated"}) - - # Verify tag is updated - partition = get_controller_partition(controller, partition_id) - global_idx = partition.keys_mapping[key] - assert partition.custom_meta[global_idx]["version"] == 2 - assert partition.custom_meta[global_idx]["status"] == "updated" - - def test_tag_retrieval_via_kv_list(self): - """Test retrieving tags via kv_list.""" - partition_id = "test_partition" - keys = ["tag_list_0", "tag_list_1", "tag_list_2"] - - expected_tags = [ - {"score": 0.9, "label": "A"}, - {"score": 0.85, "label": "B"}, - {"score": 0.95, "label": "C"}, - ] - - for key, tag in zip(keys, expected_tags, strict=False): - tq.kv_put(key=key, partition_id=partition_id, fields={"x": torch.tensor([[1]])}, tag=tag) - - # List and verify tags - listed_keys, tags = tq.kv_list(partition_id=partition_id) - - for key, expected_tag in zip(keys, expected_tags, strict=False): - assert key in listed_keys - idx = listed_keys.index(key) - assert tags[idx] == expected_tag - - class TestKVE2ECornerCases: """End-to-end tests for corner cases.""" - def test_key_to_global_index_mapping_consistency(self, controller): - """Test that key->global_index mapping is consistent across operations.""" - partition_id = "test_partition" - keys = ["map_0", "map_1", "map_2", "map_3"] - - # Put all keys - tq.kv_batch_put( - keys=keys, - partition_id=partition_id, - fields=TensorDict({"data": torch.randn(4, 5)}, batch_size=4), - tags=[{"i": i} for i in range(4)], - ) - - # Verify mapping consistency via controller - partition = get_controller_partition(controller, partition_id) - - for key in keys: - assert key in partition.keys_mapping - global_idx = partition.keys_mapping[key] - assert global_idx in partition.revert_keys_mapping - assert partition.revert_keys_mapping[global_idx] == key - def test_field_expansion_across_samples(self, controller): """Test that new fields can be added across samples.""" partition_id = "test_partition" @@ -498,77 +498,26 @@ def test_field_expansion_across_samples(self, controller): tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_b": torch.tensor([[2]])}, tag=None) # Add different field to second key - tq.kv_put(key=keys[1], partition_id=partition_id, fields={"field_c": torch.tensor([[3]])}, tag=None) + tq.kv_put( + key=keys[1], + partition_id=partition_id, + fields={"field_a": torch.tensor([[3]]), "field_c": torch.tensor([[4]])}, + tag=None, + ) # Verify field expansion in controller partition = get_controller_partition(controller, partition_id) - # All fields should be registered + # All fields should be registered, but only samples with the actual fields are labeled as READY_FOR_CONSUME assert "field_a" in partition.field_name_mapping assert "field_b" in partition.field_name_mapping assert "field_c" in partition.field_name_mapping - def test_empty_tag_list(self): - """Test operations with empty tags.""" - partition_id = "test_partition" - key = "empty_tag" - - # Use 1D tensor - will be auto-unsqueezed to 2D - tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([1])}, tag={}) - - # Should work and data should be retrievable - will be 2D after unsqueeze - retrieved = tq.kv_get(keys=key, partition_id=partition_id) - assert_tensor_equal(retrieved["data"], torch.tensor([[1]])) - - def test_large_batch_put_and_get(self): - """Test putting and getting a large batch of samples.""" - partition_id = "test_partition" - num_samples = 100 - keys = [f"large_{i}" for i in range(num_samples)] - - # Create batch data - data = TensorDict( - { - "input_ids": torch.randn(num_samples, 10), - "attention_mask": torch.ones(num_samples, 10), - }, - batch_size=num_samples, - ) - - tags = [{"idx": i} for i in range(num_samples)] - - # Batch put - tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data, tags=tags) - - # Batch get all - retrieved = tq.kv_get(keys=keys, partition_id=partition_id) - - assert retrieved["input_ids"].shape == (num_samples, 10) - assert retrieved["attention_mask"].shape == (num_samples, 10) - - # Verify specific samples - assert_tensor_equal(retrieved["input_ids"][0], data["input_ids"][0]) - assert_tensor_equal(retrieved["input_ids"][99], data["input_ids"][99]) - - def test_controller_partition_synchronization(self, controller): - """Test that controller partition state is synchronized with operations.""" - partition_id = "test_partition" - key = "sync_test" - - # Put data - tq.kv_put(key=key, partition_id=partition_id, fields={"x": torch.tensor([[42]])}, tag={"sync": True}) - - # Get snapshot before clear - partition_before = get_controller_partition(controller, partition_id) - global_idx = partition_before.keys_mapping[key] - assert partition_before.production_status[global_idx, partition_before.field_name_mapping["x"]] == 1 - - # Clear - tq.kv_clear(keys=key, partition_id=partition_id) - - # Get snapshot after clear - partition_after = get_controller_partition(controller, partition_id) - assert key not in partition_after.keys_mapping + # We can only fetch "field_a" because not all requested keys has other fields + data = tq.kv_batch_get(keys=keys, partition_id=partition_id) + assert "field_a" in data + assert "field_b" not in data + assert "field_c" not in data def run_tests(): diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index f0421c0..26f1225 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -23,9 +23,9 @@ async_clear_samples, async_get_data, async_get_meta, + async_kv_batch_get, async_kv_batch_put, async_kv_clear, - async_kv_get, async_kv_list, async_kv_put, async_put, @@ -36,9 +36,9 @@ get_data, get_meta, init, + kv_batch_get, kv_batch_put, kv_clear, - kv_get, kv_list, kv_put, put, @@ -70,12 +70,12 @@ "close", "kv_put", "kv_batch_put", - "kv_get", + "kv_batch_get", "kv_list", "kv_clear", "async_kv_put", "async_kv_batch_put", - "async_kv_get", + "async_kv_batch_get", "async_kv_list", "async_kv_clear", ] + [ diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index ed325b8..49b706f 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -798,7 +798,7 @@ def kv_batch_put(keys: list[str], partition_id: str, fields: TensorDict, tags: l tq_client.set_custom_meta(batch_meta) -def kv_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict: +def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict: """Get data from TransferQueue using user-specified keys. This is a convenience method for retrieving data using keys instead of indexes. @@ -819,9 +819,9 @@ def kv_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] >>> import transfer_queue as tq >>> tq.init() >>> # Get single key with all fields - >>> data = tq.kv_get(key="sample_1", partition_id="train") + >>> data = tq.kv_batch_get(key="sample_1", partition_id="train") >>> # Get multiple keys with specific fields - >>> data = tq.kv_get( + >>> data = tq.kv_batch_get( ... keys=["sample_1", "sample_2"], ... partition_id="train", ... fields="input_ids" @@ -1038,7 +1038,7 @@ async def async_kv_batch_put( await tq_client.async_set_custom_meta(batch_meta) -async def async_kv_get( +async def async_kv_batch_get( keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None ) -> TensorDict: """Asynchronously get data from TransferQueue using user-specified keys. @@ -1061,9 +1061,9 @@ async def async_kv_get( >>> import transfer_queue as tq >>> tq.init() >>> # Get single key with all fields - >>> data = await tq.async_kv_get(key="sample_1", partition_id="train") + >>> data = await tq.async_kv_batch_get(key="sample_1", partition_id="train") >>> # Get multiple keys with specific fields - >>> data = await tq.async_kv_get( + >>> data = await tq.async_kv_batch_get( ... keys=["sample_1", "sample_2"], ... partition_id="train", ... fields="input_ids" diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py index 2b309ca..9f8f3a7 100644 --- a/tutorial/02_kv_interface.py +++ b/tutorial/02_kv_interface.py @@ -56,10 +56,10 @@ def demonstrate_kv_api(): """ Demonstrate the Key-Value (KV) semantic API: - kv_put & kv_batch_put -> kv_list -> kv_get -> kv_clear + kv_put & kv_batch_put -> kv_list -> kv_batch_get -> kv_clear """ print("=" * 80) - print("Key-Value Semantic API Demo: kv_put/kv_batch_put → kv_list → kv_get → kv_clear") + print("Key-Value Semantic API Demo: kv_put/kv_batch_put → kv_list → kv_batch_get → kv_clear") print("=" * 80) # Step 1: Put a single key-value pair with kv_put @@ -161,16 +161,16 @@ def demonstrate_kv_api(): for k, t in zip(all_keys, all_tags, strict=False): print(f" - key='{k}' | tag={t}") - # Step 6: Retrieve specific fields using kv_get - print("\n[Step 6] Retrieving specific fields (Column) with kv_get...") + # Step 6: Retrieve specific fields using kv_batch_get + print("\n[Step 6] Retrieving specific fields (Column) with kv_batch_get...") print(" Fetching only 'input_ids' to save bandwidth (ignoring 'attention_mask' and 'response').") - retrieved_input_ids = tq.kv_get(keys=all_keys, partition_id=partition_id, fields="input_ids") + retrieved_input_ids = tq.kv_batch_get(keys=all_keys, partition_id=partition_id, fields="input_ids") print(f" ✓ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.") - # # Step 7: Retrieve all fields using kv_get - print("\n[Step 7] Retrieving all fields with kv_get...") - retrieved_all = tq.kv_get(keys=all_keys, partition_id=partition_id) + # # Step 7: Retrieve all fields using kv_batch_get + print("\n[Step 7] Retrieving all fields with kv_batch_get...") + retrieved_all = tq.kv_batch_get(keys=all_keys, partition_id=partition_id) print(f" Retrieved all fields for {all_keys}:") print(f" Fields: {list(retrieved_all.keys())}") print( @@ -200,7 +200,7 @@ def main(): Key Methods: 1. (async_)kv_put - Insert/Update a multi-column sample by key, with optional metadata tag 2. (async_)kv_batch_put - Put multiple key-value pairs efficiently in batch - 3. (async_)kv_get - Retrieve samples (by keys), supporting column selection (by fields) + 3. (async_)kv_batch_get - Retrieve samples (by keys), supporting column selection (by fields) 4. (async_)kv_list - List keys and tags (metadata) in a partition 5. (async_)kv_clear - Remove key-value pairs from storage From 99b7fc995bd0f1aaa8e5d4018acc751a6ca377c7 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sun, 8 Feb 2026 15:54:52 +0800 Subject: [PATCH 19/34] update readme Signed-off-by: 0oshowero0 --- README.md | 79 ++++++++++++++++++++++--------------- tutorial/02_kv_interface.py | 6 +-- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 2c7d4f0..ba9e54d 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ TransferQueue offers **fine-grained, sub-sample-level** data management and **lo

🔄 Updates

- - **Jan 28, 2026**: We experimentally introduce `StreamingDataloader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/05_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_streaming_dataloader.py) for details. + - **Feb 8, 2026**: 🔥 The initialization and usage is greatly simplified by high-level APIs [PR#26](https://github.com/Ascend/TransferQueue/pull/26), [PR#28](https://github.com/Ascend/TransferQueue/pull/28). You can now use a Redis-style API to take advantage of most of the advanced features provided by TransferQueue! + - **Jan 28, 2026**: We experimentally introduce `StreamingDataloader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for details. - **Dec 30, 2025**: **TransferQueue x verl** integration is tested with the DAPO algorithm at scale **(64 nodes, 1024 cards)**. It significantly optimizes host memory utilization and accelerates data transfers. Stay tuned for more details! - **Dec 20, 2025**: 🔥 The official [tutorial](https://github.com/Ascend/TransferQueue/tree/main/tutorial) is released! Feel free to check it out. - **Nov 10, 2025**: We disentangle the data retrieval logic from TransferQueueController [PR#101](https://github.com/TransferQueue/TransferQueue/pull/101). Now you can implement your own `Sampler` to control how to consume the data. @@ -91,26 +92,55 @@ This data structure design is motivated by the computational characteristics of

-### User Interface: Asynchronous & Synchronous Client -To simplify the usage of TransferQueue, we have encapsulated this process into `TransferQueueClient`. The client provides both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue into their framework. +### User Interface: High-Level & Low-Level APIs -We also experimentally provide a `StreamingDataLoader` interface as a standard PyTorch DataLoader. Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks. -This interface simplifies TransferQueue's integration, ensuring seamless compatibility with existing training workflows. Please refer to our [Roadmap](https://github.com/Ascend/TransferQueue/issues/1) and [tutorials/05_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_streaming_dataloader.py) for more details. +| Level | Tier | Style | Fine-Grained Access | Streaming | Sampler | Multiple-Backends | +|---|---|---|---|---|---|---| +| High | **KV Interface** (this PR) | Put/Get/List/Clear | ✓ | ✗ | ✗ | ✓ | +| High | **StreamingDataLoader** (#23) | PyTorch DataLoader | ✓ |✓ | ✓ | ✓ | +| Low | **TransferQueueClient** | Metadata-based | ✓ | ✓ | ✓ | ✓ | -

🔥 Showcases

-### General Usage +#### Key-Value based API + +To simplify the usage of TransferQueue, we have provided a Redis-style high-level API that can enjoy most of the advanced features provided by TransferQueue ([#PR28](https://github.com/Ascend/TransferQueue/pull/28)). + +**Methods** + +- **(async_)kv_put**: Insert/Update a multi-column sample by key, with optional metadata tag +- **(async_)kv_batch_put**: Put multiple key-value pairs efficiently in batch +- **(async_)kv_batch_get**: Retrieve samples (by keys), supporting column selection (by fields) +- **(async_)kv_list**: List keys and tags (metadata) in a partition +- **(async_)kv_clear**: Remove key-value pairs from storage + +**Key Features** + +- **Redis-style Semantics**: Familiar KV interface (Put/Get/List) for zero learning curve +- **Fine-grained Access**: Update or retrieve specific fields (columns) within a key (row) without full op. +- **Partition Isolation**: Logical separation of storage namespaces +- **Metadata Tags**: Lightweight metadata for status tracking +- **Pluggable Backends**: Supports multiple backends + +#### StreamingDataLoader API -The primary interaction points are `AsyncTransferQueueClient` and `TransferQueueClient`, serving as the communication interface with the TransferQueue system. +Designed as a drop-in replacement for the standard PyTorch `DataLoader`, this API allows each rank to automatically consume data without single-controller intervention. -Core interfaces: +In this scenario, `TransferQueueController` serves as a side-controller for data dispatching, with user-defined `Sampler` class to organize dataflow. +It encapsulates the complex scheduling and data transfer logic required for various parallelism strategies, seamlessly integrating TransferQueue into existing training workflows and simplifying the development of disaggregated frameworks. -- `(async_)get_meta(data_fields: list[str], batch_size:int, partition_id: str, mode: str, task_name:str, sampling_config: Optional[dict[str, Any]]) -> BatchMeta` -- `(async_)get_data(metadata: BatchMeta) -> TensorDict` -- `(async_)put(data: TensorDict, metadata: Optional[BatchMeta], partition_id: Optional[str])` -- `(async_)clear_partition(partition_id: str)` and `(async_)clear_samples(metadata: BatchMeta)` +See [Roadmap](https://github.com/Ascend/TransferQueue/issues/1) and [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for more details. -**Refer to our [tutorial](https://github.com/Ascend/TransferQueue/tree/main/tutorial) for detailed examples.** +#### Low-Level Native API + +The native interface of TransferQueue are implemented in `TransferQueueClient`. It offers maximum flexibility through native, atomic operations. + +Developers can leverage `TransferQueueClient` directly to implement advanced features that require fine-grained control and fully streamed data scheduling, as illustrated in the following tutorials: +- [tutorial/03_metadata_concepts.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/03_metadata_concepts.py) +- [tutorial/04_understanding_controller.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/04_understanding_controller.py) +- [tutorial/05_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_custom_sampler.py) + + +

🔥 Showcases

### Collocated Example @@ -131,7 +161,7 @@ You may refer to the [recipe](https://github.com/Ascend/TransferQueue/tree/dev/r ### Disaggregated Example -We have implemented a series of PRs ([#4](https://github.com/Ascend/TransferQueue/pull/4), [#7](https://github.com/Ascend/TransferQueue/pull/7), [#9](https://github.com/Ascend/TransferQueue/pull/9), [#16](https://github.com/Ascend/TransferQueue/pull/16)) to establish a **standardized, fully-streamed distributed** workflow via TransferQueue. +We have experimentally implemented a **standardized, fully-streamed distributed** workflow via TransferQueue. By leveraging the `RankAwareSampler` and `StreamingDataLoader` interfaces, we achieve a **streamlined micro-batch-level producer-consumer pipeline**. This design eliminates the need to manually determine data dispatching logic across varying parallelism strategies—a typical complexity in the single-controller paradigm—thereby greatly simplifying framework design. @@ -186,7 +216,7 @@ pip install TransferQueue

-> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorageUnit` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community! +> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorage` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community! For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/tml8ke0zkgn6roey?singleDoc#). @@ -250,7 +280,7 @@ batch_meta = client.get_meta( ) ``` -**Refer to [tutorial/04_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/04_custom_sampler.py) for more details.** +**Refer to [tutorial/05_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_custom_sampler.py) for more details.** ### How to integrate a new storage backend @@ -299,21 +329,6 @@ pip install pre-commit pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always ``` -

🛣️ Roadmap

- -- [x] Support data rewrite for partial rollout & agentic post-training -- [x] Provide a general storage abstraction layer `TransferQueueStorageManager` to manage distributed storage units, which simplifies `Client` design and makes it possible to introduce different storage backends ([PR#66](https://github.com/TransferQueue/TransferQueue/pull/66), [issue#72](https://github.com/TransferQueue/TransferQueue/issues/72)) -- [x] Implement `AsyncSimpleStorageManager` as the default storage backend based on the `TransferQueueStorageManager` abstraction -- [x] Provide a `KVStorageManager` to cover all the KV based storage backends ([PR#96](https://github.com/TransferQueue/TransferQueue/pull/96)) -- [x] Support topic-based data partitioning to maintain train/val/test data simultaneously ([PR#98](https://github.com/TransferQueue/TransferQueue/pull/98)) -- [x] Release the first stable version through PyPI -- [ ] Support disaggregated framework (each rank retrieves its own data without going through a centralized node) -- [ ] Provide a `StreamingDataLoader` interface for disaggregated framework -- [ ] Support load-balancing and dynamic batching -- [x] Support high-performance storage backends for RDMA transmission (e.g., [Mooncake Store](https://github.com/kvcache-ai/Mooncake), [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html)...) -- [x] High-performance serialization and deserialization -- [ ] More documentation, examples and tutorials -

📑 Citation

Please kindly cite our paper if you find this repo is useful: diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py index 9f8f3a7..51c994b 100644 --- a/tutorial/02_kv_interface.py +++ b/tutorial/02_kv_interface.py @@ -200,7 +200,7 @@ def main(): Key Methods: 1. (async_)kv_put - Insert/Update a multi-column sample by key, with optional metadata tag 2. (async_)kv_batch_put - Put multiple key-value pairs efficiently in batch - 3. (async_)kv_batch_get - Retrieve samples (by keys), supporting column selection (by fields) + 3. (async_)kv_batch_get - Retrieve samples (by keys), supporting column selection (by fields) 4. (async_)kv_list - List keys and tags (metadata) in a partition 5. (async_)kv_clear - Remove key-value pairs from storage @@ -216,9 +216,9 @@ def main(): - Integration with external ReplayBuffer/single-controller that manage sample dispatching Limitations (vs low-level native APIs): - - No built-in production/consumption tracking: Users have to manually check status through tags. + - No built-in production/consumption tracking: Users must manually check status via tags externally. - No built-in Sampler support: Must implement data dispatch by ReplayBuffer or single-controller externally. - - No fully streaming: Consumers must wait for single-controller to dispatch `keys`. + - Not fully streaming: Consumers must wait for single-controller to dispatch `keys`. """ ) ) From 8e80d6df97961c97f0fe6ec31b6dc7acbde1b0be Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sun, 8 Feb 2026 16:11:13 +0800 Subject: [PATCH 20/34] fix comments Signed-off-by: 0oshowero0 --- tests/test_controller.py | 4 +- transfer_queue/client.py | 6 +-- transfer_queue/interface.py | 18 ++++----- transfer_queue/utils/common.py | 68 +--------------------------------- 4 files changed, 15 insertions(+), 81 deletions(-) diff --git a/tests/test_controller.py b/tests/test_controller.py index f3a5c26..243bc19 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -922,8 +922,8 @@ def test_controller_kv_retrieve_keys_with_custom_meta(self, ray_setup): # Verify custom_meta is preserved all_custom_meta = retrieved_metadata.get_all_custom_meta() assert len(all_custom_meta) == 2 - assert all_custom_meta[metadata.global_indexes[0]]["score"] == 0.9 - assert all_custom_meta[metadata.global_indexes[1]]["tag"] == "B" + assert all_custom_meta[0]["score"] == 0.9 + assert all_custom_meta[1]["tag"] == "B" print("✓ kv_retrieve_keys preserves custom_meta") diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 2af07a7..302215e 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -943,7 +943,7 @@ async def async_kv_retrieve_keys( raise ValueError("Received an empty list as keys.") # validate all the elements are str if not all(isinstance(k, str) for k in keys): - raise TypeError("Not all element in `keys` are strings.") + raise TypeError("Not all elements in `keys` are strings.") else: raise TypeError("Only string or list of strings are allowed as `keys`.") @@ -1028,7 +1028,7 @@ async def async_kv_list( f"{response_msg.body.get('message', 'Unknown error')}" ) except Exception as e: - raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e + raise RuntimeError(f"[{self.client_id}]: Error in kv_list: {str(e)}") from e def close(self) -> None: """Close the client and cleanup resources including storage manager.""" @@ -1475,7 +1475,7 @@ def kv_retrieve_keys( def kv_list( self, partition_id: str, - ) -> tuple[list[Optional[str]], list[Optional[dict]]]: + ) -> tuple[list[str], list[dict]]: """Synchronously retrieve keys and custom_meta from the controller for partition. Args: diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 49b706f..905566e 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -717,13 +717,13 @@ def kv_put( if isinstance(fields, dict): # TODO: consider whether to support this... batch = {} - for key, value in fields.items(): + for field_name, value in fields.items(): if isinstance(value, torch.Tensor): if value.is_nested: raise ValueError("Please use (async)kv_batch_put for batch operation") - batch[key] = value.unsqueeze(0) + batch[field_name] = value.unsqueeze(0) else: - batch[key] = NonTensorStack(value) + batch[field_name] = NonTensorStack(value) fields = TensorDict(batch, batch_size=[1]) elif not isinstance(fields, TensorDict): raise ValueError("field can only be dict or TensorDict") @@ -768,7 +768,7 @@ def kv_batch_put(keys: list[str], partition_id: str, fields: TensorDict, tags: l if fields is None and tags is None: raise ValueError("Please provide at least one parameter of fields or tag.") - if fields.batch_size[0] != len(keys): + if fields is not None and fields.batch_size[0] != len(keys): raise ValueError( f"`keys` with length {len(keys)} does not match the `fields` TensorDict with " f"batch_size {fields.batch_size[0]}" @@ -846,7 +846,7 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list return data -def kv_list(partition_id: str) -> tuple[list[Optional[str]], list[Optional[dict[str, Any]]]]: +def kv_list(partition_id: str) -> tuple[list[str], list[dict[str, Any]]]: """List all keys and their metadata in a partition. Args: @@ -956,13 +956,13 @@ async def async_kv_put( if isinstance(fields, dict): # TODO: consider whether to support this... batch = {} - for key, value in fields.items(): + for field_name, value in fields.items(): if isinstance(value, torch.Tensor): if value.is_nested: raise ValueError("Please use (async)kv_batch_put for batch operation") - batch[key] = value.unsqueeze(0) + batch[field_name] = value.unsqueeze(0) else: - batch[key] = NonTensorStack(value) + batch[field_name] = NonTensorStack(value) fields = TensorDict(batch, batch_size=[1]) elif not isinstance(fields, TensorDict): raise ValueError("field can only be dict or TensorDict") @@ -1008,7 +1008,7 @@ async def async_kv_batch_put( if fields is None and tags is None: raise ValueError("Please provide at least one parameter of fields or tag.") - if fields.batch_size[0] != len(keys): + if fields is not None and fields.batch_size[0] != len(keys): raise ValueError( f"`keys` with length {len(keys)} does not match the `fields` TensorDict with " f"batch_size {fields.batch_size[0]}" diff --git a/transfer_queue/utils/common.py b/transfer_queue/utils/common.py index 5192e9f..e25f6b0 100644 --- a/transfer_queue/utils/common.py +++ b/transfer_queue/utils/common.py @@ -16,13 +16,11 @@ import logging import os from contextlib import contextmanager -from typing import Any, Optional +from typing import Optional -import numpy as np import psutil import ray import torch -from tensordict import NonTensorStack, TensorDict logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -100,67 +98,3 @@ def get_env_bool(env_key: str, default: bool = False) -> bool: true_values = {"true", "1", "yes", "y", "on"} return env_value_lower in true_values - - -def dict_to_tensordict(data: dict[str, Any]) -> TensorDict: - """ - Create a TensorDict from a dict of tensors and non_tensors. - """ - - batch = {} - - final_batch_size = None - tensor_batch_size = None - deterministic_tensor_batch_size = None - deterministic_non_tensor_batch_size = None - - for key, val in data.items(): - if isinstance(val, torch.Tensor): - if val.is_nested and val.layout == torch.strided: - # must use unbind for strided nested tensor - deterministic_tensor_batch_size = len(val.unbind()) - else: - tensor_batch_size = val.shape[0] - batch[key] = val - elif isinstance(val, np.ndarray): - batch[key] = val - tensor_batch_size = val.shape[0] - elif isinstance(val, str): - batch[key] = val - deterministic_non_tensor_batch_size = 1 - elif isinstance(val, list): - batch[key] = NonTensorStack(*val) - deterministic_non_tensor_batch_size = len(val) - else: - batch[key] = NonTensorStack(val) - deterministic_non_tensor_batch_size = 1 - - if deterministic_tensor_batch_size: - if deterministic_non_tensor_batch_size: - assert deterministic_non_tensor_batch_size == deterministic_tensor_batch_size - if final_batch_size: - assert final_batch_size == deterministic_tensor_batch_size - else: - final_batch_size = deterministic_tensor_batch_size - - if deterministic_non_tensor_batch_size: - if deterministic_tensor_batch_size: - assert deterministic_non_tensor_batch_size == tensor_batch_size - if final_batch_size: - assert final_batch_size == deterministic_non_tensor_batch_size - else: - final_batch_size = deterministic_non_tensor_batch_size - - if not final_batch_size: - raise RuntimeError("Cannot correctly determine batch_size for input.") - - if tensor_batch_size: - if tensor_batch_size != final_batch_size: - assert final_batch_size == 1 - for k, v in batch.items(): - if isinstance(v, torch.Tensor): - batch[k] = v.unsqueeze(0) - elif isinstance(v, np.ndarray): - batch[k] = np.expand_dims(v, 0) - - return TensorDict(batch, batch_size=[final_batch_size]) From b6b7d4b51c11d9b9f7a7728ea4b4daf6d8bad337 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sun, 8 Feb 2026 17:07:41 +0800 Subject: [PATCH 21/34] fix comment Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 538ba21..03e3631 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -545,6 +545,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te ) if mask: + if partition_global_index.numel() == 0: + empty_status = self.consumption_status[task_name].new_zeros(0) + return partition_global_index, empty_status with self.data_status_lock: self.ensure_samples_capacity(max(partition_global_index) + 1) consumption_status = self.consumption_status[task_name][partition_global_index] From d3dd0147523268d26b54e74ac358d12396fb2163 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sun, 8 Feb 2026 17:16:43 +0800 Subject: [PATCH 22/34] fix more minor comments Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 8 ++++---- transfer_queue/interface.py | 34 +++++++++++++++++++++------------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 302215e..2a0f3f1 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -270,7 +270,7 @@ async def async_set_custom_meta( Example: >>> # Create batch with custom metadata >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) - >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) >>> asyncio.run(client.async_set_custom_meta(batch_meta)) """ assert socket is not None @@ -1193,8 +1193,8 @@ def set_custom_meta(self, metadata: BatchMeta) -> None: Example: >>> # Create batch with custom metadata - >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) - >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=2, ...) + >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) >>> client.set_custom_meta(batch_meta) """ @@ -1271,7 +1271,7 @@ def get_data(self, metadata: BatchMeta) -> TensorDict: - Requested data fields (e.g., "prompts", "attention_mask") Example: - >>> batch_meta = client.get_data( + >>> batch_meta = client.get_meta( ... data_fields=["prompts", "attention_mask"], ... batch_size=4, ... partition_id="train_0", diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 905566e..75fedbd 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -282,8 +282,8 @@ def set_custom_meta(metadata: BatchMeta) -> None: >>> tq.init() >>> >>> # Create batch with custom metadata - >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...) - >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...) + >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) >>> tq.set_custom_meta(batch_meta) """ tq_client = _maybe_create_transferqueue_client() @@ -367,7 +367,7 @@ def get_data(metadata: BatchMeta) -> TensorDict: >>> import transfer_queue as tq >>> tq.init() >>> - >>> batch_meta = tq.get_data( + >>> batch_meta = tq.get_meta( ... data_fields=["prompts", "attention_mask"], ... batch_size=4, ... partition_id="train_0", @@ -496,8 +496,8 @@ async def async_set_custom_meta( >>> tq.init() >>> >>> # Create batch with custom metadata - >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...) - >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...) + >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) >>> asyncio.run(tq.async_set_custom_meta(batch_meta)) """ tq_client = _maybe_create_transferqueue_client() @@ -664,7 +664,10 @@ def close(): # ==================== KV Interface API ==================== def kv_put( - key: str, partition_id: str, fields: Optional[TensorDict | dict[str, Any]], tag: Optional[dict[str, Any]] + key: str, + partition_id: str, + fields: Optional[TensorDict | dict[str, Any]] = None, + tag: Optional[dict[str, Any]] = None, ) -> None: """Put a single key-value pair to TransferQueue. @@ -735,7 +738,9 @@ def kv_put( tq_client.set_custom_meta(batch_meta) -def kv_batch_put(keys: list[str], partition_id: str, fields: TensorDict, tags: list[dict[str, Any]]) -> None: +def kv_batch_put( + keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None +) -> None: """Put multiple key-value pairs to TransferQueue in batch. This method stores multiple key-value pairs in a single operation, which is more @@ -819,7 +824,7 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list >>> import transfer_queue as tq >>> tq.init() >>> # Get single key with all fields - >>> data = tq.kv_batch_get(key="sample_1", partition_id="train") + >>> data = tq.kv_batch_get(keys="sample_1", partition_id="train") >>> # Get multiple keys with specific fields >>> data = tq.kv_batch_get( ... keys=["sample_1", "sample_2"], @@ -885,7 +890,7 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None: >>> import transfer_queue as tq >>> tq.init() >>> # Clear single key - >>> tq.kv_clear(key="sample_1", partition_id="train") + >>> tq.kv_clear(keys="sample_1", partition_id="train") >>> # Clear multiple keys >>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train") """ @@ -902,7 +907,10 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None: # ==================== KV Interface API ==================== async def async_kv_put( - key: str, partition_id: str, fields: Optional[TensorDict | dict[str, Any]], tag: Optional[dict[str, Any]] + key: str, + partition_id: str, + fields: Optional[TensorDict | dict[str, Any]] = None, + tag: Optional[dict[str, Any]] = None, ) -> None: """Asynchronously put a single key-value pair to TransferQueue. @@ -975,7 +983,7 @@ async def async_kv_put( async def async_kv_batch_put( - keys: list[str], partition_id: str, fields: TensorDict, tags: list[dict[str, Any]] + keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None ) -> None: """Asynchronously put multiple key-value pairs to TransferQueue in batch. @@ -1061,7 +1069,7 @@ async def async_kv_batch_get( >>> import transfer_queue as tq >>> tq.init() >>> # Get single key with all fields - >>> data = await tq.async_kv_batch_get(key="sample_1", partition_id="train") + >>> data = await tq.async_kv_batch_get(keys="sample_1", partition_id="train") >>> # Get multiple keys with specific fields >>> data = await tq.async_kv_batch_get( ... keys=["sample_1", "sample_2"], @@ -1127,7 +1135,7 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: >>> import transfer_queue as tq >>> tq.init() >>> # Clear single key - >>> await tq.async_kv_clear(key="sample_1", partition_id="train") + >>> await tq.async_kv_clear(keys="sample_1", partition_id="train") >>> # Clear multiple keys >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train") """ From 1eda4e931ffe0cef2874c431a9cdf0292db8b3de Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sun, 8 Feb 2026 18:53:44 +0800 Subject: [PATCH 23/34] fix comments and add FIXME comments Signed-off-by: 0oshowero0 --- README.md | 2 +- transfer_queue/client.py | 8 ++++---- transfer_queue/interface.py | 8 ++++---- transfer_queue/storage/managers/base.py | 4 +++- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ba9e54d..150c666 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ TransferQueue offers **fine-grained, sub-sample-level** data management and **lo

🔄 Updates

- **Feb 8, 2026**: 🔥 The initialization and usage is greatly simplified by high-level APIs [PR#26](https://github.com/Ascend/TransferQueue/pull/26), [PR#28](https://github.com/Ascend/TransferQueue/pull/28). You can now use a Redis-style API to take advantage of most of the advanced features provided by TransferQueue! - - **Jan 28, 2026**: We experimentally introduce `StreamingDataloader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for details. + - **Jan 28, 2026**: We experimentally introduce `StreamingDataLoader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for details. - **Dec 30, 2025**: **TransferQueue x verl** integration is tested with the DAPO algorithm at scale **(64 nodes, 1024 cards)**. It significantly optimizes host memory utilization and accelerates data transfers. Stay tuned for more details! - **Dec 20, 2025**: 🔥 The official [tutorial](https://github.com/Ascend/TransferQueue/tree/main/tutorial) is released! Feel free to check it out. - **Nov 10, 2025**: We disentangle the data retrieval logic from TransferQueueController [PR#101](https://github.com/TransferQueue/TransferQueue/pull/101). Now you can implement your own `Sampler` to control how to consume the data. diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 2a0f3f1..470d57d 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -260,8 +260,8 @@ async def async_set_custom_meta( Args: metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() or - BatchMeta.set_custom_meta() before calling this method. + The custom_meta should be set using BatchMeta.update_custom_meta() + before calling this method. socket: ZMQ async socket for message transmission (injected by decorator) Raises: @@ -1185,8 +1185,8 @@ def set_custom_meta(self, metadata: BatchMeta) -> None: Args: metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() or - BatchMeta.set_custom_meta() before calling this method. + The custom_meta should be set using BatchMeta.update_custom_meta() + before calling this method. Raises: RuntimeError: If communication fails or controller returns error response diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 75fedbd..42e8839 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -271,8 +271,8 @@ def set_custom_meta(metadata: BatchMeta) -> None: Args: metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() or - BatchMeta.set_custom_meta() before calling this method. + The custom_meta should be set using BatchMeta.update_custom_meta() + before calling this method. Raises: RuntimeError: If communication fails or controller returns error response @@ -484,8 +484,8 @@ async def async_set_custom_meta( Args: metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() or - BatchMeta.set_custom_meta() before calling this method. + The custom_meta should be set using BatchMeta.update_custom_meta() + before calling this method. socket: ZMQ async socket for message transmission (injected by decorator) Raises: diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 258dd91..c0c5058 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -589,7 +589,9 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: for global_idx in metadata.global_indexes: per_field_custom_backend_meta[global_idx] = {} - # TODO(tianyi): the order of custom meta is coupled with keys/values + # FIXME(tianyi): the order of custom backend meta is coupled with keys/values + # FIXME: if put_data is called to partially update/add new fields, the current + # implementation will cause custom_backend_meta losses or mismatch! for (field_name, global_idx), meta_value in zip( itertools.product(sorted(metadata.field_names), metadata.global_indexes), custom_backend_meta, From 620edfe0bd0922aa6d500b04d90d97e34f6c03bb Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 13:00:34 +0800 Subject: [PATCH 24/34] improve kv_list Signed-off-by: 0oshowero0 --- tests/e2e/test_kv_interface_e2e.py | 97 ++++++++++++++++++++++++------ tests/test_client.py | 48 ++++++--------- transfer_queue/client.py | 57 ++++++++++++------ transfer_queue/controller.py | 28 ++++++--- transfer_queue/interface.py | 77 +++++++++++++++++------- tutorial/02_kv_interface.py | 14 +++-- 6 files changed, 218 insertions(+), 103 deletions(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 3918bbe..71e9833 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -69,7 +69,9 @@ def cleanup_partition(controller): """Cleanup partition after each test.""" yield try: - ray.get(controller.clear_partition.remote("test_partition")) + partition_ids = ray.get(controller.list_partitions.remote()) + for partition_id in partition_ids: + ray.get(controller.clear_partition.remote(partition_id)) except Exception: pass @@ -409,8 +411,8 @@ def test_kv_batch_get_nonexistent_key(self, controller): class TestKVListE2E: """End-to-end tests for kv_list functionality.""" - def test_kv_list_all_keys(self, controller): - """Test listing all keys in a partition.""" + def test_kv_list_single_partition(self, controller): + """Test listing all keys and tags in single partition.""" partition_id = "test_partition" keys = ["list_0", "list_1", "list_2"] @@ -418,24 +420,81 @@ def test_kv_list_all_keys(self, controller): tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag={"id": i}) # List all keys - listed_keys, tags = tq.kv_list(partition_id=partition_id) + partition_info = tq.kv_list(partition_id=partition_id) - assert len(listed_keys) == 3 + assert len(partition_info.keys()) == 1 + assert "test_partition" in partition_info.keys() + assert len(partition_info["test_partition"]) == 3 for key in keys: - assert key in listed_keys + assert key in partition_info["test_partition"] # Verify tags match - for i, (key, tag) in enumerate(zip(listed_keys, tags, strict=False)): + for i, (key, tag) in enumerate(partition_info["test_partition"].items()): assert tag["id"] == i + def test_kv_list_all_partitions(self, controller): + """Test listing keys and tags in all partitions.""" + partition_id = ["test_partition0", "test_partition1", "test_partition2"] + + keys_partition0 = ["list_0", "list_1", "list_2"] + keys_partition1 = ["list_0", "list_1", "list_2"] # deliberately set same keys + keys_partition2 = ["list_3", "list_4", "list_5", "list_6"] + + fields_partition0 = TensorDict({"data": torch.tensor([[0], [1], [2]])}, batch_size=3) + fields_partition1 = TensorDict({"data": torch.tensor([[3], [4], [5]])}, batch_size=3) + fields_partition2 = TensorDict({"data": torch.tensor([[6], [7], [8], [9]])}, batch_size=4) + + tags_partition0 = [{"id": i} for i in range(3)] + tags_partition1 = [{"id": i + 3} for i in range(3)] + tags_partition2 = [{"id": i + 6} for i in range(4)] + + # Put to TQ + tq.kv_batch_put( + keys=keys_partition0, partition_id=partition_id[0], fields=fields_partition0, tags=tags_partition0 + ) + tq.kv_batch_put( + keys=keys_partition1, partition_id=partition_id[1], fields=fields_partition1, tags=tags_partition1 + ) + tq.kv_batch_put( + keys=keys_partition2, partition_id=partition_id[2], fields=fields_partition2, tags=tags_partition2 + ) + + # List all keys + partition_info = tq.kv_list() + + # Verify all partitions are exist + assert len(partition_info.keys()) == 3 + assert "test_partition0" in partition_info.keys() + assert "test_partition1" in partition_info.keys() + assert "test_partition2" in partition_info.keys() + + assert len(partition_info["test_partition0"]) == 3 + for key in keys_partition0: + assert key in partition_info["test_partition0"] + + assert len(partition_info["test_partition1"]) == 3 + for key in keys_partition1: + assert key in partition_info["test_partition1"] + + assert len(partition_info["test_partition2"]) == 4 + for key in keys_partition2: + assert key in partition_info["test_partition2"] + + # Verify tags match + for i, (key, tag) in enumerate(partition_info["test_partition0"].items()): + assert tag["id"] == i + for i, (key, tag) in enumerate(partition_info["test_partition1"].items()): + assert tag["id"] == i + 3 + for i, (key, tag) in enumerate(partition_info["test_partition2"].items()): + assert tag["id"] == i + 6 + def test_kv_list_empty_partition(self): """Test listing empty partition.""" partition_id = "test_partition_empty" - keys, tags = tq.kv_list(partition_id=partition_id) + partition_info = tq.kv_list(partition_id=partition_id) - assert len(keys) == 0 - assert len(tags) == 0 + assert len(partition_info) == 0 class TestKVClearE2E: @@ -454,9 +513,9 @@ def test_kv_clear_single_key(self, controller): tq.kv_clear(keys=key, partition_id=partition_id) # Verify via kv_list - listed_keys, _ = tq.kv_list(partition_id=partition_id) - assert key not in listed_keys - assert other_key in listed_keys + partition_info = tq.kv_list(partition_id=partition_id) + assert key not in partition_info[partition_id] + assert other_key in partition_info[partition_id] # Verify via controller - key should be removed partition = get_controller_partition(controller, partition_id) @@ -475,12 +534,12 @@ def test_kv_clear_multiple_keys(self, controller): tq.kv_clear(keys=keys[:2], partition_id=partition_id) # Verify - listed_keys, _ = tq.kv_list(partition_id=partition_id) - assert len(listed_keys) == 2 - assert keys[0] not in listed_keys - assert keys[1] not in listed_keys - assert keys[2] in listed_keys - assert keys[3] in listed_keys + partition_info = tq.kv_list(partition_id=partition_id) + assert len(partition_info[partition_id]) == 2 + assert keys[0] not in partition_info[partition_id] + assert keys[1] not in partition_info[partition_id] + assert keys[2] in partition_info[partition_id] + assert keys[3] in partition_info[partition_id] class TestKVE2ECornerCases: diff --git a/tests/test_client.py b/tests/test_client.py index 5d3360f..5d308d8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -239,7 +239,7 @@ def _mock_kv_retrieve_keys(self, request_body): def _mock_kv_list(self, request_body): """Mock KV list response.""" - partition_id = request_body.get("partition_id", "") + partition_id = request_body.get("partition_id", None) # Initialize key tracking if not exists if not hasattr(self, "_kv_partition_keys"): @@ -247,7 +247,8 @@ def _mock_kv_list(self, request_body): # Return cached keys for this partition keys = self._kv_partition_keys.get(partition_id, []) - return {"keys": keys, "custom_meta": [{} for _ in range(len(keys))]} + + return {"partition_info": {partition_id: {k: {} for k in keys}}, "message": "success"} def _get_next_kv_index(self, partition_id): """Get next available index for KV keys in partition.""" @@ -1041,18 +1042,6 @@ async def test_async_kv_retrieve_keys_invalid_keys_type(self, client_setup): create=True, ) - @pytest.mark.asyncio - async def test_async_kv_list_empty_partition(self, client_setup): - """Test async_kv_list with empty partition.""" - client, _, _ = client_setup - - # Test async_kv_list with empty partition - keys, custom_meta = await client.async_kv_list(partition_id="empty_partition") - - # Should return empty list for partition with no keys - assert keys == [] - assert custom_meta == [] - @pytest.mark.asyncio async def test_async_kv_list_with_keys(self, client_setup): """Test async_kv_list returns keys after they are registered.""" @@ -1066,14 +1055,13 @@ async def test_async_kv_list_with_keys(self, client_setup): ) # Then list them - keys, custom_meta = await client.async_kv_list(partition_id="kv_partition") + partition_info = await client.async_kv_list(partition_id="kv_partition") # Verify keys are returned - assert keys is not None - assert len(keys) >= 2 - assert "key_1" in keys - assert "key_2" in keys - assert custom_meta == [{}, {}] + assert len(partition_info["kv_partition"]) >= 2 + assert "key_1" in partition_info["kv_partition"] + assert "key_2" in partition_info["kv_partition"] + assert list(partition_info["kv_partition"].values()) == [{}, {}] @pytest.mark.asyncio async def test_async_kv_list_multiple_partitions(self, client_setup): @@ -1093,16 +1081,20 @@ async def test_async_kv_list_multiple_partitions(self, client_setup): ) # List keys for each partition - keys_a, custom_meta_a = await client.async_kv_list(partition_id="partition_a") - keys_b, custom_meta_b = await client.async_kv_list(partition_id="partition_b") + partition_a = await client.async_kv_list(partition_id="partition_a") + partition_b = await client.async_kv_list(partition_id="partition_b") # Verify keys are isolated per partition - assert "partition_a_key" in keys_a - assert "partition_b_key" not in keys_a - assert "partition_b_key" in keys_b - assert "partition_a_key" not in keys_b - assert custom_meta_a == [{}] - assert custom_meta_b == [{}] + assert "partition_a" in partition_a + assert "partition_b" in partition_b + assert "partition_a" not in partition_b + assert "partition_b" not in partition_a + assert "partition_a_key" in partition_a["partition_a"] + assert "partition_b_key" not in partition_a["partition_a"] + assert "partition_b_key" in partition_b["partition_b"] + assert "partition_a_key" not in partition_b["partition_b"] + assert list(partition_a["partition_a"].values()) == [{}] + assert list(partition_b["partition_b"].values()) == [{}] def test_kv_retrieve_keys_type_validation(self, client_setup): """Test synchronous kv_retrieve_keys type validation.""" diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 470d57d..5b0c1e1 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -983,23 +983,32 @@ async def async_kv_retrieve_keys( @dynamic_socket(socket_name="request_handle_socket") async def async_kv_list( self, - partition_id: str, + partition_id: Optional[str] = None, socket: Optional[zmq.asyncio.Socket] = None, - ) -> tuple[list[str], list[dict]]: - """Asynchronously retrieve keys and custom_meta from the controller for partition. + ) -> dict[str, dict[str, Any]]: + """Asynchronously retrieve keys and custom_meta from the controller for one or all partitions. Args: - partition_id: Partition to retrieve from the controller + partition_id: The specific partition_id to query. + If None (default), returns keys from all partitions. socket: ZMQ socket (injected by decorator) Returns: - keys: list of keys in the partition - custom_meta: list of dict for custom_meta + A nested dictionary mapping partition IDs to their keys and metadata. + + Structure: + { + "partition_id": { + "key_name": { + "tag1": , + ... (other metadata) + }, + ..., + }, + ... + } """ - if partition_id is None: - return [], [] - request_msg = ZMQMessage.create( request_type=ZMQRequestType.KV_LIST, # type: ignore[arg-type] sender_id=self.client_id, @@ -1019,9 +1028,8 @@ async def async_kv_list( ) if response_msg.request_type == ZMQRequestType.KV_LIST_RESPONSE: - keys = response_msg.body.get("keys", []) - custom_meta = response_msg.body.get("custom_meta", []) - return keys, custom_meta + partition_info = response_msg.body.get("partition_info", {}) + return partition_info else: raise RuntimeError( f"[{self.client_id}]: Failed to list keys from controller {self._controller.id}: " @@ -1474,16 +1482,29 @@ def kv_retrieve_keys( def kv_list( self, - partition_id: str, - ) -> tuple[list[str], list[dict]]: - """Synchronously retrieve keys and custom_meta from the controller for partition. + partition_id: Optional[str] = None, + ) -> dict[str, dict[str, Any]]: + """Synchronously retrieve keys and custom_meta from the controller for one or all partitions. Args: - partition_id: Partition to retrieve from the controller + partition_id: The specific partition_id to query. + If None (default), returns keys from all partitions. + socket: ZMQ socket (injected by decorator) Returns: - keys: list of keys in the partition - custom_meta: list of dict for custom_meta + A nested dictionary mapping partition IDs to their keys and metadata. + + Structure: + { + "partition_id": { + "key_name": { + "tag1": , + ... (other metadata) + }, + ..., + }, + ... + } """ return self._kv_list(partition_id=partition_id) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 03e3631..3dcc1b5 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1843,22 +1843,30 @@ def _process_request(self): with perf_monitor.measure(op_type="KV_LIST"): params = request_msg.body partition_id = params["partition_id"] - partition = self._get_partition(partition_id) - if not partition: - keys = [] - custom_meta = [] - message = f"Partition {partition_id} not found for kv_list." - logger.debug(f"[{self.controller_id}]: {message}") + if partition_id is None: + partition_id = list(self.partitions.keys()) else: - keys = list(partition.keys_mapping.keys()) - custom_meta = [partition.custom_meta.get(partition.keys_mapping[k], {}) for k in keys] - message = "Success" + partition_id = [partition_id] + + message = "success" + partition_info = {} + for pid in partition_id: + partition = self._get_partition(pid) + if partition: + keys = list(partition.keys_mapping.keys()) + single_partition_info = { + k: partition.custom_meta.get(partition.keys_mapping[k], {}) for k in keys + } + partition_info[pid] = single_partition_info + else: + # this only happens when params["partition_id"] is not None + message = f"partition {pid} does not exist" response_msg = ZMQMessage.create( request_type=ZMQRequestType.KV_LIST_RESPONSE, sender_id=self.controller_id, receiver_id=request_msg.sender_id, - body={"keys": keys, "custom_meta": custom_meta, "message": message}, + body={"partition_info": partition_info, "message": message}, ) self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 42e8839..32670de 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -851,29 +851,45 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list return data -def kv_list(partition_id: str) -> tuple[list[str], list[dict[str, Any]]]: - """List all keys and their metadata in a partition. +def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: + """List all keys and their metadata in one or all partitions. Args: - partition_id: Partition to list keys from + partition_id: The specific partition_id to query. + If None (default), returns keys from all partitions. Returns: - Tuple of: - - List of keys in the partition - - List of custom metadata (tags) associated with each key + A nested dictionary mapping partition IDs to their keys and metadata. + + Structure: + { + "partition_id": { + "key_name": { + "tag1": , + ... (other metadata) + }, + ..., + }, + ... + } Example: >>> import transfer_queue as tq >>> tq.init() - >>> keys, tags = tq.kv_list(partition_id="train") - >>> print(f"Keys: {keys}") - >>> print(f"Tags: {tags}") + >>> # Case 1: Retrieve a specific partition + >>> partitions = tq.kv_list(partition_id="train") + >>> print(f"Keys: {list(partitions['train'].keys())}") + >>> print(f"Tags: {list(partitions['train'].values())}") + >>> # Case 2: Retrieve all partitions + >>> all_partitions = tq.kv_list() + >>> for pid, keys in all_partitions.items(): + >>> print(f"Partition: {pid}, Key count: {len(keys)}") """ tq_client = _maybe_create_transferqueue_client() - keys, custom_meta = tq_client.kv_list(partition_id) + partition_info = tq_client.kv_list(partition_id) - return keys, custom_meta + return partition_info def kv_clear(keys: list[str] | str, partition_id: str) -> None: @@ -1096,29 +1112,46 @@ async def async_kv_batch_get( return data -async def async_kv_list(partition_id: str) -> tuple[list[str], list[dict[str, Any]]]: - """Asynchronously list all keys and their metadata in a partition. +async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: + """Asynchronously list all keys and their metadata in one or all partitions. Args: - partition_id: Partition to list keys from + partition_id: The specific partition_id to query. + If None (default), returns keys from all partitions. Returns: - Tuple of: - - List of keys in the partition - - List of custom metadata (tags) associated with each key + A nested dictionary mapping partition IDs to their keys and metadata. + + Structure: + { + "partition_id": { + "key_name": { + "tag1": , + ... (other metadata) + }, + ..., + }, + ... + } + Example: >>> import transfer_queue as tq >>> tq.init() - >>> keys, tags = await tq.async_kv_list(partition_id="train") - >>> print(f"Keys: {keys}") - >>> print(f"Tags: {tags}") + >>> # Case 1: Retrieve a specific partition + >>> partitions = await tq.async_kv_list(partition_id="train") + >>> print(f"Keys: {list(partitions['train'].keys())}") + >>> print(f"Tags: {list(partitions['train'].values())}") + >>> # Case 2: Retrieve all partitions + >>> all_partitions = await tq.async_kv_list() + >>> for pid, keys in all_partitions.items(): + >>> print(f"Partition: {pid}, Key count: {len(keys)}") """ tq_client = _maybe_create_transferqueue_client() - keys, custom_meta = await tq_client.async_kv_list(partition_id) + partition_info = await tq_client.async_kv_list(partition_id) - return keys, custom_meta + return partition_info async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py index 51c994b..376ebfb 100644 --- a/tutorial/02_kv_interface.py +++ b/tutorial/02_kv_interface.py @@ -156,15 +156,17 @@ def demonstrate_kv_api(): # Step 5: List all keys and tags in a partition print("\n[Step 5] Listing all keys and tags in partition...") - all_keys, all_tags = tq.kv_list(partition_id=partition_id) - print(f" Found {len(all_keys)} keys in partition '{partition_id}':") - for k, t in zip(all_keys, all_tags, strict=False): - print(f" - key='{k}' | tag={t}") + partition_info = tq.kv_list() + print(f" Found {len(partition_info.keys())} partitions: '{list(partition_info.keys())}'") + for pid, keys_and_tags in partition_info.items(): + for k, t in keys_and_tags.items(): + print(f"Partition: {pid}, - key='{k}' | tag={t}") # Step 6: Retrieve specific fields using kv_batch_get print("\n[Step 6] Retrieving specific fields (Column) with kv_batch_get...") print(" Fetching only 'input_ids' to save bandwidth (ignoring 'attention_mask' and 'response').") + all_keys = list(partition_info[partition_id].keys()) retrieved_input_ids = tq.kv_batch_get(keys=all_keys, partition_id=partition_id, fields="input_ids") print(f" ✓ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.") @@ -183,8 +185,8 @@ def demonstrate_kv_api(): tq.kv_clear(keys=keys_to_clear, partition_id=partition_id) print(f" ✓ Cleared keys: {keys_to_clear}") - remaining_keys, _ = tq.kv_list(partition_id=partition_id) - print(f" Remaining keys in partition: {remaining_keys}") + partition_info_after_clear = tq.kv_list(partition_id=partition_id) + print(f" Remaining keys in partition: {list(partition_info_after_clear[partition_id].keys())}") def main(): From 3678eac0f9a9e54d4293a9f65303e73fcbe01f18 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 15:01:23 +0800 Subject: [PATCH 25/34] provide KVBatchMeta Signed-off-by: 0oshowero0 --- transfer_queue/__init__.py | 29 +++--- transfer_queue/metadata.py | 202 +++++++++++++++++++++++++++++++++++++ 2 files changed, 217 insertions(+), 14 deletions(-) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 26f1225..60c8f1d 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -44,7 +44,7 @@ put, set_custom_meta, ) -from .metadata import BatchMeta +from .metadata import BatchMeta, KVBatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler from .sampler.rank_aware_sampler import RankAwareSampler @@ -55,18 +55,6 @@ __all__ = [ "init", - "get_meta", - "get_data", - "put", - "set_custom_meta", - "clear_samples", - "clear_partition", - "async_get_meta", - "async_get_data", - "async_put", - "async_set_custom_meta", - "async_clear_samples", - "async_clear_partition", "close", "kv_put", "kv_batch_put", @@ -78,11 +66,24 @@ "async_kv_batch_get", "async_kv_list", "async_kv_clear", + "KVBatchMeta", ] + [ + "get_meta", + "get_data", + "put", + "set_custom_meta", + "clear_samples", + "clear_partition", + "async_get_meta", + "async_get_data", + "async_put", + "async_set_custom_meta", + "async_clear_samples", + "async_clear_partition", + "BatchMeta", "TransferQueueClient", "StreamingDataset", "StreamingDataLoader", - "BatchMeta", "TransferQueueController", "SimpleStorageUnit", "ZMQServerInfo", diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 2abde1b..eb2520c 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -829,3 +829,205 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> ] return all_fields + + +# ==================== KV Interface Metadata ==================== +@dataclass +class KVBatchMeta: + """Records the metadata for KV interface.""" + + # keys of each sample + keys: list[str] = dataclasses.field(default_factory=list) + + # sample-level tags + tags: list[dict] = dataclasses.field(default_factory=list) + + # [optional] partition_id of this batch + partition_id: Optional[str] = None + + # [optional] fields of each sample + fields: list[str] = dataclasses.field(default_factory=list) + + # [optional] external information for batch-level information + extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + """Validate all the variables""" + if len(self.keys) != len(self.tags): + raise ValueError(f"keys and tags must have same length, but got {len(self.keys)} and {len(self.tags)}") + if len(self.keys) != len(self.partition_id): + raise ValueError( + f"keys and partition_ids must have same length, but got {len(self.keys)} and {len(self.partition_ids)}" + ) + if len(self.keys) != len(set(self.keys)): + raise ValueError("Got duplicated keys.") + if len(self.fields) != len(set(self.fields)): + raise ValueError("Got duplicated fields.") + + # deepcopy to prevent unexpected behavior after chunk/concat + self.tags = copy.deepcopy(self.tags) + self.extra_info = copy.deepcopy(self.extra_info) + + object.__setattr__(self, "_size", len(self.keys)) + + @property + def size(self) -> int: + """Return the number of samples in this batch""" + return getattr(self, "_size", 0) + + def __len__(self) -> int: + """Return the number of samples in this batch.""" + return len(self.keys) + + def __str__(self): + return f"KVBatchMeta(size={self.size}, field_names={self.fields}, extra_info={self.extra_info})" + + def select_keys(self, keys_to_select: list[str]) -> "KVBatchMeta": + """ + Select specific keys from this batch. + This will construct a new KVBatchMeta instance containing only the specified keys. + + Args: + keys_to_select (list[str]): List of keys to retain. + + Returns: + KVBatchMeta: A new KVBatchMeta instance containing only the specified keys. + + Raises: + ValueError: If duplicate keys exist in input param `keys_to_select`. + RuntimeError: If `keys_to_select` contains keys that do not exist in this batch. + """ + + if len(set(keys_to_select)) != len(keys_to_select): + raise ValueError("Contain duplicate keys.") + + non_exist_keys = set(keys_to_select) - set(self.keys) + if len(non_exist_keys) > 0: + raise RuntimeError(f"Keys {non_exist_keys} not found in current batch.") + + _keys_to_idx = {key: idx for idx, key in enumerate(self.keys)} + + loc_idx = [_keys_to_idx[k] for k in keys_to_select] + tags = [self.tags[i] for i in loc_idx] + + return KVBatchMeta( + keys=keys_to_select, + tags=tags, + partition_id=self.partition_id, + fields=self.fields, + extra_info=self.extra_info, + ) + + def reorder(self, indexes: list[int]): + """ + Reorder the samples in this batch according to the specified indexes. + + The operation is performed in-place. + + Args: + indexes : list[int] + A list of integers specifying the new order of SampleMeta. + + Raises: + ValueError: If the size of input `indexes` does not match with the batch size. + ValueError: If duplicate indexes exist in input param `indexes`. + """ + if len(indexes) != self.size: + raise ValueError( + f"Attempted to reorder with indexes length {len(indexes)} that does not match " + f"the batch size {self.size}." + ) + + if len(set(indexes)) != len(indexes): + raise ValueError("Contain duplicate indexes.") + + self.keys = [self.keys[i] for i in indexes] + self.tags = [self.tags[i] for i in indexes] + + def chunk(self, chunks: int) -> list["KVBatchMeta"]: + """ + Split this batch into smaller chunks. + + Args: + chunks: number of chunks + + Return: + List of smaller KVBatchMeta chunks + """ + + chunk_list = [] + if self.size < chunks: + logger.warning( + f"Chunk size {chunks} > number of samples in this batch {self.size}, this will return some " + f"empty KVBatchMeta chunks." + ) + + # Calculate the base size and remainder of each chunk + base_size = self.size // chunks + remainder = self.size % chunks + + start = 0 + for i in range(chunks): + # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size) + current_chunk_size = base_size + 1 if i < remainder else base_size + end = start + current_chunk_size + chunk_keys = self.keys[start:end] + chunk_tags = self.tags[start:end] + + chunk = KVBatchMeta( + keys=chunk_keys, + tags=chunk_tags, + partition_id=self.partition_id, + fields=self.fields, + extra_info=self.extra_info, + ) + chunk_list.append(chunk) + start = end + + return chunk_list + + @classmethod + def concat(cls, data: list["KVBatchMeta"]) -> "KVBatchMeta": + """ + Concatenate multiple KVBatchMeta chunks into one large batch. + + Args: + data: List of KVBatchMeta chunks to concatenate + + Returns: + Concatenated KVBatchMeta + + Raises: + ValueError: If validation fails (e.g., field names do not match) + """ + if not data: + logger.warning("Try to concat empty KVBatchMeta chunks. Returning empty KVBatchMeta.") + return KVBatchMeta() + + # skip empty chunks + data = [chunk for chunk in data if chunk and chunk.size > 0] + + if len(data) == 0: + logger.warning("No valid KVBatchMeta chunks to concatenate. Returning empty KVBatchMeta.") + return KVBatchMeta() + + base_fields = data[0].fields + base_fields_set = set(base_fields) + base_partition_id = data[0].partition_id + + all_keys = [] + all_tags = [] + all_extra_info = {} + for chunk in data: + if set(chunk.fields) != base_fields_set: + raise ValueError("Field names do not match for concatenation.") + if chunk.partition_id != base_partition_id: + raise ValueError("Partition do not match for concatenation.") + + all_keys.extend(chunk.keys) + all_tags.extend(chunk.tags) + all_extra_info.update(chunk.extra_info) + + return KVBatchMeta( + keys=all_keys, tags=all_tags, partition_id=base_partition_id, fields=base_fields, extra_info=all_extra_info + ) From 42a78cf5f2e27bbffe454a0c526ea7bc31620071 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 15:18:38 +0800 Subject: [PATCH 26/34] add KVBatchMeta tests Signed-off-by: 0oshowero0 --- tests/test_metadata.py | 289 ++++++++++++++++++++++++++++++++++++- transfer_queue/metadata.py | 4 - 2 files changed, 288 insertions(+), 5 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 6db10a6..23ba56d 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -27,7 +27,7 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 +from transfer_queue.metadata import BatchMeta, FieldMeta, KVBatchMeta, SampleMeta # noqa: E402 from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 @@ -1045,3 +1045,290 @@ def test_batch_meta_concat_validation_error(self): with pytest.raises(ValueError) as exc_info: BatchMeta.concat([batch1, batch2], validate=True) assert "Field names do not match" in str(exc_info.value) + + +class TestKVBatchMeta: + """KVBatchMeta Tests""" + + def test_kv_batch_meta_basic_init(self): + """Example: Basic KVBatchMeta initialization.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{"sample_id": 0}, {"sample_id": 1}, {"sample_id": 2}], + partition_id="partition_0", + fields=["field1", "field2"], + ) + + assert kv_meta.size == 3 + assert len(kv_meta) == 3 + assert kv_meta.keys == ["key1", "key2", "key3"] + assert kv_meta.partition_id == "partition_0" + assert kv_meta.fields == ["field1", "field2"] + + def test_kv_batch_meta_empty_init(self): + """Example: Empty KVBatchMeta initialization.""" + kv_meta = KVBatchMeta() + + assert kv_meta.size == 0 + assert len(kv_meta) == 0 + assert kv_meta.keys == [] + assert kv_meta.tags == [] + assert kv_meta.partition_id is None + assert kv_meta.fields == [] + + def test_kv_batch_meta_init_validation_keys_tags_mismatch(self): + """Example: Init validation catches keys and tags length mismatch.""" + with pytest.raises(ValueError) as exc_info: + KVBatchMeta( + keys=["key1", "key2"], + tags=[{"sample_id": 0}], # Only one tag + ) + assert "keys and tags must have same length" in str(exc_info.value) + + def test_kv_batch_meta_init_validation_duplicate_keys(self): + """Example: Init validation catches duplicate keys.""" + with pytest.raises(ValueError) as exc_info: + KVBatchMeta( + keys=["key1", "key1"], + tags=[{"sample_id": 0}, {"sample_id": 1}], + partition_id="partition_0", + ) + assert "Got duplicated keys" in str(exc_info.value) + + def test_kv_batch_meta_init_validation_duplicate_fields(self): + """Example: Init validation catches duplicate fields.""" + with pytest.raises(ValueError) as exc_info: + KVBatchMeta( + keys=["key1"], + tags=[{"sample_id": 0}], + partition_id="partition_0", + fields=["field1", "field1"], + ) + assert "Got duplicated fields" in str(exc_info.value) + + def test_kv_batch_meta_select_keys(self): + """Example: Select specific keys from KVBatchMeta.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{"idx": 0}, {"idx": 1}, {"idx": 2}], + partition_id="partition_0", + fields=["field1", "field2"], + extra_info={"test": "value"}, + ) + + selected = kv_meta.select_keys(["key1", "key3"]) + + assert selected.keys == ["key1", "key3"] + assert selected.tags == [{"idx": 0}, {"idx": 2}] + assert selected.partition_id == "partition_0" + assert selected.fields == ["field1", "field2"] + assert selected.extra_info == {"test": "value"} + + def test_kv_batch_meta_select_keys_validation_duplicate(self): + """Example: Select keys validation catches duplicate keys in input.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{}, {}, {}], + ) + + with pytest.raises(ValueError) as exc_info: + kv_meta.select_keys(["key1", "key1"]) + assert "Contain duplicate keys" in str(exc_info.value) + + def test_kv_batch_meta_select_keys_validation_nonexistent(self): + """Example: Select keys validation catches non-existent keys.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{}, {}, {}], + ) + + with pytest.raises(RuntimeError) as exc_info: + kv_meta.select_keys(["key1", "nonexistent"]) + assert "not found in current batch" in str(exc_info.value) + + def test_kv_batch_meta_reorder(self): + """Example: Reorder samples in KVBatchMeta.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{"idx": 0}, {"idx": 1}, {"idx": 2}], + ) + + kv_meta.reorder([2, 0, 1]) + + assert kv_meta.keys == ["key3", "key1", "key2"] + assert kv_meta.tags == [{"idx": 2}, {"idx": 0}, {"idx": 1}] + + def test_kv_batch_meta_reorder_validation_size_mismatch(self): + """Example: Reorder validation catches size mismatch.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{}, {}, {}], + ) + + with pytest.raises(ValueError) as exc_info: + kv_meta.reorder([0, 1]) # Only 2 indexes for 3 samples + assert "does not match" in str(exc_info.value) + + def test_kv_batch_meta_reorder_validation_duplicate_indexes(self): + """Example: Reorder validation catches duplicate indexes.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{}, {}, {}], + ) + + with pytest.raises(ValueError) as exc_info: + kv_meta.reorder([0, 0, 1]) # Duplicate index 0 + assert "Contain duplicate indexes" in str(exc_info.value) + + def test_kv_batch_meta_chunk(self): + """Example: Split KVBatchMeta into multiple chunks.""" + kv_meta = KVBatchMeta( + keys=[f"key{i}" for i in range(10)], + tags=[{"idx": i} for i in range(10)], + partition_id="partition_0", + fields=["field1"], + extra_info={"test": "value"}, + ) + + chunks = kv_meta.chunk(3) + + assert len(chunks) == 3 + assert len(chunks[0]) == 4 # First chunk gets extra element + assert len(chunks[1]) == 3 + assert len(chunks[2]) == 3 + + # Verify partition_id and fields are preserved + assert chunks[0].partition_id == "partition_0" + assert chunks[0].fields == ["field1"] + assert chunks[0].extra_info == {"test": "value"} + + # Verify keys and tags are correctly chunked + assert chunks[0].keys == ["key0", "key1", "key2", "key3"] + assert chunks[0].tags == [{"idx": 0}, {"idx": 1}, {"idx": 2}, {"idx": 3}] + assert chunks[1].keys == ["key4", "key5", "key6"] + assert chunks[1].tags == [{"idx": 4}, {"idx": 5}, {"idx": 6}] + + def test_kv_batch_meta_chunk_with_more_chunks_than_samples(self): + """Example: Chunking when chunks > samples produces empty chunks.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2"], + tags=[{"idx": 0}, {"idx": 1}], + ) + + chunks = kv_meta.chunk(5) + + assert len(chunks) == 5 + assert len(chunks[0]) == 1 + assert len(chunks[1]) == 1 + assert len(chunks[2]) == 0 + assert len(chunks[3]) == 0 + assert len(chunks[4]) == 0 + + def test_kv_batch_meta_concat(self): + """Example: Concatenate multiple KVBatchMeta chunks.""" + kv_meta1 = KVBatchMeta( + keys=["key0", "key1"], + tags=[{"idx": 0}, {"idx": 1}], + partition_id="partition_0", + fields=["field1"], + extra_info={"test": "value1"}, + ) + + kv_meta2 = KVBatchMeta( + keys=["key2", "key3"], + tags=[{"idx": 2}, {"idx": 3}], + partition_id="partition_0", + fields=["field1"], + extra_info={"test": "value2"}, + ) + + result = KVBatchMeta.concat([kv_meta1, kv_meta2]) + + assert result.size == 4 + assert result.keys == ["key0", "key1", "key2", "key3"] + assert result.tags == [{"idx": 0}, {"idx": 1}, {"idx": 2}, {"idx": 3}] + assert result.partition_id == "partition_0" + assert result.fields == ["field1"] + + def test_kv_batch_meta_concat_with_empty_chunks(self): + """Example: Concat handles empty KVBatchMeta chunks gracefully.""" + kv_meta1 = KVBatchMeta() + kv_meta2 = KVBatchMeta(keys=["key0"], tags=[{"idx": 0}]) + kv_meta3 = KVBatchMeta() + + result = KVBatchMeta.concat([kv_meta1, kv_meta2, kv_meta3]) + + assert result.size == 1 + assert result.keys == ["key0"] + assert result.tags == [{"idx": 0}] + + def test_kv_batch_meta_concat_validation_field_mismatch(self): + """Example: Concat validation catches field name mismatches.""" + kv_meta1 = KVBatchMeta( + keys=["key0"], + tags=[{}], + fields=["field1"], + ) + kv_meta2 = KVBatchMeta( + keys=["key1"], + tags=[{}], + fields=["field2"], # Different field + ) + + with pytest.raises(ValueError) as exc_info: + KVBatchMeta.concat([kv_meta1, kv_meta2]) + assert "Field names do not match" in str(exc_info.value) + + def test_kv_batch_meta_concat_validation_partition_mismatch(self): + """Example: Concat validation catches partition_id mismatches.""" + kv_meta1 = KVBatchMeta( + keys=["key0"], + tags=[{}], + partition_id="partition_0", + ) + kv_meta2 = KVBatchMeta( + keys=["key1"], + tags=[{}], + partition_id="partition_1", # Different partition + ) + + with pytest.raises(ValueError) as exc_info: + KVBatchMeta.concat([kv_meta1, kv_meta2]) + assert "Partition do not match" in str(exc_info.value) + + def test_kv_batch_meta_concat_empty_list(self): + """Example: Concat with empty list returns empty KVBatchMeta.""" + result = KVBatchMeta.concat([]) + + assert result.size == 0 + assert result.keys == [] + assert result.tags == [] + + def test_kv_batch_meta_deepcopy_tags(self): + """Example: Tags are deep copied to prevent mutation.""" + original_tags = [{"data": [1, 2, 3]}] + kv_meta = KVBatchMeta( + keys=["key1"], + tags=original_tags, + ) + + # Modify the tag in the KVBatchMeta + kv_meta.tags[0]["data"].append(4) + + # Original should not be modified + assert original_tags[0]["data"] == [1, 2, 3] + + def test_kv_batch_meta_deepcopy_extra_info(self): + """Example: Extra info is deep copied to prevent mutation.""" + original_extra = {"nested": {"value": 1}} + kv_meta = KVBatchMeta( + keys=["key1"], + tags=[{}], + extra_info=original_extra, + ) + + # Modify extra_info + kv_meta.extra_info["nested"]["value"] = 999 + + # Original should not be modified + assert original_extra["nested"]["value"] == 1 diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index eb2520c..eec0058 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -855,10 +855,6 @@ def __post_init__(self): """Validate all the variables""" if len(self.keys) != len(self.tags): raise ValueError(f"keys and tags must have same length, but got {len(self.keys)} and {len(self.tags)}") - if len(self.keys) != len(self.partition_id): - raise ValueError( - f"keys and partition_ids must have same length, but got {len(self.keys)} and {len(self.partition_ids)}" - ) if len(self.keys) != len(set(self.keys)): raise ValueError("Got duplicated keys.") if len(self.fields) != len(set(self.fields)): From b301c740d507607bfcf2184472498a574cbfa66c Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 16:34:43 +0800 Subject: [PATCH 27/34] fix comments Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 2 +- transfer_queue/controller.py | 6 +++--- transfer_queue/interface.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 5b0c1e1..b06f90c 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -812,7 +812,7 @@ async def async_reset_consumption( task_name: Optional[str] = None, socket: Optional[zmq.asyncio.Socket] = None, ) -> bool: - """Aynchronously reset consumption status for a partition. + """Asynchronously reset consumption status for a partition. This allows the same data to be re-consumed, useful for debugging scenarios where the same rollout data needs to be trained multiple times. diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 3dcc1b5..3c7c3f9 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -988,19 +988,19 @@ def list_partitions(self) -> list[str]: # ==================== Partition Index Management API ==================== - def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]: + def get_partition_index_range(self, partition_id: str) -> list[int]: """ Get all indexes for a specific partition. Args: - partition: Partition identifier + partition_id: Partition identifier Returns: List of indexes allocated to the partition """ # Note: This includes the pre-allocated global_indexes for the partition. # i.e., partition.global_indexes + partition.pre_allocated_global_indexes - return self.index_manager.get_indexes_for_partition(partition) + return self.index_manager.get_indexes_for_partition(partition_id) # ==================== Data Production API ==================== diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 32670de..70ce7c2 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -701,7 +701,7 @@ def kv_put( ... ) """ if fields is None and tag is None: - raise ValueError("Please provide at least one parameter of fields or tag.") + raise ValueError("Please provide at least one parameter of `fields` or `tag`.") tq_client = _maybe_create_transferqueue_client() @@ -712,7 +712,7 @@ def kv_put( raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") # 2. register the user-specified tag to BatchMeta - if tag: + if tag is not None: batch_meta.update_custom_meta([tag]) # 3. put data @@ -1030,7 +1030,7 @@ async def async_kv_batch_put( """ if fields is None and tags is None: - raise ValueError("Please provide at least one parameter of fields or tag.") + raise ValueError("Please provide at least one parameter of `fields` or `tags`.") if fields is not None and fields.batch_size[0] != len(keys): raise ValueError( @@ -1049,7 +1049,7 @@ async def async_kv_batch_put( ) # 2. register the user-specified tags to BatchMeta - if tags: + if tags is not None: if len(tags) != len(keys): raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}") batch_meta.update_custom_meta(tags) From f53cad973867b3fa0f7383b217f23ab91f7569d9 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 17:22:22 +0800 Subject: [PATCH 28/34] change default value to None in KVBatchMeta Signed-off-by: 0oshowero0 --- transfer_queue/metadata.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index eec0058..473a487 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -846,10 +846,10 @@ class KVBatchMeta: partition_id: Optional[str] = None # [optional] fields of each sample - fields: list[str] = dataclasses.field(default_factory=list) + fields: Optional[list[str]] = None # [optional] external information for batch-level information - extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) + extra_info: Optional[dict[str, Any]] = None def __post_init__(self): """Validate all the variables""" @@ -1008,22 +1008,31 @@ def concat(cls, data: list["KVBatchMeta"]) -> "KVBatchMeta": return KVBatchMeta() base_fields = data[0].fields - base_fields_set = set(base_fields) + if base_fields is not None: + base_fields_set = set(base_fields) + else: + base_fields_set = set() + base_partition_id = data[0].partition_id all_keys = [] all_tags = [] all_extra_info = {} for chunk in data: - if set(chunk.fields) != base_fields_set: + if chunk.fields is not None and set(chunk.fields) != base_fields_set: raise ValueError("Field names do not match for concatenation.") if chunk.partition_id != base_partition_id: raise ValueError("Partition do not match for concatenation.") all_keys.extend(chunk.keys) all_tags.extend(chunk.tags) - all_extra_info.update(chunk.extra_info) + if chunk.extra_info is not None: + all_extra_info.update(chunk.extra_info) return KVBatchMeta( - keys=all_keys, tags=all_tags, partition_id=base_partition_id, fields=base_fields, extra_info=all_extra_info + keys=all_keys, + tags=all_tags, + partition_id=base_partition_id, + fields=base_fields, + extra_info=all_extra_info if all_extra_info else None, ) From f9a9830e74cc6797577a0a55b466ef0bc1feb4f9 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 17:28:06 +0800 Subject: [PATCH 29/34] update Signed-off-by: 0oshowero0 --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 150c666..8131b94 100644 --- a/README.md +++ b/README.md @@ -95,10 +95,10 @@ This data structure design is motivated by the computational characteristics of ### User Interface: High-Level & Low-Level APIs | Level | Tier | Style | Fine-Grained Access | Streaming | Sampler | Multiple-Backends | -|---|---|---|---|---|---|---| -| High | **KV Interface** (this PR) | Put/Get/List/Clear | ✓ | ✗ | ✗ | ✓ | -| High | **StreamingDataLoader** (#23) | PyTorch DataLoader | ✓ |✓ | ✓ | ✓ | -| Low | **TransferQueueClient** | Metadata-based | ✓ | ✓ | ✓ | ✓ | +|---|---|---|---|------------------|---|---| +| High | **KV Interface** (this PR) | Put/Get/List/Clear | ✓ | ○ | ✗ | ✓ | +| High | **StreamingDataLoader** (#23) | PyTorch DataLoader | ✓ | ✓ | ✓ | ✓ | +| Low | **TransferQueueClient** | Metadata-based | ✓ | ✓ | ✓ | ✓ | #### Key-Value based API From f49741d853cdf31ca868184dfd6c25c2450207f2 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 17:32:18 +0800 Subject: [PATCH 30/34] change function order Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 1364 +++++++++++++++++------------------ 1 file changed, 681 insertions(+), 683 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 70ce7c2..fb6b5f7 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -107,6 +107,7 @@ def _init_from_existing() -> None: time.sleep(1) +# ==================== Initialization API ==================== def init(conf: Optional[DictConfig] = None) -> None: """Initialize the TransferQueue system. @@ -196,480 +197,306 @@ def init(conf: Optional[DictConfig] = None) -> None: _maybe_create_transferqueue_client(final_conf) -# ==================== Basic API ==================== -def get_meta( - data_fields: list[str], - batch_size: int, - partition_id: str, - mode: str = "fetch", - task_name: Optional[str] = None, - sampling_config: Optional[dict[str, Any]] = None, -) -> BatchMeta: - """Synchronously fetch data metadata from the controller via ZMQ. - - Args: - data_fields: List of data field names to retrieve metadata for - batch_size: Number of samples to request in the batch - partition_id: Current data partition id - mode: Data fetch mode. Options: - - 'fetch': Get ready data only - - 'force_fetch': Get data regardless of readiness (may return unready samples) - - 'insert': Internal usage - should not be used by users - task_name: Optional task name associated with the request - sampling_config: Optional sampling configuration for custom samplers. - - Returns: - BatchMeta: Metadata object containing data structure, sample information, and readiness status - - Raises: - RuntimeError: If communication fails or controller returns error response - - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> # Example 1: Basic fetch metadata - >>> batch_meta = tq.get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences" - ... ) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) - >>> batch_meta = tq.get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=8, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", - ... sampling_config={"n_samples_per_prompt": 4} - ... ) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, - >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) - >>> batch_meta = tq.get_meta( - ... partition_id="train_0", # optional - ... mode="force_fetch", - ... ) - >>> print(batch_meta.is_ready) # May be False if some samples not ready - """ - - tq_client = _maybe_create_transferqueue_client() - return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) - - -def set_custom_meta(metadata: BatchMeta) -> None: - """Synchronously send custom metadata to the controller. - - This method sends per-sample custom metadata (custom_meta) to the controller. - The custom_meta is stored in the controller and can be retrieved along with - the BatchMeta in subsequent get_meta calls. - - Args: - metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() - before calling this method. +def close(): + """Close the TransferQueue system. - Raises: - RuntimeError: If communication fails or controller returns error response + This function cleans up the TransferQueue system, including: + - Closing the client and its associated resources + - Cleaning up distributed storage (only for the process that initialized it) + - Killing the controller actor - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> # Create batch with custom metadata - >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...) - >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) - >>> tq.set_custom_meta(batch_meta) + Note: + This function should be called when the TransferQueue system is no longer needed. """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.set_custom_meta(metadata) + global _TRANSFER_QUEUE_CLIENT + global _TRANSFER_QUEUE_STORAGE + if _TRANSFER_QUEUE_CLIENT: + _TRANSFER_QUEUE_CLIENT.close() + _TRANSFER_QUEUE_CLIENT = None + try: + if _TRANSFER_QUEUE_STORAGE: + # only the process that do first-time init can clean the distributed storage + for storage in _TRANSFER_QUEUE_STORAGE.values(): + ray.kill(storage) + _TRANSFER_QUEUE_STORAGE = None + except Exception: + pass -def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta: - """Synchronously write data to storage units based on metadata. + try: + controller = ray.get_actor("TransferQueueController") + ray.kill(controller) + except Exception: + pass - If metadata is not provided, it will be created automatically using insert mode - with the provided data fields and partition_id. - During put, the custom_meta in metadata will update the corresponding custom_meta in - TransferQueue Controller. +# ==================== High-Level KV Interface API ==================== +def kv_put( + key: str, + partition_id: str, + fields: Optional[TensorDict | dict[str, Any]] = None, + tag: Optional[dict[str, Any]] = None, +) -> None: + """Put a single key-value pair to TransferQueue. - Note: - When using multiple workers for distributed execution, there may be data - ordering inconsistencies between workers during put operations. + This is a convenience method for putting data using a user-specified key + instead of BatchMeta. Internally, the key is translated to a BatchMeta + and the data is stored using the regular put mechanism. Args: - data: Data to write as TensorDict - metadata: Records the metadata of a batch of data samples, containing index and - storage unit information. If None, metadata will be auto-generated. - partition_id: Target data partition id (required if metadata is not provided) - - Returns: - BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved - metadata; will be updated in a future version to reflect the post-put state) + key: User-specified key for the data sample (in row) + partition_id: Logical partition to store the data in + fields: Data fields to store. Can be a TensorDict or a dict of tensors. + Each key in `fields` will be treated as a column for the data sample. + If dict is provided, tensors will be unsqueezed to add batch dimension. + tag: Optional metadata tag to associate with the key Raises: - ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided - RuntimeError: If storage operation fails - - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_partition_id = "train_0" - >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = tq.get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=batch_size, - ... partition_id=current_partition_id, - ... mode="fetch", - ... task_name="generate_sequences", - ... ) - >>> batch = tq.get_data(batch_meta) - >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> tq.put(data=output, metadata=batch_meta) - >>> - >>> # Example 2: Initial data insertion without pre-existing metadata - >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! - >>> # Please make sure the corresponding partition_id is empty before calling the async_put() - >>> # without metadata. - >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with - >>> # interleave the initial data if n_sample > 1 before calling the async_put(). - >>> original_prompts = torch.randn(batch_size, seq_len) - >>> n_samples = 4 - >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) - >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) - >>> # This will create metadata in "insert" mode internally. - >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id) - """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.put(data, metadata, partition_id) - - -def get_data(metadata: BatchMeta) -> TensorDict: - """Synchronously fetch data from storage units and organize into TensorDict. - - Args: - metadata: Batch metadata containing data location information and global indexes - - Returns: - TensorDict containing: - - Requested data fields (e.g., "prompts", "attention_mask") + ValueError: If neither fields nor tag is provided + ValueError: If nested tensors are provided (use kv_batch_put instead) + RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` Example: >>> import transfer_queue as tq + >>> import torch >>> tq.init() - >>> - >>> batch_meta = tq.get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", + >>> # Put with both fields and tag + >>> tq.kv_put( + ... key="sample_1", + ... partition_id="train", + ... fields={"input_ids": torch.tensor([1, 2, 3])}, + ... tag={"score": 0.95} ... ) - >>> batch = tq.get_data(batch_meta) - >>> print(batch) - >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.get_data(metadata) - - -def clear_partition(partition_id: str): - """Synchronously clear the whole partition from all storage units and the controller. - - Args: - partition_id: The partition id to clear data for - - Raises: - RuntimeError: If clear operation fails - """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.clear_partition(partition_id) - - -def clear_samples(metadata: BatchMeta): - """Synchronously clear specific samples from all storage units and the controller. - - Args: - metadata: The BatchMeta of the corresponding data to be cleared + if fields is None and tag is None: + raise ValueError("Please provide at least one parameter of `fields` or `tag`.") - Raises: - RuntimeError: If clear operation fails - """ tq_client = _maybe_create_transferqueue_client() - return tq_client.clear_samples(metadata) - -async def async_get_meta( - data_fields: list[str], - batch_size: int, - partition_id: str, - mode: str = "fetch", - task_name: Optional[str] = None, - sampling_config: Optional[dict[str, Any]] = None, -) -> BatchMeta: - """Asynchronously fetch data metadata from the controller via ZMQ. - - Args: - data_fields: List of data field names to retrieve metadata for - batch_size: Number of samples to request in the batch - partition_id: Current data partition id - mode: Data fetch mode. Options: - - 'fetch': Get ready data only - - 'force_fetch': Get data regardless of readiness (may return unready samples) - - 'insert': Internal usage - should not be used by users - task_name: Optional task name associated with the request - sampling_config: Optional sampling configuration for custom samplers. + # 1. translate user-specified key to BatchMeta + batch_meta = tq_client.kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True) - Returns: - BatchMeta: Metadata object containing data structure, sample information, and readiness status + if batch_meta.size != 1: + raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") - Raises: - RuntimeError: If communication fails or controller returns error response + # 2. register the user-specified tag to BatchMeta + if tag is not None: + batch_meta.update_custom_meta([tag]) - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> # Example 1: Basic fetch metadata - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences" - ... )) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=8, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", - ... )) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, - >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... partition_id="train_0", # optional - ... mode="force_fetch", - ... )) - >>> print(batch_meta.is_ready) # May be False if some samples not ready - """ + # 3. put data + if fields is not None: + if isinstance(fields, dict): + # TODO: consider whether to support this... + batch = {} + for field_name, value in fields.items(): + if isinstance(value, torch.Tensor): + if value.is_nested: + raise ValueError("Please use (async)kv_batch_put for batch operation") + batch[field_name] = value.unsqueeze(0) + else: + batch[field_name] = NonTensorStack(value) + fields = TensorDict(batch, batch_size=[1]) + elif not isinstance(fields, TensorDict): + raise ValueError("field can only be dict or TensorDict") - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) + # custom_meta (tag) will be put to controller through the internal put process + tq_client.put(fields, batch_meta) + else: + # directly update custom_meta (tag) to controller + tq_client.set_custom_meta(batch_meta) -async def async_set_custom_meta( - metadata: BatchMeta, +def kv_batch_put( + keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None ) -> None: - """ - Asynchronously send custom metadata to the controller. + """Put multiple key-value pairs to TransferQueue in batch. - This method sends per-sample custom metadata (custom_meta) to the controller. - The custom_meta is stored in the controller and can be retrieved along with - the BatchMeta in subsequent get_meta calls. + This method stores multiple key-value pairs in a single operation, which is more + efficient than calling kv_put multiple times. Args: - metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() - before calling this method. - socket: ZMQ async socket for message transmission (injected by decorator) + keys: List of user-specified keys for the data + partition_id: Logical partition to store the data in + fields: TensorDict containing data for all keys. Must have batch_size == len(keys) + tags: List of metadata tags, one for each key Raises: - RuntimeError: If communication fails or controller returns error response + ValueError: If neither `fields` nor `tags` is provided + ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict + RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` Example: >>> import transfer_queue as tq + >>> from tensordict import TensorDict >>> tq.init() - >>> - >>> # Create batch with custom metadata - >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...) - >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) - >>> asyncio.run(tq.async_set_custom_meta(batch_meta)) + >>> keys = ["sample_1", "sample_2", "sample_3"] + >>> fields = TensorDict({ + ... "input_ids": torch.randn(3, 10), + ... "attention_mask": torch.ones(3, 10), + ... }, batch_size=3) + >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}] + >>> tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) """ + + if fields is None and tags is None: + raise ValueError("Please provide at least one parameter of fields or tag.") + + if fields is not None and fields.batch_size[0] != len(keys): + raise ValueError( + f"`keys` with length {len(keys)} does not match the `fields` TensorDict with " + f"batch_size {fields.batch_size[0]}" + ) + tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_set_custom_meta(metadata) + # 1. translate user-specified key to BatchMeta + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True) -async def async_put( - data: TensorDict, - metadata: Optional[BatchMeta] = None, - partition_id: Optional[str] = None, -) -> BatchMeta: - """Asynchronously write data to storage units based on metadata. + if batch_meta.size != len(keys): + raise RuntimeError( + f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!" + ) - If metadata is not provided, it will be created automatically using insert mode - with the provided data fields and partition_id. + # 2. register the user-specified tags to BatchMeta + if tags: + if len(tags) != len(keys): + raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}") + batch_meta.update_custom_meta(tags) - During put, the custom_meta in metadata will update the corresponding custom_meta in - TransferQueue Controller. + # 3. put data + if fields is not None: + tq_client.put(fields, batch_meta) + else: + # directly update custom_meta (tags) to controller + tq_client.set_custom_meta(batch_meta) - Note: - When using multiple workers for distributed execution, there may be data - ordering inconsistencies between workers during put operations. + +def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict: + """Get data from TransferQueue using user-specified keys. + + This is a convenience method for retrieving data using keys instead of indexes. Args: - data: Data to write as TensorDict - metadata: Records the metadata of a batch of data samples, containing index and - storage unit information. If None, metadata will be auto-generated. - partition_id: Target data partition id (required if metadata is not provided) + keys: Single key or list of keys to retrieve + partition_id: Partition containing the keys + fields: Optional field(s) to retrieve. If None, retrieves all fields Returns: - BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved - metadata; will be updated in a future version to reflect the post-put state) + TensorDict with the requested data Raises: - ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided - RuntimeError: If storage operation fails + RuntimeError: If keys or partition are not found + RuntimeError: If empty fields exist in any key (sample) Example: >>> import transfer_queue as tq >>> tq.init() - >>> - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_partition_id = "train_0" - >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=batch_size, - ... partition_id=current_partition_id, - ... mode="fetch", - ... task_name="generate_sequences", - ... )) - >>> batch = asyncio.run(tq.async_get_data(batch_meta)) - >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta)) - >>> - >>> # Example 2: Initial data insertion without pre-existing metadata - >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! - >>> # Please make sure the corresponding partition_id is empty before calling the async_put() - >>> # without metadata. - >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with - >>> # interleave the initial data if n_sample > 1 before calling the async_put(). - >>> original_prompts = torch.randn(batch_size, seq_len) - >>> n_samples = 4 - >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) - >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) - >>> # This will create metadata in "insert" mode internally. - >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) + >>> # Get single key with all fields + >>> data = tq.kv_batch_get(keys="sample_1", partition_id="train") + >>> # Get multiple keys with specific fields + >>> data = tq.kv_batch_get( + ... keys=["sample_1", "sample_2"], + ... partition_id="train", + ... fields="input_ids" + ... ) """ tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_put(data, metadata, partition_id) + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) -async def async_get_data(metadata: BatchMeta) -> TensorDict: - """Asynchronously fetch data from storage units and organize into TensorDict. + if batch_meta.size == 0: + raise RuntimeError("keys or partition were not found!") + + if fields is not None: + if isinstance(fields, str): + fields = [fields] + batch_meta = batch_meta.select_fields(fields) + + if not batch_meta.is_ready: + raise RuntimeError("Some fields are not ready in all the requested keys!") + + data = tq_client.get_data(batch_meta) + return data + + +def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: + """List all keys and their metadata in one or all partitions. Args: - metadata: Batch metadata containing data location information and global indexes + partition_id: The specific partition_id to query. + If None (default), returns keys from all partitions. Returns: - TensorDict containing: - - Requested data fields (e.g., "prompts", "attention_mask") + A nested dictionary mapping partition IDs to their keys and metadata. + + Structure: + { + "partition_id": { + "key_name": { + "tag1": , + ... (other metadata) + }, + ..., + }, + ... + } Example: >>> import transfer_queue as tq >>> tq.init() - >>> - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", - ... )) - >>> batch = asyncio.run(tq.async_get_data(batch_meta)) - >>> print(batch) - >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes + >>> # Case 1: Retrieve a specific partition + >>> partitions = tq.kv_list(partition_id="train") + >>> print(f"Keys: {list(partitions['train'].keys())}") + >>> print(f"Tags: {list(partitions['train'].values())}") + >>> # Case 2: Retrieve all partitions + >>> all_partitions = tq.kv_list() + >>> for pid, keys in all_partitions.items(): + >>> print(f"Partition: {pid}, Key count: {len(keys)}") """ tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_get_data(metadata) - - -# ==================== Data Operations API ==================== + partition_info = tq_client.kv_list(partition_id) -async def async_clear_samples(metadata: BatchMeta): - """Asynchronously clear specific samples from all storage units and the controller. - - Args: - metadata: The BatchMeta of the corresponding data to be cleared + return partition_info - Raises: - RuntimeError: If clear operation fails - """ - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_clear_samples(metadata) +def kv_clear(keys: list[str] | str, partition_id: str) -> None: + """Clear key-value pairs from TransferQueue. -async def async_clear_partition(partition_id: str): - """Asynchronously clear the whole partition from all storage units and the controller. + This removes the specified keys and their associated data from both + the controller and storage units. Args: - partition_id: The partition id to clear data for + keys: Single key or list of keys to clear + partition_id: Partition containing the keys - Raises: - RuntimeError: If clear operation fails + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # Clear single key + >>> tq.kv_clear(keys="sample_1", partition_id="train") + >>> # Clear multiple keys + >>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train") """ - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_clear_partition(partition_id) - - -def close(): - """Close the TransferQueue system. - This function cleans up the TransferQueue system, including: - - Closing the client and its associated resources - - Cleaning up distributed storage (only for the process that initialized it) - - Killing the controller actor - - Note: - This function should be called when the TransferQueue system is no longer needed. - """ - global _TRANSFER_QUEUE_CLIENT - global _TRANSFER_QUEUE_STORAGE - if _TRANSFER_QUEUE_CLIENT: - _TRANSFER_QUEUE_CLIENT.close() - _TRANSFER_QUEUE_CLIENT = None + if isinstance(keys, str): + keys = [keys] - try: - if _TRANSFER_QUEUE_STORAGE: - # only the process that do first-time init can clean the distributed storage - for storage in _TRANSFER_QUEUE_STORAGE.values(): - ray.kill(storage) - _TRANSFER_QUEUE_STORAGE = None - except Exception: - pass + tq_client = _maybe_create_transferqueue_client() + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) - try: - controller = ray.get_actor("TransferQueueController") - ray.kill(controller) - except Exception: - pass + if batch_meta.size > 0: + tq_client.clear_samples(batch_meta) # ==================== KV Interface API ==================== -def kv_put( +async def async_kv_put( key: str, partition_id: str, fields: Optional[TensorDict | dict[str, Any]] = None, tag: Optional[dict[str, Any]] = None, ) -> None: - """Put a single key-value pair to TransferQueue. + """Asynchronously put a single key-value pair to TransferQueue. This is a convenience method for putting data using a user-specified key instead of BatchMeta. Internally, the key is translated to a BatchMeta @@ -693,26 +520,27 @@ def kv_put( >>> import torch >>> tq.init() >>> # Put with both fields and tag - >>> tq.kv_put( + >>> await tq.async_kv_put( ... key="sample_1", ... partition_id="train", ... fields={"input_ids": torch.tensor([1, 2, 3])}, ... tag={"score": 0.95} - ... ) + ... )) """ + if fields is None and tag is None: - raise ValueError("Please provide at least one parameter of `fields` or `tag`.") + raise ValueError("Please provide at least one parameter of fields or tag.") tq_client = _maybe_create_transferqueue_client() # 1. translate user-specified key to BatchMeta - batch_meta = tq_client.kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True) + batch_meta = await tq_client.async_kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True) if batch_meta.size != 1: raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") # 2. register the user-specified tag to BatchMeta - if tag is not None: + if tag: batch_meta.update_custom_meta([tag]) # 3. put data @@ -731,17 +559,17 @@ def kv_put( elif not isinstance(fields, TensorDict): raise ValueError("field can only be dict or TensorDict") - # custom_meta (tag) will be put to controller through the internal put process - tq_client.put(fields, batch_meta) + # custom_meta (tag) will be put to controller through the put process + await tq_client.async_put(fields, batch_meta) else: # directly update custom_meta (tag) to controller - tq_client.set_custom_meta(batch_meta) + await tq_client.async_set_custom_meta(batch_meta) -def kv_batch_put( +async def async_kv_batch_put( keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None ) -> None: - """Put multiple key-value pairs to TransferQueue in batch. + """Asynchronously put multiple key-value pairs to TransferQueue in batch. This method stores multiple key-value pairs in a single operation, which is more efficient than calling kv_put multiple times. @@ -759,7 +587,6 @@ def kv_batch_put( Example: >>> import transfer_queue as tq - >>> from tensordict import TensorDict >>> tq.init() >>> keys = ["sample_1", "sample_2", "sample_3"] >>> fields = TensorDict({ @@ -767,11 +594,11 @@ def kv_batch_put( ... "attention_mask": torch.ones(3, 10), ... }, batch_size=3) >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}] - >>> tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) + >>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) """ if fields is None and tags is None: - raise ValueError("Please provide at least one parameter of fields or tag.") + raise ValueError("Please provide at least one parameter of `fields` or `tags`.") if fields is not None and fields.batch_size[0] != len(keys): raise ValueError( @@ -782,7 +609,7 @@ def kv_batch_put( tq_client = _maybe_create_transferqueue_client() # 1. translate user-specified key to BatchMeta - batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True) + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True) if batch_meta.size != len(keys): raise RuntimeError( @@ -790,394 +617,565 @@ def kv_batch_put( ) # 2. register the user-specified tags to BatchMeta - if tags: + if tags is not None: if len(tags) != len(keys): raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}") batch_meta.update_custom_meta(tags) - # 3. put data - if fields is not None: - tq_client.put(fields, batch_meta) - else: - # directly update custom_meta (tags) to controller - tq_client.set_custom_meta(batch_meta) + # 3. put data + if fields is not None: + await tq_client.async_put(fields, batch_meta) + else: + # directly update custom_meta (tags) to controller + await tq_client.async_set_custom_meta(batch_meta) + + +async def async_kv_batch_get( + keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None +) -> TensorDict: + """Asynchronously get data from TransferQueue using user-specified keys. + + This is a convenience method for retrieving data using keys instead of indexes. + + Args: + keys: Single key or list of keys to retrieve + partition_id: Partition containing the keys + fields: Optional field(s) to retrieve. If None, retrieves all fields + + Returns: + TensorDict with the requested data + + Raises: + RuntimeError: If keys or partition are not found + RuntimeError: If empty fields exist in any key (sample) + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # Get single key with all fields + >>> data = await tq.async_kv_batch_get(keys="sample_1", partition_id="train") + >>> # Get multiple keys with specific fields + >>> data = await tq.async_kv_batch_get( + ... keys=["sample_1", "sample_2"], + ... partition_id="train", + ... fields="input_ids" + ... ) + """ + tq_client = _maybe_create_transferqueue_client() + + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + + if batch_meta.size == 0: + raise RuntimeError("keys or partition were not found!") + + if fields is not None: + if isinstance(fields, str): + fields = [fields] + batch_meta = batch_meta.select_fields(fields) + + if not batch_meta.is_ready: + raise RuntimeError("Some fields are not ready in all the requested keys!") + + data = await tq_client.async_get_data(batch_meta) + return data + + +async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: + """Asynchronously list all keys and their metadata in one or all partitions. + + Args: + partition_id: The specific partition_id to query. + If None (default), returns keys from all partitions. + + Returns: + A nested dictionary mapping partition IDs to their keys and metadata. + + Structure: + { + "partition_id": { + "key_name": { + "tag1": , + ... (other metadata) + }, + ..., + }, + ... + } + + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # Case 1: Retrieve a specific partition + >>> partitions = await tq.async_kv_list(partition_id="train") + >>> print(f"Keys: {list(partitions['train'].keys())}") + >>> print(f"Tags: {list(partitions['train'].values())}") + >>> # Case 2: Retrieve all partitions + >>> all_partitions = await tq.async_kv_list() + >>> for pid, keys in all_partitions.items(): + >>> print(f"Partition: {pid}, Key count: {len(keys)}") + """ + tq_client = _maybe_create_transferqueue_client() + + partition_info = await tq_client.async_kv_list(partition_id) + + return partition_info + + +async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: + """Asynchronously clear key-value pairs from TransferQueue. + + This removes the specified keys and their associated data from both + the controller and storage units. + + Args: + keys: Single key or list of keys to clear + partition_id: Partition containing the keys + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # Clear single key + >>> await tq.async_kv_clear(keys="sample_1", partition_id="train") + >>> # Clear multiple keys + >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train") + """ + + if isinstance(keys, str): + keys = [keys] + + tq_client = _maybe_create_transferqueue_client() + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + + if batch_meta.size > 0: + await tq_client.async_clear_samples(batch_meta) + + +# ==================== Low-Level Native API ==================== +def get_meta( + data_fields: list[str], + batch_size: int, + partition_id: str, + mode: str = "fetch", + task_name: Optional[str] = None, + sampling_config: Optional[dict[str, Any]] = None, +) -> BatchMeta: + """Synchronously fetch data metadata from the controller via ZMQ. + + Args: + data_fields: List of data field names to retrieve metadata for + batch_size: Number of samples to request in the batch + partition_id: Current data partition id + mode: Data fetch mode. Options: + - 'fetch': Get ready data only + - 'force_fetch': Get data regardless of readiness (may return unready samples) + - 'insert': Internal usage - should not be used by users + task_name: Optional task name associated with the request + sampling_config: Optional sampling configuration for custom samplers. + + Returns: + BatchMeta: Metadata object containing data structure, sample information, and readiness status + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> # Example 1: Basic fetch metadata + >>> batch_meta = tq.get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences" + ... ) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) + >>> batch_meta = tq.get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=8, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... sampling_config={"n_samples_per_prompt": 4} + ... ) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, + >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) + >>> batch_meta = tq.get_meta( + ... partition_id="train_0", # optional + ... mode="force_fetch", + ... ) + >>> print(batch_meta.is_ready) # May be False if some samples not ready + """ + + tq_client = _maybe_create_transferqueue_client() + return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) + + +def set_custom_meta(metadata: BatchMeta) -> None: + """Synchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() + before calling this method. + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> + >>> # Create batch with custom metadata + >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...) + >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) + >>> tq.set_custom_meta(batch_meta) + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.set_custom_meta(metadata) + + +def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta: + """Synchronously write data to storage units based on metadata. + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. -def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict: - """Get data from TransferQueue using user-specified keys. + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. - This is a convenience method for retrieving data using keys instead of indexes. + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. Args: - keys: Single key or list of keys to retrieve - partition_id: Partition containing the keys - fields: Optional field(s) to retrieve. If None, retrieves all fields + data: Data to write as TensorDict + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. + partition_id: Target data partition id (required if metadata is not provided) Returns: - TensorDict with the requested data + BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved + metadata; will be updated in a future version to reflect the post-put state) Raises: - RuntimeError: If keys or partition are not found - RuntimeError: If empty fields exist in any key (sample) + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails Example: >>> import transfer_queue as tq >>> tq.init() - >>> # Get single key with all fields - >>> data = tq.kv_batch_get(keys="sample_1", partition_id="train") - >>> # Get multiple keys with specific fields - >>> data = tq.kv_batch_get( - ... keys=["sample_1", "sample_2"], - ... partition_id="train", - ... fields="input_ids" + >>> + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = tq.get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", ... ) + >>> batch = tq.get_data(batch_meta) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> tq.put(data=output, metadata=batch_meta) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id) """ tq_client = _maybe_create_transferqueue_client() - - batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) - - if batch_meta.size == 0: - raise RuntimeError("keys or partition were not found!") - - if fields is not None: - if isinstance(fields, str): - fields = [fields] - batch_meta = batch_meta.select_fields(fields) - - if not batch_meta.is_ready: - raise RuntimeError("Some fields are not ready in all the requested keys!") - - data = tq_client.get_data(batch_meta) - return data + return tq_client.put(data, metadata, partition_id) -def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: - """List all keys and their metadata in one or all partitions. +def get_data(metadata: BatchMeta) -> TensorDict: + """Synchronously fetch data from storage units and organize into TensorDict. Args: - partition_id: The specific partition_id to query. - If None (default), returns keys from all partitions. + metadata: Batch metadata containing data location information and global indexes Returns: - A nested dictionary mapping partition IDs to their keys and metadata. - - Structure: - { - "partition_id": { - "key_name": { - "tag1": , - ... (other metadata) - }, - ..., - }, - ... - } + TensorDict containing: + - Requested data fields (e.g., "prompts", "attention_mask") Example: >>> import transfer_queue as tq >>> tq.init() - >>> # Case 1: Retrieve a specific partition - >>> partitions = tq.kv_list(partition_id="train") - >>> print(f"Keys: {list(partitions['train'].keys())}") - >>> print(f"Tags: {list(partitions['train'].values())}") - >>> # Case 2: Retrieve all partitions - >>> all_partitions = tq.kv_list() - >>> for pid, keys in all_partitions.items(): - >>> print(f"Partition: {pid}, Key count: {len(keys)}") + >>> + >>> batch_meta = tq.get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = tq.get_data(batch_meta) + >>> print(batch) + >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes """ tq_client = _maybe_create_transferqueue_client() + return tq_client.get_data(metadata) - partition_info = tq_client.kv_list(partition_id) - return partition_info +def clear_partition(partition_id: str): + """Synchronously clear the whole partition from all storage units and the controller. + Args: + partition_id: The partition id to clear data for -def kv_clear(keys: list[str] | str, partition_id: str) -> None: - """Clear key-value pairs from TransferQueue. + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.clear_partition(partition_id) - This removes the specified keys and their associated data from both - the controller and storage units. + +def clear_samples(metadata: BatchMeta): + """Synchronously clear specific samples from all storage units and the controller. Args: - keys: Single key or list of keys to clear - partition_id: Partition containing the keys + metadata: The BatchMeta of the corresponding data to be cleared - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> # Clear single key - >>> tq.kv_clear(keys="sample_1", partition_id="train") - >>> # Clear multiple keys - >>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train") + Raises: + RuntimeError: If clear operation fails """ - - if isinstance(keys, str): - keys = [keys] - tq_client = _maybe_create_transferqueue_client() - batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) - - if batch_meta.size > 0: - tq_client.clear_samples(batch_meta) + return tq_client.clear_samples(metadata) -# ==================== KV Interface API ==================== -async def async_kv_put( - key: str, +async def async_get_meta( + data_fields: list[str], + batch_size: int, partition_id: str, - fields: Optional[TensorDict | dict[str, Any]] = None, - tag: Optional[dict[str, Any]] = None, -) -> None: - """Asynchronously put a single key-value pair to TransferQueue. - - This is a convenience method for putting data using a user-specified key - instead of BatchMeta. Internally, the key is translated to a BatchMeta - and the data is stored using the regular put mechanism. + mode: str = "fetch", + task_name: Optional[str] = None, + sampling_config: Optional[dict[str, Any]] = None, +) -> BatchMeta: + """Asynchronously fetch data metadata from the controller via ZMQ. Args: - key: User-specified key for the data sample (in row) - partition_id: Logical partition to store the data in - fields: Data fields to store. Can be a TensorDict or a dict of tensors. - Each key in `fields` will be treated as a column for the data sample. - If dict is provided, tensors will be unsqueezed to add batch dimension. - tag: Optional metadata tag to associate with the key + data_fields: List of data field names to retrieve metadata for + batch_size: Number of samples to request in the batch + partition_id: Current data partition id + mode: Data fetch mode. Options: + - 'fetch': Get ready data only + - 'force_fetch': Get data regardless of readiness (may return unready samples) + - 'insert': Internal usage - should not be used by users + task_name: Optional task name associated with the request + sampling_config: Optional sampling configuration for custom samplers. + + Returns: + BatchMeta: Metadata object containing data structure, sample information, and readiness status Raises: - ValueError: If neither fields nor tag is provided - ValueError: If nested tensors are provided (use kv_batch_put instead) - RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` + RuntimeError: If communication fails or controller returns error response Example: >>> import transfer_queue as tq - >>> import torch >>> tq.init() - >>> # Put with both fields and tag - >>> await tq.async_kv_put( - ... key="sample_1", - ... partition_id="train", - ... fields={"input_ids": torch.tensor([1, 2, 3])}, - ... tag={"score": 0.95} + >>> + >>> # Example 1: Basic fetch metadata + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences" ... )) - """ - - if fields is None and tag is None: - raise ValueError("Please provide at least one parameter of fields or tag.") - - tq_client = _maybe_create_transferqueue_client() - - # 1. translate user-specified key to BatchMeta - batch_meta = await tq_client.async_kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True) - - if batch_meta.size != 1: - raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") - - # 2. register the user-specified tag to BatchMeta - if tag: - batch_meta.update_custom_meta([tag]) - - # 3. put data - if fields is not None: - if isinstance(fields, dict): - # TODO: consider whether to support this... - batch = {} - for field_name, value in fields.items(): - if isinstance(value, torch.Tensor): - if value.is_nested: - raise ValueError("Please use (async)kv_batch_put for batch operation") - batch[field_name] = value.unsqueeze(0) - else: - batch[field_name] = NonTensorStack(value) - fields = TensorDict(batch, batch_size=[1]) - elif not isinstance(fields, TensorDict): - raise ValueError("field can only be dict or TensorDict") + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=8, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, + >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... partition_id="train_0", # optional + ... mode="force_fetch", + ... )) + >>> print(batch_meta.is_ready) # May be False if some samples not ready + """ - # custom_meta (tag) will be put to controller through the put process - await tq_client.async_put(fields, batch_meta) - else: - # directly update custom_meta (tag) to controller - await tq_client.async_set_custom_meta(batch_meta) + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) -async def async_kv_batch_put( - keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None +async def async_set_custom_meta( + metadata: BatchMeta, ) -> None: - """Asynchronously put multiple key-value pairs to TransferQueue in batch. + """ + Asynchronously send custom metadata to the controller. - This method stores multiple key-value pairs in a single operation, which is more - efficient than calling kv_put multiple times. + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. Args: - keys: List of user-specified keys for the data - partition_id: Logical partition to store the data in - fields: TensorDict containing data for all keys. Must have batch_size == len(keys) - tags: List of metadata tags, one for each key + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() + before calling this method. + socket: ZMQ async socket for message transmission (injected by decorator) Raises: - ValueError: If neither `fields` nor `tags` is provided - ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict - RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` + RuntimeError: If communication fails or controller returns error response Example: >>> import transfer_queue as tq >>> tq.init() - >>> keys = ["sample_1", "sample_2", "sample_3"] - >>> fields = TensorDict({ - ... "input_ids": torch.randn(3, 10), - ... "attention_mask": torch.ones(3, 10), - ... }, batch_size=3) - >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}] - >>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) + >>> + >>> # Create batch with custom metadata + >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...) + >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) + >>> asyncio.run(tq.async_set_custom_meta(batch_meta)) """ - - if fields is None and tags is None: - raise ValueError("Please provide at least one parameter of `fields` or `tags`.") - - if fields is not None and fields.batch_size[0] != len(keys): - raise ValueError( - f"`keys` with length {len(keys)} does not match the `fields` TensorDict with " - f"batch_size {fields.batch_size[0]}" - ) - tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_set_custom_meta(metadata) - # 1. translate user-specified key to BatchMeta - batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True) - - if batch_meta.size != len(keys): - raise RuntimeError( - f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!" - ) - - # 2. register the user-specified tags to BatchMeta - if tags is not None: - if len(tags) != len(keys): - raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}") - batch_meta.update_custom_meta(tags) - # 3. put data - if fields is not None: - await tq_client.async_put(fields, batch_meta) - else: - # directly update custom_meta (tags) to controller - await tq_client.async_set_custom_meta(batch_meta) +async def async_put( + data: TensorDict, + metadata: Optional[BatchMeta] = None, + partition_id: Optional[str] = None, +) -> BatchMeta: + """Asynchronously write data to storage units based on metadata. + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. -async def async_kv_batch_get( - keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None -) -> TensorDict: - """Asynchronously get data from TransferQueue using user-specified keys. + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. - This is a convenience method for retrieving data using keys instead of indexes. + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. Args: - keys: Single key or list of keys to retrieve - partition_id: Partition containing the keys - fields: Optional field(s) to retrieve. If None, retrieves all fields + data: Data to write as TensorDict + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. + partition_id: Target data partition id (required if metadata is not provided) Returns: - TensorDict with the requested data + BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved + metadata; will be updated in a future version to reflect the post-put state) Raises: - RuntimeError: If keys or partition are not found - RuntimeError: If empty fields exist in any key (sample) + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails Example: >>> import transfer_queue as tq >>> tq.init() - >>> # Get single key with all fields - >>> data = await tq.async_kv_batch_get(keys="sample_1", partition_id="train") - >>> # Get multiple keys with specific fields - >>> data = await tq.async_kv_batch_get( - ... keys=["sample_1", "sample_2"], - ... partition_id="train", - ... fields="input_ids" - ... ) + >>> + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> batch = asyncio.run(tq.async_get_data(batch_meta)) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta)) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) """ tq_client = _maybe_create_transferqueue_client() - - batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) - - if batch_meta.size == 0: - raise RuntimeError("keys or partition were not found!") - - if fields is not None: - if isinstance(fields, str): - fields = [fields] - batch_meta = batch_meta.select_fields(fields) - - if not batch_meta.is_ready: - raise RuntimeError("Some fields are not ready in all the requested keys!") - - data = await tq_client.async_get_data(batch_meta) - return data + return await tq_client.async_put(data, metadata, partition_id) -async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: - """Asynchronously list all keys and their metadata in one or all partitions. +async def async_get_data(metadata: BatchMeta) -> TensorDict: + """Asynchronously fetch data from storage units and organize into TensorDict. Args: - partition_id: The specific partition_id to query. - If None (default), returns keys from all partitions. + metadata: Batch metadata containing data location information and global indexes Returns: - A nested dictionary mapping partition IDs to their keys and metadata. - - Structure: - { - "partition_id": { - "key_name": { - "tag1": , - ... (other metadata) - }, - ..., - }, - ... - } - + TensorDict containing: + - Requested data fields (e.g., "prompts", "attention_mask") Example: >>> import transfer_queue as tq >>> tq.init() - >>> # Case 1: Retrieve a specific partition - >>> partitions = await tq.async_kv_list(partition_id="train") - >>> print(f"Keys: {list(partitions['train'].keys())}") - >>> print(f"Tags: {list(partitions['train'].values())}") - >>> # Case 2: Retrieve all partitions - >>> all_partitions = await tq.async_kv_list() - >>> for pid, keys in all_partitions.items(): - >>> print(f"Partition: {pid}, Key count: {len(keys)}") + >>> + >>> batch_meta = asyncio.run(tq.async_get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... )) + >>> batch = asyncio.run(tq.async_get_data(batch_meta)) + >>> print(batch) + >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes """ tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_get_data(metadata) - partition_info = await tq_client.async_kv_list(partition_id) - return partition_info +async def async_clear_samples(metadata: BatchMeta): + """Asynchronously clear specific samples from all storage units and the controller. + Args: + metadata: The BatchMeta of the corresponding data to be cleared -async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: - """Asynchronously clear key-value pairs from TransferQueue. + Raises: + RuntimeError: If clear operation fails + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_clear_samples(metadata) - This removes the specified keys and their associated data from both - the controller and storage units. + +async def async_clear_partition(partition_id: str): + """Asynchronously clear the whole partition from all storage units and the controller. Args: - keys: Single key or list of keys to clear - partition_id: Partition containing the keys + partition_id: The partition id to clear data for - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> # Clear single key - >>> await tq.async_kv_clear(keys="sample_1", partition_id="train") - >>> # Clear multiple keys - >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train") + Raises: + RuntimeError: If clear operation fails """ - - if isinstance(keys, str): - keys = [keys] - tq_client = _maybe_create_transferqueue_client() - batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) - - if batch_meta.size > 0: - await tq_client.async_clear_samples(batch_meta) + return await tq_client.async_clear_partition(partition_id) From 1bac20f2de639b0b6b3b04c4bc4691546cb4be85 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 21:03:28 +0800 Subject: [PATCH 31/34] simplify interface Signed-off-by: 0oshowero0 --- transfer_queue/__init__.py | 94 +++----- transfer_queue/interface.py | 434 +----------------------------------- 2 files changed, 43 insertions(+), 485 deletions(-) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 60c8f1d..4732e84 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -16,84 +16,64 @@ import os from .client import TransferQueueClient -from .controller import TransferQueueController from .dataloader import StreamingDataLoader, StreamingDataset from .interface import ( - async_clear_partition, - async_clear_samples, - async_get_data, - async_get_meta, async_kv_batch_get, async_kv_batch_put, async_kv_clear, async_kv_list, async_kv_put, - async_put, - async_set_custom_meta, - clear_partition, - clear_samples, close, - get_data, - get_meta, + get_client, init, kv_batch_get, kv_batch_put, kv_clear, kv_list, kv_put, - put, - set_custom_meta, ) from .metadata import BatchMeta, KVBatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler from .sampler.rank_aware_sampler import RankAwareSampler from .sampler.sequential_sampler import SequentialSampler -from .storage import SimpleStorageUnit -from .utils.common import get_placement_group -from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info -__all__ = [ - "init", - "close", - "kv_put", - "kv_batch_put", - "kv_batch_get", - "kv_list", - "kv_clear", - "async_kv_put", - "async_kv_batch_put", - "async_kv_batch_get", - "async_kv_list", - "async_kv_clear", - "KVBatchMeta", -] + [ - "get_meta", - "get_data", - "put", - "set_custom_meta", - "clear_samples", - "clear_partition", - "async_get_meta", - "async_get_data", - "async_put", - "async_set_custom_meta", - "async_clear_samples", - "async_clear_partition", - "BatchMeta", - "TransferQueueClient", - "StreamingDataset", - "StreamingDataLoader", - "TransferQueueController", - "SimpleStorageUnit", - "ZMQServerInfo", - "process_zmq_server_info", - "get_placement_group", - "BaseSampler", - "GRPOGroupNSampler", - "SequentialSampler", - "RankAwareSampler", -] +__all__ = ( + [ + # High-Level KV Interface + "init", + "close", + "kv_put", + "kv_batch_put", + "kv_batch_get", + "kv_list", + "kv_clear", + "async_kv_put", + "async_kv_batch_put", + "async_kv_batch_get", + "async_kv_list", + "async_kv_clear", + "KVBatchMeta", + ] + + [ + # High-Level StreamingDataLoader Interface + "StreamingDataset", + "StreamingDataLoader", + ] + + [ + # Low-Level Native Interface + "get_client", + "BatchMeta", + "TransferQueueClient", + ] + + [ + # Sampler + "BaseSampler", + "GRPOGroupNSampler", + "SequentialSampler", + "RankAwareSampler", + ] +) version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index fb6b5f7..733df24 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -28,7 +28,6 @@ from transfer_queue.client import TransferQueueClient from transfer_queue.controller import TransferQueueController -from transfer_queue.metadata import BatchMeta from transfer_queue.sampler import * # noqa: F401 from transfer_queue.sampler import BaseSampler from transfer_queue.storage.simple_backend import SimpleStorageUnit @@ -752,430 +751,9 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: # ==================== Low-Level Native API ==================== -def get_meta( - data_fields: list[str], - batch_size: int, - partition_id: str, - mode: str = "fetch", - task_name: Optional[str] = None, - sampling_config: Optional[dict[str, Any]] = None, -) -> BatchMeta: - """Synchronously fetch data metadata from the controller via ZMQ. - - Args: - data_fields: List of data field names to retrieve metadata for - batch_size: Number of samples to request in the batch - partition_id: Current data partition id - mode: Data fetch mode. Options: - - 'fetch': Get ready data only - - 'force_fetch': Get data regardless of readiness (may return unready samples) - - 'insert': Internal usage - should not be used by users - task_name: Optional task name associated with the request - sampling_config: Optional sampling configuration for custom samplers. - - Returns: - BatchMeta: Metadata object containing data structure, sample information, and readiness status - - Raises: - RuntimeError: If communication fails or controller returns error response - - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> # Example 1: Basic fetch metadata - >>> batch_meta = tq.get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences" - ... ) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) - >>> batch_meta = tq.get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=8, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", - ... sampling_config={"n_samples_per_prompt": 4} - ... ) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, - >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) - >>> batch_meta = tq.get_meta( - ... partition_id="train_0", # optional - ... mode="force_fetch", - ... ) - >>> print(batch_meta.is_ready) # May be False if some samples not ready - """ - - tq_client = _maybe_create_transferqueue_client() - return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) - - -def set_custom_meta(metadata: BatchMeta) -> None: - """Synchronously send custom metadata to the controller. - - This method sends per-sample custom metadata (custom_meta) to the controller. - The custom_meta is stored in the controller and can be retrieved along with - the BatchMeta in subsequent get_meta calls. - - Args: - metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() - before calling this method. - - Raises: - RuntimeError: If communication fails or controller returns error response - - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> # Create batch with custom metadata - >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...) - >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) - >>> tq.set_custom_meta(batch_meta) - """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.set_custom_meta(metadata) - - -def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta: - """Synchronously write data to storage units based on metadata. - - If metadata is not provided, it will be created automatically using insert mode - with the provided data fields and partition_id. - - During put, the custom_meta in metadata will update the corresponding custom_meta in - TransferQueue Controller. - - Note: - When using multiple workers for distributed execution, there may be data - ordering inconsistencies between workers during put operations. - - Args: - data: Data to write as TensorDict - metadata: Records the metadata of a batch of data samples, containing index and - storage unit information. If None, metadata will be auto-generated. - partition_id: Target data partition id (required if metadata is not provided) - - Returns: - BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved - metadata; will be updated in a future version to reflect the post-put state) - - Raises: - ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided - RuntimeError: If storage operation fails - - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_partition_id = "train_0" - >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = tq.get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=batch_size, - ... partition_id=current_partition_id, - ... mode="fetch", - ... task_name="generate_sequences", - ... ) - >>> batch = tq.get_data(batch_meta) - >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> tq.put(data=output, metadata=batch_meta) - >>> - >>> # Example 2: Initial data insertion without pre-existing metadata - >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! - >>> # Please make sure the corresponding partition_id is empty before calling the async_put() - >>> # without metadata. - >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with - >>> # interleave the initial data if n_sample > 1 before calling the async_put(). - >>> original_prompts = torch.randn(batch_size, seq_len) - >>> n_samples = 4 - >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) - >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) - >>> # This will create metadata in "insert" mode internally. - >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id) - """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.put(data, metadata, partition_id) - - -def get_data(metadata: BatchMeta) -> TensorDict: - """Synchronously fetch data from storage units and organize into TensorDict. - - Args: - metadata: Batch metadata containing data location information and global indexes - - Returns: - TensorDict containing: - - Requested data fields (e.g., "prompts", "attention_mask") - - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> batch_meta = tq.get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", - ... ) - >>> batch = tq.get_data(batch_meta) - >>> print(batch) - >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes - """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.get_data(metadata) - - -def clear_partition(partition_id: str): - """Synchronously clear the whole partition from all storage units and the controller. - - Args: - partition_id: The partition id to clear data for - - Raises: - RuntimeError: If clear operation fails - """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.clear_partition(partition_id) - - -def clear_samples(metadata: BatchMeta): - """Synchronously clear specific samples from all storage units and the controller. - - Args: - metadata: The BatchMeta of the corresponding data to be cleared - - Raises: - RuntimeError: If clear operation fails - """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.clear_samples(metadata) - - -async def async_get_meta( - data_fields: list[str], - batch_size: int, - partition_id: str, - mode: str = "fetch", - task_name: Optional[str] = None, - sampling_config: Optional[dict[str, Any]] = None, -) -> BatchMeta: - """Asynchronously fetch data metadata from the controller via ZMQ. - - Args: - data_fields: List of data field names to retrieve metadata for - batch_size: Number of samples to request in the batch - partition_id: Current data partition id - mode: Data fetch mode. Options: - - 'fetch': Get ready data only - - 'force_fetch': Get data regardless of readiness (may return unready samples) - - 'insert': Internal usage - should not be used by users - task_name: Optional task name associated with the request - sampling_config: Optional sampling configuration for custom samplers. - - Returns: - BatchMeta: Metadata object containing data structure, sample information, and readiness status - - Raises: - RuntimeError: If communication fails or controller returns error response - - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> # Example 1: Basic fetch metadata - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences" - ... )) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=8, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", - ... )) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, - >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... partition_id="train_0", # optional - ... mode="force_fetch", - ... )) - >>> print(batch_meta.is_ready) # May be False if some samples not ready - """ - - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) - - -async def async_set_custom_meta( - metadata: BatchMeta, -) -> None: - """ - Asynchronously send custom metadata to the controller. - - This method sends per-sample custom metadata (custom_meta) to the controller. - The custom_meta is stored in the controller and can be retrieved along with - the BatchMeta in subsequent get_meta calls. - - Args: - metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() - before calling this method. - socket: ZMQ async socket for message transmission (injected by decorator) - - Raises: - RuntimeError: If communication fails or controller returns error response - - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> # Create batch with custom metadata - >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...) - >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) - >>> asyncio.run(tq.async_set_custom_meta(batch_meta)) - """ - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_set_custom_meta(metadata) - - -async def async_put( - data: TensorDict, - metadata: Optional[BatchMeta] = None, - partition_id: Optional[str] = None, -) -> BatchMeta: - """Asynchronously write data to storage units based on metadata. - - If metadata is not provided, it will be created automatically using insert mode - with the provided data fields and partition_id. - - During put, the custom_meta in metadata will update the corresponding custom_meta in - TransferQueue Controller. - - Note: - When using multiple workers for distributed execution, there may be data - ordering inconsistencies between workers during put operations. - - Args: - data: Data to write as TensorDict - metadata: Records the metadata of a batch of data samples, containing index and - storage unit information. If None, metadata will be auto-generated. - partition_id: Target data partition id (required if metadata is not provided) - - Returns: - BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved - metadata; will be updated in a future version to reflect the post-put state) - - Raises: - ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided - RuntimeError: If storage operation fails - - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_partition_id = "train_0" - >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=batch_size, - ... partition_id=current_partition_id, - ... mode="fetch", - ... task_name="generate_sequences", - ... )) - >>> batch = asyncio.run(tq.async_get_data(batch_meta)) - >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta)) - >>> - >>> # Example 2: Initial data insertion without pre-existing metadata - >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! - >>> # Please make sure the corresponding partition_id is empty before calling the async_put() - >>> # without metadata. - >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with - >>> # interleave the initial data if n_sample > 1 before calling the async_put(). - >>> original_prompts = torch.randn(batch_size, seq_len) - >>> n_samples = 4 - >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) - >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) - >>> # This will create metadata in "insert" mode internally. - >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) - """ - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_put(data, metadata, partition_id) - - -async def async_get_data(metadata: BatchMeta) -> TensorDict: - """Asynchronously fetch data from storage units and organize into TensorDict. - - Args: - metadata: Batch metadata containing data location information and global indexes - - Returns: - TensorDict containing: - - Requested data fields (e.g., "prompts", "attention_mask") - - Example: - >>> import transfer_queue as tq - >>> tq.init() - >>> - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", - ... )) - >>> batch = asyncio.run(tq.async_get_data(batch_meta)) - >>> print(batch) - >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes - """ - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_get_data(metadata) - - -async def async_clear_samples(metadata: BatchMeta): - """Asynchronously clear specific samples from all storage units and the controller. - - Args: - metadata: The BatchMeta of the corresponding data to be cleared - - Raises: - RuntimeError: If clear operation fails - """ - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_clear_samples(metadata) - - -async def async_clear_partition(partition_id: str): - """Asynchronously clear the whole partition from all storage units and the controller. - - Args: - partition_id: The partition id to clear data for - - Raises: - RuntimeError: If clear operation fails - """ - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_clear_partition(partition_id) +# For low-level API support, please refer to transfer_queue/client.py for details. +def get_client(): + global _TRANSFER_QUEUE_CLIENT + if _TRANSFER_QUEUE_CLIENT is None: + raise RuntimeError("Please initialize the TransferQueue first by calling `tq.init()`!") + return _TRANSFER_QUEUE_CLIENT From 3830bcf57868e90cbef14282caeeeafb555a585b Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 21:11:00 +0800 Subject: [PATCH 32/34] fix ut Signed-off-by: 0oshowero0 --- tests/test_controller.py | 2 +- tests/test_metadata.py | 2 +- tests/test_simple_storage_unit.py | 2 +- transfer_queue/metadata.py | 5 +++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_controller.py b/tests/test_controller.py index 243bc19..3528d1a 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -28,7 +28,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -from transfer_queue import TransferQueueController # noqa: E402 +from transfer_queue.controller import TransferQueueController # noqa: E402 from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 23ba56d..2a129b5 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1074,7 +1074,7 @@ def test_kv_batch_meta_empty_init(self): assert kv_meta.keys == [] assert kv_meta.tags == [] assert kv_meta.partition_id is None - assert kv_meta.fields == [] + assert kv_meta.fields is None def test_kv_batch_meta_init_validation_keys_tags_mismatch(self): """Example: Init validation catches keys and tags length mismatch.""" diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index 043b506..ed43e41 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -27,7 +27,7 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import SimpleStorageUnit # noqa: E402 +from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402 from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType # noqa: E402 diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 473a487..056a146 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -857,8 +857,9 @@ def __post_init__(self): raise ValueError(f"keys and tags must have same length, but got {len(self.keys)} and {len(self.tags)}") if len(self.keys) != len(set(self.keys)): raise ValueError("Got duplicated keys.") - if len(self.fields) != len(set(self.fields)): - raise ValueError("Got duplicated fields.") + if self.fields is not None: + if len(self.fields) != len(set(self.fields)): + raise ValueError("Got duplicated fields.") # deepcopy to prevent unexpected behavior after chunk/concat self.tags = copy.deepcopy(self.tags) From 5ec124b0173cc74294e4c6069fee8867242715a8 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 21:18:40 +0800 Subject: [PATCH 33/34] fix tutorial Signed-off-by: 0oshowero0 --- tutorial/01_core_components.py | 9 ++--- tutorial/03_metadata_concepts.py | 16 +++++---- tutorial/04_understanding_controller.py | 44 +++++++++++++++---------- tutorial/05_custom_sampler.py | 34 +++++++++++-------- tutorial/06_streaming_dataloader.py | 4 ++- 5 files changed, 63 insertions(+), 44 deletions(-) diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py index 59159a6..9a66b8e 100644 --- a/tutorial/01_core_components.py +++ b/tutorial/01_core_components.py @@ -63,6 +63,7 @@ def demonstrate_data_workflow(): # Step 1: Put data print("[Step 1] Putting data into TransferQueue...") + tq_client = tq.get_client() input_ids = torch.tensor( [ @@ -84,12 +85,12 @@ def demonstrate_data_workflow(): print(f" Created {data_batch.batch_size[0]} samples") partition_id = "tutorial_partition_0" - tq.put(data=data_batch, partition_id=partition_id) + tq_client.put(data=data_batch, partition_id=partition_id) print(f" ✓ Data written to partition: {partition_id}") # Step 2: Get metadata print("[Step 2] Requesting data metadata...") - batch_meta = tq.get_meta( + batch_meta = tq_client.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=data_batch.batch_size[0], partition_id=partition_id, @@ -100,7 +101,7 @@ def demonstrate_data_workflow(): # Step 3: Get actual data print("[Step 3] Retrieving actual data...") - retrieved_data = tq.get_data(batch_meta) + retrieved_data = tq_client.get_data(batch_meta) print(" ✓ Data retrieved successfully") print(f" Keys: {list(retrieved_data.keys())}") @@ -112,7 +113,7 @@ def demonstrate_data_workflow(): # Step 5: Clear print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)") - tq.clear_partition(partition_id=partition_id) + tq_client.clear_partition(partition_id=partition_id) print(" ✓ Partition cleared") diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py index 91752d8..2c81941 100644 --- a/tutorial/03_metadata_concepts.py +++ b/tutorial/03_metadata_concepts.py @@ -310,6 +310,8 @@ def demonstrate_real_workflow(): # Initialize TransferQueue tq.init() + tq_client = tq.get_client() + print("[Step 1] Putting data into TransferQueue...") input_ids = torch.randint(0, 1000, (8, 512)) attention_mask = torch.ones(8, 512) @@ -323,7 +325,7 @@ def demonstrate_real_workflow(): ) partition_id = "demo_partition" - batch_meta = tq.put(data=data_batch, partition_id=partition_id) + batch_meta = tq_client.put(data=data_batch, partition_id=partition_id) print(f"✓ Put {data_batch.batch_size[0]} samples into partition '{partition_id}', got BatchMeta back {batch_meta}.") print("[Step 2] [Optional] Setting sample-level custom_meta...") @@ -335,11 +337,11 @@ def demonstrate_real_workflow(): batch_meta.update_custom_meta(custom_meta) print(f"✓ Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}") - tq.set_custom_meta(batch_meta) + tq_client.set_custom_meta(batch_meta) print("✓ Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.") print("[Step 3] Try to get metadata from TransferQueue from other places...") - batch_meta = tq.get_meta( + batch_meta = tq_client.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=8, partition_id=partition_id, @@ -358,7 +360,7 @@ def demonstrate_real_workflow(): print("✓ Selected 'input_ids' field only:") print(f" New field names: {selected_meta.field_names}") print(f" Samples still have same global indexes: {selected_meta.global_indexes}") - retrieved_data = tq.get_data(selected_meta) + retrieved_data = tq_client.get_data(selected_meta) print(f" Retrieved data keys: {list(retrieved_data.keys())}") print("[Step 5] Select specific samples from the retrieved BatchMeta...") @@ -366,7 +368,7 @@ def demonstrate_real_workflow(): print("✓ Selected samples at indices [0, 2, 4, 6]:") print(f" New global indexes: {partial_meta.global_indexes}") print(f" Number of samples: {len(partial_meta)}") - retrieved_data = tq.get_data(partial_meta) + retrieved_data = tq_client.get_data(partial_meta) print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}") print("[Step 6] Demonstrate chunk operation...") @@ -374,11 +376,11 @@ def demonstrate_real_workflow(): print(f"✓ Chunked into {len(chunks)} parts:") for i, chunk in enumerate(chunks): print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}") - chunk_data = tq.get_data(chunk) + chunk_data = tq_client.get_data(chunk) print(f" Chunk {i}: Retrieved chunk data: {chunk_data}") # Cleanup - tq.clear_partition(partition_id=partition_id) + tq_client.clear_partition(partition_id=partition_id) tq.close() ray.shutdown() print("✓ Partition cleared and resources cleaned up") diff --git a/tutorial/04_understanding_controller.py b/tutorial/04_understanding_controller.py index fbeae21..746de14 100644 --- a/tutorial/04_understanding_controller.py +++ b/tutorial/04_understanding_controller.py @@ -62,6 +62,8 @@ def demonstrate_partition_isolation(): tq.init() + tq_client = tq.get_client() + # Partition 1: Training data print("\n[Partition 1] Putting training data...") train_data = TensorDict( @@ -71,7 +73,7 @@ def demonstrate_partition_isolation(): }, batch_size=2, ) - tq.put(data=train_data, partition_id="train") + tq_client.put(data=train_data, partition_id="train") print(" ✓ Training data added to 'train' partition") # Partition 2: Validation data @@ -83,23 +85,25 @@ def demonstrate_partition_isolation(): }, batch_size=2, ) - tq.put(data=val_data, partition_id="val") + tq_client.put(data=val_data, partition_id="val") print(" ✓ Validation data added to 'val' partition") # Get from train partition print("\n[Retrieving from 'train' partition]") - train_meta = tq.get_meta( + train_meta = tq_client.get_meta( data_fields=["input_ids", "labels"], batch_size=2, partition_id="train", task_name="train_task" ) - retrieved_train_data = tq.get_data(train_meta) + retrieved_train_data = tq_client.get_data(train_meta) print(f" ✓ Got BatchMeta={train_meta} from partition 'train'") print(f" ✓ Retrieved Data: input_ids={retrieved_train_data['input_ids']}, labels={retrieved_train_data['labels']}") # Get from val partition print("\n[Retrieving from 'val' partition]") - val_meta = tq.get_meta(data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task") - retrieved_val_data = tq.get_data(val_meta) + val_meta = tq_client.get_meta( + data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task" + ) + retrieved_val_data = tq_client.get_data(val_meta) print(f" ✓ Got BatchMeta={val_meta} from partition 'val'") print(f" ✓ Retrieved Data: input_ids={retrieved_val_data['input_ids']}, labels={retrieved_val_data['labels']}") @@ -107,8 +111,8 @@ def demonstrate_partition_isolation(): print(" ✓ Data isolation: 'train' and 'val' partitions are completely independent") # Cleanup - tq.clear_partition(partition_id="train") - tq.clear_partition(partition_id="val") + tq_client.clear_partition(partition_id="train") + tq_client.clear_partition(partition_id="val") tq.close() ray.shutdown() @@ -126,6 +130,8 @@ def demonstrate_dynamic_expansion(): tq.init() + tq_client = tq.get_client() + # Add first batch with 2 samples, 2 fields print("\n[Step 1] Adding initial data (2 samples, 2 fields)...") data1 = TensorDict( @@ -135,7 +141,7 @@ def demonstrate_dynamic_expansion(): }, batch_size=2, ) - meta1 = tq.put(data=data1, partition_id="dynamic") + meta1 = tq_client.put(data=data1, partition_id="dynamic") print(" ✓ Added 2 samples") print(f" ✓ Got BatchMeta: {meta1} samples") @@ -148,9 +154,9 @@ def demonstrate_dynamic_expansion(): }, batch_size=3, ) - meta2 = tq.put(data=data2, partition_id="dynamic") + meta2 = tq_client.put(data=data2, partition_id="dynamic") - all_meta = tq.get_meta( + all_meta = tq_client.get_meta( data_fields=["field1", "field2"], batch_size=5, partition_id="dynamic", task_name="dynamic_task" ) print(" ✓ Added 3 more samples (total: 5)") @@ -165,7 +171,7 @@ def demonstrate_dynamic_expansion(): }, batch_size=2, ) - meta3 = tq.put(data=data3, metadata=meta1) + meta3 = tq_client.put(data=data3, metadata=meta1) print(" ✓ Added 2 samples with new field 'field3'") print(f" ✓ Got BatchMeta: {meta3} for newly put data with new field") @@ -174,7 +180,7 @@ def demonstrate_dynamic_expansion(): print(" ✓ Columns auto-expand: Can add new fields anytime") # Cleanup - tq.clear_partition(partition_id="dynamic") + tq_client.clear_partition(partition_id="dynamic") tq.close() ray.shutdown() @@ -190,6 +196,8 @@ def demonstrate_default_consumption_sample_strategy(): tq.init() + tq_client = tq.get_client() + # Add 6 samples print("\n[Setup] Adding 6 samples...") all_data = TensorDict( @@ -198,22 +206,22 @@ def demonstrate_default_consumption_sample_strategy(): }, batch_size=6, ) - tq.put(data=all_data, partition_id="sampling") + tq_client.put(data=all_data, partition_id="sampling") print(" ✓ 6 samples added") # First get - should get samples 0,1,2 print("\n[Task A, Get 1] Requesting 3 samples...") - meta1 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") + meta1 = tq_client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") print(f" ✓ Got samples: {meta1.global_indexes}") # Second get - should get samples 3,4,5 (no duplicates!) print("\n[Task A, Get 2] Requesting 3 more samples...") - meta2 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") + meta2 = tq_client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") print(f" ✓ Got samples: {meta2.global_indexes}") # Third get - should get samples 0,1 print("\n[Task B, Get 1] Requesting 2 samples...") - meta3 = tq.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B") + meta3 = tq_client.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -224,7 +232,7 @@ def demonstrate_default_consumption_sample_strategy(): print(" ✓ Third get (Task B): samples 0,1") # Cleanup - tq.clear_partition(partition_id="sampling") + tq_client.clear_partition(partition_id="sampling") tq.close() ray.shutdown() diff --git a/tutorial/05_custom_sampler.py b/tutorial/05_custom_sampler.py index bdca835..e30487b 100644 --- a/tutorial/05_custom_sampler.py +++ b/tutorial/05_custom_sampler.py @@ -186,6 +186,8 @@ def demonstrate_random_sampler_with_replacement(): sampler = RandomSamplerWithReplacement() setup_transfer_queue_with_sampler(sampler) + tq_client = tq.get_client() + # Add 5 samples print("\n[Step 1] Adding 5 samples...") data = TensorDict( @@ -194,22 +196,22 @@ def demonstrate_random_sampler_with_replacement(): }, batch_size=5, ) - tq.put(data=data, partition_id="test") + tq_client.put(data=data, partition_id="test") print(" ✓ 5 samples added") # Get batch 1 (should get 2 random samples) print("\n[Step 2] Get batch 1 (2 samples)...") - meta1 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta1 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta1.global_indexes}") # Get batch 2 (should get 1 random sample with replacement - may have duplicate with previous batch!) print("\n[Step 3] Get batch 2 (1 sample)...") - meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") + meta2 = tq_client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta2.global_indexes}") # Get batch 3 (should get 2 random samples with replacement - may have duplicate with previous batches!) print("\n[Step 4] Get batch 3 (2 samples)...") - meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta3 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -219,7 +221,7 @@ def demonstrate_random_sampler_with_replacement(): print(f" ✓ All sampled: {all_sampled}") # Cleanup - tq.clear_partition(partition_id="test") + tq_client.clear_partition(partition_id="test") tq.close() ray.shutdown() @@ -234,6 +236,8 @@ def demonstrate_random_sampler_without_replacement(): sampler = RandomSamplerWithoutReplacement() setup_transfer_queue_with_sampler(sampler) + tq_client = tq.get_client() + # Add 6 samples print("\n[Step 1] Adding 6 samples...") data = TensorDict( @@ -242,22 +246,22 @@ def demonstrate_random_sampler_without_replacement(): }, batch_size=6, ) - tq.put(data=data, partition_id="test") + tq_client.put(data=data, partition_id="test") print(" ✓ 6 samples added") # Get batch 1 (should get 3 random samples without replacement) print("\n[Step 2] Get batch 1 (3 samples)...") - meta1 = tq.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task") + meta1 = tq_client.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta1.global_indexes}") # Get batch 2 (should randomly get 1 sample that are different from previous batch) print("\n[Step 3] Get batch 2 (1 samples)...") - meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") + meta2 = tq_client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta2.global_indexes}") # Get batch 3 (should randomly get 2 samples that are different from previous batch) print("\n[Step 4] Get batch 3 (2 samples)...") - meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta3 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" ✓ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -267,7 +271,7 @@ def demonstrate_random_sampler_without_replacement(): print(f" ✓ Batch 3: {meta3.global_indexes} (none left)") # Cleanup - tq.clear_partition(partition_id="test") + tq_client.clear_partition(partition_id="test") tq.close() ray.shutdown() @@ -282,6 +286,8 @@ def demonstrate_priority_sampler(): sampler = PrioritySampler() setup_transfer_queue_with_sampler(sampler) + tq_client = tq.get_client() + # Add 8 samples print("\n[Step 1] Adding 8 samples...") data = TensorDict( @@ -290,7 +296,7 @@ def demonstrate_priority_sampler(): }, batch_size=8, ) - tq.put(data=data, partition_id="test") + tq_client.put(data=data, partition_id="test") print(" ✓ 8 samples added") time.sleep(1) @@ -303,7 +309,7 @@ def demonstrate_priority_sampler(): print(f"Priority scores: {priority_scores}") # Get batch using priority sampling - meta1 = tq.get_meta( + meta1 = tq_client.get_meta( data_fields=["input"], batch_size=1, partition_id="test", @@ -315,7 +321,7 @@ def demonstrate_priority_sampler(): # Get another batch print("\n[Step 3] Get another batch (2 samples)...") - meta2 = tq.get_meta( + meta2 = tq_client.get_meta( data_fields=["input"], batch_size=2, partition_id="test", @@ -331,7 +337,7 @@ def demonstrate_priority_sampler(): print(f" ✓ Batch 2 high-priority indices: {[i for i in meta2.global_indexes if priority_scores[i] >= 0.1]}") # Cleanup - tq.clear_partition(partition_id="test") + tq_client.clear_partition(partition_id="test") tq.close() ray.shutdown() diff --git a/tutorial/06_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py index d95a3c8..c60b4e3 100644 --- a/tutorial/06_streaming_dataloader.py +++ b/tutorial/06_streaming_dataloader.py @@ -123,6 +123,8 @@ def generate_worker(rank_id: int, num_samples: int = 20): # Need to call tq.init() in each process tq.init() + tq_client = tq.get_client() + # Generate and put samples into the queue for i in range(num_samples): # Create unique sequence ID for this sample @@ -137,7 +139,7 @@ def generate_worker(rank_id: int, num_samples: int = 20): print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue") # Put data into the specified partition - tq.put(data, partition_id="train") + tq_client.put(data, partition_id="train") print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue") From e9e5872a01630a69b6438c1e66c4500ef88af417 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 9 Feb 2026 21:25:42 +0800 Subject: [PATCH 34/34] fix pre-commit Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 733df24..8e4f8c3 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -753,7 +753,7 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: # ==================== Low-Level Native API ==================== # For low-level API support, please refer to transfer_queue/client.py for details. def get_client(): - global _TRANSFER_QUEUE_CLIENT + """Get a TransferQueueClient for using low-level API""" if _TRANSFER_QUEUE_CLIENT is None: raise RuntimeError("Please initialize the TransferQueue first by calling `tq.init()`!") return _TRANSFER_QUEUE_CLIENT