From 0e87f97549a09fba2a96ff42ba6088d575e21be7 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Fri, 6 Feb 2026 23:56:40 +0800
Subject: [PATCH 01/34] add kv interface
Signed-off-by: 0oshowero0
---
README.md | 6 +-
pyproject.toml | 3 +-
requirements.txt | 3 +-
tests/test_async_simple_storage_manager.py | 15 +-
tests/test_client.py | 270 +++-
tests/test_controller.py | 202 +++
tests/test_controller_data_partitions.py | 59 +
tests/test_kv_interface.py | 539 ++++++++
tests/test_kv_storage_manager.py | 116 +-
tests/test_metadata.py | 114 +-
transfer_queue/__init__.py | 58 +-
transfer_queue/client.py | 579 ++++++---
transfer_queue/config.yaml | 31 +
transfer_queue/controller.py | 167 ++-
.../dataloader/streaming_dataset.py | 46 +-
transfer_queue/interface.py | 1136 +++++++++++++++++
transfer_queue/metadata.py | 51 +-
transfer_queue/storage/managers/base.py | 41 +-
transfer_queue/storage/managers/factory.py | 36 +-
.../storage/managers/mooncake_manager.py | 7 +-
.../managers/simple_backend_manager.py | 23 +-
.../storage/managers/yuanrong_manager.py | 7 +-
transfer_queue/utils/common.py | 68 +-
transfer_queue/utils/zmq_utils.py | 39 +
tutorial/01_core_components.py | 92 +-
tutorial/02_kv_interface.py | 205 +++
...ta_concepts.py => 03_metadata_concepts.py} | 68 +-
...ller.py => 04_understanding_controller.py} | 101 +-
...custom_sampler.py => 05_custom_sampler.py} | 77 +-
...taloader.py => 06_streaming_dataloader.py} | 80 +-
30 files changed, 3455 insertions(+), 784 deletions(-)
create mode 100644 tests/test_kv_interface.py
create mode 100644 transfer_queue/config.yaml
create mode 100644 transfer_queue/interface.py
create mode 100644 tutorial/02_kv_interface.py
rename tutorial/{02_metadata_concepts.py => 03_metadata_concepts.py} (89%)
rename tutorial/{03_understanding_controller.py => 04_understanding_controller.py} (76%)
rename tutorial/{04_custom_sampler.py => 05_custom_sampler.py} (83%)
rename tutorial/{05_streaming_dataloader.py => 06_streaming_dataloader.py} (84%)
diff --git a/README.md b/README.md
index b69176d..2c7d4f0 100644
--- a/README.md
+++ b/README.md
@@ -73,10 +73,10 @@ This class encapsulates the core interaction logic within the TransferQueue syst
Currently, we support the following storage backends:
-- SimpleStorageUnit: A basic CPU memory storage with minimal data format constraints and easy usability.
+- SimpleStorage: A basic CPU memory storage with minimal data format constraints and easy usability.
- [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem) (beta, [#PR107](https://github.com/TransferQueue/TransferQueue/pull/107), [#PR96](https://github.com/TransferQueue/TransferQueue/pull/96)): An Ascend native data system that provides hierarchical storage interfaces including HBM/DRAM/SSD.
-- [Mooncake Store](https://github.com/kvcache-ai/Mooncake) (alpha, [#PR162](https://github.com/TransferQueue/TransferQueue/pull/162)): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM.
-- [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html) (alpha, [#PR167](https://github.com/TransferQueue/TransferQueue/pull/167)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors.
+- [MooncakeStore](https://github.com/kvcache-ai/Mooncake) (alpha, [#PR162](https://github.com/TransferQueue/TransferQueue/pull/162)): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM.
+- [RayRDT](https://docs.ray.io/en/master/ray-core/direct-transport.html) (alpha, [#PR167](https://github.com/TransferQueue/TransferQueue/pull/167)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors.
Among them, `SimpleStorageUnit` serves as our default storage backend, coordinated by the `AsyncSimpleStorageManager` class. Each storage unit can be deployed on a separate node, allowing for distributed data management.
diff --git a/pyproject.toml b/pyproject.toml
index f853970..35d6524 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -126,5 +126,6 @@ yuanrong = [
# This is the rough equivalent of package_data={'': ['version/*']}
[tool.setuptools.package-data]
transfer_queue = [
- "version/*",
+ "version/*",
+ "*.yaml"
]
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 1da090c..8cfccac 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,4 +4,5 @@ pyzmq
hydra-core
numpy<2.0.0
msgspec
-psutil
\ No newline at end of file
+psutil
+omegaconf
\ No newline at end of file
diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py
index 3949a0d..5187e16 100644
--- a/tests/test_async_simple_storage_manager.py
+++ b/tests/test_async_simple_storage_manager.py
@@ -62,8 +62,7 @@ async def mock_async_storage_manager():
)
config = {
- "storage_unit_infos": storage_unit_infos,
- "controller_info": controller_info,
+ "zmq_info": storage_unit_infos,
}
# Mock the handshake process entirely to avoid ZMQ complexity
@@ -199,8 +198,7 @@ async def test_async_storage_manager_mapping_functions():
)
config = {
- "storage_unit_infos": storage_unit_infos,
- "controller_info": controller_info,
+ "zmq_info": storage_unit_infos,
}
# Mock ZMQ operations
@@ -230,7 +228,7 @@ async def test_async_storage_manager_mapping_functions():
mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize())
# Create manager
- manager = AsyncSimpleStorageManager(config)
+ manager = AsyncSimpleStorageManager(controller_info, config)
# Test round-robin mapping for 3 storage units
# global_index -> storage_unit mapping: 0->storage_0, 1->storage_1, 2->storage_2,
@@ -266,7 +264,7 @@ async def test_async_storage_manager_error_handling():
}
# Mock controller info
- controller_infos = ZMQServerInfo(
+ controller_info = ZMQServerInfo(
role=TransferQueueRole.CONTROLLER,
id="controller_0",
ip="127.0.0.1",
@@ -274,8 +272,7 @@ async def test_async_storage_manager_error_handling():
)
config = {
- "storage_unit_infos": storage_unit_infos,
- "controller_info": controller_infos,
+ "zmq_info": storage_unit_infos,
}
# Mock ZMQ operations
@@ -305,7 +302,7 @@ async def test_async_storage_manager_error_handling():
mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize())
# Create manager
- manager = AsyncSimpleStorageManager(config)
+ manager = AsyncSimpleStorageManager(controller_info, config)
# Mock operations that raise exceptions
manager._put_to_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock PUT error"))
diff --git a/tests/test_client.py b/tests/test_client.py
index 38f140b..5d3360f 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -34,7 +34,7 @@
FieldMeta,
SampleMeta,
)
-from transfer_queue.utils.enum_utils import TransferQueueRole # noqa: E402
+from transfer_queue.utils.enum_utils import ProductionStatus, TransferQueueRole # noqa: E402
from transfer_queue.utils.zmq_utils import ( # noqa: E402
ZMQMessage,
ZMQRequestType,
@@ -144,6 +144,12 @@ def _handle_requests(self):
"message": "Consumption reset successfully",
}
response_type = ZMQRequestType.RESET_CONSUMPTION_RESPONSE
+ elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS:
+ response_body = self._mock_kv_retrieve_keys(request_msg.body)
+ response_type = ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE
+ elif request_msg.request_type == ZMQRequestType.KV_LIST:
+ response_body = self._mock_kv_list(request_msg.body)
+ response_type = ZMQRequestType.KV_LIST_RESPONSE
else:
response_body = {"error": f"Unknown request type: {request_msg.request_type}"}
response_type = ZMQRequestType.CLEAR_META_RESPONSE
@@ -174,7 +180,7 @@ def _mock_batch_meta(self, request_body):
name=field_name,
dtype=None,
shape=None,
- production_status=0,
+ production_status=ProductionStatus.NOT_PRODUCED,
)
fields.append(field_meta)
sample = SampleMeta(
@@ -187,6 +193,80 @@ def _mock_batch_meta(self, request_body):
return {"metadata": metadata}
+ def _mock_kv_retrieve_keys(self, request_body):
+ """Mock KV retrieve keys response."""
+ keys = request_body.get("keys", [])
+ create = request_body.get("create", False)
+ partition_id = request_body.get("partition_id", "")
+
+ # Initialize key tracking if not exists
+ if not hasattr(self, "_kv_partition_keys"):
+ self._kv_partition_keys = {}
+
+ # Generate global indexes for the keys
+ start_index = self._get_next_kv_index(partition_id)
+ global_indexes = list(range(start_index, start_index + len(keys)))
+
+ # Create metadata for each key
+ samples = []
+ for i, key in enumerate(keys):
+ field_meta = FieldMeta(
+ name="data",
+ dtype=torch.float32,
+ shape=torch.Size([1, 10]),
+ production_status=ProductionStatus.READY_FOR_CONSUME,
+ )
+ sample = SampleMeta(
+ partition_id=partition_id,
+ global_index=global_indexes[i],
+ fields={"data": field_meta},
+ )
+ samples.append(sample)
+
+ metadata = BatchMeta(samples=samples)
+
+ # Store keys for this partition (only when create=True)
+ if create:
+ if partition_id not in self._kv_partition_keys:
+ self._kv_partition_keys[partition_id] = []
+ self._kv_partition_keys[partition_id].extend(keys)
+
+ # Update the next index for this partition
+ if global_indexes:
+ self._update_kv_index(partition_id, global_indexes[-1] + 1)
+
+ return {"metadata": metadata}
+
+ def _mock_kv_list(self, request_body):
+ """Mock KV list response."""
+ partition_id = request_body.get("partition_id", "")
+
+ # Initialize key tracking if not exists
+ if not hasattr(self, "_kv_partition_keys"):
+ self._kv_partition_keys = {}
+
+ # Return cached keys for this partition
+ keys = self._kv_partition_keys.get(partition_id, [])
+ return {"keys": keys, "custom_meta": [{} for _ in range(len(keys))]}
+
+ def _get_next_kv_index(self, partition_id):
+ """Get next available index for KV keys in partition."""
+ if not hasattr(self, "_kv_index_map"):
+ self._kv_index_map = {}
+ if partition_id not in self._kv_index_map:
+ self._kv_index_map[partition_id] = 0
+ # Also initialize key tracking
+ if not hasattr(self, "_kv_partition_keys"):
+ self._kv_partition_keys = {}
+ self._kv_partition_keys[partition_id] = []
+ return self._kv_index_map[partition_id]
+
+ def _update_kv_index(self, partition_id, next_index):
+ """Update next available index for KV keys."""
+ if not hasattr(self, "_kv_index_map"):
+ self._kv_index_map = {}
+ self._kv_index_map[partition_id] = next_index
+
def stop(self):
self.running = False
time.sleep(0.2) # Give thread time to stop
@@ -318,9 +398,9 @@ def client_setup(mock_controller, mock_storage):
):
config = {
"controller_info": mock_controller.zmq_server_info,
- "storage_unit_infos": {mock_storage.storage_id: mock_storage.zmq_server_info},
+ "zmq_info": {mock_storage.storage_id: mock_storage.zmq_server_info},
}
- client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
+ client.initialize_storage_manager(manager_type="SimpleStorage", config=config)
# Mock all storage manager methods to avoid real ZMQ operations
async def mock_put_data(data, metadata):
@@ -411,9 +491,9 @@ def test_single_controller_multiple_storages():
):
config = {
"controller_info": controller.zmq_server_info,
- "storage_unit_infos": {s.storage_id: s.zmq_server_info for s in storages},
+ "zmq_info": {s.storage_id: s.zmq_server_info for s in storages},
}
- client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
+ client.initialize_storage_manager(manager_type="SimpleStorage", config=config)
# Mock all storage manager methods to avoid real ZMQ operations
async def mock_put_data(data, metadata):
@@ -850,10 +930,10 @@ def test_set_custom_meta_sync(self, client_setup):
metadata = client.get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
# Set custom_meta on the metadata
metadata.update_custom_meta(
- {
- 0: {"input_ids": {"token_count": 100}},
- 1: {"input_ids": {"token_count": 120}},
- }
+ [
+ {"input_ids": {"token_count": 100}},
+ {"input_ids": {"token_count": 120}},
+ ]
)
# Call set_custom_meta with metadata (BatchMeta)
@@ -869,12 +949,174 @@ async def test_set_custom_meta_async(self, client_setup):
metadata = await client.async_get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
# Set custom_meta on the metadata
metadata.update_custom_meta(
- {
- 0: {"input_ids": {"token_count": 100}},
- 1: {"input_ids": {"token_count": 120}},
- }
+ [
+ {"input_ids": {"token_count": 100}},
+ {"input_ids": {"token_count": 120}},
+ ]
)
# Call async_set_custom_meta with metadata (BatchMeta)
await client.async_set_custom_meta(metadata)
print("✓ async_set_custom_meta async method works")
+
+
+# =====================================================
+# KV Interface Tests
+# =====================================================
+
+
+class TestClientKVInterface:
+ """Tests for client KV interface methods."""
+
+ @pytest.mark.asyncio
+ async def test_async_kv_retrieve_keys_single(self, client_setup):
+ """Test async_kv_retrieve_keys with single key."""
+ client, _, _ = client_setup
+
+ # Test async_kv_retrieve_keys with single key
+ metadata = await client.async_kv_retrieve_keys(
+ keys="test_key_1",
+ partition_id="test_partition",
+ create=True,
+ )
+
+ # Verify metadata structure
+ assert metadata is not None
+ assert hasattr(metadata, "global_indexes")
+ assert hasattr(metadata, "size")
+ assert metadata.size == 1
+
+ @pytest.mark.asyncio
+ async def test_async_kv_retrieve_keys_multiple(self, client_setup):
+ """Test async_kv_retrieve_keys with multiple keys."""
+ client, _, _ = client_setup
+
+ # Test async_kv_retrieve_keys with multiple keys
+ keys = ["key_a", "key_b", "key_c"]
+ metadata = await client.async_kv_retrieve_keys(
+ keys=keys,
+ partition_id="test_partition",
+ create=True,
+ )
+
+ # Verify metadata structure
+ assert metadata is not None
+ assert hasattr(metadata, "global_indexes")
+ assert hasattr(metadata, "size")
+ assert metadata.size == 3
+
+ @pytest.mark.asyncio
+ async def test_async_kv_retrieve_keys_create_false(self, client_setup):
+ """Test async_kv_retrieve_keys with create=False (retrieve existing keys)."""
+ client, _, _ = client_setup
+
+ # create some keys
+ await client.async_kv_retrieve_keys(
+ keys="existing_key",
+ partition_id="existing_partition",
+ create=True,
+ )
+
+ # Then retrieve them with create=False
+ metadata = await client.async_kv_retrieve_keys(
+ keys="existing_key",
+ partition_id="existing_partition",
+ create=False,
+ )
+
+ # Verify metadata structure
+ assert metadata is not None
+ assert metadata.size == 1
+
+ @pytest.mark.asyncio
+ async def test_async_kv_retrieve_keys_invalid_keys_type(self, client_setup):
+ """Test async_kv_retrieve_keys raises error with invalid keys type."""
+ client, _, _ = client_setup
+
+ # Test with invalid keys type (not string or list)
+ with pytest.raises(TypeError):
+ await client.async_kv_retrieve_keys(
+ keys=123, # Invalid type
+ partition_id="test_partition",
+ create=True,
+ )
+
+ @pytest.mark.asyncio
+ async def test_async_kv_list_empty_partition(self, client_setup):
+ """Test async_kv_list with empty partition."""
+ client, _, _ = client_setup
+
+ # Test async_kv_list with empty partition
+ keys, custom_meta = await client.async_kv_list(partition_id="empty_partition")
+
+ # Should return empty list for partition with no keys
+ assert keys == []
+ assert custom_meta == []
+
+ @pytest.mark.asyncio
+ async def test_async_kv_list_with_keys(self, client_setup):
+ """Test async_kv_list returns keys after they are registered."""
+ client, mock_controller, _ = client_setup
+
+ # First register some keys
+ await client.async_kv_retrieve_keys(
+ keys=["key_1", "key_2"],
+ partition_id="kv_partition",
+ create=True,
+ )
+
+ # Then list them
+ keys, custom_meta = await client.async_kv_list(partition_id="kv_partition")
+
+ # Verify keys are returned
+ assert keys is not None
+ assert len(keys) >= 2
+ assert "key_1" in keys
+ assert "key_2" in keys
+ assert custom_meta == [{}, {}]
+
+ @pytest.mark.asyncio
+ async def test_async_kv_list_multiple_partitions(self, client_setup):
+ """Test async_kv_list with multiple partitions."""
+ client, _, _ = client_setup
+
+ # Create keys in different partitions
+ await client.async_kv_retrieve_keys(
+ keys="partition_a_key",
+ partition_id="partition_a",
+ create=True,
+ )
+ await client.async_kv_retrieve_keys(
+ keys="partition_b_key",
+ partition_id="partition_b",
+ create=True,
+ )
+
+ # List keys for each partition
+ keys_a, custom_meta_a = await client.async_kv_list(partition_id="partition_a")
+ keys_b, custom_meta_b = await client.async_kv_list(partition_id="partition_b")
+
+ # Verify keys are isolated per partition
+ assert "partition_a_key" in keys_a
+ assert "partition_b_key" not in keys_a
+ assert "partition_b_key" in keys_b
+ assert "partition_a_key" not in keys_b
+ assert custom_meta_a == [{}]
+ assert custom_meta_b == [{}]
+
+ def test_kv_retrieve_keys_type_validation(self, client_setup):
+ """Test synchronous kv_retrieve_keys type validation."""
+ import asyncio
+
+ client, _, _ = client_setup
+
+ # Test with non-string element in list
+ async def test_invalid_list():
+ with pytest.raises(TypeError):
+ await client.async_kv_retrieve_keys(
+ keys=["valid_key", 123], # Invalid: 123 is not a string
+ partition_id="test_partition",
+ create=True,
+ )
+
+ asyncio.run(test_invalid_list())
diff --git a/tests/test_controller.py b/tests/test_controller.py
index 09f5778..f3a5c26 100644
--- a/tests/test_controller.py
+++ b/tests/test_controller.py
@@ -760,3 +760,205 @@ def test_controller_with_custom_meta(self, ray_setup):
# Clean up
ray.get(tq_controller.clear_partition.remote(partition_id))
+
+
+class TestTransferQueueControllerKvInterface:
+ """End-to-end tests for TransferQueueController KV interface functionality.
+
+ Tests for kv_retrieve_keys method that supports key-value interface operations
+ across the controller and partition layers.
+ """
+
+ def test_controller_kv_retrieve_keys_create_mode(self, ray_setup):
+ """Test kv_retrieve_keys with create=True creates new keys in partition."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "kv_test_partition"
+
+ # Retrieve keys with create=True - should create partition and keys
+ keys = ["key_a", "key_b", "key_c"]
+ metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True))
+
+ # Verify partition was created
+ partitions = ray.get(tq_controller.list_partitions.remote())
+ assert partition_id in partitions
+
+ # Verify metadata contains correct number of global_indexes
+ assert len(metadata.global_indexes) == len(keys)
+
+ # Verify partition has keys_mapping
+ partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
+ assert "key_a" in partition.keys_mapping
+ assert "key_b" in partition.keys_mapping
+ assert "key_c" in partition.keys_mapping
+ assert metadata.global_indexes[0] == partition.keys_mapping["key_a"]
+ assert metadata.global_indexes[1] == partition.keys_mapping["key_b"]
+ assert metadata.global_indexes[2] == partition.keys_mapping["key_c"]
+ assert partition.revert_keys_mapping[metadata.global_indexes[0]] == "key_a"
+ assert partition.revert_keys_mapping[metadata.global_indexes[1]] == "key_b"
+ assert partition.revert_keys_mapping[metadata.global_indexes[2]] == "key_c"
+
+ print("✓ kv_retrieve_keys with create=True creates keys correctly")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_id))
+
+ def test_controller_kv_retrieve_keys_existing_keys(self, ray_setup):
+ """Test kv_retrieve_keys retrieves existing keys correctly."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "kv_existing_test"
+
+ # First, create some keys
+ keys = ["existing_key_1", "existing_key_2"]
+ ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True))
+
+ # Retrieve the same keys again (should return existing)
+ retrieved_metadata = ray.get(
+ tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False)
+ )
+
+ # Verify the same global_indexes are returned
+ assert len(retrieved_metadata.global_indexes) == len(keys)
+
+ print("✓ kv_retrieve_keys retrieves existing keys correctly")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_id))
+
+ def test_controller_kv_retrieve_keys_non_existent_without_create(self, ray_setup):
+ """Test kv_retrieve_keys raises error for non-existent keys without create."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "kv_nonexistent_test"
+
+ # Create partition first
+ ray.get(tq_controller.kv_retrieve_keys.remote(keys=["initial_key"], partition_id=partition_id, create=True))
+
+ # Try to retrieve non-existent key without create
+ batch_meta = ray.get(
+ tq_controller.kv_retrieve_keys.remote(keys=["nonexistent_key"], partition_id=partition_id, create=False)
+ )
+ assert batch_meta.size == 0
+
+ print("✓ kv_retrieve_keys return an empty BatchMeta for non-existent keys without create")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_id))
+
+ def test_controller_kv_retrieve_keys_empty_partition_without_create(self, ray_setup):
+ """Test kv_retrieve_keys raises error for non-existent partition without create."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "nonexistent_partition"
+
+ batch_meta = ray.get(
+ tq_controller.kv_retrieve_keys.remote(keys=["key_1"], partition_id=partition_id, create=False)
+ )
+ assert batch_meta.size == 0
+
+ print("✓ kv_retrieve_keys return an empty BatchMeta for non-existent partition_id without create")
+
+ def test_controller_kv_retrieve_keys_with_production_status(self, ray_setup):
+ """Test kv_retrieve_keys works with production status update."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "kv_production_test"
+
+ # Create keys
+ keys = ["sample_1", "sample_2", "sample_3"]
+ metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True))
+ global_indexes = metadata.global_indexes
+
+ # Update production status
+ dtypes = {idx: {"data": "torch.float32"} for idx in global_indexes}
+ shapes = {idx: {"data": (64,)} for idx in global_indexes}
+ success = ray.get(
+ tq_controller.update_production_status.remote(
+ partition_id=partition_id,
+ global_indexes=global_indexes,
+ field_names=["data"],
+ dtypes=dtypes,
+ shapes=shapes,
+ )
+ )
+ assert success
+
+ # Retrieve keys again (should include production info)
+ retrieved_metadata = ray.get(
+ tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False)
+ )
+
+ # Verify production status is available
+ assert len(retrieved_metadata.samples) == len(keys)
+ for sample in retrieved_metadata.samples:
+ assert "data" in sample.fields
+ assert sample.fields["data"].dtype == "torch.float32"
+ assert sample.fields["data"].shape == (64,)
+
+ print("✓ kv_retrieve_keys works with production status")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_id))
+
+ def test_controller_kv_retrieve_keys_with_custom_meta(self, ray_setup):
+ """Test kv_retrieve_keys preserves custom_meta through retrieve."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "kv_custom_meta_test"
+
+ # Create keys
+ keys = ["key_1", "key_2"]
+ metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True))
+
+ # Set custom_meta
+ custom_meta = {
+ partition_id: {
+ metadata.global_indexes[0]: {"score": 0.9, "tag": "A"},
+ metadata.global_indexes[1]: {"score": 0.8, "tag": "B"},
+ }
+ }
+ ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=custom_meta))
+
+ # Retrieve keys and verify custom_meta
+ retrieved_metadata = ray.get(
+ tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False)
+ )
+
+ # Verify custom_meta is preserved
+ all_custom_meta = retrieved_metadata.get_all_custom_meta()
+ assert len(all_custom_meta) == 2
+ assert all_custom_meta[metadata.global_indexes[0]]["score"] == 0.9
+ assert all_custom_meta[metadata.global_indexes[1]]["tag"] == "B"
+
+ print("✓ kv_retrieve_keys preserves custom_meta")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_id))
+
+ def test_controller_kv_interface_multiple_partitions(self, ray_setup):
+ """Test KV interface works correctly across multiple partitions."""
+ tq_controller = TransferQueueController.remote()
+
+ # Create keys in partition 1
+ partition_1 = "partition_kv_1"
+ keys_1 = ["p1_key_a", "p1_key_b"]
+ ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys_1, partition_id=partition_1, create=True))
+
+ # Create keys in partition 2
+ partition_2 = "partition_kv_2"
+ keys_2 = ["p2_key_x", "p2_key_y", "p2_key_z"]
+ ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys_2, partition_id=partition_2, create=True))
+
+ # Verify partitions are isolated
+ partition_1_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_1))
+ partition_2_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_2))
+
+ assert "p1_key_a" in partition_1_snapshot.keys_mapping
+ assert "p1_key_b" in partition_1_snapshot.keys_mapping
+ assert "p2_key_x" in partition_2_snapshot.keys_mapping
+ assert "p2_key_z" in partition_2_snapshot.keys_mapping
+
+ # Verify cross-partition access is isolated
+ assert "p2_key_x" not in partition_1_snapshot.keys_mapping
+ assert "p1_key_a" not in partition_2_snapshot.keys_mapping
+
+ print("✓ KV interface maintains partition isolation")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_1))
+ ray.get(tq_controller.clear_partition.remote(partition_2))
diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py
index 31478ba..6eba7a0 100644
--- a/tests/test_controller_data_partitions.py
+++ b/tests/test_controller_data_partitions.py
@@ -941,3 +941,62 @@ def test_custom_meta_cleared_with_data(self):
result = partition.get_custom_meta([0, 1])
assert 0 not in result
assert 1 in result # Sample 1 should still have custom_meta
+
+
+class TestDataPartitionStatusKvInterface:
+ """Unit tests for DataPartitionStatus KV interface functionality.
+
+ Tests for the keys_mapping and kv_retrieve_keys methods that support
+ key-value interface operations within a partition.
+ """
+
+ def test_kv_retrieve_keys_with_existing_keys(self):
+ """Test kv_retrieve_keys returns correct global_indexes for existing keys."""
+ from transfer_queue.controller import DataPartitionStatus
+
+ partition = DataPartitionStatus(partition_id="kv_test_partition")
+
+ # Simulate keys being registered (as would happen during kv_put)
+ partition.keys_mapping = {"key_a": 0, "key_b": 1, "key_c": 2}
+
+ # Retrieve keys
+ global_indexes = partition.kv_retrieve_keys(["key_a", "key_b", "key_c"])
+
+ assert global_indexes == [0, 1, 2]
+
+ def test_kv_retrieve_keys_with_nonexistent_keys(self):
+ """Test kv_retrieve_keys returns None for keys that don't exist."""
+ from transfer_queue.controller import DataPartitionStatus
+
+ partition = DataPartitionStatus(partition_id="kv_test_partition")
+
+ # Simulate some keys being registered
+ partition.keys_mapping = {"existing_key": 5}
+
+ # Retrieve mixed existing and non-existing keys
+ global_indexes = partition.kv_retrieve_keys(["existing_key", "nonexistent_key"])
+
+ assert global_indexes == [5, None]
+
+ def test_kv_retrieve_keys_empty_list(self):
+ """Test kv_retrieve_keys handles empty key list."""
+ from transfer_queue.controller import DataPartitionStatus
+
+ partition = DataPartitionStatus(partition_id="kv_test_partition")
+
+ global_indexes = partition.kv_retrieve_keys([])
+
+ assert global_indexes == []
+
+ def test_kv_retrieve_keys_partial_match(self):
+ """Test kv_retrieve_keys with partial key matches."""
+ from transfer_queue.controller import DataPartitionStatus
+
+ partition = DataPartitionStatus(partition_id="kv_test_partition")
+
+ partition.keys_mapping = {"key_1": 10, "key_2": 20, "key_3": 30}
+
+ # Request only some of the keys
+ global_indexes = partition.kv_retrieve_keys(["key_1", "key_3"])
+
+ assert global_indexes == [10, 30]
diff --git a/tests/test_kv_interface.py b/tests/test_kv_interface.py
new file mode 100644
index 0000000..24d74af
--- /dev/null
+++ b/tests/test_kv_interface.py
@@ -0,0 +1,539 @@
+# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+# Copyright 2025 The TransferQueue Team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for kv interface in transfer_queue.interface."""
+
+import asyncio
+import sys
+from pathlib import Path
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+import torch
+from tensordict import TensorDict
+
+# Setup path
+parent_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(parent_dir))
+
+from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402
+from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402
+
+
+def create_batch_meta(global_indexes, fields_data):
+ """Helper to create BatchMeta for testing."""
+ samples = []
+ for sample_id, global_idx in enumerate(global_indexes):
+ fields_dict = {}
+ for field_name, tensor in fields_data.items():
+ field_meta = FieldMeta(
+ name=field_name,
+ dtype=tensor.dtype,
+ shape=tensor.shape,
+ production_status=ProductionStatus.READY_FOR_CONSUME,
+ )
+ fields_dict[field_name] = field_meta
+ sample = SampleMeta(
+ partition_id="test_partition",
+ global_index=global_idx,
+ fields=fields_dict,
+ )
+ samples.append(sample)
+ return BatchMeta(samples=samples)
+
+
+class TestKVPut:
+ """Tests for kv_put function."""
+
+ def test_kv_put_with_fields(self):
+ """Test kv_put with fields parameter."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+
+ tensor_data = TensorDict(
+ {"text": torch.tensor([[1, 2, 3]]), "label": torch.tensor([0])},
+ batch_size=[1],
+ )
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_put
+
+ kv_put(key="test_key", partition_id="partition_1", fields=tensor_data, tag={"type": "test"})
+
+ # Verify kv_retrieve_keys was called
+ mock_client.kv_retrieve_keys.assert_called_once_with(keys=["test_key"], partition_id="partition_1", create=True)
+
+ # Verify update_custom_meta was called
+ mock_batch_meta.update_custom_meta.assert_called_once()
+
+ # Verify put was called
+ mock_client.put.assert_called_once()
+
+ def test_kv_put_with_dict_fields(self):
+ """Test kv_put converts dict to TensorDict correctly."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_put
+
+ # Test with dict containing tensor
+ kv_put(
+ key="test_key",
+ partition_id="partition_1",
+ fields={"text": torch.tensor([1, 2, 3])},
+ tag=None,
+ )
+
+ # Verify put was called
+ mock_client.put.assert_called_once()
+ call_args = mock_client.put.call_args
+ fields_arg = call_args[0][0]
+ assert "text" in fields_arg
+ # The dict should be converted to TensorDict
+ assert isinstance(fields_arg, TensorDict)
+
+ def test_kv_put_with_tag_only(self):
+ """Test kv_put with only tag parameter (no fields)."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_put
+
+ kv_put(key="test_key", partition_id="partition_1", fields=None, tag={"score": 0.9})
+
+ # Verify put was NOT called (only set_custom_meta)
+ mock_client.put.assert_not_called()
+ mock_client.set_custom_meta.assert_called_once_with(mock_batch_meta)
+
+ def test_kv_put_raises_error_without_fields_and_tag(self):
+ """Test kv_put raises ValueError when neither fields nor tag provided."""
+ mock_client = MagicMock()
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_put
+
+ with pytest.raises(ValueError, match="Please provide at least one parameter"):
+ kv_put(key="test_key", partition_id="partition_1", fields=None, tag=None)
+
+
+class TestKVBatchPut:
+ """Tests for kv_batch_put function."""
+
+ def test_kv_batch_put_success(self):
+ """Test kv_batch_put with valid inputs."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+
+ batch_data = TensorDict(
+ {
+ "text": torch.tensor([[1, 2], [3, 4], [5, 6]]),
+ "label": torch.tensor([0, 1, 2]),
+ },
+ batch_size=[3],
+ )
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_batch_put
+
+ keys = ["key1", "key2", "key3"]
+ tags = [{"tag": "v1"}, {"tag": "v2"}, {"tag": "v3"}]
+
+ kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags)
+
+ mock_client.kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=True)
+ mock_batch_meta.update_custom_meta.assert_called_once_with(tags)
+ mock_client.put.assert_called_once()
+
+ def test_kv_batch_put_tags_length_mismatch(self):
+ """Test kv_batch_put raises error when tags length doesn't match keys."""
+ mock_client = MagicMock()
+
+ batch_data = TensorDict(
+ {
+ "text": torch.tensor([[1, 2], [3, 4], [5, 6]]),
+ "label": torch.tensor([0, 1, 2]),
+ },
+ batch_size=[3],
+ )
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_batch_put
+
+ keys = ["key1", "key2", "key3"]
+ tags = [{"tag": "v1"}, {"tag": "v2"}] # Only 2 tags for 3 keys
+
+ with pytest.raises(ValueError, match="does not match length of tags"):
+ kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags)
+
+
+class TestKVGet:
+ """Tests for kv_get function."""
+
+ def test_kv_get_single_key(self):
+ """Test kv_get with single key."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+ mock_client.get_data.return_value = TensorDict({"data": torch.tensor([1, 2])})
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_get
+
+ kv_get(keys="test_key", partition_id="partition_1")
+
+ # keys is passed directly (not wrapped in list) for single key
+ mock_client.kv_retrieve_keys.assert_called_once_with(keys="test_key", partition_id="partition_1", create=False)
+ mock_client.get_data.assert_called_once_with(mock_batch_meta)
+
+ def test_kv_get_multiple_keys(self):
+ """Test kv_get with multiple keys."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+ mock_client.get_data.return_value = TensorDict({"data": torch.tensor([1, 2])})
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_get
+
+ keys = ["key1", "key2", "key3"]
+ kv_get(keys=keys, partition_id="partition_1")
+
+ mock_client.kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=False)
+
+ def test_kv_get_with_fields(self):
+ """Test kv_get with specific fields."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_batch_meta.select_fields = MagicMock(return_value=mock_batch_meta)
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+ mock_client.get_data.return_value = TensorDict({"text": torch.tensor([1, 2])})
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_get
+
+ kv_get(keys="test_key", partition_id="partition_1", fields="text")
+
+ mock_batch_meta.select_fields.assert_called_once_with(["text"])
+
+
+class TestKVList:
+ """Tests for kv_list function."""
+
+ def test_kv_list_with_keys(self):
+ """Test kv_list returns keys and custom_meta."""
+ mock_client = MagicMock()
+ mock_client.kv_list.return_value = ["key1", "key2", "key3"]
+ mock_batch_meta = MagicMock()
+ mock_batch_meta.global_indexes = [0, 1, 2]
+ mock_batch_meta.get_all_custom_meta = MagicMock(return_value={0: {}, 1: {}, 2: {}})
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_list
+
+ keys, custom_meta = kv_list(partition_id="partition_1")
+
+ assert keys == ["key1", "key2", "key3"]
+ assert len(custom_meta) == 3
+
+ def test_kv_list_empty_partition(self):
+ """Test kv_list returns None when partition is empty."""
+ mock_client = MagicMock()
+ mock_client.kv_list.return_value = [], []
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_list
+
+ keys, custom_meta = kv_list(partition_id="empty_partition")
+
+ assert keys == []
+ assert custom_meta == []
+
+
+class TestKVClear:
+ """Tests for kv_clear function."""
+
+ def test_kv_clear_single_key(self):
+ """Test kv_clear with single key."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_batch_meta.size = 1
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_clear
+
+ kv_clear(keys="test_key", partition_id="partition_1")
+
+ mock_client.kv_retrieve_keys.assert_called_once_with(
+ keys=["test_key"], partition_id="partition_1", create=False
+ )
+ mock_client.clear_samples.assert_called_once_with(mock_batch_meta)
+
+ def test_kv_clear_multiple_keys(self):
+ """Test kv_clear with multiple keys."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_batch_meta.size = 3
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_clear
+
+ kv_clear(keys=["key1", "key2", "key3"], partition_id="partition_1")
+
+ mock_client.kv_retrieve_keys.assert_called_once_with(
+ keys=["key1", "key2", "key3"], partition_id="partition_1", create=False
+ )
+ mock_client.clear_samples.assert_called_once()
+
+
+class TestAsyncKVPut:
+ """Tests for async_kv_put function."""
+
+ def test_async_kv_put_with_fields(self):
+ """Test async_kv_put with fields parameter."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
+ mock_client.async_put = AsyncMock()
+ mock_client.async_set_custom_meta = AsyncMock()
+
+ tensor_data = TensorDict(
+ {"text": torch.tensor([[1, 2, 3]]), "label": torch.tensor([0])},
+ batch_size=[1],
+ )
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import async_kv_put
+
+ asyncio.run(
+ async_kv_put(key="test_key", partition_id="partition_1", fields=tensor_data, tag={"type": "test"})
+ )
+
+ mock_client.async_kv_retrieve_keys.assert_called_once_with(
+ keys=["test_key"], partition_id="partition_1", create=True
+ )
+ mock_batch_meta.update_custom_meta.assert_called_once()
+ mock_client.async_put.assert_called_once()
+
+ def test_async_kv_put_with_tag_only(self):
+ """Test async_kv_put with only tag (no fields)."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
+ mock_client.async_put = AsyncMock()
+ mock_client.async_set_custom_meta = AsyncMock()
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import async_kv_put
+
+ asyncio.run(async_kv_put(key="test_key", partition_id="partition_1", fields=None, tag={"score": 0.9}))
+
+ mock_client.async_put.assert_not_called()
+ mock_client.async_set_custom_meta.assert_called_once_with(mock_batch_meta)
+
+
+class TestAsyncKVBatchPut:
+ """Tests for async_kv_batch_put function."""
+
+ def test_async_kv_batch_put_success(self):
+ """Test async_kv_batch_put with valid inputs."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
+ mock_client.async_put = AsyncMock()
+ mock_client.async_set_custom_meta = AsyncMock()
+
+ batch_data = TensorDict(
+ {
+ "text": torch.tensor([[1, 2], [3, 4], [5, 6]]),
+ "label": torch.tensor([0, 1, 2]),
+ },
+ batch_size=[3],
+ )
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import async_kv_batch_put
+
+ keys = ["key1", "key2", "key3"]
+ tags = [{"tag": "v1"}, {"tag": "v2"}, {"tag": "v3"}]
+
+ asyncio.run(async_kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags))
+
+ mock_client.async_kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=True)
+ mock_batch_meta.update_custom_meta.assert_called_once_with(tags)
+ mock_client.async_put.assert_called_once()
+
+
+class TestAsyncKVGet:
+ """Tests for async_kv_get function."""
+
+ def test_async_kv_get_single_key(self):
+ """Test async_kv_get with single key."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
+ mock_client.async_get_data = AsyncMock(return_value=TensorDict({"data": torch.tensor([1, 2])}))
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import async_kv_get
+
+ asyncio.run(async_kv_get(keys="test_key", partition_id="partition_1"))
+
+ # keys is passed directly (not wrapped in list) for single key
+ mock_client.async_kv_retrieve_keys.assert_called_once_with(
+ keys="test_key", partition_id="partition_1", create=False
+ )
+ mock_client.async_get_data.assert_called_once_with(mock_batch_meta)
+
+
+class TestAsyncKVList:
+ """Tests for async_kv_list function."""
+
+ def test_async_kv_list_with_keys(self):
+ """Test async_kv_list returns keys and custom_meta."""
+ mock_client = MagicMock()
+ mock_client.async_kv_list = AsyncMock(return_value=["key1", "key2", "key3"])
+ mock_batch_meta = MagicMock()
+ mock_batch_meta.global_indexes = [0, 1, 2]
+ mock_batch_meta.get_all_custom_meta = MagicMock(return_value={0: {}, 1: {}, 2: {}})
+ mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import async_kv_list
+
+ keys, custom_meta = asyncio.run(async_kv_list(partition_id="partition_1"))
+
+ assert keys == ["key1", "key2", "key3"]
+ assert len(custom_meta) == 3
+
+ def test_async_kv_list_empty_partition(self):
+ """Test async_kv_list returns None when partition is empty."""
+ mock_client = MagicMock()
+ mock_client.async_kv_list = AsyncMock(return_value=[])
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import async_kv_list
+
+ keys, custom_meta = asyncio.run(async_kv_list(partition_id="empty_partition"))
+
+ assert keys is None
+ assert custom_meta is None
+
+
+class TestAsyncKVClear:
+ """Tests for async_kv_clear function."""
+
+ def test_async_kv_clear_single_key(self):
+ """Test async_kv_clear with single key."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_batch_meta.size = 1
+ mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
+ mock_client.async_clear_samples = AsyncMock()
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import async_kv_clear
+
+ asyncio.run(async_kv_clear(keys="test_key", partition_id="partition_1"))
+
+ mock_client.async_kv_retrieve_keys.assert_called_once_with(
+ keys=["test_key"], partition_id="partition_1", create=False
+ )
+ mock_client.async_clear_samples.assert_called_once_with(mock_batch_meta)
+
+ def test_async_kv_clear_multiple_keys(self):
+ """Test async_kv_clear with multiple keys."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_batch_meta.size = 3
+ mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
+ mock_client.async_clear_samples = AsyncMock()
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import async_kv_clear
+
+ asyncio.run(async_kv_clear(keys=["key1", "key2", "key3"], partition_id="partition_1"))
+
+ mock_client.async_kv_retrieve_keys.assert_called_once()
+ mock_client.async_clear_samples.assert_called_once()
+
+
+class TestKVInterfaceDictConversion:
+ """Tests for dict to TensorDict conversion in kv_put."""
+
+ def test_kv_put_with_nontensor_value(self):
+ """Test kv_put converts non-tensor values using NonTensorStack."""
+ mock_client = MagicMock()
+ mock_batch_meta = MagicMock()
+ mock_client.kv_retrieve_keys.return_value = mock_batch_meta
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_put
+
+ # Test with non-tensor value (like a string or list)
+ kv_put(
+ key="test_key",
+ partition_id="partition_1",
+ fields={"meta": {"key": "value"}},
+ tag=None,
+ )
+
+ # Verify put was called
+ mock_client.put.assert_called_once()
+ call_args = mock_client.put.call_args
+ fields_arg = call_args[0][0]
+ # The dict should be converted to TensorDict
+ assert isinstance(fields_arg, TensorDict)
+ assert "meta" in fields_arg
+
+ def test_kv_put_rejects_nested_tensor(self):
+ """Test kv_put raises ValueError for nested tensors (requires batch_put)."""
+ mock_client = MagicMock()
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_put
+
+ nested_tensor = torch.nested.nested_tensor([[1, 2], [3, 4]])
+
+ with pytest.raises(ValueError, match="Please use.*kv_batch_put"):
+ kv_put(
+ key="test_key",
+ partition_id="partition_1",
+ fields={"nested": nested_tensor},
+ tag=None,
+ )
+
+ def test_kv_put_invalid_fields_type(self):
+ """Test kv_put raises ValueError for invalid fields type."""
+ mock_client = MagicMock()
+
+ with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
+ from transfer_queue.interface import kv_put
+
+ with pytest.raises(ValueError, match="field can only be dict or TensorDict"):
+ kv_put(
+ key="test_key",
+ partition_id="partition_1",
+ fields="invalid_string",
+ tag=None,
+ )
diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py
index 41296dd..ec5b29c 100644
--- a/tests/test_kv_storage_manager.py
+++ b/tests/test_kv_storage_manager.py
@@ -27,6 +27,7 @@
from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402
from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402
+from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402
def get_meta(data, global_indexes=None):
@@ -41,7 +42,7 @@ def get_meta(data, global_indexes=None):
name=field_name,
dtype=tensor.dtype if isinstance(tensor, torch.Tensor) else None,
shape=tensor.shape if isinstance(tensor, torch.Tensor) else None,
- production_status=1,
+ production_status=ProductionStatus.READY_FOR_CONSUME,
)
fields_dict[field_name] = field_meta
sample = SampleMeta(
@@ -117,7 +118,7 @@ def test_merge_tensors_to_tensordict(mock_create, test_data):
mock_client = MagicMock()
mock_create.return_value = mock_client
- manager = KVStorageManager(test_data["cfg"])
+ manager = KVStorageManager(controller_info=MagicMock(), config=test_data["cfg"])
assert manager.storage_client is mock_client
assert manager._multi_threads_executor is None
@@ -163,8 +164,8 @@ def test_merge_tensors_to_tensordict(mock_create, test_data):
assert complex_tensordict[key] == complex_data[key]
-def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data):
- """Test _get_shape_type_custom_backend_meta_list returns correct shapes and dtypes without custom_meta."""
+def test_get_shape_type_custom_backend_meta_list_without_custom_backend_meta(test_data):
+ """Test _get_shape_type_custom_backend_meta_list returns correct shapes and dtypes without custom_backend_meta."""
shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(
test_data["metadata"]
)
@@ -185,7 +186,7 @@ def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data):
torch.Size([2]), # text[2]
]
expected_dtypes = [torch.int64] * (len(test_data["field_names"]) * len(test_data["global_indexes"]))
- # No custom_meta provided, so all should be None
+ # No custom_backend_meta provided, so all should be None
expected_custom_backend_meta = [None] * (len(test_data["field_names"]) * len(test_data["global_indexes"]))
assert shapes == expected_shapes
@@ -193,9 +194,9 @@ def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data):
assert custom_backend_meta_list == expected_custom_backend_meta
-def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data):
- """Test _get_shape_type_custom_meta_list returns correct custom_meta when provided."""
- # Add custom_meta to metadata
+def test_get_shape_type_custom_backend_meta_list_with_custom_backend_meta(test_data):
+ """Test _get_shape_type_custom_backend_meta_list returns correct custom_backend_meta when provided."""
+ # Add custom_backend_meta to metadata
custom_backend_meta = {
8: {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}},
9: {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}},
@@ -206,8 +207,8 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data):
shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata)
- # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index
- expected_custom_meta = [
+ # Check custom_backend_meta - order is label, mask, text (sorted alphabetically) by global_index
+ expected_custom_backend_meta = [
{"key2": "value2"}, # label, global_index=8
{"key5": "value5"}, # label, global_index=9
{"key8": "value8"}, # label, global_index=10
@@ -218,15 +219,15 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data):
{"key4": "value4"}, # text, global_index=9
{"key7": "value7"}, # text, global_index=10
]
- assert custom_backend_meta_list == expected_custom_meta
+ assert custom_backend_meta_list == expected_custom_backend_meta
-def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_data):
- """Test _get_shape_type_custom_backend_meta_list handles partial custom_meta correctly."""
- # Add custom_meta only for some global_indexes and fields
+def test_get_shape_type_custom_backend_meta_list_with_partial_custom_backend_meta(test_data):
+ """Test _get_shape_type_custom_backend_meta_list handles partial custom_backend_meta correctly."""
+ # Add custom_backend_meta only for some global_indexes and fields
custom_backend_meta = {
8: {"text": {"key1": "value1"}}, # Only text field
- # global_index 9 has no custom_meta
+ # global_index 9 has no custom_backend_meta
10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only
}
metadata = test_data["metadata"]
@@ -234,17 +235,17 @@ def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_d
shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata)
- # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index
+ # Check custom_backend_meta - order is label, mask, text (sorted alphabetically) by global_index
expected_custom_backend_meta = [
- None, # label, global_index=8 (not in custom_meta)
- None, # label, global_index=9 (not in custom_meta)
+ None, # label, global_index=8 (not in custom_backend_meta)
+ None, # label, global_index=9 (not in custom_backend_meta)
{"key2": "value2"}, # label, global_index=10
- None, # mask, global_index=8 (not in custom_meta)
- None, # mask, global_index=9 (not in custom_meta)
+ None, # mask, global_index=8 (not in custom_backend_meta)
+ None, # mask, global_index=9 (not in custom_backend_meta)
{"key3": "value3"}, # mask, global_index=10
{"key1": "value1"}, # text, global_index=8
- None, # text, global_index=9 (not in custom_meta)
- None, # text, global_index=10 (not in custom_meta for text)
+ None, # text, global_index=9 (not in custom_backend_meta)
+ None, # text, global_index=10 (not in custom_backend_meta for text)
]
assert custom_backend_meta_list == expected_custom_backend_meta
@@ -279,13 +280,13 @@ def test_data_for_put_data():
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
@patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock)
-def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_for_put_data):
- """Test that put_data correctly processes custom_meta returned by storage client."""
+def test_put_data_with_custom_backend_meta_from_storage_client(mock_notify, test_data_for_put_data):
+ """Test that put_data correctly processes custom_backend_meta returned by storage client."""
# Create a mock storage client
mock_storage_client = MagicMock()
- # Simulate storage client returning custom_meta (one per key)
+ # Simulate storage client returning custom_backend_meta (one per key)
# Keys order: label[0,1,2], text[0,1,2] (sorted by field name)
- mock_custom_meta = [
+ mock_custom_backend_meta = [
{"storage_key": "0@label"},
{"storage_key": "1@label"},
{"storage_key": "2@label"},
@@ -293,12 +294,12 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo
{"storage_key": "1@text"},
{"storage_key": "2@text"},
]
- mock_storage_client.put.return_value = mock_custom_meta
+ mock_storage_client.put.return_value = mock_custom_backend_meta
# Create manager with mocked dependencies
- config = {"controller_info": MagicMock(), "client_name": "MockClient"}
+ config = {"client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
- manager = KVStorageManager(config)
+ manager = KVStorageManager(controller_info=MagicMock(), config=config)
# Run put_data
asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"]))
@@ -314,64 +315,65 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo
assert keys == expected_keys
assert len(values) == 6
- # Verify notify_data_update was called with correct custom_meta structure
+ # Verify notify_data_update was called with correct custom_backend_meta structure
mock_notify.assert_called_once()
notify_call_args = mock_notify.call_args
- per_field_custom_meta = notify_call_args[0][5] # 6th positional argument
+ per_field_custom_backend_meta = notify_call_args[0][5] # 6th positional argument
- # Verify custom_meta is structured correctly: {global_index: {field: meta}}
- assert 0 in per_field_custom_meta
- assert 1 in per_field_custom_meta
- assert 2 in per_field_custom_meta
+ # Verify custom_backend_meta is structured correctly: {global_index: {field: meta}}
+ assert 0 in per_field_custom_backend_meta
+ assert 1 in per_field_custom_backend_meta
+ assert 2 in per_field_custom_backend_meta
- assert per_field_custom_meta[0]["label"] == {"storage_key": "0@label"}
- assert per_field_custom_meta[0]["text"] == {"storage_key": "0@text"}
- assert per_field_custom_meta[1]["label"] == {"storage_key": "1@label"}
- assert per_field_custom_meta[1]["text"] == {"storage_key": "1@text"}
- assert per_field_custom_meta[2]["label"] == {"storage_key": "2@label"}
- assert per_field_custom_meta[2]["text"] == {"storage_key": "2@text"}
+ assert per_field_custom_backend_meta[0]["label"] == {"storage_key": "0@label"}
+ assert per_field_custom_backend_meta[0]["text"] == {"storage_key": "0@text"}
+ assert per_field_custom_backend_meta[1]["label"] == {"storage_key": "1@label"}
+ assert per_field_custom_backend_meta[1]["text"] == {"storage_key": "1@text"}
+ assert per_field_custom_backend_meta[2]["label"] == {"storage_key": "2@label"}
+ assert per_field_custom_backend_meta[2]["text"] == {"storage_key": "2@text"}
- # Verify metadata was updated with custom_meta
- all_custom_meta = test_data_for_put_data["metadata"].get_all_custom_meta()
- assert all_custom_meta[0]["label"] == {"storage_key": "0@label"}
- assert all_custom_meta[2]["text"] == {"storage_key": "2@text"}
+ # Verify metadata was updated with custom_backend_meta
+ all_custom_backend_meta = test_data_for_put_data["metadata"]._custom_backend_meta
+ assert len(all_custom_backend_meta) == 3
+ assert all_custom_backend_meta[0]["label"] == {"storage_key": "0@label"}
+ assert all_custom_backend_meta[2]["text"] == {"storage_key": "2@text"}
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
@patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock)
-def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data):
- """Test that put_data works correctly when storage client returns no custom_meta."""
- # Create a mock storage client that returns None for custom_meta
+def test_put_data_without_custom_backend_meta(mock_notify, test_data_for_put_data):
+ """Test that put_data works correctly when storage client returns no custom_backend_meta."""
+ # Create a mock storage client that returns None for custom_backend_meta
mock_storage_client = MagicMock()
mock_storage_client.put.return_value = None
# Create manager with mocked dependencies
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
- manager = KVStorageManager(config)
+ manager = KVStorageManager(controller_info=MagicMock(), config=config)
# Run put_data
asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"]))
- # Verify notify_data_update was called with empty dict for custom_meta
+ # Verify notify_data_update was called with empty dict for custom_backend_meta
mock_notify.assert_called_once()
notify_call_args = mock_notify.call_args
- per_field_custom_meta = notify_call_args[0][5] # 6th positional argument
- assert per_field_custom_meta == {}
+ per_field_custom_backend_meta = notify_call_args[0][5] # 6th positional argument
+ assert per_field_custom_backend_meta == {}
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
-def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_data):
- """Test that put_data raises ValueError when custom_meta length doesn't match keys."""
- # Create a mock storage client that returns mismatched custom_meta length
+def test_put_data_custom_backend_meta_length_mismatch_raises_error(test_data_for_put_data):
+ """Test that put_data raises ValueError when custom_backend_meta length doesn't match keys."""
+ # Create a mock storage client that returns mismatched custom_backend_meta length
mock_storage_client = MagicMock()
- # Return only 3 custom_meta entries when 6 are expected
+ # Return only 3 custom_backend_meta entries when 6 are expected
mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}]
# Create manager with mocked dependencies
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
- manager = KVStorageManager(config)
+ manager = KVStorageManager(controller_info=MagicMock(), config=config)
# Run put_data and expect ValueError
with pytest.raises(ValueError) as exc_info:
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index 2bbf40c..6db10a6 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -788,64 +788,6 @@ def test_batch_meta_select_samples_with_extra_info(self):
# =====================================================
# Custom Meta Tests
# =====================================================
-
- def test_batch_meta_set_custom_meta_basic(self):
- """Test set_custom_meta sets metadata for a sample by global_index."""
- fields = {
- "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)),
- "field_b": FieldMeta(name="field_b", dtype=torch.int64, shape=(3,)),
- }
- samples = [
- SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
- SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
- ]
- batch = BatchMeta(samples=samples)
-
- # Set custom_meta for sample 0
- batch.set_custom_meta(0, {"sample_score": 0.9, "quality": "high"})
-
- result = batch.get_all_custom_meta()
- assert 0 in result
- assert result[0]["sample_score"] == 0.9
- assert result[0]["quality"] == "high"
- # Sample 1 should not have custom_meta
- assert 1 not in result
-
- def test_batch_meta_set_custom_meta_overwrites(self):
- """Test set_custom_meta overwrites existing metadata."""
- fields = {
- "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)),
- }
- samples = [
- SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
- ]
- batch = BatchMeta(samples=samples)
-
- # Set initial custom_meta
- batch.set_custom_meta(0, {"sample_score": 0.9, "quality": "high"})
-
- # Overwrite with new custom_meta
- batch.set_custom_meta(0, {"sample_score": 0.1, "quality": "low"})
-
- result = batch.get_all_custom_meta()
- assert result[0]["sample_score"] == 0.1
- assert result[0]["quality"] == "low"
-
- def test_batch_meta_set_custom_meta_invalid_global_index(self):
- """Test set_custom_meta raises error for invalid global_index."""
- fields = {
- "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)),
- }
- samples = [
- SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
- ]
- batch = BatchMeta(samples=samples)
-
- # Try to set with non-existent global index
- with pytest.raises(ValueError) as exc_info:
- batch.set_custom_meta(999, {"sample_score": 0.9})
- assert "not found in global_indexes" in str(exc_info.value)
-
def test_batch_meta_update_custom_meta(self):
"""Test update_custom_meta adds metadata for different global indices."""
fields = {
@@ -859,10 +801,7 @@ def test_batch_meta_update_custom_meta(self):
batch = BatchMeta(samples=samples)
# Initial custom_meta for sample 0
- batch.update_custom_meta({0: {"sample_score": 0.9}})
-
- # Update with metadata for sample 1
- batch.update_custom_meta({1: {"sample_score": 0.1}})
+ batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}])
result = batch.get_all_custom_meta()
assert result[0]["sample_score"] == 0.9
@@ -879,10 +818,10 @@ def test_batch_meta_update_custom_meta_overwrites(self):
batch = BatchMeta(samples=samples)
# Initial custom_meta
- batch.update_custom_meta({0: {"sample_score": 0.9, "quality": "high"}})
+ batch.update_custom_meta([{"sample_score": 0.9, "quality": "high"}])
# Update with new value for same field - dict.update replaces
- batch.update_custom_meta({0: {"sample_score": 0.1, "quality": "low"}})
+ batch.update_custom_meta([{"sample_score": 0.1, "quality": "low"}])
result = batch.get_all_custom_meta()
assert result[0]["sample_score"] == 0.1
@@ -899,7 +838,7 @@ def test_batch_meta_update_custom_meta_with_none(self):
batch = BatchMeta(samples=samples)
# Set initial value
- batch.update_custom_meta({0: {"sample_score": 0.9}})
+ batch.update_custom_meta([{"sample_score": 0.9}])
# Update with None should not change anything
batch.update_custom_meta(None)
@@ -907,40 +846,6 @@ def test_batch_meta_update_custom_meta_with_none(self):
result = batch.get_all_custom_meta()
assert result[0]["sample_score"] == 0.9
- def test_batch_meta_update_custom_meta_with_empty_dict(self):
- """Test update_custom_meta with empty dict does nothing."""
- fields = {
- "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)),
- }
- samples = [
- SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
- ]
- batch = BatchMeta(samples=samples)
-
- # Set initial value
- batch.update_custom_meta({0: {"sample_score": 0.9}})
-
- # Update with empty dict should not change anything
- batch.update_custom_meta({})
-
- result = batch.get_all_custom_meta()
- assert result[0]["sample_score"] == 0.9
-
- def test_batch_meta_update_custom_meta_invalid_global_index(self):
- """Test update_custom_meta raises error for invalid global_index."""
- fields = {
- "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)),
- }
- samples = [
- SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
- ]
- batch = BatchMeta(samples=samples)
-
- # Try to update with non-existent global index
- with pytest.raises(ValueError) as exc_info:
- batch.update_custom_meta({999: {"sample_score": 0.9}})
- assert "non-exist global_indexes" in str(exc_info.value)
-
def test_batch_meta_clear_custom_meta(self):
"""Test clear_custom_meta removes all custom metadata."""
fields = {
@@ -953,14 +858,13 @@ def test_batch_meta_clear_custom_meta(self):
batch = BatchMeta(samples=samples)
# Set custom_meta
- batch.set_custom_meta(0, {"sample_score": 0.9})
- batch.set_custom_meta(1, {"sample_score": 0.1})
+ batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}])
# Clear all
batch.clear_custom_meta()
result = batch.get_all_custom_meta()
- assert result == {}
+ assert result == [{}, {}]
def test_batch_meta_get_all_custom_meta_returns_deep_copy(self):
"""Test get_all_custom_meta returns a deep copy."""
@@ -972,7 +876,7 @@ def test_batch_meta_get_all_custom_meta_returns_deep_copy(self):
]
batch = BatchMeta(samples=samples)
- custom_meta = {0: {"sample_score": 0.9, "nested": {"value": 1}}}
+ custom_meta = [{"sample_score": 0.9, "nested": {"value": 1}}]
batch.update_custom_meta(custom_meta)
# Get all custom_meta
@@ -997,7 +901,7 @@ def test_batch_meta_get_all_custom_meta_empty(self):
batch = BatchMeta(samples=samples)
result = batch.get_all_custom_meta()
- assert result == {}
+ assert result == [{}]
def test_batch_meta_custom_meta_with_nested_data(self):
"""Test custom_meta supports nested dictionary data."""
@@ -1013,7 +917,7 @@ def test_batch_meta_custom_meta_with_nested_data(self):
"model_info": {"name": "llama", "version": "7b", "config": {"hidden_size": 4096, "num_layers": 32}},
"tags": ["training", "inference"],
}
- batch.set_custom_meta(0, nested_meta)
+ batch.update_custom_meta([nested_meta])
result = batch.get_all_custom_meta()
assert result[0]["model_info"]["name"] == "llama"
diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py
index e925902..f0421c0 100644
--- a/transfer_queue/__init__.py
+++ b/transfer_queue/__init__.py
@@ -15,12 +15,35 @@
import os
-from .client import (
- TransferQueueClient,
- process_zmq_server_info,
-)
+from .client import TransferQueueClient
from .controller import TransferQueueController
from .dataloader import StreamingDataLoader, StreamingDataset
+from .interface import (
+ async_clear_partition,
+ async_clear_samples,
+ async_get_data,
+ async_get_meta,
+ async_kv_batch_put,
+ async_kv_clear,
+ async_kv_get,
+ async_kv_list,
+ async_kv_put,
+ async_put,
+ async_set_custom_meta,
+ clear_partition,
+ clear_samples,
+ close,
+ get_data,
+ get_meta,
+ init,
+ kv_batch_put,
+ kv_clear,
+ kv_get,
+ kv_list,
+ kv_put,
+ put,
+ set_custom_meta,
+)
from .metadata import BatchMeta
from .sampler import BaseSampler
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
@@ -28,9 +51,34 @@
from .sampler.sequential_sampler import SequentialSampler
from .storage import SimpleStorageUnit
from .utils.common import get_placement_group
-from .utils.zmq_utils import ZMQServerInfo
+from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info
__all__ = [
+ "init",
+ "get_meta",
+ "get_data",
+ "put",
+ "set_custom_meta",
+ "clear_samples",
+ "clear_partition",
+ "async_get_meta",
+ "async_get_data",
+ "async_put",
+ "async_set_custom_meta",
+ "async_clear_samples",
+ "async_clear_partition",
+ "close",
+ "kv_put",
+ "kv_batch_put",
+ "kv_get",
+ "kv_list",
+ "kv_clear",
+ "async_kv_put",
+ "async_kv_batch_put",
+ "async_kv_get",
+ "async_kv_list",
+ "async_kv_clear",
+] + [
"TransferQueueClient",
"StreamingDataset",
"StreamingDataLoader",
diff --git a/transfer_queue/client.py b/transfer_queue/client.py
index 24b4e55..f0c899b 100644
--- a/transfer_queue/client.py
+++ b/transfer_queue/client.py
@@ -18,23 +18,19 @@
import os
import threading
from functools import wraps
-from typing import Any, Callable, Optional, Union
+from typing import Any, Callable, Optional
from uuid import uuid4
-import ray
import torch
import zmq
import zmq.asyncio
from tensordict import TensorDict
from torch import Tensor
-from transfer_queue.controller import TransferQueueController
from transfer_queue.metadata import (
BatchMeta,
)
from transfer_queue.storage import (
- SimpleStorageUnit,
- TransferQueueStorageManager,
TransferQueueStorageManagerFactory,
)
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
@@ -95,11 +91,12 @@ def initialize_storage_manager(
AsyncSimpleStorageManager, KVStorageManager (under development), etc.
config: Configuration dictionary for the storage manager.
For AsyncSimpleStorageManager, must contain the following required keys:
- - controller_info: ZMQ server information about the controller
- - storage_unit_infos: ZMQ server information about the storage units
+ - zmq_info: ZMQ server information about the storage units
"""
- self.storage_manager = TransferQueueStorageManagerFactory.create(manager_type, config)
+ self.storage_manager = TransferQueueStorageManagerFactory.create(
+ manager_type, controller_info=self._controller, config=config
+ )
# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
@staticmethod
@@ -156,6 +153,7 @@ async def wrapper(self, *args, **kwargs):
return decorator
+ # ==================== Basic API ====================
@dynamic_socket(socket_name="request_handle_socket")
async def async_get_meta(
self,
@@ -218,7 +216,7 @@ async def async_get_meta(
"""
assert socket is not None
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.GET_META,
+ request_type=ZMQRequestType.GET_META, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
@@ -294,10 +292,13 @@ async def async_set_custom_meta(
partition_custom_meta: dict[str, dict[int, dict]] = {pid: {} for pid in set(metadata.partition_ids)}
for meta in metadata_chunks:
- partition_custom_meta[meta.partition_ids[0]].update(meta.get_all_custom_meta())
+ custom_meta = meta.get_all_custom_meta()
+ partition_custom_meta[meta.partition_ids[0]].update(
+ {meta.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta))}
+ )
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.SET_CUSTOM_META,
+ request_type=ZMQRequestType.SET_CUSTOM_META, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
@@ -535,7 +536,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
"""
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.CLEAR_META,
+ request_type=ZMQRequestType.CLEAR_META, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={"global_indexes": metadata.global_indexes, "partition_ids": metadata.partition_ids},
@@ -563,7 +564,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta
RuntimeError: If controller returns error response
"""
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.GET_PARTITION_META,
+ request_type=ZMQRequestType.GET_PARTITION_META, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={"partition_id": partition_id},
@@ -608,6 +609,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None):
if response_msg.request_type != ZMQRequestType.CLEAR_PARTITION_RESPONSE:
raise RuntimeError(f"Failed to clear partition {partition_id} in controller.")
+ # ==================== Status Query API ====================
@dynamic_socket(socket_name="request_handle_socket")
async def async_get_consumption_status(
self,
@@ -641,7 +643,7 @@ async def async_get_consumption_status(
assert socket is not None
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.GET_CONSUMPTION,
+ request_type=ZMQRequestType.GET_CONSUMPTION, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
@@ -678,7 +680,7 @@ async def async_get_production_status(
partition_id: str,
socket: Optional[zmq.asyncio.Socket] = None,
) -> tuple[Optional[Tensor], Optional[Tensor]]:
- """Get production status for current partition for specific fields.
+ """Get production status for specific data fields and partition.
Args:
data_fields: Data fields to check production status for
@@ -703,7 +705,7 @@ async def async_get_production_status(
"""
assert socket is not None
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.GET_PRODUCTION,
+ request_type=ZMQRequestType.GET_PRODUCTION, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
@@ -768,6 +770,41 @@ async def async_check_consumption_status(
return False
return torch.all(consumption_status == 1).item()
+ async def async_check_production_status(
+ self,
+ data_fields: list[str],
+ partition_id: str,
+ ) -> bool:
+ """Check if the all specific fields of samples for current partition are ready
+ (produced) for consumption.
+
+ Args:
+ data_fields: Data fields to check production status for
+ partition_id: Partition id to check production status for
+
+ Returns:
+ bool: True if all samples have been produced and ready, False otherwise
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> # Check if all samples are ready for consumption
+ >>> is_ready = asyncio.run(client.async_check_production_status(
+ ... data_fields=["input_ids", "attention_mask"],
+ ... partition_id="train_0"
+ ... ))
+ >>> print(f"All samples ready: {is_ready}")
+ """
+ _, production_status = await self.async_get_production_status(
+ data_fields=data_fields,
+ partition_id=partition_id,
+ )
+
+ if production_status is None:
+ return False
+ return torch.all(production_status == 1).item()
+
@dynamic_socket(socket_name="request_handle_socket")
async def async_reset_consumption(
self,
@@ -775,17 +812,22 @@ async def async_reset_consumption(
task_name: Optional[str] = None,
socket: Optional[zmq.asyncio.Socket] = None,
) -> bool:
- """Reset consumption status for a partition, allowing data to be re-consumed.
- This is useful for debugging scenarios where the same rollout data needs to be
- trained multiple times without regenerating the data.
+ """Aynchronously reset consumption status for a partition.
+
+ This allows the same data to be re-consumed, useful for debugging scenarios
+ where the same rollout data needs to be trained multiple times.
+
Args:
partition_id: Partition id to reset consumption status for
task_name: Name of the task to reset. If None, resets all tasks.
socket: ZMQ async socket for message transmission (injected by decorator)
+
Returns:
bool: True if reset was successful, False otherwise
+
Raises:
RuntimeError: If communication fails or controller returns error response
+
Example:
>>> # Reset consumption for train task to re-train on same data
>>> success = asyncio.run(client.async_reset_consumption(
@@ -799,7 +841,7 @@ async def async_reset_consumption(
if task_name is not None:
body["task_name"] = task_name
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.RESET_CONSUMPTION,
+ request_type=ZMQRequestType.RESET_CONSUMPTION, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body=body,
@@ -825,41 +867,6 @@ async def async_reset_consumption(
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e
- async def async_check_production_status(
- self,
- data_fields: list[str],
- partition_id: str,
- ) -> bool:
- """Check if the all specific fields of samples for current partition are ready
- (produced) for consumption.
-
- Args:
- data_fields: Data fields to check production status for
- partition_id: Partition id to check production status for
-
- Returns:
- bool: True if all samples have been produced and ready, False otherwise
-
- Raises:
- RuntimeError: If communication fails or controller returns error response
-
- Example:
- >>> # Check if all samples are ready for consumption
- >>> is_ready = asyncio.run(client.async_check_production_status(
- ... data_fields=["input_ids", "attention_mask"],
- ... partition_id="train_0"
- ... ))
- >>> print(f"All samples ready: {is_ready}")
- """
- _, production_status = await self.async_get_production_status(
- data_fields=data_fields,
- partition_id=partition_id,
- )
-
- if production_status is None:
- return False
- return torch.all(production_status == 1).item()
-
@dynamic_socket(socket_name="request_handle_socket")
async def async_get_partition_list(
self,
@@ -872,9 +879,13 @@ async def async_get_partition_list(
Returns:
list[str]: List of partition ids managed by the controller
+
+ Example:
+ >>> partition_ids = asyncio.run(client.get_partition_list())
+ >>> print(f"Available partitions: {partition_ids}")
"""
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.GET_LIST_PARTITIONS,
+ request_type=ZMQRequestType.GET_LIST_PARTITIONS, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={},
@@ -901,6 +912,123 @@ async def async_get_partition_list(
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e
+ # ==================== KV Interface API ====================
+ @dynamic_socket(socket_name="request_handle_socket")
+ async def async_kv_retrieve_keys(
+ self,
+ keys: list[str] | str,
+ partition_id: str,
+ create: bool = False,
+ socket: Optional[zmq.asyncio.Socket] = None,
+ ) -> BatchMeta:
+ """Asynchronously retrieve BatchMeta from the controller using user-specified keys.
+
+ Args:
+ keys: List of keys to retrieve from the controller
+ partition_id: The ID of the logical partition to search for keys.
+ create: Whether to register new keys if not found.
+ socket: ZMQ socket (injected by decorator)
+
+ Returns:
+ metadata: BatchMeta of the corresponding keys
+
+ Raises:
+ TypeError: If `keys` is not a list of string or a string
+ """
+
+ if isinstance(keys, str):
+ keys = [keys]
+ elif isinstance(keys, list):
+ if len(keys) < 1:
+ raise ValueError("Received an empty list as keys.")
+ # validate all the elements are str
+ if not all(isinstance(k, str) for k in keys):
+ raise TypeError("Not all element in `keys` are strings.")
+ else:
+ raise TypeError("Only string or list of strings are allowed as `keys`.")
+
+ request_msg = ZMQMessage.create(
+ request_type=ZMQRequestType.KV_RETRIEVE_KEYS, # type: ignore[arg-type]
+ sender_id=self.client_id,
+ receiver_id=self._controller.id,
+ body={
+ "keys": keys,
+ "partition_id": partition_id,
+ "create": create,
+ },
+ )
+
+ try:
+ assert socket is not None
+ await socket.send_multipart(request_msg.serialize())
+ response_serialized = await socket.recv_multipart()
+ response_msg = ZMQMessage.deserialize(response_serialized)
+ logger.debug(
+ f"[{self.client_id}]: Client get kv_retrieve_keys response: {response_msg} "
+ f"from controller {self._controller.id}"
+ )
+
+ if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE:
+ metadata = response_msg.body.get("metadata", BatchMeta.empty())
+ return metadata
+ else:
+ raise RuntimeError(
+ f"[{self.client_id}]: Failed to retrieve keys from controller {self._controller.id}: "
+ f"{response_msg.body.get('message', 'Unknown error')}"
+ )
+ except Exception as e:
+ raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e
+
+ @dynamic_socket(socket_name="request_handle_socket")
+ async def async_kv_list(
+ self,
+ partition_id: str,
+ socket: Optional[zmq.asyncio.Socket] = None,
+ ) -> tuple[list[str], list[dict]]:
+ """Asynchronously retrieve keys and custom_meta from the controller for partition.
+
+ Args:
+ partition_id: Partition to retrieve from the controller
+ socket: ZMQ socket (injected by decorator)
+
+ Returns:
+ keys: list of keys in the partition
+ custom_meta: list of dict for custom_meta
+ """
+
+ if partition_id is None:
+ return [], []
+
+ request_msg = ZMQMessage.create(
+ request_type=ZMQRequestType.KV_LIST, # type: ignore[arg-type]
+ sender_id=self.client_id,
+ receiver_id=self._controller.id,
+ body={
+ "partition_id": partition_id,
+ },
+ )
+
+ try:
+ assert socket is not None
+ await socket.send_multipart(request_msg.serialize())
+ response_serialized = await socket.recv_multipart()
+ response_msg = ZMQMessage.deserialize(response_serialized)
+ logger.debug(
+ f"[{self.client_id}]: Client get kv_list response: {response_msg} from controller {self._controller.id}"
+ )
+
+ if response_msg.request_type == ZMQRequestType.KV_LIST_RESPONSE:
+ keys = response_msg.body.get("keys", [])
+ custom_meta = response_msg.body.get("custom_meta", [])
+ return keys, custom_meta
+ else:
+ raise RuntimeError(
+ f"[{self.client_id}]: Failed to list keys from controller {self._controller.id}: "
+ f"{response_msg.body.get('message', 'Unknown error')}"
+ )
+ except Exception as e:
+ raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e
+
def close(self) -> None:
"""Close the client and cleanup resources including storage manager."""
try:
@@ -975,72 +1103,16 @@ def wrapper(*args, **kwargs):
self._get_partition_list = _make_sync(self.async_get_partition_list)
self._set_custom_meta = _make_sync(self.async_set_custom_meta)
self._reset_consumption = _make_sync(self.async_reset_consumption)
+ self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys)
+ self._kv_list = _make_sync(self.async_kv_list)
- def put(
- self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None
- ) -> BatchMeta:
- """Synchronously write data to storage units based on metadata.
-
- If metadata is not provided, it will be created automatically using insert mode
- with the provided data fields and partition_id.
-
- During put, the custom_meta in metadata will update the corresponding custom_meta in
- TransferQueue Controller.
-
- Note:
- When using multiple workers for distributed execution, there may be data
- ordering inconsistencies between workers during put operations.
-
- Args:
- data: Data to write as TensorDict
- metadata: Records the metadata of a batch of data samples, containing index and
- storage unit information. If None, metadata will be auto-generated.
- partition_id: Target data partition id (required if metadata is not provided)
-
- Returns:
- BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
- metadata; will be updated in a future version to reflect the post-put state)
-
- Raises:
- ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
- RuntimeError: If storage operation fails
-
- Example:
- >>> batch_size = 4
- >>> seq_len = 16
- >>> current_partition_id = "train_0"
- >>> # Example 1: Normal usage with existing metadata
- >>> batch_meta = client.get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=batch_size,
- ... partition_id=current_partition_id,
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... )
- >>> batch = client.get_data(batch_meta)
- >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
- >>> client.put(data=output, metadata=batch_meta)
- >>>
- >>> # Example 2: Initial data insertion without pre-existing metadata
- >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
- >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
- >>> # without metadata.
- >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
- >>> # interleave the initial data if n_sample > 1 before calling the async_put().
- >>> original_prompts = torch.randn(batch_size, seq_len)
- >>> n_samples = 4
- >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
- >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
- >>> # This will create metadata in "insert" mode internally.
- >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id)
- """
- return self._put(data=data, metadata=metadata, partition_id=partition_id)
-
+ # ==================== Basic API ====================
def get_meta(
self,
data_fields: list[str],
batch_size: int,
partition_id: str,
+ mode: str = "fetch",
task_name: Optional[str] = None,
sampling_config: Optional[dict[str, Any]] = None,
) -> BatchMeta:
@@ -1098,10 +1170,95 @@ def get_meta(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
+ mode=mode,
task_name=task_name,
sampling_config=sampling_config,
)
+ def set_custom_meta(self, metadata: BatchMeta) -> None:
+ """Synchronously send custom metadata to the controller.
+
+ This method sends per-sample custom metadata (custom_meta) to the controller.
+ The custom_meta is stored in the controller and can be retrieved along with
+ the BatchMeta in subsequent get_meta calls.
+
+ Args:
+ metadata: BatchMeta containing the samples and their custom metadata to store.
+ The custom_meta should be set using BatchMeta.update_custom_meta() or
+ BatchMeta.set_custom_meta() before calling this method.
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> # Create batch with custom metadata
+ >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...)
+ >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
+ >>> client.set_custom_meta(batch_meta)
+ """
+
+ return self._set_custom_meta(metadata=metadata)
+
+ def put(
+ self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None
+ ) -> BatchMeta:
+ """Synchronously write data to storage units based on metadata.
+
+ If metadata is not provided, it will be created automatically using insert mode
+ with the provided data fields and partition_id.
+
+ During put, the custom_meta in metadata will update the corresponding custom_meta in
+ TransferQueue Controller.
+
+ Note:
+ When using multiple workers for distributed execution, there may be data
+ ordering inconsistencies between workers during put operations.
+
+ Args:
+ data: Data to write as TensorDict
+ metadata: Records the metadata of a batch of data samples, containing index and
+ storage unit information. If None, metadata will be auto-generated.
+ partition_id: Target data partition id (required if metadata is not provided)
+
+ Returns:
+ BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
+ metadata; will be updated in a future version to reflect the post-put state)
+
+ Raises:
+ ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
+ RuntimeError: If storage operation fails
+
+ Example:
+ >>> batch_size = 4
+ >>> seq_len = 16
+ >>> current_partition_id = "train_0"
+ >>> # Example 1: Normal usage with existing metadata
+ >>> batch_meta = client.get_meta(
+ ... data_fields=["prompts", "attention_mask"],
+ ... batch_size=batch_size,
+ ... partition_id=current_partition_id,
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... )
+ >>> batch = client.get_data(batch_meta)
+ >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
+ >>> client.put(data=output, metadata=batch_meta)
+ >>>
+ >>> # Example 2: Initial data insertion without pre-existing metadata
+ >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
+ >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
+ >>> # without metadata.
+ >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
+ >>> # interleave the initial data if n_sample > 1 before calling the async_put().
+ >>> original_prompts = torch.randn(batch_size, seq_len)
+ >>> n_samples = 4
+ >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
+ >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
+ >>> # This will create metadata in "insert" mode internally.
+ >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id)
+ """
+ return self._put(data=data, metadata=metadata, partition_id=partition_id)
+
def get_data(self, metadata: BatchMeta) -> TensorDict:
"""Synchronously fetch data from storage units and organize into TensorDict.
@@ -1148,65 +1305,85 @@ def clear_samples(self, metadata: BatchMeta):
"""
return self._clear_samples(metadata=metadata)
- def check_consumption_status(self, task_name: str, partition_id: str) -> bool:
- """Synchronously check if all samples for a partition have been consumed by a specific task.
+ # ==================== Status Query API ====================
+ def get_consumption_status(
+ self,
+ task_name: str,
+ partition_id: str,
+ ) -> tuple[Optional[Tensor], Optional[Tensor]]:
+ """Synchronously get consumption status for a specific task and partition.
Args:
task_name: Name of the task to check consumption for
partition_id: Partition id to check consumption status for
Returns:
- bool: True if all samples have been consumed by the task, False otherwise
+ Tuple of:
+ - Partition global index tensor
+ - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed.
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
- >>> # Check if all samples have been consumed
- >>> is_consumed = client.check_consumption_status(
+ >>> global_index, consumption_status = client.get_consumption_status(
... task_name="generate_sequences",
... partition_id="train_0"
... )
- >>> print(f"All samples consumed: {is_consumed}")
+ >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}")
"""
- return self._check_consumption_status(task_name=task_name, partition_id=partition_id)
+ return self._get_consumption_status(task_name, partition_id)
- def get_consumption_status(
+ def get_production_status(
self,
- task_name: str,
+ data_fields: list[str],
partition_id: str,
) -> tuple[Optional[Tensor], Optional[Tensor]]:
- """Synchronously get consumption status for a specific task and partition.
+ """Synchronously get production status for specific data fields and partition.
Args:
- task_name: Name of the task to check consumption for
- partition_id: Partition id to check consumption status for
+ data_fields: Data fields to check production status for
+ partition_id: Partition id to check production status for
Returns:
Tuple of:
- Partition global index tensor
- - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed.
+ - Production status tensor for the specified fields. 1 for ready, 0 for not ready.
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
Example:
- >>> global_index, consumption_status = client.get_consumption_status(
- ... task_name="generate_sequences",
+ >>> global_index, production_status = client.get_production_status(
+ ... data_fields=["input_ids", "attention_mask"],
... partition_id="train_0"
... )
- >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}")
+ >>> print(f"Global index: {global_index}, Production status: {production_status}")
"""
- return self._get_consumption_status(task_name, partition_id)
+ return self._get_production_status(data_fields=data_fields, partition_id=partition_id)
+
+ def check_consumption_status(self, task_name: str, partition_id: str) -> bool:
+ """Synchronously check if all samples for a partition have been consumed by a specific task.
- def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool:
- """Synchronously reset consumption status for a partition.
- This allows the same data to be re-consumed, useful for debugging scenarios
- where the same rollout data needs to be trained multiple times.
Args:
- partition_id: Partition id to reset consumption status for
- task_name: Name of the task to reset. If None, resets all tasks.
+ task_name: Name of the task to check consumption for
+ partition_id: Partition id to check consumption status for
+
Returns:
- bool: True if reset was successful, False otherwise
+ bool: True if all samples have been consumed by the task, False otherwise
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> # Check if all samples have been consumed
+ >>> is_consumed = client.check_consumption_status(
+ ... task_name="generate_sequences",
+ ... partition_id="train_0"
+ ... )
+ >>> print(f"All samples consumed: {is_consumed}")
"""
- return self._reset_consumption(partition_id, task_name)
+ return self._check_consumption_status(task_name=task_name, partition_id=partition_id)
def check_production_status(self, data_fields: list[str], partition_id: str) -> bool:
"""Synchronously check if all samples for a partition are ready (produced) for consumption.
@@ -1215,6 +1392,9 @@ def check_production_status(self, data_fields: list[str], partition_id: str) ->
data_fields: Data fields to check production status for
partition_id: Partition id to check production status for
+ Returns:
+ bool: True if all samples have been produced and ready, False otherwise
+
Raises:
RuntimeError: If communication fails or controller returns error response
@@ -1228,30 +1408,31 @@ def check_production_status(self, data_fields: list[str], partition_id: str) ->
"""
return self._check_production_status(data_fields=data_fields, partition_id=partition_id)
- def get_production_status(
- self,
- data_fields: list[str],
- partition_id: str,
- ) -> tuple[Optional[Tensor], Optional[Tensor]]:
- """Synchronously get production status for a specific data fields and partition.
+ def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool:
+ """Synchronously reset consumption status for a partition.
+
+ This allows the same data to be re-consumed, useful for debugging scenarios
+ where the same rollout data needs to be trained multiple times.
Args:
- data_fields: Data fields to check production status for
- partition_id: Partition id to check production status for
+ partition_id: Partition id to reset consumption status for
+ task_name: Name of the task to reset. If None, resets all tasks.
Returns:
- Tuple of:
- - Partition global index tensor
- - Production status tensor for the specified fields. 1 for ready, 0 for not ready.
+ bool: True if reset was successful, False otherwise
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
Example:
- >>> global_index, production_status = client.get_production_status(
- ... data_fields=["input_ids", "attention_mask"],
- ... partition_id="train_0"
+ >>> # Reset consumption for train task to re-train on same data
+ >>> success = client.reset_consumption(
+ ... partition_id="train_0",
+ ... task_name="train"
... )
- >>> print(f"Global index: {global_index}, Production status: {production_status}")
+ >>> print(f"Reset successful: {success}")
"""
- return self._get_production_status(data_fields=data_fields, partition_id=partition_id)
+ return self._reset_consumption(partition_id, task_name)
def get_partition_list(
self,
@@ -1260,32 +1441,51 @@ def get_partition_list(
Returns:
list[str]: List of partition ids managed by the controller
+
+ Example:
+ >>> partition_ids = client.get_partition_list()
+ >>> print(f"Available partitions: {partition_ids}")
"""
return self._get_partition_list()
- def set_custom_meta(self, metadata: BatchMeta) -> None:
- """Synchronously send custom metadata to the controller.
-
- This method sends per-sample custom metadata (custom_meta) to the controller.
- The custom_meta is stored in the controller and can be retrieved along with
- the BatchMeta in subsequent get_meta calls.
+ # ==================== KV Interface API ====================
+ def kv_retrieve_keys(
+ self,
+ keys: list[str] | str,
+ partition_id: str,
+ create: bool = False,
+ ) -> BatchMeta:
+ """Synchronously retrieve BatchMeta from the controller using user-specified keys.
Args:
- metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta() or
- BatchMeta.set_custom_meta() before calling this method.
+ keys: List of keys to retrieve from the controller
+ partition_id: The ID of the logical partition to search for keys.
+ create: Whether to register new keys if not found.
+
+ Returns:
+ metadata: BatchMeta of the corresponding keys
Raises:
- RuntimeError: If communication fails or controller returns error response
+ TypeError: If `keys` is not a list of string or a string
+ """
- Example:
- >>> # Create batch with custom metadata
- >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...)
- >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
- >>> client.set_custom_meta(batch_meta)
+ return self._kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create)
+
+ def kv_list(
+ self,
+ partition_id: str,
+ ) -> tuple[list[Optional[str]], list[Optional[dict]]]:
+ """Synchronously retrieve keys and custom_meta from the controller for partition.
+
+ Args:
+ partition_id: Partition to retrieve from the controller
+
+ Returns:
+ keys: list of keys in the partition
+ custom_meta: list of dict for custom_meta
"""
- return self._set_custom_meta(metadata=metadata)
+ return self._kv_list(partition_id=partition_id)
def close(self) -> None:
"""Close the client and cleanup resources including event loop and thread."""
@@ -1304,36 +1504,3 @@ def close(self) -> None:
logger.warning(f"[{self.client_id}]: Error closing event loop: {e}")
super().close()
-
-
-def process_zmq_server_info(
- handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]]
- | Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"],
-): # noqa: UP007
- """Extract ZMQ server information from handler objects.
-
- Args:
- handlers: Dictionary of handler objects (controllers, storage managers, or storage units),
- or a single handler object
-
- Returns:
- If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information
- If handlers is a single object: ZMQ server information for that object
-
- Examples:
- >>> # Single handler
- >>> controller = TransferQueueController.remote(...)
- >>> info = process_zmq_server_info(controller)
- >>>
- >>> # Multiple handlers
- >>> handlers = {"storage_0": storage_0, "storage_1": storage_1}
- >>> info_dict = process_zmq_server_info(handlers)"""
- # Handle single handler object case
- if not isinstance(handlers, dict):
- return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
- else:
- # Handle dictionary case
- server_info = {}
- for name, handler in handlers.items():
- server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
- return server_info
diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml
new file mode 100644
index 0000000..9503afa
--- /dev/null
+++ b/transfer_queue/config.yaml
@@ -0,0 +1,31 @@
+# This is the default configuration of TransferQueue. Users may modify the default value
+# and use transfer_queue.init(conf) to overwrite the config entries.
+
+controller:
+ # User-defined sampler. User can pass sampler instance to overwrite this string config.
+ sampler: SequentialSampler
+ # Whether return an empty BatchMeta to prevent request blocking when no enough data is available
+ polling_mode: False
+ # ZMQ Server IP & Ports (automatically generated during init)
+ zmq_info: null
+
+
+backend:
+ # Pluggable storage/transport backend of TransferQueue. Choose from:
+ # SimpleStorage, Yuanrong, MooncakeStore, ...
+ storage_backend: SimpleStorage
+
+ # For SimpleStorage:
+ SimpleStorage:
+ # Total number of samples
+ total_storage_size: 100000
+ # Number of distributed storage units for SimpleStorage backend
+ num_data_storage_units: 2
+ # ZMQ Server IP & Ports (automatically generated during init)
+ zmq_info: null
+
+ # For Yuanrong:
+ # TODO
+
+ # For MooncakeStore:
+ # TODO
\ No newline at end of file
diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py
index c91eaf7..6175863 100644
--- a/transfer_queue/controller.py
+++ b/transfer_queue/controller.py
@@ -28,6 +28,7 @@
import ray
import torch
import zmq
+from omegaconf import DictConfig
from ray.util import get_node_ip_address
from torch import Tensor
@@ -213,7 +214,7 @@ class DataPartitionStatus:
# Each tensor tracks which samples have been consumed by that task
consumption_status: dict[str, Tensor] = field(default_factory=dict)
- # Sample metadata
+ # Global indexes
global_indexes: set[int] = field(
default_factory=set
) # set of global indexes that have been added to this partition
@@ -232,6 +233,10 @@ class DataPartitionStatus:
# User-defined metadata that may not apply to field level
custom_meta: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {}
+ # User-defined Keys
+ keys_mapping: dict[str, int] = field(default_factory=dict) # key -> global_idx
+ revert_keys_mapping: dict[int, str] = field(default_factory=dict) # global_idx -> key
+
# Threading lock for concurrency control; only for preventing mask operation error when expanding production_status.
# No need to strictly lock for every read/write operation since freshness is not critical.
data_status_lock: Lock = field(default_factory=Lock)
@@ -824,6 +829,8 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr
self.field_shapes.pop(idx, None)
self.field_custom_backend_meta.pop(idx, None)
self.custom_meta.pop(idx, None)
+ self.keys_mapping.pop(self.revert_keys_mapping[idx], None)
+ self.revert_keys_mapping.pop(idx, None)
except Exception as e:
logger.error(
@@ -831,6 +838,11 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr
f"Attempted to clear global_indexes: {indexes_to_release}"
)
+ def kv_retrieve_keys(self, keys: list[str]) -> list[int | None]:
+ """Translate the user-specified keys to global_indexes"""
+ global_indexes = [self.keys_mapping.get(k, None) for k in keys]
+ return global_indexes
+
@ray.remote(num_cpus=1)
class TransferQueueController:
@@ -879,6 +891,7 @@ def __init__(
self.controller_id = f"TQ_CONTROLLER_{uuid4().hex[:8]}"
self.polling_mode = polling_mode
+ self.tq_config = None # global config for TransferQueue system
# Initialize ZMQ sockets for communication
self._init_zmq_socket()
@@ -979,6 +992,8 @@ def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]
Returns:
List of indexes allocated to the partition
"""
+ # Note: This includes the pre-allocated global_indexes for the partition.
+ # i.e., partition.global_indexes + partition.pre_allocated_global_indexes
return self.index_manager.get_indexes_for_partition(partition)
# ==================== Data Production API ====================
@@ -1158,9 +1173,10 @@ def get_metadata(
)
batch_global_indexes.extend(new_global_indexes)
+ # register global_indexes in partition
+ partition.global_indexes.update(batch_global_indexes)
+
else:
- # TODO: separate this "clear" related logic into a separated mode
- # clear metadata call passes empty data_fields
batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id)
return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)
@@ -1312,7 +1328,7 @@ def generate_batch_meta(
and partition.production_status is not None
and partition.production_status[global_index, field_index] == 1
):
- production_status = ProductionStatus.NOT_PRODUCED
+ production_status = ProductionStatus.READY_FOR_CONSUME
dtype = partition.get_field_dtype(global_index, field_name)
shape = partition.get_field_shape(global_index, field_name)
else:
@@ -1338,7 +1354,7 @@ def generate_batch_meta(
custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields)
batch_meta = BatchMeta(samples=samples)
- batch_meta.update_custom_meta(custom_meta)
+ batch_meta.update_custom_meta([custom_meta.get(idx, {}) for idx in batch_meta.global_indexes])
batch_meta._custom_backend_meta.update(custom_backend_meta)
return batch_meta
@@ -1355,7 +1371,8 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True):
partition = self._get_partition(partition_id)
if not partition:
- raise ValueError(f"Partition {partition_id} not found")
+ logger.warning(f"Try to clear an non-existent partition {partition_id}!")
+ return
global_indexes_range = list(self.index_manager.get_indexes_for_partition(partition_id))
partition.clear_data(global_indexes_range, clear_consumption)
@@ -1372,13 +1389,13 @@ def reset_consumption(self, partition_id: str, task_name: Optional[str] = None):
Args:
partition_id: ID of the partition to reset consumption for
task_name: Name of the task to reset. If None, resets all tasks.
- Raises:
- ValueError: If partition not found
+
"""
logger.debug(f"[{self.controller_id}]: Resetting consumption for partition {partition_id}, task={task_name}")
partition = self._get_partition(partition_id)
if not partition:
- raise ValueError(f"Partition {partition_id} not found")
+ logger.warning(f"Try to reset consumption of an non-existent partition {partition_id}!")
+ return
partition.reset_consumption(task_name)
def clear_meta(
@@ -1430,6 +1447,71 @@ def clear_meta(
# Release the specific indexes from index manager
self.index_manager.release_indexes(partition_id, global_indexes_to_clear)
+ def kv_retrieve_keys(
+ self,
+ keys: list[str],
+ partition_id: str,
+ create: bool = False,
+ ) -> BatchMeta:
+ """
+ Retrieve BatchMeta from the controller using a list of keys.
+
+ Args:
+ keys: List of keys to retrieve from the controller
+ partition_id: Partition id to retrieve from the controller
+ create: Whether to register new keys if not found.
+
+ Returns:
+ metadata: BatchMeta of the requested keys
+ """
+
+ logger.debug(f"[{self.controller_id}]: Retrieve keys {keys} in partition {partition_id}")
+
+ partition = self._get_partition(partition_id)
+
+ if partition is None:
+ if not create:
+ logger.warning(f"Partition {partition_id} were not found in controller!")
+ return BatchMeta.empty()
+ else:
+ self.create_partition(partition_id)
+ partition = self._get_partition(partition_id)
+
+ assert partition is not None
+ global_indexes = partition.kv_retrieve_keys(keys)
+
+ none_indexes = [idx for idx, value in enumerate(global_indexes) if value is None]
+ if len(none_indexes) > 0:
+ if not create:
+ logger.warning(f"Keys {[keys[i] for i in none_indexes]} were not found in partition {partition_id}!")
+ return BatchMeta.empty()
+ else:
+ # create non-exist keys
+ batch_global_indexes = partition.activate_pre_allocated_indexes(len(none_indexes))
+
+ if len(batch_global_indexes) < len(none_indexes):
+ new_global_indexes = self.index_manager.allocate_indexes(
+ partition_id, count=(len(none_indexes) - len(batch_global_indexes))
+ )
+ batch_global_indexes.extend(new_global_indexes)
+
+ # register global_indexes in partition
+ partition.global_indexes.update(batch_global_indexes)
+
+ # register key-global_indexes mapping in partition
+ for i in range(len(none_indexes)):
+ global_indexes[none_indexes[i]] = batch_global_indexes[i]
+ partition.keys_mapping[keys[none_indexes[i]]] = batch_global_indexes[i]
+ partition.revert_keys_mapping[batch_global_indexes[i]] = keys[none_indexes[i]]
+
+ verified_global_indexes = [idx for idx in global_indexes if idx is not None]
+ assert len(verified_global_indexes) == len(keys)
+
+ data_fields = list(partition.field_name_mapping.keys())
+ metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch")
+
+ return metadata
+
def _init_zmq_socket(self):
"""Initialize ZMQ sockets for communication."""
self.zmq_context = zmq.Context()
@@ -1725,6 +1807,41 @@ def _process_request(self):
body={"partition_ids": partition_ids},
)
+ elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS:
+ with perf_monitor.measure(op_type="KV_RETRIEVE_KEYS"):
+ keys = params["keys"]
+ partition_id = params["partition_id"]
+ create = params["create"]
+
+ metadata = self.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create)
+ response_msg = ZMQMessage.create(
+ request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE,
+ sender_id=self.controller_id,
+ receiver_id=request_msg.sender_id,
+ body={"metadata": metadata},
+ )
+
+ elif request_msg.request_type == ZMQRequestType.KV_LIST:
+ with perf_monitor.measure(op_type="KV_LIST"):
+ partition_id = params["partition_id"]
+ partition = self._get_partition(partition_id)
+ if not partition:
+ keys = []
+ custom_meta = []
+ message = f"Partition {partition_id} not found for kv_list."
+ logger.debug(f"[{self.controller_id}]: {message}")
+ else:
+ keys = list(partition.keys_mapping.keys())
+ custom_meta = [partition.custom_meta.get(partition.keys_mapping[k], {}) for k in keys]
+ message = "Success"
+
+ response_msg = ZMQMessage.create(
+ request_type=ZMQRequestType.KV_LIST_RESPONSE,
+ sender_id=self.controller_id,
+ receiver_id=request_msg.sender_id,
+ body={"keys": keys, "custom_meta": custom_meta, message: message},
+ )
+
self.request_handle_socket.send_multipart([identity, *response_msg.serialize()])
def _update_data_status(self):
@@ -1772,3 +1889,35 @@ def _update_data_status(self):
def get_zmq_server_info(self) -> ZMQServerInfo:
"""Get ZMQ server connection information."""
return self.zmq_server_info
+
+ def store_config(self, conf: DictConfig) -> None:
+ """Store the global config of TransferQueue."""
+ self.tq_config = conf
+
+ def get_config(self) -> DictConfig:
+ """Retrieve the global config of TransferQueue."""
+ return self.tq_config
+
+ def register_sampler(
+ self,
+ sampler: BaseSampler | type[BaseSampler] = SequentialSampler,
+ ) -> None:
+ """
+ Register a sampler instance or subclass after the controller is initialized.
+
+ Args:
+ sampler: Sampler instance or sampler class to use for data sampling.
+ - If a BaseSampler instance is provided, it will be used directly
+ - If a BaseSampler subclass is provided, it will be instantiated
+ - Defaults to SequentialSampler for simple sequential sampling
+ - Example: sampler=GRPOGroupNSampler() (instance)
+ - Example: sampler=SequentialSampler (class)
+ """
+ if isinstance(sampler, BaseSampler):
+ self.sampler = sampler
+ elif isinstance(sampler, type) and issubclass(sampler, BaseSampler):
+ self.sampler = sampler()
+ else:
+ raise TypeError(
+ f"sampler {getattr(sampler, '__name__', repr(sampler))} must be an instance or subclass of BaseSampler"
+ )
diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py
index eb2afe4..1de90c5 100644
--- a/transfer_queue/dataloader/streaming_dataset.py
+++ b/transfer_queue/dataloader/streaming_dataset.py
@@ -17,8 +17,10 @@
import os
import time
import uuid
-from typing import Any, Callable, Iterator
+import warnings
+from typing import Callable, Iterator
+from omegaconf import DictConfig
from tensordict import TensorDict
from torch.utils.data import IterableDataset
@@ -68,7 +70,7 @@ class StreamingDataset(IterableDataset):
def __init__(
self,
- config: dict[str, Any],
+ config: DictConfig,
batch_size: int,
micro_batch_size: int,
data_fields: list[str],
@@ -82,8 +84,8 @@ def __init__(
Args:
config: Configuration dictionary containing:
- - controller_info: ZMQServerInfo for the TransferQueueController
- - storage_backend: Storage backend type (e.g., "AsyncSimpleStorageManager")
+ - controller.controller_info: ZMQServerInfo for the TransferQueueController
+ - backend.storage_backend: Storage backend type (e.g., "SimpleStorage")
- Other backend-specific configuration
batch_size: Batch size for data loading per iter.
micro_batch_size: Number of samples per micro-batch. This is the batch size
@@ -156,16 +158,44 @@ def _create_client(self):
ValueError: If controller_info or storage_backend is missing or invalid.
"""
client_id = uuid.uuid4().hex[:8]
- controller_info = self.config.get("controller_info", None)
+
+ # TODO: DEPRECATE in future
+ controller_config = self.config.get("controller", None)
+ if controller_config:
+ controller_info = controller_config.get("zmq_info", None)
+ else:
+ controller_info = self.config.get("controller_info", None)
+ if controller_info:
+ warnings.warn(
+ "Config entry `controller_info` will be deprecated in 0.1.7, please "
+ "use `controller.zmq_info` instead.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+
if not controller_info or not isinstance(controller_info, ZMQServerInfo):
- raise ValueError("Invalid or missing controller_info in config")
+ raise ValueError("Invalid or missing controller.zmq_info in config")
+
+ backend_config = self.config.get("backend", None)
+ if not backend_config:
+ storage_backend = self.config.get("storage_backend", None)
+ backend_config = self.config
+ if storage_backend:
+ warnings.warn(
+ "Config entry `storage_backend` will be deprecated in 0.1.7, please "
+ "use `backend.storage_backend` instead.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+ else:
+ storage_backend = backend_config.get("storage_backend", None)
+ backend_config = self.config.backend[storage_backend]
- storage_backend = self.config.get("storage_backend", None)
if not storage_backend:
raise ValueError("Missing storage_backend in config")
self._tq_client = TransferQueueClient(client_id, controller_info)
- self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=self.config)
+ self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config)
def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
"""Iterate over the dataset, yielding batches of data.
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
new file mode 100644
index 0000000..2585c35
--- /dev/null
+++ b/transfer_queue/interface.py
@@ -0,0 +1,1136 @@
+# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+# Copyright 2025 The TransferQueue Team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.resources as pkg_resources
+import logging
+import math
+import os
+import time
+from typing import Any, Optional
+
+import ray
+import torch
+from omegaconf import DictConfig, OmegaConf
+from tensordict import TensorDict
+from tensordict.tensorclass import NonTensorStack
+
+from transfer_queue.client import TransferQueueClient
+from transfer_queue.controller import TransferQueueController
+from transfer_queue.metadata import BatchMeta
+from transfer_queue.sampler import * # noqa: F401
+from transfer_queue.sampler import BaseSampler
+from transfer_queue.storage.simple_backend import SimpleStorageUnit
+from transfer_queue.utils.common import get_placement_group
+from transfer_queue.utils.zmq_utils import process_zmq_server_info
+
+logger = logging.getLogger(__name__)
+logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
+
+_TRANSFER_QUEUE_CLIENT: Any = None
+_TRANSFER_QUEUE_STORAGE: Any = None
+
+
+def _maybe_create_transferqueue_client(
+ conf: Optional[DictConfig] = None,
+) -> TransferQueueClient:
+ global _TRANSFER_QUEUE_CLIENT
+ if _TRANSFER_QUEUE_CLIENT is None:
+ if conf is None:
+ raise ValueError("Missing config for initializing TransferQueueClient!")
+ pid = os.getpid()
+ _TRANSFER_QUEUE_CLIENT = TransferQueueClient(
+ client_id=f"TransferQueueClient_{pid}", controller_info=conf.controller.zmq_info
+ )
+
+ backend_name = conf.backend.storage_backend
+
+ _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=backend_name, config=conf.backend[backend_name])
+
+ return _TRANSFER_QUEUE_CLIENT
+
+
+def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig:
+ global _TRANSFER_QUEUE_STORAGE
+
+ if _TRANSFER_QUEUE_STORAGE is None:
+ _TRANSFER_QUEUE_STORAGE = {}
+ if conf.backend.storage_backend == "SimpleStorage":
+ # initialize SimpleStorageUnit
+ num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units
+ total_storage_size = conf.backend.SimpleStorage.total_storage_size
+ storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1)
+
+ for storage_unit_rank in range(num_data_storage_units):
+ storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined]
+ placement_group=storage_placement_group,
+ placement_group_bundle_index=storage_unit_rank,
+ name=f"TransferQueueStorageUnit#{storage_unit_rank}",
+ lifetime="detached",
+ ).remote(storage_unit_size=math.ceil(total_storage_size / num_data_storage_units))
+ _TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node
+ logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.")
+
+ storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_STORAGE)
+ backend_name = conf.backend.storage_backend
+ conf.backend[backend_name].zmq_info = storage_zmq_info
+
+ return conf
+
+
+def _init_from_existing() -> None:
+ """Initialize the TransferQueueClient from existing controller."""
+
+ controller = ray.get_actor("TransferQueueController")
+ logger.info("Found existing TransferQueueController instance. Connecting...")
+
+ conf = None
+ while conf is None:
+ remote_conf = ray.get(controller.get_config.remote())
+ if remote_conf is not None:
+ _maybe_create_transferqueue_client(remote_conf)
+ logger.info("TransferQueueClient initialized.")
+ return
+
+ logger.debug("Waiting for controller to initialize... Retrying in 1s")
+ time.sleep(1)
+
+
+def init(conf: Optional[DictConfig] = None) -> None:
+ """Initialize the TransferQueue system.
+
+ This function sets up the TransferQueue controller, distributed storage, and client.
+ It should be called once at the beginning of the program before any data operations.
+
+ If a controller already exists (e.g., initialized by another process), this function
+ will retrieve the config from existing controller and initialize the TransferQueueClient.
+ In this case, the `conf` parameter will be ignored.
+
+ Args:
+ conf: Optional configuration dictionary. If provided, it will be merged with
+ the default config from 'config.yaml'. This is only used for first-time
+ initializing. When connecting to an existing controller, this parameter
+ is ignored.
+
+ Raises:
+ ValueError: If config is not valid or required configuration keys are missing.
+
+ Example:
+ >>> # In process 0, node A
+ >>> import transfer_queue as tq
+ >>> tq.init() # Initialize the TransferQueue
+ >>> tq.put(...) # then you can use tq for data operations
+ >>>
+ >>> # In process 1, node B (with Ray connected to node A)
+ >>> import transfer_queue as tq
+ >>> tq.init() # This will only initialize a TransferQueueClient and link with existing TQ
+ >>> metadata = tq.get_meta(...)
+ >>> data = tq.get_data(metadata)
+ """
+ try:
+ _init_from_existing()
+ except ValueError:
+ logger.info("No TransferQueueController found. Starting first-time initialization...")
+ else:
+ return
+
+ # First-time initialize TransferQueue
+
+ # create config
+ final_conf = OmegaConf.create({}, flags={"allow_objects": True})
+ with pkg_resources.path("transfer_queue", "config.yaml") as p:
+ default_conf = OmegaConf.load(p)
+ final_conf = OmegaConf.merge(final_conf, default_conf)
+ if conf:
+ final_conf = OmegaConf.merge(final_conf, conf)
+
+ # create controller
+ try:
+ sampler = final_conf.controller.sampler
+ if isinstance(sampler, BaseSampler):
+ # user pass a pre-initialized sampler instance
+ sampler = sampler
+ elif isinstance(sampler, type) and issubclass(sampler, BaseSampler):
+ # user pass a sampler class
+ sampler = sampler()
+ elif isinstance(sampler, str):
+ # user pass a sampler name str
+ # try to convert as sampler class
+ sampler = globals()[final_conf.controller.sampler]
+ except KeyError:
+ raise ValueError(f"Could not find sampler {final_conf.controller.sampler}") from None
+
+ try:
+ # Ray will make sure actor with same name can only be created once
+ controller = TransferQueueController.options(name="TransferQueueController", lifetime="detached").remote( # type: ignore[attr-defined]
+ sampler=sampler, polling_mode=final_conf.controller.polling_mode
+ )
+ logger.info("TransferQueueController has been created.")
+ except ValueError:
+ logger.info("Some other rank has initialized TransferQueueController. Try to connect to existing controller.")
+ _init_from_existing()
+ return
+
+ controller_zmq_info = process_zmq_server_info(controller)
+ final_conf.controller.zmq_info = controller_zmq_info
+
+ # create distributed storage backends
+ final_conf = _maybe_create_transferqueue_storage(final_conf)
+
+ # store the config into controller
+ ray.get(controller.store_config.remote(final_conf))
+ logger.info(f"TransferQueue config: {final_conf}")
+
+ # create client
+ _maybe_create_transferqueue_client(final_conf)
+
+
+# ==================== Basic API ====================
+def get_meta(
+ data_fields: list[str],
+ batch_size: int,
+ partition_id: str,
+ mode: str = "fetch",
+ task_name: Optional[str] = None,
+ sampling_config: Optional[dict[str, Any]] = None,
+) -> BatchMeta:
+ """Synchronously fetch data metadata from the controller via ZMQ.
+
+ Args:
+ data_fields: List of data field names to retrieve metadata for
+ batch_size: Number of samples to request in the batch
+ partition_id: Current data partition id
+ mode: Data fetch mode. Options:
+ - 'fetch': Get ready data only
+ - 'force_fetch': Get data regardless of readiness (may return unready samples)
+ - 'insert': Internal usage - should not be used by users
+ task_name: Optional task name associated with the request
+ sampling_config: Optional sampling configuration for custom samplers.
+
+ Returns:
+ BatchMeta: Metadata object containing data structure, sample information, and readiness status
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>>
+ >>> # Example 1: Basic fetch metadata
+ >>> batch_meta = tq.get_meta(
+ ... data_fields=["input_ids", "attention_mask"],
+ ... batch_size=4,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences"
+ ... )
+ >>> print(batch_meta.is_ready) # True if all samples ready
+ >>>
+ >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
+ >>> batch_meta = tq.get_meta(
+ ... data_fields=["input_ids", "attention_mask"],
+ ... batch_size=8,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... sampling_config={"n_samples_per_prompt": 4}
+ ... )
+ >>> print(batch_meta.is_ready) # True if all samples ready
+ >>>
+ >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
+ >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
+ >>> batch_meta = tq.get_meta(
+ ... partition_id="train_0", # optional
+ ... mode="force_fetch",
+ ... )
+ >>> print(batch_meta.is_ready) # May be False if some samples not ready
+ """
+
+ tq_client = _maybe_create_transferqueue_client()
+ return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config)
+
+
+def set_custom_meta(metadata: BatchMeta) -> None:
+ """Synchronously send custom metadata to the controller.
+
+ This method sends per-sample custom metadata (custom_meta) to the controller.
+ The custom_meta is stored in the controller and can be retrieved along with
+ the BatchMeta in subsequent get_meta calls.
+
+ Args:
+ metadata: BatchMeta containing the samples and their custom metadata to store.
+ The custom_meta should be set using BatchMeta.update_custom_meta() or
+ BatchMeta.set_custom_meta() before calling this method.
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>>
+ >>> # Create batch with custom metadata
+ >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...)
+ >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
+ >>> tq.set_custom_meta(batch_meta)
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return tq_client.set_custom_meta(metadata)
+
+
+def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta:
+ """Synchronously write data to storage units based on metadata.
+
+ If metadata is not provided, it will be created automatically using insert mode
+ with the provided data fields and partition_id.
+
+ During put, the custom_meta in metadata will update the corresponding custom_meta in
+ TransferQueue Controller.
+
+ Note:
+ When using multiple workers for distributed execution, there may be data
+ ordering inconsistencies between workers during put operations.
+
+ Args:
+ data: Data to write as TensorDict
+ metadata: Records the metadata of a batch of data samples, containing index and
+ storage unit information. If None, metadata will be auto-generated.
+ partition_id: Target data partition id (required if metadata is not provided)
+
+ Returns:
+ BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
+ metadata; will be updated in a future version to reflect the post-put state)
+
+ Raises:
+ ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
+ RuntimeError: If storage operation fails
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>>
+ >>> batch_size = 4
+ >>> seq_len = 16
+ >>> current_partition_id = "train_0"
+ >>> # Example 1: Normal usage with existing metadata
+ >>> batch_meta = tq.get_meta(
+ ... data_fields=["prompts", "attention_mask"],
+ ... batch_size=batch_size,
+ ... partition_id=current_partition_id,
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... )
+ >>> batch = tq.get_data(batch_meta)
+ >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
+ >>> tq.put(data=output, metadata=batch_meta)
+ >>>
+ >>> # Example 2: Initial data insertion without pre-existing metadata
+ >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
+ >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
+ >>> # without metadata.
+ >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
+ >>> # interleave the initial data if n_sample > 1 before calling the async_put().
+ >>> original_prompts = torch.randn(batch_size, seq_len)
+ >>> n_samples = 4
+ >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
+ >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
+ >>> # This will create metadata in "insert" mode internally.
+ >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id)
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return tq_client.put(data, metadata, partition_id)
+
+
+def get_data(metadata: BatchMeta) -> TensorDict:
+ """Synchronously fetch data from storage units and organize into TensorDict.
+
+ Args:
+ metadata: Batch metadata containing data location information and global indexes
+
+ Returns:
+ TensorDict containing:
+ - Requested data fields (e.g., "prompts", "attention_mask")
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>>
+ >>> batch_meta = tq.get_data(
+ ... data_fields=["prompts", "attention_mask"],
+ ... batch_size=4,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... )
+ >>> batch = tq.get_data(batch_meta)
+ >>> print(batch)
+ >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return tq_client.get_data(metadata)
+
+
+def clear_partition(partition_id: str):
+ """Synchronously clear the whole partition from all storage units and the controller.
+
+ Args:
+ partition_id: The partition id to clear data for
+
+ Raises:
+ RuntimeError: If clear operation fails
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return tq_client.clear_partition(partition_id)
+
+
+def clear_samples(metadata: BatchMeta):
+ """Synchronously clear specific samples from all storage units and the controller.
+
+ Args:
+ metadata: The BatchMeta of the corresponding data to be cleared
+
+ Raises:
+ RuntimeError: If clear operation fails
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return tq_client.clear_samples(metadata)
+
+
+async def async_get_meta(
+ data_fields: list[str],
+ batch_size: int,
+ partition_id: str,
+ mode: str = "fetch",
+ task_name: Optional[str] = None,
+ sampling_config: Optional[dict[str, Any]] = None,
+) -> BatchMeta:
+ """Asynchronously fetch data metadata from the controller via ZMQ.
+
+ Args:
+ data_fields: List of data field names to retrieve metadata for
+ batch_size: Number of samples to request in the batch
+ partition_id: Current data partition id
+ mode: Data fetch mode. Options:
+ - 'fetch': Get ready data only
+ - 'force_fetch': Get data regardless of readiness (may return unready samples)
+ - 'insert': Internal usage - should not be used by users
+ task_name: Optional task name associated with the request
+ sampling_config: Optional sampling configuration for custom samplers.
+
+ Returns:
+ BatchMeta: Metadata object containing data structure, sample information, and readiness status
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>>
+ >>> # Example 1: Basic fetch metadata
+ >>> batch_meta = asyncio.run(tq.async_get_meta(
+ ... data_fields=["input_ids", "attention_mask"],
+ ... batch_size=4,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences"
+ ... ))
+ >>> print(batch_meta.is_ready) # True if all samples ready
+ >>>
+ >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
+ >>> batch_meta = asyncio.run(tq.async_get_meta(
+ ... data_fields=["input_ids", "attention_mask"],
+ ... batch_size=8,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... ))
+ >>> print(batch_meta.is_ready) # True if all samples ready
+ >>>
+ >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
+ >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
+ >>> batch_meta = asyncio.run(tq.async_get_meta(
+ ... partition_id="train_0", # optional
+ ... mode="force_fetch",
+ ... ))
+ >>> print(batch_meta.is_ready) # May be False if some samples not ready
+ """
+
+ tq_client = _maybe_create_transferqueue_client()
+ return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config)
+
+
+async def async_set_custom_meta(
+ metadata: BatchMeta,
+) -> None:
+ """
+ Asynchronously send custom metadata to the controller.
+
+ This method sends per-sample custom metadata (custom_meta) to the controller.
+ The custom_meta is stored in the controller and can be retrieved along with
+ the BatchMeta in subsequent get_meta calls.
+
+ Args:
+ metadata: BatchMeta containing the samples and their custom metadata to store.
+ The custom_meta should be set using BatchMeta.update_custom_meta() or
+ BatchMeta.set_custom_meta() before calling this method.
+ socket: ZMQ async socket for message transmission (injected by decorator)
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>>
+ >>> # Create batch with custom metadata
+ >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...)
+ >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
+ >>> asyncio.run(tq.async_set_custom_meta(batch_meta))
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return await tq_client.async_set_custom_meta(metadata)
+
+
+async def async_put(
+ data: TensorDict,
+ metadata: Optional[BatchMeta] = None,
+ partition_id: Optional[str] = None,
+) -> BatchMeta:
+ """Asynchronously write data to storage units based on metadata.
+
+ If metadata is not provided, it will be created automatically using insert mode
+ with the provided data fields and partition_id.
+
+ During put, the custom_meta in metadata will update the corresponding custom_meta in
+ TransferQueue Controller.
+
+ Note:
+ When using multiple workers for distributed execution, there may be data
+ ordering inconsistencies between workers during put operations.
+
+ Args:
+ data: Data to write as TensorDict
+ metadata: Records the metadata of a batch of data samples, containing index and
+ storage unit information. If None, metadata will be auto-generated.
+ partition_id: Target data partition id (required if metadata is not provided)
+
+ Returns:
+ BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
+ metadata; will be updated in a future version to reflect the post-put state)
+
+ Raises:
+ ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
+ RuntimeError: If storage operation fails
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>>
+ >>> batch_size = 4
+ >>> seq_len = 16
+ >>> current_partition_id = "train_0"
+ >>> # Example 1: Normal usage with existing metadata
+ >>> batch_meta = asyncio.run(tq.async_get_meta(
+ ... data_fields=["prompts", "attention_mask"],
+ ... batch_size=batch_size,
+ ... partition_id=current_partition_id,
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... ))
+ >>> batch = asyncio.run(tq.async_get_data(batch_meta))
+ >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
+ >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta))
+ >>>
+ >>> # Example 2: Initial data insertion without pre-existing metadata
+ >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
+ >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
+ >>> # without metadata.
+ >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
+ >>> # interleave the initial data if n_sample > 1 before calling the async_put().
+ >>> original_prompts = torch.randn(batch_size, seq_len)
+ >>> n_samples = 4
+ >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
+ >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
+ >>> # This will create metadata in "insert" mode internally.
+ >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id))
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return await tq_client.async_put(data, metadata, partition_id)
+
+
+async def async_get_data(metadata: BatchMeta) -> TensorDict:
+ """Asynchronously fetch data from storage units and organize into TensorDict.
+
+ Args:
+ metadata: Batch metadata containing data location information and global indexes
+
+ Returns:
+ TensorDict containing:
+ - Requested data fields (e.g., "prompts", "attention_mask")
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>>
+ >>> batch_meta = asyncio.run(tq.async_get_meta(
+ ... data_fields=["prompts", "attention_mask"],
+ ... batch_size=4,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... ))
+ >>> batch = asyncio.run(tq.async_get_data(batch_meta))
+ >>> print(batch)
+ >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return await tq_client.async_get_data(metadata)
+
+
+# ==================== Data Operations API ====================
+
+
+async def async_clear_samples(metadata: BatchMeta):
+ """Asynchronously clear specific samples from all storage units and the controller.
+
+ Args:
+ metadata: The BatchMeta of the corresponding data to be cleared
+
+ Raises:
+ RuntimeError: If clear operation fails
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return await tq_client.async_clear_samples(metadata)
+
+
+async def async_clear_partition(partition_id: str):
+ """Asynchronously clear the whole partition from all storage units and the controller.
+
+ Args:
+ partition_id: The partition id to clear data for
+
+ Raises:
+ RuntimeError: If clear operation fails
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return await tq_client.async_clear_partition(partition_id)
+
+
+def close():
+ """Close the TransferQueue system.
+
+ This function cleans up the TransferQueue system, including:
+ - Closing the client and its associated resources
+ - Cleaning up distributed storage (only for the process that initialized it)
+ - Killing the controller actor
+
+ Note:
+ This function should be called when the TransferQueue system is no longer needed.
+ """
+ global _TRANSFER_QUEUE_CLIENT
+ global _TRANSFER_QUEUE_STORAGE
+ if _TRANSFER_QUEUE_CLIENT:
+ _TRANSFER_QUEUE_CLIENT.close()
+ _TRANSFER_QUEUE_CLIENT = None
+
+ try:
+ if _TRANSFER_QUEUE_STORAGE:
+ # only the process that do first-time init can clean the distributed storage
+ for storage in _TRANSFER_QUEUE_STORAGE.values():
+ ray.kill(storage)
+ _TRANSFER_QUEUE_STORAGE = None
+ except Exception:
+ pass
+
+ try:
+ controller = ray.get_actor("TransferQueueController")
+ ray.kill(controller)
+ except Exception:
+ pass
+
+
+# ==================== KV Interface API ====================
+def kv_put(
+ key: str, partition_id: str, fields: Optional[TensorDict | dict[str, Any]], tag: Optional[dict[str, Any]]
+) -> None:
+ """Put a single key-value pair to TransferQueue.
+
+ This is a convenience method for putting data using a user-specified key
+ instead of BatchMeta. Internally, the key is translated to a BatchMeta
+ and the data is stored using the regular put mechanism.
+
+ Args:
+ key: User-specified key for the data sample (in row)
+ partition_id: Logical partition to store the data in
+ fields: Data fields to store. Can be a TensorDict or a dict of tensors.
+ Each key in `fields` will be treated as a column for the data sample.
+ If dict is provided, tensors will be unsqueezed to add batch dimension.
+ tag: Optional metadata tag to associate with the key
+
+ Raises:
+ ValueError: If neither fields nor tag is provided
+ ValueError: If nested tensors are provided (use kv_batch_put instead)
+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> import torch
+ >>> tq.init()
+ >>> # Put with both fields and tag
+ >>> tq.kv_put(
+ ... key="sample_1",
+ ... partition_id="train",
+ ... fields={"input_ids": torch.tensor([1, 2, 3])},
+ ... tag={"score": 0.95}
+ ... )
+ """
+ if fields is None and tag is None:
+ raise ValueError("Please provide at least one parameter of fields or tag.")
+
+ tq_client = _maybe_create_transferqueue_client()
+
+ # 1. translate user-specified key to BatchMeta
+ batch_meta = tq_client.kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True)
+
+ if batch_meta.size != 1:
+ raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
+
+ # 2. register the user-specified tag to BatchMeta
+ if tag:
+ batch_meta.update_custom_meta([tag])
+
+ # 3. put data
+ if fields is not None:
+ if isinstance(fields, dict):
+ # TODO: consider whether to support this...
+ batch = {}
+ for key, value in fields.items():
+ if isinstance(value, torch.Tensor):
+ if value.is_nested:
+ raise ValueError("Please use (async)kv_batch_put for batch operation")
+ batch[key] = value.unsqueeze(0)
+ else:
+ batch[key] = NonTensorStack(value)
+ fields = TensorDict(batch, batch_size=[1])
+ elif not isinstance(fields, TensorDict):
+ raise ValueError("field can only be dict or TensorDict")
+
+ # custom_meta (tag) will be put to controller through the internal put process
+ tq_client.put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tag) to controller
+ tq_client.set_custom_meta(batch_meta)
+
+
+def kv_batch_put(keys: list[str], partition_id: str, fields: TensorDict, tags: list[dict[str, Any]]) -> None:
+ """Put multiple key-value pairs to TransferQueue in batch.
+
+ This method stores multiple key-value pairs in a single operation, which is more
+ efficient than calling kv_put multiple times.
+
+ Args:
+ keys: List of user-specified keys for the data
+ partition_id: Logical partition to store the data in
+ fields: TensorDict containing data for all keys. Must have batch_size == len(keys)
+ tags: List of metadata tags, one for each key
+
+ Raises:
+ ValueError: If neither `fields` nor `tags` is provided
+ ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> from tensordict import TensorDict
+ >>> tq.init()
+ >>> keys = ["sample_1", "sample_2", "sample_3"]
+ >>> fields = TensorDict({
+ ... "input_ids": torch.randn(3, 10),
+ ... "attention_mask": torch.ones(3, 10),
+ ... }, batch_size=3)
+ >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
+ >>> tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
+ """
+
+ if fields is None and tags is None:
+ raise ValueError("Please provide at least one parameter of fields or tag.")
+
+ if fields.batch_size[0] != len(keys):
+ raise ValueError(
+ f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
+ f"batch_size {fields.batch_size[0]}"
+ )
+
+ tq_client = _maybe_create_transferqueue_client()
+
+ # 1. translate user-specified key to BatchMeta
+ batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True)
+
+ if batch_meta.size != len(keys):
+ raise RuntimeError(
+ f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!"
+ )
+
+ # 2. register the user-specified tags to BatchMeta
+ if tags:
+ if len(tags) != len(keys):
+ raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
+ batch_meta.update_custom_meta(tags)
+
+ # 3. put data
+ if fields is not None:
+ tq_client.put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tags) to controller
+ tq_client.set_custom_meta(batch_meta)
+
+
+def kv_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict:
+ """Get data from TransferQueue using user-specified keys.
+
+ This is a convenience method for retrieving data using keys instead of indexes.
+
+ Args:
+ keys: Single key or list of keys to retrieve
+ partition_id: Partition containing the keys
+ fields: Optional field(s) to retrieve. If None, retrieves all fields
+
+ Returns:
+ TensorDict with the requested data
+
+ Raises:
+ RuntimeError: If keys or partition are not found
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> # Get single key with all fields
+ >>> data = tq.kv_get(key="sample_1", partition_id="train")
+ >>> # Get multiple keys with specific fields
+ >>> data = tq.kv_get(
+ ... keys=["sample_1", "sample_2"],
+ ... partition_id="train",
+ ... fields="input_ids"
+ ... )
+ """
+ tq_client = _maybe_create_transferqueue_client()
+
+ batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
+
+ if batch_meta.size == 0:
+ raise RuntimeError("keys or partition were not found!")
+
+ if fields is not None:
+ if isinstance(fields, str):
+ fields = [fields]
+ batch_meta = batch_meta.select_fields(fields)
+
+ data = tq_client.get_data(batch_meta)
+
+ return data
+
+
+def kv_list(partition_id: str) -> tuple[list[Optional[str]], list[Optional[dict[str, Any]]]]:
+ """List all keys and their metadata in a partition.
+
+ Args:
+ partition_id: Partition to list keys from
+
+ Returns:
+ Tuple of:
+ - List of keys in the partition
+ - List of custom metadata (tags) associated with each key
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> keys, tags = tq.kv_list(partition_id="train")
+ >>> print(f"Keys: {keys}")
+ >>> print(f"Tags: {tags}")
+ """
+ tq_client = _maybe_create_transferqueue_client()
+
+ keys, custom_meta = tq_client.kv_list(partition_id)
+
+ return keys, custom_meta
+
+
+def kv_clear(keys: list[str] | str, partition_id: str) -> None:
+ """Clear key-value pairs from TransferQueue.
+
+ This removes the specified keys and their associated data from both
+ the controller and storage units.
+
+ Args:
+ keys: Single key or list of keys to clear
+ partition_id: Partition containing the keys
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> # Clear single key
+ >>> tq.kv_clear(key="sample_1", partition_id="train")
+ >>> # Clear multiple keys
+ >>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
+ """
+
+ if isinstance(keys, str):
+ keys = [keys]
+
+ tq_client = _maybe_create_transferqueue_client()
+ batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
+
+ if batch_meta.size > 0:
+ tq_client.clear_samples(batch_meta)
+
+
+# ==================== KV Interface API ====================
+async def async_kv_put(
+ key: str, partition_id: str, fields: Optional[TensorDict | dict[str, Any]], tag: Optional[dict[str, Any]]
+) -> None:
+ """Asynchronously put a single key-value pair to TransferQueue.
+
+ This is a convenience method for putting data using a user-specified key
+ instead of BatchMeta. Internally, the key is translated to a BatchMeta
+ and the data is stored using the regular put mechanism.
+
+ Args:
+ key: User-specified key for the data sample (in row)
+ partition_id: Logical partition to store the data in
+ fields: Data fields to store. Can be a TensorDict or a dict of tensors.
+ Each key in `fields` will be treated as a column for the data sample.
+ If dict is provided, tensors will be unsqueezed to add batch dimension.
+ tag: Optional metadata tag to associate with the key
+
+ Raises:
+ ValueError: If neither fields nor tag is provided
+ ValueError: If nested tensors are provided (use kv_batch_put instead)
+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> import torch
+ >>> tq.init()
+ >>> # Put with both fields and tag
+ >>> await tq.async_kv_put(
+ ... key="sample_1",
+ ... partition_id="train",
+ ... fields={"input_ids": torch.tensor([1, 2, 3])},
+ ... tag={"score": 0.95}
+ ... ))
+ """
+
+ if fields is None and tag is None:
+ raise ValueError("Please provide at least one parameter of fields or tag.")
+
+ tq_client = _maybe_create_transferqueue_client()
+
+ # 1. translate user-specified key to BatchMeta
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True)
+
+ if batch_meta.size != 1:
+ raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
+
+ # 2. register the user-specified tag to BatchMeta
+ if tag:
+ batch_meta.update_custom_meta([tag])
+
+ # 3. put data
+ if fields is not None:
+ if isinstance(fields, dict):
+ # TODO: consider whether to support this...
+ batch = {}
+ for key, value in fields.items():
+ if isinstance(value, torch.Tensor):
+ if value.is_nested:
+ raise ValueError("Please use (async)kv_batch_put for batch operation")
+ batch[key] = value.unsqueeze(0)
+ else:
+ batch[key] = NonTensorStack(value)
+ fields = TensorDict(batch, batch_size=[1])
+ elif not isinstance(fields, TensorDict):
+ raise ValueError("field can only be dict or TensorDict")
+
+ # custom_meta (tag) will be put to controller through the put process
+ await tq_client.async_put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tag) to controller
+ await tq_client.async_set_custom_meta(batch_meta)
+
+
+async def async_kv_batch_put(
+ keys: list[str], partition_id: str, fields: TensorDict, tags: list[dict[str, Any]]
+) -> None:
+ """Asynchronously put multiple key-value pairs to TransferQueue in batch.
+
+ This method stores multiple key-value pairs in a single operation, which is more
+ efficient than calling kv_put multiple times.
+
+ Args:
+ keys: List of user-specified keys for the data
+ partition_id: Logical partition to store the data in
+ fields: TensorDict containing data for all keys. Must have batch_size == len(keys)
+ tags: List of metadata tags, one for each key
+
+ Raises:
+ ValueError: If neither `fields` nor `tags` is provided
+ ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> keys = ["sample_1", "sample_2", "sample_3"]
+ >>> fields = TensorDict({
+ ... "input_ids": torch.randn(3, 10),
+ ... "attention_mask": torch.ones(3, 10),
+ ... }, batch_size=3)
+ >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
+ >>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
+ """
+
+ if fields is None and tags is None:
+ raise ValueError("Please provide at least one parameter of fields or tag.")
+
+ if fields.batch_size[0] != len(keys):
+ raise ValueError(
+ f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
+ f"batch_size {fields.batch_size[0]}"
+ )
+
+ tq_client = _maybe_create_transferqueue_client()
+
+ # 1. translate user-specified key to BatchMeta
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True)
+
+ if batch_meta.size != len(keys):
+ raise RuntimeError(
+ f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!"
+ )
+
+ # 2. register the user-specified tags to BatchMeta
+ if tags:
+ if len(tags) != len(keys):
+ raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
+ batch_meta.update_custom_meta(tags)
+
+ # 3. put data
+ if fields is not None:
+ await tq_client.async_put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tags) to controller
+ await tq_client.async_set_custom_meta(batch_meta)
+
+
+async def async_kv_get(
+ keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None
+) -> TensorDict:
+ """Asynchronously get data from TransferQueue using user-specified keys.
+
+ This is a convenience method for retrieving data using keys instead of indexes.
+
+ Args:
+ keys: Single key or list of keys to retrieve
+ partition_id: Partition containing the keys
+ fields: Optional field(s) to retrieve. If None, retrieves all fields
+
+ Returns:
+ TensorDict with the requested data
+
+ Raises:
+ RuntimeError: If keys or partition are not found
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> # Get single key with all fields
+ >>> data = await tq.async_kv_get(key="sample_1", partition_id="train")
+ >>> # Get multiple keys with specific fields
+ >>> data = await tq.async_kv_get(
+ ... keys=["sample_1", "sample_2"],
+ ... partition_id="train",
+ ... fields="input_ids"
+ ... )
+ """
+ tq_client = _maybe_create_transferqueue_client()
+
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
+
+ if batch_meta.size == 0:
+ raise RuntimeError("keys or partition were not found!")
+
+ if fields is not None:
+ if isinstance(fields, str):
+ fields = [fields]
+ batch_meta = batch_meta.select_fields(fields)
+
+ data = await tq_client.async_get_data(batch_meta)
+
+ return data
+
+
+async def async_kv_list(partition_id: str) -> tuple[list[str], list[dict[str, Any]]]:
+ """Asynchronously list all keys and their metadata in a partition.
+
+ Args:
+ partition_id: Partition to list keys from
+
+ Returns:
+ Tuple of:
+ - List of keys in the partition
+ - List of custom metadata (tags) associated with each key
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> keys, tags = await tq.async_kv_list(partition_id="train")
+ >>> print(f"Keys: {keys}")
+ >>> print(f"Tags: {tags}")
+ """
+ tq_client = _maybe_create_transferqueue_client()
+
+ keys, custom_meta = await tq_client.async_kv_list(partition_id)
+
+ return keys, custom_meta
+
+
+async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
+ """Asynchronously clear key-value pairs from TransferQueue.
+
+ This removes the specified keys and their associated data from both
+ the controller and storage units.
+
+ Args:
+ keys: Single key or list of keys to clear
+ partition_id: Partition containing the keys
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> # Clear single key
+ >>> await tq.async_kv_clear(key="sample_1", partition_id="train")
+ >>> # Clear multiple keys
+ >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
+ """
+
+ if isinstance(keys, str):
+ keys = [keys]
+
+ tq_client = _maybe_create_transferqueue_client()
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
+
+ if batch_meta.size > 0:
+ await tq_client.async_clear_samples(batch_meta)
diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py
index 5e25404..2abde1b 100644
--- a/transfer_queue/metadata.py
+++ b/transfer_queue/metadata.py
@@ -303,60 +303,43 @@ def clear_extra_info(self) -> None:
"""
self.extra_info.clear()
- def set_custom_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None:
+ def get_all_custom_meta(self) -> list[dict[str, Any]]:
"""
- Set custom_meta for a specific sample by global_index.
-
- Custom metadata is user-defined per-sample metadata that can be stored
- and retrieved along with the BatchMeta.
-
- Args:
- global_index: The global_index of the sample to set custom meta for
- meta_dict: Dictionary containing custom metadata for the sample
-
- Raises:
- ValueError: If the key is not in global_indexes
- """
-
- if global_index not in self.global_indexes:
- raise ValueError(f"key {global_index} not found in global_indexes {self.global_indexes}.")
-
- self.custom_meta[global_index] = copy.deepcopy(meta_dict)
-
- def get_all_custom_meta(self) -> dict[int, dict[str, Any]]:
- """
- Get all custom_meta as a dictionary.
+ Get all custom_meta as a list of dictionary.
Returns:
- A deep copy of the custom_meta dictionary
+ A deep copy of the custom_meta list
"""
- return copy.deepcopy(self.custom_meta)
+ custom_meta = [self.custom_meta.get(i, {}) for i in self.global_indexes]
+ return copy.deepcopy(custom_meta)
- def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]):
+ def update_custom_meta(self, custom_meta: list[dict[str, Any]]):
"""
- Update custom_meta with a dictionary of new metadata.
+ Update custom_meta with a list of dictionary of custom metadata.
This method updates the custom_meta dictionary with the provided metadata.
Existing keys will be overwritten with new values.
Args:
- new_meta: Dictionary of new metadata
+ custom_meta: list of custom_meta dictionary
Raises:
- ValueError: If any key in new_meta is not in global_indexes
+ ValueError: If the length of custom_meta cannot match the batch size
"""
- if new_meta is None:
+ if custom_meta is None:
return
- non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes)
- if non_exist_global_indexes:
+ if len(custom_meta) != self.size:
raise ValueError(
- f"Trying to update custom_meta with non-exist global_indexes! {non_exist_global_indexes} "
- f"do not exist in this batch."
+ f"The length of custom_meta list {len(custom_meta)} must match the batch size: {self.size}"
)
- self.custom_meta.update(new_meta)
+ custom_meta_dict: dict[int, dict[str, Any]] = {
+ self.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta))
+ }
+
+ self.custom_meta.update(custom_meta_dict)
def clear_custom_meta(self) -> None:
"""
diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py
index c07ec1f..258dd91 100644
--- a/transfer_queue/storage/managers/base.py
+++ b/transfer_queue/storage/managers/base.py
@@ -28,6 +28,7 @@
import ray
import torch
import zmq
+from omegaconf import DictConfig
from tensordict import NonTensorStack, TensorDict
from torch import Tensor
@@ -59,12 +60,10 @@ class TransferQueueStorageManager(ABC):
"""Base class for storage layer. It defines the interface for data operations and
generally provides handshake & notification capabilities."""
- def __init__(self, config: dict[str, Any]):
+ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig):
self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}"
self.config = config
- controller_info = config.get("controller_info")
- assert controller_info is not None, "controller_info is required"
- self.controller_info: ZMQServerInfo = controller_info
+ self.controller_info = controller_info
self.data_status_update_socket: Optional[zmq.Socket[bytes]] = None
self.controller_handshake_socket: Optional[zmq.Socket[bytes]] = None
@@ -178,7 +177,7 @@ def _send_handshake_requests(self) -> None:
"""Send handshake request to controller."""
assert self.controller_handshake_socket is not None, "controller_handshake_socket is not properly initialized"
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.HANDSHAKE,
+ request_type=ZMQRequestType.HANDSHAKE, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={
"storage_manager_id": self.storage_manager_id,
@@ -226,7 +225,7 @@ async def notify_data_update(
poller.register(self.data_status_update_socket, zmq.POLLIN)
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.NOTIFY_DATA_UPDATE,
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={
"partition_id": partition_id,
@@ -246,7 +245,7 @@ async def notify_data_update(
)
except Exception as e:
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR,
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={
"message": f"Failed to notify data status update information from "
@@ -351,14 +350,14 @@ class KVStorageManager(TransferQueueStorageManager):
It maps structured metadata (BatchMeta) to flat lists of keys and values for efficient KV operations.
"""
- def __init__(self, config: dict[str, Any]):
+ def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]):
"""
Initialize the KVStorageManager with configuration.
"""
client_name = config.get("client_name", None)
if client_name is None:
raise ValueError("Missing client_name in config")
- super().__init__(config)
+ super().__init__(controller_info, config)
self.storage_client = StorageClientFactory.create(client_name, config)
self._multi_threads_executor: Optional[ThreadPoolExecutor] = None
# Register a cleanup function: automatically invoke shutdown when the instance is garbage collected.
@@ -557,7 +556,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
loop = asyncio.get_event_loop()
# put to storage backends
- custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values)
+ custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values)
per_field_dtypes: dict[int, dict[str, Any]] = {}
per_field_shapes: dict[int, dict[str, Any]] = {}
@@ -578,24 +577,26 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
getattr(data_item, "shape", None) if isinstance(data_item, Tensor) else None
)
- # Prepare per-field custom_meta if available
- per_field_custom_meta: dict[int, dict[str, Any]] = {}
- if custom_meta:
- if len(custom_meta) != len(keys):
- raise ValueError(f"Length of custom_meta ({len(custom_meta)}) does not match expected ({len(keys)})")
+ # Prepare per-field custom_backend_meta if available
+ per_field_custom_backend_meta: dict[int, dict[str, Any]] = {}
+ if custom_backend_meta:
+ if len(custom_backend_meta) != len(keys):
+ raise ValueError(
+ f"Length of custom_backend_meta ({len(custom_backend_meta)}) does not match expected ({len(keys)})"
+ )
# custom meta is a flat list aligned with keys/values
# Use itertools.product to eliminate nested loops
for global_idx in metadata.global_indexes:
- per_field_custom_meta[global_idx] = {}
+ per_field_custom_backend_meta[global_idx] = {}
# TODO(tianyi): the order of custom meta is coupled with keys/values
for (field_name, global_idx), meta_value in zip(
itertools.product(sorted(metadata.field_names), metadata.global_indexes),
- custom_meta,
+ custom_backend_meta,
strict=True,
):
- per_field_custom_meta[global_idx][field_name] = meta_value
- metadata.update_custom_meta(per_field_custom_meta)
+ per_field_custom_backend_meta[global_idx][field_name] = meta_value
+ metadata._custom_backend_meta.update(per_field_custom_backend_meta)
# Get current data partition id
# Note: Currently we only support putting to & getting data from a single data partition simultaneously,
@@ -608,7 +609,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
metadata.global_indexes,
per_field_dtypes,
per_field_shapes,
- per_field_custom_meta,
+ per_field_custom_backend_meta,
)
async def get_data(self, metadata: BatchMeta) -> TensorDict:
diff --git a/transfer_queue/storage/managers/factory.py b/transfer_queue/storage/managers/factory.py
index e595ccd..04d2cd0 100644
--- a/transfer_queue/storage/managers/factory.py
+++ b/transfer_queue/storage/managers/factory.py
@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import warnings
from typing import Any
from transfer_queue.storage.managers.base import TransferQueueStorageManager
+from transfer_queue.utils.zmq_utils import ZMQServerInfo
class TransferQueueStorageManagerFactory:
@@ -39,10 +41,34 @@ def decorator(manager_cls: type[TransferQueueStorageManager]):
return decorator
@classmethod
- def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager:
+ def create(
+ cls, manager_type: str, controller_info: ZMQServerInfo, config: dict[str, Any]
+ ) -> TransferQueueStorageManager:
"""Create and return a TransferQueueStorageManager instance."""
if manager_type not in cls._registry:
- raise ValueError(
- f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}"
- )
- return cls._registry[manager_type](config)
+ if manager_type == "AsyncSimpleStorageManager":
+ warnings.warn(
+ f"The manager_type {manager_type} will be deprecated in 0.1.7, please use SimpleStorage instead.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+ manager_type = "SimpleStorage"
+ elif manager_type == "MooncakeStorageManager":
+ warnings.warn(
+ f"The manager_type {manager_type} will be deprecated in 0.1.7, please use MooncakeStore instead.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+ manager_type = "MooncakeStore"
+ elif manager_type == "YuanrongStorageManager":
+ warnings.warn(
+ f"The manager_type {manager_type} will be deprecated in 0.1.7, please use Yuanrong instead.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+ manager_type = "Yuanrong"
+ else:
+ raise ValueError(
+ f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}"
+ )
+ return cls._registry[manager_type](controller_info, config)
diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py
index ca55566..9f6f93a 100644
--- a/transfer_queue/storage/managers/mooncake_manager.py
+++ b/transfer_queue/storage/managers/mooncake_manager.py
@@ -19,16 +19,17 @@
from transfer_queue.storage.managers.base import KVStorageManager
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
+from transfer_queue.utils.zmq_utils import ZMQServerInfo
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
-@TransferQueueStorageManagerFactory.register("MooncakeStorageManager")
+@TransferQueueStorageManagerFactory.register("MooncakeStore")
class MooncakeStorageManager(KVStorageManager):
"""Storage manager for MooncakeStorage backend."""
- def __init__(self, config: dict[str, Any]):
+ def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]):
# Required: Address of the HTTP metadata server (e.g., "localhost:8080")
metadata_server = config.get("metadata_server", None)
# Required: Address of the master server RPC endpoint (e.g., "localhost:8081")
@@ -45,4 +46,4 @@ def __init__(self, config: dict[str, Any]):
config["client_name"] = "MooncakeStorageClient"
elif client_name != "MooncakeStorageClient":
raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'MooncakeStorageClient'")
- super().__init__(config)
+ super().__init__(controller_info, config)
diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py
index 5c6e68f..f142069 100644
--- a/transfer_queue/storage/managers/simple_backend_manager.py
+++ b/transfer_queue/storage/managers/simple_backend_manager.py
@@ -16,6 +16,7 @@
import asyncio
import logging
import os
+import warnings
from collections.abc import Mapping
from functools import wraps
from operator import itemgetter
@@ -24,6 +25,7 @@
import torch
import zmq
+from omegaconf import DictConfig
from tensordict import NonTensorStack, TensorDict
from transfer_queue.metadata import BatchMeta
@@ -48,7 +50,7 @@
TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False)
-@TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager")
+@TransferQueueStorageManagerFactory.register("SimpleStorage")
class AsyncSimpleStorageManager(TransferQueueStorageManager):
"""Asynchronous storage manager that handles multiple storage units.
@@ -56,14 +58,23 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager):
instances using ZMQ communication and dynamic socket management.
"""
- def __init__(self, config: dict[str, Any]):
- super().__init__(config)
+ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig):
+ super().__init__(controller_info, config)
self.config = config
- server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("storage_unit_infos", None)
+ server_infos: ZMQServerInfo | dict[str, ZMQServerInfo] | None = config.get("zmq_info", None)
if server_infos is None:
- raise ValueError("AsyncSimpleStorageManager requires non-empty 'storage_unit_infos' in config.")
+ server_infos = config.get("storage_unit_infos", None)
+ if server_infos is not None:
+ warnings.warn(
+ "The config entry `storage_unit_infos` will be deprecated in 0.1.7, please use `zmq_info` instead.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+
+ if server_infos is None:
+ raise ValueError("AsyncSimpleStorageManager requires non-empty 'zmq_info' in config.")
self.storage_unit_infos = self._register_servers(server_infos)
self._build_storage_mapping_functions()
@@ -277,7 +288,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict:
metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
)
- # retrive data
+ # retrieve data
tasks = [
self._get_from_single_storage_unit(meta_group, target_storage_unit=storage_id)
for storage_id, meta_group in storage_meta_groups.items()
diff --git a/transfer_queue/storage/managers/yuanrong_manager.py b/transfer_queue/storage/managers/yuanrong_manager.py
index bfb79e6..54ac094 100644
--- a/transfer_queue/storage/managers/yuanrong_manager.py
+++ b/transfer_queue/storage/managers/yuanrong_manager.py
@@ -19,6 +19,7 @@
from transfer_queue.storage.managers.base import KVStorageManager
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
+from transfer_queue.utils.zmq_utils import ZMQServerInfo
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
@@ -30,11 +31,11 @@
logger.addHandler(handler)
-@TransferQueueStorageManagerFactory.register("YuanrongStorageManager")
+@TransferQueueStorageManagerFactory.register("Yuanrong")
class YuanrongStorageManager(KVStorageManager):
"""Storage manager for Yuanrong backend."""
- def __init__(self, config: dict[str, Any]):
+ def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]):
host = config.get("host", None)
port = config.get("port", None)
client_name = config.get("client_name", None)
@@ -48,4 +49,4 @@ def __init__(self, config: dict[str, Any]):
config["client_name"] = "YuanrongStorageClient"
elif client_name != "YuanrongStorageClient":
raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'YuanrongStorageClient'")
- super().__init__(config)
+ super().__init__(controller_info, config)
diff --git a/transfer_queue/utils/common.py b/transfer_queue/utils/common.py
index e25f6b0..5192e9f 100644
--- a/transfer_queue/utils/common.py
+++ b/transfer_queue/utils/common.py
@@ -16,11 +16,13 @@
import logging
import os
from contextlib import contextmanager
-from typing import Optional
+from typing import Any, Optional
+import numpy as np
import psutil
import ray
import torch
+from tensordict import NonTensorStack, TensorDict
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
@@ -98,3 +100,67 @@ def get_env_bool(env_key: str, default: bool = False) -> bool:
true_values = {"true", "1", "yes", "y", "on"}
return env_value_lower in true_values
+
+
+def dict_to_tensordict(data: dict[str, Any]) -> TensorDict:
+ """
+ Create a TensorDict from a dict of tensors and non_tensors.
+ """
+
+ batch = {}
+
+ final_batch_size = None
+ tensor_batch_size = None
+ deterministic_tensor_batch_size = None
+ deterministic_non_tensor_batch_size = None
+
+ for key, val in data.items():
+ if isinstance(val, torch.Tensor):
+ if val.is_nested and val.layout == torch.strided:
+ # must use unbind for strided nested tensor
+ deterministic_tensor_batch_size = len(val.unbind())
+ else:
+ tensor_batch_size = val.shape[0]
+ batch[key] = val
+ elif isinstance(val, np.ndarray):
+ batch[key] = val
+ tensor_batch_size = val.shape[0]
+ elif isinstance(val, str):
+ batch[key] = val
+ deterministic_non_tensor_batch_size = 1
+ elif isinstance(val, list):
+ batch[key] = NonTensorStack(*val)
+ deterministic_non_tensor_batch_size = len(val)
+ else:
+ batch[key] = NonTensorStack(val)
+ deterministic_non_tensor_batch_size = 1
+
+ if deterministic_tensor_batch_size:
+ if deterministic_non_tensor_batch_size:
+ assert deterministic_non_tensor_batch_size == deterministic_tensor_batch_size
+ if final_batch_size:
+ assert final_batch_size == deterministic_tensor_batch_size
+ else:
+ final_batch_size = deterministic_tensor_batch_size
+
+ if deterministic_non_tensor_batch_size:
+ if deterministic_tensor_batch_size:
+ assert deterministic_non_tensor_batch_size == tensor_batch_size
+ if final_batch_size:
+ assert final_batch_size == deterministic_non_tensor_batch_size
+ else:
+ final_batch_size = deterministic_non_tensor_batch_size
+
+ if not final_batch_size:
+ raise RuntimeError("Cannot correctly determine batch_size for input.")
+
+ if tensor_batch_size:
+ if tensor_batch_size != final_batch_size:
+ assert final_batch_size == 1
+ for k, v in batch.items():
+ if isinstance(v, torch.Tensor):
+ batch[k] = v.unsqueeze(0)
+ elif isinstance(v, np.ndarray):
+ batch[k] = np.expand_dims(v, 0)
+
+ return TensorDict(batch, batch_size=[final_batch_size])
diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py
index 1f6ed92..eaaf65e 100644
--- a/transfer_queue/utils/zmq_utils.py
+++ b/transfer_queue/utils/zmq_utils.py
@@ -23,6 +23,7 @@
from uuid import uuid4
import psutil
+import ray
import zmq
from transfer_queue.utils.common import (
@@ -100,6 +101,12 @@ class ZMQRequestType(ExplicitEnum):
NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK"
NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR"
+ # KV_INTERFACE
+ KV_RETRIEVE_KEYS = "KV_RETRIEVE_KEYS"
+ KV_RETRIEVE_KEYS_RESPONSE = "KV_RETRIEVE_KEYS_RESPONSE"
+ KV_LIST = "KV_LIST"
+ KV_LIST_RESPONSE = "KV_LIST_RESPONSE"
+
class ZMQServerInfo:
"""
@@ -239,3 +246,35 @@ def create_zmq_socket(
if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity)
return socket
+
+
+def process_zmq_server_info(
+ handlers: dict[Any, Any] | Any,
+): # noqa: UP007
+ """Extract ZMQ server information from handler objects.
+
+ Args:
+ handlers: Dictionary of handler objects (controllers, storage managers, or storage units),
+ or a single handler object
+
+ Returns:
+ If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information
+ If handlers is a single object: ZMQ server information for that object
+
+ Examples:
+ >>> # Single handler
+ >>> controller = TransferQueueController.remote(...)
+ >>> info = process_zmq_server_info(controller)
+ >>>
+ >>> # Multiple handlers
+ >>> handlers = {"storage_0": storage_0, "storage_1": storage_1}
+ >>> info_dict = process_zmq_server_info(handlers)"""
+ # Handle single handler object case
+ if not isinstance(handlers, dict):
+ return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
+ else:
+ # Handle dictionary case
+ server_info = {}
+ for name, handler in handlers.items():
+ server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
+ return server_info
diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py
index 25b530c..59159a6 100644
--- a/tutorial/01_core_components.py
+++ b/tutorial/01_core_components.py
@@ -37,87 +37,23 @@
import ray # noqa: E402
import torch # noqa: E402
-from omegaconf import OmegaConf # noqa: E402
from tensordict import TensorDict # noqa: E402
# Add the parent directory to the path
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
-from transfer_queue import ( # noqa: E402
- SimpleStorageUnit,
- TransferQueueClient,
- TransferQueueController,
- process_zmq_server_info,
-)
+import transfer_queue as tq # noqa: E402
# Configure Ray
os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["RAY_DEBUG"] = "1"
-
-def demonstrate_basic_setup():
- """
- Demonstrate the basic setup of TransferQueue with three core components.
- """
-
- # Initialize Ray
- if not ray.is_initialized():
- ray.init()
-
- # Configuration
- config = OmegaConf.create(
- {
- "num_data_storage_units": 2,
- }
- )
-
- print("[Step 1] Creating Storage Backend (using default SimpleStorageUnit)...")
- storage_units = {}
- for i in range(config["num_data_storage_units"]):
- storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100)
- print(f" ✓ Created SimpleStorageUnit #{i}")
-
- print("[Step 2] Creating TransferQueueController...")
- controller = TransferQueueController.remote()
- print(" ✓ Controller created - manages data state")
-
- # Get server information
- controller_info = process_zmq_server_info(controller)
- storage_unit_infos = process_zmq_server_info(storage_units)
-
- # Create Client (User-facing API)
- print("[Step 3] Creating TransferQueueClient...")
- client = TransferQueueClient(
- client_id="TutorialClient",
- controller_info=controller_info,
- )
- print(" ✓ Client created - this is what users interact with!")
-
- # Initialize storage manager
- tq_config = OmegaConf.create({}, flags={"allow_objects": True})
- tq_config.controller_info = controller_info
- tq_config.storage_unit_infos = storage_unit_infos
- config = OmegaConf.merge(tq_config, config)
-
- client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
- print(
- " ✓ Storage manager initialized. It is a class variable inside the client, acting as an adapter to "
- "suit for various storage backends."
- )
-
- print("[Architecture Summary]")
- print(
- " - TransferQueueController: Tracking the production/consumption status as metadata (can define your own "
- "data consumption logics)."
- )
- print(" - SimpleStorageUnit: Distributed data storage that holds actual data (easily swap out by other backends).")
- print(" - TransferQueueClient: User interface that allows you to put/get/clear data or metadata)")
-
- return controller, storage_units, client
+if not ray.is_initialized():
+ ray.init(namespace="TransferQueueTutorial")
-def demonstrate_data_workflow(client):
+def demonstrate_data_workflow():
"""
Demonstrate basic data workflow: put → get → clear.
"""
@@ -148,12 +84,12 @@ def demonstrate_data_workflow(client):
print(f" Created {data_batch.batch_size[0]} samples")
partition_id = "tutorial_partition_0"
- client.put(data=data_batch, partition_id=partition_id)
+ tq.put(data=data_batch, partition_id=partition_id)
print(f" ✓ Data written to partition: {partition_id}")
# Step 2: Get metadata
print("[Step 2] Requesting data metadata...")
- batch_meta = client.get_meta(
+ batch_meta = tq.get_meta(
data_fields=["input_ids", "attention_mask"],
batch_size=data_batch.batch_size[0],
partition_id=partition_id,
@@ -164,7 +100,7 @@ def demonstrate_data_workflow(client):
# Step 3: Get actual data
print("[Step 3] Retrieving actual data...")
- retrieved_data = client.get_data(batch_meta)
+ retrieved_data = tq.get_data(batch_meta)
print(" ✓ Data retrieved successfully")
print(f" Keys: {list(retrieved_data.keys())}")
@@ -176,7 +112,7 @@ def demonstrate_data_workflow(client):
# Step 5: Clear
print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)")
- client.clear_partition(partition_id=partition_id)
+ tq.clear_partition(partition_id=partition_id)
print(" ✓ Partition cleared")
@@ -189,16 +125,16 @@ def demonstrate_storage_backend_options():
print("=" * 80)
print("TransferQueue supports multiple storage backends:")
- print("1. SimpleStorageUnit (default)")
+ print("1. SimpleStorage (default)")
print(" - In-memory storage, fast and simple")
print(" - Leveraging ZMQ for communication, with zero-copy serialization and transfer")
print(" - No extra dependencies, good for development and testing")
- print("2. YuanrongStorage")
+ print("2. Yuanrong")
print(" - Ascend native distributed storage solution")
print(" - Hierarchical storage interfaces including HBM/DRAM/SSD")
- print("3. MoonCakeStore (on the way)")
+ print("3. MooncakeStore (on the way)")
print(" - Support multiple transmission protocols")
print(" - RDMA between DRAM and HBM")
@@ -234,10 +170,10 @@ def main():
try:
print("Setting up TransferQueue...")
- controller, storage_units, client = demonstrate_basic_setup()
+ tq.init()
print("Demonstrating the user workflow...")
- demonstrate_data_workflow(client)
+ demonstrate_data_workflow()
demonstrate_storage_backend_options()
@@ -253,7 +189,7 @@ def main():
print("3. You can swap out different storage backends easily")
# Cleanup
- client.close()
+ tq.close()
ray.shutdown()
print("\n✓ Cleanup complete")
diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py
new file mode 100644
index 0000000..59159a6
--- /dev/null
+++ b/tutorial/02_kv_interface.py
@@ -0,0 +1,205 @@
+# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+# Copyright 2025 The TransferQueue Team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import textwrap
+import warnings
+from pathlib import Path
+
+warnings.filterwarnings(
+ action="ignore",
+ message=r"The PyTorch API of nested tensors is in prototype stage*",
+ category=UserWarning,
+ module=r"torch\.nested",
+)
+
+warnings.filterwarnings(
+ action="ignore",
+ message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible "
+ r"devices env var if num_gpus=0 or num_gpus=None.*",
+ category=FutureWarning,
+ module=r"ray\._private\.worker",
+)
+
+
+import ray # noqa: E402
+import torch # noqa: E402
+from tensordict import TensorDict # noqa: E402
+
+# Add the parent directory to the path
+parent_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(parent_dir))
+
+import transfer_queue as tq # noqa: E402
+
+# Configure Ray
+os.environ["RAY_DEDUP_LOGS"] = "0"
+os.environ["RAY_DEBUG"] = "1"
+
+if not ray.is_initialized():
+ ray.init(namespace="TransferQueueTutorial")
+
+
+def demonstrate_data_workflow():
+ """
+ Demonstrate basic data workflow: put → get → clear.
+ """
+ print("=" * 80)
+ print("Data Workflow Demo: put → get → clear")
+ print("=" * 80)
+
+ # Step 1: Put data
+ print("[Step 1] Putting data into TransferQueue...")
+
+ input_ids = torch.tensor(
+ [
+ [1, 2, 3],
+ [4, 5, 6],
+ [7, 8, 9],
+ [10, 11, 12],
+ ]
+ )
+ attention_mask = torch.ones_like(input_ids)
+
+ data_batch = TensorDict(
+ {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ },
+ batch_size=input_ids.size(0),
+ )
+
+ print(f" Created {data_batch.batch_size[0]} samples")
+ partition_id = "tutorial_partition_0"
+ tq.put(data=data_batch, partition_id=partition_id)
+ print(f" ✓ Data written to partition: {partition_id}")
+
+ # Step 2: Get metadata
+ print("[Step 2] Requesting data metadata...")
+ batch_meta = tq.get_meta(
+ data_fields=["input_ids", "attention_mask"],
+ batch_size=data_batch.batch_size[0],
+ partition_id=partition_id,
+ task_name="tutorial_task",
+ )
+ print(f" ✓ Got metadata: {len(batch_meta)} samples")
+ print(f" Global indexes: {batch_meta.global_indexes}")
+
+ # Step 3: Get actual data
+ print("[Step 3] Retrieving actual data...")
+ retrieved_data = tq.get_data(batch_meta)
+ print(" ✓ Data retrieved successfully")
+ print(f" Keys: {list(retrieved_data.keys())}")
+
+ # Step 4: Verify
+ print("[Step 4] Verifying data integrity...")
+ assert torch.equal(retrieved_data["input_ids"], input_ids)
+ assert torch.equal(retrieved_data["attention_mask"], attention_mask)
+ print(" ✓ Data matches original!")
+
+ # Step 5: Clear
+ print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)")
+ tq.clear_partition(partition_id=partition_id)
+ print(" ✓ Partition cleared")
+
+
+def demonstrate_storage_backend_options():
+ """
+ Show different storage backend options.
+ """
+ print("=" * 80)
+ print("Storage Backend Options")
+ print("=" * 80)
+
+ print("TransferQueue supports multiple storage backends:")
+ print("1. SimpleStorage (default)")
+ print(" - In-memory storage, fast and simple")
+ print(" - Leveraging ZMQ for communication, with zero-copy serialization and transfer")
+ print(" - No extra dependencies, good for development and testing")
+
+ print("2. Yuanrong")
+ print(" - Ascend native distributed storage solution")
+ print(" - Hierarchical storage interfaces including HBM/DRAM/SSD")
+
+ print("3. MooncakeStore (on the way)")
+ print(" - Support multiple transmission protocols")
+ print(" - RDMA between DRAM and HBM")
+
+ print("4. Ray RDT (on the way)")
+ print(" - Leverage Ray's distributed object store to store data")
+
+ print("5. Custom Storage Backends")
+ print(" - Implement your own storage manager by inheriting from `TransferQueueStorageManager` base class")
+ print(" - For KV based storage, you only need to provide a storage client and integrate with `KVStorageManager`")
+
+
+def main():
+ print("=" * 80)
+ print(
+ textwrap.dedent(
+ """
+ TransferQueue Tutorial 1: Core Components Introduction
+
+ This script introduces the three core components of TransferQueue:
+ 1. TransferQueueController - Manages all the metadata and tracks the production and consumption states
+ 2. StorageBackend - Pluggable distributed storage backend that holds the actual data
+ 3. TransferQueueClient - Client interface for reading/writing data (user-facing API)
+
+ Key Concepts:
+ - Data is organized into logical partitions (e.g., "train", "val")
+ - Each sample has multiple fields, with a global index for identification
+ - Controller maintains production/consumption state tracking
+ - Client is the main interface users interact with
+ """
+ )
+ )
+ print("=" * 80)
+
+ try:
+ print("Setting up TransferQueue...")
+ tq.init()
+
+ print("Demonstrating the user workflow...")
+ demonstrate_data_workflow()
+
+ demonstrate_storage_backend_options()
+
+ print("=" * 80)
+ print("Tutorial Complete!")
+ print("=" * 80)
+ print("Key Takeaways:")
+ print("1. TransferQueue has 3 core components:")
+ print(" - Controller: Manages data production/consumption state")
+ print(" - StorageBackend: Persists actual data")
+ print(" - Client: User-facing API (what you use)")
+ print("2. Client is the main interface users interact with")
+ print("3. You can swap out different storage backends easily")
+
+ # Cleanup
+ tq.close()
+ ray.shutdown()
+ print("\n✓ Cleanup complete")
+
+ except Exception as e:
+ print(f"Error during tutorial: {e}")
+ import traceback
+
+ traceback.print_exc()
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tutorial/02_metadata_concepts.py b/tutorial/03_metadata_concepts.py
similarity index 89%
rename from tutorial/02_metadata_concepts.py
rename to tutorial/03_metadata_concepts.py
index 93e59ac..3a4a57d 100644
--- a/tutorial/02_metadata_concepts.py
+++ b/tutorial/03_metadata_concepts.py
@@ -38,19 +38,13 @@
import ray # noqa: E402
import torch # noqa: E402
-from omegaconf import OmegaConf # noqa: E402
from tensordict import TensorDict # noqa: E402
# Add the parent directory to the path
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
-from transfer_queue import ( # noqa: E402
- SimpleStorageUnit,
- TransferQueueClient,
- TransferQueueController,
- process_zmq_server_info,
-)
+import transfer_queue as tq # noqa: E402
from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402
from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402
@@ -211,10 +205,12 @@ def demonstrate_batch_meta():
print(f"✓ Extra info: {batch.get_all_extra_info()}")
print("[Example 3] Adding sample-level information through custom_meta...")
- batch.set_custom_meta(
- global_index=0, meta_dict={"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"}
+ batch.update_custom_meta(
+ [
+ {"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"},
+ {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"},
+ ]
)
- batch.update_custom_meta({1: {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"}})
print(f"✓ Custom meta: {batch.get_all_custom_meta()}")
# Example 4: Chunk a batch
@@ -308,32 +304,8 @@ def demonstrate_real_workflow():
if not ray.is_initialized():
ray.init()
- # Setup TransferQueue
- config = OmegaConf.create(
- {
- "num_data_storage_units": 2,
- }
- )
-
- storage_units = {}
- for i in range(config["num_data_storage_units"]):
- storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100)
-
- controller = TransferQueueController.remote()
- controller_info = process_zmq_server_info(controller)
- storage_unit_infos = process_zmq_server_info(storage_units)
-
- client = TransferQueueClient(
- client_id="TutorialClient",
- controller_info=controller_info,
- )
-
- tq_config = OmegaConf.create({}, flags={"allow_objects": True})
- tq_config.controller_info = controller_info
- tq_config.storage_unit_infos = storage_unit_infos
- config = OmegaConf.merge(tq_config, config)
-
- client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
+ # Initialize TransferQueue
+ tq.init()
print("[Step 1] Putting data into TransferQueue...")
input_ids = torch.randint(0, 1000, (8, 512))
@@ -348,23 +320,23 @@ def demonstrate_real_workflow():
)
partition_id = "demo_partition"
- batch_meta = client.put(data=data_batch, partition_id=partition_id)
+ batch_meta = tq.put(data=data_batch, partition_id=partition_id)
print(f"✓ Put {data_batch.batch_size[0]} samples into partition '{partition_id}', got BatchMeta back {batch_meta}.")
print("[Step 2] [Optional] Setting sample-level custom_meta...")
- custom_meta = {
- global_index: {"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0}
- for global_index in batch_meta.global_indexes
- }
+ custom_meta = [
+ {"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0}
+ for _ in range(batch_meta.size)
+ ]
batch_meta.update_custom_meta(custom_meta)
print(f"✓ Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}")
- client.set_custom_meta(batch_meta)
+ tq.set_custom_meta(batch_meta)
print("✓ Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.")
print("[Step 3] Try to get metadata from TransferQueue from other places...")
- batch_meta = client.get_meta(
+ batch_meta = tq.get_meta(
data_fields=["input_ids", "attention_mask"],
batch_size=8,
partition_id=partition_id,
@@ -383,7 +355,7 @@ def demonstrate_real_workflow():
print("✓ Selected 'input_ids' field only:")
print(f" New field names: {selected_meta.field_names}")
print(f" Samples still have same global indexes: {selected_meta.global_indexes}")
- retrieved_data = client.get_data(selected_meta)
+ retrieved_data = tq.get_data(selected_meta)
print(f" Retrieved data keys: {list(retrieved_data.keys())}")
print("[Step 5] Select specific samples from the retrieved BatchMeta...")
@@ -391,7 +363,7 @@ def demonstrate_real_workflow():
print("✓ Selected samples at indices [0, 2, 4, 6]:")
print(f" New global indexes: {partial_meta.global_indexes}")
print(f" Number of samples: {len(partial_meta)}")
- retrieved_data = client.get_data(partial_meta)
+ retrieved_data = tq.get_data(partial_meta)
print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}")
print("[Step 6] Demonstrate chunk operation...")
@@ -399,12 +371,12 @@ def demonstrate_real_workflow():
print(f"✓ Chunked into {len(chunks)} parts:")
for i, chunk in enumerate(chunks):
print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}")
- chunk_data = client.get_data(chunk)
+ chunk_data = tq.get_data(chunk)
print(f" Chunk {i}: Retrieved chunk data: {chunk_data}")
# Cleanup
- client.clear_partition(partition_id=partition_id)
- client.close()
+ tq.clear_partition(partition_id=partition_id)
+ tq.close()
ray.shutdown()
print("✓ Partition cleared and resources cleaned up")
diff --git a/tutorial/03_understanding_controller.py b/tutorial/04_understanding_controller.py
similarity index 76%
rename from tutorial/03_understanding_controller.py
rename to tutorial/04_understanding_controller.py
index 8b7e6d0..4ca426d 100644
--- a/tutorial/03_understanding_controller.py
+++ b/tutorial/04_understanding_controller.py
@@ -36,59 +36,19 @@
import ray # noqa: E402
import torch # noqa: E402
-from omegaconf import OmegaConf # noqa: E402
from tensordict import TensorDict # noqa: E402
# Add the parent directory to the path
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
-from transfer_queue import ( # noqa: E402
- SimpleStorageUnit,
- TransferQueueClient,
- TransferQueueController,
- process_zmq_server_info,
-)
+import transfer_queue as tq # noqa: E402
# Configure Ray
os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["RAY_DEBUG"] = "1"
-def setup_transfer_queue():
- """Setup TransferQueue components."""
- if not ray.is_initialized():
- ray.init()
-
- config = OmegaConf.create(
- {
- "num_data_storage_units": 2,
- }
- )
-
- storage_units = {}
- for i in range(config["num_data_storage_units"]):
- storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100)
-
- controller = TransferQueueController.remote()
- controller_info = process_zmq_server_info(controller)
- storage_unit_infos = process_zmq_server_info(storage_units)
-
- client = TransferQueueClient(
- client_id="TutorialClient",
- controller_info=controller_info,
- )
-
- tq_config = OmegaConf.create({}, flags={"allow_objects": True})
- tq_config.controller_info = controller_info
- tq_config.storage_unit_infos = storage_unit_infos
- config = OmegaConf.merge(tq_config, config)
-
- client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
-
- return controller, storage_units, client
-
-
def demonstrate_partition_isolation():
"""Feature 1: Different partitions are isolated - data doesn't interfere."""
print("=" * 80)
@@ -97,7 +57,10 @@ def demonstrate_partition_isolation():
print("\nDifferent partitions are completely isolated - data doesn't interfere between partitions")
- controller, storage_units, client = setup_transfer_queue()
+ if not ray.is_initialized():
+ ray.init(namespace="TransferQueueTutorial")
+
+ tq.init()
# Partition 1: Training data
print("\n[Partition 1] Putting training data...")
@@ -108,7 +71,7 @@ def demonstrate_partition_isolation():
},
batch_size=2,
)
- client.put(data=train_data, partition_id="train")
+ tq.put(data=train_data, partition_id="train")
print(" ✓ Training data added to 'train' partition")
# Partition 2: Validation data
@@ -120,25 +83,23 @@ def demonstrate_partition_isolation():
},
batch_size=2,
)
- client.put(data=val_data, partition_id="val")
+ tq.put(data=val_data, partition_id="val")
print(" ✓ Validation data added to 'val' partition")
# Get from train partition
print("\n[Retrieving from 'train' partition]")
- train_meta = client.get_meta(
+ train_meta = tq.get_meta(
data_fields=["input_ids", "labels"], batch_size=2, partition_id="train", task_name="train_task"
)
- retrieved_train_data = client.get_data(train_meta)
+ retrieved_train_data = tq.get_data(train_meta)
print(f" ✓ Got BatchMeta={train_meta} from partition 'train'")
print(f" ✓ Retrieved Data: input_ids={retrieved_train_data['input_ids']}, labels={retrieved_train_data['labels']}")
# Get from val partition
print("\n[Retrieving from 'val' partition]")
- val_meta = client.get_meta(
- data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task"
- )
- retrieved_val_data = client.get_data(val_meta)
+ val_meta = tq.get_meta(data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task")
+ retrieved_val_data = tq.get_data(val_meta)
print(f" ✓ Got BatchMeta={val_meta} from partition 'val'")
print(f" ✓ Retrieved Data: input_ids={retrieved_val_data['input_ids']}, labels={retrieved_val_data['labels']}")
@@ -146,9 +107,9 @@ def demonstrate_partition_isolation():
print(" ✓ Data isolation: 'train' and 'val' partitions are completely independent")
# Cleanup
- client.clear_partition(partition_id="train")
- client.clear_partition(partition_id="val")
- client.close()
+ tq.clear_partition(partition_id="train")
+ tq.clear_partition(partition_id="val")
+ tq.close()
ray.shutdown()
@@ -160,7 +121,10 @@ def demonstrate_dynamic_expansion():
print("\nPartitions dynamically expand to accommodate new data (rows and columns)")
- controller, storage_units, client = setup_transfer_queue()
+ if not ray.is_initialized():
+ ray.init(namespace="TransferQueueTutorial")
+
+ tq.init()
# Add first batch with 2 samples, 2 fields
print("\n[Step 1] Adding initial data (2 samples, 2 fields)...")
@@ -171,7 +135,7 @@ def demonstrate_dynamic_expansion():
},
batch_size=2,
)
- meta1 = client.put(data=data1, partition_id="dynamic")
+ meta1 = tq.put(data=data1, partition_id="dynamic")
print(" ✓ Added 2 samples")
print(f" ✓ Got BatchMeta: {meta1} samples")
@@ -184,9 +148,9 @@ def demonstrate_dynamic_expansion():
},
batch_size=3,
)
- meta2 = client.put(data=data2, partition_id="dynamic")
+ meta2 = tq.put(data=data2, partition_id="dynamic")
- all_meta = client.get_meta(
+ all_meta = tq.get_meta(
data_fields=["field1", "field2"], batch_size=5, partition_id="dynamic", task_name="dynamic_task"
)
print(" ✓ Added 3 more samples (total: 5)")
@@ -201,7 +165,7 @@ def demonstrate_dynamic_expansion():
},
batch_size=2,
)
- meta3 = client.put(data=data3, metadata=meta1)
+ meta3 = tq.put(data=data3, metadata=meta1)
print(" ✓ Added 2 samples with new field 'field3'")
print(f" ✓ Got BatchMeta: {meta3} for newly put data with new field")
@@ -210,8 +174,8 @@ def demonstrate_dynamic_expansion():
print(" ✓ Columns auto-expand: Can add new fields anytime")
# Cleanup
- client.clear_partition(partition_id="dynamic")
- client.close()
+ tq.clear_partition(partition_id="dynamic")
+ tq.close()
ray.shutdown()
@@ -221,7 +185,10 @@ def demonstrate_default_consumption_sample_strategy():
print("Feature 3: Default Sampling Strategy for Controller - No Duplicate, Sequential Samples")
print("=" * 80)
- controller, storage_units, client = setup_transfer_queue()
+ if not ray.is_initialized():
+ ray.init(namespace="TransferQueueTutorial")
+
+ tq.init()
# Add 6 samples
print("\n[Setup] Adding 6 samples...")
@@ -231,22 +198,22 @@ def demonstrate_default_consumption_sample_strategy():
},
batch_size=6,
)
- client.put(data=all_data, partition_id="sampling")
+ tq.put(data=all_data, partition_id="sampling")
print(" ✓ 6 samples added")
# First get - should get samples 0,1,2
print("\n[Task A, Get 1] Requesting 3 samples...")
- meta1 = client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
+ meta1 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
print(f" ✓ Got samples: {meta1.global_indexes}")
# Second get - should get samples 3,4,5 (no duplicates!)
print("\n[Task A, Get 2] Requesting 3 more samples...")
- meta2 = client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
+ meta2 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
print(f" ✓ Got samples: {meta2.global_indexes}")
# Third get - should get samples 0,1
print("\n[Task B, Get 1] Requesting 2 samples...")
- meta3 = client.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B")
+ meta3 = tq.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B")
print(f" ✓ Got samples: {meta3.global_indexes}")
print("\n[Verification]")
@@ -257,8 +224,8 @@ def demonstrate_default_consumption_sample_strategy():
print(" ✓ Third get (Task B): samples 0,1")
# Cleanup
- client.clear_partition(partition_id="sampling")
- client.close()
+ tq.clear_partition(partition_id="sampling")
+ tq.close()
ray.shutdown()
diff --git a/tutorial/04_custom_sampler.py b/tutorial/05_custom_sampler.py
similarity index 83%
rename from tutorial/04_custom_sampler.py
rename to tutorial/05_custom_sampler.py
index 7bf13cd..c35b4e0 100644
--- a/tutorial/04_custom_sampler.py
+++ b/tutorial/05_custom_sampler.py
@@ -49,12 +49,7 @@
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
-from transfer_queue import ( # noqa: E402
- SimpleStorageUnit,
- TransferQueueClient,
- TransferQueueController,
- process_zmq_server_info,
-)
+import transfer_queue as tq # noqa: E402
from transfer_queue.sampler import BaseSampler # noqa: E402
@@ -171,36 +166,14 @@ def sample(
def setup_transfer_queue_with_sampler(sampler):
"""Setup TransferQueue with custom sampler."""
if not ray.is_initialized():
- ray.init()
+ ray.init(namespace="TransferQueueTutorial")
config = OmegaConf.create(
- {
- "global_batch_size": 8,
- "num_data_storage_units": 2,
- }
- )
-
- storage_units = {}
- for i in range(2):
- storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100)
-
- controller = TransferQueueController.remote(sampler=sampler)
- controller_info = process_zmq_server_info(controller)
- storage_unit_infos = process_zmq_server_info(storage_units)
-
- client = TransferQueueClient(
- client_id="TutorialClient",
- controller_info=controller_info,
+ {"controller": {"sampler": sampler}, "backend": {"SimpleStorage": {"num_data_storage_units": 2}}},
+ flags={"allow_objects": True},
)
- tq_config = OmegaConf.create({}, flags={"allow_objects": True})
- tq_config.controller_info = controller_info
- tq_config.storage_unit_infos = storage_unit_infos
- config = OmegaConf.merge(tq_config, config)
-
- client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
-
- return controller, storage_units, client
+ tq.init(config)
def demonstrate_random_sampler_with_replacement():
@@ -211,7 +184,7 @@ def demonstrate_random_sampler_with_replacement():
print("\nSetup TransferQueue with RandomSamplerWithReplacement...")
sampler = RandomSamplerWithReplacement()
- controller, storage_units, client = setup_transfer_queue_with_sampler(sampler)
+ setup_transfer_queue_with_sampler(sampler)
# Add 5 samples
print("\n[Step 1] Adding 5 samples...")
@@ -221,22 +194,22 @@ def demonstrate_random_sampler_with_replacement():
},
batch_size=5,
)
- client.put(data=data, partition_id="test")
+ tq.put(data=data, partition_id="test")
print(" ✓ 5 samples added")
# Get batch 1 (should get 2 random samples)
print("\n[Step 2] Get batch 1 (2 samples)...")
- meta1 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
+ meta1 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta1.global_indexes}")
# Get batch 2 (should get 1 random sample with replacement - may have duplicate with previous batch!)
print("\n[Step 3] Get batch 2 (1 sample)...")
- meta2 = client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
+ meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta2.global_indexes}")
# Get batch 3 (should get 2 random samples with replacement - may have duplicate with previous batches!)
print("\n[Step 4] Get batch 3 (2 samples)...")
- meta3 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
+ meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta3.global_indexes}")
print("\n[Verification]")
@@ -246,8 +219,8 @@ def demonstrate_random_sampler_with_replacement():
print(f" ✓ All sampled: {all_sampled}")
# Cleanup
- client.clear_partition(partition_id="test")
- client.close()
+ tq.clear_partition(partition_id="test")
+ tq.close()
ray.shutdown()
@@ -259,7 +232,7 @@ def demonstrate_random_sampler_without_replacement():
print("\nSetup TransferQueue with RandomSamplerWithoutReplacement...")
sampler = RandomSamplerWithoutReplacement()
- controller, storage_units, client = setup_transfer_queue_with_sampler(sampler)
+ setup_transfer_queue_with_sampler(sampler)
# Add 6 samples
print("\n[Step 1] Adding 6 samples...")
@@ -269,22 +242,22 @@ def demonstrate_random_sampler_without_replacement():
},
batch_size=6,
)
- client.put(data=data, partition_id="test")
+ tq.put(data=data, partition_id="test")
print(" ✓ 6 samples added")
# Get batch 1 (should get 3 random samples without replacement)
print("\n[Step 2] Get batch 1 (3 samples)...")
- meta1 = client.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task")
+ meta1 = tq.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta1.global_indexes}")
# Get batch 2 (should randomly get 1 sample that are different from previous batch)
print("\n[Step 3] Get batch 2 (1 samples)...")
- meta2 = client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
+ meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta2.global_indexes}")
# Get batch 3 (should randomly get 2 samples that are different from previous batch)
print("\n[Step 4] Get batch 3 (2 samples)...")
- meta3 = client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
+ meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta3.global_indexes}")
print("\n[Verification]")
@@ -294,8 +267,8 @@ def demonstrate_random_sampler_without_replacement():
print(f" ✓ Batch 3: {meta3.global_indexes} (none left)")
# Cleanup
- client.clear_partition(partition_id="test")
- client.close()
+ tq.clear_partition(partition_id="test")
+ tq.close()
ray.shutdown()
@@ -307,7 +280,7 @@ def demonstrate_priority_sampler():
print("\nSetup TransferQueue with PrioritySampler...")
sampler = PrioritySampler()
- controller, storage_units, client = setup_transfer_queue_with_sampler(sampler)
+ setup_transfer_queue_with_sampler(sampler)
# Add 8 samples
print("\n[Step 1] Adding 8 samples...")
@@ -317,7 +290,7 @@ def demonstrate_priority_sampler():
},
batch_size=8,
)
- client.put(data=data, partition_id="test")
+ tq.put(data=data, partition_id="test")
print(" ✓ 8 samples added")
time.sleep(1)
@@ -330,7 +303,7 @@ def demonstrate_priority_sampler():
print(f"Priority scores: {priority_scores}")
# Get batch using priority sampling
- meta1 = client.get_meta(
+ meta1 = tq.get_meta(
data_fields=["input"],
batch_size=1,
partition_id="test",
@@ -342,7 +315,7 @@ def demonstrate_priority_sampler():
# Get another batch
print("\n[Step 3] Get another batch (2 samples)...")
- meta2 = client.get_meta(
+ meta2 = tq.get_meta(
data_fields=["input"],
batch_size=2,
partition_id="test",
@@ -358,8 +331,8 @@ def demonstrate_priority_sampler():
print(f" ✓ Batch 2 high-priority indices: {[i for i in meta2.global_indexes if priority_scores[i] >= 0.1]}")
# Cleanup
- client.clear_partition(partition_id="test")
- client.close()
+ tq.clear_partition(partition_id="test")
+ tq.close()
ray.shutdown()
diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py
similarity index 84%
rename from tutorial/05_streaming_dataloader.py
rename to tutorial/06_streaming_dataloader.py
index 916be00..4d92aa2 100644
--- a/tutorial/05_streaming_dataloader.py
+++ b/tutorial/06_streaming_dataloader.py
@@ -57,39 +57,25 @@
import ray # noqa: E402
import torch # noqa: E402
-from omegaconf import DictConfig, OmegaConf # noqa: E402
+from omegaconf import OmegaConf # noqa: E402
from tensordict import TensorDict # noqa: E402
# Add the parent directory to the path for imports
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
-
+import transfer_queue as tq # noqa: E402
from transfer_queue import ( # noqa: E402
RankAwareSampler,
- SimpleStorageUnit,
StreamingDataLoader,
StreamingDataset,
- TransferQueueClient,
- TransferQueueController,
- process_zmq_server_info,
)
def setup_transfer_queue():
"""Setup TransferQueue components."""
if not ray.is_initialized():
- ray.init()
-
- config = OmegaConf.create(
- {
- "num_data_storage_units": 2,
- }
- )
-
- storage_units = {}
- for i in range(config["num_data_storage_units"]):
- storage_units[i] = SimpleStorageUnit.remote(storage_unit_size=100)
+ ray.init(namespace="TransferQueueTutorial")
print("[Setup]: Setup TransferQueue components")
print(
@@ -101,26 +87,23 @@ def setup_transfer_queue():
"TransferQueueController. In polling_mode, the controller will return empty BatchMeta when "
"available data cannot meet the consumption requirements. User side need to retry later."
)
- controller = TransferQueueController.remote(
- sampler=RankAwareSampler, # RankAwareSampler enables consistent sampling for each DP rank
- polling_mode=True, # Enable polling mode for streaming data retrieval
- )
-
- controller_info = process_zmq_server_info(controller)
- storage_unit_infos = process_zmq_server_info(storage_units)
- # Build the complete configuration
- tq_config = OmegaConf.create({}, flags={"allow_objects": True})
- tq_config.controller_info = controller_info
- tq_config.storage_unit_infos = storage_unit_infos
- config.storage_backend = "AsyncSimpleStorageManager"
- config = OmegaConf.merge(tq_config, config)
+ config = OmegaConf.create(
+ {
+ "controller": {
+ "sampler": RankAwareSampler, # RankAwareSampler enables consistent sampling for each DP rank
+ "polling_mode": True, # Enable polling mode for streaming data retrieval
+ },
+ "backend": {"SimpleStorage": {"num_data_storage_units": 2}},
+ },
+ flags={"allow_objects": True},
+ )
- return controller, storage_units, config
+ tq.init(config)
@ray.remote(num_cpus=0.1)
-def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20):
+def generate_worker(rank_id: int, num_samples: int = 20):
"""
Generate actor that produces training samples.
@@ -129,7 +112,6 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20):
Args:
rank_id: Unique identifier for this generator (used for sample indexing)
- config: TransferQueue configuration
num_samples: Number of samples to generate
Note:
@@ -137,13 +119,9 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20):
This ensures global uniqueness across all generator actors.
"""
# Create a client for interacting with TransferQueue
- client = TransferQueueClient(
- client_id=f"gen_worker_{rank_id}",
- controller_info=config.controller_info,
- )
- # Initialize the storage manager for this client
- client.initialize_storage_manager(manager_type=config.storage_backend, config=config)
+ # Need to call tq.init() in each process
+ tq.init()
# Generate and put samples into the queue
for i in range(num_samples):
@@ -159,7 +137,7 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20):
print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue")
# Put data into the specified partition
- client.put(data, partition_id="train")
+ tq.put(data, partition_id="train")
print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue")
@@ -168,7 +146,6 @@ def generate_worker(rank_id: int, config: DictConfig, num_samples: int = 20):
def update_worker(
rank_id: int,
dp_rank: int,
- config: DictConfig,
max_steps: int = 5,
):
"""
@@ -182,7 +159,6 @@ def update_worker(
rank_id: Global rank identifier for logging and display purposes
dp_rank: Data parallel rank ID that this worker belongs to
The same Ranks receive the same data samples
- config: TransferQueue configuration
max_steps: Maximum number of batches to consume
Returns:
@@ -200,8 +176,15 @@ def update_worker(
- batch_meta: Metadata for TransferQueue coordination (contains global_indexes)
"""
+ # Need to call tq.init() in each process
+ tq.init()
+
# Step 1: Create StreamingDataset
# This dataset integrates with TransferQueue and handles batch retrieval
+
+ controller = ray.get_actor("TransferQueueController")
+ config = ray.get(controller.get_config.remote())
+
dataset = StreamingDataset(
config=config,
batch_size=2,
@@ -253,7 +236,7 @@ def update_worker(
}
-def start_all_generate_actors(config):
+def start_all_generate_actors():
"""
Launch generate_actors for producing training samples.
"""
@@ -261,12 +244,12 @@ def start_all_generate_actors(config):
handlers = []
for i in range(num_workers):
- handlers.append(generate_worker.remote(rank_id=i, config=config, num_samples=20))
+ handlers.append(generate_worker.remote(rank_id=i, num_samples=20))
return handlers
-def start_all_update_actors(config):
+def start_all_update_actors():
"""
Launch update_actors for consuming training samples.
"""
@@ -285,7 +268,6 @@ def start_all_update_actors(config):
update_worker.remote(
rank_id=rank_ids[i],
dp_rank=dp_rank[i],
- config=config,
)
)
@@ -331,15 +313,15 @@ def main():
"global_batch_size to make sure consumers can accurately determine consumption status even before "
"producers have generated the samples."
)
- controller, storage_units, config = setup_transfer_queue()
+ setup_transfer_queue()
# Step 2: Launch data generation actors
print("\n[Phase 2] Starting data generation...")
- generate_worker_handlers = start_all_generate_actors(config)
+ generate_worker_handlers = start_all_generate_actors()
# Step 3: Launch data consumption actors
print("\n[Phase 3] Starting data consumption...")
- update_worker_handlers = start_all_update_actors(config)
+ update_worker_handlers = start_all_update_actors()
# Wait for completion
print("\n[Phase 4] Waiting for actors to complete...")
From 4836b484ab36d412447caf2056f6611c6dc4657e Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 00:06:07 +0800
Subject: [PATCH 02/34] todo: fix ut
Signed-off-by: 0oshowero0
---
tests/test_kv_interface.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/tests/test_kv_interface.py b/tests/test_kv_interface.py
index 24d74af..6b9eee1 100644
--- a/tests/test_kv_interface.py
+++ b/tests/test_kv_interface.py
@@ -257,15 +257,15 @@ def test_kv_list_with_keys(self):
def test_kv_list_empty_partition(self):
"""Test kv_list returns None when partition is empty."""
mock_client = MagicMock()
- mock_client.kv_list.return_value = [], []
+ mock_client.kv_list.return_value = []
with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
from transfer_queue.interface import kv_list
keys, custom_meta = kv_list(partition_id="empty_partition")
- assert keys == []
- assert custom_meta == []
+ assert keys is None
+ assert custom_meta is None
class TestKVClear:
From c1cb09f82bd8c08ff9c58dd1b2ff00caf96d2c1a Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 15:36:19 +0800
Subject: [PATCH 03/34] try add tutorial
Signed-off-by: 0oshowero0
---
tutorial/02_kv_interface.py | 179 +++++++++++++-----------
tutorial/03_metadata_concepts.py | 2 +-
tutorial/04_understanding_controller.py | 2 +-
tutorial/05_custom_sampler.py | 2 +-
tutorial/06_streaming_dataloader.py | 2 +-
5 files changed, 98 insertions(+), 89 deletions(-)
diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py
index 59159a6..b54b03a 100644
--- a/tutorial/02_kv_interface.py
+++ b/tutorial/02_kv_interface.py
@@ -53,97 +53,108 @@
ray.init(namespace="TransferQueueTutorial")
-def demonstrate_data_workflow():
+def demonstrate_kv_api():
"""
- Demonstrate basic data workflow: put → get → clear.
+ Demonstrate xxxxx
"""
print("=" * 80)
- print("Data Workflow Demo: put → get → clear")
+ print("Data xxxx")
print("=" * 80)
- # Step 1: Put data
- print("[Step 1] Putting data into TransferQueue...")
+ # Step 1: Put single data sample
+ print("[Step 1] Putting single data into TransferQueue...")
- input_ids = torch.tensor(
+ input_ids = torch.tensor([[1, 2, 3]])
+ attention_mask = torch.ones(input_ids.size())
+
+ single_sample = TensorDict(
+ {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ },
+ batch_size=input_ids.size(0),
+ )
+
+ print(f" Created single sample with multiple fields:{single_sample.keys()}.")
+ print(" Leveraging TransferQueue, we can provide fine-grained access into each single field of a sample (key).")
+
+ partition_id = "Train"
+ key = [f"{uid}_{session_id}" for uid in [0] for session_id in [0]]
+ tag = [{"global_steps": 0, "status": "running", "model_version": 0}]
+ tq.kv_put(key, partition_id=partition_id, fields=single_sample, tag=tag)
+ print(f" ✓ Data put to partition: {partition_id}")
+
+ # Step 2: Put data batch
+ print("[Step 2] Putting multiple data into TransferQueue...")
+
+ batch_input_ids = torch.tensor(
[
- [1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
+ [13, 14, 15],
]
)
- attention_mask = torch.ones_like(input_ids)
+ batch_attention_mask = torch.ones_like(batch_input_ids)
data_batch = TensorDict(
{
- "input_ids": input_ids,
- "attention_mask": attention_mask,
+ "input_ids": batch_input_ids,
+ "attention_mask": batch_attention_mask,
},
- batch_size=input_ids.size(0),
+ batch_size=batch_input_ids.size(0),
)
- print(f" Created {data_batch.batch_size[0]} samples")
- partition_id = "tutorial_partition_0"
- tq.put(data=data_batch, partition_id=partition_id)
- print(f" ✓ Data written to partition: {partition_id}")
-
- # Step 2: Get metadata
- print("[Step 2] Requesting data metadata...")
- batch_meta = tq.get_meta(
- data_fields=["input_ids", "attention_mask"],
- batch_size=data_batch.batch_size[0],
- partition_id=partition_id,
- task_name="tutorial_task",
- )
- print(f" ✓ Got metadata: {len(batch_meta)} samples")
- print(f" Global indexes: {batch_meta.global_indexes}")
-
- # Step 3: Get actual data
- print("[Step 3] Retrieving actual data...")
- retrieved_data = tq.get_data(batch_meta)
- print(" ✓ Data retrieved successfully")
- print(f" Keys: {list(retrieved_data.keys())}")
-
- # Step 4: Verify
- print("[Step 4] Verifying data integrity...")
- assert torch.equal(retrieved_data["input_ids"], input_ids)
- assert torch.equal(retrieved_data["attention_mask"], attention_mask)
- print(" ✓ Data matches original!")
+ partition_id = "Train"
+ keys = [f"{uid}_{session_id}" for uid in [1, 1, 1, 2] for session_id in [0, 1, 2, 0]]
+ tags = [{"global_steps": 1, "status": "running", "model_version": 1} for _ in range(len(keys))]
- # Step 5: Clear
- print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)")
- tq.clear_partition(partition_id=partition_id)
- print(" ✓ Partition cleared")
+ print(f" Created {data_batch.batch_size[0]} samples, assigning keys: {keys}, tags: {tags}")
+ tq.kv_batch_put(keys, partition_id=partition_id, fields=data_batch, tags=tags)
+ print(f" ✓ Data put to partition: {partition_id}")
+ # Step 3: Append new data fields to existing samples
+ print("[Step 3] Putting multiple data into TransferQueue...")
+ batch_response = torch.tensor(
+ [
+ [4, 5, 6],
+ [7, 8, 9],
+ ]
+ )
+ data_batch = TensorDict(
+ {
+ "response": batch_response,
+ },
+ batch_size=batch_response.size(0),
+ )
-def demonstrate_storage_backend_options():
- """
- Show different storage backend options.
- """
- print("=" * 80)
- print("Storage Backend Options")
- print("=" * 80)
+ keys = [f"{uid}_{session_id}" for uid in [1, 2] for session_id in [1, 0]]
+ tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(keys))]
- print("TransferQueue supports multiple storage backends:")
- print("1. SimpleStorage (default)")
- print(" - In-memory storage, fast and simple")
- print(" - Leveraging ZMQ for communication, with zero-copy serialization and transfer")
- print(" - No extra dependencies, good for development and testing")
+ tq.kv_batch_put(keys, partition_id=partition_id, fields=data_batch, tags=tags)
- print("2. Yuanrong")
- print(" - Ascend native distributed storage solution")
- print(" - Hierarchical storage interfaces including HBM/DRAM/SSD")
+ # Step 4: Query all keys and tags
+ print("[Step 4] Query all the keys and tags from TransferQueue...")
+ all_keys, all_tags = tq.kv_list(partition_id=partition_id)
+ print(f" ✓ Got keys: {keys}")
+ print(f" Got tags: {tags}")
- print("3. MooncakeStore (on the way)")
- print(" - Support multiple transmission protocols")
- print(" - RDMA between DRAM and HBM")
+ # Step 5: Get specific fields of values
+ print("[Step 5] ...")
+ retrieved_input_ids_data = tq.kv_get(all_keys, fields="input_ids")
+ print(" ✓ Data retrieved successfully")
+ print(f" {retrieved_input_ids_data}")
- print("4. Ray RDT (on the way)")
- print(" - Leverage Ray's distributed object store to store data")
+ # Step 6: Get all fields of values
+ print("[Step 5] ...")
+ retrieved_all_data = tq.kv_get(all_keys)
+ print(" ✓ Data retrieved successfully")
+ print(f" {retrieved_all_data}")
- print("5. Custom Storage Backends")
- print(" - Implement your own storage manager by inheriting from `TransferQueueStorageManager` base class")
- print(" - For KV based storage, you only need to provide a storage client and integrate with `KVStorageManager`")
+ # Step 5: Clear
+ print("[Step 5] Clearing partition...")
+ tq.kv_clear(all_keys, partition_id=partition_id)
+ print(" ✓ Keys are cleared")
def main():
@@ -151,18 +162,23 @@ def main():
print(
textwrap.dedent(
"""
- TransferQueue Tutorial 1: Core Components Introduction
+ TransferQueue Tutorial 2: Key-Value Semantic API
- This script introduces the three core components of TransferQueue:
- 1. TransferQueueController - Manages all the metadata and tracks the production and consumption states
- 2. StorageBackend - Pluggable distributed storage backend that holds the actual data
- 3. TransferQueueClient - Client interface for reading/writing data (user-facing API)
+ This script demonstrate the key-value semantic API of TransferQueue:
+ 1. kv_put & kv_batch_put - Put key-value pairs and custom tags into TransferQueue
+ 2. kv_get - Get values from TransferQueue according to user-specified keys
+ 3. kv_list - Get all the keys and tags from TransferQueue
+ 4. kv_clear - Delete the value and tags of given keys
- Key Concepts:
- - Data is organized into logical partitions (e.g., "train", "val")
- - Each sample has multiple fields, with a global index for identification
- - Controller maintains production/consumption state tracking
- - Client is the main interface users interact with
+ Supported Features:
+ 1. Fine-grained access - user can put/get partial fields inside a data sample (key)
+ 2. Partition management - each logical partition manages their own key-value mapping
+
+ Unsupported Features:
+ 1. Production & consumption management (user have to manually management through tags)
+ 2. User-defined sampler in TransferQueue controller (user need to do sampling by themselves through tags)
+ 3. Fully streamed data pipeline (TQ controller cannot determine which sample to dispatch to the consumers)
+
"""
)
)
@@ -172,21 +188,14 @@ def main():
print("Setting up TransferQueue...")
tq.init()
- print("Demonstrating the user workflow...")
- demonstrate_data_workflow()
-
- demonstrate_storage_backend_options()
+ print("Demonstrating the key-value semantic API...")
+ demonstrate_kv_api()
print("=" * 80)
print("Tutorial Complete!")
print("=" * 80)
print("Key Takeaways:")
- print("1. TransferQueue has 3 core components:")
- print(" - Controller: Manages data production/consumption state")
- print(" - StorageBackend: Persists actual data")
- print(" - Client: User-facing API (what you use)")
- print("2. Client is the main interface users interact with")
- print("3. You can swap out different storage backends easily")
+ print("1. ")
# Cleanup
tq.close()
diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py
index 3a4a57d..bbda5b9 100644
--- a/tutorial/03_metadata_concepts.py
+++ b/tutorial/03_metadata_concepts.py
@@ -387,7 +387,7 @@ def main():
print(
textwrap.dedent(
"""
- TransferQueue Tutorial 2: Metadata System
+ TransferQueue Tutorial 3: Metadata System
This script introduces the metadata system in TransferQueue, which tracks
the structure and state of data:
diff --git a/tutorial/04_understanding_controller.py b/tutorial/04_understanding_controller.py
index 4ca426d..fbeae21 100644
--- a/tutorial/04_understanding_controller.py
+++ b/tutorial/04_understanding_controller.py
@@ -235,7 +235,7 @@ def main():
print(
textwrap.dedent(
"""
- TransferQueue Tutorial 3: Understanding TransferQueueController
+ TransferQueue Tutorial 4: Understanding TransferQueueController
This script demonstrates TransferQueueController's key features:
diff --git a/tutorial/05_custom_sampler.py b/tutorial/05_custom_sampler.py
index c35b4e0..bdca835 100644
--- a/tutorial/05_custom_sampler.py
+++ b/tutorial/05_custom_sampler.py
@@ -341,7 +341,7 @@ def main():
print(
textwrap.dedent(
"""
- TransferQueue Tutorial 4: Custom Sampler Development
+ TransferQueue Tutorial 5: Custom Sampler Development
This script demonstrates how to develop custom samplers for TransferQueue.
Samplers control HOW data is consumed from the queue.
diff --git a/tutorial/06_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py
index 4d92aa2..d95a3c8 100644
--- a/tutorial/06_streaming_dataloader.py
+++ b/tutorial/06_streaming_dataloader.py
@@ -289,7 +289,7 @@ def main():
print(
textwrap.dedent(
"""
- TransferQueue Tutorial 5: StreamingDataLoader for Distributed Training
+ TransferQueue Tutorial 6: StreamingDataLoader for Distributed Training
This tutorial demonstrates the StreamingDataLoader interface for distributed
training scenarios. It showcases how to use StreamingDataset and StreamingDataLoader
From c53e5fa9161e78277ae55157b169fb34abd9e3c5 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 16:42:39 +0800
Subject: [PATCH 04/34] provide basic runnable tutorial
Signed-off-by: 0oshowero0
---
transfer_queue/controller.py | 10 ++-
tutorial/02_kv_interface.py | 164 ++++++++++++++++++++---------------
2 files changed, 105 insertions(+), 69 deletions(-)
diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py
index 6175863..6e46280 100644
--- a/transfer_queue/controller.py
+++ b/transfer_queue/controller.py
@@ -1504,6 +1504,9 @@ def kv_retrieve_keys(
partition.keys_mapping[keys[none_indexes[i]]] = batch_global_indexes[i]
partition.revert_keys_mapping[batch_global_indexes[i]] = keys[none_indexes[i]]
+ with partition.data_status_lock:
+ partition.ensure_samples_capacity(max(batch_global_indexes) + 1)
+
verified_global_indexes = [idx for idx in global_indexes if idx is not None]
assert len(verified_global_indexes) == len(keys)
@@ -1809,10 +1812,14 @@ def _process_request(self):
elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS:
with perf_monitor.measure(op_type="KV_RETRIEVE_KEYS"):
+ params = request_msg.body
keys = params["keys"]
partition_id = params["partition_id"]
create = params["create"]
+ print("+++++++++++TQDEBUG+++++++++++++++++")
+ print(params)
+
metadata = self.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create)
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE,
@@ -1823,6 +1830,7 @@ def _process_request(self):
elif request_msg.request_type == ZMQRequestType.KV_LIST:
with perf_monitor.measure(op_type="KV_LIST"):
+ params = request_msg.body
partition_id = params["partition_id"]
partition = self._get_partition(partition_id)
if not partition:
@@ -1839,7 +1847,7 @@ def _process_request(self):
request_type=ZMQRequestType.KV_LIST_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
- body={"keys": keys, "custom_meta": custom_meta, message: message},
+ body={"keys": keys, "custom_meta": custom_meta, "message": message},
)
self.request_handle_socket.send_multipart([identity, *response_msg.serialize()])
diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py
index b54b03a..921a02b 100644
--- a/tutorial/02_kv_interface.py
+++ b/tutorial/02_kv_interface.py
@@ -55,14 +55,15 @@
def demonstrate_kv_api():
"""
- Demonstrate xxxxx
+ Demonstrate the Key-Value (KV) semantic API:
+ kv_put & kv_batch_put -> kv_get -> kv_list -> kv_clear
"""
print("=" * 80)
- print("Data xxxx")
+ print("Key-Value Semantic API Demo: kv_put → kv_get → kv_list → kv_clear")
print("=" * 80)
- # Step 1: Put single data sample
- print("[Step 1] Putting single data into TransferQueue...")
+ # Step 1: Put a single key-value pair with kv_put
+ print("[Step 1] Putting a single sample with kv_put...")
input_ids = torch.tensor([[1, 2, 3]])
attention_mask = torch.ones(input_ids.size())
@@ -75,17 +76,18 @@ def demonstrate_kv_api():
batch_size=input_ids.size(0),
)
- print(f" Created single sample with multiple fields:{single_sample.keys()}.")
- print(" Leveraging TransferQueue, we can provide fine-grained access into each single field of a sample (key).")
-
partition_id = "Train"
- key = [f"{uid}_{session_id}" for uid in [0] for session_id in [0]]
- tag = [{"global_steps": 0, "status": "running", "model_version": 0}]
- tq.kv_put(key, partition_id=partition_id, fields=single_sample, tag=tag)
- print(f" ✓ Data put to partition: {partition_id}")
+ key = "0_0" # User-defined key: "{uid}_{session_id}"
+ tag = {"global_steps": 0, "status": "running", "model_version": 0}
+
+ print(f" Created single sample with key: {key}, fields: {list(single_sample.keys())}, and tag: {tag}")
+ print(" Note: kv_put accepts a user-defined string key instead of auto-generated index")
+
+ tq.kv_put(key=key, partition_id=partition_id, fields=single_sample, tag=tag)
+ print(f" ✓ kv_put: key='{key}', tag={tag}")
- # Step 2: Put data batch
- print("[Step 2] Putting multiple data into TransferQueue...")
+ # Step 2: Put multiple key-value pairs with kv_batch_put
+ print("\n[Step 2] Putting batch data with kv_batch_put...")
batch_input_ids = torch.tensor(
[
@@ -105,56 +107,66 @@ def demonstrate_kv_api():
batch_size=batch_input_ids.size(0),
)
- partition_id = "Train"
- keys = [f"{uid}_{session_id}" for uid in [1, 1, 1, 2] for session_id in [0, 1, 2, 0]]
+ keys = ["1_0", "1_1", "1_2", "2_0"] # 4 keys for 4 samples
tags = [{"global_steps": 1, "status": "running", "model_version": 1} for _ in range(len(keys))]
- print(f" Created {data_batch.batch_size[0]} samples, assigning keys: {keys}, tags: {tags}")
- tq.kv_batch_put(keys, partition_id=partition_id, fields=data_batch, tags=tags)
- print(f" ✓ Data put to partition: {partition_id}")
+ print(f" Created batch with {data_batch.batch_size[0]} samples")
+ print(f" Batch keys: {keys}")
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data_batch, tags=tags)
+ print(f" ✓ kv_batch_put: {len(keys)} samples written to partition '{partition_id}'")
+
+ # Step 3: Append additional fields to existing samples
+ print("\n[Step 3] Appending new fields to existing samples...")
- # Step 3: Append new data fields to existing samples
- print("[Step 3] Putting multiple data into TransferQueue...")
batch_response = torch.tensor(
[
[4, 5, 6],
[7, 8, 9],
]
)
- data_batch = TensorDict(
+ response_batch = TensorDict(
{
"response": batch_response,
},
batch_size=batch_response.size(0),
)
- keys = [f"{uid}_{session_id}" for uid in [1, 2] for session_id in [1, 0]]
- tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(keys))]
+ append_keys = ["1_1", "2_0"] # Appending to existing samples
+ append_tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(append_keys))]
+ print(f" Adding 'response' field to keys: {append_keys}")
+ tq.kv_batch_put(keys=append_keys, partition_id=partition_id, fields=response_batch, tags=append_tags)
+ print(" ✓ Field appended successfully (sample now has input_ids, attention_mask, and response)")
- tq.kv_batch_put(keys, partition_id=partition_id, fields=data_batch, tags=tags)
-
- # Step 4: Query all keys and tags
- print("[Step 4] Query all the keys and tags from TransferQueue...")
+ # Step 4: List all keys and tags in a partition
+ print("\n[Step 4] Listing all keys and tags in partition...")
all_keys, all_tags = tq.kv_list(partition_id=partition_id)
- print(f" ✓ Got keys: {keys}")
- print(f" Got tags: {tags}")
-
- # Step 5: Get specific fields of values
- print("[Step 5] ...")
- retrieved_input_ids_data = tq.kv_get(all_keys, fields="input_ids")
- print(" ✓ Data retrieved successfully")
- print(f" {retrieved_input_ids_data}")
-
- # Step 6: Get all fields of values
- print("[Step 5] ...")
- retrieved_all_data = tq.kv_get(all_keys)
- print(" ✓ Data retrieved successfully")
- print(f" {retrieved_all_data}")
-
- # Step 5: Clear
- print("[Step 5] Clearing partition...")
- tq.kv_clear(all_keys, partition_id=partition_id)
- print(" ✓ Keys are cleared")
+ print(f" Found {len(all_keys)} keys in partition '{partition_id}':")
+ for k, t in zip(all_keys, all_tags, strict=False):
+ print(f" - key='{k}', tag={t}")
+
+ # Step 5: Retrieve specific fields using kv_get
+ print("\n[Step 5] Retrieving specific fields with kv_get...")
+ retrieved_input_ids = tq.kv_get(keys=all_keys, partition_id=partition_id, fields="input_ids")
+ print(f" Retrieved 'input_ids' field for all {len(all_keys)} samples:")
+ print(f" Shape: {retrieved_input_ids.batch_size}")
+ print(f" Values: {retrieved_input_ids['input_ids']}")
+
+ # TODO: this will fail because only single sample has an extra fields...
+ # need to add additional check during kv_get to make sure other samples are correctly tackled
+ # # Step 6: Retrieve all fields using kv_get
+ # print("\n[Step 6] Retrieving all fields with kv_get...")
+ # retrieved_all = tq.kv_get(keys=all_keys, partition_id=partition_id)
+ # print(f" Retrieved all fields for {len(all_keys)} samples:")
+ # print(f" Fields: {list(retrieved_all.keys())}")
+
+ # Step 7: Clear specific keys
+ print("\n[Step 7] Clearing keys from partition...")
+ keys_to_clear = all_keys[:2] # Clear first 2 keys
+ tq.kv_clear(keys=keys_to_clear, partition_id=partition_id)
+ print(f" ✓ Cleared keys: {keys_to_clear}")
+
+ remaining_keys, _ = tq.kv_list(partition_id=partition_id)
+ print(f" Remaining keys in partition: {remaining_keys}")
def main():
@@ -162,23 +174,35 @@ def main():
print(
textwrap.dedent(
"""
- TransferQueue Tutorial 2: Key-Value Semantic API
-
- This script demonstrate the key-value semantic API of TransferQueue:
- 1. kv_put & kv_batch_put - Put key-value pairs and custom tags into TransferQueue
- 2. kv_get - Get values from TransferQueue according to user-specified keys
- 3. kv_list - Get all the keys and tags from TransferQueue
- 4. kv_clear - Delete the value and tags of given keys
-
- Supported Features:
- 1. Fine-grained access - user can put/get partial fields inside a data sample (key)
- 2. Partition management - each logical partition manages their own key-value mapping
-
- Unsupported Features:
- 1. Production & consumption management (user have to manually management through tags)
- 2. User-defined sampler in TransferQueue controller (user need to do sampling by themselves through tags)
- 3. Fully streamed data pipeline (TQ controller cannot determine which sample to dispatch to the consumers)
-
+ TransferQueue Tutorial 2: Key-Value (KV) Semantic API
+
+ This tutorial demonstrates the KV semantic API, which provides a simpler
+ interface for data storage and retrieval using user-defined string keys
+ instead of auto-generated numeric indexes.
+
+ Key Methods:
+ 1. kv_put - Put a single key-value pair with optional metadata tag
+ 2. kv_batch_put - Put multiple key-value pairs efficiently in batch
+ 3. kv_get - Retrieve data by key(s), optionally specifying fields
+ 4. kv_list - List all keys and their metadata tags in a partition
+ 5. kv_clear - Remove key-value pairs from storage
+
+ Key Features:
+ ✓ User-defined keys - Use meaningful string keys instead of numeric indexes
+ ✓ Fine-grained access - Get/put individual fields within a sample
+ ✓ Partition management - Each partition maintains its own key-value mapping
+ ✓ Metadata tags - Attach custom metadata (status, scores, etc.) to samples
+
+ Use Cases:
+ - Storing per-model-checkpoint states
+ - Managing evaluation results by sample ID
+ - Caching intermediate computation results
+ - Fine-grained data access without full BatchMeta management
+
+ Limitations (vs Full API):
+ - No built-in production/consumption tracking (manage via tags)
+ - No Sampler-based sampling (implement sampling logic externally)
+ - Controller doesn't control streaming (manual key management required)
"""
)
)
@@ -188,19 +212,23 @@ def main():
print("Setting up TransferQueue...")
tq.init()
- print("Demonstrating the key-value semantic API...")
+ print("\nDemonstrating the KV semantic API...")
demonstrate_kv_api()
- print("=" * 80)
+ print("\n" + "=" * 80)
print("Tutorial Complete!")
print("=" * 80)
- print("Key Takeaways:")
- print("1. ")
+ print("\nKey Takeaways:")
+ print(" 1. KV API simplifies data access with user-defined string keys")
+ print(" 2. kv_batch_put is more efficient for bulk operations")
+ print(" 3. Use 'fields' parameter to get/put specific fields only")
+ print(" 4. Tags enable custom metadata for production status, scores, etc.")
+ print(" 5. Use kv_list to inspect partition contents")
# Cleanup
tq.close()
ray.shutdown()
- print("\n✓ Cleanup complete")
+ print("\nCleanup complete")
except Exception as e:
print(f"Error during tutorial: {e}")
From 85567f3624c65c89eedbfb3ed364600c0931aa20 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 16:49:33 +0800
Subject: [PATCH 05/34] fix other tutorials
Signed-off-by: 0oshowero0
---
tutorial/03_metadata_concepts.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py
index bbda5b9..126cc5b 100644
--- a/tutorial/03_metadata_concepts.py
+++ b/tutorial/03_metadata_concepts.py
@@ -209,6 +209,9 @@ def demonstrate_batch_meta():
[
{"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"},
{"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"},
+ {"uid": "prompt@2", "session_id": "session@0", "model_version": "epoch@0"},
+ {"uid": "prompt@3", "session_id": "session@0", "model_version": "epoch@0"},
+ {"uid": "prompt@4", "session_id": "session@0", "model_version": "epoch@0"},
]
)
print(f"✓ Custom meta: {batch.get_all_custom_meta()}")
From 063b6b1e36a0e0325582a820b8b3d7c7d15e3c60 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 17:39:55 +0800
Subject: [PATCH 06/34] update tutorial
Signed-off-by: 0oshowero0
---
tutorial/02_kv_interface.py | 110 ++++++++++++++++++++----------------
1 file changed, 62 insertions(+), 48 deletions(-)
diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py
index 921a02b..fc2e96b 100644
--- a/tutorial/02_kv_interface.py
+++ b/tutorial/02_kv_interface.py
@@ -56,15 +56,16 @@
def demonstrate_kv_api():
"""
Demonstrate the Key-Value (KV) semantic API:
- kv_put & kv_batch_put -> kv_get -> kv_list -> kv_clear
+ kv_put & kv_batch_put -> kv_list -> kv_get -> kv_clear
"""
print("=" * 80)
- print("Key-Value Semantic API Demo: kv_put → kv_get → kv_list → kv_clear")
+ print("Key-Value Semantic API Demo: kv_put/kv_batch_put → kv_list → kv_get → kv_clear")
print("=" * 80)
# Step 1: Put a single key-value pair with kv_put
print("[Step 1] Putting a single sample with kv_put...")
+ # Define the data content (The "Value")
input_ids = torch.tensor([[1, 2, 3]])
attention_mask = torch.ones(input_ids.size())
@@ -77,14 +78,16 @@ def demonstrate_kv_api():
)
partition_id = "Train"
+ # Use a meaningful string key instead of an auto-increment integer
key = "0_0" # User-defined key: "{uid}_{session_id}"
tag = {"global_steps": 0, "status": "running", "model_version": 0}
- print(f" Created single sample with key: {key}, fields: {list(single_sample.keys())}, and tag: {tag}")
- print(" Note: kv_put accepts a user-defined string key instead of auto-generated index")
+ print(f" Inserting Key: {key}")
+ print(f" Fields (Columns): {list(single_sample.keys())}")
+ print(f" Tag (Metadata): {tag}")
tq.kv_put(key=key, partition_id=partition_id, fields=single_sample, tag=tag)
- print(f" ✓ kv_put: key='{key}', tag={tag}")
+ print(" ✓ kv_put success.")
# Step 2: Put multiple key-value pairs with kv_batch_put
print("\n[Step 2] Putting batch data with kv_batch_put...")
@@ -110,13 +113,14 @@ def demonstrate_kv_api():
keys = ["1_0", "1_1", "1_2", "2_0"] # 4 keys for 4 samples
tags = [{"global_steps": 1, "status": "running", "model_version": 1} for _ in range(len(keys))]
- print(f" Created batch with {data_batch.batch_size[0]} samples")
- print(f" Batch keys: {keys}")
+ print(f" Inserting batch of {len(keys)} samples.")
+ print(f" Fields (Columns): {list(data_batch.keys())}")
+ print(f" Tag (Metadata): {tags}")
tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data_batch, tags=tags)
- print(f" ✓ kv_batch_put: {len(keys)} samples written to partition '{partition_id}'")
+ print(" ✓ kv_batch_put success.")
# Step 3: Append additional fields to existing samples
- print("\n[Step 3] Appending new fields to existing samples...")
+ print("\n[Step 3] Appending new fields (Columns) to existing samples...")
batch_response = torch.tensor(
[
@@ -131,25 +135,38 @@ def demonstrate_kv_api():
batch_size=batch_response.size(0),
)
+ # We only update subset of keys
append_keys = ["1_1", "2_0"] # Appending to existing samples
append_tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(append_keys))]
- print(f" Adding 'response' field to keys: {append_keys}")
+ print(f" Target Keys: {append_keys}")
+ print(" New Field to Add is: 'response'")
+ print(f" The updated tags are: {append_tags}")
tq.kv_batch_put(keys=append_keys, partition_id=partition_id, fields=response_batch, tags=append_tags)
- print(" ✓ Field appended successfully (sample now has input_ids, attention_mask, and response)")
+ print(" ✓ Update success: Samples '1_1' and '2_0' now contain {input_ids, attention_mask, response}.")
+
+ # Step 4: Only update tags through kv_put
+ print("\n[Step 4] Update existing tags without providing value...")
+ key = "0_0"
+ tag = {"global_steps": 0, "status": "finish", "model_version": 0}
+ print(f" Target Key: {key}")
+ print(f" The updated tag is: {tag}")
+ tq.kv_put(key=key, partition_id=partition_id, fields=None, tag=tag)
+ print(f" ✓ Update success: Samples '0_0' now has tag as {tag}.")
+
+ # Step 5: List all keys and tags in a partition
+ print("\n[Step 5] Listing all keys and tags in partition...")
- # Step 4: List all keys and tags in a partition
- print("\n[Step 4] Listing all keys and tags in partition...")
all_keys, all_tags = tq.kv_list(partition_id=partition_id)
print(f" Found {len(all_keys)} keys in partition '{partition_id}':")
for k, t in zip(all_keys, all_tags, strict=False):
- print(f" - key='{k}', tag={t}")
+ print(f" - key='{k}' | tag={t}")
+
+ # Step 6: Retrieve specific fields using kv_get
+ print("\n[Step 6] Retrieving specific fields (Column) with kv_get...")
+ print(" Fetching only 'input_ids' to save bandwidth (ignoring 'attention_mask' and 'response').")
- # Step 5: Retrieve specific fields using kv_get
- print("\n[Step 5] Retrieving specific fields with kv_get...")
retrieved_input_ids = tq.kv_get(keys=all_keys, partition_id=partition_id, fields="input_ids")
- print(f" Retrieved 'input_ids' field for all {len(all_keys)} samples:")
- print(f" Shape: {retrieved_input_ids.batch_size}")
- print(f" Values: {retrieved_input_ids['input_ids']}")
+ print(f" ✓ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.")
# TODO: this will fail because only single sample has an extra fields...
# need to add additional check during kv_get to make sure other samples are correctly tackled
@@ -159,9 +176,9 @@ def demonstrate_kv_api():
# print(f" Retrieved all fields for {len(all_keys)} samples:")
# print(f" Fields: {list(retrieved_all.keys())}")
- # Step 7: Clear specific keys
- print("\n[Step 7] Clearing keys from partition...")
- keys_to_clear = all_keys[:2] # Clear first 2 keys
+ # Step 8: Clear specific keys
+ print("\n[Step 8] Clearing keys from partition...")
+ keys_to_clear = all_keys[:2] # Delete the first 2 keys
tq.kv_clear(keys=keys_to_clear, partition_id=partition_id)
print(f" ✓ Cleared keys: {keys_to_clear}")
@@ -176,33 +193,31 @@ def main():
"""
TransferQueue Tutorial 2: Key-Value (KV) Semantic API
- This tutorial demonstrates the KV semantic API, which provides a simpler
- interface for data storage and retrieval using user-defined string keys
- instead of auto-generated numeric indexes.
+ This tutorial demonstrates the KV semantic API, which provides a simple
+ interface for data storage and retrieval using user-defined string keys.
Key Methods:
- 1. kv_put - Put a single key-value pair with optional metadata tag
- 2. kv_batch_put - Put multiple key-value pairs efficiently in batch
- 3. kv_get - Retrieve data by key(s), optionally specifying fields
- 4. kv_list - List all keys and their metadata tags in a partition
- 5. kv_clear - Remove key-value pairs from storage
+ 1. (async_)kv_put - Insert/Update a multi-column sample by key, with optional metadata tag
+ 2. (async_)kv_batch_put - Put multiple key-value pairs efficiently in batch
+ 3. (async_)kv_get - Retrieve samples (by keys), supporting column selection (by fields)
+ 4. (async_)kv_list - List keys and tags (metadata) in a partition
+ 5. (async_)kv_clear - Remove key-value pairs from storage
Key Features:
- ✓ User-defined keys - Use meaningful string keys instead of numeric indexes
- ✓ Fine-grained access - Get/put individual fields within a sample
- ✓ Partition management - Each partition maintains its own key-value mapping
- ✓ Metadata tags - Attach custom metadata (status, scores, etc.) to samples
+ ✓ Redis-style Semantics - Familiar KV interface (Put/Get/List) for zero learning curve
+ ✓ Fine-grained Access - Update or retrieve specific fields (columns) within a key (row) without full op.
+ ✓ Partition Isolation - Logical separation of storage namespaces
+ ✓ Metadata Tags - Lightweight metadata for status tracking
+ ✓ Pluggable Backends - Supports multiple backends
Use Cases:
- - Storing per-model-checkpoint states
- - Managing evaluation results by sample ID
- - Caching intermediate computation results
- - Fine-grained data access without full BatchMeta management
-
- Limitations (vs Full API):
- - No built-in production/consumption tracking (manage via tags)
- - No Sampler-based sampling (implement sampling logic externally)
- - Controller doesn't control streaming (manual key management required)
+ - Focusing on fine-grained data access where extreme streaming performance is non-essential
+ - Integration with external ReplayBuffer/single-controller that manage sample dispatching
+
+ Limitations (vs low-level native APIs):
+ - No built-in production/consumption tracking: Users have to manually check status through tags.
+ - No built-in Sampler support: Must implement data dispatch by ReplayBuffer or single-controller externally.
+ - No fully streaming: Consumers must wait for single-controller to dispatch `keys`.
"""
)
)
@@ -219,11 +234,10 @@ def main():
print("Tutorial Complete!")
print("=" * 80)
print("\nKey Takeaways:")
- print(" 1. KV API simplifies data access with user-defined string keys")
- print(" 2. kv_batch_put is more efficient for bulk operations")
- print(" 3. Use 'fields' parameter to get/put specific fields only")
- print(" 4. Tags enable custom metadata for production status, scores, etc.")
- print(" 5. Use kv_list to inspect partition contents")
+ print(" 1. KV API simplifies data access with Redis-style semantics")
+ print(" 2. Use 'fields' parameter to get/put specific fields only")
+ print(" 3. Tags enable custom metadata for production status, scores, etc.")
+ print(" 4. Use kv_list to inspect partition contents")
# Cleanup
tq.close()
From 1a86e20b050385da769b3abae295faab05b7c227 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 18:23:59 +0800
Subject: [PATCH 07/34] fix
Signed-off-by: 0oshowero0
---
transfer_queue/controller.py | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py
index 6e46280..1f430aa 100644
--- a/transfer_queue/controller.py
+++ b/transfer_queue/controller.py
@@ -829,8 +829,10 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr
self.field_shapes.pop(idx, None)
self.field_custom_backend_meta.pop(idx, None)
self.custom_meta.pop(idx, None)
- self.keys_mapping.pop(self.revert_keys_mapping[idx], None)
- self.revert_keys_mapping.pop(idx, None)
+
+ if idx in self.revert_keys_mapping:
+ self.keys_mapping.pop(self.revert_keys_mapping[idx], None)
+ self.revert_keys_mapping.pop(idx, None)
except Exception as e:
logger.error(
@@ -1817,9 +1819,6 @@ def _process_request(self):
partition_id = params["partition_id"]
create = params["create"]
- print("+++++++++++TQDEBUG+++++++++++++++++")
- print(params)
-
metadata = self.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create)
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE,
From e1f6ecf7ec930c82291c7af4a0a5a8f8ef4a459b Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 19:06:26 +0800
Subject: [PATCH 08/34] fix cornercase
Signed-off-by: 0oshowero0
---
transfer_queue/controller.py | 10 +++++++++-
transfer_queue/interface.py | 10 ++++++++--
tutorial/02_kv_interface.py | 27 ++++++++++++++-------------
3 files changed, 31 insertions(+), 16 deletions(-)
diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py
index 1f430aa..1f01482 100644
--- a/transfer_queue/controller.py
+++ b/transfer_queue/controller.py
@@ -1512,7 +1512,15 @@ def kv_retrieve_keys(
verified_global_indexes = [idx for idx in global_indexes if idx is not None]
assert len(verified_global_indexes) == len(keys)
- data_fields = list(partition.field_name_mapping.keys())
+ # must fetch fields that the requested samples all have
+ col_mask = partition.production_status[verified_global_indexes, :].sum(dim=0).reshape(-1) == len(
+ verified_global_indexes
+ )
+ data_fields = []
+ for fname, col_idx in partition.field_name_mapping.items():
+ if col_mask[col_idx]:
+ data_fields.append(fname)
+
metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch")
return metadata
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index 2585c35..ed325b8 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -813,6 +813,7 @@ def kv_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str]
Raises:
RuntimeError: If keys or partition are not found
+ RuntimeError: If empty fields exist in any key (sample)
Example:
>>> import transfer_queue as tq
@@ -838,8 +839,10 @@ def kv_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str]
fields = [fields]
batch_meta = batch_meta.select_fields(fields)
- data = tq_client.get_data(batch_meta)
+ if not batch_meta.is_ready:
+ raise RuntimeError("Some fields are not ready in all the requested keys!")
+ data = tq_client.get_data(batch_meta)
return data
@@ -1052,6 +1055,7 @@ async def async_kv_get(
Raises:
RuntimeError: If keys or partition are not found
+ RuntimeError: If empty fields exist in any key (sample)
Example:
>>> import transfer_queue as tq
@@ -1077,8 +1081,10 @@ async def async_kv_get(
fields = [fields]
batch_meta = batch_meta.select_fields(fields)
- data = await tq_client.async_get_data(batch_meta)
+ if not batch_meta.is_ready:
+ raise RuntimeError("Some fields are not ready in all the requested keys!")
+ data = await tq_client.async_get_data(batch_meta)
return data
diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py
index fc2e96b..2b309ca 100644
--- a/tutorial/02_kv_interface.py
+++ b/tutorial/02_kv_interface.py
@@ -146,12 +146,12 @@ def demonstrate_kv_api():
# Step 4: Only update tags through kv_put
print("\n[Step 4] Update existing tags without providing value...")
- key = "0_0"
- tag = {"global_steps": 0, "status": "finish", "model_version": 0}
- print(f" Target Key: {key}")
- print(f" The updated tag is: {tag}")
- tq.kv_put(key=key, partition_id=partition_id, fields=None, tag=tag)
- print(f" ✓ Update success: Samples '0_0' now has tag as {tag}.")
+ key_for_update_tags = "0_0"
+ tag_update = {"global_steps": 0, "status": "finish", "model_version": 0}
+ print(f" Target Key: {key_for_update_tags}")
+ print(f" The updated tag is: {tag_update}")
+ tq.kv_put(key=key_for_update_tags, partition_id=partition_id, fields=None, tag=tag_update)
+ print(f" ✓ Update success: Samples '0_0' now has tag as {tag_update}.")
# Step 5: List all keys and tags in a partition
print("\n[Step 5] Listing all keys and tags in partition...")
@@ -168,13 +168,14 @@ def demonstrate_kv_api():
retrieved_input_ids = tq.kv_get(keys=all_keys, partition_id=partition_id, fields="input_ids")
print(f" ✓ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.")
- # TODO: this will fail because only single sample has an extra fields...
- # need to add additional check during kv_get to make sure other samples are correctly tackled
- # # Step 6: Retrieve all fields using kv_get
- # print("\n[Step 6] Retrieving all fields with kv_get...")
- # retrieved_all = tq.kv_get(keys=all_keys, partition_id=partition_id)
- # print(f" Retrieved all fields for {len(all_keys)} samples:")
- # print(f" Fields: {list(retrieved_all.keys())}")
+ # # Step 7: Retrieve all fields using kv_get
+ print("\n[Step 7] Retrieving all fields with kv_get...")
+ retrieved_all = tq.kv_get(keys=all_keys, partition_id=partition_id)
+ print(f" Retrieved all fields for {all_keys}:")
+ print(f" Fields: {list(retrieved_all.keys())}")
+ print(
+ f" Note: We cannot retrieve fields {list(response_batch.keys())}, since they only available in {append_keys}"
+ )
# Step 8: Clear specific keys
print("\n[Step 8] Clearing keys from partition...")
From f1c31ee2d5a6e725835760cc05522454a6d2e0ab Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 20:55:36 +0800
Subject: [PATCH 09/34] minor fix
Signed-off-by: 0oshowero0
---
tutorial/03_metadata_concepts.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py
index 126cc5b..91752d8 100644
--- a/tutorial/03_metadata_concepts.py
+++ b/tutorial/03_metadata_concepts.py
@@ -305,7 +305,7 @@ def demonstrate_real_workflow():
# Initialize Ray
if not ray.is_initialized():
- ray.init()
+ ray.init(namespace="TransferQueueTutorial")
# Initialize TransferQueue
tq.init()
From 3ecf4d394e5a6cfd8193685f4dcd85d819e147c5 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 21:20:16 +0800
Subject: [PATCH 10/34] add ut
Signed-off-by: 0oshowero0
---
tests/e2e/test_kv_interface_e2e.py | 580 +++++++++++++++++++++++++++++
tests/test_kv_interface.py | 539 ---------------------------
2 files changed, 580 insertions(+), 539 deletions(-)
create mode 100644 tests/e2e/test_kv_interface_e2e.py
delete mode 100644 tests/test_kv_interface.py
diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py
new file mode 100644
index 0000000..458befa
--- /dev/null
+++ b/tests/e2e/test_kv_interface_e2e.py
@@ -0,0 +1,580 @@
+# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+# Copyright 2025 The TransferQueue Team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""End-to-end tests for KV interface in transfer_queue.interface.
+
+This test module validates the KV interface functionality by:
+1. Using external interfaces (kv_put, kv_batch_put, kv_get, kv_list, kv_clear) for read/write
+2. Verifying correctness by calling TransferQueueController's internal methods directly
+"""
+
+import os
+import sys
+from pathlib import Path
+
+import pytest
+import ray
+import torch
+from tensordict import TensorDict
+
+# Add parent directory to path
+parent_dir = Path(__file__).resolve().parent.parent.parent
+sys.path.append(str(parent_dir))
+
+import transfer_queue as tq # noqa: E402
+
+# Configure Ray for tests
+os.environ["RAY_DEDUP_LOGS"] = "0"
+
+
+@pytest.fixture(scope="module")
+def ray_init():
+ """Initialize Ray for the test module."""
+ if not ray.is_initialized():
+ ray.init(namespace="TestKVInterfaceE2E")
+ yield
+ if ray.is_initialized():
+ ray.shutdown()
+
+
+@pytest.fixture(scope="module")
+def tq_system(ray_init):
+ """Initialize TransferQueue system for the test module."""
+ tq.init()
+ yield
+ tq.close()
+
+
+@pytest.fixture
+def controller(tq_system):
+ """Get the TransferQueueController actor for direct verification."""
+ controller = ray.get_actor("TransferQueueController")
+ yield controller
+
+
+@pytest.fixture(autouse=True)
+def cleanup_partition(controller):
+ """Cleanup partition after each test."""
+ yield
+ try:
+ ray.get(controller.clear_partition.remote("test_partition"))
+ except Exception:
+ pass
+
+
+def get_controller_partition(controller, partition_id: str):
+ """Get partition snapshot from controller for verification."""
+ return ray.get(controller.get_partition_snapshot.remote(partition_id))
+
+
+def assert_tensor_equal(tensor_a, tensor_b, msg=""):
+ """Assert two tensors are equal."""
+ assert torch.equal(tensor_a, tensor_b), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}"
+
+
+def assert_tensor_close(tensor_a, tensor_b, rtol=1e-5, atol=1e-8, msg=""):
+ """Assert two tensors are close."""
+ assert torch.allclose(tensor_a, tensor_b, rtol=rtol, atol=atol), f"{msg} Tensors are not close"
+
+
+class TestKVPutE2E:
+ """End-to-end tests for kv_put functionality."""
+
+ def test_kv_put_single_sample_with_fields_and_tag(self, controller):
+ """Test putting a single sample with fields and tag."""
+ partition_id = "test_partition"
+ key = "sample_0"
+ # Use 1D tensors - kv_put with dict will auto-unsqueeze to add batch dimension
+ input_ids = torch.tensor([1, 2, 3])
+ attention_mask = torch.ones(3)
+ tag = {"global_steps": 0, "status": "running"}
+
+ # Put data using interface
+ tq.kv_put(
+ key=key,
+ partition_id=partition_id,
+ fields={"input_ids": input_ids, "attention_mask": attention_mask},
+ tag=tag,
+ )
+
+ # Verify via controller internal state
+ partition = get_controller_partition(controller, partition_id)
+ assert partition is not None, "Partition should exist"
+
+ # Check key->global_index mapping
+ assert key in partition.keys_mapping, f"Key {key} should be in keys_mapping"
+ global_idx = partition.keys_mapping[key]
+ assert global_idx in partition.global_indexes, f"Global index {global_idx} should be in global_indexes"
+
+ # Check custom_meta (tag)
+ assert global_idx in partition.custom_meta, f"Custom meta should exist for global index {global_idx}"
+ assert partition.custom_meta[global_idx]["global_steps"] == 0
+ assert partition.custom_meta[global_idx]["status"] == "running"
+
+ # Check production status - fields should be marked as produced
+ assert "input_ids" in partition.field_name_mapping, "input_ids field should be registered"
+ assert "attention_mask" in partition.field_name_mapping, "attention_mask field should be registered"
+ input_ids_col_idx = partition.field_name_mapping["input_ids"]
+ assert partition.production_status[global_idx, input_ids_col_idx] == 1, "input_ids should be marked as produced"
+
+ # Retrieve and verify data via kv_get - tensors will have batch dimension
+ retrieved = tq.kv_get(keys=key, partition_id=partition_id)
+ assert "input_ids" in retrieved.keys()
+ assert "attention_mask" in retrieved.keys()
+ # After unsqueeze, tensors become 2D [batch_size=1, original_size]
+ expected_input_ids = input_ids.unsqueeze(0)
+ expected_attention_mask = attention_mask.unsqueeze(0)
+ assert_tensor_equal(retrieved["input_ids"], expected_input_ids)
+ assert_tensor_equal(retrieved["attention_mask"], expected_attention_mask)
+
+ def test_kv_put_update_tag_only(self, controller):
+ """Test updating only tag without providing fields."""
+ partition_id = "test_partition"
+ key = "sample_1"
+
+ # First put with fields - use TensorDict to avoid unsqueeze
+ single_data = TensorDict({"value": torch.tensor([[10]])}, batch_size=1)
+ tq.kv_put(key=key, partition_id=partition_id, fields=single_data, tag={"version": 1})
+
+ # Update only tag
+ new_tag = {"version": 2, "status": "updated"}
+ tq.kv_put(key=key, partition_id=partition_id, fields=None, tag=new_tag)
+
+ # Verify via controller
+ partition = get_controller_partition(controller, partition_id)
+ global_idx = partition.keys_mapping[key]
+ assert partition.custom_meta[global_idx]["version"] == 2
+ assert partition.custom_meta[global_idx]["status"] == "updated"
+
+ # Data should still be accessible
+ retrieved = tq.kv_get(keys=key, partition_id=partition_id)
+ assert_tensor_equal(retrieved["value"], torch.tensor([[10]]))
+
+ def test_kv_put_with_dict_fields(self, controller):
+ """Test kv_put with dict fields (auto-converted to TensorDict)."""
+ partition_id = "test_partition"
+ key = "sample_2"
+
+ # Put with dict fields - will be auto-unsqueezed
+ tq.kv_put(
+ key=key, partition_id=partition_id, fields={"data": torch.tensor([1, 2, 3, 4])}, tag={"type": "dict_test"}
+ )
+
+ # Verify - retrieved data will have batch dimension
+ retrieved = tq.kv_get(keys=key, partition_id=partition_id)
+ expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed
+ assert_tensor_equal(retrieved["data"], expected)
+
+
+class TestKVBatchPutE2E:
+ """End-to-end tests for kv_batch_put functionality."""
+
+ def test_kv_batch_put_multiple_samples(self, controller):
+ """Test batch putting multiple samples."""
+ partition_id = "test_partition"
+ keys = ["batch_0", "batch_1", "batch_2", "batch_3"]
+ batch_input_ids = torch.tensor(
+ [
+ [4, 5, 6],
+ [7, 8, 9],
+ [10, 11, 12],
+ [13, 14, 15],
+ ]
+ )
+ batch_attention_mask = torch.ones_like(batch_input_ids)
+
+ fields = TensorDict(
+ {
+ "input_ids": batch_input_ids,
+ "attention_mask": batch_attention_mask,
+ },
+ batch_size=4,
+ )
+
+ tags = [{"idx": i, "batch": True} for i in range(4)]
+
+ # Batch put using interface
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags)
+
+ # Verify via controller
+ partition = get_controller_partition(controller, partition_id)
+ assert partition is not None
+
+ # All keys should be registered
+ for key in keys:
+ assert key in partition.keys_mapping, f"Key {key} should be in keys_mapping"
+
+ # Verify tags
+ for i, key in enumerate(keys):
+ global_idx = partition.keys_mapping[key]
+ assert partition.custom_meta[global_idx]["idx"] == i
+ assert partition.custom_meta[global_idx]["batch"] is True
+
+ # Verify all data via kv_get
+ retrieved = tq.kv_get(keys=keys, partition_id=partition_id)
+ assert_tensor_equal(retrieved["input_ids"], batch_input_ids)
+ assert_tensor_equal(retrieved["attention_mask"], batch_attention_mask)
+
+ def test_kv_batch_put_partial_update(self, controller):
+ """Test adding new fields to existing samples."""
+ partition_id = "test_partition"
+ keys = ["partial_0", "partial_1"]
+
+ # First put initial data
+ initial_data = TensorDict(
+ {
+ "input_ids": torch.tensor([[1, 2], [3, 4]]),
+ },
+ batch_size=2,
+ )
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=initial_data, tags=[{"v": 1}, {"v": 1}])
+
+ # Add new fields to subset of keys
+ new_fields = TensorDict(
+ {
+ "response": torch.tensor([[5, 6]]), # Only for 1 sample
+ },
+ batch_size=1,
+ )
+ tq.kv_batch_put(keys=[keys[1]], partition_id=partition_id, fields=new_fields, tags=[{"v": 2}])
+
+ # Verify via controller - only keys[1] should have response field
+ partition = get_controller_partition(controller, partition_id)
+ global_idx_1 = partition.keys_mapping[keys[1]]
+
+ # Check that fields were added
+ assert "response" in partition.field_name_mapping
+ response_col_idx = partition.field_name_mapping["response"]
+
+ # keys[0] should NOT have response marked as produced
+ global_idx_0 = partition.keys_mapping[keys[0]]
+ assert partition.production_status[global_idx_0, response_col_idx] == 0, "Keys[0] should not have response"
+
+ # keys[1] should have response marked as produced
+ assert partition.production_status[global_idx_1, response_col_idx] == 1, "Keys[1] should have response"
+
+
+class TestKVGetE2E:
+ """End-to-end tests for kv_get functionality."""
+
+ def test_kv_get_single_key(self, controller):
+ """Test getting data for a single key."""
+ partition_id = "test_partition"
+ key = "get_single"
+ # Use TensorDict to avoid auto-unsqueeze issue with dict input
+ expected_data = torch.tensor([[100, 200, 300]])
+ fields = TensorDict({"data": expected_data}, batch_size=1)
+
+ tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None)
+
+ retrieved = tq.kv_get(keys=key, partition_id=partition_id)
+ assert_tensor_equal(retrieved["data"], expected_data)
+
+ def test_kv_get_multiple_keys(self, controller):
+ """Test getting data for multiple keys."""
+ partition_id = "test_partition"
+ keys = ["get_multi_0", "get_multi_1", "get_multi_2"]
+ expected_data = torch.tensor([[1, 2], [3, 4], [5, 6]])
+
+ fields = TensorDict({"data": expected_data}, batch_size=3)
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}])
+
+ retrieved = tq.kv_get(keys=keys, partition_id=partition_id)
+ assert_tensor_equal(retrieved["data"], expected_data)
+
+ def test_kv_get_specific_fields(self, controller):
+ """Test getting only specific fields."""
+ partition_id = "test_partition"
+ key = "get_fields"
+ # Use TensorDict to avoid auto-unsqueeze issue
+ input_ids = torch.tensor([[1, 2, 3]])
+ attention_mask = torch.ones(1, 3)
+ response = torch.tensor([[10, 20]])
+
+ fields = TensorDict(
+ {"input_ids": input_ids, "attention_mask": attention_mask, "response": response}, batch_size=1
+ )
+
+ # Put all fields
+ tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None)
+
+ # Get only input_ids
+ retrieved = tq.kv_get(keys=key, partition_id=partition_id, fields="input_ids")
+ assert "input_ids" in retrieved.keys()
+ assert "attention_mask" not in retrieved.keys()
+ assert "response" not in retrieved.keys()
+ assert_tensor_equal(retrieved["input_ids"], input_ids)
+
+ # Get multiple specific fields
+ retrieved = tq.kv_get(keys=key, partition_id=partition_id, fields=["input_ids", "response"])
+ assert "input_ids" in retrieved.keys()
+ assert "response" in retrieved.keys()
+ assert "attention_mask" not in retrieved.keys()
+
+ def test_kv_get_nonexistent_key(self, controller):
+ """Test that getting data for non-existent key returns empty result."""
+ partition_id = "test_partition"
+
+ # Try to get data for a key that doesn't exist - should return empty or raise error
+ try:
+ retrieved = tq.kv_get(keys="nonexistent_key", partition_id=partition_id)
+ # If it returns, it should be empty
+ assert retrieved.batch_size[0] == 0
+ except RuntimeError as e:
+ # Or it might raise an error about keys not found
+ assert "not found" in str(e).lower() or "empty" in str(e).lower()
+
+
+class TestKVListE2E:
+ """End-to-end tests for kv_list functionality."""
+
+ def test_kv_list_all_keys(self, controller):
+ """Test listing all keys in a partition."""
+ partition_id = "test_partition"
+ keys = ["list_0", "list_1", "list_2"]
+
+ for i, key in enumerate(keys):
+ tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag={"id": i})
+
+ # List all keys
+ listed_keys, tags = tq.kv_list(partition_id=partition_id)
+
+ assert len(listed_keys) == 3
+ for key in keys:
+ assert key in listed_keys
+
+ # Verify tags match
+ for i, (key, tag) in enumerate(zip(listed_keys, tags, strict=False)):
+ assert tag["id"] == i
+
+ def test_kv_list_empty_partition(self):
+ """Test listing empty partition."""
+ partition_id = "test_partition_empty"
+
+ keys, tags = tq.kv_list(partition_id=partition_id)
+
+ assert len(keys) == 0
+ assert len(tags) == 0
+
+
+class TestKVClearE2E:
+ """End-to-end tests for kv_clear functionality."""
+
+ def test_kv_clear_single_key(self, controller):
+ """Test clearing a single key."""
+ partition_id = "test_partition"
+ key = "clear_single"
+ other_key = "clear_other"
+
+ tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[1]])}, tag={"id": "single"})
+ tq.kv_put(key=other_key, partition_id=partition_id, fields={"data": torch.tensor([[2]])}, tag={"id": "other"})
+
+ # Clear single key
+ tq.kv_clear(keys=key, partition_id=partition_id)
+
+ # Verify via kv_list
+ listed_keys, _ = tq.kv_list(partition_id=partition_id)
+ assert key not in listed_keys
+ assert other_key in listed_keys
+
+ # Verify via controller - key should be removed
+ partition = get_controller_partition(controller, partition_id)
+ assert key not in partition.keys_mapping
+
+ def test_kv_clear_multiple_keys(self, controller):
+ """Test clearing multiple keys."""
+ partition_id = "test_partition"
+ keys = ["clear_multi_0", "clear_multi_1", "clear_multi_2", "clear_multi_3"]
+
+ for i, key in enumerate(keys):
+ tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag=None)
+
+ # Clear first 2 keys
+ tq.kv_clear(keys=keys[:2], partition_id=partition_id)
+
+ # Verify
+ listed_keys, _ = tq.kv_list(partition_id=partition_id)
+ assert len(listed_keys) == 2
+ assert keys[0] not in listed_keys
+ assert keys[1] not in listed_keys
+ assert keys[2] in listed_keys
+ assert keys[3] in listed_keys
+
+
+class TestKVTagsE2E:
+ """End-to-end tests for tag functionality."""
+
+ def test_tag_preservation_across_operations(self, controller):
+ """Test that tags are preserved and updated correctly."""
+ partition_id = "test_partition"
+ key = "tag_test"
+
+ # Put with initial tag
+ tq.kv_put(
+ key=key,
+ partition_id=partition_id,
+ fields={"data": torch.tensor([[1]])},
+ tag={"version": 1, "status": "init"},
+ )
+
+ # Update with new tag (keeping version incrementing)
+ tq.kv_put(key=key, partition_id=partition_id, fields=None, tag={"version": 2, "status": "updated"})
+
+ # Verify tag is updated
+ partition = get_controller_partition(controller, partition_id)
+ global_idx = partition.keys_mapping[key]
+ assert partition.custom_meta[global_idx]["version"] == 2
+ assert partition.custom_meta[global_idx]["status"] == "updated"
+
+ def test_tag_retrieval_via_kv_list(self):
+ """Test retrieving tags via kv_list."""
+ partition_id = "test_partition"
+ keys = ["tag_list_0", "tag_list_1", "tag_list_2"]
+
+ expected_tags = [
+ {"score": 0.9, "label": "A"},
+ {"score": 0.85, "label": "B"},
+ {"score": 0.95, "label": "C"},
+ ]
+
+ for key, tag in zip(keys, expected_tags, strict=False):
+ tq.kv_put(key=key, partition_id=partition_id, fields={"x": torch.tensor([[1]])}, tag=tag)
+
+ # List and verify tags
+ listed_keys, tags = tq.kv_list(partition_id=partition_id)
+
+ for key, expected_tag in zip(keys, expected_tags, strict=False):
+ assert key in listed_keys
+ idx = listed_keys.index(key)
+ assert tags[idx] == expected_tag
+
+
+class TestKVE2ECornerCases:
+ """End-to-end tests for corner cases."""
+
+ def test_key_to_global_index_mapping_consistency(self, controller):
+ """Test that key->global_index mapping is consistent across operations."""
+ partition_id = "test_partition"
+ keys = ["map_0", "map_1", "map_2", "map_3"]
+
+ # Put all keys
+ tq.kv_batch_put(
+ keys=keys,
+ partition_id=partition_id,
+ fields=TensorDict({"data": torch.randn(4, 5)}, batch_size=4),
+ tags=[{"i": i} for i in range(4)],
+ )
+
+ # Verify mapping consistency via controller
+ partition = get_controller_partition(controller, partition_id)
+
+ for key in keys:
+ assert key in partition.keys_mapping
+ global_idx = partition.keys_mapping[key]
+ assert global_idx in partition.revert_keys_mapping
+ assert partition.revert_keys_mapping[global_idx] == key
+
+ def test_field_expansion_across_samples(self, controller):
+ """Test that new fields can be added across samples."""
+ partition_id = "test_partition"
+ keys = ["expand_0", "expand_1"]
+
+ # Put initial fields
+ tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_a": torch.tensor([[1]])}, tag=None)
+
+ # Add new field to first key
+ tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_b": torch.tensor([[2]])}, tag=None)
+
+ # Add different field to second key
+ tq.kv_put(key=keys[1], partition_id=partition_id, fields={"field_c": torch.tensor([[3]])}, tag=None)
+
+ # Verify field expansion in controller
+ partition = get_controller_partition(controller, partition_id)
+
+ # All fields should be registered
+ assert "field_a" in partition.field_name_mapping
+ assert "field_b" in partition.field_name_mapping
+ assert "field_c" in partition.field_name_mapping
+
+ def test_empty_tag_list(self):
+ """Test operations with empty tags."""
+ partition_id = "test_partition"
+ key = "empty_tag"
+
+ # Use 1D tensor - will be auto-unsqueezed to 2D
+ tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([1])}, tag={})
+
+ # Should work and data should be retrievable - will be 2D after unsqueeze
+ retrieved = tq.kv_get(keys=key, partition_id=partition_id)
+ assert_tensor_equal(retrieved["data"], torch.tensor([[1]]))
+
+ def test_large_batch_put_and_get(self):
+ """Test putting and getting a large batch of samples."""
+ partition_id = "test_partition"
+ num_samples = 100
+ keys = [f"large_{i}" for i in range(num_samples)]
+
+ # Create batch data
+ data = TensorDict(
+ {
+ "input_ids": torch.randn(num_samples, 10),
+ "attention_mask": torch.ones(num_samples, 10),
+ },
+ batch_size=num_samples,
+ )
+
+ tags = [{"idx": i} for i in range(num_samples)]
+
+ # Batch put
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data, tags=tags)
+
+ # Batch get all
+ retrieved = tq.kv_get(keys=keys, partition_id=partition_id)
+
+ assert retrieved["input_ids"].shape == (num_samples, 10)
+ assert retrieved["attention_mask"].shape == (num_samples, 10)
+
+ # Verify specific samples
+ assert_tensor_equal(retrieved["input_ids"][0], data["input_ids"][0])
+ assert_tensor_equal(retrieved["input_ids"][99], data["input_ids"][99])
+
+ def test_controller_partition_synchronization(self, controller):
+ """Test that controller partition state is synchronized with operations."""
+ partition_id = "test_partition"
+ key = "sync_test"
+
+ # Put data
+ tq.kv_put(key=key, partition_id=partition_id, fields={"x": torch.tensor([[42]])}, tag={"sync": True})
+
+ # Get snapshot before clear
+ partition_before = get_controller_partition(controller, partition_id)
+ global_idx = partition_before.keys_mapping[key]
+ assert partition_before.production_status[global_idx, partition_before.field_name_mapping["x"]] == 1
+
+ # Clear
+ tq.kv_clear(keys=key, partition_id=partition_id)
+
+ # Get snapshot after clear
+ partition_after = get_controller_partition(controller, partition_id)
+ assert key not in partition_after.keys_mapping
+
+
+def run_tests():
+ """Run all e2e tests manually for debugging."""
+ pytest.main([__file__, "-v", "-s"])
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/tests/test_kv_interface.py b/tests/test_kv_interface.py
deleted file mode 100644
index 6b9eee1..0000000
--- a/tests/test_kv_interface.py
+++ /dev/null
@@ -1,539 +0,0 @@
-# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
-# Copyright 2025 The TransferQueue Team
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Unit tests for kv interface in transfer_queue.interface."""
-
-import asyncio
-import sys
-from pathlib import Path
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-import torch
-from tensordict import TensorDict
-
-# Setup path
-parent_dir = Path(__file__).resolve().parent.parent
-sys.path.append(str(parent_dir))
-
-from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402
-from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402
-
-
-def create_batch_meta(global_indexes, fields_data):
- """Helper to create BatchMeta for testing."""
- samples = []
- for sample_id, global_idx in enumerate(global_indexes):
- fields_dict = {}
- for field_name, tensor in fields_data.items():
- field_meta = FieldMeta(
- name=field_name,
- dtype=tensor.dtype,
- shape=tensor.shape,
- production_status=ProductionStatus.READY_FOR_CONSUME,
- )
- fields_dict[field_name] = field_meta
- sample = SampleMeta(
- partition_id="test_partition",
- global_index=global_idx,
- fields=fields_dict,
- )
- samples.append(sample)
- return BatchMeta(samples=samples)
-
-
-class TestKVPut:
- """Tests for kv_put function."""
-
- def test_kv_put_with_fields(self):
- """Test kv_put with fields parameter."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
-
- tensor_data = TensorDict(
- {"text": torch.tensor([[1, 2, 3]]), "label": torch.tensor([0])},
- batch_size=[1],
- )
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_put
-
- kv_put(key="test_key", partition_id="partition_1", fields=tensor_data, tag={"type": "test"})
-
- # Verify kv_retrieve_keys was called
- mock_client.kv_retrieve_keys.assert_called_once_with(keys=["test_key"], partition_id="partition_1", create=True)
-
- # Verify update_custom_meta was called
- mock_batch_meta.update_custom_meta.assert_called_once()
-
- # Verify put was called
- mock_client.put.assert_called_once()
-
- def test_kv_put_with_dict_fields(self):
- """Test kv_put converts dict to TensorDict correctly."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_put
-
- # Test with dict containing tensor
- kv_put(
- key="test_key",
- partition_id="partition_1",
- fields={"text": torch.tensor([1, 2, 3])},
- tag=None,
- )
-
- # Verify put was called
- mock_client.put.assert_called_once()
- call_args = mock_client.put.call_args
- fields_arg = call_args[0][0]
- assert "text" in fields_arg
- # The dict should be converted to TensorDict
- assert isinstance(fields_arg, TensorDict)
-
- def test_kv_put_with_tag_only(self):
- """Test kv_put with only tag parameter (no fields)."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_put
-
- kv_put(key="test_key", partition_id="partition_1", fields=None, tag={"score": 0.9})
-
- # Verify put was NOT called (only set_custom_meta)
- mock_client.put.assert_not_called()
- mock_client.set_custom_meta.assert_called_once_with(mock_batch_meta)
-
- def test_kv_put_raises_error_without_fields_and_tag(self):
- """Test kv_put raises ValueError when neither fields nor tag provided."""
- mock_client = MagicMock()
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_put
-
- with pytest.raises(ValueError, match="Please provide at least one parameter"):
- kv_put(key="test_key", partition_id="partition_1", fields=None, tag=None)
-
-
-class TestKVBatchPut:
- """Tests for kv_batch_put function."""
-
- def test_kv_batch_put_success(self):
- """Test kv_batch_put with valid inputs."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
-
- batch_data = TensorDict(
- {
- "text": torch.tensor([[1, 2], [3, 4], [5, 6]]),
- "label": torch.tensor([0, 1, 2]),
- },
- batch_size=[3],
- )
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_batch_put
-
- keys = ["key1", "key2", "key3"]
- tags = [{"tag": "v1"}, {"tag": "v2"}, {"tag": "v3"}]
-
- kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags)
-
- mock_client.kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=True)
- mock_batch_meta.update_custom_meta.assert_called_once_with(tags)
- mock_client.put.assert_called_once()
-
- def test_kv_batch_put_tags_length_mismatch(self):
- """Test kv_batch_put raises error when tags length doesn't match keys."""
- mock_client = MagicMock()
-
- batch_data = TensorDict(
- {
- "text": torch.tensor([[1, 2], [3, 4], [5, 6]]),
- "label": torch.tensor([0, 1, 2]),
- },
- batch_size=[3],
- )
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_batch_put
-
- keys = ["key1", "key2", "key3"]
- tags = [{"tag": "v1"}, {"tag": "v2"}] # Only 2 tags for 3 keys
-
- with pytest.raises(ValueError, match="does not match length of tags"):
- kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags)
-
-
-class TestKVGet:
- """Tests for kv_get function."""
-
- def test_kv_get_single_key(self):
- """Test kv_get with single key."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
- mock_client.get_data.return_value = TensorDict({"data": torch.tensor([1, 2])})
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_get
-
- kv_get(keys="test_key", partition_id="partition_1")
-
- # keys is passed directly (not wrapped in list) for single key
- mock_client.kv_retrieve_keys.assert_called_once_with(keys="test_key", partition_id="partition_1", create=False)
- mock_client.get_data.assert_called_once_with(mock_batch_meta)
-
- def test_kv_get_multiple_keys(self):
- """Test kv_get with multiple keys."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
- mock_client.get_data.return_value = TensorDict({"data": torch.tensor([1, 2])})
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_get
-
- keys = ["key1", "key2", "key3"]
- kv_get(keys=keys, partition_id="partition_1")
-
- mock_client.kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=False)
-
- def test_kv_get_with_fields(self):
- """Test kv_get with specific fields."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_batch_meta.select_fields = MagicMock(return_value=mock_batch_meta)
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
- mock_client.get_data.return_value = TensorDict({"text": torch.tensor([1, 2])})
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_get
-
- kv_get(keys="test_key", partition_id="partition_1", fields="text")
-
- mock_batch_meta.select_fields.assert_called_once_with(["text"])
-
-
-class TestKVList:
- """Tests for kv_list function."""
-
- def test_kv_list_with_keys(self):
- """Test kv_list returns keys and custom_meta."""
- mock_client = MagicMock()
- mock_client.kv_list.return_value = ["key1", "key2", "key3"]
- mock_batch_meta = MagicMock()
- mock_batch_meta.global_indexes = [0, 1, 2]
- mock_batch_meta.get_all_custom_meta = MagicMock(return_value={0: {}, 1: {}, 2: {}})
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_list
-
- keys, custom_meta = kv_list(partition_id="partition_1")
-
- assert keys == ["key1", "key2", "key3"]
- assert len(custom_meta) == 3
-
- def test_kv_list_empty_partition(self):
- """Test kv_list returns None when partition is empty."""
- mock_client = MagicMock()
- mock_client.kv_list.return_value = []
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_list
-
- keys, custom_meta = kv_list(partition_id="empty_partition")
-
- assert keys is None
- assert custom_meta is None
-
-
-class TestKVClear:
- """Tests for kv_clear function."""
-
- def test_kv_clear_single_key(self):
- """Test kv_clear with single key."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_batch_meta.size = 1
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_clear
-
- kv_clear(keys="test_key", partition_id="partition_1")
-
- mock_client.kv_retrieve_keys.assert_called_once_with(
- keys=["test_key"], partition_id="partition_1", create=False
- )
- mock_client.clear_samples.assert_called_once_with(mock_batch_meta)
-
- def test_kv_clear_multiple_keys(self):
- """Test kv_clear with multiple keys."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_batch_meta.size = 3
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_clear
-
- kv_clear(keys=["key1", "key2", "key3"], partition_id="partition_1")
-
- mock_client.kv_retrieve_keys.assert_called_once_with(
- keys=["key1", "key2", "key3"], partition_id="partition_1", create=False
- )
- mock_client.clear_samples.assert_called_once()
-
-
-class TestAsyncKVPut:
- """Tests for async_kv_put function."""
-
- def test_async_kv_put_with_fields(self):
- """Test async_kv_put with fields parameter."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
- mock_client.async_put = AsyncMock()
- mock_client.async_set_custom_meta = AsyncMock()
-
- tensor_data = TensorDict(
- {"text": torch.tensor([[1, 2, 3]]), "label": torch.tensor([0])},
- batch_size=[1],
- )
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import async_kv_put
-
- asyncio.run(
- async_kv_put(key="test_key", partition_id="partition_1", fields=tensor_data, tag={"type": "test"})
- )
-
- mock_client.async_kv_retrieve_keys.assert_called_once_with(
- keys=["test_key"], partition_id="partition_1", create=True
- )
- mock_batch_meta.update_custom_meta.assert_called_once()
- mock_client.async_put.assert_called_once()
-
- def test_async_kv_put_with_tag_only(self):
- """Test async_kv_put with only tag (no fields)."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
- mock_client.async_put = AsyncMock()
- mock_client.async_set_custom_meta = AsyncMock()
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import async_kv_put
-
- asyncio.run(async_kv_put(key="test_key", partition_id="partition_1", fields=None, tag={"score": 0.9}))
-
- mock_client.async_put.assert_not_called()
- mock_client.async_set_custom_meta.assert_called_once_with(mock_batch_meta)
-
-
-class TestAsyncKVBatchPut:
- """Tests for async_kv_batch_put function."""
-
- def test_async_kv_batch_put_success(self):
- """Test async_kv_batch_put with valid inputs."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
- mock_client.async_put = AsyncMock()
- mock_client.async_set_custom_meta = AsyncMock()
-
- batch_data = TensorDict(
- {
- "text": torch.tensor([[1, 2], [3, 4], [5, 6]]),
- "label": torch.tensor([0, 1, 2]),
- },
- batch_size=[3],
- )
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import async_kv_batch_put
-
- keys = ["key1", "key2", "key3"]
- tags = [{"tag": "v1"}, {"tag": "v2"}, {"tag": "v3"}]
-
- asyncio.run(async_kv_batch_put(keys=keys, partition_id="partition_1", fields=batch_data, tags=tags))
-
- mock_client.async_kv_retrieve_keys.assert_called_once_with(keys=keys, partition_id="partition_1", create=True)
- mock_batch_meta.update_custom_meta.assert_called_once_with(tags)
- mock_client.async_put.assert_called_once()
-
-
-class TestAsyncKVGet:
- """Tests for async_kv_get function."""
-
- def test_async_kv_get_single_key(self):
- """Test async_kv_get with single key."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
- mock_client.async_get_data = AsyncMock(return_value=TensorDict({"data": torch.tensor([1, 2])}))
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import async_kv_get
-
- asyncio.run(async_kv_get(keys="test_key", partition_id="partition_1"))
-
- # keys is passed directly (not wrapped in list) for single key
- mock_client.async_kv_retrieve_keys.assert_called_once_with(
- keys="test_key", partition_id="partition_1", create=False
- )
- mock_client.async_get_data.assert_called_once_with(mock_batch_meta)
-
-
-class TestAsyncKVList:
- """Tests for async_kv_list function."""
-
- def test_async_kv_list_with_keys(self):
- """Test async_kv_list returns keys and custom_meta."""
- mock_client = MagicMock()
- mock_client.async_kv_list = AsyncMock(return_value=["key1", "key2", "key3"])
- mock_batch_meta = MagicMock()
- mock_batch_meta.global_indexes = [0, 1, 2]
- mock_batch_meta.get_all_custom_meta = MagicMock(return_value={0: {}, 1: {}, 2: {}})
- mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import async_kv_list
-
- keys, custom_meta = asyncio.run(async_kv_list(partition_id="partition_1"))
-
- assert keys == ["key1", "key2", "key3"]
- assert len(custom_meta) == 3
-
- def test_async_kv_list_empty_partition(self):
- """Test async_kv_list returns None when partition is empty."""
- mock_client = MagicMock()
- mock_client.async_kv_list = AsyncMock(return_value=[])
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import async_kv_list
-
- keys, custom_meta = asyncio.run(async_kv_list(partition_id="empty_partition"))
-
- assert keys is None
- assert custom_meta is None
-
-
-class TestAsyncKVClear:
- """Tests for async_kv_clear function."""
-
- def test_async_kv_clear_single_key(self):
- """Test async_kv_clear with single key."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_batch_meta.size = 1
- mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
- mock_client.async_clear_samples = AsyncMock()
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import async_kv_clear
-
- asyncio.run(async_kv_clear(keys="test_key", partition_id="partition_1"))
-
- mock_client.async_kv_retrieve_keys.assert_called_once_with(
- keys=["test_key"], partition_id="partition_1", create=False
- )
- mock_client.async_clear_samples.assert_called_once_with(mock_batch_meta)
-
- def test_async_kv_clear_multiple_keys(self):
- """Test async_kv_clear with multiple keys."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_batch_meta.size = 3
- mock_client.async_kv_retrieve_keys = AsyncMock(return_value=mock_batch_meta)
- mock_client.async_clear_samples = AsyncMock()
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import async_kv_clear
-
- asyncio.run(async_kv_clear(keys=["key1", "key2", "key3"], partition_id="partition_1"))
-
- mock_client.async_kv_retrieve_keys.assert_called_once()
- mock_client.async_clear_samples.assert_called_once()
-
-
-class TestKVInterfaceDictConversion:
- """Tests for dict to TensorDict conversion in kv_put."""
-
- def test_kv_put_with_nontensor_value(self):
- """Test kv_put converts non-tensor values using NonTensorStack."""
- mock_client = MagicMock()
- mock_batch_meta = MagicMock()
- mock_client.kv_retrieve_keys.return_value = mock_batch_meta
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_put
-
- # Test with non-tensor value (like a string or list)
- kv_put(
- key="test_key",
- partition_id="partition_1",
- fields={"meta": {"key": "value"}},
- tag=None,
- )
-
- # Verify put was called
- mock_client.put.assert_called_once()
- call_args = mock_client.put.call_args
- fields_arg = call_args[0][0]
- # The dict should be converted to TensorDict
- assert isinstance(fields_arg, TensorDict)
- assert "meta" in fields_arg
-
- def test_kv_put_rejects_nested_tensor(self):
- """Test kv_put raises ValueError for nested tensors (requires batch_put)."""
- mock_client = MagicMock()
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_put
-
- nested_tensor = torch.nested.nested_tensor([[1, 2], [3, 4]])
-
- with pytest.raises(ValueError, match="Please use.*kv_batch_put"):
- kv_put(
- key="test_key",
- partition_id="partition_1",
- fields={"nested": nested_tensor},
- tag=None,
- )
-
- def test_kv_put_invalid_fields_type(self):
- """Test kv_put raises ValueError for invalid fields type."""
- mock_client = MagicMock()
-
- with patch("transfer_queue.interface._maybe_create_transferqueue_client", return_value=mock_client):
- from transfer_queue.interface import kv_put
-
- with pytest.raises(ValueError, match="field can only be dict or TensorDict"):
- kv_put(
- key="test_key",
- partition_id="partition_1",
- fields="invalid_string",
- tag=None,
- )
From 4d479ab896128dc7c01ec5cf4f140bc7652b6765 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 21:38:55 +0800
Subject: [PATCH 11/34] try fix tutorial ci
Signed-off-by: 0oshowero0
---
.github/workflows/tutorial-check.yml | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml
index 1202e21..ed8a2ba 100644
--- a/.github/workflows/tutorial-check.yml
+++ b/.github/workflows/tutorial-check.yml
@@ -36,4 +36,11 @@ jobs:
- name: Run tutorials
run: |
export TQ_NUM_THREADS=2
- for file in tutorial/*.py; do python3 "$file"; done
\ No newline at end of file
+ export RAY_DEDUP_LOGS=0
+ ray stop --force || true
+ for file in tutorial/*.py; do
+ echo "================ Running $file ================"
+ python3 "$file"
+ ray stop --force
+ sleep 5
+ done
\ No newline at end of file
From 5556278548237f84207741475f696ef799dc8de1 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 21:54:35 +0800
Subject: [PATCH 12/34] try only run 06
Signed-off-by: 0oshowero0
---
.github/workflows/tutorial-check.yml | 8 +-------
1 file changed, 1 insertion(+), 7 deletions(-)
diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml
index ed8a2ba..0c72b76 100644
--- a/.github/workflows/tutorial-check.yml
+++ b/.github/workflows/tutorial-check.yml
@@ -37,10 +37,4 @@ jobs:
run: |
export TQ_NUM_THREADS=2
export RAY_DEDUP_LOGS=0
- ray stop --force || true
- for file in tutorial/*.py; do
- echo "================ Running $file ================"
- python3 "$file"
- ray stop --force
- sleep 5
- done
\ No newline at end of file
+ python tutorial/06_streaming_dataloader.py
\ No newline at end of file
From 4f6ce684df0d2e83cefc917b257ddea5fc9e637a Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 21:59:46 +0800
Subject: [PATCH 13/34] try add sleep in tutorial
Signed-off-by: 0oshowero0
---
tutorial/06_streaming_dataloader.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/tutorial/06_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py
index d95a3c8..c0131ef 100644
--- a/tutorial/06_streaming_dataloader.py
+++ b/tutorial/06_streaming_dataloader.py
@@ -319,6 +319,8 @@ def main():
print("\n[Phase 2] Starting data generation...")
generate_worker_handlers = start_all_generate_actors()
+ time.sleep(10)
+
# Step 3: Launch data consumption actors
print("\n[Phase 3] Starting data consumption...")
update_worker_handlers = start_all_update_actors()
From e7ea8e29680b72a2650478c049b8dfceb3bb1570 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 22:20:53 +0800
Subject: [PATCH 14/34] fix race condition
Signed-off-by: 0oshowero0
---
transfer_queue/controller.py | 2 ++
tutorial/06_streaming_dataloader.py | 2 --
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py
index 1f01482..ea803f5 100644
--- a/transfer_queue/controller.py
+++ b/transfer_queue/controller.py
@@ -547,6 +547,8 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
)
if mask:
+ with self.data_status_lock:
+ self.ensure_samples_capacity(max(partition_global_index) + 1)
consumption_status = consumption_status[partition_global_index]
return partition_global_index, consumption_status
diff --git a/tutorial/06_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py
index c0131ef..d95a3c8 100644
--- a/tutorial/06_streaming_dataloader.py
+++ b/tutorial/06_streaming_dataloader.py
@@ -319,8 +319,6 @@ def main():
print("\n[Phase 2] Starting data generation...")
generate_worker_handlers = start_all_generate_actors()
- time.sleep(10)
-
# Step 3: Launch data consumption actors
print("\n[Phase 3] Starting data consumption...")
update_worker_handlers = start_all_update_actors()
From 29f5a1ed5e340742314c72d57cf4f4f44b2ce86e Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 22:24:07 +0800
Subject: [PATCH 15/34] fix race condition
Signed-off-by: 0oshowero0
---
transfer_queue/controller.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py
index ea803f5..538ba21 100644
--- a/transfer_queue/controller.py
+++ b/transfer_queue/controller.py
@@ -540,8 +540,6 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8)
# Get consumption status for requested task
- consumption_status = self.consumption_status[task_name]
-
partition_global_index = torch.tensor(
sorted(self.global_indexes | self.pre_allocated_global_indexes), dtype=torch.long
)
@@ -549,8 +547,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
if mask:
with self.data_status_lock:
self.ensure_samples_capacity(max(partition_global_index) + 1)
- consumption_status = consumption_status[partition_global_index]
-
+ consumption_status = self.consumption_status[task_name][partition_global_index]
+ else:
+ consumption_status = self.consumption_status[task_name]
return partition_global_index, consumption_status
def reset_consumption(self, task_name: Optional[str] = None):
From e1615715574b22b9bb81570f0005c223f12fb403 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sat, 7 Feb 2026 22:26:35 +0800
Subject: [PATCH 16/34] recover ci
Signed-off-by: 0oshowero0
---
.github/workflows/tutorial-check.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml
index 0c72b76..c24cab8 100644
--- a/.github/workflows/tutorial-check.yml
+++ b/.github/workflows/tutorial-check.yml
@@ -37,4 +37,4 @@ jobs:
run: |
export TQ_NUM_THREADS=2
export RAY_DEDUP_LOGS=0
- python tutorial/06_streaming_dataloader.py
\ No newline at end of file
+ for file in tutorial/*.py; do python3 "$file"; done
\ No newline at end of file
From 928edda0f8408e17ec25c9b64af9b91e8359b1e8 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sun, 8 Feb 2026 12:36:11 +0800
Subject: [PATCH 17/34] fix batchmeta deserial
Signed-off-by: 0oshowero0
---
transfer_queue/client.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/transfer_queue/client.py b/transfer_queue/client.py
index f0c899b..2af07a7 100644
--- a/transfer_queue/client.py
+++ b/transfer_queue/client.py
@@ -970,6 +970,7 @@ async def async_kv_retrieve_keys(
if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE:
metadata = response_msg.body.get("metadata", BatchMeta.empty())
+ metadata = BatchMeta.from_dict(metadata) if isinstance(metadata, dict) else metadata
return metadata
else:
raise RuntimeError(
From ac9779505aa9345b47d1eafb697a7db88f8d358e Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sun, 8 Feb 2026 13:17:22 +0800
Subject: [PATCH 18/34] update kv interface test
Signed-off-by: 0oshowero0
---
tests/e2e/test_kv_interface_e2e.py | 275 ++++++++++++-----------------
transfer_queue/__init__.py | 8 +-
transfer_queue/interface.py | 12 +-
tutorial/02_kv_interface.py | 18 +-
4 files changed, 131 insertions(+), 182 deletions(-)
diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py
index 458befa..3918bbe 100644
--- a/tests/e2e/test_kv_interface_e2e.py
+++ b/tests/e2e/test_kv_interface_e2e.py
@@ -16,7 +16,7 @@
"""End-to-end tests for KV interface in transfer_queue.interface.
This test module validates the KV interface functionality by:
-1. Using external interfaces (kv_put, kv_batch_put, kv_get, kv_list, kv_clear) for read/write
+1. Using external interfaces (kv_put, kv_batch_put, kv_batch_get, kv_list, kv_clear) for read/write
2. Verifying correctness by calling TransferQueueController's internal methods directly
"""
@@ -92,10 +92,44 @@ def assert_tensor_close(tensor_a, tensor_b, rtol=1e-5, atol=1e-8, msg=""):
class TestKVPutE2E:
"""End-to-end tests for kv_put functionality."""
+ def test_kv_put_with_dict_fields(self, controller):
+ """Test kv_put with dict fields (auto-converted to TensorDict)."""
+ partition_id = "test_partition"
+ key = "sample_0"
+
+ # Put with dict fields - will be auto-unsqueezed
+ tq.kv_put(
+ key=key, partition_id=partition_id, fields={"data": torch.tensor([1, 2, 3, 4])}, tag={"type": "dict_test"}
+ )
+
+ # Verify - retrieved data will have batch dimension
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id)
+ expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed
+ assert_tensor_equal(retrieved["data"], expected)
+
+ def test_kv_put_with_tensordict_fields(self, controller):
+ """Test kv_put with tensordict fields."""
+ partition_id = "test_partition"
+ key = "sample_1"
+
+ tensordict_data = TensorDict(
+ {
+ "input_ids": torch.tensor([[1, 2, 3, 4]]),
+ },
+ batch_size=1,
+ )
+ # Put with dict fields - will be auto-unsqueezed
+ tq.kv_put(key=key, partition_id=partition_id, fields=tensordict_data, tag={"type": "tensordict_test"})
+
+ # Verify - retrieved data will have batch dimension
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id)
+ expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed
+ assert_tensor_equal(retrieved["input_ids"], expected)
+
def test_kv_put_single_sample_with_fields_and_tag(self, controller):
"""Test putting a single sample with fields and tag."""
partition_id = "test_partition"
- key = "sample_0"
+ key = "sample_2"
# Use 1D tensors - kv_put with dict will auto-unsqueeze to add batch dimension
input_ids = torch.tensor([1, 2, 3])
attention_mask = torch.ones(3)
@@ -129,8 +163,8 @@ def test_kv_put_single_sample_with_fields_and_tag(self, controller):
input_ids_col_idx = partition.field_name_mapping["input_ids"]
assert partition.production_status[global_idx, input_ids_col_idx] == 1, "input_ids should be marked as produced"
- # Retrieve and verify data via kv_get - tensors will have batch dimension
- retrieved = tq.kv_get(keys=key, partition_id=partition_id)
+ # Retrieve and verify data via kv_batch_get - tensors will have batch dimension
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id)
assert "input_ids" in retrieved.keys()
assert "attention_mask" in retrieved.keys()
# After unsqueeze, tensors become 2D [batch_size=1, original_size]
@@ -142,9 +176,9 @@ def test_kv_put_single_sample_with_fields_and_tag(self, controller):
def test_kv_put_update_tag_only(self, controller):
"""Test updating only tag without providing fields."""
partition_id = "test_partition"
- key = "sample_1"
+ key = "sample_3"
- # First put with fields - use TensorDict to avoid unsqueeze
+ # First put with fields - use TensorDict as another example
single_data = TensorDict({"value": torch.tensor([[10]])}, batch_size=1)
tq.kv_put(key=key, partition_id=partition_id, fields=single_data, tag={"version": 1})
@@ -159,23 +193,42 @@ def test_kv_put_update_tag_only(self, controller):
assert partition.custom_meta[global_idx]["status"] == "updated"
# Data should still be accessible
- retrieved = tq.kv_get(keys=key, partition_id=partition_id)
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id)
assert_tensor_equal(retrieved["value"], torch.tensor([[10]]))
- def test_kv_put_with_dict_fields(self, controller):
- """Test kv_put with dict fields (auto-converted to TensorDict)."""
+ def test_kv_put_partial_update(self, controller):
+ """Test adding new fields to existing sample."""
partition_id = "test_partition"
- key = "sample_2"
+ key = "sample_4"
- # Put with dict fields - will be auto-unsqueezed
- tq.kv_put(
- key=key, partition_id=partition_id, fields={"data": torch.tensor([1, 2, 3, 4])}, tag={"type": "dict_test"}
+ # First put initial data
+ initial_data = TensorDict(
+ {
+ "input_ids": torch.tensor([[1, 2, 3, 4]]),
+ },
+ batch_size=1,
)
+ tq.kv_put(key=key, partition_id=partition_id, fields=initial_data, tag={"v": 1})
- # Verify - retrieved data will have batch dimension
- retrieved = tq.kv_get(keys=key, partition_id=partition_id)
- expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed
- assert_tensor_equal(retrieved["data"], expected)
+ # Add new fields to subset of keys
+ new_fields = TensorDict(
+ {
+ "response": torch.tensor([[5, 6]]),
+ },
+ batch_size=1,
+ )
+ tq.kv_put(key=key, partition_id=partition_id, fields=new_fields, tag={"v": 2})
+
+ # Verify via controller - only keys[1] should have response field
+ partition = get_controller_partition(controller, partition_id)
+ global_idx = partition.keys_mapping[key]
+
+ # Check that fields were added
+ assert "response" in partition.field_name_mapping
+ response_col_idx = partition.field_name_mapping["response"]
+
+ # key should have response marked as produced
+ assert partition.production_status[global_idx, response_col_idx] == 1, "Key should have response"
class TestKVBatchPutE2E:
@@ -222,8 +275,8 @@ def test_kv_batch_put_multiple_samples(self, controller):
assert partition.custom_meta[global_idx]["idx"] == i
assert partition.custom_meta[global_idx]["batch"] is True
- # Verify all data via kv_get
- retrieved = tq.kv_get(keys=keys, partition_id=partition_id)
+ # Verify all data via kv_batch_get
+ retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id)
assert_tensor_equal(retrieved["input_ids"], batch_input_ids)
assert_tensor_equal(retrieved["attention_mask"], batch_attention_mask)
@@ -267,9 +320,9 @@ def test_kv_batch_put_partial_update(self, controller):
class TestKVGetE2E:
- """End-to-end tests for kv_get functionality."""
+ """End-to-end tests for kv_batch_get functionality."""
- def test_kv_get_single_key(self, controller):
+ def test_kv_batch_get_single_key(self, controller):
"""Test getting data for a single key."""
partition_id = "test_partition"
key = "get_single"
@@ -279,10 +332,10 @@ def test_kv_get_single_key(self, controller):
tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None)
- retrieved = tq.kv_get(keys=key, partition_id=partition_id)
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id)
assert_tensor_equal(retrieved["data"], expected_data)
- def test_kv_get_multiple_keys(self, controller):
+ def test_kv_batch_get_multiple_keys(self, controller):
"""Test getting data for multiple keys."""
partition_id = "test_partition"
keys = ["get_multi_0", "get_multi_1", "get_multi_2"]
@@ -291,11 +344,25 @@ def test_kv_get_multiple_keys(self, controller):
fields = TensorDict({"data": expected_data}, batch_size=3)
tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}])
- retrieved = tq.kv_get(keys=keys, partition_id=partition_id)
+ retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id)
+ assert_tensor_equal(retrieved["data"], expected_data)
+
+ def test_kv_batch_get_partial_keys(self, controller):
+ """Test getting data for partial keys."""
+ partition_id = "test_partition"
+ keys = ["get_multi_3", "get_multi_4", "get_multi_5"]
+ partial_keys = ["get_multi_3", "get_multi_5"]
+ input_data = torch.tensor([[1, 2], [3, 4], [5, 6]])
+ expected_data = torch.tensor([[1, 2], [5, 6]])
+
+ fields = TensorDict({"data": input_data}, batch_size=3)
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}])
+
+ retrieved = tq.kv_batch_get(keys=partial_keys, partition_id=partition_id)
assert_tensor_equal(retrieved["data"], expected_data)
- def test_kv_get_specific_fields(self, controller):
- """Test getting only specific fields."""
+ def test_kv_batch_get_partial_fields(self, controller):
+ """Test getting only partial fields."""
partition_id = "test_partition"
key = "get_fields"
# Use TensorDict to avoid auto-unsqueeze issue
@@ -311,25 +378,27 @@ def test_kv_get_specific_fields(self, controller):
tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None)
# Get only input_ids
- retrieved = tq.kv_get(keys=key, partition_id=partition_id, fields="input_ids")
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, fields="input_ids")
assert "input_ids" in retrieved.keys()
assert "attention_mask" not in retrieved.keys()
assert "response" not in retrieved.keys()
assert_tensor_equal(retrieved["input_ids"], input_ids)
# Get multiple specific fields
- retrieved = tq.kv_get(keys=key, partition_id=partition_id, fields=["input_ids", "response"])
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, fields=["input_ids", "response"])
assert "input_ids" in retrieved.keys()
assert "response" in retrieved.keys()
assert "attention_mask" not in retrieved.keys()
+ assert_tensor_equal(retrieved["input_ids"], input_ids)
+ assert_tensor_equal(retrieved["response"], response)
- def test_kv_get_nonexistent_key(self, controller):
+ def test_kv_batch_get_nonexistent_key(self, controller):
"""Test that getting data for non-existent key returns empty result."""
partition_id = "test_partition"
# Try to get data for a key that doesn't exist - should return empty or raise error
try:
- retrieved = tq.kv_get(keys="nonexistent_key", partition_id=partition_id)
+ retrieved = tq.kv_batch_get(keys="nonexistent_key", partition_id=partition_id)
# If it returns, it should be empty
assert retrieved.batch_size[0] == 0
except RuntimeError as e:
@@ -392,6 +461,7 @@ def test_kv_clear_single_key(self, controller):
# Verify via controller - key should be removed
partition = get_controller_partition(controller, partition_id)
assert key not in partition.keys_mapping
+ assert other_key in partition.keys_mapping
def test_kv_clear_multiple_keys(self, controller):
"""Test clearing multiple keys."""
@@ -413,79 +483,9 @@ def test_kv_clear_multiple_keys(self, controller):
assert keys[3] in listed_keys
-class TestKVTagsE2E:
- """End-to-end tests for tag functionality."""
-
- def test_tag_preservation_across_operations(self, controller):
- """Test that tags are preserved and updated correctly."""
- partition_id = "test_partition"
- key = "tag_test"
-
- # Put with initial tag
- tq.kv_put(
- key=key,
- partition_id=partition_id,
- fields={"data": torch.tensor([[1]])},
- tag={"version": 1, "status": "init"},
- )
-
- # Update with new tag (keeping version incrementing)
- tq.kv_put(key=key, partition_id=partition_id, fields=None, tag={"version": 2, "status": "updated"})
-
- # Verify tag is updated
- partition = get_controller_partition(controller, partition_id)
- global_idx = partition.keys_mapping[key]
- assert partition.custom_meta[global_idx]["version"] == 2
- assert partition.custom_meta[global_idx]["status"] == "updated"
-
- def test_tag_retrieval_via_kv_list(self):
- """Test retrieving tags via kv_list."""
- partition_id = "test_partition"
- keys = ["tag_list_0", "tag_list_1", "tag_list_2"]
-
- expected_tags = [
- {"score": 0.9, "label": "A"},
- {"score": 0.85, "label": "B"},
- {"score": 0.95, "label": "C"},
- ]
-
- for key, tag in zip(keys, expected_tags, strict=False):
- tq.kv_put(key=key, partition_id=partition_id, fields={"x": torch.tensor([[1]])}, tag=tag)
-
- # List and verify tags
- listed_keys, tags = tq.kv_list(partition_id=partition_id)
-
- for key, expected_tag in zip(keys, expected_tags, strict=False):
- assert key in listed_keys
- idx = listed_keys.index(key)
- assert tags[idx] == expected_tag
-
-
class TestKVE2ECornerCases:
"""End-to-end tests for corner cases."""
- def test_key_to_global_index_mapping_consistency(self, controller):
- """Test that key->global_index mapping is consistent across operations."""
- partition_id = "test_partition"
- keys = ["map_0", "map_1", "map_2", "map_3"]
-
- # Put all keys
- tq.kv_batch_put(
- keys=keys,
- partition_id=partition_id,
- fields=TensorDict({"data": torch.randn(4, 5)}, batch_size=4),
- tags=[{"i": i} for i in range(4)],
- )
-
- # Verify mapping consistency via controller
- partition = get_controller_partition(controller, partition_id)
-
- for key in keys:
- assert key in partition.keys_mapping
- global_idx = partition.keys_mapping[key]
- assert global_idx in partition.revert_keys_mapping
- assert partition.revert_keys_mapping[global_idx] == key
-
def test_field_expansion_across_samples(self, controller):
"""Test that new fields can be added across samples."""
partition_id = "test_partition"
@@ -498,77 +498,26 @@ def test_field_expansion_across_samples(self, controller):
tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_b": torch.tensor([[2]])}, tag=None)
# Add different field to second key
- tq.kv_put(key=keys[1], partition_id=partition_id, fields={"field_c": torch.tensor([[3]])}, tag=None)
+ tq.kv_put(
+ key=keys[1],
+ partition_id=partition_id,
+ fields={"field_a": torch.tensor([[3]]), "field_c": torch.tensor([[4]])},
+ tag=None,
+ )
# Verify field expansion in controller
partition = get_controller_partition(controller, partition_id)
- # All fields should be registered
+ # All fields should be registered, but only samples with the actual fields are labeled as READY_FOR_CONSUME
assert "field_a" in partition.field_name_mapping
assert "field_b" in partition.field_name_mapping
assert "field_c" in partition.field_name_mapping
- def test_empty_tag_list(self):
- """Test operations with empty tags."""
- partition_id = "test_partition"
- key = "empty_tag"
-
- # Use 1D tensor - will be auto-unsqueezed to 2D
- tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([1])}, tag={})
-
- # Should work and data should be retrievable - will be 2D after unsqueeze
- retrieved = tq.kv_get(keys=key, partition_id=partition_id)
- assert_tensor_equal(retrieved["data"], torch.tensor([[1]]))
-
- def test_large_batch_put_and_get(self):
- """Test putting and getting a large batch of samples."""
- partition_id = "test_partition"
- num_samples = 100
- keys = [f"large_{i}" for i in range(num_samples)]
-
- # Create batch data
- data = TensorDict(
- {
- "input_ids": torch.randn(num_samples, 10),
- "attention_mask": torch.ones(num_samples, 10),
- },
- batch_size=num_samples,
- )
-
- tags = [{"idx": i} for i in range(num_samples)]
-
- # Batch put
- tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data, tags=tags)
-
- # Batch get all
- retrieved = tq.kv_get(keys=keys, partition_id=partition_id)
-
- assert retrieved["input_ids"].shape == (num_samples, 10)
- assert retrieved["attention_mask"].shape == (num_samples, 10)
-
- # Verify specific samples
- assert_tensor_equal(retrieved["input_ids"][0], data["input_ids"][0])
- assert_tensor_equal(retrieved["input_ids"][99], data["input_ids"][99])
-
- def test_controller_partition_synchronization(self, controller):
- """Test that controller partition state is synchronized with operations."""
- partition_id = "test_partition"
- key = "sync_test"
-
- # Put data
- tq.kv_put(key=key, partition_id=partition_id, fields={"x": torch.tensor([[42]])}, tag={"sync": True})
-
- # Get snapshot before clear
- partition_before = get_controller_partition(controller, partition_id)
- global_idx = partition_before.keys_mapping[key]
- assert partition_before.production_status[global_idx, partition_before.field_name_mapping["x"]] == 1
-
- # Clear
- tq.kv_clear(keys=key, partition_id=partition_id)
-
- # Get snapshot after clear
- partition_after = get_controller_partition(controller, partition_id)
- assert key not in partition_after.keys_mapping
+ # We can only fetch "field_a" because not all requested keys has other fields
+ data = tq.kv_batch_get(keys=keys, partition_id=partition_id)
+ assert "field_a" in data
+ assert "field_b" not in data
+ assert "field_c" not in data
def run_tests():
diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py
index f0421c0..26f1225 100644
--- a/transfer_queue/__init__.py
+++ b/transfer_queue/__init__.py
@@ -23,9 +23,9 @@
async_clear_samples,
async_get_data,
async_get_meta,
+ async_kv_batch_get,
async_kv_batch_put,
async_kv_clear,
- async_kv_get,
async_kv_list,
async_kv_put,
async_put,
@@ -36,9 +36,9 @@
get_data,
get_meta,
init,
+ kv_batch_get,
kv_batch_put,
kv_clear,
- kv_get,
kv_list,
kv_put,
put,
@@ -70,12 +70,12 @@
"close",
"kv_put",
"kv_batch_put",
- "kv_get",
+ "kv_batch_get",
"kv_list",
"kv_clear",
"async_kv_put",
"async_kv_batch_put",
- "async_kv_get",
+ "async_kv_batch_get",
"async_kv_list",
"async_kv_clear",
] + [
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index ed325b8..49b706f 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -798,7 +798,7 @@ def kv_batch_put(keys: list[str], partition_id: str, fields: TensorDict, tags: l
tq_client.set_custom_meta(batch_meta)
-def kv_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict:
+def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict:
"""Get data from TransferQueue using user-specified keys.
This is a convenience method for retrieving data using keys instead of indexes.
@@ -819,9 +819,9 @@ def kv_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str]
>>> import transfer_queue as tq
>>> tq.init()
>>> # Get single key with all fields
- >>> data = tq.kv_get(key="sample_1", partition_id="train")
+ >>> data = tq.kv_batch_get(key="sample_1", partition_id="train")
>>> # Get multiple keys with specific fields
- >>> data = tq.kv_get(
+ >>> data = tq.kv_batch_get(
... keys=["sample_1", "sample_2"],
... partition_id="train",
... fields="input_ids"
@@ -1038,7 +1038,7 @@ async def async_kv_batch_put(
await tq_client.async_set_custom_meta(batch_meta)
-async def async_kv_get(
+async def async_kv_batch_get(
keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None
) -> TensorDict:
"""Asynchronously get data from TransferQueue using user-specified keys.
@@ -1061,9 +1061,9 @@ async def async_kv_get(
>>> import transfer_queue as tq
>>> tq.init()
>>> # Get single key with all fields
- >>> data = await tq.async_kv_get(key="sample_1", partition_id="train")
+ >>> data = await tq.async_kv_batch_get(key="sample_1", partition_id="train")
>>> # Get multiple keys with specific fields
- >>> data = await tq.async_kv_get(
+ >>> data = await tq.async_kv_batch_get(
... keys=["sample_1", "sample_2"],
... partition_id="train",
... fields="input_ids"
diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py
index 2b309ca..9f8f3a7 100644
--- a/tutorial/02_kv_interface.py
+++ b/tutorial/02_kv_interface.py
@@ -56,10 +56,10 @@
def demonstrate_kv_api():
"""
Demonstrate the Key-Value (KV) semantic API:
- kv_put & kv_batch_put -> kv_list -> kv_get -> kv_clear
+ kv_put & kv_batch_put -> kv_list -> kv_batch_get -> kv_clear
"""
print("=" * 80)
- print("Key-Value Semantic API Demo: kv_put/kv_batch_put → kv_list → kv_get → kv_clear")
+ print("Key-Value Semantic API Demo: kv_put/kv_batch_put → kv_list → kv_batch_get → kv_clear")
print("=" * 80)
# Step 1: Put a single key-value pair with kv_put
@@ -161,16 +161,16 @@ def demonstrate_kv_api():
for k, t in zip(all_keys, all_tags, strict=False):
print(f" - key='{k}' | tag={t}")
- # Step 6: Retrieve specific fields using kv_get
- print("\n[Step 6] Retrieving specific fields (Column) with kv_get...")
+ # Step 6: Retrieve specific fields using kv_batch_get
+ print("\n[Step 6] Retrieving specific fields (Column) with kv_batch_get...")
print(" Fetching only 'input_ids' to save bandwidth (ignoring 'attention_mask' and 'response').")
- retrieved_input_ids = tq.kv_get(keys=all_keys, partition_id=partition_id, fields="input_ids")
+ retrieved_input_ids = tq.kv_batch_get(keys=all_keys, partition_id=partition_id, fields="input_ids")
print(f" ✓ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.")
- # # Step 7: Retrieve all fields using kv_get
- print("\n[Step 7] Retrieving all fields with kv_get...")
- retrieved_all = tq.kv_get(keys=all_keys, partition_id=partition_id)
+ # # Step 7: Retrieve all fields using kv_batch_get
+ print("\n[Step 7] Retrieving all fields with kv_batch_get...")
+ retrieved_all = tq.kv_batch_get(keys=all_keys, partition_id=partition_id)
print(f" Retrieved all fields for {all_keys}:")
print(f" Fields: {list(retrieved_all.keys())}")
print(
@@ -200,7 +200,7 @@ def main():
Key Methods:
1. (async_)kv_put - Insert/Update a multi-column sample by key, with optional metadata tag
2. (async_)kv_batch_put - Put multiple key-value pairs efficiently in batch
- 3. (async_)kv_get - Retrieve samples (by keys), supporting column selection (by fields)
+ 3. (async_)kv_batch_get - Retrieve samples (by keys), supporting column selection (by fields)
4. (async_)kv_list - List keys and tags (metadata) in a partition
5. (async_)kv_clear - Remove key-value pairs from storage
From 99b7fc995bd0f1aaa8e5d4018acc751a6ca377c7 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sun, 8 Feb 2026 15:54:52 +0800
Subject: [PATCH 19/34] update readme
Signed-off-by: 0oshowero0
---
README.md | 79 ++++++++++++++++++++++---------------
tutorial/02_kv_interface.py | 6 +--
2 files changed, 50 insertions(+), 35 deletions(-)
diff --git a/README.md b/README.md
index 2c7d4f0..ba9e54d 100644
--- a/README.md
+++ b/README.md
@@ -31,7 +31,8 @@ TransferQueue offers **fine-grained, sub-sample-level** data management and **lo
🔄 Updates
- - **Jan 28, 2026**: We experimentally introduce `StreamingDataloader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/05_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_streaming_dataloader.py) for details.
+ - **Feb 8, 2026**: 🔥 The initialization and usage is greatly simplified by high-level APIs [PR#26](https://github.com/Ascend/TransferQueue/pull/26), [PR#28](https://github.com/Ascend/TransferQueue/pull/28). You can now use a Redis-style API to take advantage of most of the advanced features provided by TransferQueue!
+ - **Jan 28, 2026**: We experimentally introduce `StreamingDataloader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for details.
- **Dec 30, 2025**: **TransferQueue x verl** integration is tested with the DAPO algorithm at scale **(64 nodes, 1024 cards)**. It significantly optimizes host memory utilization and accelerates data transfers. Stay tuned for more details!
- **Dec 20, 2025**: 🔥 The official [tutorial](https://github.com/Ascend/TransferQueue/tree/main/tutorial) is released! Feel free to check it out.
- **Nov 10, 2025**: We disentangle the data retrieval logic from TransferQueueController [PR#101](https://github.com/TransferQueue/TransferQueue/pull/101). Now you can implement your own `Sampler` to control how to consume the data.
@@ -91,26 +92,55 @@ This data structure design is motivated by the computational characteristics of
-### User Interface: Asynchronous & Synchronous Client
-To simplify the usage of TransferQueue, we have encapsulated this process into `TransferQueueClient`. The client provides both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue into their framework.
+### User Interface: High-Level & Low-Level APIs
-We also experimentally provide a `StreamingDataLoader` interface as a standard PyTorch DataLoader. Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks.
-This interface simplifies TransferQueue's integration, ensuring seamless compatibility with existing training workflows. Please refer to our [Roadmap](https://github.com/Ascend/TransferQueue/issues/1) and [tutorials/05_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_streaming_dataloader.py) for more details.
+| Level | Tier | Style | Fine-Grained Access | Streaming | Sampler | Multiple-Backends |
+|---|---|---|---|---|---|---|
+| High | **KV Interface** (this PR) | Put/Get/List/Clear | ✓ | ✗ | ✗ | ✓ |
+| High | **StreamingDataLoader** (#23) | PyTorch DataLoader | ✓ |✓ | ✓ | ✓ |
+| Low | **TransferQueueClient** | Metadata-based | ✓ | ✓ | ✓ | ✓ |
-🔥 Showcases
-### General Usage
+#### Key-Value based API
+
+To simplify the usage of TransferQueue, we have provided a Redis-style high-level API that can enjoy most of the advanced features provided by TransferQueue ([#PR28](https://github.com/Ascend/TransferQueue/pull/28)).
+
+**Methods**
+
+- **(async_)kv_put**: Insert/Update a multi-column sample by key, with optional metadata tag
+- **(async_)kv_batch_put**: Put multiple key-value pairs efficiently in batch
+- **(async_)kv_batch_get**: Retrieve samples (by keys), supporting column selection (by fields)
+- **(async_)kv_list**: List keys and tags (metadata) in a partition
+- **(async_)kv_clear**: Remove key-value pairs from storage
+
+**Key Features**
+
+- **Redis-style Semantics**: Familiar KV interface (Put/Get/List) for zero learning curve
+- **Fine-grained Access**: Update or retrieve specific fields (columns) within a key (row) without full op.
+- **Partition Isolation**: Logical separation of storage namespaces
+- **Metadata Tags**: Lightweight metadata for status tracking
+- **Pluggable Backends**: Supports multiple backends
+
+#### StreamingDataLoader API
-The primary interaction points are `AsyncTransferQueueClient` and `TransferQueueClient`, serving as the communication interface with the TransferQueue system.
+Designed as a drop-in replacement for the standard PyTorch `DataLoader`, this API allows each rank to automatically consume data without single-controller intervention.
-Core interfaces:
+In this scenario, `TransferQueueController` serves as a side-controller for data dispatching, with user-defined `Sampler` class to organize dataflow.
+It encapsulates the complex scheduling and data transfer logic required for various parallelism strategies, seamlessly integrating TransferQueue into existing training workflows and simplifying the development of disaggregated frameworks.
-- `(async_)get_meta(data_fields: list[str], batch_size:int, partition_id: str, mode: str, task_name:str, sampling_config: Optional[dict[str, Any]]) -> BatchMeta`
-- `(async_)get_data(metadata: BatchMeta) -> TensorDict`
-- `(async_)put(data: TensorDict, metadata: Optional[BatchMeta], partition_id: Optional[str])`
-- `(async_)clear_partition(partition_id: str)` and `(async_)clear_samples(metadata: BatchMeta)`
+See [Roadmap](https://github.com/Ascend/TransferQueue/issues/1) and [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for more details.
-**Refer to our [tutorial](https://github.com/Ascend/TransferQueue/tree/main/tutorial) for detailed examples.**
+#### Low-Level Native API
+
+The native interface of TransferQueue are implemented in `TransferQueueClient`. It offers maximum flexibility through native, atomic operations.
+
+Developers can leverage `TransferQueueClient` directly to implement advanced features that require fine-grained control and fully streamed data scheduling, as illustrated in the following tutorials:
+- [tutorial/03_metadata_concepts.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/03_metadata_concepts.py)
+- [tutorial/04_understanding_controller.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/04_understanding_controller.py)
+- [tutorial/05_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_custom_sampler.py)
+
+
+🔥 Showcases
### Collocated Example
@@ -131,7 +161,7 @@ You may refer to the [recipe](https://github.com/Ascend/TransferQueue/tree/dev/r
### Disaggregated Example
-We have implemented a series of PRs ([#4](https://github.com/Ascend/TransferQueue/pull/4), [#7](https://github.com/Ascend/TransferQueue/pull/7), [#9](https://github.com/Ascend/TransferQueue/pull/9), [#16](https://github.com/Ascend/TransferQueue/pull/16)) to establish a **standardized, fully-streamed distributed** workflow via TransferQueue.
+We have experimentally implemented a **standardized, fully-streamed distributed** workflow via TransferQueue.
By leveraging the `RankAwareSampler` and `StreamingDataLoader` interfaces, we achieve a **streamlined micro-batch-level producer-consumer pipeline**. This design eliminates the need to manually determine data dispatching logic across varying parallelism strategies—a typical complexity in the single-controller paradigm—thereby greatly simplifying framework design.
@@ -186,7 +216,7 @@ pip install TransferQueue
-> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorageUnit` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community!
+> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorage` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community!
For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/tml8ke0zkgn6roey?singleDoc#).
@@ -250,7 +280,7 @@ batch_meta = client.get_meta(
)
```
-**Refer to [tutorial/04_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/04_custom_sampler.py) for more details.**
+**Refer to [tutorial/05_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_custom_sampler.py) for more details.**
### How to integrate a new storage backend
@@ -299,21 +329,6 @@ pip install pre-commit
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
```
- 🛣️ Roadmap
-
-- [x] Support data rewrite for partial rollout & agentic post-training
-- [x] Provide a general storage abstraction layer `TransferQueueStorageManager` to manage distributed storage units, which simplifies `Client` design and makes it possible to introduce different storage backends ([PR#66](https://github.com/TransferQueue/TransferQueue/pull/66), [issue#72](https://github.com/TransferQueue/TransferQueue/issues/72))
-- [x] Implement `AsyncSimpleStorageManager` as the default storage backend based on the `TransferQueueStorageManager` abstraction
-- [x] Provide a `KVStorageManager` to cover all the KV based storage backends ([PR#96](https://github.com/TransferQueue/TransferQueue/pull/96))
-- [x] Support topic-based data partitioning to maintain train/val/test data simultaneously ([PR#98](https://github.com/TransferQueue/TransferQueue/pull/98))
-- [x] Release the first stable version through PyPI
-- [ ] Support disaggregated framework (each rank retrieves its own data without going through a centralized node)
-- [ ] Provide a `StreamingDataLoader` interface for disaggregated framework
-- [ ] Support load-balancing and dynamic batching
-- [x] Support high-performance storage backends for RDMA transmission (e.g., [Mooncake Store](https://github.com/kvcache-ai/Mooncake), [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html)...)
-- [x] High-performance serialization and deserialization
-- [ ] More documentation, examples and tutorials
-
📑 Citation
Please kindly cite our paper if you find this repo is useful:
diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py
index 9f8f3a7..51c994b 100644
--- a/tutorial/02_kv_interface.py
+++ b/tutorial/02_kv_interface.py
@@ -200,7 +200,7 @@ def main():
Key Methods:
1. (async_)kv_put - Insert/Update a multi-column sample by key, with optional metadata tag
2. (async_)kv_batch_put - Put multiple key-value pairs efficiently in batch
- 3. (async_)kv_batch_get - Retrieve samples (by keys), supporting column selection (by fields)
+ 3. (async_)kv_batch_get - Retrieve samples (by keys), supporting column selection (by fields)
4. (async_)kv_list - List keys and tags (metadata) in a partition
5. (async_)kv_clear - Remove key-value pairs from storage
@@ -216,9 +216,9 @@ def main():
- Integration with external ReplayBuffer/single-controller that manage sample dispatching
Limitations (vs low-level native APIs):
- - No built-in production/consumption tracking: Users have to manually check status through tags.
+ - No built-in production/consumption tracking: Users must manually check status via tags externally.
- No built-in Sampler support: Must implement data dispatch by ReplayBuffer or single-controller externally.
- - No fully streaming: Consumers must wait for single-controller to dispatch `keys`.
+ - Not fully streaming: Consumers must wait for single-controller to dispatch `keys`.
"""
)
)
From 8e80d6df97961c97f0fe6ec31b6dc7acbde1b0be Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sun, 8 Feb 2026 16:11:13 +0800
Subject: [PATCH 20/34] fix comments
Signed-off-by: 0oshowero0
---
tests/test_controller.py | 4 +-
transfer_queue/client.py | 6 +--
transfer_queue/interface.py | 18 ++++-----
transfer_queue/utils/common.py | 68 +---------------------------------
4 files changed, 15 insertions(+), 81 deletions(-)
diff --git a/tests/test_controller.py b/tests/test_controller.py
index f3a5c26..243bc19 100644
--- a/tests/test_controller.py
+++ b/tests/test_controller.py
@@ -922,8 +922,8 @@ def test_controller_kv_retrieve_keys_with_custom_meta(self, ray_setup):
# Verify custom_meta is preserved
all_custom_meta = retrieved_metadata.get_all_custom_meta()
assert len(all_custom_meta) == 2
- assert all_custom_meta[metadata.global_indexes[0]]["score"] == 0.9
- assert all_custom_meta[metadata.global_indexes[1]]["tag"] == "B"
+ assert all_custom_meta[0]["score"] == 0.9
+ assert all_custom_meta[1]["tag"] == "B"
print("✓ kv_retrieve_keys preserves custom_meta")
diff --git a/transfer_queue/client.py b/transfer_queue/client.py
index 2af07a7..302215e 100644
--- a/transfer_queue/client.py
+++ b/transfer_queue/client.py
@@ -943,7 +943,7 @@ async def async_kv_retrieve_keys(
raise ValueError("Received an empty list as keys.")
# validate all the elements are str
if not all(isinstance(k, str) for k in keys):
- raise TypeError("Not all element in `keys` are strings.")
+ raise TypeError("Not all elements in `keys` are strings.")
else:
raise TypeError("Only string or list of strings are allowed as `keys`.")
@@ -1028,7 +1028,7 @@ async def async_kv_list(
f"{response_msg.body.get('message', 'Unknown error')}"
)
except Exception as e:
- raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e
+ raise RuntimeError(f"[{self.client_id}]: Error in kv_list: {str(e)}") from e
def close(self) -> None:
"""Close the client and cleanup resources including storage manager."""
@@ -1475,7 +1475,7 @@ def kv_retrieve_keys(
def kv_list(
self,
partition_id: str,
- ) -> tuple[list[Optional[str]], list[Optional[dict]]]:
+ ) -> tuple[list[str], list[dict]]:
"""Synchronously retrieve keys and custom_meta from the controller for partition.
Args:
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index 49b706f..905566e 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -717,13 +717,13 @@ def kv_put(
if isinstance(fields, dict):
# TODO: consider whether to support this...
batch = {}
- for key, value in fields.items():
+ for field_name, value in fields.items():
if isinstance(value, torch.Tensor):
if value.is_nested:
raise ValueError("Please use (async)kv_batch_put for batch operation")
- batch[key] = value.unsqueeze(0)
+ batch[field_name] = value.unsqueeze(0)
else:
- batch[key] = NonTensorStack(value)
+ batch[field_name] = NonTensorStack(value)
fields = TensorDict(batch, batch_size=[1])
elif not isinstance(fields, TensorDict):
raise ValueError("field can only be dict or TensorDict")
@@ -768,7 +768,7 @@ def kv_batch_put(keys: list[str], partition_id: str, fields: TensorDict, tags: l
if fields is None and tags is None:
raise ValueError("Please provide at least one parameter of fields or tag.")
- if fields.batch_size[0] != len(keys):
+ if fields is not None and fields.batch_size[0] != len(keys):
raise ValueError(
f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
f"batch_size {fields.batch_size[0]}"
@@ -846,7 +846,7 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list
return data
-def kv_list(partition_id: str) -> tuple[list[Optional[str]], list[Optional[dict[str, Any]]]]:
+def kv_list(partition_id: str) -> tuple[list[str], list[dict[str, Any]]]:
"""List all keys and their metadata in a partition.
Args:
@@ -956,13 +956,13 @@ async def async_kv_put(
if isinstance(fields, dict):
# TODO: consider whether to support this...
batch = {}
- for key, value in fields.items():
+ for field_name, value in fields.items():
if isinstance(value, torch.Tensor):
if value.is_nested:
raise ValueError("Please use (async)kv_batch_put for batch operation")
- batch[key] = value.unsqueeze(0)
+ batch[field_name] = value.unsqueeze(0)
else:
- batch[key] = NonTensorStack(value)
+ batch[field_name] = NonTensorStack(value)
fields = TensorDict(batch, batch_size=[1])
elif not isinstance(fields, TensorDict):
raise ValueError("field can only be dict or TensorDict")
@@ -1008,7 +1008,7 @@ async def async_kv_batch_put(
if fields is None and tags is None:
raise ValueError("Please provide at least one parameter of fields or tag.")
- if fields.batch_size[0] != len(keys):
+ if fields is not None and fields.batch_size[0] != len(keys):
raise ValueError(
f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
f"batch_size {fields.batch_size[0]}"
diff --git a/transfer_queue/utils/common.py b/transfer_queue/utils/common.py
index 5192e9f..e25f6b0 100644
--- a/transfer_queue/utils/common.py
+++ b/transfer_queue/utils/common.py
@@ -16,13 +16,11 @@
import logging
import os
from contextlib import contextmanager
-from typing import Any, Optional
+from typing import Optional
-import numpy as np
import psutil
import ray
import torch
-from tensordict import NonTensorStack, TensorDict
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
@@ -100,67 +98,3 @@ def get_env_bool(env_key: str, default: bool = False) -> bool:
true_values = {"true", "1", "yes", "y", "on"}
return env_value_lower in true_values
-
-
-def dict_to_tensordict(data: dict[str, Any]) -> TensorDict:
- """
- Create a TensorDict from a dict of tensors and non_tensors.
- """
-
- batch = {}
-
- final_batch_size = None
- tensor_batch_size = None
- deterministic_tensor_batch_size = None
- deterministic_non_tensor_batch_size = None
-
- for key, val in data.items():
- if isinstance(val, torch.Tensor):
- if val.is_nested and val.layout == torch.strided:
- # must use unbind for strided nested tensor
- deterministic_tensor_batch_size = len(val.unbind())
- else:
- tensor_batch_size = val.shape[0]
- batch[key] = val
- elif isinstance(val, np.ndarray):
- batch[key] = val
- tensor_batch_size = val.shape[0]
- elif isinstance(val, str):
- batch[key] = val
- deterministic_non_tensor_batch_size = 1
- elif isinstance(val, list):
- batch[key] = NonTensorStack(*val)
- deterministic_non_tensor_batch_size = len(val)
- else:
- batch[key] = NonTensorStack(val)
- deterministic_non_tensor_batch_size = 1
-
- if deterministic_tensor_batch_size:
- if deterministic_non_tensor_batch_size:
- assert deterministic_non_tensor_batch_size == deterministic_tensor_batch_size
- if final_batch_size:
- assert final_batch_size == deterministic_tensor_batch_size
- else:
- final_batch_size = deterministic_tensor_batch_size
-
- if deterministic_non_tensor_batch_size:
- if deterministic_tensor_batch_size:
- assert deterministic_non_tensor_batch_size == tensor_batch_size
- if final_batch_size:
- assert final_batch_size == deterministic_non_tensor_batch_size
- else:
- final_batch_size = deterministic_non_tensor_batch_size
-
- if not final_batch_size:
- raise RuntimeError("Cannot correctly determine batch_size for input.")
-
- if tensor_batch_size:
- if tensor_batch_size != final_batch_size:
- assert final_batch_size == 1
- for k, v in batch.items():
- if isinstance(v, torch.Tensor):
- batch[k] = v.unsqueeze(0)
- elif isinstance(v, np.ndarray):
- batch[k] = np.expand_dims(v, 0)
-
- return TensorDict(batch, batch_size=[final_batch_size])
From b6b7d4b51c11d9b9f7a7728ea4b4daf6d8bad337 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sun, 8 Feb 2026 17:07:41 +0800
Subject: [PATCH 21/34] fix comment
Signed-off-by: 0oshowero0
---
transfer_queue/controller.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py
index 538ba21..03e3631 100644
--- a/transfer_queue/controller.py
+++ b/transfer_queue/controller.py
@@ -545,6 +545,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
)
if mask:
+ if partition_global_index.numel() == 0:
+ empty_status = self.consumption_status[task_name].new_zeros(0)
+ return partition_global_index, empty_status
with self.data_status_lock:
self.ensure_samples_capacity(max(partition_global_index) + 1)
consumption_status = self.consumption_status[task_name][partition_global_index]
From d3dd0147523268d26b54e74ac358d12396fb2163 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sun, 8 Feb 2026 17:16:43 +0800
Subject: [PATCH 22/34] fix more minor comments
Signed-off-by: 0oshowero0
---
transfer_queue/client.py | 8 ++++----
transfer_queue/interface.py | 34 +++++++++++++++++++++-------------
2 files changed, 25 insertions(+), 17 deletions(-)
diff --git a/transfer_queue/client.py b/transfer_queue/client.py
index 302215e..2a0f3f1 100644
--- a/transfer_queue/client.py
+++ b/transfer_queue/client.py
@@ -270,7 +270,7 @@ async def async_set_custom_meta(
Example:
>>> # Create batch with custom metadata
>>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...)
- >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
+ >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
>>> asyncio.run(client.async_set_custom_meta(batch_meta))
"""
assert socket is not None
@@ -1193,8 +1193,8 @@ def set_custom_meta(self, metadata: BatchMeta) -> None:
Example:
>>> # Create batch with custom metadata
- >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...)
- >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
+ >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=2, ...)
+ >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
>>> client.set_custom_meta(batch_meta)
"""
@@ -1271,7 +1271,7 @@ def get_data(self, metadata: BatchMeta) -> TensorDict:
- Requested data fields (e.g., "prompts", "attention_mask")
Example:
- >>> batch_meta = client.get_data(
+ >>> batch_meta = client.get_meta(
... data_fields=["prompts", "attention_mask"],
... batch_size=4,
... partition_id="train_0",
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index 905566e..75fedbd 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -282,8 +282,8 @@ def set_custom_meta(metadata: BatchMeta) -> None:
>>> tq.init()
>>>
>>> # Create batch with custom metadata
- >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...)
- >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
+ >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...)
+ >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
>>> tq.set_custom_meta(batch_meta)
"""
tq_client = _maybe_create_transferqueue_client()
@@ -367,7 +367,7 @@ def get_data(metadata: BatchMeta) -> TensorDict:
>>> import transfer_queue as tq
>>> tq.init()
>>>
- >>> batch_meta = tq.get_data(
+ >>> batch_meta = tq.get_meta(
... data_fields=["prompts", "attention_mask"],
... batch_size=4,
... partition_id="train_0",
@@ -496,8 +496,8 @@ async def async_set_custom_meta(
>>> tq.init()
>>>
>>> # Create batch with custom metadata
- >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...)
- >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
+ >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...)
+ >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
>>> asyncio.run(tq.async_set_custom_meta(batch_meta))
"""
tq_client = _maybe_create_transferqueue_client()
@@ -664,7 +664,10 @@ def close():
# ==================== KV Interface API ====================
def kv_put(
- key: str, partition_id: str, fields: Optional[TensorDict | dict[str, Any]], tag: Optional[dict[str, Any]]
+ key: str,
+ partition_id: str,
+ fields: Optional[TensorDict | dict[str, Any]] = None,
+ tag: Optional[dict[str, Any]] = None,
) -> None:
"""Put a single key-value pair to TransferQueue.
@@ -735,7 +738,9 @@ def kv_put(
tq_client.set_custom_meta(batch_meta)
-def kv_batch_put(keys: list[str], partition_id: str, fields: TensorDict, tags: list[dict[str, Any]]) -> None:
+def kv_batch_put(
+ keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None
+) -> None:
"""Put multiple key-value pairs to TransferQueue in batch.
This method stores multiple key-value pairs in a single operation, which is more
@@ -819,7 +824,7 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list
>>> import transfer_queue as tq
>>> tq.init()
>>> # Get single key with all fields
- >>> data = tq.kv_batch_get(key="sample_1", partition_id="train")
+ >>> data = tq.kv_batch_get(keys="sample_1", partition_id="train")
>>> # Get multiple keys with specific fields
>>> data = tq.kv_batch_get(
... keys=["sample_1", "sample_2"],
@@ -885,7 +890,7 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None:
>>> import transfer_queue as tq
>>> tq.init()
>>> # Clear single key
- >>> tq.kv_clear(key="sample_1", partition_id="train")
+ >>> tq.kv_clear(keys="sample_1", partition_id="train")
>>> # Clear multiple keys
>>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
"""
@@ -902,7 +907,10 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None:
# ==================== KV Interface API ====================
async def async_kv_put(
- key: str, partition_id: str, fields: Optional[TensorDict | dict[str, Any]], tag: Optional[dict[str, Any]]
+ key: str,
+ partition_id: str,
+ fields: Optional[TensorDict | dict[str, Any]] = None,
+ tag: Optional[dict[str, Any]] = None,
) -> None:
"""Asynchronously put a single key-value pair to TransferQueue.
@@ -975,7 +983,7 @@ async def async_kv_put(
async def async_kv_batch_put(
- keys: list[str], partition_id: str, fields: TensorDict, tags: list[dict[str, Any]]
+ keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None
) -> None:
"""Asynchronously put multiple key-value pairs to TransferQueue in batch.
@@ -1061,7 +1069,7 @@ async def async_kv_batch_get(
>>> import transfer_queue as tq
>>> tq.init()
>>> # Get single key with all fields
- >>> data = await tq.async_kv_batch_get(key="sample_1", partition_id="train")
+ >>> data = await tq.async_kv_batch_get(keys="sample_1", partition_id="train")
>>> # Get multiple keys with specific fields
>>> data = await tq.async_kv_batch_get(
... keys=["sample_1", "sample_2"],
@@ -1127,7 +1135,7 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
>>> import transfer_queue as tq
>>> tq.init()
>>> # Clear single key
- >>> await tq.async_kv_clear(key="sample_1", partition_id="train")
+ >>> await tq.async_kv_clear(keys="sample_1", partition_id="train")
>>> # Clear multiple keys
>>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
"""
From 1eda4e931ffe0cef2874c431a9cdf0292db8b3de Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Sun, 8 Feb 2026 18:53:44 +0800
Subject: [PATCH 23/34] fix comments and add FIXME comments
Signed-off-by: 0oshowero0
---
README.md | 2 +-
transfer_queue/client.py | 8 ++++----
transfer_queue/interface.py | 8 ++++----
transfer_queue/storage/managers/base.py | 4 +++-
4 files changed, 12 insertions(+), 10 deletions(-)
diff --git a/README.md b/README.md
index ba9e54d..150c666 100644
--- a/README.md
+++ b/README.md
@@ -32,7 +32,7 @@ TransferQueue offers **fine-grained, sub-sample-level** data management and **lo
🔄 Updates
- **Feb 8, 2026**: 🔥 The initialization and usage is greatly simplified by high-level APIs [PR#26](https://github.com/Ascend/TransferQueue/pull/26), [PR#28](https://github.com/Ascend/TransferQueue/pull/28). You can now use a Redis-style API to take advantage of most of the advanced features provided by TransferQueue!
- - **Jan 28, 2026**: We experimentally introduce `StreamingDataloader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for details.
+ - **Jan 28, 2026**: We experimentally introduce `StreamingDataLoader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for details.
- **Dec 30, 2025**: **TransferQueue x verl** integration is tested with the DAPO algorithm at scale **(64 nodes, 1024 cards)**. It significantly optimizes host memory utilization and accelerates data transfers. Stay tuned for more details!
- **Dec 20, 2025**: 🔥 The official [tutorial](https://github.com/Ascend/TransferQueue/tree/main/tutorial) is released! Feel free to check it out.
- **Nov 10, 2025**: We disentangle the data retrieval logic from TransferQueueController [PR#101](https://github.com/TransferQueue/TransferQueue/pull/101). Now you can implement your own `Sampler` to control how to consume the data.
diff --git a/transfer_queue/client.py b/transfer_queue/client.py
index 2a0f3f1..470d57d 100644
--- a/transfer_queue/client.py
+++ b/transfer_queue/client.py
@@ -260,8 +260,8 @@ async def async_set_custom_meta(
Args:
metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta() or
- BatchMeta.set_custom_meta() before calling this method.
+ The custom_meta should be set using BatchMeta.update_custom_meta()
+ before calling this method.
socket: ZMQ async socket for message transmission (injected by decorator)
Raises:
@@ -1185,8 +1185,8 @@ def set_custom_meta(self, metadata: BatchMeta) -> None:
Args:
metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta() or
- BatchMeta.set_custom_meta() before calling this method.
+ The custom_meta should be set using BatchMeta.update_custom_meta()
+ before calling this method.
Raises:
RuntimeError: If communication fails or controller returns error response
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index 75fedbd..42e8839 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -271,8 +271,8 @@ def set_custom_meta(metadata: BatchMeta) -> None:
Args:
metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta() or
- BatchMeta.set_custom_meta() before calling this method.
+ The custom_meta should be set using BatchMeta.update_custom_meta()
+ before calling this method.
Raises:
RuntimeError: If communication fails or controller returns error response
@@ -484,8 +484,8 @@ async def async_set_custom_meta(
Args:
metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta() or
- BatchMeta.set_custom_meta() before calling this method.
+ The custom_meta should be set using BatchMeta.update_custom_meta()
+ before calling this method.
socket: ZMQ async socket for message transmission (injected by decorator)
Raises:
diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py
index 258dd91..c0c5058 100644
--- a/transfer_queue/storage/managers/base.py
+++ b/transfer_queue/storage/managers/base.py
@@ -589,7 +589,9 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
for global_idx in metadata.global_indexes:
per_field_custom_backend_meta[global_idx] = {}
- # TODO(tianyi): the order of custom meta is coupled with keys/values
+ # FIXME(tianyi): the order of custom backend meta is coupled with keys/values
+ # FIXME: if put_data is called to partially update/add new fields, the current
+ # implementation will cause custom_backend_meta losses or mismatch!
for (field_name, global_idx), meta_value in zip(
itertools.product(sorted(metadata.field_names), metadata.global_indexes),
custom_backend_meta,
From 620edfe0bd0922aa6d500b04d90d97e34f6c03bb Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 13:00:34 +0800
Subject: [PATCH 24/34] improve kv_list
Signed-off-by: 0oshowero0
---
tests/e2e/test_kv_interface_e2e.py | 97 ++++++++++++++++++++++++------
tests/test_client.py | 48 ++++++---------
transfer_queue/client.py | 57 ++++++++++++------
transfer_queue/controller.py | 28 ++++++---
transfer_queue/interface.py | 77 +++++++++++++++++-------
tutorial/02_kv_interface.py | 14 +++--
6 files changed, 218 insertions(+), 103 deletions(-)
diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py
index 3918bbe..71e9833 100644
--- a/tests/e2e/test_kv_interface_e2e.py
+++ b/tests/e2e/test_kv_interface_e2e.py
@@ -69,7 +69,9 @@ def cleanup_partition(controller):
"""Cleanup partition after each test."""
yield
try:
- ray.get(controller.clear_partition.remote("test_partition"))
+ partition_ids = ray.get(controller.list_partitions.remote())
+ for partition_id in partition_ids:
+ ray.get(controller.clear_partition.remote(partition_id))
except Exception:
pass
@@ -409,8 +411,8 @@ def test_kv_batch_get_nonexistent_key(self, controller):
class TestKVListE2E:
"""End-to-end tests for kv_list functionality."""
- def test_kv_list_all_keys(self, controller):
- """Test listing all keys in a partition."""
+ def test_kv_list_single_partition(self, controller):
+ """Test listing all keys and tags in single partition."""
partition_id = "test_partition"
keys = ["list_0", "list_1", "list_2"]
@@ -418,24 +420,81 @@ def test_kv_list_all_keys(self, controller):
tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag={"id": i})
# List all keys
- listed_keys, tags = tq.kv_list(partition_id=partition_id)
+ partition_info = tq.kv_list(partition_id=partition_id)
- assert len(listed_keys) == 3
+ assert len(partition_info.keys()) == 1
+ assert "test_partition" in partition_info.keys()
+ assert len(partition_info["test_partition"]) == 3
for key in keys:
- assert key in listed_keys
+ assert key in partition_info["test_partition"]
# Verify tags match
- for i, (key, tag) in enumerate(zip(listed_keys, tags, strict=False)):
+ for i, (key, tag) in enumerate(partition_info["test_partition"].items()):
assert tag["id"] == i
+ def test_kv_list_all_partitions(self, controller):
+ """Test listing keys and tags in all partitions."""
+ partition_id = ["test_partition0", "test_partition1", "test_partition2"]
+
+ keys_partition0 = ["list_0", "list_1", "list_2"]
+ keys_partition1 = ["list_0", "list_1", "list_2"] # deliberately set same keys
+ keys_partition2 = ["list_3", "list_4", "list_5", "list_6"]
+
+ fields_partition0 = TensorDict({"data": torch.tensor([[0], [1], [2]])}, batch_size=3)
+ fields_partition1 = TensorDict({"data": torch.tensor([[3], [4], [5]])}, batch_size=3)
+ fields_partition2 = TensorDict({"data": torch.tensor([[6], [7], [8], [9]])}, batch_size=4)
+
+ tags_partition0 = [{"id": i} for i in range(3)]
+ tags_partition1 = [{"id": i + 3} for i in range(3)]
+ tags_partition2 = [{"id": i + 6} for i in range(4)]
+
+ # Put to TQ
+ tq.kv_batch_put(
+ keys=keys_partition0, partition_id=partition_id[0], fields=fields_partition0, tags=tags_partition0
+ )
+ tq.kv_batch_put(
+ keys=keys_partition1, partition_id=partition_id[1], fields=fields_partition1, tags=tags_partition1
+ )
+ tq.kv_batch_put(
+ keys=keys_partition2, partition_id=partition_id[2], fields=fields_partition2, tags=tags_partition2
+ )
+
+ # List all keys
+ partition_info = tq.kv_list()
+
+ # Verify all partitions are exist
+ assert len(partition_info.keys()) == 3
+ assert "test_partition0" in partition_info.keys()
+ assert "test_partition1" in partition_info.keys()
+ assert "test_partition2" in partition_info.keys()
+
+ assert len(partition_info["test_partition0"]) == 3
+ for key in keys_partition0:
+ assert key in partition_info["test_partition0"]
+
+ assert len(partition_info["test_partition1"]) == 3
+ for key in keys_partition1:
+ assert key in partition_info["test_partition1"]
+
+ assert len(partition_info["test_partition2"]) == 4
+ for key in keys_partition2:
+ assert key in partition_info["test_partition2"]
+
+ # Verify tags match
+ for i, (key, tag) in enumerate(partition_info["test_partition0"].items()):
+ assert tag["id"] == i
+ for i, (key, tag) in enumerate(partition_info["test_partition1"].items()):
+ assert tag["id"] == i + 3
+ for i, (key, tag) in enumerate(partition_info["test_partition2"].items()):
+ assert tag["id"] == i + 6
+
def test_kv_list_empty_partition(self):
"""Test listing empty partition."""
partition_id = "test_partition_empty"
- keys, tags = tq.kv_list(partition_id=partition_id)
+ partition_info = tq.kv_list(partition_id=partition_id)
- assert len(keys) == 0
- assert len(tags) == 0
+ assert len(partition_info) == 0
class TestKVClearE2E:
@@ -454,9 +513,9 @@ def test_kv_clear_single_key(self, controller):
tq.kv_clear(keys=key, partition_id=partition_id)
# Verify via kv_list
- listed_keys, _ = tq.kv_list(partition_id=partition_id)
- assert key not in listed_keys
- assert other_key in listed_keys
+ partition_info = tq.kv_list(partition_id=partition_id)
+ assert key not in partition_info[partition_id]
+ assert other_key in partition_info[partition_id]
# Verify via controller - key should be removed
partition = get_controller_partition(controller, partition_id)
@@ -475,12 +534,12 @@ def test_kv_clear_multiple_keys(self, controller):
tq.kv_clear(keys=keys[:2], partition_id=partition_id)
# Verify
- listed_keys, _ = tq.kv_list(partition_id=partition_id)
- assert len(listed_keys) == 2
- assert keys[0] not in listed_keys
- assert keys[1] not in listed_keys
- assert keys[2] in listed_keys
- assert keys[3] in listed_keys
+ partition_info = tq.kv_list(partition_id=partition_id)
+ assert len(partition_info[partition_id]) == 2
+ assert keys[0] not in partition_info[partition_id]
+ assert keys[1] not in partition_info[partition_id]
+ assert keys[2] in partition_info[partition_id]
+ assert keys[3] in partition_info[partition_id]
class TestKVE2ECornerCases:
diff --git a/tests/test_client.py b/tests/test_client.py
index 5d3360f..5d308d8 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -239,7 +239,7 @@ def _mock_kv_retrieve_keys(self, request_body):
def _mock_kv_list(self, request_body):
"""Mock KV list response."""
- partition_id = request_body.get("partition_id", "")
+ partition_id = request_body.get("partition_id", None)
# Initialize key tracking if not exists
if not hasattr(self, "_kv_partition_keys"):
@@ -247,7 +247,8 @@ def _mock_kv_list(self, request_body):
# Return cached keys for this partition
keys = self._kv_partition_keys.get(partition_id, [])
- return {"keys": keys, "custom_meta": [{} for _ in range(len(keys))]}
+
+ return {"partition_info": {partition_id: {k: {} for k in keys}}, "message": "success"}
def _get_next_kv_index(self, partition_id):
"""Get next available index for KV keys in partition."""
@@ -1041,18 +1042,6 @@ async def test_async_kv_retrieve_keys_invalid_keys_type(self, client_setup):
create=True,
)
- @pytest.mark.asyncio
- async def test_async_kv_list_empty_partition(self, client_setup):
- """Test async_kv_list with empty partition."""
- client, _, _ = client_setup
-
- # Test async_kv_list with empty partition
- keys, custom_meta = await client.async_kv_list(partition_id="empty_partition")
-
- # Should return empty list for partition with no keys
- assert keys == []
- assert custom_meta == []
-
@pytest.mark.asyncio
async def test_async_kv_list_with_keys(self, client_setup):
"""Test async_kv_list returns keys after they are registered."""
@@ -1066,14 +1055,13 @@ async def test_async_kv_list_with_keys(self, client_setup):
)
# Then list them
- keys, custom_meta = await client.async_kv_list(partition_id="kv_partition")
+ partition_info = await client.async_kv_list(partition_id="kv_partition")
# Verify keys are returned
- assert keys is not None
- assert len(keys) >= 2
- assert "key_1" in keys
- assert "key_2" in keys
- assert custom_meta == [{}, {}]
+ assert len(partition_info["kv_partition"]) >= 2
+ assert "key_1" in partition_info["kv_partition"]
+ assert "key_2" in partition_info["kv_partition"]
+ assert list(partition_info["kv_partition"].values()) == [{}, {}]
@pytest.mark.asyncio
async def test_async_kv_list_multiple_partitions(self, client_setup):
@@ -1093,16 +1081,20 @@ async def test_async_kv_list_multiple_partitions(self, client_setup):
)
# List keys for each partition
- keys_a, custom_meta_a = await client.async_kv_list(partition_id="partition_a")
- keys_b, custom_meta_b = await client.async_kv_list(partition_id="partition_b")
+ partition_a = await client.async_kv_list(partition_id="partition_a")
+ partition_b = await client.async_kv_list(partition_id="partition_b")
# Verify keys are isolated per partition
- assert "partition_a_key" in keys_a
- assert "partition_b_key" not in keys_a
- assert "partition_b_key" in keys_b
- assert "partition_a_key" not in keys_b
- assert custom_meta_a == [{}]
- assert custom_meta_b == [{}]
+ assert "partition_a" in partition_a
+ assert "partition_b" in partition_b
+ assert "partition_a" not in partition_b
+ assert "partition_b" not in partition_a
+ assert "partition_a_key" in partition_a["partition_a"]
+ assert "partition_b_key" not in partition_a["partition_a"]
+ assert "partition_b_key" in partition_b["partition_b"]
+ assert "partition_a_key" not in partition_b["partition_b"]
+ assert list(partition_a["partition_a"].values()) == [{}]
+ assert list(partition_b["partition_b"].values()) == [{}]
def test_kv_retrieve_keys_type_validation(self, client_setup):
"""Test synchronous kv_retrieve_keys type validation."""
diff --git a/transfer_queue/client.py b/transfer_queue/client.py
index 470d57d..5b0c1e1 100644
--- a/transfer_queue/client.py
+++ b/transfer_queue/client.py
@@ -983,23 +983,32 @@ async def async_kv_retrieve_keys(
@dynamic_socket(socket_name="request_handle_socket")
async def async_kv_list(
self,
- partition_id: str,
+ partition_id: Optional[str] = None,
socket: Optional[zmq.asyncio.Socket] = None,
- ) -> tuple[list[str], list[dict]]:
- """Asynchronously retrieve keys and custom_meta from the controller for partition.
+ ) -> dict[str, dict[str, Any]]:
+ """Asynchronously retrieve keys and custom_meta from the controller for one or all partitions.
Args:
- partition_id: Partition to retrieve from the controller
+ partition_id: The specific partition_id to query.
+ If None (default), returns keys from all partitions.
socket: ZMQ socket (injected by decorator)
Returns:
- keys: list of keys in the partition
- custom_meta: list of dict for custom_meta
+ A nested dictionary mapping partition IDs to their keys and metadata.
+
+ Structure:
+ {
+ "partition_id": {
+ "key_name": {
+ "tag1": ,
+ ... (other metadata)
+ },
+ ...,
+ },
+ ...
+ }
"""
- if partition_id is None:
- return [], []
-
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.KV_LIST, # type: ignore[arg-type]
sender_id=self.client_id,
@@ -1019,9 +1028,8 @@ async def async_kv_list(
)
if response_msg.request_type == ZMQRequestType.KV_LIST_RESPONSE:
- keys = response_msg.body.get("keys", [])
- custom_meta = response_msg.body.get("custom_meta", [])
- return keys, custom_meta
+ partition_info = response_msg.body.get("partition_info", {})
+ return partition_info
else:
raise RuntimeError(
f"[{self.client_id}]: Failed to list keys from controller {self._controller.id}: "
@@ -1474,16 +1482,29 @@ def kv_retrieve_keys(
def kv_list(
self,
- partition_id: str,
- ) -> tuple[list[str], list[dict]]:
- """Synchronously retrieve keys and custom_meta from the controller for partition.
+ partition_id: Optional[str] = None,
+ ) -> dict[str, dict[str, Any]]:
+ """Synchronously retrieve keys and custom_meta from the controller for one or all partitions.
Args:
- partition_id: Partition to retrieve from the controller
+ partition_id: The specific partition_id to query.
+ If None (default), returns keys from all partitions.
+ socket: ZMQ socket (injected by decorator)
Returns:
- keys: list of keys in the partition
- custom_meta: list of dict for custom_meta
+ A nested dictionary mapping partition IDs to their keys and metadata.
+
+ Structure:
+ {
+ "partition_id": {
+ "key_name": {
+ "tag1": ,
+ ... (other metadata)
+ },
+ ...,
+ },
+ ...
+ }
"""
return self._kv_list(partition_id=partition_id)
diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py
index 03e3631..3dcc1b5 100644
--- a/transfer_queue/controller.py
+++ b/transfer_queue/controller.py
@@ -1843,22 +1843,30 @@ def _process_request(self):
with perf_monitor.measure(op_type="KV_LIST"):
params = request_msg.body
partition_id = params["partition_id"]
- partition = self._get_partition(partition_id)
- if not partition:
- keys = []
- custom_meta = []
- message = f"Partition {partition_id} not found for kv_list."
- logger.debug(f"[{self.controller_id}]: {message}")
+ if partition_id is None:
+ partition_id = list(self.partitions.keys())
else:
- keys = list(partition.keys_mapping.keys())
- custom_meta = [partition.custom_meta.get(partition.keys_mapping[k], {}) for k in keys]
- message = "Success"
+ partition_id = [partition_id]
+
+ message = "success"
+ partition_info = {}
+ for pid in partition_id:
+ partition = self._get_partition(pid)
+ if partition:
+ keys = list(partition.keys_mapping.keys())
+ single_partition_info = {
+ k: partition.custom_meta.get(partition.keys_mapping[k], {}) for k in keys
+ }
+ partition_info[pid] = single_partition_info
+ else:
+ # this only happens when params["partition_id"] is not None
+ message = f"partition {pid} does not exist"
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.KV_LIST_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
- body={"keys": keys, "custom_meta": custom_meta, "message": message},
+ body={"partition_info": partition_info, "message": message},
)
self.request_handle_socket.send_multipart([identity, *response_msg.serialize()])
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index 42e8839..32670de 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -851,29 +851,45 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list
return data
-def kv_list(partition_id: str) -> tuple[list[str], list[dict[str, Any]]]:
- """List all keys and their metadata in a partition.
+def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]:
+ """List all keys and their metadata in one or all partitions.
Args:
- partition_id: Partition to list keys from
+ partition_id: The specific partition_id to query.
+ If None (default), returns keys from all partitions.
Returns:
- Tuple of:
- - List of keys in the partition
- - List of custom metadata (tags) associated with each key
+ A nested dictionary mapping partition IDs to their keys and metadata.
+
+ Structure:
+ {
+ "partition_id": {
+ "key_name": {
+ "tag1": ,
+ ... (other metadata)
+ },
+ ...,
+ },
+ ...
+ }
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>> keys, tags = tq.kv_list(partition_id="train")
- >>> print(f"Keys: {keys}")
- >>> print(f"Tags: {tags}")
+ >>> # Case 1: Retrieve a specific partition
+ >>> partitions = tq.kv_list(partition_id="train")
+ >>> print(f"Keys: {list(partitions['train'].keys())}")
+ >>> print(f"Tags: {list(partitions['train'].values())}")
+ >>> # Case 2: Retrieve all partitions
+ >>> all_partitions = tq.kv_list()
+ >>> for pid, keys in all_partitions.items():
+ >>> print(f"Partition: {pid}, Key count: {len(keys)}")
"""
tq_client = _maybe_create_transferqueue_client()
- keys, custom_meta = tq_client.kv_list(partition_id)
+ partition_info = tq_client.kv_list(partition_id)
- return keys, custom_meta
+ return partition_info
def kv_clear(keys: list[str] | str, partition_id: str) -> None:
@@ -1096,29 +1112,46 @@ async def async_kv_batch_get(
return data
-async def async_kv_list(partition_id: str) -> tuple[list[str], list[dict[str, Any]]]:
- """Asynchronously list all keys and their metadata in a partition.
+async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]:
+ """Asynchronously list all keys and their metadata in one or all partitions.
Args:
- partition_id: Partition to list keys from
+ partition_id: The specific partition_id to query.
+ If None (default), returns keys from all partitions.
Returns:
- Tuple of:
- - List of keys in the partition
- - List of custom metadata (tags) associated with each key
+ A nested dictionary mapping partition IDs to their keys and metadata.
+
+ Structure:
+ {
+ "partition_id": {
+ "key_name": {
+ "tag1": ,
+ ... (other metadata)
+ },
+ ...,
+ },
+ ...
+ }
+
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>> keys, tags = await tq.async_kv_list(partition_id="train")
- >>> print(f"Keys: {keys}")
- >>> print(f"Tags: {tags}")
+ >>> # Case 1: Retrieve a specific partition
+ >>> partitions = await tq.async_kv_list(partition_id="train")
+ >>> print(f"Keys: {list(partitions['train'].keys())}")
+ >>> print(f"Tags: {list(partitions['train'].values())}")
+ >>> # Case 2: Retrieve all partitions
+ >>> all_partitions = await tq.async_kv_list()
+ >>> for pid, keys in all_partitions.items():
+ >>> print(f"Partition: {pid}, Key count: {len(keys)}")
"""
tq_client = _maybe_create_transferqueue_client()
- keys, custom_meta = await tq_client.async_kv_list(partition_id)
+ partition_info = await tq_client.async_kv_list(partition_id)
- return keys, custom_meta
+ return partition_info
async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py
index 51c994b..376ebfb 100644
--- a/tutorial/02_kv_interface.py
+++ b/tutorial/02_kv_interface.py
@@ -156,15 +156,17 @@ def demonstrate_kv_api():
# Step 5: List all keys and tags in a partition
print("\n[Step 5] Listing all keys and tags in partition...")
- all_keys, all_tags = tq.kv_list(partition_id=partition_id)
- print(f" Found {len(all_keys)} keys in partition '{partition_id}':")
- for k, t in zip(all_keys, all_tags, strict=False):
- print(f" - key='{k}' | tag={t}")
+ partition_info = tq.kv_list()
+ print(f" Found {len(partition_info.keys())} partitions: '{list(partition_info.keys())}'")
+ for pid, keys_and_tags in partition_info.items():
+ for k, t in keys_and_tags.items():
+ print(f"Partition: {pid}, - key='{k}' | tag={t}")
# Step 6: Retrieve specific fields using kv_batch_get
print("\n[Step 6] Retrieving specific fields (Column) with kv_batch_get...")
print(" Fetching only 'input_ids' to save bandwidth (ignoring 'attention_mask' and 'response').")
+ all_keys = list(partition_info[partition_id].keys())
retrieved_input_ids = tq.kv_batch_get(keys=all_keys, partition_id=partition_id, fields="input_ids")
print(f" ✓ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.")
@@ -183,8 +185,8 @@ def demonstrate_kv_api():
tq.kv_clear(keys=keys_to_clear, partition_id=partition_id)
print(f" ✓ Cleared keys: {keys_to_clear}")
- remaining_keys, _ = tq.kv_list(partition_id=partition_id)
- print(f" Remaining keys in partition: {remaining_keys}")
+ partition_info_after_clear = tq.kv_list(partition_id=partition_id)
+ print(f" Remaining keys in partition: {list(partition_info_after_clear[partition_id].keys())}")
def main():
From 3678eac0f9a9e54d4293a9f65303e73fcbe01f18 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 15:01:23 +0800
Subject: [PATCH 25/34] provide KVBatchMeta
Signed-off-by: 0oshowero0
---
transfer_queue/__init__.py | 29 +++---
transfer_queue/metadata.py | 202 +++++++++++++++++++++++++++++++++++++
2 files changed, 217 insertions(+), 14 deletions(-)
diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py
index 26f1225..60c8f1d 100644
--- a/transfer_queue/__init__.py
+++ b/transfer_queue/__init__.py
@@ -44,7 +44,7 @@
put,
set_custom_meta,
)
-from .metadata import BatchMeta
+from .metadata import BatchMeta, KVBatchMeta
from .sampler import BaseSampler
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
from .sampler.rank_aware_sampler import RankAwareSampler
@@ -55,18 +55,6 @@
__all__ = [
"init",
- "get_meta",
- "get_data",
- "put",
- "set_custom_meta",
- "clear_samples",
- "clear_partition",
- "async_get_meta",
- "async_get_data",
- "async_put",
- "async_set_custom_meta",
- "async_clear_samples",
- "async_clear_partition",
"close",
"kv_put",
"kv_batch_put",
@@ -78,11 +66,24 @@
"async_kv_batch_get",
"async_kv_list",
"async_kv_clear",
+ "KVBatchMeta",
] + [
+ "get_meta",
+ "get_data",
+ "put",
+ "set_custom_meta",
+ "clear_samples",
+ "clear_partition",
+ "async_get_meta",
+ "async_get_data",
+ "async_put",
+ "async_set_custom_meta",
+ "async_clear_samples",
+ "async_clear_partition",
+ "BatchMeta",
"TransferQueueClient",
"StreamingDataset",
"StreamingDataLoader",
- "BatchMeta",
"TransferQueueController",
"SimpleStorageUnit",
"ZMQServerInfo",
diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py
index 2abde1b..eb2520c 100644
--- a/transfer_queue/metadata.py
+++ b/transfer_queue/metadata.py
@@ -829,3 +829,205 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) ->
]
return all_fields
+
+
+# ==================== KV Interface Metadata ====================
+@dataclass
+class KVBatchMeta:
+ """Records the metadata for KV interface."""
+
+ # keys of each sample
+ keys: list[str] = dataclasses.field(default_factory=list)
+
+ # sample-level tags
+ tags: list[dict] = dataclasses.field(default_factory=list)
+
+ # [optional] partition_id of this batch
+ partition_id: Optional[str] = None
+
+ # [optional] fields of each sample
+ fields: list[str] = dataclasses.field(default_factory=list)
+
+ # [optional] external information for batch-level information
+ extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
+
+ def __post_init__(self):
+ """Validate all the variables"""
+ if len(self.keys) != len(self.tags):
+ raise ValueError(f"keys and tags must have same length, but got {len(self.keys)} and {len(self.tags)}")
+ if len(self.keys) != len(self.partition_id):
+ raise ValueError(
+ f"keys and partition_ids must have same length, but got {len(self.keys)} and {len(self.partition_ids)}"
+ )
+ if len(self.keys) != len(set(self.keys)):
+ raise ValueError("Got duplicated keys.")
+ if len(self.fields) != len(set(self.fields)):
+ raise ValueError("Got duplicated fields.")
+
+ # deepcopy to prevent unexpected behavior after chunk/concat
+ self.tags = copy.deepcopy(self.tags)
+ self.extra_info = copy.deepcopy(self.extra_info)
+
+ object.__setattr__(self, "_size", len(self.keys))
+
+ @property
+ def size(self) -> int:
+ """Return the number of samples in this batch"""
+ return getattr(self, "_size", 0)
+
+ def __len__(self) -> int:
+ """Return the number of samples in this batch."""
+ return len(self.keys)
+
+ def __str__(self):
+ return f"KVBatchMeta(size={self.size}, field_names={self.fields}, extra_info={self.extra_info})"
+
+ def select_keys(self, keys_to_select: list[str]) -> "KVBatchMeta":
+ """
+ Select specific keys from this batch.
+ This will construct a new KVBatchMeta instance containing only the specified keys.
+
+ Args:
+ keys_to_select (list[str]): List of keys to retain.
+
+ Returns:
+ KVBatchMeta: A new KVBatchMeta instance containing only the specified keys.
+
+ Raises:
+ ValueError: If duplicate keys exist in input param `keys_to_select`.
+ RuntimeError: If `keys_to_select` contains keys that do not exist in this batch.
+ """
+
+ if len(set(keys_to_select)) != len(keys_to_select):
+ raise ValueError("Contain duplicate keys.")
+
+ non_exist_keys = set(keys_to_select) - set(self.keys)
+ if len(non_exist_keys) > 0:
+ raise RuntimeError(f"Keys {non_exist_keys} not found in current batch.")
+
+ _keys_to_idx = {key: idx for idx, key in enumerate(self.keys)}
+
+ loc_idx = [_keys_to_idx[k] for k in keys_to_select]
+ tags = [self.tags[i] for i in loc_idx]
+
+ return KVBatchMeta(
+ keys=keys_to_select,
+ tags=tags,
+ partition_id=self.partition_id,
+ fields=self.fields,
+ extra_info=self.extra_info,
+ )
+
+ def reorder(self, indexes: list[int]):
+ """
+ Reorder the samples in this batch according to the specified indexes.
+
+ The operation is performed in-place.
+
+ Args:
+ indexes : list[int]
+ A list of integers specifying the new order of SampleMeta.
+
+ Raises:
+ ValueError: If the size of input `indexes` does not match with the batch size.
+ ValueError: If duplicate indexes exist in input param `indexes`.
+ """
+ if len(indexes) != self.size:
+ raise ValueError(
+ f"Attempted to reorder with indexes length {len(indexes)} that does not match "
+ f"the batch size {self.size}."
+ )
+
+ if len(set(indexes)) != len(indexes):
+ raise ValueError("Contain duplicate indexes.")
+
+ self.keys = [self.keys[i] for i in indexes]
+ self.tags = [self.tags[i] for i in indexes]
+
+ def chunk(self, chunks: int) -> list["KVBatchMeta"]:
+ """
+ Split this batch into smaller chunks.
+
+ Args:
+ chunks: number of chunks
+
+ Return:
+ List of smaller KVBatchMeta chunks
+ """
+
+ chunk_list = []
+ if self.size < chunks:
+ logger.warning(
+ f"Chunk size {chunks} > number of samples in this batch {self.size}, this will return some "
+ f"empty KVBatchMeta chunks."
+ )
+
+ # Calculate the base size and remainder of each chunk
+ base_size = self.size // chunks
+ remainder = self.size % chunks
+
+ start = 0
+ for i in range(chunks):
+ # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size)
+ current_chunk_size = base_size + 1 if i < remainder else base_size
+ end = start + current_chunk_size
+ chunk_keys = self.keys[start:end]
+ chunk_tags = self.tags[start:end]
+
+ chunk = KVBatchMeta(
+ keys=chunk_keys,
+ tags=chunk_tags,
+ partition_id=self.partition_id,
+ fields=self.fields,
+ extra_info=self.extra_info,
+ )
+ chunk_list.append(chunk)
+ start = end
+
+ return chunk_list
+
+ @classmethod
+ def concat(cls, data: list["KVBatchMeta"]) -> "KVBatchMeta":
+ """
+ Concatenate multiple KVBatchMeta chunks into one large batch.
+
+ Args:
+ data: List of KVBatchMeta chunks to concatenate
+
+ Returns:
+ Concatenated KVBatchMeta
+
+ Raises:
+ ValueError: If validation fails (e.g., field names do not match)
+ """
+ if not data:
+ logger.warning("Try to concat empty KVBatchMeta chunks. Returning empty KVBatchMeta.")
+ return KVBatchMeta()
+
+ # skip empty chunks
+ data = [chunk for chunk in data if chunk and chunk.size > 0]
+
+ if len(data) == 0:
+ logger.warning("No valid KVBatchMeta chunks to concatenate. Returning empty KVBatchMeta.")
+ return KVBatchMeta()
+
+ base_fields = data[0].fields
+ base_fields_set = set(base_fields)
+ base_partition_id = data[0].partition_id
+
+ all_keys = []
+ all_tags = []
+ all_extra_info = {}
+ for chunk in data:
+ if set(chunk.fields) != base_fields_set:
+ raise ValueError("Field names do not match for concatenation.")
+ if chunk.partition_id != base_partition_id:
+ raise ValueError("Partition do not match for concatenation.")
+
+ all_keys.extend(chunk.keys)
+ all_tags.extend(chunk.tags)
+ all_extra_info.update(chunk.extra_info)
+
+ return KVBatchMeta(
+ keys=all_keys, tags=all_tags, partition_id=base_partition_id, fields=base_fields, extra_info=all_extra_info
+ )
From 42a78cf5f2e27bbffe454a0c526ea7bc31620071 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 15:18:38 +0800
Subject: [PATCH 26/34] add KVBatchMeta tests
Signed-off-by: 0oshowero0
---
tests/test_metadata.py | 289 ++++++++++++++++++++++++++++++++++++-
transfer_queue/metadata.py | 4 -
2 files changed, 288 insertions(+), 5 deletions(-)
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index 6db10a6..23ba56d 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -27,7 +27,7 @@
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
-from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402
+from transfer_queue.metadata import BatchMeta, FieldMeta, KVBatchMeta, SampleMeta # noqa: E402
from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402
@@ -1045,3 +1045,290 @@ def test_batch_meta_concat_validation_error(self):
with pytest.raises(ValueError) as exc_info:
BatchMeta.concat([batch1, batch2], validate=True)
assert "Field names do not match" in str(exc_info.value)
+
+
+class TestKVBatchMeta:
+ """KVBatchMeta Tests"""
+
+ def test_kv_batch_meta_basic_init(self):
+ """Example: Basic KVBatchMeta initialization."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{"sample_id": 0}, {"sample_id": 1}, {"sample_id": 2}],
+ partition_id="partition_0",
+ fields=["field1", "field2"],
+ )
+
+ assert kv_meta.size == 3
+ assert len(kv_meta) == 3
+ assert kv_meta.keys == ["key1", "key2", "key3"]
+ assert kv_meta.partition_id == "partition_0"
+ assert kv_meta.fields == ["field1", "field2"]
+
+ def test_kv_batch_meta_empty_init(self):
+ """Example: Empty KVBatchMeta initialization."""
+ kv_meta = KVBatchMeta()
+
+ assert kv_meta.size == 0
+ assert len(kv_meta) == 0
+ assert kv_meta.keys == []
+ assert kv_meta.tags == []
+ assert kv_meta.partition_id is None
+ assert kv_meta.fields == []
+
+ def test_kv_batch_meta_init_validation_keys_tags_mismatch(self):
+ """Example: Init validation catches keys and tags length mismatch."""
+ with pytest.raises(ValueError) as exc_info:
+ KVBatchMeta(
+ keys=["key1", "key2"],
+ tags=[{"sample_id": 0}], # Only one tag
+ )
+ assert "keys and tags must have same length" in str(exc_info.value)
+
+ def test_kv_batch_meta_init_validation_duplicate_keys(self):
+ """Example: Init validation catches duplicate keys."""
+ with pytest.raises(ValueError) as exc_info:
+ KVBatchMeta(
+ keys=["key1", "key1"],
+ tags=[{"sample_id": 0}, {"sample_id": 1}],
+ partition_id="partition_0",
+ )
+ assert "Got duplicated keys" in str(exc_info.value)
+
+ def test_kv_batch_meta_init_validation_duplicate_fields(self):
+ """Example: Init validation catches duplicate fields."""
+ with pytest.raises(ValueError) as exc_info:
+ KVBatchMeta(
+ keys=["key1"],
+ tags=[{"sample_id": 0}],
+ partition_id="partition_0",
+ fields=["field1", "field1"],
+ )
+ assert "Got duplicated fields" in str(exc_info.value)
+
+ def test_kv_batch_meta_select_keys(self):
+ """Example: Select specific keys from KVBatchMeta."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{"idx": 0}, {"idx": 1}, {"idx": 2}],
+ partition_id="partition_0",
+ fields=["field1", "field2"],
+ extra_info={"test": "value"},
+ )
+
+ selected = kv_meta.select_keys(["key1", "key3"])
+
+ assert selected.keys == ["key1", "key3"]
+ assert selected.tags == [{"idx": 0}, {"idx": 2}]
+ assert selected.partition_id == "partition_0"
+ assert selected.fields == ["field1", "field2"]
+ assert selected.extra_info == {"test": "value"}
+
+ def test_kv_batch_meta_select_keys_validation_duplicate(self):
+ """Example: Select keys validation catches duplicate keys in input."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{}, {}, {}],
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ kv_meta.select_keys(["key1", "key1"])
+ assert "Contain duplicate keys" in str(exc_info.value)
+
+ def test_kv_batch_meta_select_keys_validation_nonexistent(self):
+ """Example: Select keys validation catches non-existent keys."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{}, {}, {}],
+ )
+
+ with pytest.raises(RuntimeError) as exc_info:
+ kv_meta.select_keys(["key1", "nonexistent"])
+ assert "not found in current batch" in str(exc_info.value)
+
+ def test_kv_batch_meta_reorder(self):
+ """Example: Reorder samples in KVBatchMeta."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{"idx": 0}, {"idx": 1}, {"idx": 2}],
+ )
+
+ kv_meta.reorder([2, 0, 1])
+
+ assert kv_meta.keys == ["key3", "key1", "key2"]
+ assert kv_meta.tags == [{"idx": 2}, {"idx": 0}, {"idx": 1}]
+
+ def test_kv_batch_meta_reorder_validation_size_mismatch(self):
+ """Example: Reorder validation catches size mismatch."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{}, {}, {}],
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ kv_meta.reorder([0, 1]) # Only 2 indexes for 3 samples
+ assert "does not match" in str(exc_info.value)
+
+ def test_kv_batch_meta_reorder_validation_duplicate_indexes(self):
+ """Example: Reorder validation catches duplicate indexes."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{}, {}, {}],
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ kv_meta.reorder([0, 0, 1]) # Duplicate index 0
+ assert "Contain duplicate indexes" in str(exc_info.value)
+
+ def test_kv_batch_meta_chunk(self):
+ """Example: Split KVBatchMeta into multiple chunks."""
+ kv_meta = KVBatchMeta(
+ keys=[f"key{i}" for i in range(10)],
+ tags=[{"idx": i} for i in range(10)],
+ partition_id="partition_0",
+ fields=["field1"],
+ extra_info={"test": "value"},
+ )
+
+ chunks = kv_meta.chunk(3)
+
+ assert len(chunks) == 3
+ assert len(chunks[0]) == 4 # First chunk gets extra element
+ assert len(chunks[1]) == 3
+ assert len(chunks[2]) == 3
+
+ # Verify partition_id and fields are preserved
+ assert chunks[0].partition_id == "partition_0"
+ assert chunks[0].fields == ["field1"]
+ assert chunks[0].extra_info == {"test": "value"}
+
+ # Verify keys and tags are correctly chunked
+ assert chunks[0].keys == ["key0", "key1", "key2", "key3"]
+ assert chunks[0].tags == [{"idx": 0}, {"idx": 1}, {"idx": 2}, {"idx": 3}]
+ assert chunks[1].keys == ["key4", "key5", "key6"]
+ assert chunks[1].tags == [{"idx": 4}, {"idx": 5}, {"idx": 6}]
+
+ def test_kv_batch_meta_chunk_with_more_chunks_than_samples(self):
+ """Example: Chunking when chunks > samples produces empty chunks."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2"],
+ tags=[{"idx": 0}, {"idx": 1}],
+ )
+
+ chunks = kv_meta.chunk(5)
+
+ assert len(chunks) == 5
+ assert len(chunks[0]) == 1
+ assert len(chunks[1]) == 1
+ assert len(chunks[2]) == 0
+ assert len(chunks[3]) == 0
+ assert len(chunks[4]) == 0
+
+ def test_kv_batch_meta_concat(self):
+ """Example: Concatenate multiple KVBatchMeta chunks."""
+ kv_meta1 = KVBatchMeta(
+ keys=["key0", "key1"],
+ tags=[{"idx": 0}, {"idx": 1}],
+ partition_id="partition_0",
+ fields=["field1"],
+ extra_info={"test": "value1"},
+ )
+
+ kv_meta2 = KVBatchMeta(
+ keys=["key2", "key3"],
+ tags=[{"idx": 2}, {"idx": 3}],
+ partition_id="partition_0",
+ fields=["field1"],
+ extra_info={"test": "value2"},
+ )
+
+ result = KVBatchMeta.concat([kv_meta1, kv_meta2])
+
+ assert result.size == 4
+ assert result.keys == ["key0", "key1", "key2", "key3"]
+ assert result.tags == [{"idx": 0}, {"idx": 1}, {"idx": 2}, {"idx": 3}]
+ assert result.partition_id == "partition_0"
+ assert result.fields == ["field1"]
+
+ def test_kv_batch_meta_concat_with_empty_chunks(self):
+ """Example: Concat handles empty KVBatchMeta chunks gracefully."""
+ kv_meta1 = KVBatchMeta()
+ kv_meta2 = KVBatchMeta(keys=["key0"], tags=[{"idx": 0}])
+ kv_meta3 = KVBatchMeta()
+
+ result = KVBatchMeta.concat([kv_meta1, kv_meta2, kv_meta3])
+
+ assert result.size == 1
+ assert result.keys == ["key0"]
+ assert result.tags == [{"idx": 0}]
+
+ def test_kv_batch_meta_concat_validation_field_mismatch(self):
+ """Example: Concat validation catches field name mismatches."""
+ kv_meta1 = KVBatchMeta(
+ keys=["key0"],
+ tags=[{}],
+ fields=["field1"],
+ )
+ kv_meta2 = KVBatchMeta(
+ keys=["key1"],
+ tags=[{}],
+ fields=["field2"], # Different field
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ KVBatchMeta.concat([kv_meta1, kv_meta2])
+ assert "Field names do not match" in str(exc_info.value)
+
+ def test_kv_batch_meta_concat_validation_partition_mismatch(self):
+ """Example: Concat validation catches partition_id mismatches."""
+ kv_meta1 = KVBatchMeta(
+ keys=["key0"],
+ tags=[{}],
+ partition_id="partition_0",
+ )
+ kv_meta2 = KVBatchMeta(
+ keys=["key1"],
+ tags=[{}],
+ partition_id="partition_1", # Different partition
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ KVBatchMeta.concat([kv_meta1, kv_meta2])
+ assert "Partition do not match" in str(exc_info.value)
+
+ def test_kv_batch_meta_concat_empty_list(self):
+ """Example: Concat with empty list returns empty KVBatchMeta."""
+ result = KVBatchMeta.concat([])
+
+ assert result.size == 0
+ assert result.keys == []
+ assert result.tags == []
+
+ def test_kv_batch_meta_deepcopy_tags(self):
+ """Example: Tags are deep copied to prevent mutation."""
+ original_tags = [{"data": [1, 2, 3]}]
+ kv_meta = KVBatchMeta(
+ keys=["key1"],
+ tags=original_tags,
+ )
+
+ # Modify the tag in the KVBatchMeta
+ kv_meta.tags[0]["data"].append(4)
+
+ # Original should not be modified
+ assert original_tags[0]["data"] == [1, 2, 3]
+
+ def test_kv_batch_meta_deepcopy_extra_info(self):
+ """Example: Extra info is deep copied to prevent mutation."""
+ original_extra = {"nested": {"value": 1}}
+ kv_meta = KVBatchMeta(
+ keys=["key1"],
+ tags=[{}],
+ extra_info=original_extra,
+ )
+
+ # Modify extra_info
+ kv_meta.extra_info["nested"]["value"] = 999
+
+ # Original should not be modified
+ assert original_extra["nested"]["value"] == 1
diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py
index eb2520c..eec0058 100644
--- a/transfer_queue/metadata.py
+++ b/transfer_queue/metadata.py
@@ -855,10 +855,6 @@ def __post_init__(self):
"""Validate all the variables"""
if len(self.keys) != len(self.tags):
raise ValueError(f"keys and tags must have same length, but got {len(self.keys)} and {len(self.tags)}")
- if len(self.keys) != len(self.partition_id):
- raise ValueError(
- f"keys and partition_ids must have same length, but got {len(self.keys)} and {len(self.partition_ids)}"
- )
if len(self.keys) != len(set(self.keys)):
raise ValueError("Got duplicated keys.")
if len(self.fields) != len(set(self.fields)):
From b301c740d507607bfcf2184472498a574cbfa66c Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 16:34:43 +0800
Subject: [PATCH 27/34] fix comments
Signed-off-by: 0oshowero0
---
transfer_queue/client.py | 2 +-
transfer_queue/controller.py | 6 +++---
transfer_queue/interface.py | 8 ++++----
3 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/transfer_queue/client.py b/transfer_queue/client.py
index 5b0c1e1..b06f90c 100644
--- a/transfer_queue/client.py
+++ b/transfer_queue/client.py
@@ -812,7 +812,7 @@ async def async_reset_consumption(
task_name: Optional[str] = None,
socket: Optional[zmq.asyncio.Socket] = None,
) -> bool:
- """Aynchronously reset consumption status for a partition.
+ """Asynchronously reset consumption status for a partition.
This allows the same data to be re-consumed, useful for debugging scenarios
where the same rollout data needs to be trained multiple times.
diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py
index 3dcc1b5..3c7c3f9 100644
--- a/transfer_queue/controller.py
+++ b/transfer_queue/controller.py
@@ -988,19 +988,19 @@ def list_partitions(self) -> list[str]:
# ==================== Partition Index Management API ====================
- def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]:
+ def get_partition_index_range(self, partition_id: str) -> list[int]:
"""
Get all indexes for a specific partition.
Args:
- partition: Partition identifier
+ partition_id: Partition identifier
Returns:
List of indexes allocated to the partition
"""
# Note: This includes the pre-allocated global_indexes for the partition.
# i.e., partition.global_indexes + partition.pre_allocated_global_indexes
- return self.index_manager.get_indexes_for_partition(partition)
+ return self.index_manager.get_indexes_for_partition(partition_id)
# ==================== Data Production API ====================
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index 32670de..70ce7c2 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -701,7 +701,7 @@ def kv_put(
... )
"""
if fields is None and tag is None:
- raise ValueError("Please provide at least one parameter of fields or tag.")
+ raise ValueError("Please provide at least one parameter of `fields` or `tag`.")
tq_client = _maybe_create_transferqueue_client()
@@ -712,7 +712,7 @@ def kv_put(
raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
# 2. register the user-specified tag to BatchMeta
- if tag:
+ if tag is not None:
batch_meta.update_custom_meta([tag])
# 3. put data
@@ -1030,7 +1030,7 @@ async def async_kv_batch_put(
"""
if fields is None and tags is None:
- raise ValueError("Please provide at least one parameter of fields or tag.")
+ raise ValueError("Please provide at least one parameter of `fields` or `tags`.")
if fields is not None and fields.batch_size[0] != len(keys):
raise ValueError(
@@ -1049,7 +1049,7 @@ async def async_kv_batch_put(
)
# 2. register the user-specified tags to BatchMeta
- if tags:
+ if tags is not None:
if len(tags) != len(keys):
raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
batch_meta.update_custom_meta(tags)
From f53cad973867b3fa0f7383b217f23ab91f7569d9 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 17:22:22 +0800
Subject: [PATCH 28/34] change default value to None in KVBatchMeta
Signed-off-by: 0oshowero0
---
transfer_queue/metadata.py | 21 +++++++++++++++------
1 file changed, 15 insertions(+), 6 deletions(-)
diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py
index eec0058..473a487 100644
--- a/transfer_queue/metadata.py
+++ b/transfer_queue/metadata.py
@@ -846,10 +846,10 @@ class KVBatchMeta:
partition_id: Optional[str] = None
# [optional] fields of each sample
- fields: list[str] = dataclasses.field(default_factory=list)
+ fields: Optional[list[str]] = None
# [optional] external information for batch-level information
- extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
+ extra_info: Optional[dict[str, Any]] = None
def __post_init__(self):
"""Validate all the variables"""
@@ -1008,22 +1008,31 @@ def concat(cls, data: list["KVBatchMeta"]) -> "KVBatchMeta":
return KVBatchMeta()
base_fields = data[0].fields
- base_fields_set = set(base_fields)
+ if base_fields is not None:
+ base_fields_set = set(base_fields)
+ else:
+ base_fields_set = set()
+
base_partition_id = data[0].partition_id
all_keys = []
all_tags = []
all_extra_info = {}
for chunk in data:
- if set(chunk.fields) != base_fields_set:
+ if chunk.fields is not None and set(chunk.fields) != base_fields_set:
raise ValueError("Field names do not match for concatenation.")
if chunk.partition_id != base_partition_id:
raise ValueError("Partition do not match for concatenation.")
all_keys.extend(chunk.keys)
all_tags.extend(chunk.tags)
- all_extra_info.update(chunk.extra_info)
+ if chunk.extra_info is not None:
+ all_extra_info.update(chunk.extra_info)
return KVBatchMeta(
- keys=all_keys, tags=all_tags, partition_id=base_partition_id, fields=base_fields, extra_info=all_extra_info
+ keys=all_keys,
+ tags=all_tags,
+ partition_id=base_partition_id,
+ fields=base_fields,
+ extra_info=all_extra_info if all_extra_info else None,
)
From f9a9830e74cc6797577a0a55b466ef0bc1feb4f9 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 17:28:06 +0800
Subject: [PATCH 29/34] update
Signed-off-by: 0oshowero0
---
README.md | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index 150c666..8131b94 100644
--- a/README.md
+++ b/README.md
@@ -95,10 +95,10 @@ This data structure design is motivated by the computational characteristics of
### User Interface: High-Level & Low-Level APIs
| Level | Tier | Style | Fine-Grained Access | Streaming | Sampler | Multiple-Backends |
-|---|---|---|---|---|---|---|
-| High | **KV Interface** (this PR) | Put/Get/List/Clear | ✓ | ✗ | ✗ | ✓ |
-| High | **StreamingDataLoader** (#23) | PyTorch DataLoader | ✓ |✓ | ✓ | ✓ |
-| Low | **TransferQueueClient** | Metadata-based | ✓ | ✓ | ✓ | ✓ |
+|---|---|---|---|------------------|---|---|
+| High | **KV Interface** (this PR) | Put/Get/List/Clear | ✓ | ○ | ✗ | ✓ |
+| High | **StreamingDataLoader** (#23) | PyTorch DataLoader | ✓ | ✓ | ✓ | ✓ |
+| Low | **TransferQueueClient** | Metadata-based | ✓ | ✓ | ✓ | ✓ |
#### Key-Value based API
From f49741d853cdf31ca868184dfd6c25c2450207f2 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 17:32:18 +0800
Subject: [PATCH 30/34] change function order
Signed-off-by: 0oshowero0
---
transfer_queue/interface.py | 1364 +++++++++++++++++------------------
1 file changed, 681 insertions(+), 683 deletions(-)
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index 70ce7c2..fb6b5f7 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -107,6 +107,7 @@ def _init_from_existing() -> None:
time.sleep(1)
+# ==================== Initialization API ====================
def init(conf: Optional[DictConfig] = None) -> None:
"""Initialize the TransferQueue system.
@@ -196,480 +197,306 @@ def init(conf: Optional[DictConfig] = None) -> None:
_maybe_create_transferqueue_client(final_conf)
-# ==================== Basic API ====================
-def get_meta(
- data_fields: list[str],
- batch_size: int,
- partition_id: str,
- mode: str = "fetch",
- task_name: Optional[str] = None,
- sampling_config: Optional[dict[str, Any]] = None,
-) -> BatchMeta:
- """Synchronously fetch data metadata from the controller via ZMQ.
-
- Args:
- data_fields: List of data field names to retrieve metadata for
- batch_size: Number of samples to request in the batch
- partition_id: Current data partition id
- mode: Data fetch mode. Options:
- - 'fetch': Get ready data only
- - 'force_fetch': Get data regardless of readiness (may return unready samples)
- - 'insert': Internal usage - should not be used by users
- task_name: Optional task name associated with the request
- sampling_config: Optional sampling configuration for custom samplers.
-
- Returns:
- BatchMeta: Metadata object containing data structure, sample information, and readiness status
-
- Raises:
- RuntimeError: If communication fails or controller returns error response
-
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> # Example 1: Basic fetch metadata
- >>> batch_meta = tq.get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences"
- ... )
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
- >>> batch_meta = tq.get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=8,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... sampling_config={"n_samples_per_prompt": 4}
- ... )
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
- >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
- >>> batch_meta = tq.get_meta(
- ... partition_id="train_0", # optional
- ... mode="force_fetch",
- ... )
- >>> print(batch_meta.is_ready) # May be False if some samples not ready
- """
-
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config)
-
-
-def set_custom_meta(metadata: BatchMeta) -> None:
- """Synchronously send custom metadata to the controller.
-
- This method sends per-sample custom metadata (custom_meta) to the controller.
- The custom_meta is stored in the controller and can be retrieved along with
- the BatchMeta in subsequent get_meta calls.
-
- Args:
- metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta()
- before calling this method.
+def close():
+ """Close the TransferQueue system.
- Raises:
- RuntimeError: If communication fails or controller returns error response
+ This function cleans up the TransferQueue system, including:
+ - Closing the client and its associated resources
+ - Cleaning up distributed storage (only for the process that initialized it)
+ - Killing the controller actor
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> # Create batch with custom metadata
- >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...)
- >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
- >>> tq.set_custom_meta(batch_meta)
+ Note:
+ This function should be called when the TransferQueue system is no longer needed.
"""
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.set_custom_meta(metadata)
+ global _TRANSFER_QUEUE_CLIENT
+ global _TRANSFER_QUEUE_STORAGE
+ if _TRANSFER_QUEUE_CLIENT:
+ _TRANSFER_QUEUE_CLIENT.close()
+ _TRANSFER_QUEUE_CLIENT = None
+ try:
+ if _TRANSFER_QUEUE_STORAGE:
+ # only the process that do first-time init can clean the distributed storage
+ for storage in _TRANSFER_QUEUE_STORAGE.values():
+ ray.kill(storage)
+ _TRANSFER_QUEUE_STORAGE = None
+ except Exception:
+ pass
-def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta:
- """Synchronously write data to storage units based on metadata.
+ try:
+ controller = ray.get_actor("TransferQueueController")
+ ray.kill(controller)
+ except Exception:
+ pass
- If metadata is not provided, it will be created automatically using insert mode
- with the provided data fields and partition_id.
- During put, the custom_meta in metadata will update the corresponding custom_meta in
- TransferQueue Controller.
+# ==================== High-Level KV Interface API ====================
+def kv_put(
+ key: str,
+ partition_id: str,
+ fields: Optional[TensorDict | dict[str, Any]] = None,
+ tag: Optional[dict[str, Any]] = None,
+) -> None:
+ """Put a single key-value pair to TransferQueue.
- Note:
- When using multiple workers for distributed execution, there may be data
- ordering inconsistencies between workers during put operations.
+ This is a convenience method for putting data using a user-specified key
+ instead of BatchMeta. Internally, the key is translated to a BatchMeta
+ and the data is stored using the regular put mechanism.
Args:
- data: Data to write as TensorDict
- metadata: Records the metadata of a batch of data samples, containing index and
- storage unit information. If None, metadata will be auto-generated.
- partition_id: Target data partition id (required if metadata is not provided)
-
- Returns:
- BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
- metadata; will be updated in a future version to reflect the post-put state)
+ key: User-specified key for the data sample (in row)
+ partition_id: Logical partition to store the data in
+ fields: Data fields to store. Can be a TensorDict or a dict of tensors.
+ Each key in `fields` will be treated as a column for the data sample.
+ If dict is provided, tensors will be unsqueezed to add batch dimension.
+ tag: Optional metadata tag to associate with the key
Raises:
- ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
- RuntimeError: If storage operation fails
-
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> batch_size = 4
- >>> seq_len = 16
- >>> current_partition_id = "train_0"
- >>> # Example 1: Normal usage with existing metadata
- >>> batch_meta = tq.get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=batch_size,
- ... partition_id=current_partition_id,
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... )
- >>> batch = tq.get_data(batch_meta)
- >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
- >>> tq.put(data=output, metadata=batch_meta)
- >>>
- >>> # Example 2: Initial data insertion without pre-existing metadata
- >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
- >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
- >>> # without metadata.
- >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
- >>> # interleave the initial data if n_sample > 1 before calling the async_put().
- >>> original_prompts = torch.randn(batch_size, seq_len)
- >>> n_samples = 4
- >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
- >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
- >>> # This will create metadata in "insert" mode internally.
- >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id)
- """
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.put(data, metadata, partition_id)
-
-
-def get_data(metadata: BatchMeta) -> TensorDict:
- """Synchronously fetch data from storage units and organize into TensorDict.
-
- Args:
- metadata: Batch metadata containing data location information and global indexes
-
- Returns:
- TensorDict containing:
- - Requested data fields (e.g., "prompts", "attention_mask")
+ ValueError: If neither fields nor tag is provided
+ ValueError: If nested tensors are provided (use kv_batch_put instead)
+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
Example:
>>> import transfer_queue as tq
+ >>> import torch
>>> tq.init()
- >>>
- >>> batch_meta = tq.get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
+ >>> # Put with both fields and tag
+ >>> tq.kv_put(
+ ... key="sample_1",
+ ... partition_id="train",
+ ... fields={"input_ids": torch.tensor([1, 2, 3])},
+ ... tag={"score": 0.95}
... )
- >>> batch = tq.get_data(batch_meta)
- >>> print(batch)
- >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
"""
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.get_data(metadata)
-
-
-def clear_partition(partition_id: str):
- """Synchronously clear the whole partition from all storage units and the controller.
-
- Args:
- partition_id: The partition id to clear data for
-
- Raises:
- RuntimeError: If clear operation fails
- """
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.clear_partition(partition_id)
-
-
-def clear_samples(metadata: BatchMeta):
- """Synchronously clear specific samples from all storage units and the controller.
-
- Args:
- metadata: The BatchMeta of the corresponding data to be cleared
+ if fields is None and tag is None:
+ raise ValueError("Please provide at least one parameter of `fields` or `tag`.")
- Raises:
- RuntimeError: If clear operation fails
- """
tq_client = _maybe_create_transferqueue_client()
- return tq_client.clear_samples(metadata)
-
-async def async_get_meta(
- data_fields: list[str],
- batch_size: int,
- partition_id: str,
- mode: str = "fetch",
- task_name: Optional[str] = None,
- sampling_config: Optional[dict[str, Any]] = None,
-) -> BatchMeta:
- """Asynchronously fetch data metadata from the controller via ZMQ.
-
- Args:
- data_fields: List of data field names to retrieve metadata for
- batch_size: Number of samples to request in the batch
- partition_id: Current data partition id
- mode: Data fetch mode. Options:
- - 'fetch': Get ready data only
- - 'force_fetch': Get data regardless of readiness (may return unready samples)
- - 'insert': Internal usage - should not be used by users
- task_name: Optional task name associated with the request
- sampling_config: Optional sampling configuration for custom samplers.
+ # 1. translate user-specified key to BatchMeta
+ batch_meta = tq_client.kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True)
- Returns:
- BatchMeta: Metadata object containing data structure, sample information, and readiness status
+ if batch_meta.size != 1:
+ raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
- Raises:
- RuntimeError: If communication fails or controller returns error response
+ # 2. register the user-specified tag to BatchMeta
+ if tag is not None:
+ batch_meta.update_custom_meta([tag])
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> # Example 1: Basic fetch metadata
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences"
- ... ))
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=8,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... ))
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
- >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... partition_id="train_0", # optional
- ... mode="force_fetch",
- ... ))
- >>> print(batch_meta.is_ready) # May be False if some samples not ready
- """
+ # 3. put data
+ if fields is not None:
+ if isinstance(fields, dict):
+ # TODO: consider whether to support this...
+ batch = {}
+ for field_name, value in fields.items():
+ if isinstance(value, torch.Tensor):
+ if value.is_nested:
+ raise ValueError("Please use (async)kv_batch_put for batch operation")
+ batch[field_name] = value.unsqueeze(0)
+ else:
+ batch[field_name] = NonTensorStack(value)
+ fields = TensorDict(batch, batch_size=[1])
+ elif not isinstance(fields, TensorDict):
+ raise ValueError("field can only be dict or TensorDict")
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config)
+ # custom_meta (tag) will be put to controller through the internal put process
+ tq_client.put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tag) to controller
+ tq_client.set_custom_meta(batch_meta)
-async def async_set_custom_meta(
- metadata: BatchMeta,
+def kv_batch_put(
+ keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None
) -> None:
- """
- Asynchronously send custom metadata to the controller.
+ """Put multiple key-value pairs to TransferQueue in batch.
- This method sends per-sample custom metadata (custom_meta) to the controller.
- The custom_meta is stored in the controller and can be retrieved along with
- the BatchMeta in subsequent get_meta calls.
+ This method stores multiple key-value pairs in a single operation, which is more
+ efficient than calling kv_put multiple times.
Args:
- metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta()
- before calling this method.
- socket: ZMQ async socket for message transmission (injected by decorator)
+ keys: List of user-specified keys for the data
+ partition_id: Logical partition to store the data in
+ fields: TensorDict containing data for all keys. Must have batch_size == len(keys)
+ tags: List of metadata tags, one for each key
Raises:
- RuntimeError: If communication fails or controller returns error response
+ ValueError: If neither `fields` nor `tags` is provided
+ ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
Example:
>>> import transfer_queue as tq
+ >>> from tensordict import TensorDict
>>> tq.init()
- >>>
- >>> # Create batch with custom metadata
- >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...)
- >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
- >>> asyncio.run(tq.async_set_custom_meta(batch_meta))
+ >>> keys = ["sample_1", "sample_2", "sample_3"]
+ >>> fields = TensorDict({
+ ... "input_ids": torch.randn(3, 10),
+ ... "attention_mask": torch.ones(3, 10),
+ ... }, batch_size=3)
+ >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
+ >>> tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
"""
+
+ if fields is None and tags is None:
+ raise ValueError("Please provide at least one parameter of fields or tag.")
+
+ if fields is not None and fields.batch_size[0] != len(keys):
+ raise ValueError(
+ f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
+ f"batch_size {fields.batch_size[0]}"
+ )
+
tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_set_custom_meta(metadata)
+ # 1. translate user-specified key to BatchMeta
+ batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True)
-async def async_put(
- data: TensorDict,
- metadata: Optional[BatchMeta] = None,
- partition_id: Optional[str] = None,
-) -> BatchMeta:
- """Asynchronously write data to storage units based on metadata.
+ if batch_meta.size != len(keys):
+ raise RuntimeError(
+ f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!"
+ )
- If metadata is not provided, it will be created automatically using insert mode
- with the provided data fields and partition_id.
+ # 2. register the user-specified tags to BatchMeta
+ if tags:
+ if len(tags) != len(keys):
+ raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
+ batch_meta.update_custom_meta(tags)
- During put, the custom_meta in metadata will update the corresponding custom_meta in
- TransferQueue Controller.
+ # 3. put data
+ if fields is not None:
+ tq_client.put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tags) to controller
+ tq_client.set_custom_meta(batch_meta)
- Note:
- When using multiple workers for distributed execution, there may be data
- ordering inconsistencies between workers during put operations.
+
+def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict:
+ """Get data from TransferQueue using user-specified keys.
+
+ This is a convenience method for retrieving data using keys instead of indexes.
Args:
- data: Data to write as TensorDict
- metadata: Records the metadata of a batch of data samples, containing index and
- storage unit information. If None, metadata will be auto-generated.
- partition_id: Target data partition id (required if metadata is not provided)
+ keys: Single key or list of keys to retrieve
+ partition_id: Partition containing the keys
+ fields: Optional field(s) to retrieve. If None, retrieves all fields
Returns:
- BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
- metadata; will be updated in a future version to reflect the post-put state)
+ TensorDict with the requested data
Raises:
- ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
- RuntimeError: If storage operation fails
+ RuntimeError: If keys or partition are not found
+ RuntimeError: If empty fields exist in any key (sample)
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>>
- >>> batch_size = 4
- >>> seq_len = 16
- >>> current_partition_id = "train_0"
- >>> # Example 1: Normal usage with existing metadata
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=batch_size,
- ... partition_id=current_partition_id,
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... ))
- >>> batch = asyncio.run(tq.async_get_data(batch_meta))
- >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
- >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta))
- >>>
- >>> # Example 2: Initial data insertion without pre-existing metadata
- >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
- >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
- >>> # without metadata.
- >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
- >>> # interleave the initial data if n_sample > 1 before calling the async_put().
- >>> original_prompts = torch.randn(batch_size, seq_len)
- >>> n_samples = 4
- >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
- >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
- >>> # This will create metadata in "insert" mode internally.
- >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id))
+ >>> # Get single key with all fields
+ >>> data = tq.kv_batch_get(keys="sample_1", partition_id="train")
+ >>> # Get multiple keys with specific fields
+ >>> data = tq.kv_batch_get(
+ ... keys=["sample_1", "sample_2"],
+ ... partition_id="train",
+ ... fields="input_ids"
+ ... )
"""
tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_put(data, metadata, partition_id)
+ batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
-async def async_get_data(metadata: BatchMeta) -> TensorDict:
- """Asynchronously fetch data from storage units and organize into TensorDict.
+ if batch_meta.size == 0:
+ raise RuntimeError("keys or partition were not found!")
+
+ if fields is not None:
+ if isinstance(fields, str):
+ fields = [fields]
+ batch_meta = batch_meta.select_fields(fields)
+
+ if not batch_meta.is_ready:
+ raise RuntimeError("Some fields are not ready in all the requested keys!")
+
+ data = tq_client.get_data(batch_meta)
+ return data
+
+
+def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]:
+ """List all keys and their metadata in one or all partitions.
Args:
- metadata: Batch metadata containing data location information and global indexes
+ partition_id: The specific partition_id to query.
+ If None (default), returns keys from all partitions.
Returns:
- TensorDict containing:
- - Requested data fields (e.g., "prompts", "attention_mask")
+ A nested dictionary mapping partition IDs to their keys and metadata.
+
+ Structure:
+ {
+ "partition_id": {
+ "key_name": {
+ "tag1": ,
+ ... (other metadata)
+ },
+ ...,
+ },
+ ...
+ }
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>>
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... ))
- >>> batch = asyncio.run(tq.async_get_data(batch_meta))
- >>> print(batch)
- >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
+ >>> # Case 1: Retrieve a specific partition
+ >>> partitions = tq.kv_list(partition_id="train")
+ >>> print(f"Keys: {list(partitions['train'].keys())}")
+ >>> print(f"Tags: {list(partitions['train'].values())}")
+ >>> # Case 2: Retrieve all partitions
+ >>> all_partitions = tq.kv_list()
+ >>> for pid, keys in all_partitions.items():
+ >>> print(f"Partition: {pid}, Key count: {len(keys)}")
"""
tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_get_data(metadata)
-
-
-# ==================== Data Operations API ====================
+ partition_info = tq_client.kv_list(partition_id)
-async def async_clear_samples(metadata: BatchMeta):
- """Asynchronously clear specific samples from all storage units and the controller.
-
- Args:
- metadata: The BatchMeta of the corresponding data to be cleared
+ return partition_info
- Raises:
- RuntimeError: If clear operation fails
- """
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_clear_samples(metadata)
+def kv_clear(keys: list[str] | str, partition_id: str) -> None:
+ """Clear key-value pairs from TransferQueue.
-async def async_clear_partition(partition_id: str):
- """Asynchronously clear the whole partition from all storage units and the controller.
+ This removes the specified keys and their associated data from both
+ the controller and storage units.
Args:
- partition_id: The partition id to clear data for
+ keys: Single key or list of keys to clear
+ partition_id: Partition containing the keys
- Raises:
- RuntimeError: If clear operation fails
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> # Clear single key
+ >>> tq.kv_clear(keys="sample_1", partition_id="train")
+ >>> # Clear multiple keys
+ >>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
"""
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_clear_partition(partition_id)
-
-
-def close():
- """Close the TransferQueue system.
- This function cleans up the TransferQueue system, including:
- - Closing the client and its associated resources
- - Cleaning up distributed storage (only for the process that initialized it)
- - Killing the controller actor
-
- Note:
- This function should be called when the TransferQueue system is no longer needed.
- """
- global _TRANSFER_QUEUE_CLIENT
- global _TRANSFER_QUEUE_STORAGE
- if _TRANSFER_QUEUE_CLIENT:
- _TRANSFER_QUEUE_CLIENT.close()
- _TRANSFER_QUEUE_CLIENT = None
+ if isinstance(keys, str):
+ keys = [keys]
- try:
- if _TRANSFER_QUEUE_STORAGE:
- # only the process that do first-time init can clean the distributed storage
- for storage in _TRANSFER_QUEUE_STORAGE.values():
- ray.kill(storage)
- _TRANSFER_QUEUE_STORAGE = None
- except Exception:
- pass
+ tq_client = _maybe_create_transferqueue_client()
+ batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
- try:
- controller = ray.get_actor("TransferQueueController")
- ray.kill(controller)
- except Exception:
- pass
+ if batch_meta.size > 0:
+ tq_client.clear_samples(batch_meta)
# ==================== KV Interface API ====================
-def kv_put(
+async def async_kv_put(
key: str,
partition_id: str,
fields: Optional[TensorDict | dict[str, Any]] = None,
tag: Optional[dict[str, Any]] = None,
) -> None:
- """Put a single key-value pair to TransferQueue.
+ """Asynchronously put a single key-value pair to TransferQueue.
This is a convenience method for putting data using a user-specified key
instead of BatchMeta. Internally, the key is translated to a BatchMeta
@@ -693,26 +520,27 @@ def kv_put(
>>> import torch
>>> tq.init()
>>> # Put with both fields and tag
- >>> tq.kv_put(
+ >>> await tq.async_kv_put(
... key="sample_1",
... partition_id="train",
... fields={"input_ids": torch.tensor([1, 2, 3])},
... tag={"score": 0.95}
- ... )
+ ... ))
"""
+
if fields is None and tag is None:
- raise ValueError("Please provide at least one parameter of `fields` or `tag`.")
+ raise ValueError("Please provide at least one parameter of fields or tag.")
tq_client = _maybe_create_transferqueue_client()
# 1. translate user-specified key to BatchMeta
- batch_meta = tq_client.kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True)
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True)
if batch_meta.size != 1:
raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
# 2. register the user-specified tag to BatchMeta
- if tag is not None:
+ if tag:
batch_meta.update_custom_meta([tag])
# 3. put data
@@ -731,17 +559,17 @@ def kv_put(
elif not isinstance(fields, TensorDict):
raise ValueError("field can only be dict or TensorDict")
- # custom_meta (tag) will be put to controller through the internal put process
- tq_client.put(fields, batch_meta)
+ # custom_meta (tag) will be put to controller through the put process
+ await tq_client.async_put(fields, batch_meta)
else:
# directly update custom_meta (tag) to controller
- tq_client.set_custom_meta(batch_meta)
+ await tq_client.async_set_custom_meta(batch_meta)
-def kv_batch_put(
+async def async_kv_batch_put(
keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None
) -> None:
- """Put multiple key-value pairs to TransferQueue in batch.
+ """Asynchronously put multiple key-value pairs to TransferQueue in batch.
This method stores multiple key-value pairs in a single operation, which is more
efficient than calling kv_put multiple times.
@@ -759,7 +587,6 @@ def kv_batch_put(
Example:
>>> import transfer_queue as tq
- >>> from tensordict import TensorDict
>>> tq.init()
>>> keys = ["sample_1", "sample_2", "sample_3"]
>>> fields = TensorDict({
@@ -767,11 +594,11 @@ def kv_batch_put(
... "attention_mask": torch.ones(3, 10),
... }, batch_size=3)
>>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
- >>> tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
+ >>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
"""
if fields is None and tags is None:
- raise ValueError("Please provide at least one parameter of fields or tag.")
+ raise ValueError("Please provide at least one parameter of `fields` or `tags`.")
if fields is not None and fields.batch_size[0] != len(keys):
raise ValueError(
@@ -782,7 +609,7 @@ def kv_batch_put(
tq_client = _maybe_create_transferqueue_client()
# 1. translate user-specified key to BatchMeta
- batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True)
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True)
if batch_meta.size != len(keys):
raise RuntimeError(
@@ -790,394 +617,565 @@ def kv_batch_put(
)
# 2. register the user-specified tags to BatchMeta
- if tags:
+ if tags is not None:
if len(tags) != len(keys):
raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
batch_meta.update_custom_meta(tags)
- # 3. put data
- if fields is not None:
- tq_client.put(fields, batch_meta)
- else:
- # directly update custom_meta (tags) to controller
- tq_client.set_custom_meta(batch_meta)
+ # 3. put data
+ if fields is not None:
+ await tq_client.async_put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tags) to controller
+ await tq_client.async_set_custom_meta(batch_meta)
+
+
+async def async_kv_batch_get(
+ keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None
+) -> TensorDict:
+ """Asynchronously get data from TransferQueue using user-specified keys.
+
+ This is a convenience method for retrieving data using keys instead of indexes.
+
+ Args:
+ keys: Single key or list of keys to retrieve
+ partition_id: Partition containing the keys
+ fields: Optional field(s) to retrieve. If None, retrieves all fields
+
+ Returns:
+ TensorDict with the requested data
+
+ Raises:
+ RuntimeError: If keys or partition are not found
+ RuntimeError: If empty fields exist in any key (sample)
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> # Get single key with all fields
+ >>> data = await tq.async_kv_batch_get(keys="sample_1", partition_id="train")
+ >>> # Get multiple keys with specific fields
+ >>> data = await tq.async_kv_batch_get(
+ ... keys=["sample_1", "sample_2"],
+ ... partition_id="train",
+ ... fields="input_ids"
+ ... )
+ """
+ tq_client = _maybe_create_transferqueue_client()
+
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
+
+ if batch_meta.size == 0:
+ raise RuntimeError("keys or partition were not found!")
+
+ if fields is not None:
+ if isinstance(fields, str):
+ fields = [fields]
+ batch_meta = batch_meta.select_fields(fields)
+
+ if not batch_meta.is_ready:
+ raise RuntimeError("Some fields are not ready in all the requested keys!")
+
+ data = await tq_client.async_get_data(batch_meta)
+ return data
+
+
+async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]:
+ """Asynchronously list all keys and their metadata in one or all partitions.
+
+ Args:
+ partition_id: The specific partition_id to query.
+ If None (default), returns keys from all partitions.
+
+ Returns:
+ A nested dictionary mapping partition IDs to their keys and metadata.
+
+ Structure:
+ {
+ "partition_id": {
+ "key_name": {
+ "tag1": ,
+ ... (other metadata)
+ },
+ ...,
+ },
+ ...
+ }
+
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> # Case 1: Retrieve a specific partition
+ >>> partitions = await tq.async_kv_list(partition_id="train")
+ >>> print(f"Keys: {list(partitions['train'].keys())}")
+ >>> print(f"Tags: {list(partitions['train'].values())}")
+ >>> # Case 2: Retrieve all partitions
+ >>> all_partitions = await tq.async_kv_list()
+ >>> for pid, keys in all_partitions.items():
+ >>> print(f"Partition: {pid}, Key count: {len(keys)}")
+ """
+ tq_client = _maybe_create_transferqueue_client()
+
+ partition_info = await tq_client.async_kv_list(partition_id)
+
+ return partition_info
+
+
+async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
+ """Asynchronously clear key-value pairs from TransferQueue.
+
+ This removes the specified keys and their associated data from both
+ the controller and storage units.
+
+ Args:
+ keys: Single key or list of keys to clear
+ partition_id: Partition containing the keys
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> # Clear single key
+ >>> await tq.async_kv_clear(keys="sample_1", partition_id="train")
+ >>> # Clear multiple keys
+ >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
+ """
+
+ if isinstance(keys, str):
+ keys = [keys]
+
+ tq_client = _maybe_create_transferqueue_client()
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
+
+ if batch_meta.size > 0:
+ await tq_client.async_clear_samples(batch_meta)
+
+
+# ==================== Low-Level Native API ====================
+def get_meta(
+ data_fields: list[str],
+ batch_size: int,
+ partition_id: str,
+ mode: str = "fetch",
+ task_name: Optional[str] = None,
+ sampling_config: Optional[dict[str, Any]] = None,
+) -> BatchMeta:
+ """Synchronously fetch data metadata from the controller via ZMQ.
+
+ Args:
+ data_fields: List of data field names to retrieve metadata for
+ batch_size: Number of samples to request in the batch
+ partition_id: Current data partition id
+ mode: Data fetch mode. Options:
+ - 'fetch': Get ready data only
+ - 'force_fetch': Get data regardless of readiness (may return unready samples)
+ - 'insert': Internal usage - should not be used by users
+ task_name: Optional task name associated with the request
+ sampling_config: Optional sampling configuration for custom samplers.
+
+ Returns:
+ BatchMeta: Metadata object containing data structure, sample information, and readiness status
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>>
+ >>> # Example 1: Basic fetch metadata
+ >>> batch_meta = tq.get_meta(
+ ... data_fields=["input_ids", "attention_mask"],
+ ... batch_size=4,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences"
+ ... )
+ >>> print(batch_meta.is_ready) # True if all samples ready
+ >>>
+ >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
+ >>> batch_meta = tq.get_meta(
+ ... data_fields=["input_ids", "attention_mask"],
+ ... batch_size=8,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... sampling_config={"n_samples_per_prompt": 4}
+ ... )
+ >>> print(batch_meta.is_ready) # True if all samples ready
+ >>>
+ >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
+ >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
+ >>> batch_meta = tq.get_meta(
+ ... partition_id="train_0", # optional
+ ... mode="force_fetch",
+ ... )
+ >>> print(batch_meta.is_ready) # May be False if some samples not ready
+ """
+
+ tq_client = _maybe_create_transferqueue_client()
+ return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config)
+
+
+def set_custom_meta(metadata: BatchMeta) -> None:
+ """Synchronously send custom metadata to the controller.
+
+ This method sends per-sample custom metadata (custom_meta) to the controller.
+ The custom_meta is stored in the controller and can be retrieved along with
+ the BatchMeta in subsequent get_meta calls.
+
+ Args:
+ metadata: BatchMeta containing the samples and their custom metadata to store.
+ The custom_meta should be set using BatchMeta.update_custom_meta()
+ before calling this method.
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>>
+ >>> # Create batch with custom metadata
+ >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...)
+ >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
+ >>> tq.set_custom_meta(batch_meta)
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return tq_client.set_custom_meta(metadata)
+
+
+def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta:
+ """Synchronously write data to storage units based on metadata.
+ If metadata is not provided, it will be created automatically using insert mode
+ with the provided data fields and partition_id.
-def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict:
- """Get data from TransferQueue using user-specified keys.
+ During put, the custom_meta in metadata will update the corresponding custom_meta in
+ TransferQueue Controller.
- This is a convenience method for retrieving data using keys instead of indexes.
+ Note:
+ When using multiple workers for distributed execution, there may be data
+ ordering inconsistencies between workers during put operations.
Args:
- keys: Single key or list of keys to retrieve
- partition_id: Partition containing the keys
- fields: Optional field(s) to retrieve. If None, retrieves all fields
+ data: Data to write as TensorDict
+ metadata: Records the metadata of a batch of data samples, containing index and
+ storage unit information. If None, metadata will be auto-generated.
+ partition_id: Target data partition id (required if metadata is not provided)
Returns:
- TensorDict with the requested data
+ BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
+ metadata; will be updated in a future version to reflect the post-put state)
Raises:
- RuntimeError: If keys or partition are not found
- RuntimeError: If empty fields exist in any key (sample)
+ ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
+ RuntimeError: If storage operation fails
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>> # Get single key with all fields
- >>> data = tq.kv_batch_get(keys="sample_1", partition_id="train")
- >>> # Get multiple keys with specific fields
- >>> data = tq.kv_batch_get(
- ... keys=["sample_1", "sample_2"],
- ... partition_id="train",
- ... fields="input_ids"
+ >>>
+ >>> batch_size = 4
+ >>> seq_len = 16
+ >>> current_partition_id = "train_0"
+ >>> # Example 1: Normal usage with existing metadata
+ >>> batch_meta = tq.get_meta(
+ ... data_fields=["prompts", "attention_mask"],
+ ... batch_size=batch_size,
+ ... partition_id=current_partition_id,
+ ... mode="fetch",
+ ... task_name="generate_sequences",
... )
+ >>> batch = tq.get_data(batch_meta)
+ >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
+ >>> tq.put(data=output, metadata=batch_meta)
+ >>>
+ >>> # Example 2: Initial data insertion without pre-existing metadata
+ >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
+ >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
+ >>> # without metadata.
+ >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
+ >>> # interleave the initial data if n_sample > 1 before calling the async_put().
+ >>> original_prompts = torch.randn(batch_size, seq_len)
+ >>> n_samples = 4
+ >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
+ >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
+ >>> # This will create metadata in "insert" mode internally.
+ >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id)
"""
tq_client = _maybe_create_transferqueue_client()
-
- batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
-
- if batch_meta.size == 0:
- raise RuntimeError("keys or partition were not found!")
-
- if fields is not None:
- if isinstance(fields, str):
- fields = [fields]
- batch_meta = batch_meta.select_fields(fields)
-
- if not batch_meta.is_ready:
- raise RuntimeError("Some fields are not ready in all the requested keys!")
-
- data = tq_client.get_data(batch_meta)
- return data
+ return tq_client.put(data, metadata, partition_id)
-def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]:
- """List all keys and their metadata in one or all partitions.
+def get_data(metadata: BatchMeta) -> TensorDict:
+ """Synchronously fetch data from storage units and organize into TensorDict.
Args:
- partition_id: The specific partition_id to query.
- If None (default), returns keys from all partitions.
+ metadata: Batch metadata containing data location information and global indexes
Returns:
- A nested dictionary mapping partition IDs to their keys and metadata.
-
- Structure:
- {
- "partition_id": {
- "key_name": {
- "tag1": ,
- ... (other metadata)
- },
- ...,
- },
- ...
- }
+ TensorDict containing:
+ - Requested data fields (e.g., "prompts", "attention_mask")
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>> # Case 1: Retrieve a specific partition
- >>> partitions = tq.kv_list(partition_id="train")
- >>> print(f"Keys: {list(partitions['train'].keys())}")
- >>> print(f"Tags: {list(partitions['train'].values())}")
- >>> # Case 2: Retrieve all partitions
- >>> all_partitions = tq.kv_list()
- >>> for pid, keys in all_partitions.items():
- >>> print(f"Partition: {pid}, Key count: {len(keys)}")
+ >>>
+ >>> batch_meta = tq.get_meta(
+ ... data_fields=["prompts", "attention_mask"],
+ ... batch_size=4,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... )
+ >>> batch = tq.get_data(batch_meta)
+ >>> print(batch)
+ >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
"""
tq_client = _maybe_create_transferqueue_client()
+ return tq_client.get_data(metadata)
- partition_info = tq_client.kv_list(partition_id)
- return partition_info
+def clear_partition(partition_id: str):
+ """Synchronously clear the whole partition from all storage units and the controller.
+ Args:
+ partition_id: The partition id to clear data for
-def kv_clear(keys: list[str] | str, partition_id: str) -> None:
- """Clear key-value pairs from TransferQueue.
+ Raises:
+ RuntimeError: If clear operation fails
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return tq_client.clear_partition(partition_id)
- This removes the specified keys and their associated data from both
- the controller and storage units.
+
+def clear_samples(metadata: BatchMeta):
+ """Synchronously clear specific samples from all storage units and the controller.
Args:
- keys: Single key or list of keys to clear
- partition_id: Partition containing the keys
+ metadata: The BatchMeta of the corresponding data to be cleared
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>> # Clear single key
- >>> tq.kv_clear(keys="sample_1", partition_id="train")
- >>> # Clear multiple keys
- >>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
+ Raises:
+ RuntimeError: If clear operation fails
"""
-
- if isinstance(keys, str):
- keys = [keys]
-
tq_client = _maybe_create_transferqueue_client()
- batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
-
- if batch_meta.size > 0:
- tq_client.clear_samples(batch_meta)
+ return tq_client.clear_samples(metadata)
-# ==================== KV Interface API ====================
-async def async_kv_put(
- key: str,
+async def async_get_meta(
+ data_fields: list[str],
+ batch_size: int,
partition_id: str,
- fields: Optional[TensorDict | dict[str, Any]] = None,
- tag: Optional[dict[str, Any]] = None,
-) -> None:
- """Asynchronously put a single key-value pair to TransferQueue.
-
- This is a convenience method for putting data using a user-specified key
- instead of BatchMeta. Internally, the key is translated to a BatchMeta
- and the data is stored using the regular put mechanism.
+ mode: str = "fetch",
+ task_name: Optional[str] = None,
+ sampling_config: Optional[dict[str, Any]] = None,
+) -> BatchMeta:
+ """Asynchronously fetch data metadata from the controller via ZMQ.
Args:
- key: User-specified key for the data sample (in row)
- partition_id: Logical partition to store the data in
- fields: Data fields to store. Can be a TensorDict or a dict of tensors.
- Each key in `fields` will be treated as a column for the data sample.
- If dict is provided, tensors will be unsqueezed to add batch dimension.
- tag: Optional metadata tag to associate with the key
+ data_fields: List of data field names to retrieve metadata for
+ batch_size: Number of samples to request in the batch
+ partition_id: Current data partition id
+ mode: Data fetch mode. Options:
+ - 'fetch': Get ready data only
+ - 'force_fetch': Get data regardless of readiness (may return unready samples)
+ - 'insert': Internal usage - should not be used by users
+ task_name: Optional task name associated with the request
+ sampling_config: Optional sampling configuration for custom samplers.
+
+ Returns:
+ BatchMeta: Metadata object containing data structure, sample information, and readiness status
Raises:
- ValueError: If neither fields nor tag is provided
- ValueError: If nested tensors are provided (use kv_batch_put instead)
- RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
+ RuntimeError: If communication fails or controller returns error response
Example:
>>> import transfer_queue as tq
- >>> import torch
>>> tq.init()
- >>> # Put with both fields and tag
- >>> await tq.async_kv_put(
- ... key="sample_1",
- ... partition_id="train",
- ... fields={"input_ids": torch.tensor([1, 2, 3])},
- ... tag={"score": 0.95}
+ >>>
+ >>> # Example 1: Basic fetch metadata
+ >>> batch_meta = asyncio.run(tq.async_get_meta(
+ ... data_fields=["input_ids", "attention_mask"],
+ ... batch_size=4,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences"
... ))
- """
-
- if fields is None and tag is None:
- raise ValueError("Please provide at least one parameter of fields or tag.")
-
- tq_client = _maybe_create_transferqueue_client()
-
- # 1. translate user-specified key to BatchMeta
- batch_meta = await tq_client.async_kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True)
-
- if batch_meta.size != 1:
- raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
-
- # 2. register the user-specified tag to BatchMeta
- if tag:
- batch_meta.update_custom_meta([tag])
-
- # 3. put data
- if fields is not None:
- if isinstance(fields, dict):
- # TODO: consider whether to support this...
- batch = {}
- for field_name, value in fields.items():
- if isinstance(value, torch.Tensor):
- if value.is_nested:
- raise ValueError("Please use (async)kv_batch_put for batch operation")
- batch[field_name] = value.unsqueeze(0)
- else:
- batch[field_name] = NonTensorStack(value)
- fields = TensorDict(batch, batch_size=[1])
- elif not isinstance(fields, TensorDict):
- raise ValueError("field can only be dict or TensorDict")
+ >>> print(batch_meta.is_ready) # True if all samples ready
+ >>>
+ >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
+ >>> batch_meta = asyncio.run(tq.async_get_meta(
+ ... data_fields=["input_ids", "attention_mask"],
+ ... batch_size=8,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... ))
+ >>> print(batch_meta.is_ready) # True if all samples ready
+ >>>
+ >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
+ >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
+ >>> batch_meta = asyncio.run(tq.async_get_meta(
+ ... partition_id="train_0", # optional
+ ... mode="force_fetch",
+ ... ))
+ >>> print(batch_meta.is_ready) # May be False if some samples not ready
+ """
- # custom_meta (tag) will be put to controller through the put process
- await tq_client.async_put(fields, batch_meta)
- else:
- # directly update custom_meta (tag) to controller
- await tq_client.async_set_custom_meta(batch_meta)
+ tq_client = _maybe_create_transferqueue_client()
+ return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config)
-async def async_kv_batch_put(
- keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None
+async def async_set_custom_meta(
+ metadata: BatchMeta,
) -> None:
- """Asynchronously put multiple key-value pairs to TransferQueue in batch.
+ """
+ Asynchronously send custom metadata to the controller.
- This method stores multiple key-value pairs in a single operation, which is more
- efficient than calling kv_put multiple times.
+ This method sends per-sample custom metadata (custom_meta) to the controller.
+ The custom_meta is stored in the controller and can be retrieved along with
+ the BatchMeta in subsequent get_meta calls.
Args:
- keys: List of user-specified keys for the data
- partition_id: Logical partition to store the data in
- fields: TensorDict containing data for all keys. Must have batch_size == len(keys)
- tags: List of metadata tags, one for each key
+ metadata: BatchMeta containing the samples and their custom metadata to store.
+ The custom_meta should be set using BatchMeta.update_custom_meta()
+ before calling this method.
+ socket: ZMQ async socket for message transmission (injected by decorator)
Raises:
- ValueError: If neither `fields` nor `tags` is provided
- ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
- RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
+ RuntimeError: If communication fails or controller returns error response
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>> keys = ["sample_1", "sample_2", "sample_3"]
- >>> fields = TensorDict({
- ... "input_ids": torch.randn(3, 10),
- ... "attention_mask": torch.ones(3, 10),
- ... }, batch_size=3)
- >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
- >>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
+ >>>
+ >>> # Create batch with custom metadata
+ >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...)
+ >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
+ >>> asyncio.run(tq.async_set_custom_meta(batch_meta))
"""
-
- if fields is None and tags is None:
- raise ValueError("Please provide at least one parameter of `fields` or `tags`.")
-
- if fields is not None and fields.batch_size[0] != len(keys):
- raise ValueError(
- f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
- f"batch_size {fields.batch_size[0]}"
- )
-
tq_client = _maybe_create_transferqueue_client()
+ return await tq_client.async_set_custom_meta(metadata)
- # 1. translate user-specified key to BatchMeta
- batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True)
-
- if batch_meta.size != len(keys):
- raise RuntimeError(
- f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!"
- )
-
- # 2. register the user-specified tags to BatchMeta
- if tags is not None:
- if len(tags) != len(keys):
- raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
- batch_meta.update_custom_meta(tags)
- # 3. put data
- if fields is not None:
- await tq_client.async_put(fields, batch_meta)
- else:
- # directly update custom_meta (tags) to controller
- await tq_client.async_set_custom_meta(batch_meta)
+async def async_put(
+ data: TensorDict,
+ metadata: Optional[BatchMeta] = None,
+ partition_id: Optional[str] = None,
+) -> BatchMeta:
+ """Asynchronously write data to storage units based on metadata.
+ If metadata is not provided, it will be created automatically using insert mode
+ with the provided data fields and partition_id.
-async def async_kv_batch_get(
- keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None
-) -> TensorDict:
- """Asynchronously get data from TransferQueue using user-specified keys.
+ During put, the custom_meta in metadata will update the corresponding custom_meta in
+ TransferQueue Controller.
- This is a convenience method for retrieving data using keys instead of indexes.
+ Note:
+ When using multiple workers for distributed execution, there may be data
+ ordering inconsistencies between workers during put operations.
Args:
- keys: Single key or list of keys to retrieve
- partition_id: Partition containing the keys
- fields: Optional field(s) to retrieve. If None, retrieves all fields
+ data: Data to write as TensorDict
+ metadata: Records the metadata of a batch of data samples, containing index and
+ storage unit information. If None, metadata will be auto-generated.
+ partition_id: Target data partition id (required if metadata is not provided)
Returns:
- TensorDict with the requested data
+ BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
+ metadata; will be updated in a future version to reflect the post-put state)
Raises:
- RuntimeError: If keys or partition are not found
- RuntimeError: If empty fields exist in any key (sample)
+ ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
+ RuntimeError: If storage operation fails
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>> # Get single key with all fields
- >>> data = await tq.async_kv_batch_get(keys="sample_1", partition_id="train")
- >>> # Get multiple keys with specific fields
- >>> data = await tq.async_kv_batch_get(
- ... keys=["sample_1", "sample_2"],
- ... partition_id="train",
- ... fields="input_ids"
- ... )
+ >>>
+ >>> batch_size = 4
+ >>> seq_len = 16
+ >>> current_partition_id = "train_0"
+ >>> # Example 1: Normal usage with existing metadata
+ >>> batch_meta = asyncio.run(tq.async_get_meta(
+ ... data_fields=["prompts", "attention_mask"],
+ ... batch_size=batch_size,
+ ... partition_id=current_partition_id,
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... ))
+ >>> batch = asyncio.run(tq.async_get_data(batch_meta))
+ >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
+ >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta))
+ >>>
+ >>> # Example 2: Initial data insertion without pre-existing metadata
+ >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
+ >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
+ >>> # without metadata.
+ >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
+ >>> # interleave the initial data if n_sample > 1 before calling the async_put().
+ >>> original_prompts = torch.randn(batch_size, seq_len)
+ >>> n_samples = 4
+ >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
+ >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
+ >>> # This will create metadata in "insert" mode internally.
+ >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id))
"""
tq_client = _maybe_create_transferqueue_client()
-
- batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
-
- if batch_meta.size == 0:
- raise RuntimeError("keys or partition were not found!")
-
- if fields is not None:
- if isinstance(fields, str):
- fields = [fields]
- batch_meta = batch_meta.select_fields(fields)
-
- if not batch_meta.is_ready:
- raise RuntimeError("Some fields are not ready in all the requested keys!")
-
- data = await tq_client.async_get_data(batch_meta)
- return data
+ return await tq_client.async_put(data, metadata, partition_id)
-async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]:
- """Asynchronously list all keys and their metadata in one or all partitions.
+async def async_get_data(metadata: BatchMeta) -> TensorDict:
+ """Asynchronously fetch data from storage units and organize into TensorDict.
Args:
- partition_id: The specific partition_id to query.
- If None (default), returns keys from all partitions.
+ metadata: Batch metadata containing data location information and global indexes
Returns:
- A nested dictionary mapping partition IDs to their keys and metadata.
-
- Structure:
- {
- "partition_id": {
- "key_name": {
- "tag1": ,
- ... (other metadata)
- },
- ...,
- },
- ...
- }
-
+ TensorDict containing:
+ - Requested data fields (e.g., "prompts", "attention_mask")
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>> # Case 1: Retrieve a specific partition
- >>> partitions = await tq.async_kv_list(partition_id="train")
- >>> print(f"Keys: {list(partitions['train'].keys())}")
- >>> print(f"Tags: {list(partitions['train'].values())}")
- >>> # Case 2: Retrieve all partitions
- >>> all_partitions = await tq.async_kv_list()
- >>> for pid, keys in all_partitions.items():
- >>> print(f"Partition: {pid}, Key count: {len(keys)}")
+ >>>
+ >>> batch_meta = asyncio.run(tq.async_get_meta(
+ ... data_fields=["prompts", "attention_mask"],
+ ... batch_size=4,
+ ... partition_id="train_0",
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... ))
+ >>> batch = asyncio.run(tq.async_get_data(batch_meta))
+ >>> print(batch)
+ >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
"""
tq_client = _maybe_create_transferqueue_client()
+ return await tq_client.async_get_data(metadata)
- partition_info = await tq_client.async_kv_list(partition_id)
- return partition_info
+async def async_clear_samples(metadata: BatchMeta):
+ """Asynchronously clear specific samples from all storage units and the controller.
+ Args:
+ metadata: The BatchMeta of the corresponding data to be cleared
-async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
- """Asynchronously clear key-value pairs from TransferQueue.
+ Raises:
+ RuntimeError: If clear operation fails
+ """
+ tq_client = _maybe_create_transferqueue_client()
+ return await tq_client.async_clear_samples(metadata)
- This removes the specified keys and their associated data from both
- the controller and storage units.
+
+async def async_clear_partition(partition_id: str):
+ """Asynchronously clear the whole partition from all storage units and the controller.
Args:
- keys: Single key or list of keys to clear
- partition_id: Partition containing the keys
+ partition_id: The partition id to clear data for
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>> # Clear single key
- >>> await tq.async_kv_clear(keys="sample_1", partition_id="train")
- >>> # Clear multiple keys
- >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
+ Raises:
+ RuntimeError: If clear operation fails
"""
-
- if isinstance(keys, str):
- keys = [keys]
-
tq_client = _maybe_create_transferqueue_client()
- batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
-
- if batch_meta.size > 0:
- await tq_client.async_clear_samples(batch_meta)
+ return await tq_client.async_clear_partition(partition_id)
From 1bac20f2de639b0b6b3b04c4bc4691546cb4be85 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 21:03:28 +0800
Subject: [PATCH 31/34] simplify interface
Signed-off-by: 0oshowero0
---
transfer_queue/__init__.py | 94 +++-----
transfer_queue/interface.py | 434 +-----------------------------------
2 files changed, 43 insertions(+), 485 deletions(-)
diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py
index 60c8f1d..4732e84 100644
--- a/transfer_queue/__init__.py
+++ b/transfer_queue/__init__.py
@@ -16,84 +16,64 @@
import os
from .client import TransferQueueClient
-from .controller import TransferQueueController
from .dataloader import StreamingDataLoader, StreamingDataset
from .interface import (
- async_clear_partition,
- async_clear_samples,
- async_get_data,
- async_get_meta,
async_kv_batch_get,
async_kv_batch_put,
async_kv_clear,
async_kv_list,
async_kv_put,
- async_put,
- async_set_custom_meta,
- clear_partition,
- clear_samples,
close,
- get_data,
- get_meta,
+ get_client,
init,
kv_batch_get,
kv_batch_put,
kv_clear,
kv_list,
kv_put,
- put,
- set_custom_meta,
)
from .metadata import BatchMeta, KVBatchMeta
from .sampler import BaseSampler
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
from .sampler.rank_aware_sampler import RankAwareSampler
from .sampler.sequential_sampler import SequentialSampler
-from .storage import SimpleStorageUnit
-from .utils.common import get_placement_group
-from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info
-__all__ = [
- "init",
- "close",
- "kv_put",
- "kv_batch_put",
- "kv_batch_get",
- "kv_list",
- "kv_clear",
- "async_kv_put",
- "async_kv_batch_put",
- "async_kv_batch_get",
- "async_kv_list",
- "async_kv_clear",
- "KVBatchMeta",
-] + [
- "get_meta",
- "get_data",
- "put",
- "set_custom_meta",
- "clear_samples",
- "clear_partition",
- "async_get_meta",
- "async_get_data",
- "async_put",
- "async_set_custom_meta",
- "async_clear_samples",
- "async_clear_partition",
- "BatchMeta",
- "TransferQueueClient",
- "StreamingDataset",
- "StreamingDataLoader",
- "TransferQueueController",
- "SimpleStorageUnit",
- "ZMQServerInfo",
- "process_zmq_server_info",
- "get_placement_group",
- "BaseSampler",
- "GRPOGroupNSampler",
- "SequentialSampler",
- "RankAwareSampler",
-]
+__all__ = (
+ [
+ # High-Level KV Interface
+ "init",
+ "close",
+ "kv_put",
+ "kv_batch_put",
+ "kv_batch_get",
+ "kv_list",
+ "kv_clear",
+ "async_kv_put",
+ "async_kv_batch_put",
+ "async_kv_batch_get",
+ "async_kv_list",
+ "async_kv_clear",
+ "KVBatchMeta",
+ ]
+ + [
+ # High-Level StreamingDataLoader Interface
+ "StreamingDataset",
+ "StreamingDataLoader",
+ ]
+ + [
+ # Low-Level Native Interface
+ "get_client",
+ "BatchMeta",
+ "TransferQueueClient",
+ ]
+ + [
+ # Sampler
+ "BaseSampler",
+ "GRPOGroupNSampler",
+ "SequentialSampler",
+ "RankAwareSampler",
+ ]
+)
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index fb6b5f7..733df24 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -28,7 +28,6 @@
from transfer_queue.client import TransferQueueClient
from transfer_queue.controller import TransferQueueController
-from transfer_queue.metadata import BatchMeta
from transfer_queue.sampler import * # noqa: F401
from transfer_queue.sampler import BaseSampler
from transfer_queue.storage.simple_backend import SimpleStorageUnit
@@ -752,430 +751,9 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
# ==================== Low-Level Native API ====================
-def get_meta(
- data_fields: list[str],
- batch_size: int,
- partition_id: str,
- mode: str = "fetch",
- task_name: Optional[str] = None,
- sampling_config: Optional[dict[str, Any]] = None,
-) -> BatchMeta:
- """Synchronously fetch data metadata from the controller via ZMQ.
-
- Args:
- data_fields: List of data field names to retrieve metadata for
- batch_size: Number of samples to request in the batch
- partition_id: Current data partition id
- mode: Data fetch mode. Options:
- - 'fetch': Get ready data only
- - 'force_fetch': Get data regardless of readiness (may return unready samples)
- - 'insert': Internal usage - should not be used by users
- task_name: Optional task name associated with the request
- sampling_config: Optional sampling configuration for custom samplers.
-
- Returns:
- BatchMeta: Metadata object containing data structure, sample information, and readiness status
-
- Raises:
- RuntimeError: If communication fails or controller returns error response
-
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> # Example 1: Basic fetch metadata
- >>> batch_meta = tq.get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences"
- ... )
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
- >>> batch_meta = tq.get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=8,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... sampling_config={"n_samples_per_prompt": 4}
- ... )
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
- >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
- >>> batch_meta = tq.get_meta(
- ... partition_id="train_0", # optional
- ... mode="force_fetch",
- ... )
- >>> print(batch_meta.is_ready) # May be False if some samples not ready
- """
-
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config)
-
-
-def set_custom_meta(metadata: BatchMeta) -> None:
- """Synchronously send custom metadata to the controller.
-
- This method sends per-sample custom metadata (custom_meta) to the controller.
- The custom_meta is stored in the controller and can be retrieved along with
- the BatchMeta in subsequent get_meta calls.
-
- Args:
- metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta()
- before calling this method.
-
- Raises:
- RuntimeError: If communication fails or controller returns error response
-
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> # Create batch with custom metadata
- >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...)
- >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
- >>> tq.set_custom_meta(batch_meta)
- """
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.set_custom_meta(metadata)
-
-
-def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta:
- """Synchronously write data to storage units based on metadata.
-
- If metadata is not provided, it will be created automatically using insert mode
- with the provided data fields and partition_id.
-
- During put, the custom_meta in metadata will update the corresponding custom_meta in
- TransferQueue Controller.
-
- Note:
- When using multiple workers for distributed execution, there may be data
- ordering inconsistencies between workers during put operations.
-
- Args:
- data: Data to write as TensorDict
- metadata: Records the metadata of a batch of data samples, containing index and
- storage unit information. If None, metadata will be auto-generated.
- partition_id: Target data partition id (required if metadata is not provided)
-
- Returns:
- BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
- metadata; will be updated in a future version to reflect the post-put state)
-
- Raises:
- ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
- RuntimeError: If storage operation fails
-
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> batch_size = 4
- >>> seq_len = 16
- >>> current_partition_id = "train_0"
- >>> # Example 1: Normal usage with existing metadata
- >>> batch_meta = tq.get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=batch_size,
- ... partition_id=current_partition_id,
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... )
- >>> batch = tq.get_data(batch_meta)
- >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
- >>> tq.put(data=output, metadata=batch_meta)
- >>>
- >>> # Example 2: Initial data insertion without pre-existing metadata
- >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
- >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
- >>> # without metadata.
- >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
- >>> # interleave the initial data if n_sample > 1 before calling the async_put().
- >>> original_prompts = torch.randn(batch_size, seq_len)
- >>> n_samples = 4
- >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
- >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
- >>> # This will create metadata in "insert" mode internally.
- >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id)
- """
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.put(data, metadata, partition_id)
-
-
-def get_data(metadata: BatchMeta) -> TensorDict:
- """Synchronously fetch data from storage units and organize into TensorDict.
-
- Args:
- metadata: Batch metadata containing data location information and global indexes
-
- Returns:
- TensorDict containing:
- - Requested data fields (e.g., "prompts", "attention_mask")
-
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> batch_meta = tq.get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... )
- >>> batch = tq.get_data(batch_meta)
- >>> print(batch)
- >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
- """
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.get_data(metadata)
-
-
-def clear_partition(partition_id: str):
- """Synchronously clear the whole partition from all storage units and the controller.
-
- Args:
- partition_id: The partition id to clear data for
-
- Raises:
- RuntimeError: If clear operation fails
- """
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.clear_partition(partition_id)
-
-
-def clear_samples(metadata: BatchMeta):
- """Synchronously clear specific samples from all storage units and the controller.
-
- Args:
- metadata: The BatchMeta of the corresponding data to be cleared
-
- Raises:
- RuntimeError: If clear operation fails
- """
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.clear_samples(metadata)
-
-
-async def async_get_meta(
- data_fields: list[str],
- batch_size: int,
- partition_id: str,
- mode: str = "fetch",
- task_name: Optional[str] = None,
- sampling_config: Optional[dict[str, Any]] = None,
-) -> BatchMeta:
- """Asynchronously fetch data metadata from the controller via ZMQ.
-
- Args:
- data_fields: List of data field names to retrieve metadata for
- batch_size: Number of samples to request in the batch
- partition_id: Current data partition id
- mode: Data fetch mode. Options:
- - 'fetch': Get ready data only
- - 'force_fetch': Get data regardless of readiness (may return unready samples)
- - 'insert': Internal usage - should not be used by users
- task_name: Optional task name associated with the request
- sampling_config: Optional sampling configuration for custom samplers.
-
- Returns:
- BatchMeta: Metadata object containing data structure, sample information, and readiness status
-
- Raises:
- RuntimeError: If communication fails or controller returns error response
-
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> # Example 1: Basic fetch metadata
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences"
- ... ))
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=8,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... ))
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
- >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... partition_id="train_0", # optional
- ... mode="force_fetch",
- ... ))
- >>> print(batch_meta.is_ready) # May be False if some samples not ready
- """
-
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config)
-
-
-async def async_set_custom_meta(
- metadata: BatchMeta,
-) -> None:
- """
- Asynchronously send custom metadata to the controller.
-
- This method sends per-sample custom metadata (custom_meta) to the controller.
- The custom_meta is stored in the controller and can be retrieved along with
- the BatchMeta in subsequent get_meta calls.
-
- Args:
- metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta()
- before calling this method.
- socket: ZMQ async socket for message transmission (injected by decorator)
-
- Raises:
- RuntimeError: If communication fails or controller returns error response
-
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> # Create batch with custom metadata
- >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=2, ...)
- >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
- >>> asyncio.run(tq.async_set_custom_meta(batch_meta))
- """
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_set_custom_meta(metadata)
-
-
-async def async_put(
- data: TensorDict,
- metadata: Optional[BatchMeta] = None,
- partition_id: Optional[str] = None,
-) -> BatchMeta:
- """Asynchronously write data to storage units based on metadata.
-
- If metadata is not provided, it will be created automatically using insert mode
- with the provided data fields and partition_id.
-
- During put, the custom_meta in metadata will update the corresponding custom_meta in
- TransferQueue Controller.
-
- Note:
- When using multiple workers for distributed execution, there may be data
- ordering inconsistencies between workers during put operations.
-
- Args:
- data: Data to write as TensorDict
- metadata: Records the metadata of a batch of data samples, containing index and
- storage unit information. If None, metadata will be auto-generated.
- partition_id: Target data partition id (required if metadata is not provided)
-
- Returns:
- BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
- metadata; will be updated in a future version to reflect the post-put state)
-
- Raises:
- ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
- RuntimeError: If storage operation fails
-
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> batch_size = 4
- >>> seq_len = 16
- >>> current_partition_id = "train_0"
- >>> # Example 1: Normal usage with existing metadata
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=batch_size,
- ... partition_id=current_partition_id,
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... ))
- >>> batch = asyncio.run(tq.async_get_data(batch_meta))
- >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
- >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta))
- >>>
- >>> # Example 2: Initial data insertion without pre-existing metadata
- >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
- >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
- >>> # without metadata.
- >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
- >>> # interleave the initial data if n_sample > 1 before calling the async_put().
- >>> original_prompts = torch.randn(batch_size, seq_len)
- >>> n_samples = 4
- >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
- >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
- >>> # This will create metadata in "insert" mode internally.
- >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id))
- """
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_put(data, metadata, partition_id)
-
-
-async def async_get_data(metadata: BatchMeta) -> TensorDict:
- """Asynchronously fetch data from storage units and organize into TensorDict.
-
- Args:
- metadata: Batch metadata containing data location information and global indexes
-
- Returns:
- TensorDict containing:
- - Requested data fields (e.g., "prompts", "attention_mask")
-
- Example:
- >>> import transfer_queue as tq
- >>> tq.init()
- >>>
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... ))
- >>> batch = asyncio.run(tq.async_get_data(batch_meta))
- >>> print(batch)
- >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
- """
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_get_data(metadata)
-
-
-async def async_clear_samples(metadata: BatchMeta):
- """Asynchronously clear specific samples from all storage units and the controller.
-
- Args:
- metadata: The BatchMeta of the corresponding data to be cleared
-
- Raises:
- RuntimeError: If clear operation fails
- """
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_clear_samples(metadata)
-
-
-async def async_clear_partition(partition_id: str):
- """Asynchronously clear the whole partition from all storage units and the controller.
-
- Args:
- partition_id: The partition id to clear data for
-
- Raises:
- RuntimeError: If clear operation fails
- """
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_clear_partition(partition_id)
+# For low-level API support, please refer to transfer_queue/client.py for details.
+def get_client():
+ global _TRANSFER_QUEUE_CLIENT
+ if _TRANSFER_QUEUE_CLIENT is None:
+ raise RuntimeError("Please initialize the TransferQueue first by calling `tq.init()`!")
+ return _TRANSFER_QUEUE_CLIENT
From 3830bcf57868e90cbef14282caeeeafb555a585b Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 21:11:00 +0800
Subject: [PATCH 32/34] fix ut
Signed-off-by: 0oshowero0
---
tests/test_controller.py | 2 +-
tests/test_metadata.py | 2 +-
tests/test_simple_storage_unit.py | 2 +-
transfer_queue/metadata.py | 5 +++--
4 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/tests/test_controller.py b/tests/test_controller.py
index 243bc19..3528d1a 100644
--- a/tests/test_controller.py
+++ b/tests/test_controller.py
@@ -28,7 +28,7 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
-from transfer_queue import TransferQueueController # noqa: E402
+from transfer_queue.controller import TransferQueueController # noqa: E402
from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index 23ba56d..2a129b5 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -1074,7 +1074,7 @@ def test_kv_batch_meta_empty_init(self):
assert kv_meta.keys == []
assert kv_meta.tags == []
assert kv_meta.partition_id is None
- assert kv_meta.fields == []
+ assert kv_meta.fields is None
def test_kv_batch_meta_init_validation_keys_tags_mismatch(self):
"""Example: Init validation catches keys and tags length mismatch."""
diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py
index 043b506..ed43e41 100644
--- a/tests/test_simple_storage_unit.py
+++ b/tests/test_simple_storage_unit.py
@@ -27,7 +27,7 @@
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
-from transfer_queue import SimpleStorageUnit # noqa: E402
+from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType # noqa: E402
diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py
index 473a487..056a146 100644
--- a/transfer_queue/metadata.py
+++ b/transfer_queue/metadata.py
@@ -857,8 +857,9 @@ def __post_init__(self):
raise ValueError(f"keys and tags must have same length, but got {len(self.keys)} and {len(self.tags)}")
if len(self.keys) != len(set(self.keys)):
raise ValueError("Got duplicated keys.")
- if len(self.fields) != len(set(self.fields)):
- raise ValueError("Got duplicated fields.")
+ if self.fields is not None:
+ if len(self.fields) != len(set(self.fields)):
+ raise ValueError("Got duplicated fields.")
# deepcopy to prevent unexpected behavior after chunk/concat
self.tags = copy.deepcopy(self.tags)
From 5ec124b0173cc74294e4c6069fee8867242715a8 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 21:18:40 +0800
Subject: [PATCH 33/34] fix tutorial
Signed-off-by: 0oshowero0
---
tutorial/01_core_components.py | 9 ++---
tutorial/03_metadata_concepts.py | 16 +++++----
tutorial/04_understanding_controller.py | 44 +++++++++++++++----------
tutorial/05_custom_sampler.py | 34 +++++++++++--------
tutorial/06_streaming_dataloader.py | 4 ++-
5 files changed, 63 insertions(+), 44 deletions(-)
diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py
index 59159a6..9a66b8e 100644
--- a/tutorial/01_core_components.py
+++ b/tutorial/01_core_components.py
@@ -63,6 +63,7 @@ def demonstrate_data_workflow():
# Step 1: Put data
print("[Step 1] Putting data into TransferQueue...")
+ tq_client = tq.get_client()
input_ids = torch.tensor(
[
@@ -84,12 +85,12 @@ def demonstrate_data_workflow():
print(f" Created {data_batch.batch_size[0]} samples")
partition_id = "tutorial_partition_0"
- tq.put(data=data_batch, partition_id=partition_id)
+ tq_client.put(data=data_batch, partition_id=partition_id)
print(f" ✓ Data written to partition: {partition_id}")
# Step 2: Get metadata
print("[Step 2] Requesting data metadata...")
- batch_meta = tq.get_meta(
+ batch_meta = tq_client.get_meta(
data_fields=["input_ids", "attention_mask"],
batch_size=data_batch.batch_size[0],
partition_id=partition_id,
@@ -100,7 +101,7 @@ def demonstrate_data_workflow():
# Step 3: Get actual data
print("[Step 3] Retrieving actual data...")
- retrieved_data = tq.get_data(batch_meta)
+ retrieved_data = tq_client.get_data(batch_meta)
print(" ✓ Data retrieved successfully")
print(f" Keys: {list(retrieved_data.keys())}")
@@ -112,7 +113,7 @@ def demonstrate_data_workflow():
# Step 5: Clear
print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)")
- tq.clear_partition(partition_id=partition_id)
+ tq_client.clear_partition(partition_id=partition_id)
print(" ✓ Partition cleared")
diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py
index 91752d8..2c81941 100644
--- a/tutorial/03_metadata_concepts.py
+++ b/tutorial/03_metadata_concepts.py
@@ -310,6 +310,8 @@ def demonstrate_real_workflow():
# Initialize TransferQueue
tq.init()
+ tq_client = tq.get_client()
+
print("[Step 1] Putting data into TransferQueue...")
input_ids = torch.randint(0, 1000, (8, 512))
attention_mask = torch.ones(8, 512)
@@ -323,7 +325,7 @@ def demonstrate_real_workflow():
)
partition_id = "demo_partition"
- batch_meta = tq.put(data=data_batch, partition_id=partition_id)
+ batch_meta = tq_client.put(data=data_batch, partition_id=partition_id)
print(f"✓ Put {data_batch.batch_size[0]} samples into partition '{partition_id}', got BatchMeta back {batch_meta}.")
print("[Step 2] [Optional] Setting sample-level custom_meta...")
@@ -335,11 +337,11 @@ def demonstrate_real_workflow():
batch_meta.update_custom_meta(custom_meta)
print(f"✓ Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}")
- tq.set_custom_meta(batch_meta)
+ tq_client.set_custom_meta(batch_meta)
print("✓ Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.")
print("[Step 3] Try to get metadata from TransferQueue from other places...")
- batch_meta = tq.get_meta(
+ batch_meta = tq_client.get_meta(
data_fields=["input_ids", "attention_mask"],
batch_size=8,
partition_id=partition_id,
@@ -358,7 +360,7 @@ def demonstrate_real_workflow():
print("✓ Selected 'input_ids' field only:")
print(f" New field names: {selected_meta.field_names}")
print(f" Samples still have same global indexes: {selected_meta.global_indexes}")
- retrieved_data = tq.get_data(selected_meta)
+ retrieved_data = tq_client.get_data(selected_meta)
print(f" Retrieved data keys: {list(retrieved_data.keys())}")
print("[Step 5] Select specific samples from the retrieved BatchMeta...")
@@ -366,7 +368,7 @@ def demonstrate_real_workflow():
print("✓ Selected samples at indices [0, 2, 4, 6]:")
print(f" New global indexes: {partial_meta.global_indexes}")
print(f" Number of samples: {len(partial_meta)}")
- retrieved_data = tq.get_data(partial_meta)
+ retrieved_data = tq_client.get_data(partial_meta)
print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}")
print("[Step 6] Demonstrate chunk operation...")
@@ -374,11 +376,11 @@ def demonstrate_real_workflow():
print(f"✓ Chunked into {len(chunks)} parts:")
for i, chunk in enumerate(chunks):
print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}")
- chunk_data = tq.get_data(chunk)
+ chunk_data = tq_client.get_data(chunk)
print(f" Chunk {i}: Retrieved chunk data: {chunk_data}")
# Cleanup
- tq.clear_partition(partition_id=partition_id)
+ tq_client.clear_partition(partition_id=partition_id)
tq.close()
ray.shutdown()
print("✓ Partition cleared and resources cleaned up")
diff --git a/tutorial/04_understanding_controller.py b/tutorial/04_understanding_controller.py
index fbeae21..746de14 100644
--- a/tutorial/04_understanding_controller.py
+++ b/tutorial/04_understanding_controller.py
@@ -62,6 +62,8 @@ def demonstrate_partition_isolation():
tq.init()
+ tq_client = tq.get_client()
+
# Partition 1: Training data
print("\n[Partition 1] Putting training data...")
train_data = TensorDict(
@@ -71,7 +73,7 @@ def demonstrate_partition_isolation():
},
batch_size=2,
)
- tq.put(data=train_data, partition_id="train")
+ tq_client.put(data=train_data, partition_id="train")
print(" ✓ Training data added to 'train' partition")
# Partition 2: Validation data
@@ -83,23 +85,25 @@ def demonstrate_partition_isolation():
},
batch_size=2,
)
- tq.put(data=val_data, partition_id="val")
+ tq_client.put(data=val_data, partition_id="val")
print(" ✓ Validation data added to 'val' partition")
# Get from train partition
print("\n[Retrieving from 'train' partition]")
- train_meta = tq.get_meta(
+ train_meta = tq_client.get_meta(
data_fields=["input_ids", "labels"], batch_size=2, partition_id="train", task_name="train_task"
)
- retrieved_train_data = tq.get_data(train_meta)
+ retrieved_train_data = tq_client.get_data(train_meta)
print(f" ✓ Got BatchMeta={train_meta} from partition 'train'")
print(f" ✓ Retrieved Data: input_ids={retrieved_train_data['input_ids']}, labels={retrieved_train_data['labels']}")
# Get from val partition
print("\n[Retrieving from 'val' partition]")
- val_meta = tq.get_meta(data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task")
- retrieved_val_data = tq.get_data(val_meta)
+ val_meta = tq_client.get_meta(
+ data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task"
+ )
+ retrieved_val_data = tq_client.get_data(val_meta)
print(f" ✓ Got BatchMeta={val_meta} from partition 'val'")
print(f" ✓ Retrieved Data: input_ids={retrieved_val_data['input_ids']}, labels={retrieved_val_data['labels']}")
@@ -107,8 +111,8 @@ def demonstrate_partition_isolation():
print(" ✓ Data isolation: 'train' and 'val' partitions are completely independent")
# Cleanup
- tq.clear_partition(partition_id="train")
- tq.clear_partition(partition_id="val")
+ tq_client.clear_partition(partition_id="train")
+ tq_client.clear_partition(partition_id="val")
tq.close()
ray.shutdown()
@@ -126,6 +130,8 @@ def demonstrate_dynamic_expansion():
tq.init()
+ tq_client = tq.get_client()
+
# Add first batch with 2 samples, 2 fields
print("\n[Step 1] Adding initial data (2 samples, 2 fields)...")
data1 = TensorDict(
@@ -135,7 +141,7 @@ def demonstrate_dynamic_expansion():
},
batch_size=2,
)
- meta1 = tq.put(data=data1, partition_id="dynamic")
+ meta1 = tq_client.put(data=data1, partition_id="dynamic")
print(" ✓ Added 2 samples")
print(f" ✓ Got BatchMeta: {meta1} samples")
@@ -148,9 +154,9 @@ def demonstrate_dynamic_expansion():
},
batch_size=3,
)
- meta2 = tq.put(data=data2, partition_id="dynamic")
+ meta2 = tq_client.put(data=data2, partition_id="dynamic")
- all_meta = tq.get_meta(
+ all_meta = tq_client.get_meta(
data_fields=["field1", "field2"], batch_size=5, partition_id="dynamic", task_name="dynamic_task"
)
print(" ✓ Added 3 more samples (total: 5)")
@@ -165,7 +171,7 @@ def demonstrate_dynamic_expansion():
},
batch_size=2,
)
- meta3 = tq.put(data=data3, metadata=meta1)
+ meta3 = tq_client.put(data=data3, metadata=meta1)
print(" ✓ Added 2 samples with new field 'field3'")
print(f" ✓ Got BatchMeta: {meta3} for newly put data with new field")
@@ -174,7 +180,7 @@ def demonstrate_dynamic_expansion():
print(" ✓ Columns auto-expand: Can add new fields anytime")
# Cleanup
- tq.clear_partition(partition_id="dynamic")
+ tq_client.clear_partition(partition_id="dynamic")
tq.close()
ray.shutdown()
@@ -190,6 +196,8 @@ def demonstrate_default_consumption_sample_strategy():
tq.init()
+ tq_client = tq.get_client()
+
# Add 6 samples
print("\n[Setup] Adding 6 samples...")
all_data = TensorDict(
@@ -198,22 +206,22 @@ def demonstrate_default_consumption_sample_strategy():
},
batch_size=6,
)
- tq.put(data=all_data, partition_id="sampling")
+ tq_client.put(data=all_data, partition_id="sampling")
print(" ✓ 6 samples added")
# First get - should get samples 0,1,2
print("\n[Task A, Get 1] Requesting 3 samples...")
- meta1 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
+ meta1 = tq_client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
print(f" ✓ Got samples: {meta1.global_indexes}")
# Second get - should get samples 3,4,5 (no duplicates!)
print("\n[Task A, Get 2] Requesting 3 more samples...")
- meta2 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
+ meta2 = tq_client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
print(f" ✓ Got samples: {meta2.global_indexes}")
# Third get - should get samples 0,1
print("\n[Task B, Get 1] Requesting 2 samples...")
- meta3 = tq.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B")
+ meta3 = tq_client.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B")
print(f" ✓ Got samples: {meta3.global_indexes}")
print("\n[Verification]")
@@ -224,7 +232,7 @@ def demonstrate_default_consumption_sample_strategy():
print(" ✓ Third get (Task B): samples 0,1")
# Cleanup
- tq.clear_partition(partition_id="sampling")
+ tq_client.clear_partition(partition_id="sampling")
tq.close()
ray.shutdown()
diff --git a/tutorial/05_custom_sampler.py b/tutorial/05_custom_sampler.py
index bdca835..e30487b 100644
--- a/tutorial/05_custom_sampler.py
+++ b/tutorial/05_custom_sampler.py
@@ -186,6 +186,8 @@ def demonstrate_random_sampler_with_replacement():
sampler = RandomSamplerWithReplacement()
setup_transfer_queue_with_sampler(sampler)
+ tq_client = tq.get_client()
+
# Add 5 samples
print("\n[Step 1] Adding 5 samples...")
data = TensorDict(
@@ -194,22 +196,22 @@ def demonstrate_random_sampler_with_replacement():
},
batch_size=5,
)
- tq.put(data=data, partition_id="test")
+ tq_client.put(data=data, partition_id="test")
print(" ✓ 5 samples added")
# Get batch 1 (should get 2 random samples)
print("\n[Step 2] Get batch 1 (2 samples)...")
- meta1 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
+ meta1 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta1.global_indexes}")
# Get batch 2 (should get 1 random sample with replacement - may have duplicate with previous batch!)
print("\n[Step 3] Get batch 2 (1 sample)...")
- meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
+ meta2 = tq_client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta2.global_indexes}")
# Get batch 3 (should get 2 random samples with replacement - may have duplicate with previous batches!)
print("\n[Step 4] Get batch 3 (2 samples)...")
- meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
+ meta3 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta3.global_indexes}")
print("\n[Verification]")
@@ -219,7 +221,7 @@ def demonstrate_random_sampler_with_replacement():
print(f" ✓ All sampled: {all_sampled}")
# Cleanup
- tq.clear_partition(partition_id="test")
+ tq_client.clear_partition(partition_id="test")
tq.close()
ray.shutdown()
@@ -234,6 +236,8 @@ def demonstrate_random_sampler_without_replacement():
sampler = RandomSamplerWithoutReplacement()
setup_transfer_queue_with_sampler(sampler)
+ tq_client = tq.get_client()
+
# Add 6 samples
print("\n[Step 1] Adding 6 samples...")
data = TensorDict(
@@ -242,22 +246,22 @@ def demonstrate_random_sampler_without_replacement():
},
batch_size=6,
)
- tq.put(data=data, partition_id="test")
+ tq_client.put(data=data, partition_id="test")
print(" ✓ 6 samples added")
# Get batch 1 (should get 3 random samples without replacement)
print("\n[Step 2] Get batch 1 (3 samples)...")
- meta1 = tq.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task")
+ meta1 = tq_client.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta1.global_indexes}")
# Get batch 2 (should randomly get 1 sample that are different from previous batch)
print("\n[Step 3] Get batch 2 (1 samples)...")
- meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
+ meta2 = tq_client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta2.global_indexes}")
# Get batch 3 (should randomly get 2 samples that are different from previous batch)
print("\n[Step 4] Get batch 3 (2 samples)...")
- meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
+ meta3 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
print(f" ✓ Got samples: {meta3.global_indexes}")
print("\n[Verification]")
@@ -267,7 +271,7 @@ def demonstrate_random_sampler_without_replacement():
print(f" ✓ Batch 3: {meta3.global_indexes} (none left)")
# Cleanup
- tq.clear_partition(partition_id="test")
+ tq_client.clear_partition(partition_id="test")
tq.close()
ray.shutdown()
@@ -282,6 +286,8 @@ def demonstrate_priority_sampler():
sampler = PrioritySampler()
setup_transfer_queue_with_sampler(sampler)
+ tq_client = tq.get_client()
+
# Add 8 samples
print("\n[Step 1] Adding 8 samples...")
data = TensorDict(
@@ -290,7 +296,7 @@ def demonstrate_priority_sampler():
},
batch_size=8,
)
- tq.put(data=data, partition_id="test")
+ tq_client.put(data=data, partition_id="test")
print(" ✓ 8 samples added")
time.sleep(1)
@@ -303,7 +309,7 @@ def demonstrate_priority_sampler():
print(f"Priority scores: {priority_scores}")
# Get batch using priority sampling
- meta1 = tq.get_meta(
+ meta1 = tq_client.get_meta(
data_fields=["input"],
batch_size=1,
partition_id="test",
@@ -315,7 +321,7 @@ def demonstrate_priority_sampler():
# Get another batch
print("\n[Step 3] Get another batch (2 samples)...")
- meta2 = tq.get_meta(
+ meta2 = tq_client.get_meta(
data_fields=["input"],
batch_size=2,
partition_id="test",
@@ -331,7 +337,7 @@ def demonstrate_priority_sampler():
print(f" ✓ Batch 2 high-priority indices: {[i for i in meta2.global_indexes if priority_scores[i] >= 0.1]}")
# Cleanup
- tq.clear_partition(partition_id="test")
+ tq_client.clear_partition(partition_id="test")
tq.close()
ray.shutdown()
diff --git a/tutorial/06_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py
index d95a3c8..c60b4e3 100644
--- a/tutorial/06_streaming_dataloader.py
+++ b/tutorial/06_streaming_dataloader.py
@@ -123,6 +123,8 @@ def generate_worker(rank_id: int, num_samples: int = 20):
# Need to call tq.init() in each process
tq.init()
+ tq_client = tq.get_client()
+
# Generate and put samples into the queue
for i in range(num_samples):
# Create unique sequence ID for this sample
@@ -137,7 +139,7 @@ def generate_worker(rank_id: int, num_samples: int = 20):
print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue")
# Put data into the specified partition
- tq.put(data, partition_id="train")
+ tq_client.put(data, partition_id="train")
print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue")
From e9e5872a01630a69b6438c1e66c4500ef88af417 Mon Sep 17 00:00:00 2001
From: 0oshowero0
Date: Mon, 9 Feb 2026 21:25:42 +0800
Subject: [PATCH 34/34] fix pre-commit
Signed-off-by: 0oshowero0
---
transfer_queue/interface.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index 733df24..8e4f8c3 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -753,7 +753,7 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
# ==================== Low-Level Native API ====================
# For low-level API support, please refer to transfer_queue/client.py for details.
def get_client():
- global _TRANSFER_QUEUE_CLIENT
+ """Get a TransferQueueClient for using low-level API"""
if _TRANSFER_QUEUE_CLIENT is None:
raise RuntimeError("Please initialize the TransferQueue first by calling `tq.init()`!")
return _TRANSFER_QUEUE_CLIENT