diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml
index 1202e21..c24cab8 100644
--- a/.github/workflows/tutorial-check.yml
+++ b/.github/workflows/tutorial-check.yml
@@ -36,4 +36,5 @@ jobs:
- name: Run tutorials
run: |
export TQ_NUM_THREADS=2
+ export RAY_DEDUP_LOGS=0
for file in tutorial/*.py; do python3 "$file"; done
\ No newline at end of file
diff --git a/README.md b/README.md
index 2c7d4f0..8131b94 100644
--- a/README.md
+++ b/README.md
@@ -31,7 +31,8 @@ TransferQueue offers **fine-grained, sub-sample-level** data management and **lo
π Updates
- - **Jan 28, 2026**: We experimentally introduce `StreamingDataloader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/05_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_streaming_dataloader.py) for details.
+ - **Feb 8, 2026**: π₯ The initialization and usage is greatly simplified by high-level APIs [PR#26](https://github.com/Ascend/TransferQueue/pull/26), [PR#28](https://github.com/Ascend/TransferQueue/pull/28). You can now use a Redis-style API to take advantage of most of the advanced features provided by TransferQueue!
+ - **Jan 28, 2026**: We experimentally introduce `StreamingDataLoader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for details.
- **Dec 30, 2025**: **TransferQueue x verl** integration is tested with the DAPO algorithm at scale **(64 nodes, 1024 cards)**. It significantly optimizes host memory utilization and accelerates data transfers. Stay tuned for more details!
- **Dec 20, 2025**: π₯ The official [tutorial](https://github.com/Ascend/TransferQueue/tree/main/tutorial) is released! Feel free to check it out.
- **Nov 10, 2025**: We disentangle the data retrieval logic from TransferQueueController [PR#101](https://github.com/TransferQueue/TransferQueue/pull/101). Now you can implement your own `Sampler` to control how to consume the data.
@@ -91,26 +92,55 @@ This data structure design is motivated by the computational characteristics of
-### User Interface: Asynchronous & Synchronous Client
-To simplify the usage of TransferQueue, we have encapsulated this process into `TransferQueueClient`. The client provides both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue into their framework.
+### User Interface: High-Level & Low-Level APIs
-We also experimentally provide a `StreamingDataLoader` interface as a standard PyTorch DataLoader. Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks.
-This interface simplifies TransferQueue's integration, ensuring seamless compatibility with existing training workflows. Please refer to our [Roadmap](https://github.com/Ascend/TransferQueue/issues/1) and [tutorials/05_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_streaming_dataloader.py) for more details.
+| Level | Tier | Style | Fine-Grained Access | Streaming | Sampler | Multiple-Backends |
+|---|---|---|---|------------------|---|---|
+| High | **KV Interface** (this PR) | Put/Get/List/Clear | β | β | β | β |
+| High | **StreamingDataLoader** (#23) | PyTorch DataLoader | β | β | β | β |
+| Low | **TransferQueueClient** | Metadata-based | β | β | β | β |
-π₯ Showcases
-### General Usage
+#### Key-Value based API
+
+To simplify the usage of TransferQueue, we have provided a Redis-style high-level API that can enjoy most of the advanced features provided by TransferQueue ([#PR28](https://github.com/Ascend/TransferQueue/pull/28)).
+
+**Methods**
+
+- **(async_)kv_put**: Insert/Update a multi-column sample by key, with optional metadata tag
+- **(async_)kv_batch_put**: Put multiple key-value pairs efficiently in batch
+- **(async_)kv_batch_get**: Retrieve samples (by keys), supporting column selection (by fields)
+- **(async_)kv_list**: List keys and tags (metadata) in a partition
+- **(async_)kv_clear**: Remove key-value pairs from storage
+
+**Key Features**
+
+- **Redis-style Semantics**: Familiar KV interface (Put/Get/List) for zero learning curve
+- **Fine-grained Access**: Update or retrieve specific fields (columns) within a key (row) without full op.
+- **Partition Isolation**: Logical separation of storage namespaces
+- **Metadata Tags**: Lightweight metadata for status tracking
+- **Pluggable Backends**: Supports multiple backends
+
+#### StreamingDataLoader API
-The primary interaction points are `AsyncTransferQueueClient` and `TransferQueueClient`, serving as the communication interface with the TransferQueue system.
+Designed as a drop-in replacement for the standard PyTorch `DataLoader`, this API allows each rank to automatically consume data without single-controller intervention.
-Core interfaces:
+In this scenario, `TransferQueueController` serves as a side-controller for data dispatching, with user-defined `Sampler` class to organize dataflow.
+It encapsulates the complex scheduling and data transfer logic required for various parallelism strategies, seamlessly integrating TransferQueue into existing training workflows and simplifying the development of disaggregated frameworks.
-- `(async_)get_meta(data_fields: list[str], batch_size:int, partition_id: str, mode: str, task_name:str, sampling_config: Optional[dict[str, Any]]) -> BatchMeta`
-- `(async_)get_data(metadata: BatchMeta) -> TensorDict`
-- `(async_)put(data: TensorDict, metadata: Optional[BatchMeta], partition_id: Optional[str])`
-- `(async_)clear_partition(partition_id: str)` and `(async_)clear_samples(metadata: BatchMeta)`
+See [Roadmap](https://github.com/Ascend/TransferQueue/issues/1) and [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for more details.
-**Refer to our [tutorial](https://github.com/Ascend/TransferQueue/tree/main/tutorial) for detailed examples.**
+#### Low-Level Native API
+
+The native interface of TransferQueue are implemented in `TransferQueueClient`. It offers maximum flexibility through native, atomic operations.
+
+Developers can leverage `TransferQueueClient` directly to implement advanced features that require fine-grained control and fully streamed data scheduling, as illustrated in the following tutorials:
+- [tutorial/03_metadata_concepts.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/03_metadata_concepts.py)
+- [tutorial/04_understanding_controller.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/04_understanding_controller.py)
+- [tutorial/05_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_custom_sampler.py)
+
+
+π₯ Showcases
### Collocated Example
@@ -131,7 +161,7 @@ You may refer to the [recipe](https://github.com/Ascend/TransferQueue/tree/dev/r
### Disaggregated Example
-We have implemented a series of PRs ([#4](https://github.com/Ascend/TransferQueue/pull/4), [#7](https://github.com/Ascend/TransferQueue/pull/7), [#9](https://github.com/Ascend/TransferQueue/pull/9), [#16](https://github.com/Ascend/TransferQueue/pull/16)) to establish a **standardized, fully-streamed distributed** workflow via TransferQueue.
+We have experimentally implemented a **standardized, fully-streamed distributed** workflow via TransferQueue.
By leveraging the `RankAwareSampler` and `StreamingDataLoader` interfaces, we achieve a **streamlined micro-batch-level producer-consumer pipeline**. This design eliminates the need to manually determine data dispatching logic across varying parallelism strategiesβa typical complexity in the single-controller paradigmβthereby greatly simplifying framework design.
@@ -186,7 +216,7 @@ pip install TransferQueue
-> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorageUnit` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community!
+> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorage` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community!
For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/tml8ke0zkgn6roey?singleDoc#).
@@ -250,7 +280,7 @@ batch_meta = client.get_meta(
)
```
-**Refer to [tutorial/04_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/04_custom_sampler.py) for more details.**
+**Refer to [tutorial/05_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_custom_sampler.py) for more details.**
### How to integrate a new storage backend
@@ -299,21 +329,6 @@ pip install pre-commit
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
```
- π£οΈ Roadmap
-
-- [x] Support data rewrite for partial rollout & agentic post-training
-- [x] Provide a general storage abstraction layer `TransferQueueStorageManager` to manage distributed storage units, which simplifies `Client` design and makes it possible to introduce different storage backends ([PR#66](https://github.com/TransferQueue/TransferQueue/pull/66), [issue#72](https://github.com/TransferQueue/TransferQueue/issues/72))
-- [x] Implement `AsyncSimpleStorageManager` as the default storage backend based on the `TransferQueueStorageManager` abstraction
-- [x] Provide a `KVStorageManager` to cover all the KV based storage backends ([PR#96](https://github.com/TransferQueue/TransferQueue/pull/96))
-- [x] Support topic-based data partitioning to maintain train/val/test data simultaneously ([PR#98](https://github.com/TransferQueue/TransferQueue/pull/98))
-- [x] Release the first stable version through PyPI
-- [ ] Support disaggregated framework (each rank retrieves its own data without going through a centralized node)
-- [ ] Provide a `StreamingDataLoader` interface for disaggregated framework
-- [ ] Support load-balancing and dynamic batching
-- [x] Support high-performance storage backends for RDMA transmission (e.g., [Mooncake Store](https://github.com/kvcache-ai/Mooncake), [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html)...)
-- [x] High-performance serialization and deserialization
-- [ ] More documentation, examples and tutorials
-
π Citation
Please kindly cite our paper if you find this repo is useful:
diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py
new file mode 100644
index 0000000..71e9833
--- /dev/null
+++ b/tests/e2e/test_kv_interface_e2e.py
@@ -0,0 +1,588 @@
+# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+# Copyright 2025 The TransferQueue Team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""End-to-end tests for KV interface in transfer_queue.interface.
+
+This test module validates the KV interface functionality by:
+1. Using external interfaces (kv_put, kv_batch_put, kv_batch_get, kv_list, kv_clear) for read/write
+2. Verifying correctness by calling TransferQueueController's internal methods directly
+"""
+
+import os
+import sys
+from pathlib import Path
+
+import pytest
+import ray
+import torch
+from tensordict import TensorDict
+
+# Add parent directory to path
+parent_dir = Path(__file__).resolve().parent.parent.parent
+sys.path.append(str(parent_dir))
+
+import transfer_queue as tq # noqa: E402
+
+# Configure Ray for tests
+os.environ["RAY_DEDUP_LOGS"] = "0"
+
+
+@pytest.fixture(scope="module")
+def ray_init():
+ """Initialize Ray for the test module."""
+ if not ray.is_initialized():
+ ray.init(namespace="TestKVInterfaceE2E")
+ yield
+ if ray.is_initialized():
+ ray.shutdown()
+
+
+@pytest.fixture(scope="module")
+def tq_system(ray_init):
+ """Initialize TransferQueue system for the test module."""
+ tq.init()
+ yield
+ tq.close()
+
+
+@pytest.fixture
+def controller(tq_system):
+ """Get the TransferQueueController actor for direct verification."""
+ controller = ray.get_actor("TransferQueueController")
+ yield controller
+
+
+@pytest.fixture(autouse=True)
+def cleanup_partition(controller):
+ """Cleanup partition after each test."""
+ yield
+ try:
+ partition_ids = ray.get(controller.list_partitions.remote())
+ for partition_id in partition_ids:
+ ray.get(controller.clear_partition.remote(partition_id))
+ except Exception:
+ pass
+
+
+def get_controller_partition(controller, partition_id: str):
+ """Get partition snapshot from controller for verification."""
+ return ray.get(controller.get_partition_snapshot.remote(partition_id))
+
+
+def assert_tensor_equal(tensor_a, tensor_b, msg=""):
+ """Assert two tensors are equal."""
+ assert torch.equal(tensor_a, tensor_b), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}"
+
+
+def assert_tensor_close(tensor_a, tensor_b, rtol=1e-5, atol=1e-8, msg=""):
+ """Assert two tensors are close."""
+ assert torch.allclose(tensor_a, tensor_b, rtol=rtol, atol=atol), f"{msg} Tensors are not close"
+
+
+class TestKVPutE2E:
+ """End-to-end tests for kv_put functionality."""
+
+ def test_kv_put_with_dict_fields(self, controller):
+ """Test kv_put with dict fields (auto-converted to TensorDict)."""
+ partition_id = "test_partition"
+ key = "sample_0"
+
+ # Put with dict fields - will be auto-unsqueezed
+ tq.kv_put(
+ key=key, partition_id=partition_id, fields={"data": torch.tensor([1, 2, 3, 4])}, tag={"type": "dict_test"}
+ )
+
+ # Verify - retrieved data will have batch dimension
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id)
+ expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed
+ assert_tensor_equal(retrieved["data"], expected)
+
+ def test_kv_put_with_tensordict_fields(self, controller):
+ """Test kv_put with tensordict fields."""
+ partition_id = "test_partition"
+ key = "sample_1"
+
+ tensordict_data = TensorDict(
+ {
+ "input_ids": torch.tensor([[1, 2, 3, 4]]),
+ },
+ batch_size=1,
+ )
+ # Put with dict fields - will be auto-unsqueezed
+ tq.kv_put(key=key, partition_id=partition_id, fields=tensordict_data, tag={"type": "tensordict_test"})
+
+ # Verify - retrieved data will have batch dimension
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id)
+ expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed
+ assert_tensor_equal(retrieved["input_ids"], expected)
+
+ def test_kv_put_single_sample_with_fields_and_tag(self, controller):
+ """Test putting a single sample with fields and tag."""
+ partition_id = "test_partition"
+ key = "sample_2"
+ # Use 1D tensors - kv_put with dict will auto-unsqueeze to add batch dimension
+ input_ids = torch.tensor([1, 2, 3])
+ attention_mask = torch.ones(3)
+ tag = {"global_steps": 0, "status": "running"}
+
+ # Put data using interface
+ tq.kv_put(
+ key=key,
+ partition_id=partition_id,
+ fields={"input_ids": input_ids, "attention_mask": attention_mask},
+ tag=tag,
+ )
+
+ # Verify via controller internal state
+ partition = get_controller_partition(controller, partition_id)
+ assert partition is not None, "Partition should exist"
+
+ # Check key->global_index mapping
+ assert key in partition.keys_mapping, f"Key {key} should be in keys_mapping"
+ global_idx = partition.keys_mapping[key]
+ assert global_idx in partition.global_indexes, f"Global index {global_idx} should be in global_indexes"
+
+ # Check custom_meta (tag)
+ assert global_idx in partition.custom_meta, f"Custom meta should exist for global index {global_idx}"
+ assert partition.custom_meta[global_idx]["global_steps"] == 0
+ assert partition.custom_meta[global_idx]["status"] == "running"
+
+ # Check production status - fields should be marked as produced
+ assert "input_ids" in partition.field_name_mapping, "input_ids field should be registered"
+ assert "attention_mask" in partition.field_name_mapping, "attention_mask field should be registered"
+ input_ids_col_idx = partition.field_name_mapping["input_ids"]
+ assert partition.production_status[global_idx, input_ids_col_idx] == 1, "input_ids should be marked as produced"
+
+ # Retrieve and verify data via kv_batch_get - tensors will have batch dimension
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id)
+ assert "input_ids" in retrieved.keys()
+ assert "attention_mask" in retrieved.keys()
+ # After unsqueeze, tensors become 2D [batch_size=1, original_size]
+ expected_input_ids = input_ids.unsqueeze(0)
+ expected_attention_mask = attention_mask.unsqueeze(0)
+ assert_tensor_equal(retrieved["input_ids"], expected_input_ids)
+ assert_tensor_equal(retrieved["attention_mask"], expected_attention_mask)
+
+ def test_kv_put_update_tag_only(self, controller):
+ """Test updating only tag without providing fields."""
+ partition_id = "test_partition"
+ key = "sample_3"
+
+ # First put with fields - use TensorDict as another example
+ single_data = TensorDict({"value": torch.tensor([[10]])}, batch_size=1)
+ tq.kv_put(key=key, partition_id=partition_id, fields=single_data, tag={"version": 1})
+
+ # Update only tag
+ new_tag = {"version": 2, "status": "updated"}
+ tq.kv_put(key=key, partition_id=partition_id, fields=None, tag=new_tag)
+
+ # Verify via controller
+ partition = get_controller_partition(controller, partition_id)
+ global_idx = partition.keys_mapping[key]
+ assert partition.custom_meta[global_idx]["version"] == 2
+ assert partition.custom_meta[global_idx]["status"] == "updated"
+
+ # Data should still be accessible
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id)
+ assert_tensor_equal(retrieved["value"], torch.tensor([[10]]))
+
+ def test_kv_put_partial_update(self, controller):
+ """Test adding new fields to existing sample."""
+ partition_id = "test_partition"
+ key = "sample_4"
+
+ # First put initial data
+ initial_data = TensorDict(
+ {
+ "input_ids": torch.tensor([[1, 2, 3, 4]]),
+ },
+ batch_size=1,
+ )
+ tq.kv_put(key=key, partition_id=partition_id, fields=initial_data, tag={"v": 1})
+
+ # Add new fields to subset of keys
+ new_fields = TensorDict(
+ {
+ "response": torch.tensor([[5, 6]]),
+ },
+ batch_size=1,
+ )
+ tq.kv_put(key=key, partition_id=partition_id, fields=new_fields, tag={"v": 2})
+
+ # Verify via controller - only keys[1] should have response field
+ partition = get_controller_partition(controller, partition_id)
+ global_idx = partition.keys_mapping[key]
+
+ # Check that fields were added
+ assert "response" in partition.field_name_mapping
+ response_col_idx = partition.field_name_mapping["response"]
+
+ # key should have response marked as produced
+ assert partition.production_status[global_idx, response_col_idx] == 1, "Key should have response"
+
+
+class TestKVBatchPutE2E:
+ """End-to-end tests for kv_batch_put functionality."""
+
+ def test_kv_batch_put_multiple_samples(self, controller):
+ """Test batch putting multiple samples."""
+ partition_id = "test_partition"
+ keys = ["batch_0", "batch_1", "batch_2", "batch_3"]
+ batch_input_ids = torch.tensor(
+ [
+ [4, 5, 6],
+ [7, 8, 9],
+ [10, 11, 12],
+ [13, 14, 15],
+ ]
+ )
+ batch_attention_mask = torch.ones_like(batch_input_ids)
+
+ fields = TensorDict(
+ {
+ "input_ids": batch_input_ids,
+ "attention_mask": batch_attention_mask,
+ },
+ batch_size=4,
+ )
+
+ tags = [{"idx": i, "batch": True} for i in range(4)]
+
+ # Batch put using interface
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags)
+
+ # Verify via controller
+ partition = get_controller_partition(controller, partition_id)
+ assert partition is not None
+
+ # All keys should be registered
+ for key in keys:
+ assert key in partition.keys_mapping, f"Key {key} should be in keys_mapping"
+
+ # Verify tags
+ for i, key in enumerate(keys):
+ global_idx = partition.keys_mapping[key]
+ assert partition.custom_meta[global_idx]["idx"] == i
+ assert partition.custom_meta[global_idx]["batch"] is True
+
+ # Verify all data via kv_batch_get
+ retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id)
+ assert_tensor_equal(retrieved["input_ids"], batch_input_ids)
+ assert_tensor_equal(retrieved["attention_mask"], batch_attention_mask)
+
+ def test_kv_batch_put_partial_update(self, controller):
+ """Test adding new fields to existing samples."""
+ partition_id = "test_partition"
+ keys = ["partial_0", "partial_1"]
+
+ # First put initial data
+ initial_data = TensorDict(
+ {
+ "input_ids": torch.tensor([[1, 2], [3, 4]]),
+ },
+ batch_size=2,
+ )
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=initial_data, tags=[{"v": 1}, {"v": 1}])
+
+ # Add new fields to subset of keys
+ new_fields = TensorDict(
+ {
+ "response": torch.tensor([[5, 6]]), # Only for 1 sample
+ },
+ batch_size=1,
+ )
+ tq.kv_batch_put(keys=[keys[1]], partition_id=partition_id, fields=new_fields, tags=[{"v": 2}])
+
+ # Verify via controller - only keys[1] should have response field
+ partition = get_controller_partition(controller, partition_id)
+ global_idx_1 = partition.keys_mapping[keys[1]]
+
+ # Check that fields were added
+ assert "response" in partition.field_name_mapping
+ response_col_idx = partition.field_name_mapping["response"]
+
+ # keys[0] should NOT have response marked as produced
+ global_idx_0 = partition.keys_mapping[keys[0]]
+ assert partition.production_status[global_idx_0, response_col_idx] == 0, "Keys[0] should not have response"
+
+ # keys[1] should have response marked as produced
+ assert partition.production_status[global_idx_1, response_col_idx] == 1, "Keys[1] should have response"
+
+
+class TestKVGetE2E:
+ """End-to-end tests for kv_batch_get functionality."""
+
+ def test_kv_batch_get_single_key(self, controller):
+ """Test getting data for a single key."""
+ partition_id = "test_partition"
+ key = "get_single"
+ # Use TensorDict to avoid auto-unsqueeze issue with dict input
+ expected_data = torch.tensor([[100, 200, 300]])
+ fields = TensorDict({"data": expected_data}, batch_size=1)
+
+ tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None)
+
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id)
+ assert_tensor_equal(retrieved["data"], expected_data)
+
+ def test_kv_batch_get_multiple_keys(self, controller):
+ """Test getting data for multiple keys."""
+ partition_id = "test_partition"
+ keys = ["get_multi_0", "get_multi_1", "get_multi_2"]
+ expected_data = torch.tensor([[1, 2], [3, 4], [5, 6]])
+
+ fields = TensorDict({"data": expected_data}, batch_size=3)
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}])
+
+ retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id)
+ assert_tensor_equal(retrieved["data"], expected_data)
+
+ def test_kv_batch_get_partial_keys(self, controller):
+ """Test getting data for partial keys."""
+ partition_id = "test_partition"
+ keys = ["get_multi_3", "get_multi_4", "get_multi_5"]
+ partial_keys = ["get_multi_3", "get_multi_5"]
+ input_data = torch.tensor([[1, 2], [3, 4], [5, 6]])
+ expected_data = torch.tensor([[1, 2], [5, 6]])
+
+ fields = TensorDict({"data": input_data}, batch_size=3)
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}])
+
+ retrieved = tq.kv_batch_get(keys=partial_keys, partition_id=partition_id)
+ assert_tensor_equal(retrieved["data"], expected_data)
+
+ def test_kv_batch_get_partial_fields(self, controller):
+ """Test getting only partial fields."""
+ partition_id = "test_partition"
+ key = "get_fields"
+ # Use TensorDict to avoid auto-unsqueeze issue
+ input_ids = torch.tensor([[1, 2, 3]])
+ attention_mask = torch.ones(1, 3)
+ response = torch.tensor([[10, 20]])
+
+ fields = TensorDict(
+ {"input_ids": input_ids, "attention_mask": attention_mask, "response": response}, batch_size=1
+ )
+
+ # Put all fields
+ tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None)
+
+ # Get only input_ids
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, fields="input_ids")
+ assert "input_ids" in retrieved.keys()
+ assert "attention_mask" not in retrieved.keys()
+ assert "response" not in retrieved.keys()
+ assert_tensor_equal(retrieved["input_ids"], input_ids)
+
+ # Get multiple specific fields
+ retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, fields=["input_ids", "response"])
+ assert "input_ids" in retrieved.keys()
+ assert "response" in retrieved.keys()
+ assert "attention_mask" not in retrieved.keys()
+ assert_tensor_equal(retrieved["input_ids"], input_ids)
+ assert_tensor_equal(retrieved["response"], response)
+
+ def test_kv_batch_get_nonexistent_key(self, controller):
+ """Test that getting data for non-existent key returns empty result."""
+ partition_id = "test_partition"
+
+ # Try to get data for a key that doesn't exist - should return empty or raise error
+ try:
+ retrieved = tq.kv_batch_get(keys="nonexistent_key", partition_id=partition_id)
+ # If it returns, it should be empty
+ assert retrieved.batch_size[0] == 0
+ except RuntimeError as e:
+ # Or it might raise an error about keys not found
+ assert "not found" in str(e).lower() or "empty" in str(e).lower()
+
+
+class TestKVListE2E:
+ """End-to-end tests for kv_list functionality."""
+
+ def test_kv_list_single_partition(self, controller):
+ """Test listing all keys and tags in single partition."""
+ partition_id = "test_partition"
+ keys = ["list_0", "list_1", "list_2"]
+
+ for i, key in enumerate(keys):
+ tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag={"id": i})
+
+ # List all keys
+ partition_info = tq.kv_list(partition_id=partition_id)
+
+ assert len(partition_info.keys()) == 1
+ assert "test_partition" in partition_info.keys()
+ assert len(partition_info["test_partition"]) == 3
+ for key in keys:
+ assert key in partition_info["test_partition"]
+
+ # Verify tags match
+ for i, (key, tag) in enumerate(partition_info["test_partition"].items()):
+ assert tag["id"] == i
+
+ def test_kv_list_all_partitions(self, controller):
+ """Test listing keys and tags in all partitions."""
+ partition_id = ["test_partition0", "test_partition1", "test_partition2"]
+
+ keys_partition0 = ["list_0", "list_1", "list_2"]
+ keys_partition1 = ["list_0", "list_1", "list_2"] # deliberately set same keys
+ keys_partition2 = ["list_3", "list_4", "list_5", "list_6"]
+
+ fields_partition0 = TensorDict({"data": torch.tensor([[0], [1], [2]])}, batch_size=3)
+ fields_partition1 = TensorDict({"data": torch.tensor([[3], [4], [5]])}, batch_size=3)
+ fields_partition2 = TensorDict({"data": torch.tensor([[6], [7], [8], [9]])}, batch_size=4)
+
+ tags_partition0 = [{"id": i} for i in range(3)]
+ tags_partition1 = [{"id": i + 3} for i in range(3)]
+ tags_partition2 = [{"id": i + 6} for i in range(4)]
+
+ # Put to TQ
+ tq.kv_batch_put(
+ keys=keys_partition0, partition_id=partition_id[0], fields=fields_partition0, tags=tags_partition0
+ )
+ tq.kv_batch_put(
+ keys=keys_partition1, partition_id=partition_id[1], fields=fields_partition1, tags=tags_partition1
+ )
+ tq.kv_batch_put(
+ keys=keys_partition2, partition_id=partition_id[2], fields=fields_partition2, tags=tags_partition2
+ )
+
+ # List all keys
+ partition_info = tq.kv_list()
+
+ # Verify all partitions are exist
+ assert len(partition_info.keys()) == 3
+ assert "test_partition0" in partition_info.keys()
+ assert "test_partition1" in partition_info.keys()
+ assert "test_partition2" in partition_info.keys()
+
+ assert len(partition_info["test_partition0"]) == 3
+ for key in keys_partition0:
+ assert key in partition_info["test_partition0"]
+
+ assert len(partition_info["test_partition1"]) == 3
+ for key in keys_partition1:
+ assert key in partition_info["test_partition1"]
+
+ assert len(partition_info["test_partition2"]) == 4
+ for key in keys_partition2:
+ assert key in partition_info["test_partition2"]
+
+ # Verify tags match
+ for i, (key, tag) in enumerate(partition_info["test_partition0"].items()):
+ assert tag["id"] == i
+ for i, (key, tag) in enumerate(partition_info["test_partition1"].items()):
+ assert tag["id"] == i + 3
+ for i, (key, tag) in enumerate(partition_info["test_partition2"].items()):
+ assert tag["id"] == i + 6
+
+ def test_kv_list_empty_partition(self):
+ """Test listing empty partition."""
+ partition_id = "test_partition_empty"
+
+ partition_info = tq.kv_list(partition_id=partition_id)
+
+ assert len(partition_info) == 0
+
+
+class TestKVClearE2E:
+ """End-to-end tests for kv_clear functionality."""
+
+ def test_kv_clear_single_key(self, controller):
+ """Test clearing a single key."""
+ partition_id = "test_partition"
+ key = "clear_single"
+ other_key = "clear_other"
+
+ tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[1]])}, tag={"id": "single"})
+ tq.kv_put(key=other_key, partition_id=partition_id, fields={"data": torch.tensor([[2]])}, tag={"id": "other"})
+
+ # Clear single key
+ tq.kv_clear(keys=key, partition_id=partition_id)
+
+ # Verify via kv_list
+ partition_info = tq.kv_list(partition_id=partition_id)
+ assert key not in partition_info[partition_id]
+ assert other_key in partition_info[partition_id]
+
+ # Verify via controller - key should be removed
+ partition = get_controller_partition(controller, partition_id)
+ assert key not in partition.keys_mapping
+ assert other_key in partition.keys_mapping
+
+ def test_kv_clear_multiple_keys(self, controller):
+ """Test clearing multiple keys."""
+ partition_id = "test_partition"
+ keys = ["clear_multi_0", "clear_multi_1", "clear_multi_2", "clear_multi_3"]
+
+ for i, key in enumerate(keys):
+ tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag=None)
+
+ # Clear first 2 keys
+ tq.kv_clear(keys=keys[:2], partition_id=partition_id)
+
+ # Verify
+ partition_info = tq.kv_list(partition_id=partition_id)
+ assert len(partition_info[partition_id]) == 2
+ assert keys[0] not in partition_info[partition_id]
+ assert keys[1] not in partition_info[partition_id]
+ assert keys[2] in partition_info[partition_id]
+ assert keys[3] in partition_info[partition_id]
+
+
+class TestKVE2ECornerCases:
+ """End-to-end tests for corner cases."""
+
+ def test_field_expansion_across_samples(self, controller):
+ """Test that new fields can be added across samples."""
+ partition_id = "test_partition"
+ keys = ["expand_0", "expand_1"]
+
+ # Put initial fields
+ tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_a": torch.tensor([[1]])}, tag=None)
+
+ # Add new field to first key
+ tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_b": torch.tensor([[2]])}, tag=None)
+
+ # Add different field to second key
+ tq.kv_put(
+ key=keys[1],
+ partition_id=partition_id,
+ fields={"field_a": torch.tensor([[3]]), "field_c": torch.tensor([[4]])},
+ tag=None,
+ )
+
+ # Verify field expansion in controller
+ partition = get_controller_partition(controller, partition_id)
+
+ # All fields should be registered, but only samples with the actual fields are labeled as READY_FOR_CONSUME
+ assert "field_a" in partition.field_name_mapping
+ assert "field_b" in partition.field_name_mapping
+ assert "field_c" in partition.field_name_mapping
+
+ # We can only fetch "field_a" because not all requested keys has other fields
+ data = tq.kv_batch_get(keys=keys, partition_id=partition_id)
+ assert "field_a" in data
+ assert "field_b" not in data
+ assert "field_c" not in data
+
+
+def run_tests():
+ """Run all e2e tests manually for debugging."""
+ pytest.main([__file__, "-v", "-s"])
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/tests/test_client.py b/tests/test_client.py
index 42cd63b..5d308d8 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -34,7 +34,7 @@
FieldMeta,
SampleMeta,
)
-from transfer_queue.utils.enum_utils import TransferQueueRole # noqa: E402
+from transfer_queue.utils.enum_utils import ProductionStatus, TransferQueueRole # noqa: E402
from transfer_queue.utils.zmq_utils import ( # noqa: E402
ZMQMessage,
ZMQRequestType,
@@ -144,6 +144,12 @@ def _handle_requests(self):
"message": "Consumption reset successfully",
}
response_type = ZMQRequestType.RESET_CONSUMPTION_RESPONSE
+ elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS:
+ response_body = self._mock_kv_retrieve_keys(request_msg.body)
+ response_type = ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE
+ elif request_msg.request_type == ZMQRequestType.KV_LIST:
+ response_body = self._mock_kv_list(request_msg.body)
+ response_type = ZMQRequestType.KV_LIST_RESPONSE
else:
response_body = {"error": f"Unknown request type: {request_msg.request_type}"}
response_type = ZMQRequestType.CLEAR_META_RESPONSE
@@ -174,7 +180,7 @@ def _mock_batch_meta(self, request_body):
name=field_name,
dtype=None,
shape=None,
- production_status=0,
+ production_status=ProductionStatus.NOT_PRODUCED,
)
fields.append(field_meta)
sample = SampleMeta(
@@ -187,6 +193,81 @@ def _mock_batch_meta(self, request_body):
return {"metadata": metadata}
+ def _mock_kv_retrieve_keys(self, request_body):
+ """Mock KV retrieve keys response."""
+ keys = request_body.get("keys", [])
+ create = request_body.get("create", False)
+ partition_id = request_body.get("partition_id", "")
+
+ # Initialize key tracking if not exists
+ if not hasattr(self, "_kv_partition_keys"):
+ self._kv_partition_keys = {}
+
+ # Generate global indexes for the keys
+ start_index = self._get_next_kv_index(partition_id)
+ global_indexes = list(range(start_index, start_index + len(keys)))
+
+ # Create metadata for each key
+ samples = []
+ for i, key in enumerate(keys):
+ field_meta = FieldMeta(
+ name="data",
+ dtype=torch.float32,
+ shape=torch.Size([1, 10]),
+ production_status=ProductionStatus.READY_FOR_CONSUME,
+ )
+ sample = SampleMeta(
+ partition_id=partition_id,
+ global_index=global_indexes[i],
+ fields={"data": field_meta},
+ )
+ samples.append(sample)
+
+ metadata = BatchMeta(samples=samples)
+
+ # Store keys for this partition (only when create=True)
+ if create:
+ if partition_id not in self._kv_partition_keys:
+ self._kv_partition_keys[partition_id] = []
+ self._kv_partition_keys[partition_id].extend(keys)
+
+ # Update the next index for this partition
+ if global_indexes:
+ self._update_kv_index(partition_id, global_indexes[-1] + 1)
+
+ return {"metadata": metadata}
+
+ def _mock_kv_list(self, request_body):
+ """Mock KV list response."""
+ partition_id = request_body.get("partition_id", None)
+
+ # Initialize key tracking if not exists
+ if not hasattr(self, "_kv_partition_keys"):
+ self._kv_partition_keys = {}
+
+ # Return cached keys for this partition
+ keys = self._kv_partition_keys.get(partition_id, [])
+
+ return {"partition_info": {partition_id: {k: {} for k in keys}}, "message": "success"}
+
+ def _get_next_kv_index(self, partition_id):
+ """Get next available index for KV keys in partition."""
+ if not hasattr(self, "_kv_index_map"):
+ self._kv_index_map = {}
+ if partition_id not in self._kv_index_map:
+ self._kv_index_map[partition_id] = 0
+ # Also initialize key tracking
+ if not hasattr(self, "_kv_partition_keys"):
+ self._kv_partition_keys = {}
+ self._kv_partition_keys[partition_id] = []
+ return self._kv_index_map[partition_id]
+
+ def _update_kv_index(self, partition_id, next_index):
+ """Update next available index for KV keys."""
+ if not hasattr(self, "_kv_index_map"):
+ self._kv_index_map = {}
+ self._kv_index_map[partition_id] = next_index
+
def stop(self):
self.running = False
time.sleep(0.2) # Give thread time to stop
@@ -850,10 +931,10 @@ def test_set_custom_meta_sync(self, client_setup):
metadata = client.get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
# Set custom_meta on the metadata
metadata.update_custom_meta(
- {
- 0: {"input_ids": {"token_count": 100}},
- 1: {"input_ids": {"token_count": 120}},
- }
+ [
+ {"input_ids": {"token_count": 100}},
+ {"input_ids": {"token_count": 120}},
+ ]
)
# Call set_custom_meta with metadata (BatchMeta)
@@ -869,12 +950,165 @@ async def test_set_custom_meta_async(self, client_setup):
metadata = await client.async_get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
# Set custom_meta on the metadata
metadata.update_custom_meta(
- {
- 0: {"input_ids": {"token_count": 100}},
- 1: {"input_ids": {"token_count": 120}},
- }
+ [
+ {"input_ids": {"token_count": 100}},
+ {"input_ids": {"token_count": 120}},
+ ]
)
# Call async_set_custom_meta with metadata (BatchMeta)
await client.async_set_custom_meta(metadata)
print("β async_set_custom_meta async method works")
+
+
+# =====================================================
+# KV Interface Tests
+# =====================================================
+
+
+class TestClientKVInterface:
+ """Tests for client KV interface methods."""
+
+ @pytest.mark.asyncio
+ async def test_async_kv_retrieve_keys_single(self, client_setup):
+ """Test async_kv_retrieve_keys with single key."""
+ client, _, _ = client_setup
+
+ # Test async_kv_retrieve_keys with single key
+ metadata = await client.async_kv_retrieve_keys(
+ keys="test_key_1",
+ partition_id="test_partition",
+ create=True,
+ )
+
+ # Verify metadata structure
+ assert metadata is not None
+ assert hasattr(metadata, "global_indexes")
+ assert hasattr(metadata, "size")
+ assert metadata.size == 1
+
+ @pytest.mark.asyncio
+ async def test_async_kv_retrieve_keys_multiple(self, client_setup):
+ """Test async_kv_retrieve_keys with multiple keys."""
+ client, _, _ = client_setup
+
+ # Test async_kv_retrieve_keys with multiple keys
+ keys = ["key_a", "key_b", "key_c"]
+ metadata = await client.async_kv_retrieve_keys(
+ keys=keys,
+ partition_id="test_partition",
+ create=True,
+ )
+
+ # Verify metadata structure
+ assert metadata is not None
+ assert hasattr(metadata, "global_indexes")
+ assert hasattr(metadata, "size")
+ assert metadata.size == 3
+
+ @pytest.mark.asyncio
+ async def test_async_kv_retrieve_keys_create_false(self, client_setup):
+ """Test async_kv_retrieve_keys with create=False (retrieve existing keys)."""
+ client, _, _ = client_setup
+
+ # create some keys
+ await client.async_kv_retrieve_keys(
+ keys="existing_key",
+ partition_id="existing_partition",
+ create=True,
+ )
+
+ # Then retrieve them with create=False
+ metadata = await client.async_kv_retrieve_keys(
+ keys="existing_key",
+ partition_id="existing_partition",
+ create=False,
+ )
+
+ # Verify metadata structure
+ assert metadata is not None
+ assert metadata.size == 1
+
+ @pytest.mark.asyncio
+ async def test_async_kv_retrieve_keys_invalid_keys_type(self, client_setup):
+ """Test async_kv_retrieve_keys raises error with invalid keys type."""
+ client, _, _ = client_setup
+
+ # Test with invalid keys type (not string or list)
+ with pytest.raises(TypeError):
+ await client.async_kv_retrieve_keys(
+ keys=123, # Invalid type
+ partition_id="test_partition",
+ create=True,
+ )
+
+ @pytest.mark.asyncio
+ async def test_async_kv_list_with_keys(self, client_setup):
+ """Test async_kv_list returns keys after they are registered."""
+ client, mock_controller, _ = client_setup
+
+ # First register some keys
+ await client.async_kv_retrieve_keys(
+ keys=["key_1", "key_2"],
+ partition_id="kv_partition",
+ create=True,
+ )
+
+ # Then list them
+ partition_info = await client.async_kv_list(partition_id="kv_partition")
+
+ # Verify keys are returned
+ assert len(partition_info["kv_partition"]) >= 2
+ assert "key_1" in partition_info["kv_partition"]
+ assert "key_2" in partition_info["kv_partition"]
+ assert list(partition_info["kv_partition"].values()) == [{}, {}]
+
+ @pytest.mark.asyncio
+ async def test_async_kv_list_multiple_partitions(self, client_setup):
+ """Test async_kv_list with multiple partitions."""
+ client, _, _ = client_setup
+
+ # Create keys in different partitions
+ await client.async_kv_retrieve_keys(
+ keys="partition_a_key",
+ partition_id="partition_a",
+ create=True,
+ )
+ await client.async_kv_retrieve_keys(
+ keys="partition_b_key",
+ partition_id="partition_b",
+ create=True,
+ )
+
+ # List keys for each partition
+ partition_a = await client.async_kv_list(partition_id="partition_a")
+ partition_b = await client.async_kv_list(partition_id="partition_b")
+
+ # Verify keys are isolated per partition
+ assert "partition_a" in partition_a
+ assert "partition_b" in partition_b
+ assert "partition_a" not in partition_b
+ assert "partition_b" not in partition_a
+ assert "partition_a_key" in partition_a["partition_a"]
+ assert "partition_b_key" not in partition_a["partition_a"]
+ assert "partition_b_key" in partition_b["partition_b"]
+ assert "partition_a_key" not in partition_b["partition_b"]
+ assert list(partition_a["partition_a"].values()) == [{}]
+ assert list(partition_b["partition_b"].values()) == [{}]
+
+ def test_kv_retrieve_keys_type_validation(self, client_setup):
+ """Test synchronous kv_retrieve_keys type validation."""
+ import asyncio
+
+ client, _, _ = client_setup
+
+ # Test with non-string element in list
+ async def test_invalid_list():
+ with pytest.raises(TypeError):
+ await client.async_kv_retrieve_keys(
+ keys=["valid_key", 123], # Invalid: 123 is not a string
+ partition_id="test_partition",
+ create=True,
+ )
+
+ asyncio.run(test_invalid_list())
diff --git a/tests/test_controller.py b/tests/test_controller.py
index 09f5778..3528d1a 100644
--- a/tests/test_controller.py
+++ b/tests/test_controller.py
@@ -28,7 +28,7 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
-from transfer_queue import TransferQueueController # noqa: E402
+from transfer_queue.controller import TransferQueueController # noqa: E402
from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402
@@ -760,3 +760,205 @@ def test_controller_with_custom_meta(self, ray_setup):
# Clean up
ray.get(tq_controller.clear_partition.remote(partition_id))
+
+
+class TestTransferQueueControllerKvInterface:
+ """End-to-end tests for TransferQueueController KV interface functionality.
+
+ Tests for kv_retrieve_keys method that supports key-value interface operations
+ across the controller and partition layers.
+ """
+
+ def test_controller_kv_retrieve_keys_create_mode(self, ray_setup):
+ """Test kv_retrieve_keys with create=True creates new keys in partition."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "kv_test_partition"
+
+ # Retrieve keys with create=True - should create partition and keys
+ keys = ["key_a", "key_b", "key_c"]
+ metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True))
+
+ # Verify partition was created
+ partitions = ray.get(tq_controller.list_partitions.remote())
+ assert partition_id in partitions
+
+ # Verify metadata contains correct number of global_indexes
+ assert len(metadata.global_indexes) == len(keys)
+
+ # Verify partition has keys_mapping
+ partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
+ assert "key_a" in partition.keys_mapping
+ assert "key_b" in partition.keys_mapping
+ assert "key_c" in partition.keys_mapping
+ assert metadata.global_indexes[0] == partition.keys_mapping["key_a"]
+ assert metadata.global_indexes[1] == partition.keys_mapping["key_b"]
+ assert metadata.global_indexes[2] == partition.keys_mapping["key_c"]
+ assert partition.revert_keys_mapping[metadata.global_indexes[0]] == "key_a"
+ assert partition.revert_keys_mapping[metadata.global_indexes[1]] == "key_b"
+ assert partition.revert_keys_mapping[metadata.global_indexes[2]] == "key_c"
+
+ print("β kv_retrieve_keys with create=True creates keys correctly")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_id))
+
+ def test_controller_kv_retrieve_keys_existing_keys(self, ray_setup):
+ """Test kv_retrieve_keys retrieves existing keys correctly."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "kv_existing_test"
+
+ # First, create some keys
+ keys = ["existing_key_1", "existing_key_2"]
+ ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True))
+
+ # Retrieve the same keys again (should return existing)
+ retrieved_metadata = ray.get(
+ tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False)
+ )
+
+ # Verify the same global_indexes are returned
+ assert len(retrieved_metadata.global_indexes) == len(keys)
+
+ print("β kv_retrieve_keys retrieves existing keys correctly")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_id))
+
+ def test_controller_kv_retrieve_keys_non_existent_without_create(self, ray_setup):
+ """Test kv_retrieve_keys raises error for non-existent keys without create."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "kv_nonexistent_test"
+
+ # Create partition first
+ ray.get(tq_controller.kv_retrieve_keys.remote(keys=["initial_key"], partition_id=partition_id, create=True))
+
+ # Try to retrieve non-existent key without create
+ batch_meta = ray.get(
+ tq_controller.kv_retrieve_keys.remote(keys=["nonexistent_key"], partition_id=partition_id, create=False)
+ )
+ assert batch_meta.size == 0
+
+ print("β kv_retrieve_keys return an empty BatchMeta for non-existent keys without create")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_id))
+
+ def test_controller_kv_retrieve_keys_empty_partition_without_create(self, ray_setup):
+ """Test kv_retrieve_keys raises error for non-existent partition without create."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "nonexistent_partition"
+
+ batch_meta = ray.get(
+ tq_controller.kv_retrieve_keys.remote(keys=["key_1"], partition_id=partition_id, create=False)
+ )
+ assert batch_meta.size == 0
+
+ print("β kv_retrieve_keys return an empty BatchMeta for non-existent partition_id without create")
+
+ def test_controller_kv_retrieve_keys_with_production_status(self, ray_setup):
+ """Test kv_retrieve_keys works with production status update."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "kv_production_test"
+
+ # Create keys
+ keys = ["sample_1", "sample_2", "sample_3"]
+ metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True))
+ global_indexes = metadata.global_indexes
+
+ # Update production status
+ dtypes = {idx: {"data": "torch.float32"} for idx in global_indexes}
+ shapes = {idx: {"data": (64,)} for idx in global_indexes}
+ success = ray.get(
+ tq_controller.update_production_status.remote(
+ partition_id=partition_id,
+ global_indexes=global_indexes,
+ field_names=["data"],
+ dtypes=dtypes,
+ shapes=shapes,
+ )
+ )
+ assert success
+
+ # Retrieve keys again (should include production info)
+ retrieved_metadata = ray.get(
+ tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False)
+ )
+
+ # Verify production status is available
+ assert len(retrieved_metadata.samples) == len(keys)
+ for sample in retrieved_metadata.samples:
+ assert "data" in sample.fields
+ assert sample.fields["data"].dtype == "torch.float32"
+ assert sample.fields["data"].shape == (64,)
+
+ print("β kv_retrieve_keys works with production status")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_id))
+
+ def test_controller_kv_retrieve_keys_with_custom_meta(self, ray_setup):
+ """Test kv_retrieve_keys preserves custom_meta through retrieve."""
+ tq_controller = TransferQueueController.remote()
+ partition_id = "kv_custom_meta_test"
+
+ # Create keys
+ keys = ["key_1", "key_2"]
+ metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True))
+
+ # Set custom_meta
+ custom_meta = {
+ partition_id: {
+ metadata.global_indexes[0]: {"score": 0.9, "tag": "A"},
+ metadata.global_indexes[1]: {"score": 0.8, "tag": "B"},
+ }
+ }
+ ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=custom_meta))
+
+ # Retrieve keys and verify custom_meta
+ retrieved_metadata = ray.get(
+ tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False)
+ )
+
+ # Verify custom_meta is preserved
+ all_custom_meta = retrieved_metadata.get_all_custom_meta()
+ assert len(all_custom_meta) == 2
+ assert all_custom_meta[0]["score"] == 0.9
+ assert all_custom_meta[1]["tag"] == "B"
+
+ print("β kv_retrieve_keys preserves custom_meta")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_id))
+
+ def test_controller_kv_interface_multiple_partitions(self, ray_setup):
+ """Test KV interface works correctly across multiple partitions."""
+ tq_controller = TransferQueueController.remote()
+
+ # Create keys in partition 1
+ partition_1 = "partition_kv_1"
+ keys_1 = ["p1_key_a", "p1_key_b"]
+ ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys_1, partition_id=partition_1, create=True))
+
+ # Create keys in partition 2
+ partition_2 = "partition_kv_2"
+ keys_2 = ["p2_key_x", "p2_key_y", "p2_key_z"]
+ ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys_2, partition_id=partition_2, create=True))
+
+ # Verify partitions are isolated
+ partition_1_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_1))
+ partition_2_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_2))
+
+ assert "p1_key_a" in partition_1_snapshot.keys_mapping
+ assert "p1_key_b" in partition_1_snapshot.keys_mapping
+ assert "p2_key_x" in partition_2_snapshot.keys_mapping
+ assert "p2_key_z" in partition_2_snapshot.keys_mapping
+
+ # Verify cross-partition access is isolated
+ assert "p2_key_x" not in partition_1_snapshot.keys_mapping
+ assert "p1_key_a" not in partition_2_snapshot.keys_mapping
+
+ print("β KV interface maintains partition isolation")
+
+ # Clean up
+ ray.get(tq_controller.clear_partition.remote(partition_1))
+ ray.get(tq_controller.clear_partition.remote(partition_2))
diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py
index 31478ba..6eba7a0 100644
--- a/tests/test_controller_data_partitions.py
+++ b/tests/test_controller_data_partitions.py
@@ -941,3 +941,62 @@ def test_custom_meta_cleared_with_data(self):
result = partition.get_custom_meta([0, 1])
assert 0 not in result
assert 1 in result # Sample 1 should still have custom_meta
+
+
+class TestDataPartitionStatusKvInterface:
+ """Unit tests for DataPartitionStatus KV interface functionality.
+
+ Tests for the keys_mapping and kv_retrieve_keys methods that support
+ key-value interface operations within a partition.
+ """
+
+ def test_kv_retrieve_keys_with_existing_keys(self):
+ """Test kv_retrieve_keys returns correct global_indexes for existing keys."""
+ from transfer_queue.controller import DataPartitionStatus
+
+ partition = DataPartitionStatus(partition_id="kv_test_partition")
+
+ # Simulate keys being registered (as would happen during kv_put)
+ partition.keys_mapping = {"key_a": 0, "key_b": 1, "key_c": 2}
+
+ # Retrieve keys
+ global_indexes = partition.kv_retrieve_keys(["key_a", "key_b", "key_c"])
+
+ assert global_indexes == [0, 1, 2]
+
+ def test_kv_retrieve_keys_with_nonexistent_keys(self):
+ """Test kv_retrieve_keys returns None for keys that don't exist."""
+ from transfer_queue.controller import DataPartitionStatus
+
+ partition = DataPartitionStatus(partition_id="kv_test_partition")
+
+ # Simulate some keys being registered
+ partition.keys_mapping = {"existing_key": 5}
+
+ # Retrieve mixed existing and non-existing keys
+ global_indexes = partition.kv_retrieve_keys(["existing_key", "nonexistent_key"])
+
+ assert global_indexes == [5, None]
+
+ def test_kv_retrieve_keys_empty_list(self):
+ """Test kv_retrieve_keys handles empty key list."""
+ from transfer_queue.controller import DataPartitionStatus
+
+ partition = DataPartitionStatus(partition_id="kv_test_partition")
+
+ global_indexes = partition.kv_retrieve_keys([])
+
+ assert global_indexes == []
+
+ def test_kv_retrieve_keys_partial_match(self):
+ """Test kv_retrieve_keys with partial key matches."""
+ from transfer_queue.controller import DataPartitionStatus
+
+ partition = DataPartitionStatus(partition_id="kv_test_partition")
+
+ partition.keys_mapping = {"key_1": 10, "key_2": 20, "key_3": 30}
+
+ # Request only some of the keys
+ global_indexes = partition.kv_retrieve_keys(["key_1", "key_3"])
+
+ assert global_indexes == [10, 30]
diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py
index 083e16d..ec5b29c 100644
--- a/tests/test_kv_storage_manager.py
+++ b/tests/test_kv_storage_manager.py
@@ -27,6 +27,7 @@
from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402
from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402
+from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402
def get_meta(data, global_indexes=None):
@@ -41,7 +42,7 @@ def get_meta(data, global_indexes=None):
name=field_name,
dtype=tensor.dtype if isinstance(tensor, torch.Tensor) else None,
shape=tensor.shape if isinstance(tensor, torch.Tensor) else None,
- production_status=1,
+ production_status=ProductionStatus.READY_FOR_CONSUME,
)
fields_dict[field_name] = field_meta
sample = SampleMeta(
@@ -163,8 +164,8 @@ def test_merge_tensors_to_tensordict(mock_create, test_data):
assert complex_tensordict[key] == complex_data[key]
-def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data):
- """Test _get_shape_type_custom_backend_meta_list returns correct shapes and dtypes without custom_meta."""
+def test_get_shape_type_custom_backend_meta_list_without_custom_backend_meta(test_data):
+ """Test _get_shape_type_custom_backend_meta_list returns correct shapes and dtypes without custom_backend_meta."""
shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(
test_data["metadata"]
)
@@ -185,7 +186,7 @@ def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data):
torch.Size([2]), # text[2]
]
expected_dtypes = [torch.int64] * (len(test_data["field_names"]) * len(test_data["global_indexes"]))
- # No custom_meta provided, so all should be None
+ # No custom_backend_meta provided, so all should be None
expected_custom_backend_meta = [None] * (len(test_data["field_names"]) * len(test_data["global_indexes"]))
assert shapes == expected_shapes
@@ -193,9 +194,9 @@ def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data):
assert custom_backend_meta_list == expected_custom_backend_meta
-def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data):
- """Test _get_shape_type_custom_meta_list returns correct custom_meta when provided."""
- # Add custom_meta to metadata
+def test_get_shape_type_custom_backend_meta_list_with_custom_backend_meta(test_data):
+ """Test _get_shape_type_custom_backend_meta_list returns correct custom_backend_meta when provided."""
+ # Add custom_backend_meta to metadata
custom_backend_meta = {
8: {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}},
9: {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}},
@@ -206,8 +207,8 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data):
shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata)
- # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index
- expected_custom_meta = [
+ # Check custom_backend_meta - order is label, mask, text (sorted alphabetically) by global_index
+ expected_custom_backend_meta = [
{"key2": "value2"}, # label, global_index=8
{"key5": "value5"}, # label, global_index=9
{"key8": "value8"}, # label, global_index=10
@@ -218,15 +219,15 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data):
{"key4": "value4"}, # text, global_index=9
{"key7": "value7"}, # text, global_index=10
]
- assert custom_backend_meta_list == expected_custom_meta
+ assert custom_backend_meta_list == expected_custom_backend_meta
-def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_data):
- """Test _get_shape_type_custom_backend_meta_list handles partial custom_meta correctly."""
- # Add custom_meta only for some global_indexes and fields
+def test_get_shape_type_custom_backend_meta_list_with_partial_custom_backend_meta(test_data):
+ """Test _get_shape_type_custom_backend_meta_list handles partial custom_backend_meta correctly."""
+ # Add custom_backend_meta only for some global_indexes and fields
custom_backend_meta = {
8: {"text": {"key1": "value1"}}, # Only text field
- # global_index 9 has no custom_meta
+ # global_index 9 has no custom_backend_meta
10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only
}
metadata = test_data["metadata"]
@@ -234,17 +235,17 @@ def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_d
shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata)
- # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index
+ # Check custom_backend_meta - order is label, mask, text (sorted alphabetically) by global_index
expected_custom_backend_meta = [
- None, # label, global_index=8 (not in custom_meta)
- None, # label, global_index=9 (not in custom_meta)
+ None, # label, global_index=8 (not in custom_backend_meta)
+ None, # label, global_index=9 (not in custom_backend_meta)
{"key2": "value2"}, # label, global_index=10
- None, # mask, global_index=8 (not in custom_meta)
- None, # mask, global_index=9 (not in custom_meta)
+ None, # mask, global_index=8 (not in custom_backend_meta)
+ None, # mask, global_index=9 (not in custom_backend_meta)
{"key3": "value3"}, # mask, global_index=10
{"key1": "value1"}, # text, global_index=8
- None, # text, global_index=9 (not in custom_meta)
- None, # text, global_index=10 (not in custom_meta for text)
+ None, # text, global_index=9 (not in custom_backend_meta)
+ None, # text, global_index=10 (not in custom_backend_meta for text)
]
assert custom_backend_meta_list == expected_custom_backend_meta
@@ -279,13 +280,13 @@ def test_data_for_put_data():
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
@patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock)
-def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_for_put_data):
- """Test that put_data correctly processes custom_meta returned by storage client."""
+def test_put_data_with_custom_backend_meta_from_storage_client(mock_notify, test_data_for_put_data):
+ """Test that put_data correctly processes custom_backend_meta returned by storage client."""
# Create a mock storage client
mock_storage_client = MagicMock()
- # Simulate storage client returning custom_meta (one per key)
+ # Simulate storage client returning custom_backend_meta (one per key)
# Keys order: label[0,1,2], text[0,1,2] (sorted by field name)
- mock_custom_meta = [
+ mock_custom_backend_meta = [
{"storage_key": "0@label"},
{"storage_key": "1@label"},
{"storage_key": "2@label"},
@@ -293,7 +294,7 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo
{"storage_key": "1@text"},
{"storage_key": "2@text"},
]
- mock_storage_client.put.return_value = mock_custom_meta
+ mock_storage_client.put.return_value = mock_custom_backend_meta
# Create manager with mocked dependencies
config = {"client_name": "MockClient"}
@@ -314,34 +315,35 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo
assert keys == expected_keys
assert len(values) == 6
- # Verify notify_data_update was called with correct custom_meta structure
+ # Verify notify_data_update was called with correct custom_backend_meta structure
mock_notify.assert_called_once()
notify_call_args = mock_notify.call_args
- per_field_custom_meta = notify_call_args[0][5] # 6th positional argument
+ per_field_custom_backend_meta = notify_call_args[0][5] # 6th positional argument
- # Verify custom_meta is structured correctly: {global_index: {field: meta}}
- assert 0 in per_field_custom_meta
- assert 1 in per_field_custom_meta
- assert 2 in per_field_custom_meta
+ # Verify custom_backend_meta is structured correctly: {global_index: {field: meta}}
+ assert 0 in per_field_custom_backend_meta
+ assert 1 in per_field_custom_backend_meta
+ assert 2 in per_field_custom_backend_meta
- assert per_field_custom_meta[0]["label"] == {"storage_key": "0@label"}
- assert per_field_custom_meta[0]["text"] == {"storage_key": "0@text"}
- assert per_field_custom_meta[1]["label"] == {"storage_key": "1@label"}
- assert per_field_custom_meta[1]["text"] == {"storage_key": "1@text"}
- assert per_field_custom_meta[2]["label"] == {"storage_key": "2@label"}
- assert per_field_custom_meta[2]["text"] == {"storage_key": "2@text"}
+ assert per_field_custom_backend_meta[0]["label"] == {"storage_key": "0@label"}
+ assert per_field_custom_backend_meta[0]["text"] == {"storage_key": "0@text"}
+ assert per_field_custom_backend_meta[1]["label"] == {"storage_key": "1@label"}
+ assert per_field_custom_backend_meta[1]["text"] == {"storage_key": "1@text"}
+ assert per_field_custom_backend_meta[2]["label"] == {"storage_key": "2@label"}
+ assert per_field_custom_backend_meta[2]["text"] == {"storage_key": "2@text"}
- # Verify metadata was updated with custom_meta
- all_custom_meta = test_data_for_put_data["metadata"].get_all_custom_meta()
- assert all_custom_meta[0]["label"] == {"storage_key": "0@label"}
- assert all_custom_meta[2]["text"] == {"storage_key": "2@text"}
+ # Verify metadata was updated with custom_backend_meta
+ all_custom_backend_meta = test_data_for_put_data["metadata"]._custom_backend_meta
+ assert len(all_custom_backend_meta) == 3
+ assert all_custom_backend_meta[0]["label"] == {"storage_key": "0@label"}
+ assert all_custom_backend_meta[2]["text"] == {"storage_key": "2@text"}
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
@patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock)
-def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data):
- """Test that put_data works correctly when storage client returns no custom_meta."""
- # Create a mock storage client that returns None for custom_meta
+def test_put_data_without_custom_backend_meta(mock_notify, test_data_for_put_data):
+ """Test that put_data works correctly when storage client returns no custom_backend_meta."""
+ # Create a mock storage client that returns None for custom_backend_meta
mock_storage_client = MagicMock()
mock_storage_client.put.return_value = None
@@ -353,19 +355,19 @@ def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data):
# Run put_data
asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"]))
- # Verify notify_data_update was called with empty dict for custom_meta
+ # Verify notify_data_update was called with empty dict for custom_backend_meta
mock_notify.assert_called_once()
notify_call_args = mock_notify.call_args
- per_field_custom_meta = notify_call_args[0][5] # 6th positional argument
- assert per_field_custom_meta == {}
+ per_field_custom_backend_meta = notify_call_args[0][5] # 6th positional argument
+ assert per_field_custom_backend_meta == {}
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
-def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_data):
- """Test that put_data raises ValueError when custom_meta length doesn't match keys."""
- # Create a mock storage client that returns mismatched custom_meta length
+def test_put_data_custom_backend_meta_length_mismatch_raises_error(test_data_for_put_data):
+ """Test that put_data raises ValueError when custom_backend_meta length doesn't match keys."""
+ # Create a mock storage client that returns mismatched custom_backend_meta length
mock_storage_client = MagicMock()
- # Return only 3 custom_meta entries when 6 are expected
+ # Return only 3 custom_backend_meta entries when 6 are expected
mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}]
# Create manager with mocked dependencies
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index 2bbf40c..2a129b5 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -27,7 +27,7 @@
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
-from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402
+from transfer_queue.metadata import BatchMeta, FieldMeta, KVBatchMeta, SampleMeta # noqa: E402
from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402
@@ -788,64 +788,6 @@ def test_batch_meta_select_samples_with_extra_info(self):
# =====================================================
# Custom Meta Tests
# =====================================================
-
- def test_batch_meta_set_custom_meta_basic(self):
- """Test set_custom_meta sets metadata for a sample by global_index."""
- fields = {
- "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)),
- "field_b": FieldMeta(name="field_b", dtype=torch.int64, shape=(3,)),
- }
- samples = [
- SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
- SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
- ]
- batch = BatchMeta(samples=samples)
-
- # Set custom_meta for sample 0
- batch.set_custom_meta(0, {"sample_score": 0.9, "quality": "high"})
-
- result = batch.get_all_custom_meta()
- assert 0 in result
- assert result[0]["sample_score"] == 0.9
- assert result[0]["quality"] == "high"
- # Sample 1 should not have custom_meta
- assert 1 not in result
-
- def test_batch_meta_set_custom_meta_overwrites(self):
- """Test set_custom_meta overwrites existing metadata."""
- fields = {
- "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)),
- }
- samples = [
- SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
- ]
- batch = BatchMeta(samples=samples)
-
- # Set initial custom_meta
- batch.set_custom_meta(0, {"sample_score": 0.9, "quality": "high"})
-
- # Overwrite with new custom_meta
- batch.set_custom_meta(0, {"sample_score": 0.1, "quality": "low"})
-
- result = batch.get_all_custom_meta()
- assert result[0]["sample_score"] == 0.1
- assert result[0]["quality"] == "low"
-
- def test_batch_meta_set_custom_meta_invalid_global_index(self):
- """Test set_custom_meta raises error for invalid global_index."""
- fields = {
- "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)),
- }
- samples = [
- SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
- ]
- batch = BatchMeta(samples=samples)
-
- # Try to set with non-existent global index
- with pytest.raises(ValueError) as exc_info:
- batch.set_custom_meta(999, {"sample_score": 0.9})
- assert "not found in global_indexes" in str(exc_info.value)
-
def test_batch_meta_update_custom_meta(self):
"""Test update_custom_meta adds metadata for different global indices."""
fields = {
@@ -859,10 +801,7 @@ def test_batch_meta_update_custom_meta(self):
batch = BatchMeta(samples=samples)
# Initial custom_meta for sample 0
- batch.update_custom_meta({0: {"sample_score": 0.9}})
-
- # Update with metadata for sample 1
- batch.update_custom_meta({1: {"sample_score": 0.1}})
+ batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}])
result = batch.get_all_custom_meta()
assert result[0]["sample_score"] == 0.9
@@ -879,10 +818,10 @@ def test_batch_meta_update_custom_meta_overwrites(self):
batch = BatchMeta(samples=samples)
# Initial custom_meta
- batch.update_custom_meta({0: {"sample_score": 0.9, "quality": "high"}})
+ batch.update_custom_meta([{"sample_score": 0.9, "quality": "high"}])
# Update with new value for same field - dict.update replaces
- batch.update_custom_meta({0: {"sample_score": 0.1, "quality": "low"}})
+ batch.update_custom_meta([{"sample_score": 0.1, "quality": "low"}])
result = batch.get_all_custom_meta()
assert result[0]["sample_score"] == 0.1
@@ -899,7 +838,7 @@ def test_batch_meta_update_custom_meta_with_none(self):
batch = BatchMeta(samples=samples)
# Set initial value
- batch.update_custom_meta({0: {"sample_score": 0.9}})
+ batch.update_custom_meta([{"sample_score": 0.9}])
# Update with None should not change anything
batch.update_custom_meta(None)
@@ -907,40 +846,6 @@ def test_batch_meta_update_custom_meta_with_none(self):
result = batch.get_all_custom_meta()
assert result[0]["sample_score"] == 0.9
- def test_batch_meta_update_custom_meta_with_empty_dict(self):
- """Test update_custom_meta with empty dict does nothing."""
- fields = {
- "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)),
- }
- samples = [
- SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
- ]
- batch = BatchMeta(samples=samples)
-
- # Set initial value
- batch.update_custom_meta({0: {"sample_score": 0.9}})
-
- # Update with empty dict should not change anything
- batch.update_custom_meta({})
-
- result = batch.get_all_custom_meta()
- assert result[0]["sample_score"] == 0.9
-
- def test_batch_meta_update_custom_meta_invalid_global_index(self):
- """Test update_custom_meta raises error for invalid global_index."""
- fields = {
- "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)),
- }
- samples = [
- SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
- ]
- batch = BatchMeta(samples=samples)
-
- # Try to update with non-existent global index
- with pytest.raises(ValueError) as exc_info:
- batch.update_custom_meta({999: {"sample_score": 0.9}})
- assert "non-exist global_indexes" in str(exc_info.value)
-
def test_batch_meta_clear_custom_meta(self):
"""Test clear_custom_meta removes all custom metadata."""
fields = {
@@ -953,14 +858,13 @@ def test_batch_meta_clear_custom_meta(self):
batch = BatchMeta(samples=samples)
# Set custom_meta
- batch.set_custom_meta(0, {"sample_score": 0.9})
- batch.set_custom_meta(1, {"sample_score": 0.1})
+ batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}])
# Clear all
batch.clear_custom_meta()
result = batch.get_all_custom_meta()
- assert result == {}
+ assert result == [{}, {}]
def test_batch_meta_get_all_custom_meta_returns_deep_copy(self):
"""Test get_all_custom_meta returns a deep copy."""
@@ -972,7 +876,7 @@ def test_batch_meta_get_all_custom_meta_returns_deep_copy(self):
]
batch = BatchMeta(samples=samples)
- custom_meta = {0: {"sample_score": 0.9, "nested": {"value": 1}}}
+ custom_meta = [{"sample_score": 0.9, "nested": {"value": 1}}]
batch.update_custom_meta(custom_meta)
# Get all custom_meta
@@ -997,7 +901,7 @@ def test_batch_meta_get_all_custom_meta_empty(self):
batch = BatchMeta(samples=samples)
result = batch.get_all_custom_meta()
- assert result == {}
+ assert result == [{}]
def test_batch_meta_custom_meta_with_nested_data(self):
"""Test custom_meta supports nested dictionary data."""
@@ -1013,7 +917,7 @@ def test_batch_meta_custom_meta_with_nested_data(self):
"model_info": {"name": "llama", "version": "7b", "config": {"hidden_size": 4096, "num_layers": 32}},
"tags": ["training", "inference"],
}
- batch.set_custom_meta(0, nested_meta)
+ batch.update_custom_meta([nested_meta])
result = batch.get_all_custom_meta()
assert result[0]["model_info"]["name"] == "llama"
@@ -1141,3 +1045,290 @@ def test_batch_meta_concat_validation_error(self):
with pytest.raises(ValueError) as exc_info:
BatchMeta.concat([batch1, batch2], validate=True)
assert "Field names do not match" in str(exc_info.value)
+
+
+class TestKVBatchMeta:
+ """KVBatchMeta Tests"""
+
+ def test_kv_batch_meta_basic_init(self):
+ """Example: Basic KVBatchMeta initialization."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{"sample_id": 0}, {"sample_id": 1}, {"sample_id": 2}],
+ partition_id="partition_0",
+ fields=["field1", "field2"],
+ )
+
+ assert kv_meta.size == 3
+ assert len(kv_meta) == 3
+ assert kv_meta.keys == ["key1", "key2", "key3"]
+ assert kv_meta.partition_id == "partition_0"
+ assert kv_meta.fields == ["field1", "field2"]
+
+ def test_kv_batch_meta_empty_init(self):
+ """Example: Empty KVBatchMeta initialization."""
+ kv_meta = KVBatchMeta()
+
+ assert kv_meta.size == 0
+ assert len(kv_meta) == 0
+ assert kv_meta.keys == []
+ assert kv_meta.tags == []
+ assert kv_meta.partition_id is None
+ assert kv_meta.fields is None
+
+ def test_kv_batch_meta_init_validation_keys_tags_mismatch(self):
+ """Example: Init validation catches keys and tags length mismatch."""
+ with pytest.raises(ValueError) as exc_info:
+ KVBatchMeta(
+ keys=["key1", "key2"],
+ tags=[{"sample_id": 0}], # Only one tag
+ )
+ assert "keys and tags must have same length" in str(exc_info.value)
+
+ def test_kv_batch_meta_init_validation_duplicate_keys(self):
+ """Example: Init validation catches duplicate keys."""
+ with pytest.raises(ValueError) as exc_info:
+ KVBatchMeta(
+ keys=["key1", "key1"],
+ tags=[{"sample_id": 0}, {"sample_id": 1}],
+ partition_id="partition_0",
+ )
+ assert "Got duplicated keys" in str(exc_info.value)
+
+ def test_kv_batch_meta_init_validation_duplicate_fields(self):
+ """Example: Init validation catches duplicate fields."""
+ with pytest.raises(ValueError) as exc_info:
+ KVBatchMeta(
+ keys=["key1"],
+ tags=[{"sample_id": 0}],
+ partition_id="partition_0",
+ fields=["field1", "field1"],
+ )
+ assert "Got duplicated fields" in str(exc_info.value)
+
+ def test_kv_batch_meta_select_keys(self):
+ """Example: Select specific keys from KVBatchMeta."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{"idx": 0}, {"idx": 1}, {"idx": 2}],
+ partition_id="partition_0",
+ fields=["field1", "field2"],
+ extra_info={"test": "value"},
+ )
+
+ selected = kv_meta.select_keys(["key1", "key3"])
+
+ assert selected.keys == ["key1", "key3"]
+ assert selected.tags == [{"idx": 0}, {"idx": 2}]
+ assert selected.partition_id == "partition_0"
+ assert selected.fields == ["field1", "field2"]
+ assert selected.extra_info == {"test": "value"}
+
+ def test_kv_batch_meta_select_keys_validation_duplicate(self):
+ """Example: Select keys validation catches duplicate keys in input."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{}, {}, {}],
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ kv_meta.select_keys(["key1", "key1"])
+ assert "Contain duplicate keys" in str(exc_info.value)
+
+ def test_kv_batch_meta_select_keys_validation_nonexistent(self):
+ """Example: Select keys validation catches non-existent keys."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{}, {}, {}],
+ )
+
+ with pytest.raises(RuntimeError) as exc_info:
+ kv_meta.select_keys(["key1", "nonexistent"])
+ assert "not found in current batch" in str(exc_info.value)
+
+ def test_kv_batch_meta_reorder(self):
+ """Example: Reorder samples in KVBatchMeta."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{"idx": 0}, {"idx": 1}, {"idx": 2}],
+ )
+
+ kv_meta.reorder([2, 0, 1])
+
+ assert kv_meta.keys == ["key3", "key1", "key2"]
+ assert kv_meta.tags == [{"idx": 2}, {"idx": 0}, {"idx": 1}]
+
+ def test_kv_batch_meta_reorder_validation_size_mismatch(self):
+ """Example: Reorder validation catches size mismatch."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{}, {}, {}],
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ kv_meta.reorder([0, 1]) # Only 2 indexes for 3 samples
+ assert "does not match" in str(exc_info.value)
+
+ def test_kv_batch_meta_reorder_validation_duplicate_indexes(self):
+ """Example: Reorder validation catches duplicate indexes."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2", "key3"],
+ tags=[{}, {}, {}],
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ kv_meta.reorder([0, 0, 1]) # Duplicate index 0
+ assert "Contain duplicate indexes" in str(exc_info.value)
+
+ def test_kv_batch_meta_chunk(self):
+ """Example: Split KVBatchMeta into multiple chunks."""
+ kv_meta = KVBatchMeta(
+ keys=[f"key{i}" for i in range(10)],
+ tags=[{"idx": i} for i in range(10)],
+ partition_id="partition_0",
+ fields=["field1"],
+ extra_info={"test": "value"},
+ )
+
+ chunks = kv_meta.chunk(3)
+
+ assert len(chunks) == 3
+ assert len(chunks[0]) == 4 # First chunk gets extra element
+ assert len(chunks[1]) == 3
+ assert len(chunks[2]) == 3
+
+ # Verify partition_id and fields are preserved
+ assert chunks[0].partition_id == "partition_0"
+ assert chunks[0].fields == ["field1"]
+ assert chunks[0].extra_info == {"test": "value"}
+
+ # Verify keys and tags are correctly chunked
+ assert chunks[0].keys == ["key0", "key1", "key2", "key3"]
+ assert chunks[0].tags == [{"idx": 0}, {"idx": 1}, {"idx": 2}, {"idx": 3}]
+ assert chunks[1].keys == ["key4", "key5", "key6"]
+ assert chunks[1].tags == [{"idx": 4}, {"idx": 5}, {"idx": 6}]
+
+ def test_kv_batch_meta_chunk_with_more_chunks_than_samples(self):
+ """Example: Chunking when chunks > samples produces empty chunks."""
+ kv_meta = KVBatchMeta(
+ keys=["key1", "key2"],
+ tags=[{"idx": 0}, {"idx": 1}],
+ )
+
+ chunks = kv_meta.chunk(5)
+
+ assert len(chunks) == 5
+ assert len(chunks[0]) == 1
+ assert len(chunks[1]) == 1
+ assert len(chunks[2]) == 0
+ assert len(chunks[3]) == 0
+ assert len(chunks[4]) == 0
+
+ def test_kv_batch_meta_concat(self):
+ """Example: Concatenate multiple KVBatchMeta chunks."""
+ kv_meta1 = KVBatchMeta(
+ keys=["key0", "key1"],
+ tags=[{"idx": 0}, {"idx": 1}],
+ partition_id="partition_0",
+ fields=["field1"],
+ extra_info={"test": "value1"},
+ )
+
+ kv_meta2 = KVBatchMeta(
+ keys=["key2", "key3"],
+ tags=[{"idx": 2}, {"idx": 3}],
+ partition_id="partition_0",
+ fields=["field1"],
+ extra_info={"test": "value2"},
+ )
+
+ result = KVBatchMeta.concat([kv_meta1, kv_meta2])
+
+ assert result.size == 4
+ assert result.keys == ["key0", "key1", "key2", "key3"]
+ assert result.tags == [{"idx": 0}, {"idx": 1}, {"idx": 2}, {"idx": 3}]
+ assert result.partition_id == "partition_0"
+ assert result.fields == ["field1"]
+
+ def test_kv_batch_meta_concat_with_empty_chunks(self):
+ """Example: Concat handles empty KVBatchMeta chunks gracefully."""
+ kv_meta1 = KVBatchMeta()
+ kv_meta2 = KVBatchMeta(keys=["key0"], tags=[{"idx": 0}])
+ kv_meta3 = KVBatchMeta()
+
+ result = KVBatchMeta.concat([kv_meta1, kv_meta2, kv_meta3])
+
+ assert result.size == 1
+ assert result.keys == ["key0"]
+ assert result.tags == [{"idx": 0}]
+
+ def test_kv_batch_meta_concat_validation_field_mismatch(self):
+ """Example: Concat validation catches field name mismatches."""
+ kv_meta1 = KVBatchMeta(
+ keys=["key0"],
+ tags=[{}],
+ fields=["field1"],
+ )
+ kv_meta2 = KVBatchMeta(
+ keys=["key1"],
+ tags=[{}],
+ fields=["field2"], # Different field
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ KVBatchMeta.concat([kv_meta1, kv_meta2])
+ assert "Field names do not match" in str(exc_info.value)
+
+ def test_kv_batch_meta_concat_validation_partition_mismatch(self):
+ """Example: Concat validation catches partition_id mismatches."""
+ kv_meta1 = KVBatchMeta(
+ keys=["key0"],
+ tags=[{}],
+ partition_id="partition_0",
+ )
+ kv_meta2 = KVBatchMeta(
+ keys=["key1"],
+ tags=[{}],
+ partition_id="partition_1", # Different partition
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ KVBatchMeta.concat([kv_meta1, kv_meta2])
+ assert "Partition do not match" in str(exc_info.value)
+
+ def test_kv_batch_meta_concat_empty_list(self):
+ """Example: Concat with empty list returns empty KVBatchMeta."""
+ result = KVBatchMeta.concat([])
+
+ assert result.size == 0
+ assert result.keys == []
+ assert result.tags == []
+
+ def test_kv_batch_meta_deepcopy_tags(self):
+ """Example: Tags are deep copied to prevent mutation."""
+ original_tags = [{"data": [1, 2, 3]}]
+ kv_meta = KVBatchMeta(
+ keys=["key1"],
+ tags=original_tags,
+ )
+
+ # Modify the tag in the KVBatchMeta
+ kv_meta.tags[0]["data"].append(4)
+
+ # Original should not be modified
+ assert original_tags[0]["data"] == [1, 2, 3]
+
+ def test_kv_batch_meta_deepcopy_extra_info(self):
+ """Example: Extra info is deep copied to prevent mutation."""
+ original_extra = {"nested": {"value": 1}}
+ kv_meta = KVBatchMeta(
+ keys=["key1"],
+ tags=[{}],
+ extra_info=original_extra,
+ )
+
+ # Modify extra_info
+ kv_meta.extra_info["nested"]["value"] = 999
+
+ # Original should not be modified
+ assert original_extra["nested"]["value"] == 1
diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py
index 043b506..ed43e41 100644
--- a/tests/test_simple_storage_unit.py
+++ b/tests/test_simple_storage_unit.py
@@ -27,7 +27,7 @@
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
-from transfer_queue import SimpleStorageUnit # noqa: E402
+from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType # noqa: E402
diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py
index 592ef0d..4732e84 100644
--- a/transfer_queue/__init__.py
+++ b/transfer_queue/__init__.py
@@ -16,63 +16,64 @@
import os
from .client import TransferQueueClient
-from .controller import TransferQueueController
from .dataloader import StreamingDataLoader, StreamingDataset
from .interface import (
- async_clear_partition,
- async_clear_samples,
- async_get_data,
- async_get_meta,
- async_put,
- async_set_custom_meta,
- clear_partition,
- clear_samples,
+ async_kv_batch_get,
+ async_kv_batch_put,
+ async_kv_clear,
+ async_kv_list,
+ async_kv_put,
close,
- get_data,
- get_meta,
+ get_client,
init,
- put,
- set_custom_meta,
+ kv_batch_get,
+ kv_batch_put,
+ kv_clear,
+ kv_list,
+ kv_put,
)
-from .metadata import BatchMeta
+from .metadata import BatchMeta, KVBatchMeta
from .sampler import BaseSampler
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
from .sampler.rank_aware_sampler import RankAwareSampler
from .sampler.sequential_sampler import SequentialSampler
-from .storage import SimpleStorageUnit
-from .utils.common import get_placement_group
-from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info
-__all__ = [
- "init",
- "get_meta",
- "get_data",
- "put",
- "set_custom_meta",
- "clear_samples",
- "clear_partition",
- "async_get_meta",
- "async_get_data",
- "async_put",
- "async_set_custom_meta",
- "async_clear_samples",
- "async_clear_partition",
- "close",
-] + [
- "TransferQueueClient",
- "StreamingDataset",
- "StreamingDataLoader",
- "BatchMeta",
- "TransferQueueController",
- "SimpleStorageUnit",
- "ZMQServerInfo",
- "process_zmq_server_info",
- "get_placement_group",
- "BaseSampler",
- "GRPOGroupNSampler",
- "SequentialSampler",
- "RankAwareSampler",
-]
+__all__ = (
+ [
+ # High-Level KV Interface
+ "init",
+ "close",
+ "kv_put",
+ "kv_batch_put",
+ "kv_batch_get",
+ "kv_list",
+ "kv_clear",
+ "async_kv_put",
+ "async_kv_batch_put",
+ "async_kv_batch_get",
+ "async_kv_list",
+ "async_kv_clear",
+ "KVBatchMeta",
+ ]
+ + [
+ # High-Level StreamingDataLoader Interface
+ "StreamingDataset",
+ "StreamingDataLoader",
+ ]
+ + [
+ # Low-Level Native Interface
+ "get_client",
+ "BatchMeta",
+ "TransferQueueClient",
+ ]
+ + [
+ # Sampler
+ "BaseSampler",
+ "GRPOGroupNSampler",
+ "SequentialSampler",
+ "RankAwareSampler",
+ ]
+)
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
diff --git a/transfer_queue/client.py b/transfer_queue/client.py
index cd9c6e8..b06f90c 100644
--- a/transfer_queue/client.py
+++ b/transfer_queue/client.py
@@ -153,6 +153,7 @@ async def wrapper(self, *args, **kwargs):
return decorator
+ # ==================== Basic API ====================
@dynamic_socket(socket_name="request_handle_socket")
async def async_get_meta(
self,
@@ -215,7 +216,7 @@ async def async_get_meta(
"""
assert socket is not None
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.GET_META,
+ request_type=ZMQRequestType.GET_META, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
@@ -259,8 +260,8 @@ async def async_set_custom_meta(
Args:
metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta() or
- BatchMeta.set_custom_meta() before calling this method.
+ The custom_meta should be set using BatchMeta.update_custom_meta()
+ before calling this method.
socket: ZMQ async socket for message transmission (injected by decorator)
Raises:
@@ -269,7 +270,7 @@ async def async_set_custom_meta(
Example:
>>> # Create batch with custom metadata
>>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...)
- >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
+ >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
>>> asyncio.run(client.async_set_custom_meta(batch_meta))
"""
assert socket is not None
@@ -291,10 +292,13 @@ async def async_set_custom_meta(
partition_custom_meta: dict[str, dict[int, dict]] = {pid: {} for pid in set(metadata.partition_ids)}
for meta in metadata_chunks:
- partition_custom_meta[meta.partition_ids[0]].update(meta.get_all_custom_meta())
+ custom_meta = meta.get_all_custom_meta()
+ partition_custom_meta[meta.partition_ids[0]].update(
+ {meta.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta))}
+ )
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.SET_CUSTOM_META,
+ request_type=ZMQRequestType.SET_CUSTOM_META, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
@@ -532,7 +536,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
"""
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.CLEAR_META,
+ request_type=ZMQRequestType.CLEAR_META, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={"global_indexes": metadata.global_indexes, "partition_ids": metadata.partition_ids},
@@ -560,7 +564,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta
RuntimeError: If controller returns error response
"""
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.GET_PARTITION_META,
+ request_type=ZMQRequestType.GET_PARTITION_META, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={"partition_id": partition_id},
@@ -605,6 +609,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None):
if response_msg.request_type != ZMQRequestType.CLEAR_PARTITION_RESPONSE:
raise RuntimeError(f"Failed to clear partition {partition_id} in controller.")
+ # ==================== Status Query API ====================
@dynamic_socket(socket_name="request_handle_socket")
async def async_get_consumption_status(
self,
@@ -638,7 +643,7 @@ async def async_get_consumption_status(
assert socket is not None
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.GET_CONSUMPTION,
+ request_type=ZMQRequestType.GET_CONSUMPTION, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
@@ -675,7 +680,7 @@ async def async_get_production_status(
partition_id: str,
socket: Optional[zmq.asyncio.Socket] = None,
) -> tuple[Optional[Tensor], Optional[Tensor]]:
- """Get production status for current partition for specific fields.
+ """Get production status for specific data fields and partition.
Args:
data_fields: Data fields to check production status for
@@ -700,7 +705,7 @@ async def async_get_production_status(
"""
assert socket is not None
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.GET_PRODUCTION,
+ request_type=ZMQRequestType.GET_PRODUCTION, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={
@@ -765,6 +770,41 @@ async def async_check_consumption_status(
return False
return torch.all(consumption_status == 1).item()
+ async def async_check_production_status(
+ self,
+ data_fields: list[str],
+ partition_id: str,
+ ) -> bool:
+ """Check if the all specific fields of samples for current partition are ready
+ (produced) for consumption.
+
+ Args:
+ data_fields: Data fields to check production status for
+ partition_id: Partition id to check production status for
+
+ Returns:
+ bool: True if all samples have been produced and ready, False otherwise
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> # Check if all samples are ready for consumption
+ >>> is_ready = asyncio.run(client.async_check_production_status(
+ ... data_fields=["input_ids", "attention_mask"],
+ ... partition_id="train_0"
+ ... ))
+ >>> print(f"All samples ready: {is_ready}")
+ """
+ _, production_status = await self.async_get_production_status(
+ data_fields=data_fields,
+ partition_id=partition_id,
+ )
+
+ if production_status is None:
+ return False
+ return torch.all(production_status == 1).item()
+
@dynamic_socket(socket_name="request_handle_socket")
async def async_reset_consumption(
self,
@@ -772,17 +812,22 @@ async def async_reset_consumption(
task_name: Optional[str] = None,
socket: Optional[zmq.asyncio.Socket] = None,
) -> bool:
- """Reset consumption status for a partition, allowing data to be re-consumed.
- This is useful for debugging scenarios where the same rollout data needs to be
- trained multiple times without regenerating the data.
+ """Asynchronously reset consumption status for a partition.
+
+ This allows the same data to be re-consumed, useful for debugging scenarios
+ where the same rollout data needs to be trained multiple times.
+
Args:
partition_id: Partition id to reset consumption status for
task_name: Name of the task to reset. If None, resets all tasks.
socket: ZMQ async socket for message transmission (injected by decorator)
+
Returns:
bool: True if reset was successful, False otherwise
+
Raises:
RuntimeError: If communication fails or controller returns error response
+
Example:
>>> # Reset consumption for train task to re-train on same data
>>> success = asyncio.run(client.async_reset_consumption(
@@ -796,7 +841,7 @@ async def async_reset_consumption(
if task_name is not None:
body["task_name"] = task_name
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.RESET_CONSUMPTION,
+ request_type=ZMQRequestType.RESET_CONSUMPTION, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body=body,
@@ -822,41 +867,6 @@ async def async_reset_consumption(
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e
- async def async_check_production_status(
- self,
- data_fields: list[str],
- partition_id: str,
- ) -> bool:
- """Check if the all specific fields of samples for current partition are ready
- (produced) for consumption.
-
- Args:
- data_fields: Data fields to check production status for
- partition_id: Partition id to check production status for
-
- Returns:
- bool: True if all samples have been produced and ready, False otherwise
-
- Raises:
- RuntimeError: If communication fails or controller returns error response
-
- Example:
- >>> # Check if all samples are ready for consumption
- >>> is_ready = asyncio.run(client.async_check_production_status(
- ... data_fields=["input_ids", "attention_mask"],
- ... partition_id="train_0"
- ... ))
- >>> print(f"All samples ready: {is_ready}")
- """
- _, production_status = await self.async_get_production_status(
- data_fields=data_fields,
- partition_id=partition_id,
- )
-
- if production_status is None:
- return False
- return torch.all(production_status == 1).item()
-
@dynamic_socket(socket_name="request_handle_socket")
async def async_get_partition_list(
self,
@@ -869,9 +879,13 @@ async def async_get_partition_list(
Returns:
list[str]: List of partition ids managed by the controller
+
+ Example:
+ >>> partition_ids = asyncio.run(client.get_partition_list())
+ >>> print(f"Available partitions: {partition_ids}")
"""
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.GET_LIST_PARTITIONS,
+ request_type=ZMQRequestType.GET_LIST_PARTITIONS, # type: ignore[arg-type]
sender_id=self.client_id,
receiver_id=self._controller.id,
body={},
@@ -898,6 +912,132 @@ async def async_get_partition_list(
except Exception as e:
raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e
+ # ==================== KV Interface API ====================
+ @dynamic_socket(socket_name="request_handle_socket")
+ async def async_kv_retrieve_keys(
+ self,
+ keys: list[str] | str,
+ partition_id: str,
+ create: bool = False,
+ socket: Optional[zmq.asyncio.Socket] = None,
+ ) -> BatchMeta:
+ """Asynchronously retrieve BatchMeta from the controller using user-specified keys.
+
+ Args:
+ keys: List of keys to retrieve from the controller
+ partition_id: The ID of the logical partition to search for keys.
+ create: Whether to register new keys if not found.
+ socket: ZMQ socket (injected by decorator)
+
+ Returns:
+ metadata: BatchMeta of the corresponding keys
+
+ Raises:
+ TypeError: If `keys` is not a list of string or a string
+ """
+
+ if isinstance(keys, str):
+ keys = [keys]
+ elif isinstance(keys, list):
+ if len(keys) < 1:
+ raise ValueError("Received an empty list as keys.")
+ # validate all the elements are str
+ if not all(isinstance(k, str) for k in keys):
+ raise TypeError("Not all elements in `keys` are strings.")
+ else:
+ raise TypeError("Only string or list of strings are allowed as `keys`.")
+
+ request_msg = ZMQMessage.create(
+ request_type=ZMQRequestType.KV_RETRIEVE_KEYS, # type: ignore[arg-type]
+ sender_id=self.client_id,
+ receiver_id=self._controller.id,
+ body={
+ "keys": keys,
+ "partition_id": partition_id,
+ "create": create,
+ },
+ )
+
+ try:
+ assert socket is not None
+ await socket.send_multipart(request_msg.serialize())
+ response_serialized = await socket.recv_multipart()
+ response_msg = ZMQMessage.deserialize(response_serialized)
+ logger.debug(
+ f"[{self.client_id}]: Client get kv_retrieve_keys response: {response_msg} "
+ f"from controller {self._controller.id}"
+ )
+
+ if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE:
+ metadata = response_msg.body.get("metadata", BatchMeta.empty())
+ metadata = BatchMeta.from_dict(metadata) if isinstance(metadata, dict) else metadata
+ return metadata
+ else:
+ raise RuntimeError(
+ f"[{self.client_id}]: Failed to retrieve keys from controller {self._controller.id}: "
+ f"{response_msg.body.get('message', 'Unknown error')}"
+ )
+ except Exception as e:
+ raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e
+
+ @dynamic_socket(socket_name="request_handle_socket")
+ async def async_kv_list(
+ self,
+ partition_id: Optional[str] = None,
+ socket: Optional[zmq.asyncio.Socket] = None,
+ ) -> dict[str, dict[str, Any]]:
+ """Asynchronously retrieve keys and custom_meta from the controller for one or all partitions.
+
+ Args:
+ partition_id: The specific partition_id to query.
+ If None (default), returns keys from all partitions.
+ socket: ZMQ socket (injected by decorator)
+
+ Returns:
+ A nested dictionary mapping partition IDs to their keys and metadata.
+
+ Structure:
+ {
+ "partition_id": {
+ "key_name": {
+ "tag1": ,
+ ... (other metadata)
+ },
+ ...,
+ },
+ ...
+ }
+ """
+
+ request_msg = ZMQMessage.create(
+ request_type=ZMQRequestType.KV_LIST, # type: ignore[arg-type]
+ sender_id=self.client_id,
+ receiver_id=self._controller.id,
+ body={
+ "partition_id": partition_id,
+ },
+ )
+
+ try:
+ assert socket is not None
+ await socket.send_multipart(request_msg.serialize())
+ response_serialized = await socket.recv_multipart()
+ response_msg = ZMQMessage.deserialize(response_serialized)
+ logger.debug(
+ f"[{self.client_id}]: Client get kv_list response: {response_msg} from controller {self._controller.id}"
+ )
+
+ if response_msg.request_type == ZMQRequestType.KV_LIST_RESPONSE:
+ partition_info = response_msg.body.get("partition_info", {})
+ return partition_info
+ else:
+ raise RuntimeError(
+ f"[{self.client_id}]: Failed to list keys from controller {self._controller.id}: "
+ f"{response_msg.body.get('message', 'Unknown error')}"
+ )
+ except Exception as e:
+ raise RuntimeError(f"[{self.client_id}]: Error in kv_list: {str(e)}") from e
+
def close(self) -> None:
"""Close the client and cleanup resources including storage manager."""
try:
@@ -972,67 +1112,10 @@ def wrapper(*args, **kwargs):
self._get_partition_list = _make_sync(self.async_get_partition_list)
self._set_custom_meta = _make_sync(self.async_set_custom_meta)
self._reset_consumption = _make_sync(self.async_reset_consumption)
+ self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys)
+ self._kv_list = _make_sync(self.async_kv_list)
- def put(
- self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None
- ) -> BatchMeta:
- """Synchronously write data to storage units based on metadata.
-
- If metadata is not provided, it will be created automatically using insert mode
- with the provided data fields and partition_id.
-
- During put, the custom_meta in metadata will update the corresponding custom_meta in
- TransferQueue Controller.
-
- Note:
- When using multiple workers for distributed execution, there may be data
- ordering inconsistencies between workers during put operations.
-
- Args:
- data: Data to write as TensorDict
- metadata: Records the metadata of a batch of data samples, containing index and
- storage unit information. If None, metadata will be auto-generated.
- partition_id: Target data partition id (required if metadata is not provided)
-
- Returns:
- BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
- metadata; will be updated in a future version to reflect the post-put state)
-
- Raises:
- ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
- RuntimeError: If storage operation fails
-
- Example:
- >>> batch_size = 4
- >>> seq_len = 16
- >>> current_partition_id = "train_0"
- >>> # Example 1: Normal usage with existing metadata
- >>> batch_meta = client.get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=batch_size,
- ... partition_id=current_partition_id,
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... )
- >>> batch = client.get_data(batch_meta)
- >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
- >>> client.put(data=output, metadata=batch_meta)
- >>>
- >>> # Example 2: Initial data insertion without pre-existing metadata
- >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
- >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
- >>> # without metadata.
- >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
- >>> # interleave the initial data if n_sample > 1 before calling the async_put().
- >>> original_prompts = torch.randn(batch_size, seq_len)
- >>> n_samples = 4
- >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
- >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
- >>> # This will create metadata in "insert" mode internally.
- >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id)
- """
- return self._put(data=data, metadata=metadata, partition_id=partition_id)
-
+ # ==================== Basic API ====================
def get_meta(
self,
data_fields: list[str],
@@ -1101,6 +1184,90 @@ def get_meta(
sampling_config=sampling_config,
)
+ def set_custom_meta(self, metadata: BatchMeta) -> None:
+ """Synchronously send custom metadata to the controller.
+
+ This method sends per-sample custom metadata (custom_meta) to the controller.
+ The custom_meta is stored in the controller and can be retrieved along with
+ the BatchMeta in subsequent get_meta calls.
+
+ Args:
+ metadata: BatchMeta containing the samples and their custom metadata to store.
+ The custom_meta should be set using BatchMeta.update_custom_meta()
+ before calling this method.
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> # Create batch with custom metadata
+ >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=2, ...)
+ >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}])
+ >>> client.set_custom_meta(batch_meta)
+ """
+
+ return self._set_custom_meta(metadata=metadata)
+
+ def put(
+ self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None
+ ) -> BatchMeta:
+ """Synchronously write data to storage units based on metadata.
+
+ If metadata is not provided, it will be created automatically using insert mode
+ with the provided data fields and partition_id.
+
+ During put, the custom_meta in metadata will update the corresponding custom_meta in
+ TransferQueue Controller.
+
+ Note:
+ When using multiple workers for distributed execution, there may be data
+ ordering inconsistencies between workers during put operations.
+
+ Args:
+ data: Data to write as TensorDict
+ metadata: Records the metadata of a batch of data samples, containing index and
+ storage unit information. If None, metadata will be auto-generated.
+ partition_id: Target data partition id (required if metadata is not provided)
+
+ Returns:
+ BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
+ metadata; will be updated in a future version to reflect the post-put state)
+
+ Raises:
+ ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
+ RuntimeError: If storage operation fails
+
+ Example:
+ >>> batch_size = 4
+ >>> seq_len = 16
+ >>> current_partition_id = "train_0"
+ >>> # Example 1: Normal usage with existing metadata
+ >>> batch_meta = client.get_meta(
+ ... data_fields=["prompts", "attention_mask"],
+ ... batch_size=batch_size,
+ ... partition_id=current_partition_id,
+ ... mode="fetch",
+ ... task_name="generate_sequences",
+ ... )
+ >>> batch = client.get_data(batch_meta)
+ >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
+ >>> client.put(data=output, metadata=batch_meta)
+ >>>
+ >>> # Example 2: Initial data insertion without pre-existing metadata
+ >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
+ >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
+ >>> # without metadata.
+ >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
+ >>> # interleave the initial data if n_sample > 1 before calling the async_put().
+ >>> original_prompts = torch.randn(batch_size, seq_len)
+ >>> n_samples = 4
+ >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
+ >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
+ >>> # This will create metadata in "insert" mode internally.
+ >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id)
+ """
+ return self._put(data=data, metadata=metadata, partition_id=partition_id)
+
def get_data(self, metadata: BatchMeta) -> TensorDict:
"""Synchronously fetch data from storage units and organize into TensorDict.
@@ -1112,7 +1279,7 @@ def get_data(self, metadata: BatchMeta) -> TensorDict:
- Requested data fields (e.g., "prompts", "attention_mask")
Example:
- >>> batch_meta = client.get_data(
+ >>> batch_meta = client.get_meta(
... data_fields=["prompts", "attention_mask"],
... batch_size=4,
... partition_id="train_0",
@@ -1147,65 +1314,85 @@ def clear_samples(self, metadata: BatchMeta):
"""
return self._clear_samples(metadata=metadata)
- def check_consumption_status(self, task_name: str, partition_id: str) -> bool:
- """Synchronously check if all samples for a partition have been consumed by a specific task.
+ # ==================== Status Query API ====================
+ def get_consumption_status(
+ self,
+ task_name: str,
+ partition_id: str,
+ ) -> tuple[Optional[Tensor], Optional[Tensor]]:
+ """Synchronously get consumption status for a specific task and partition.
Args:
task_name: Name of the task to check consumption for
partition_id: Partition id to check consumption status for
Returns:
- bool: True if all samples have been consumed by the task, False otherwise
+ Tuple of:
+ - Partition global index tensor
+ - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed.
Raises:
RuntimeError: If communication fails or controller returns error response
Example:
- >>> # Check if all samples have been consumed
- >>> is_consumed = client.check_consumption_status(
+ >>> global_index, consumption_status = client.get_consumption_status(
... task_name="generate_sequences",
... partition_id="train_0"
... )
- >>> print(f"All samples consumed: {is_consumed}")
+ >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}")
"""
- return self._check_consumption_status(task_name=task_name, partition_id=partition_id)
+ return self._get_consumption_status(task_name, partition_id)
- def get_consumption_status(
+ def get_production_status(
self,
- task_name: str,
+ data_fields: list[str],
partition_id: str,
) -> tuple[Optional[Tensor], Optional[Tensor]]:
- """Synchronously get consumption status for a specific task and partition.
+ """Synchronously get production status for specific data fields and partition.
Args:
- task_name: Name of the task to check consumption for
- partition_id: Partition id to check consumption status for
+ data_fields: Data fields to check production status for
+ partition_id: Partition id to check production status for
Returns:
Tuple of:
- Partition global index tensor
- - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed.
+ - Production status tensor for the specified fields. 1 for ready, 0 for not ready.
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
Example:
- >>> global_index, consumption_status = client.get_consumption_status(
- ... task_name="generate_sequences",
+ >>> global_index, production_status = client.get_production_status(
+ ... data_fields=["input_ids", "attention_mask"],
... partition_id="train_0"
... )
- >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}")
+ >>> print(f"Global index: {global_index}, Production status: {production_status}")
"""
- return self._get_consumption_status(task_name, partition_id)
+ return self._get_production_status(data_fields=data_fields, partition_id=partition_id)
+
+ def check_consumption_status(self, task_name: str, partition_id: str) -> bool:
+ """Synchronously check if all samples for a partition have been consumed by a specific task.
- def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool:
- """Synchronously reset consumption status for a partition.
- This allows the same data to be re-consumed, useful for debugging scenarios
- where the same rollout data needs to be trained multiple times.
Args:
- partition_id: Partition id to reset consumption status for
- task_name: Name of the task to reset. If None, resets all tasks.
+ task_name: Name of the task to check consumption for
+ partition_id: Partition id to check consumption status for
+
Returns:
- bool: True if reset was successful, False otherwise
+ bool: True if all samples have been consumed by the task, False otherwise
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
+
+ Example:
+ >>> # Check if all samples have been consumed
+ >>> is_consumed = client.check_consumption_status(
+ ... task_name="generate_sequences",
+ ... partition_id="train_0"
+ ... )
+ >>> print(f"All samples consumed: {is_consumed}")
"""
- return self._reset_consumption(partition_id, task_name)
+ return self._check_consumption_status(task_name=task_name, partition_id=partition_id)
def check_production_status(self, data_fields: list[str], partition_id: str) -> bool:
"""Synchronously check if all samples for a partition are ready (produced) for consumption.
@@ -1214,6 +1401,9 @@ def check_production_status(self, data_fields: list[str], partition_id: str) ->
data_fields: Data fields to check production status for
partition_id: Partition id to check production status for
+ Returns:
+ bool: True if all samples have been produced and ready, False otherwise
+
Raises:
RuntimeError: If communication fails or controller returns error response
@@ -1227,30 +1417,31 @@ def check_production_status(self, data_fields: list[str], partition_id: str) ->
"""
return self._check_production_status(data_fields=data_fields, partition_id=partition_id)
- def get_production_status(
- self,
- data_fields: list[str],
- partition_id: str,
- ) -> tuple[Optional[Tensor], Optional[Tensor]]:
- """Synchronously get production status for a specific data fields and partition.
+ def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool:
+ """Synchronously reset consumption status for a partition.
+
+ This allows the same data to be re-consumed, useful for debugging scenarios
+ where the same rollout data needs to be trained multiple times.
Args:
- data_fields: Data fields to check production status for
- partition_id: Partition id to check production status for
+ partition_id: Partition id to reset consumption status for
+ task_name: Name of the task to reset. If None, resets all tasks.
Returns:
- Tuple of:
- - Partition global index tensor
- - Production status tensor for the specified fields. 1 for ready, 0 for not ready.
+ bool: True if reset was successful, False otherwise
+
+ Raises:
+ RuntimeError: If communication fails or controller returns error response
Example:
- >>> global_index, production_status = client.get_production_status(
- ... data_fields=["input_ids", "attention_mask"],
- ... partition_id="train_0"
+ >>> # Reset consumption for train task to re-train on same data
+ >>> success = client.reset_consumption(
+ ... partition_id="train_0",
+ ... task_name="train"
... )
- >>> print(f"Global index: {global_index}, Production status: {production_status}")
+ >>> print(f"Reset successful: {success}")
"""
- return self._get_production_status(data_fields=data_fields, partition_id=partition_id)
+ return self._reset_consumption(partition_id, task_name)
def get_partition_list(
self,
@@ -1259,32 +1450,64 @@ def get_partition_list(
Returns:
list[str]: List of partition ids managed by the controller
+
+ Example:
+ >>> partition_ids = client.get_partition_list()
+ >>> print(f"Available partitions: {partition_ids}")
"""
return self._get_partition_list()
- def set_custom_meta(self, metadata: BatchMeta) -> None:
- """Synchronously send custom metadata to the controller.
-
- This method sends per-sample custom metadata (custom_meta) to the controller.
- The custom_meta is stored in the controller and can be retrieved along with
- the BatchMeta in subsequent get_meta calls.
+ # ==================== KV Interface API ====================
+ def kv_retrieve_keys(
+ self,
+ keys: list[str] | str,
+ partition_id: str,
+ create: bool = False,
+ ) -> BatchMeta:
+ """Synchronously retrieve BatchMeta from the controller using user-specified keys.
Args:
- metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta() or
- BatchMeta.set_custom_meta() before calling this method.
+ keys: List of keys to retrieve from the controller
+ partition_id: The ID of the logical partition to search for keys.
+ create: Whether to register new keys if not found.
+
+ Returns:
+ metadata: BatchMeta of the corresponding keys
Raises:
- RuntimeError: If communication fails or controller returns error response
+ TypeError: If `keys` is not a list of string or a string
+ """
- Example:
- >>> # Create batch with custom metadata
- >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...)
- >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
- >>> client.set_custom_meta(batch_meta)
+ return self._kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create)
+
+ def kv_list(
+ self,
+ partition_id: Optional[str] = None,
+ ) -> dict[str, dict[str, Any]]:
+ """Synchronously retrieve keys and custom_meta from the controller for one or all partitions.
+
+ Args:
+ partition_id: The specific partition_id to query.
+ If None (default), returns keys from all partitions.
+ socket: ZMQ socket (injected by decorator)
+
+ Returns:
+ A nested dictionary mapping partition IDs to their keys and metadata.
+
+ Structure:
+ {
+ "partition_id": {
+ "key_name": {
+ "tag1": ,
+ ... (other metadata)
+ },
+ ...,
+ },
+ ...
+ }
"""
- return self._set_custom_meta(metadata=metadata)
+ return self._kv_list(partition_id=partition_id)
def close(self) -> None:
"""Close the client and cleanup resources including event loop and thread."""
diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py
index 763f3f4..3c7c3f9 100644
--- a/transfer_queue/controller.py
+++ b/transfer_queue/controller.py
@@ -214,7 +214,7 @@ class DataPartitionStatus:
# Each tensor tracks which samples have been consumed by that task
consumption_status: dict[str, Tensor] = field(default_factory=dict)
- # Sample metadata
+ # Global indexes
global_indexes: set[int] = field(
default_factory=set
) # set of global indexes that have been added to this partition
@@ -233,6 +233,10 @@ class DataPartitionStatus:
# User-defined metadata that may not apply to field level
custom_meta: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {}
+ # User-defined Keys
+ keys_mapping: dict[str, int] = field(default_factory=dict) # key -> global_idx
+ revert_keys_mapping: dict[int, str] = field(default_factory=dict) # global_idx -> key
+
# Threading lock for concurrency control; only for preventing mask operation error when expanding production_status.
# No need to strictly lock for every read/write operation since freshness is not critical.
data_status_lock: Lock = field(default_factory=Lock)
@@ -536,15 +540,19 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8)
# Get consumption status for requested task
- consumption_status = self.consumption_status[task_name]
-
partition_global_index = torch.tensor(
sorted(self.global_indexes | self.pre_allocated_global_indexes), dtype=torch.long
)
if mask:
- consumption_status = consumption_status[partition_global_index]
-
+ if partition_global_index.numel() == 0:
+ empty_status = self.consumption_status[task_name].new_zeros(0)
+ return partition_global_index, empty_status
+ with self.data_status_lock:
+ self.ensure_samples_capacity(max(partition_global_index) + 1)
+ consumption_status = self.consumption_status[task_name][partition_global_index]
+ else:
+ consumption_status = self.consumption_status[task_name]
return partition_global_index, consumption_status
def reset_consumption(self, task_name: Optional[str] = None):
@@ -826,12 +834,21 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr
self.field_custom_backend_meta.pop(idx, None)
self.custom_meta.pop(idx, None)
+ if idx in self.revert_keys_mapping:
+ self.keys_mapping.pop(self.revert_keys_mapping[idx], None)
+ self.revert_keys_mapping.pop(idx, None)
+
except Exception as e:
logger.error(
f"Error clearing data for partition {self.partition_id}: {e}. "
f"Attempted to clear global_indexes: {indexes_to_release}"
)
+ def kv_retrieve_keys(self, keys: list[str]) -> list[int | None]:
+ """Translate the user-specified keys to global_indexes"""
+ global_indexes = [self.keys_mapping.get(k, None) for k in keys]
+ return global_indexes
+
@ray.remote(num_cpus=1)
class TransferQueueController:
@@ -971,17 +988,19 @@ def list_partitions(self) -> list[str]:
# ==================== Partition Index Management API ====================
- def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]:
+ def get_partition_index_range(self, partition_id: str) -> list[int]:
"""
Get all indexes for a specific partition.
Args:
- partition: Partition identifier
+ partition_id: Partition identifier
Returns:
List of indexes allocated to the partition
"""
- return self.index_manager.get_indexes_for_partition(partition)
+ # Note: This includes the pre-allocated global_indexes for the partition.
+ # i.e., partition.global_indexes + partition.pre_allocated_global_indexes
+ return self.index_manager.get_indexes_for_partition(partition_id)
# ==================== Data Production API ====================
@@ -1160,9 +1179,10 @@ def get_metadata(
)
batch_global_indexes.extend(new_global_indexes)
+ # register global_indexes in partition
+ partition.global_indexes.update(batch_global_indexes)
+
else:
- # TODO: separate this "clear" related logic into a separated mode
- # clear metadata call passes empty data_fields
batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id)
return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)
@@ -1314,7 +1334,7 @@ def generate_batch_meta(
and partition.production_status is not None
and partition.production_status[global_index, field_index] == 1
):
- production_status = ProductionStatus.NOT_PRODUCED
+ production_status = ProductionStatus.READY_FOR_CONSUME
dtype = partition.get_field_dtype(global_index, field_name)
shape = partition.get_field_shape(global_index, field_name)
else:
@@ -1340,7 +1360,7 @@ def generate_batch_meta(
custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields)
batch_meta = BatchMeta(samples=samples)
- batch_meta.update_custom_meta(custom_meta)
+ batch_meta.update_custom_meta([custom_meta.get(idx, {}) for idx in batch_meta.global_indexes])
batch_meta._custom_backend_meta.update(custom_backend_meta)
return batch_meta
@@ -1357,7 +1377,8 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True):
partition = self._get_partition(partition_id)
if not partition:
- raise ValueError(f"Partition {partition_id} not found")
+ logger.warning(f"Try to clear an non-existent partition {partition_id}!")
+ return
global_indexes_range = list(self.index_manager.get_indexes_for_partition(partition_id))
partition.clear_data(global_indexes_range, clear_consumption)
@@ -1374,13 +1395,13 @@ def reset_consumption(self, partition_id: str, task_name: Optional[str] = None):
Args:
partition_id: ID of the partition to reset consumption for
task_name: Name of the task to reset. If None, resets all tasks.
- Raises:
- ValueError: If partition not found
+
"""
logger.debug(f"[{self.controller_id}]: Resetting consumption for partition {partition_id}, task={task_name}")
partition = self._get_partition(partition_id)
if not partition:
- raise ValueError(f"Partition {partition_id} not found")
+ logger.warning(f"Try to reset consumption of an non-existent partition {partition_id}!")
+ return
partition.reset_consumption(task_name)
def clear_meta(
@@ -1432,6 +1453,82 @@ def clear_meta(
# Release the specific indexes from index manager
self.index_manager.release_indexes(partition_id, global_indexes_to_clear)
+ def kv_retrieve_keys(
+ self,
+ keys: list[str],
+ partition_id: str,
+ create: bool = False,
+ ) -> BatchMeta:
+ """
+ Retrieve BatchMeta from the controller using a list of keys.
+
+ Args:
+ keys: List of keys to retrieve from the controller
+ partition_id: Partition id to retrieve from the controller
+ create: Whether to register new keys if not found.
+
+ Returns:
+ metadata: BatchMeta of the requested keys
+ """
+
+ logger.debug(f"[{self.controller_id}]: Retrieve keys {keys} in partition {partition_id}")
+
+ partition = self._get_partition(partition_id)
+
+ if partition is None:
+ if not create:
+ logger.warning(f"Partition {partition_id} were not found in controller!")
+ return BatchMeta.empty()
+ else:
+ self.create_partition(partition_id)
+ partition = self._get_partition(partition_id)
+
+ assert partition is not None
+ global_indexes = partition.kv_retrieve_keys(keys)
+
+ none_indexes = [idx for idx, value in enumerate(global_indexes) if value is None]
+ if len(none_indexes) > 0:
+ if not create:
+ logger.warning(f"Keys {[keys[i] for i in none_indexes]} were not found in partition {partition_id}!")
+ return BatchMeta.empty()
+ else:
+ # create non-exist keys
+ batch_global_indexes = partition.activate_pre_allocated_indexes(len(none_indexes))
+
+ if len(batch_global_indexes) < len(none_indexes):
+ new_global_indexes = self.index_manager.allocate_indexes(
+ partition_id, count=(len(none_indexes) - len(batch_global_indexes))
+ )
+ batch_global_indexes.extend(new_global_indexes)
+
+ # register global_indexes in partition
+ partition.global_indexes.update(batch_global_indexes)
+
+ # register key-global_indexes mapping in partition
+ for i in range(len(none_indexes)):
+ global_indexes[none_indexes[i]] = batch_global_indexes[i]
+ partition.keys_mapping[keys[none_indexes[i]]] = batch_global_indexes[i]
+ partition.revert_keys_mapping[batch_global_indexes[i]] = keys[none_indexes[i]]
+
+ with partition.data_status_lock:
+ partition.ensure_samples_capacity(max(batch_global_indexes) + 1)
+
+ verified_global_indexes = [idx for idx in global_indexes if idx is not None]
+ assert len(verified_global_indexes) == len(keys)
+
+ # must fetch fields that the requested samples all have
+ col_mask = partition.production_status[verified_global_indexes, :].sum(dim=0).reshape(-1) == len(
+ verified_global_indexes
+ )
+ data_fields = []
+ for fname, col_idx in partition.field_name_mapping.items():
+ if col_mask[col_idx]:
+ data_fields.append(fname)
+
+ metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch")
+
+ return metadata
+
def _init_zmq_socket(self):
"""Initialize ZMQ sockets for communication."""
self.zmq_context = zmq.Context()
@@ -1727,6 +1824,51 @@ def _process_request(self):
body={"partition_ids": partition_ids},
)
+ elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS:
+ with perf_monitor.measure(op_type="KV_RETRIEVE_KEYS"):
+ params = request_msg.body
+ keys = params["keys"]
+ partition_id = params["partition_id"]
+ create = params["create"]
+
+ metadata = self.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create)
+ response_msg = ZMQMessage.create(
+ request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE,
+ sender_id=self.controller_id,
+ receiver_id=request_msg.sender_id,
+ body={"metadata": metadata},
+ )
+
+ elif request_msg.request_type == ZMQRequestType.KV_LIST:
+ with perf_monitor.measure(op_type="KV_LIST"):
+ params = request_msg.body
+ partition_id = params["partition_id"]
+ if partition_id is None:
+ partition_id = list(self.partitions.keys())
+ else:
+ partition_id = [partition_id]
+
+ message = "success"
+ partition_info = {}
+ for pid in partition_id:
+ partition = self._get_partition(pid)
+ if partition:
+ keys = list(partition.keys_mapping.keys())
+ single_partition_info = {
+ k: partition.custom_meta.get(partition.keys_mapping[k], {}) for k in keys
+ }
+ partition_info[pid] = single_partition_info
+ else:
+ # this only happens when params["partition_id"] is not None
+ message = f"partition {pid} does not exist"
+
+ response_msg = ZMQMessage.create(
+ request_type=ZMQRequestType.KV_LIST_RESPONSE,
+ sender_id=self.controller_id,
+ receiver_id=request_msg.sender_id,
+ body={"partition_info": partition_info, "message": message},
+ )
+
self.request_handle_socket.send_multipart([identity, *response_msg.serialize()])
def _update_data_status(self):
diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py
index 37ebdb2..8e4f8c3 100644
--- a/transfer_queue/interface.py
+++ b/transfer_queue/interface.py
@@ -21,12 +21,13 @@
from typing import Any, Optional
import ray
+import torch
from omegaconf import DictConfig, OmegaConf
from tensordict import TensorDict
+from tensordict.tensorclass import NonTensorStack
from transfer_queue.client import TransferQueueClient
from transfer_queue.controller import TransferQueueController
-from transfer_queue.metadata import BatchMeta
from transfer_queue.sampler import * # noqa: F401
from transfer_queue.sampler import BaseSampler
from transfer_queue.storage.simple_backend import SimpleStorageUnit
@@ -105,6 +106,7 @@ def _init_from_existing() -> None:
time.sleep(1)
+# ==================== Initialization API ====================
def init(conf: Optional[DictConfig] = None) -> None:
"""Initialize the TransferQueue system.
@@ -194,456 +196,564 @@ def init(conf: Optional[DictConfig] = None) -> None:
_maybe_create_transferqueue_client(final_conf)
-def get_meta(
- data_fields: list[str],
- batch_size: int,
- partition_id: str,
- mode: str = "fetch",
- task_name: Optional[str] = None,
- sampling_config: Optional[dict[str, Any]] = None,
-) -> BatchMeta:
- """Synchronously fetch data metadata from the controller via ZMQ.
+def close():
+ """Close the TransferQueue system.
- Args:
- data_fields: List of data field names to retrieve metadata for
- batch_size: Number of samples to request in the batch
- partition_id: Current data partition id
- mode: Data fetch mode. Options:
- - 'fetch': Get ready data only
- - 'force_fetch': Get data regardless of readiness (may return unready samples)
- - 'insert': Internal usage - should not be used by users
- task_name: Optional task name associated with the request
- sampling_config: Optional sampling configuration for custom samplers.
+ This function cleans up the TransferQueue system, including:
+ - Closing the client and its associated resources
+ - Cleaning up distributed storage (only for the process that initialized it)
+ - Killing the controller actor
+ Note:
+ This function should be called when the TransferQueue system is no longer needed.
+ """
+ global _TRANSFER_QUEUE_CLIENT
+ global _TRANSFER_QUEUE_STORAGE
+ if _TRANSFER_QUEUE_CLIENT:
+ _TRANSFER_QUEUE_CLIENT.close()
+ _TRANSFER_QUEUE_CLIENT = None
- Returns:
- BatchMeta: Metadata object containing data structure, sample information, and readiness status
+ try:
+ if _TRANSFER_QUEUE_STORAGE:
+ # only the process that do first-time init can clean the distributed storage
+ for storage in _TRANSFER_QUEUE_STORAGE.values():
+ ray.kill(storage)
+ _TRANSFER_QUEUE_STORAGE = None
+ except Exception:
+ pass
+
+ try:
+ controller = ray.get_actor("TransferQueueController")
+ ray.kill(controller)
+ except Exception:
+ pass
+
+
+# ==================== High-Level KV Interface API ====================
+def kv_put(
+ key: str,
+ partition_id: str,
+ fields: Optional[TensorDict | dict[str, Any]] = None,
+ tag: Optional[dict[str, Any]] = None,
+) -> None:
+ """Put a single key-value pair to TransferQueue.
+
+ This is a convenience method for putting data using a user-specified key
+ instead of BatchMeta. Internally, the key is translated to a BatchMeta
+ and the data is stored using the regular put mechanism.
+
+ Args:
+ key: User-specified key for the data sample (in row)
+ partition_id: Logical partition to store the data in
+ fields: Data fields to store. Can be a TensorDict or a dict of tensors.
+ Each key in `fields` will be treated as a column for the data sample.
+ If dict is provided, tensors will be unsqueezed to add batch dimension.
+ tag: Optional metadata tag to associate with the key
Raises:
- RuntimeError: If communication fails or controller returns error response
+ ValueError: If neither fields nor tag is provided
+ ValueError: If nested tensors are provided (use kv_batch_put instead)
+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
Example:
>>> import transfer_queue as tq
+ >>> import torch
>>> tq.init()
- >>>
- >>> # Example 1: Basic fetch metadata
- >>> batch_meta = tq.get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences"
- ... )
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
- >>> batch_meta = tq.get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=8,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... sampling_config={"n_samples_per_prompt": 4}
+ >>> # Put with both fields and tag
+ >>> tq.kv_put(
+ ... key="sample_1",
+ ... partition_id="train",
+ ... fields={"input_ids": torch.tensor([1, 2, 3])},
+ ... tag={"score": 0.95}
... )
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
- >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
- >>> batch_meta = tq.get_meta(
- ... partition_id="train_0", # optional
- ... mode="force_fetch",
- ... )
- >>> print(batch_meta.is_ready) # May be False if some samples not ready
"""
+ if fields is None and tag is None:
+ raise ValueError("Please provide at least one parameter of `fields` or `tag`.")
tq_client = _maybe_create_transferqueue_client()
- return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config)
+ # 1. translate user-specified key to BatchMeta
+ batch_meta = tq_client.kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True)
+
+ if batch_meta.size != 1:
+ raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
+
+ # 2. register the user-specified tag to BatchMeta
+ if tag is not None:
+ batch_meta.update_custom_meta([tag])
+
+ # 3. put data
+ if fields is not None:
+ if isinstance(fields, dict):
+ # TODO: consider whether to support this...
+ batch = {}
+ for field_name, value in fields.items():
+ if isinstance(value, torch.Tensor):
+ if value.is_nested:
+ raise ValueError("Please use (async)kv_batch_put for batch operation")
+ batch[field_name] = value.unsqueeze(0)
+ else:
+ batch[field_name] = NonTensorStack(value)
+ fields = TensorDict(batch, batch_size=[1])
+ elif not isinstance(fields, TensorDict):
+ raise ValueError("field can only be dict or TensorDict")
+
+ # custom_meta (tag) will be put to controller through the internal put process
+ tq_client.put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tag) to controller
+ tq_client.set_custom_meta(batch_meta)
-async def async_get_meta(
- data_fields: list[str],
- batch_size: int,
- partition_id: str,
- mode: str = "fetch",
- task_name: Optional[str] = None,
- sampling_config: Optional[dict[str, Any]] = None,
-) -> BatchMeta:
- """Asynchronously fetch data metadata from the controller via ZMQ.
- Args:
- data_fields: List of data field names to retrieve metadata for
- batch_size: Number of samples to request in the batch
- partition_id: Current data partition id
- mode: Data fetch mode. Options:
- - 'fetch': Get ready data only
- - 'force_fetch': Get data regardless of readiness (may return unready samples)
- - 'insert': Internal usage - should not be used by users
- task_name: Optional task name associated with the request
- sampling_config: Optional sampling configuration for custom samplers.
- socket: ZMQ async socket for message transmission (injected by decorator)
+def kv_batch_put(
+ keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None
+) -> None:
+ """Put multiple key-value pairs to TransferQueue in batch.
- Returns:
- BatchMeta: Metadata object containing data structure, sample information, and readiness status
+ This method stores multiple key-value pairs in a single operation, which is more
+ efficient than calling kv_put multiple times.
+
+ Args:
+ keys: List of user-specified keys for the data
+ partition_id: Logical partition to store the data in
+ fields: TensorDict containing data for all keys. Must have batch_size == len(keys)
+ tags: List of metadata tags, one for each key
Raises:
- RuntimeError: If communication fails or controller returns error response
+ ValueError: If neither `fields` nor `tags` is provided
+ ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
Example:
>>> import transfer_queue as tq
+ >>> from tensordict import TensorDict
>>> tq.init()
- >>>
- >>> # Example 1: Basic fetch metadata
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences"
- ... ))
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["input_ids", "attention_mask"],
- ... batch_size=8,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... ))
- >>> print(batch_meta.is_ready) # True if all samples ready
- >>>
- >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
- >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.)
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... partition_id="train_0", # optional
- ... mode="force_fetch",
- ... ))
- >>> print(batch_meta.is_ready) # May be False if some samples not ready
+ >>> keys = ["sample_1", "sample_2", "sample_3"]
+ >>> fields = TensorDict({
+ ... "input_ids": torch.randn(3, 10),
+ ... "attention_mask": torch.ones(3, 10),
+ ... }, batch_size=3)
+ >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
+ >>> tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
"""
+ if fields is None and tags is None:
+ raise ValueError("Please provide at least one parameter of fields or tag.")
+
+ if fields is not None and fields.batch_size[0] != len(keys):
+ raise ValueError(
+ f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
+ f"batch_size {fields.batch_size[0]}"
+ )
+
tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config)
+
+ # 1. translate user-specified key to BatchMeta
+ batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True)
+
+ if batch_meta.size != len(keys):
+ raise RuntimeError(
+ f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!"
+ )
+
+ # 2. register the user-specified tags to BatchMeta
+ if tags:
+ if len(tags) != len(keys):
+ raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
+ batch_meta.update_custom_meta(tags)
+
+ # 3. put data
+ if fields is not None:
+ tq_client.put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tags) to controller
+ tq_client.set_custom_meta(batch_meta)
-def get_data(metadata: BatchMeta) -> TensorDict:
- """Synchronously fetch data from storage units and organize into TensorDict.
+def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict:
+ """Get data from TransferQueue using user-specified keys.
+
+ This is a convenience method for retrieving data using keys instead of indexes.
Args:
- metadata: Batch metadata containing data location information and global indexes
+ keys: Single key or list of keys to retrieve
+ partition_id: Partition containing the keys
+ fields: Optional field(s) to retrieve. If None, retrieves all fields
Returns:
- TensorDict containing:
- - Requested data fields (e.g., "prompts", "attention_mask")
+ TensorDict with the requested data
+
+ Raises:
+ RuntimeError: If keys or partition are not found
+ RuntimeError: If empty fields exist in any key (sample)
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>>
- >>> batch_meta = tq.get_data(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
+ >>> # Get single key with all fields
+ >>> data = tq.kv_batch_get(keys="sample_1", partition_id="train")
+ >>> # Get multiple keys with specific fields
+ >>> data = tq.kv_batch_get(
+ ... keys=["sample_1", "sample_2"],
+ ... partition_id="train",
+ ... fields="input_ids"
... )
- >>> batch = tq.get_data(batch_meta)
- >>> print(batch)
- >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
"""
tq_client = _maybe_create_transferqueue_client()
- return tq_client.get_data(metadata)
+ batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
+
+ if batch_meta.size == 0:
+ raise RuntimeError("keys or partition were not found!")
-async def async_get_data(metadata: BatchMeta) -> TensorDict:
- """Asynchronously fetch data from storage units and organize into TensorDict.
+ if fields is not None:
+ if isinstance(fields, str):
+ fields = [fields]
+ batch_meta = batch_meta.select_fields(fields)
+
+ if not batch_meta.is_ready:
+ raise RuntimeError("Some fields are not ready in all the requested keys!")
+
+ data = tq_client.get_data(batch_meta)
+ return data
+
+
+def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]:
+ """List all keys and their metadata in one or all partitions.
Args:
- metadata: Batch metadata containing data location information and global indexes
+ partition_id: The specific partition_id to query.
+ If None (default), returns keys from all partitions.
Returns:
- TensorDict containing:
- - Requested data fields (e.g., "prompts", "attention_mask")
+ A nested dictionary mapping partition IDs to their keys and metadata.
+
+ Structure:
+ {
+ "partition_id": {
+ "key_name": {
+ "tag1": ,
+ ... (other metadata)
+ },
+ ...,
+ },
+ ...
+ }
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>>
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=4,
- ... partition_id="train_0",
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... ))
- >>> batch = asyncio.run(tq.async_get_data(batch_meta))
- >>> print(batch)
- >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
+ >>> # Case 1: Retrieve a specific partition
+ >>> partitions = tq.kv_list(partition_id="train")
+ >>> print(f"Keys: {list(partitions['train'].keys())}")
+ >>> print(f"Tags: {list(partitions['train'].values())}")
+ >>> # Case 2: Retrieve all partitions
+ >>> all_partitions = tq.kv_list()
+ >>> for pid, keys in all_partitions.items():
+ >>> print(f"Partition: {pid}, Key count: {len(keys)}")
"""
tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_get_data(metadata)
+ partition_info = tq_client.kv_list(partition_id)
-def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta:
- """Synchronously write data to storage units based on metadata.
+ return partition_info
- If metadata is not provided, it will be created automatically using insert mode
- with the provided data fields and partition_id.
- During put, the custom_meta in metadata will update the corresponding custom_meta in
- TransferQueue Controller.
+def kv_clear(keys: list[str] | str, partition_id: str) -> None:
+ """Clear key-value pairs from TransferQueue.
- Note:
- When using multiple workers for distributed execution, there may be data
- ordering inconsistencies between workers during put operations.
+ This removes the specified keys and their associated data from both
+ the controller and storage units.
Args:
- data: Data to write as TensorDict
- metadata: Records the metadata of a batch of data samples, containing index and
- storage unit information. If None, metadata will be auto-generated.
- partition_id: Target data partition id (required if metadata is not provided)
-
- Returns:
- BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
- metadata; will be updated in a future version to reflect the post-put state)
-
- Raises:
- ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
- RuntimeError: If storage operation fails
+ keys: Single key or list of keys to clear
+ partition_id: Partition containing the keys
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>>
- >>> batch_size = 4
- >>> seq_len = 16
- >>> current_partition_id = "train_0"
- >>> # Example 1: Normal usage with existing metadata
- >>> batch_meta = tq.get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=batch_size,
- ... partition_id=current_partition_id,
- ... mode="fetch",
- ... task_name="generate_sequences",
- ... )
- >>> batch = tq.get_data(batch_meta)
- >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
- >>> tq.put(data=output, metadata=batch_meta)
- >>>
- >>> # Example 2: Initial data insertion without pre-existing metadata
- >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
- >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
- >>> # without metadata.
- >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
- >>> # interleave the initial data if n_sample > 1 before calling the async_put().
- >>> original_prompts = torch.randn(batch_size, seq_len)
- >>> n_samples = 4
- >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
- >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
- >>> # This will create metadata in "insert" mode internally.
- >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id)
+ >>> # Clear single key
+ >>> tq.kv_clear(keys="sample_1", partition_id="train")
+ >>> # Clear multiple keys
+ >>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
"""
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.put(data, metadata, partition_id)
+ if isinstance(keys, str):
+ keys = [keys]
-async def async_put(
- data: TensorDict,
- metadata: Optional[BatchMeta] = None,
- partition_id: Optional[str] = None,
-) -> BatchMeta:
- """Asynchronously write data to storage units based on metadata.
+ tq_client = _maybe_create_transferqueue_client()
+ batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
- If metadata is not provided, it will be created automatically using insert mode
- with the provided data fields and partition_id.
+ if batch_meta.size > 0:
+ tq_client.clear_samples(batch_meta)
- During put, the custom_meta in metadata will update the corresponding custom_meta in
- TransferQueue Controller.
- Note:
- When using multiple workers for distributed execution, there may be data
- ordering inconsistencies between workers during put operations.
+# ==================== KV Interface API ====================
+async def async_kv_put(
+ key: str,
+ partition_id: str,
+ fields: Optional[TensorDict | dict[str, Any]] = None,
+ tag: Optional[dict[str, Any]] = None,
+) -> None:
+ """Asynchronously put a single key-value pair to TransferQueue.
- Args:
- data: Data to write as TensorDict
- metadata: Records the metadata of a batch of data samples, containing index and
- storage unit information. If None, metadata will be auto-generated.
- partition_id: Target data partition id (required if metadata is not provided)
+ This is a convenience method for putting data using a user-specified key
+ instead of BatchMeta. Internally, the key is translated to a BatchMeta
+ and the data is stored using the regular put mechanism.
- Returns:
- BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
- metadata; will be updated in a future version to reflect the post-put state)
+ Args:
+ key: User-specified key for the data sample (in row)
+ partition_id: Logical partition to store the data in
+ fields: Data fields to store. Can be a TensorDict or a dict of tensors.
+ Each key in `fields` will be treated as a column for the data sample.
+ If dict is provided, tensors will be unsqueezed to add batch dimension.
+ tag: Optional metadata tag to associate with the key
Raises:
- ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
- RuntimeError: If storage operation fails
+ ValueError: If neither fields nor tag is provided
+ ValueError: If nested tensors are provided (use kv_batch_put instead)
+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
Example:
>>> import transfer_queue as tq
+ >>> import torch
>>> tq.init()
- >>>
- >>> batch_size = 4
- >>> seq_len = 16
- >>> current_partition_id = "train_0"
- >>> # Example 1: Normal usage with existing metadata
- >>> batch_meta = asyncio.run(tq.async_get_meta(
- ... data_fields=["prompts", "attention_mask"],
- ... batch_size=batch_size,
- ... partition_id=current_partition_id,
- ... mode="fetch",
- ... task_name="generate_sequences",
+ >>> # Put with both fields and tag
+ >>> await tq.async_kv_put(
+ ... key="sample_1",
+ ... partition_id="train",
+ ... fields={"input_ids": torch.tensor([1, 2, 3])},
+ ... tag={"score": 0.95}
... ))
- >>> batch = asyncio.run(tq.async_get_data(batch_meta))
- >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
- >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta))
- >>>
- >>> # Example 2: Initial data insertion without pre-existing metadata
- >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
- >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
- >>> # without metadata.
- >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
- >>> # interleave the initial data if n_sample > 1 before calling the async_put().
- >>> original_prompts = torch.randn(batch_size, seq_len)
- >>> n_samples = 4
- >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
- >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
- >>> # This will create metadata in "insert" mode internally.
- >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id))
"""
+
+ if fields is None and tag is None:
+ raise ValueError("Please provide at least one parameter of fields or tag.")
+
tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_put(data, metadata, partition_id)
+ # 1. translate user-specified key to BatchMeta
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True)
+
+ if batch_meta.size != 1:
+ raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
+
+ # 2. register the user-specified tag to BatchMeta
+ if tag:
+ batch_meta.update_custom_meta([tag])
+
+ # 3. put data
+ if fields is not None:
+ if isinstance(fields, dict):
+ # TODO: consider whether to support this...
+ batch = {}
+ for field_name, value in fields.items():
+ if isinstance(value, torch.Tensor):
+ if value.is_nested:
+ raise ValueError("Please use (async)kv_batch_put for batch operation")
+ batch[field_name] = value.unsqueeze(0)
+ else:
+ batch[field_name] = NonTensorStack(value)
+ fields = TensorDict(batch, batch_size=[1])
+ elif not isinstance(fields, TensorDict):
+ raise ValueError("field can only be dict or TensorDict")
+
+ # custom_meta (tag) will be put to controller through the put process
+ await tq_client.async_put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tag) to controller
+ await tq_client.async_set_custom_meta(batch_meta)
-def set_custom_meta(metadata: BatchMeta) -> None:
- """Synchronously send custom metadata to the controller.
- This method sends per-sample custom metadata (custom_meta) to the controller.
- The custom_meta is stored in the controller and can be retrieved along with
- the BatchMeta in subsequent get_meta calls.
+async def async_kv_batch_put(
+ keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None
+) -> None:
+ """Asynchronously put multiple key-value pairs to TransferQueue in batch.
+
+ This method stores multiple key-value pairs in a single operation, which is more
+ efficient than calling kv_put multiple times.
Args:
- metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta() or
- BatchMeta.set_custom_meta() before calling this method.
+ keys: List of user-specified keys for the data
+ partition_id: Logical partition to store the data in
+ fields: TensorDict containing data for all keys. Must have batch_size == len(keys)
+ tags: List of metadata tags, one for each key
Raises:
- RuntimeError: If communication fails or controller returns error response
+ ValueError: If neither `fields` nor `tags` is provided
+ ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>>
- >>> # Create batch with custom metadata
- >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...)
- >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
- >>> tq.set_custom_meta(batch_meta)
+ >>> keys = ["sample_1", "sample_2", "sample_3"]
+ >>> fields = TensorDict({
+ ... "input_ids": torch.randn(3, 10),
+ ... "attention_mask": torch.ones(3, 10),
+ ... }, batch_size=3)
+ >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
+ >>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
"""
+
+ if fields is None and tags is None:
+ raise ValueError("Please provide at least one parameter of `fields` or `tags`.")
+
+ if fields is not None and fields.batch_size[0] != len(keys):
+ raise ValueError(
+ f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
+ f"batch_size {fields.batch_size[0]}"
+ )
+
tq_client = _maybe_create_transferqueue_client()
- return tq_client.set_custom_meta(metadata)
+ # 1. translate user-specified key to BatchMeta
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True)
-async def async_set_custom_meta(
- metadata: BatchMeta,
-) -> None:
- """
- Asynchronously send custom metadata to the controller.
+ if batch_meta.size != len(keys):
+ raise RuntimeError(
+ f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!"
+ )
+
+ # 2. register the user-specified tags to BatchMeta
+ if tags is not None:
+ if len(tags) != len(keys):
+ raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
+ batch_meta.update_custom_meta(tags)
+
+ # 3. put data
+ if fields is not None:
+ await tq_client.async_put(fields, batch_meta)
+ else:
+ # directly update custom_meta (tags) to controller
+ await tq_client.async_set_custom_meta(batch_meta)
+
+
+async def async_kv_batch_get(
+ keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None
+) -> TensorDict:
+ """Asynchronously get data from TransferQueue using user-specified keys.
- This method sends per-sample custom metadata (custom_meta) to the controller.
- The custom_meta is stored in the controller and can be retrieved along with
- the BatchMeta in subsequent get_meta calls.
+ This is a convenience method for retrieving data using keys instead of indexes.
Args:
- metadata: BatchMeta containing the samples and their custom metadata to store.
- The custom_meta should be set using BatchMeta.update_custom_meta() or
- BatchMeta.set_custom_meta() before calling this method.
- socket: ZMQ async socket for message transmission (injected by decorator)
+ keys: Single key or list of keys to retrieve
+ partition_id: Partition containing the keys
+ fields: Optional field(s) to retrieve. If None, retrieves all fields
+
+ Returns:
+ TensorDict with the requested data
Raises:
- RuntimeError: If communication fails or controller returns error response
+ RuntimeError: If keys or partition are not found
+ RuntimeError: If empty fields exist in any key (sample)
Example:
>>> import transfer_queue as tq
>>> tq.init()
- >>>
- >>> # Create batch with custom metadata
- >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...)
- >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
- >>> asyncio.run(tq.async_set_custom_meta(batch_meta))
+ >>> # Get single key with all fields
+ >>> data = await tq.async_kv_batch_get(keys="sample_1", partition_id="train")
+ >>> # Get multiple keys with specific fields
+ >>> data = await tq.async_kv_batch_get(
+ ... keys=["sample_1", "sample_2"],
+ ... partition_id="train",
+ ... fields="input_ids"
+ ... )
"""
tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_set_custom_meta(metadata)
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
-def clear_samples(metadata: BatchMeta):
- """Synchronously clear specific samples from all storage units and the controller.
+ if batch_meta.size == 0:
+ raise RuntimeError("keys or partition were not found!")
- Args:
- metadata: The BatchMeta of the corresponding data to be cleared
+ if fields is not None:
+ if isinstance(fields, str):
+ fields = [fields]
+ batch_meta = batch_meta.select_fields(fields)
- Raises:
- RuntimeError: If clear operation fails
- """
- tq_client = _maybe_create_transferqueue_client()
- return tq_client.clear_samples(metadata)
+ if not batch_meta.is_ready:
+ raise RuntimeError("Some fields are not ready in all the requested keys!")
+ data = await tq_client.async_get_data(batch_meta)
+ return data
-async def async_clear_samples(metadata: BatchMeta):
- """Asynchronously clear specific samples from all storage units and the controller.
- Args:
- metadata: The BatchMeta of the corresponding data to be cleared
+async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]:
+ """Asynchronously list all keys and their metadata in one or all partitions.
- Raises:
- RuntimeError: If clear operation fails
- """
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_clear_samples(metadata)
+ Args:
+ partition_id: The specific partition_id to query.
+ If None (default), returns keys from all partitions.
+ Returns:
+ A nested dictionary mapping partition IDs to their keys and metadata.
-def clear_partition(partition_id: str):
- """Synchronously clear the whole partition from all storage units and the controller.
+ Structure:
+ {
+ "partition_id": {
+ "key_name": {
+ "tag1": ,
+ ... (other metadata)
+ },
+ ...,
+ },
+ ...
+ }
- Args:
- partition_id: The partition id to clear data for
- Raises:
- RuntimeError: If clear operation fails
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> # Case 1: Retrieve a specific partition
+ >>> partitions = await tq.async_kv_list(partition_id="train")
+ >>> print(f"Keys: {list(partitions['train'].keys())}")
+ >>> print(f"Tags: {list(partitions['train'].values())}")
+ >>> # Case 2: Retrieve all partitions
+ >>> all_partitions = await tq.async_kv_list()
+ >>> for pid, keys in all_partitions.items():
+ >>> print(f"Partition: {pid}, Key count: {len(keys)}")
"""
tq_client = _maybe_create_transferqueue_client()
- return tq_client.clear_partition(partition_id)
+
+ partition_info = await tq_client.async_kv_list(partition_id)
+
+ return partition_info
-async def async_clear_partition(partition_id: str):
- """Asynchronously clear the whole partition from all storage units and the controller.
+async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
+ """Asynchronously clear key-value pairs from TransferQueue.
+
+ This removes the specified keys and their associated data from both
+ the controller and storage units.
Args:
- partition_id: The partition id to clear data for
+ keys: Single key or list of keys to clear
+ partition_id: Partition containing the keys
- Raises:
- RuntimeError: If clear operation fails
+ Example:
+ >>> import transfer_queue as tq
+ >>> tq.init()
+ >>> # Clear single key
+ >>> await tq.async_kv_clear(keys="sample_1", partition_id="train")
+ >>> # Clear multiple keys
+ >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
"""
- tq_client = _maybe_create_transferqueue_client()
- return await tq_client.async_clear_partition(partition_id)
+ if isinstance(keys, str):
+ keys = [keys]
-def close():
- """Close the TransferQueue system."""
- global _TRANSFER_QUEUE_CLIENT
- global _TRANSFER_QUEUE_STORAGE
- if _TRANSFER_QUEUE_CLIENT:
- _TRANSFER_QUEUE_CLIENT.close()
- _TRANSFER_QUEUE_CLIENT = None
+ tq_client = _maybe_create_transferqueue_client()
+ batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
- try:
- if _TRANSFER_QUEUE_STORAGE:
- # only the process that do first-time init can clean the distributed storage
- for storage in _TRANSFER_QUEUE_STORAGE.values():
- ray.kill(storage)
- _TRANSFER_QUEUE_STORAGE = None
- except Exception:
- pass
+ if batch_meta.size > 0:
+ await tq_client.async_clear_samples(batch_meta)
- try:
- controller = ray.get_actor("TransferQueueController")
- ray.kill(controller)
- except Exception:
- pass
+
+# ==================== Low-Level Native API ====================
+# For low-level API support, please refer to transfer_queue/client.py for details.
+def get_client():
+ """Get a TransferQueueClient for using low-level API"""
+ if _TRANSFER_QUEUE_CLIENT is None:
+ raise RuntimeError("Please initialize the TransferQueue first by calling `tq.init()`!")
+ return _TRANSFER_QUEUE_CLIENT
diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py
index 5e25404..056a146 100644
--- a/transfer_queue/metadata.py
+++ b/transfer_queue/metadata.py
@@ -303,60 +303,43 @@ def clear_extra_info(self) -> None:
"""
self.extra_info.clear()
- def set_custom_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None:
+ def get_all_custom_meta(self) -> list[dict[str, Any]]:
"""
- Set custom_meta for a specific sample by global_index.
-
- Custom metadata is user-defined per-sample metadata that can be stored
- and retrieved along with the BatchMeta.
-
- Args:
- global_index: The global_index of the sample to set custom meta for
- meta_dict: Dictionary containing custom metadata for the sample
-
- Raises:
- ValueError: If the key is not in global_indexes
- """
-
- if global_index not in self.global_indexes:
- raise ValueError(f"key {global_index} not found in global_indexes {self.global_indexes}.")
-
- self.custom_meta[global_index] = copy.deepcopy(meta_dict)
-
- def get_all_custom_meta(self) -> dict[int, dict[str, Any]]:
- """
- Get all custom_meta as a dictionary.
+ Get all custom_meta as a list of dictionary.
Returns:
- A deep copy of the custom_meta dictionary
+ A deep copy of the custom_meta list
"""
- return copy.deepcopy(self.custom_meta)
+ custom_meta = [self.custom_meta.get(i, {}) for i in self.global_indexes]
+ return copy.deepcopy(custom_meta)
- def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]):
+ def update_custom_meta(self, custom_meta: list[dict[str, Any]]):
"""
- Update custom_meta with a dictionary of new metadata.
+ Update custom_meta with a list of dictionary of custom metadata.
This method updates the custom_meta dictionary with the provided metadata.
Existing keys will be overwritten with new values.
Args:
- new_meta: Dictionary of new metadata
+ custom_meta: list of custom_meta dictionary
Raises:
- ValueError: If any key in new_meta is not in global_indexes
+ ValueError: If the length of custom_meta cannot match the batch size
"""
- if new_meta is None:
+ if custom_meta is None:
return
- non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes)
- if non_exist_global_indexes:
+ if len(custom_meta) != self.size:
raise ValueError(
- f"Trying to update custom_meta with non-exist global_indexes! {non_exist_global_indexes} "
- f"do not exist in this batch."
+ f"The length of custom_meta list {len(custom_meta)} must match the batch size: {self.size}"
)
- self.custom_meta.update(new_meta)
+ custom_meta_dict: dict[int, dict[str, Any]] = {
+ self.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta))
+ }
+
+ self.custom_meta.update(custom_meta_dict)
def clear_custom_meta(self) -> None:
"""
@@ -846,3 +829,211 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) ->
]
return all_fields
+
+
+# ==================== KV Interface Metadata ====================
+@dataclass
+class KVBatchMeta:
+ """Records the metadata for KV interface."""
+
+ # keys of each sample
+ keys: list[str] = dataclasses.field(default_factory=list)
+
+ # sample-level tags
+ tags: list[dict] = dataclasses.field(default_factory=list)
+
+ # [optional] partition_id of this batch
+ partition_id: Optional[str] = None
+
+ # [optional] fields of each sample
+ fields: Optional[list[str]] = None
+
+ # [optional] external information for batch-level information
+ extra_info: Optional[dict[str, Any]] = None
+
+ def __post_init__(self):
+ """Validate all the variables"""
+ if len(self.keys) != len(self.tags):
+ raise ValueError(f"keys and tags must have same length, but got {len(self.keys)} and {len(self.tags)}")
+ if len(self.keys) != len(set(self.keys)):
+ raise ValueError("Got duplicated keys.")
+ if self.fields is not None:
+ if len(self.fields) != len(set(self.fields)):
+ raise ValueError("Got duplicated fields.")
+
+ # deepcopy to prevent unexpected behavior after chunk/concat
+ self.tags = copy.deepcopy(self.tags)
+ self.extra_info = copy.deepcopy(self.extra_info)
+
+ object.__setattr__(self, "_size", len(self.keys))
+
+ @property
+ def size(self) -> int:
+ """Return the number of samples in this batch"""
+ return getattr(self, "_size", 0)
+
+ def __len__(self) -> int:
+ """Return the number of samples in this batch."""
+ return len(self.keys)
+
+ def __str__(self):
+ return f"KVBatchMeta(size={self.size}, field_names={self.fields}, extra_info={self.extra_info})"
+
+ def select_keys(self, keys_to_select: list[str]) -> "KVBatchMeta":
+ """
+ Select specific keys from this batch.
+ This will construct a new KVBatchMeta instance containing only the specified keys.
+
+ Args:
+ keys_to_select (list[str]): List of keys to retain.
+
+ Returns:
+ KVBatchMeta: A new KVBatchMeta instance containing only the specified keys.
+
+ Raises:
+ ValueError: If duplicate keys exist in input param `keys_to_select`.
+ RuntimeError: If `keys_to_select` contains keys that do not exist in this batch.
+ """
+
+ if len(set(keys_to_select)) != len(keys_to_select):
+ raise ValueError("Contain duplicate keys.")
+
+ non_exist_keys = set(keys_to_select) - set(self.keys)
+ if len(non_exist_keys) > 0:
+ raise RuntimeError(f"Keys {non_exist_keys} not found in current batch.")
+
+ _keys_to_idx = {key: idx for idx, key in enumerate(self.keys)}
+
+ loc_idx = [_keys_to_idx[k] for k in keys_to_select]
+ tags = [self.tags[i] for i in loc_idx]
+
+ return KVBatchMeta(
+ keys=keys_to_select,
+ tags=tags,
+ partition_id=self.partition_id,
+ fields=self.fields,
+ extra_info=self.extra_info,
+ )
+
+ def reorder(self, indexes: list[int]):
+ """
+ Reorder the samples in this batch according to the specified indexes.
+
+ The operation is performed in-place.
+
+ Args:
+ indexes : list[int]
+ A list of integers specifying the new order of SampleMeta.
+
+ Raises:
+ ValueError: If the size of input `indexes` does not match with the batch size.
+ ValueError: If duplicate indexes exist in input param `indexes`.
+ """
+ if len(indexes) != self.size:
+ raise ValueError(
+ f"Attempted to reorder with indexes length {len(indexes)} that does not match "
+ f"the batch size {self.size}."
+ )
+
+ if len(set(indexes)) != len(indexes):
+ raise ValueError("Contain duplicate indexes.")
+
+ self.keys = [self.keys[i] for i in indexes]
+ self.tags = [self.tags[i] for i in indexes]
+
+ def chunk(self, chunks: int) -> list["KVBatchMeta"]:
+ """
+ Split this batch into smaller chunks.
+
+ Args:
+ chunks: number of chunks
+
+ Return:
+ List of smaller KVBatchMeta chunks
+ """
+
+ chunk_list = []
+ if self.size < chunks:
+ logger.warning(
+ f"Chunk size {chunks} > number of samples in this batch {self.size}, this will return some "
+ f"empty KVBatchMeta chunks."
+ )
+
+ # Calculate the base size and remainder of each chunk
+ base_size = self.size // chunks
+ remainder = self.size % chunks
+
+ start = 0
+ for i in range(chunks):
+ # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size)
+ current_chunk_size = base_size + 1 if i < remainder else base_size
+ end = start + current_chunk_size
+ chunk_keys = self.keys[start:end]
+ chunk_tags = self.tags[start:end]
+
+ chunk = KVBatchMeta(
+ keys=chunk_keys,
+ tags=chunk_tags,
+ partition_id=self.partition_id,
+ fields=self.fields,
+ extra_info=self.extra_info,
+ )
+ chunk_list.append(chunk)
+ start = end
+
+ return chunk_list
+
+ @classmethod
+ def concat(cls, data: list["KVBatchMeta"]) -> "KVBatchMeta":
+ """
+ Concatenate multiple KVBatchMeta chunks into one large batch.
+
+ Args:
+ data: List of KVBatchMeta chunks to concatenate
+
+ Returns:
+ Concatenated KVBatchMeta
+
+ Raises:
+ ValueError: If validation fails (e.g., field names do not match)
+ """
+ if not data:
+ logger.warning("Try to concat empty KVBatchMeta chunks. Returning empty KVBatchMeta.")
+ return KVBatchMeta()
+
+ # skip empty chunks
+ data = [chunk for chunk in data if chunk and chunk.size > 0]
+
+ if len(data) == 0:
+ logger.warning("No valid KVBatchMeta chunks to concatenate. Returning empty KVBatchMeta.")
+ return KVBatchMeta()
+
+ base_fields = data[0].fields
+ if base_fields is not None:
+ base_fields_set = set(base_fields)
+ else:
+ base_fields_set = set()
+
+ base_partition_id = data[0].partition_id
+
+ all_keys = []
+ all_tags = []
+ all_extra_info = {}
+ for chunk in data:
+ if chunk.fields is not None and set(chunk.fields) != base_fields_set:
+ raise ValueError("Field names do not match for concatenation.")
+ if chunk.partition_id != base_partition_id:
+ raise ValueError("Partition do not match for concatenation.")
+
+ all_keys.extend(chunk.keys)
+ all_tags.extend(chunk.tags)
+ if chunk.extra_info is not None:
+ all_extra_info.update(chunk.extra_info)
+
+ return KVBatchMeta(
+ keys=all_keys,
+ tags=all_tags,
+ partition_id=base_partition_id,
+ fields=base_fields,
+ extra_info=all_extra_info if all_extra_info else None,
+ )
diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py
index 352b3ad..c0c5058 100644
--- a/transfer_queue/storage/managers/base.py
+++ b/transfer_queue/storage/managers/base.py
@@ -177,7 +177,7 @@ def _send_handshake_requests(self) -> None:
"""Send handshake request to controller."""
assert self.controller_handshake_socket is not None, "controller_handshake_socket is not properly initialized"
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.HANDSHAKE,
+ request_type=ZMQRequestType.HANDSHAKE, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={
"storage_manager_id": self.storage_manager_id,
@@ -225,7 +225,7 @@ async def notify_data_update(
poller.register(self.data_status_update_socket, zmq.POLLIN)
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.NOTIFY_DATA_UPDATE,
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={
"partition_id": partition_id,
@@ -245,7 +245,7 @@ async def notify_data_update(
)
except Exception as e:
request_msg = ZMQMessage.create(
- request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR,
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={
"message": f"Failed to notify data status update information from "
@@ -556,7 +556,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
loop = asyncio.get_event_loop()
# put to storage backends
- custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values)
+ custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values)
per_field_dtypes: dict[int, dict[str, Any]] = {}
per_field_shapes: dict[int, dict[str, Any]] = {}
@@ -577,24 +577,28 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
getattr(data_item, "shape", None) if isinstance(data_item, Tensor) else None
)
- # Prepare per-field custom_meta if available
- per_field_custom_meta: dict[int, dict[str, Any]] = {}
- if custom_meta:
- if len(custom_meta) != len(keys):
- raise ValueError(f"Length of custom_meta ({len(custom_meta)}) does not match expected ({len(keys)})")
+ # Prepare per-field custom_backend_meta if available
+ per_field_custom_backend_meta: dict[int, dict[str, Any]] = {}
+ if custom_backend_meta:
+ if len(custom_backend_meta) != len(keys):
+ raise ValueError(
+ f"Length of custom_backend_meta ({len(custom_backend_meta)}) does not match expected ({len(keys)})"
+ )
# custom meta is a flat list aligned with keys/values
# Use itertools.product to eliminate nested loops
for global_idx in metadata.global_indexes:
- per_field_custom_meta[global_idx] = {}
+ per_field_custom_backend_meta[global_idx] = {}
- # TODO(tianyi): the order of custom meta is coupled with keys/values
+ # FIXME(tianyi): the order of custom backend meta is coupled with keys/values
+ # FIXME: if put_data is called to partially update/add new fields, the current
+ # implementation will cause custom_backend_meta losses or mismatch!
for (field_name, global_idx), meta_value in zip(
itertools.product(sorted(metadata.field_names), metadata.global_indexes),
- custom_meta,
+ custom_backend_meta,
strict=True,
):
- per_field_custom_meta[global_idx][field_name] = meta_value
- metadata.update_custom_meta(per_field_custom_meta)
+ per_field_custom_backend_meta[global_idx][field_name] = meta_value
+ metadata._custom_backend_meta.update(per_field_custom_backend_meta)
# Get current data partition id
# Note: Currently we only support putting to & getting data from a single data partition simultaneously,
@@ -607,7 +611,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
metadata.global_indexes,
per_field_dtypes,
per_field_shapes,
- per_field_custom_meta,
+ per_field_custom_backend_meta,
)
async def get_data(self, metadata: BatchMeta) -> TensorDict:
diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py
index bf711c8..eaaf65e 100644
--- a/transfer_queue/utils/zmq_utils.py
+++ b/transfer_queue/utils/zmq_utils.py
@@ -101,6 +101,12 @@ class ZMQRequestType(ExplicitEnum):
NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK"
NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR"
+ # KV_INTERFACE
+ KV_RETRIEVE_KEYS = "KV_RETRIEVE_KEYS"
+ KV_RETRIEVE_KEYS_RESPONSE = "KV_RETRIEVE_KEYS_RESPONSE"
+ KV_LIST = "KV_LIST"
+ KV_LIST_RESPONSE = "KV_LIST_RESPONSE"
+
class ZMQServerInfo:
"""
diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py
index 59159a6..9a66b8e 100644
--- a/tutorial/01_core_components.py
+++ b/tutorial/01_core_components.py
@@ -63,6 +63,7 @@ def demonstrate_data_workflow():
# Step 1: Put data
print("[Step 1] Putting data into TransferQueue...")
+ tq_client = tq.get_client()
input_ids = torch.tensor(
[
@@ -84,12 +85,12 @@ def demonstrate_data_workflow():
print(f" Created {data_batch.batch_size[0]} samples")
partition_id = "tutorial_partition_0"
- tq.put(data=data_batch, partition_id=partition_id)
+ tq_client.put(data=data_batch, partition_id=partition_id)
print(f" β Data written to partition: {partition_id}")
# Step 2: Get metadata
print("[Step 2] Requesting data metadata...")
- batch_meta = tq.get_meta(
+ batch_meta = tq_client.get_meta(
data_fields=["input_ids", "attention_mask"],
batch_size=data_batch.batch_size[0],
partition_id=partition_id,
@@ -100,7 +101,7 @@ def demonstrate_data_workflow():
# Step 3: Get actual data
print("[Step 3] Retrieving actual data...")
- retrieved_data = tq.get_data(batch_meta)
+ retrieved_data = tq_client.get_data(batch_meta)
print(" β Data retrieved successfully")
print(f" Keys: {list(retrieved_data.keys())}")
@@ -112,7 +113,7 @@ def demonstrate_data_workflow():
# Step 5: Clear
print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)")
- tq.clear_partition(partition_id=partition_id)
+ tq_client.clear_partition(partition_id=partition_id)
print(" β Partition cleared")
diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py
new file mode 100644
index 0000000..376ebfb
--- /dev/null
+++ b/tutorial/02_kv_interface.py
@@ -0,0 +1,259 @@
+# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+# Copyright 2025 The TransferQueue Team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import textwrap
+import warnings
+from pathlib import Path
+
+warnings.filterwarnings(
+ action="ignore",
+ message=r"The PyTorch API of nested tensors is in prototype stage*",
+ category=UserWarning,
+ module=r"torch\.nested",
+)
+
+warnings.filterwarnings(
+ action="ignore",
+ message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible "
+ r"devices env var if num_gpus=0 or num_gpus=None.*",
+ category=FutureWarning,
+ module=r"ray\._private\.worker",
+)
+
+
+import ray # noqa: E402
+import torch # noqa: E402
+from tensordict import TensorDict # noqa: E402
+
+# Add the parent directory to the path
+parent_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(parent_dir))
+
+import transfer_queue as tq # noqa: E402
+
+# Configure Ray
+os.environ["RAY_DEDUP_LOGS"] = "0"
+os.environ["RAY_DEBUG"] = "1"
+
+if not ray.is_initialized():
+ ray.init(namespace="TransferQueueTutorial")
+
+
+def demonstrate_kv_api():
+ """
+ Demonstrate the Key-Value (KV) semantic API:
+ kv_put & kv_batch_put -> kv_list -> kv_batch_get -> kv_clear
+ """
+ print("=" * 80)
+ print("Key-Value Semantic API Demo: kv_put/kv_batch_put β kv_list β kv_batch_get β kv_clear")
+ print("=" * 80)
+
+ # Step 1: Put a single key-value pair with kv_put
+ print("[Step 1] Putting a single sample with kv_put...")
+
+ # Define the data content (The "Value")
+ input_ids = torch.tensor([[1, 2, 3]])
+ attention_mask = torch.ones(input_ids.size())
+
+ single_sample = TensorDict(
+ {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ },
+ batch_size=input_ids.size(0),
+ )
+
+ partition_id = "Train"
+ # Use a meaningful string key instead of an auto-increment integer
+ key = "0_0" # User-defined key: "{uid}_{session_id}"
+ tag = {"global_steps": 0, "status": "running", "model_version": 0}
+
+ print(f" Inserting Key: {key}")
+ print(f" Fields (Columns): {list(single_sample.keys())}")
+ print(f" Tag (Metadata): {tag}")
+
+ tq.kv_put(key=key, partition_id=partition_id, fields=single_sample, tag=tag)
+ print(" β kv_put success.")
+
+ # Step 2: Put multiple key-value pairs with kv_batch_put
+ print("\n[Step 2] Putting batch data with kv_batch_put...")
+
+ batch_input_ids = torch.tensor(
+ [
+ [4, 5, 6],
+ [7, 8, 9],
+ [10, 11, 12],
+ [13, 14, 15],
+ ]
+ )
+ batch_attention_mask = torch.ones_like(batch_input_ids)
+
+ data_batch = TensorDict(
+ {
+ "input_ids": batch_input_ids,
+ "attention_mask": batch_attention_mask,
+ },
+ batch_size=batch_input_ids.size(0),
+ )
+
+ keys = ["1_0", "1_1", "1_2", "2_0"] # 4 keys for 4 samples
+ tags = [{"global_steps": 1, "status": "running", "model_version": 1} for _ in range(len(keys))]
+
+ print(f" Inserting batch of {len(keys)} samples.")
+ print(f" Fields (Columns): {list(data_batch.keys())}")
+ print(f" Tag (Metadata): {tags}")
+ tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data_batch, tags=tags)
+ print(" β kv_batch_put success.")
+
+ # Step 3: Append additional fields to existing samples
+ print("\n[Step 3] Appending new fields (Columns) to existing samples...")
+
+ batch_response = torch.tensor(
+ [
+ [4, 5, 6],
+ [7, 8, 9],
+ ]
+ )
+ response_batch = TensorDict(
+ {
+ "response": batch_response,
+ },
+ batch_size=batch_response.size(0),
+ )
+
+ # We only update subset of keys
+ append_keys = ["1_1", "2_0"] # Appending to existing samples
+ append_tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(append_keys))]
+ print(f" Target Keys: {append_keys}")
+ print(" New Field to Add is: 'response'")
+ print(f" The updated tags are: {append_tags}")
+ tq.kv_batch_put(keys=append_keys, partition_id=partition_id, fields=response_batch, tags=append_tags)
+ print(" β Update success: Samples '1_1' and '2_0' now contain {input_ids, attention_mask, response}.")
+
+ # Step 4: Only update tags through kv_put
+ print("\n[Step 4] Update existing tags without providing value...")
+ key_for_update_tags = "0_0"
+ tag_update = {"global_steps": 0, "status": "finish", "model_version": 0}
+ print(f" Target Key: {key_for_update_tags}")
+ print(f" The updated tag is: {tag_update}")
+ tq.kv_put(key=key_for_update_tags, partition_id=partition_id, fields=None, tag=tag_update)
+ print(f" β Update success: Samples '0_0' now has tag as {tag_update}.")
+
+ # Step 5: List all keys and tags in a partition
+ print("\n[Step 5] Listing all keys and tags in partition...")
+
+ partition_info = tq.kv_list()
+ print(f" Found {len(partition_info.keys())} partitions: '{list(partition_info.keys())}'")
+ for pid, keys_and_tags in partition_info.items():
+ for k, t in keys_and_tags.items():
+ print(f"Partition: {pid}, - key='{k}' | tag={t}")
+
+ # Step 6: Retrieve specific fields using kv_batch_get
+ print("\n[Step 6] Retrieving specific fields (Column) with kv_batch_get...")
+ print(" Fetching only 'input_ids' to save bandwidth (ignoring 'attention_mask' and 'response').")
+
+ all_keys = list(partition_info[partition_id].keys())
+ retrieved_input_ids = tq.kv_batch_get(keys=all_keys, partition_id=partition_id, fields="input_ids")
+ print(f" β Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.")
+
+ # # Step 7: Retrieve all fields using kv_batch_get
+ print("\n[Step 7] Retrieving all fields with kv_batch_get...")
+ retrieved_all = tq.kv_batch_get(keys=all_keys, partition_id=partition_id)
+ print(f" Retrieved all fields for {all_keys}:")
+ print(f" Fields: {list(retrieved_all.keys())}")
+ print(
+ f" Note: We cannot retrieve fields {list(response_batch.keys())}, since they only available in {append_keys}"
+ )
+
+ # Step 8: Clear specific keys
+ print("\n[Step 8] Clearing keys from partition...")
+ keys_to_clear = all_keys[:2] # Delete the first 2 keys
+ tq.kv_clear(keys=keys_to_clear, partition_id=partition_id)
+ print(f" β Cleared keys: {keys_to_clear}")
+
+ partition_info_after_clear = tq.kv_list(partition_id=partition_id)
+ print(f" Remaining keys in partition: {list(partition_info_after_clear[partition_id].keys())}")
+
+
+def main():
+ print("=" * 80)
+ print(
+ textwrap.dedent(
+ """
+ TransferQueue Tutorial 2: Key-Value (KV) Semantic API
+
+ This tutorial demonstrates the KV semantic API, which provides a simple
+ interface for data storage and retrieval using user-defined string keys.
+
+ Key Methods:
+ 1. (async_)kv_put - Insert/Update a multi-column sample by key, with optional metadata tag
+ 2. (async_)kv_batch_put - Put multiple key-value pairs efficiently in batch
+ 3. (async_)kv_batch_get - Retrieve samples (by keys), supporting column selection (by fields)
+ 4. (async_)kv_list - List keys and tags (metadata) in a partition
+ 5. (async_)kv_clear - Remove key-value pairs from storage
+
+ Key Features:
+ β Redis-style Semantics - Familiar KV interface (Put/Get/List) for zero learning curve
+ β Fine-grained Access - Update or retrieve specific fields (columns) within a key (row) without full op.
+ β Partition Isolation - Logical separation of storage namespaces
+ β Metadata Tags - Lightweight metadata for status tracking
+ β Pluggable Backends - Supports multiple backends
+
+ Use Cases:
+ - Focusing on fine-grained data access where extreme streaming performance is non-essential
+ - Integration with external ReplayBuffer/single-controller that manage sample dispatching
+
+ Limitations (vs low-level native APIs):
+ - No built-in production/consumption tracking: Users must manually check status via tags externally.
+ - No built-in Sampler support: Must implement data dispatch by ReplayBuffer or single-controller externally.
+ - Not fully streaming: Consumers must wait for single-controller to dispatch `keys`.
+ """
+ )
+ )
+ print("=" * 80)
+
+ try:
+ print("Setting up TransferQueue...")
+ tq.init()
+
+ print("\nDemonstrating the KV semantic API...")
+ demonstrate_kv_api()
+
+ print("\n" + "=" * 80)
+ print("Tutorial Complete!")
+ print("=" * 80)
+ print("\nKey Takeaways:")
+ print(" 1. KV API simplifies data access with Redis-style semantics")
+ print(" 2. Use 'fields' parameter to get/put specific fields only")
+ print(" 3. Tags enable custom metadata for production status, scores, etc.")
+ print(" 4. Use kv_list to inspect partition contents")
+
+ # Cleanup
+ tq.close()
+ ray.shutdown()
+ print("\nCleanup complete")
+
+ except Exception as e:
+ print(f"Error during tutorial: {e}")
+ import traceback
+
+ traceback.print_exc()
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tutorial/02_metadata_concepts.py b/tutorial/03_metadata_concepts.py
similarity index 93%
rename from tutorial/02_metadata_concepts.py
rename to tutorial/03_metadata_concepts.py
index d864db3..2c81941 100644
--- a/tutorial/02_metadata_concepts.py
+++ b/tutorial/03_metadata_concepts.py
@@ -205,10 +205,15 @@ def demonstrate_batch_meta():
print(f"β Extra info: {batch.get_all_extra_info()}")
print("[Example 3] Adding sample-level information through custom_meta...")
- batch.set_custom_meta(
- global_index=0, meta_dict={"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"}
+ batch.update_custom_meta(
+ [
+ {"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"},
+ {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"},
+ {"uid": "prompt@2", "session_id": "session@0", "model_version": "epoch@0"},
+ {"uid": "prompt@3", "session_id": "session@0", "model_version": "epoch@0"},
+ {"uid": "prompt@4", "session_id": "session@0", "model_version": "epoch@0"},
+ ]
)
- batch.update_custom_meta({1: {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"}})
print(f"β Custom meta: {batch.get_all_custom_meta()}")
# Example 4: Chunk a batch
@@ -300,11 +305,13 @@ def demonstrate_real_workflow():
# Initialize Ray
if not ray.is_initialized():
- ray.init()
+ ray.init(namespace="TransferQueueTutorial")
# Initialize TransferQueue
tq.init()
+ tq_client = tq.get_client()
+
print("[Step 1] Putting data into TransferQueue...")
input_ids = torch.randint(0, 1000, (8, 512))
attention_mask = torch.ones(8, 512)
@@ -318,23 +325,23 @@ def demonstrate_real_workflow():
)
partition_id = "demo_partition"
- batch_meta = tq.put(data=data_batch, partition_id=partition_id)
+ batch_meta = tq_client.put(data=data_batch, partition_id=partition_id)
print(f"β Put {data_batch.batch_size[0]} samples into partition '{partition_id}', got BatchMeta back {batch_meta}.")
print("[Step 2] [Optional] Setting sample-level custom_meta...")
- custom_meta = {
- global_index: {"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0}
- for global_index in batch_meta.global_indexes
- }
+ custom_meta = [
+ {"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0}
+ for _ in range(batch_meta.size)
+ ]
batch_meta.update_custom_meta(custom_meta)
print(f"β Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}")
- tq.set_custom_meta(batch_meta)
+ tq_client.set_custom_meta(batch_meta)
print("β Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.")
print("[Step 3] Try to get metadata from TransferQueue from other places...")
- batch_meta = tq.get_meta(
+ batch_meta = tq_client.get_meta(
data_fields=["input_ids", "attention_mask"],
batch_size=8,
partition_id=partition_id,
@@ -353,7 +360,7 @@ def demonstrate_real_workflow():
print("β Selected 'input_ids' field only:")
print(f" New field names: {selected_meta.field_names}")
print(f" Samples still have same global indexes: {selected_meta.global_indexes}")
- retrieved_data = tq.get_data(selected_meta)
+ retrieved_data = tq_client.get_data(selected_meta)
print(f" Retrieved data keys: {list(retrieved_data.keys())}")
print("[Step 5] Select specific samples from the retrieved BatchMeta...")
@@ -361,7 +368,7 @@ def demonstrate_real_workflow():
print("β Selected samples at indices [0, 2, 4, 6]:")
print(f" New global indexes: {partial_meta.global_indexes}")
print(f" Number of samples: {len(partial_meta)}")
- retrieved_data = tq.get_data(partial_meta)
+ retrieved_data = tq_client.get_data(partial_meta)
print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}")
print("[Step 6] Demonstrate chunk operation...")
@@ -369,11 +376,11 @@ def demonstrate_real_workflow():
print(f"β Chunked into {len(chunks)} parts:")
for i, chunk in enumerate(chunks):
print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}")
- chunk_data = tq.get_data(chunk)
+ chunk_data = tq_client.get_data(chunk)
print(f" Chunk {i}: Retrieved chunk data: {chunk_data}")
# Cleanup
- tq.clear_partition(partition_id=partition_id)
+ tq_client.clear_partition(partition_id=partition_id)
tq.close()
ray.shutdown()
print("β Partition cleared and resources cleaned up")
@@ -385,7 +392,7 @@ def main():
print(
textwrap.dedent(
"""
- TransferQueue Tutorial 2: Metadata System
+ TransferQueue Tutorial 3: Metadata System
This script introduces the metadata system in TransferQueue, which tracks
the structure and state of data:
diff --git a/tutorial/03_understanding_controller.py b/tutorial/04_understanding_controller.py
similarity index 86%
rename from tutorial/03_understanding_controller.py
rename to tutorial/04_understanding_controller.py
index 4ca426d..746de14 100644
--- a/tutorial/03_understanding_controller.py
+++ b/tutorial/04_understanding_controller.py
@@ -62,6 +62,8 @@ def demonstrate_partition_isolation():
tq.init()
+ tq_client = tq.get_client()
+
# Partition 1: Training data
print("\n[Partition 1] Putting training data...")
train_data = TensorDict(
@@ -71,7 +73,7 @@ def demonstrate_partition_isolation():
},
batch_size=2,
)
- tq.put(data=train_data, partition_id="train")
+ tq_client.put(data=train_data, partition_id="train")
print(" β Training data added to 'train' partition")
# Partition 2: Validation data
@@ -83,23 +85,25 @@ def demonstrate_partition_isolation():
},
batch_size=2,
)
- tq.put(data=val_data, partition_id="val")
+ tq_client.put(data=val_data, partition_id="val")
print(" β Validation data added to 'val' partition")
# Get from train partition
print("\n[Retrieving from 'train' partition]")
- train_meta = tq.get_meta(
+ train_meta = tq_client.get_meta(
data_fields=["input_ids", "labels"], batch_size=2, partition_id="train", task_name="train_task"
)
- retrieved_train_data = tq.get_data(train_meta)
+ retrieved_train_data = tq_client.get_data(train_meta)
print(f" β Got BatchMeta={train_meta} from partition 'train'")
print(f" β Retrieved Data: input_ids={retrieved_train_data['input_ids']}, labels={retrieved_train_data['labels']}")
# Get from val partition
print("\n[Retrieving from 'val' partition]")
- val_meta = tq.get_meta(data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task")
- retrieved_val_data = tq.get_data(val_meta)
+ val_meta = tq_client.get_meta(
+ data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task"
+ )
+ retrieved_val_data = tq_client.get_data(val_meta)
print(f" β Got BatchMeta={val_meta} from partition 'val'")
print(f" β Retrieved Data: input_ids={retrieved_val_data['input_ids']}, labels={retrieved_val_data['labels']}")
@@ -107,8 +111,8 @@ def demonstrate_partition_isolation():
print(" β Data isolation: 'train' and 'val' partitions are completely independent")
# Cleanup
- tq.clear_partition(partition_id="train")
- tq.clear_partition(partition_id="val")
+ tq_client.clear_partition(partition_id="train")
+ tq_client.clear_partition(partition_id="val")
tq.close()
ray.shutdown()
@@ -126,6 +130,8 @@ def demonstrate_dynamic_expansion():
tq.init()
+ tq_client = tq.get_client()
+
# Add first batch with 2 samples, 2 fields
print("\n[Step 1] Adding initial data (2 samples, 2 fields)...")
data1 = TensorDict(
@@ -135,7 +141,7 @@ def demonstrate_dynamic_expansion():
},
batch_size=2,
)
- meta1 = tq.put(data=data1, partition_id="dynamic")
+ meta1 = tq_client.put(data=data1, partition_id="dynamic")
print(" β Added 2 samples")
print(f" β Got BatchMeta: {meta1} samples")
@@ -148,9 +154,9 @@ def demonstrate_dynamic_expansion():
},
batch_size=3,
)
- meta2 = tq.put(data=data2, partition_id="dynamic")
+ meta2 = tq_client.put(data=data2, partition_id="dynamic")
- all_meta = tq.get_meta(
+ all_meta = tq_client.get_meta(
data_fields=["field1", "field2"], batch_size=5, partition_id="dynamic", task_name="dynamic_task"
)
print(" β Added 3 more samples (total: 5)")
@@ -165,7 +171,7 @@ def demonstrate_dynamic_expansion():
},
batch_size=2,
)
- meta3 = tq.put(data=data3, metadata=meta1)
+ meta3 = tq_client.put(data=data3, metadata=meta1)
print(" β Added 2 samples with new field 'field3'")
print(f" β Got BatchMeta: {meta3} for newly put data with new field")
@@ -174,7 +180,7 @@ def demonstrate_dynamic_expansion():
print(" β Columns auto-expand: Can add new fields anytime")
# Cleanup
- tq.clear_partition(partition_id="dynamic")
+ tq_client.clear_partition(partition_id="dynamic")
tq.close()
ray.shutdown()
@@ -190,6 +196,8 @@ def demonstrate_default_consumption_sample_strategy():
tq.init()
+ tq_client = tq.get_client()
+
# Add 6 samples
print("\n[Setup] Adding 6 samples...")
all_data = TensorDict(
@@ -198,22 +206,22 @@ def demonstrate_default_consumption_sample_strategy():
},
batch_size=6,
)
- tq.put(data=all_data, partition_id="sampling")
+ tq_client.put(data=all_data, partition_id="sampling")
print(" β 6 samples added")
# First get - should get samples 0,1,2
print("\n[Task A, Get 1] Requesting 3 samples...")
- meta1 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
+ meta1 = tq_client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
print(f" β Got samples: {meta1.global_indexes}")
# Second get - should get samples 3,4,5 (no duplicates!)
print("\n[Task A, Get 2] Requesting 3 more samples...")
- meta2 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
+ meta2 = tq_client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A")
print(f" β Got samples: {meta2.global_indexes}")
# Third get - should get samples 0,1
print("\n[Task B, Get 1] Requesting 2 samples...")
- meta3 = tq.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B")
+ meta3 = tq_client.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B")
print(f" β Got samples: {meta3.global_indexes}")
print("\n[Verification]")
@@ -224,7 +232,7 @@ def demonstrate_default_consumption_sample_strategy():
print(" β Third get (Task B): samples 0,1")
# Cleanup
- tq.clear_partition(partition_id="sampling")
+ tq_client.clear_partition(partition_id="sampling")
tq.close()
ray.shutdown()
@@ -235,7 +243,7 @@ def main():
print(
textwrap.dedent(
"""
- TransferQueue Tutorial 3: Understanding TransferQueueController
+ TransferQueue Tutorial 4: Understanding TransferQueueController
This script demonstrates TransferQueueController's key features:
diff --git a/tutorial/04_custom_sampler.py b/tutorial/05_custom_sampler.py
similarity index 90%
rename from tutorial/04_custom_sampler.py
rename to tutorial/05_custom_sampler.py
index c35b4e0..e30487b 100644
--- a/tutorial/04_custom_sampler.py
+++ b/tutorial/05_custom_sampler.py
@@ -186,6 +186,8 @@ def demonstrate_random_sampler_with_replacement():
sampler = RandomSamplerWithReplacement()
setup_transfer_queue_with_sampler(sampler)
+ tq_client = tq.get_client()
+
# Add 5 samples
print("\n[Step 1] Adding 5 samples...")
data = TensorDict(
@@ -194,22 +196,22 @@ def demonstrate_random_sampler_with_replacement():
},
batch_size=5,
)
- tq.put(data=data, partition_id="test")
+ tq_client.put(data=data, partition_id="test")
print(" β 5 samples added")
# Get batch 1 (should get 2 random samples)
print("\n[Step 2] Get batch 1 (2 samples)...")
- meta1 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
+ meta1 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
print(f" β Got samples: {meta1.global_indexes}")
# Get batch 2 (should get 1 random sample with replacement - may have duplicate with previous batch!)
print("\n[Step 3] Get batch 2 (1 sample)...")
- meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
+ meta2 = tq_client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
print(f" β Got samples: {meta2.global_indexes}")
# Get batch 3 (should get 2 random samples with replacement - may have duplicate with previous batches!)
print("\n[Step 4] Get batch 3 (2 samples)...")
- meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
+ meta3 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
print(f" β Got samples: {meta3.global_indexes}")
print("\n[Verification]")
@@ -219,7 +221,7 @@ def demonstrate_random_sampler_with_replacement():
print(f" β All sampled: {all_sampled}")
# Cleanup
- tq.clear_partition(partition_id="test")
+ tq_client.clear_partition(partition_id="test")
tq.close()
ray.shutdown()
@@ -234,6 +236,8 @@ def demonstrate_random_sampler_without_replacement():
sampler = RandomSamplerWithoutReplacement()
setup_transfer_queue_with_sampler(sampler)
+ tq_client = tq.get_client()
+
# Add 6 samples
print("\n[Step 1] Adding 6 samples...")
data = TensorDict(
@@ -242,22 +246,22 @@ def demonstrate_random_sampler_without_replacement():
},
batch_size=6,
)
- tq.put(data=data, partition_id="test")
+ tq_client.put(data=data, partition_id="test")
print(" β 6 samples added")
# Get batch 1 (should get 3 random samples without replacement)
print("\n[Step 2] Get batch 1 (3 samples)...")
- meta1 = tq.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task")
+ meta1 = tq_client.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task")
print(f" β Got samples: {meta1.global_indexes}")
# Get batch 2 (should randomly get 1 sample that are different from previous batch)
print("\n[Step 3] Get batch 2 (1 samples)...")
- meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
+ meta2 = tq_client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task")
print(f" β Got samples: {meta2.global_indexes}")
# Get batch 3 (should randomly get 2 samples that are different from previous batch)
print("\n[Step 4] Get batch 3 (2 samples)...")
- meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
+ meta3 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task")
print(f" β Got samples: {meta3.global_indexes}")
print("\n[Verification]")
@@ -267,7 +271,7 @@ def demonstrate_random_sampler_without_replacement():
print(f" β Batch 3: {meta3.global_indexes} (none left)")
# Cleanup
- tq.clear_partition(partition_id="test")
+ tq_client.clear_partition(partition_id="test")
tq.close()
ray.shutdown()
@@ -282,6 +286,8 @@ def demonstrate_priority_sampler():
sampler = PrioritySampler()
setup_transfer_queue_with_sampler(sampler)
+ tq_client = tq.get_client()
+
# Add 8 samples
print("\n[Step 1] Adding 8 samples...")
data = TensorDict(
@@ -290,7 +296,7 @@ def demonstrate_priority_sampler():
},
batch_size=8,
)
- tq.put(data=data, partition_id="test")
+ tq_client.put(data=data, partition_id="test")
print(" β 8 samples added")
time.sleep(1)
@@ -303,7 +309,7 @@ def demonstrate_priority_sampler():
print(f"Priority scores: {priority_scores}")
# Get batch using priority sampling
- meta1 = tq.get_meta(
+ meta1 = tq_client.get_meta(
data_fields=["input"],
batch_size=1,
partition_id="test",
@@ -315,7 +321,7 @@ def demonstrate_priority_sampler():
# Get another batch
print("\n[Step 3] Get another batch (2 samples)...")
- meta2 = tq.get_meta(
+ meta2 = tq_client.get_meta(
data_fields=["input"],
batch_size=2,
partition_id="test",
@@ -331,7 +337,7 @@ def demonstrate_priority_sampler():
print(f" β Batch 2 high-priority indices: {[i for i in meta2.global_indexes if priority_scores[i] >= 0.1]}")
# Cleanup
- tq.clear_partition(partition_id="test")
+ tq_client.clear_partition(partition_id="test")
tq.close()
ray.shutdown()
@@ -341,7 +347,7 @@ def main():
print(
textwrap.dedent(
"""
- TransferQueue Tutorial 4: Custom Sampler Development
+ TransferQueue Tutorial 5: Custom Sampler Development
This script demonstrates how to develop custom samplers for TransferQueue.
Samplers control HOW data is consumed from the queue.
diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py
similarity index 98%
rename from tutorial/05_streaming_dataloader.py
rename to tutorial/06_streaming_dataloader.py
index 4d92aa2..c60b4e3 100644
--- a/tutorial/05_streaming_dataloader.py
+++ b/tutorial/06_streaming_dataloader.py
@@ -123,6 +123,8 @@ def generate_worker(rank_id: int, num_samples: int = 20):
# Need to call tq.init() in each process
tq.init()
+ tq_client = tq.get_client()
+
# Generate and put samples into the queue
for i in range(num_samples):
# Create unique sequence ID for this sample
@@ -137,7 +139,7 @@ def generate_worker(rank_id: int, num_samples: int = 20):
print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue")
# Put data into the specified partition
- tq.put(data, partition_id="train")
+ tq_client.put(data, partition_id="train")
print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue")
@@ -289,7 +291,7 @@ def main():
print(
textwrap.dedent(
"""
- TransferQueue Tutorial 5: StreamingDataLoader for Distributed Training
+ TransferQueue Tutorial 6: StreamingDataLoader for Distributed Training
This tutorial demonstrates the StreamingDataLoader interface for distributed
training scenarios. It showcases how to use StreamingDataset and StreamingDataLoader