diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml index 1202e21..c24cab8 100644 --- a/.github/workflows/tutorial-check.yml +++ b/.github/workflows/tutorial-check.yml @@ -36,4 +36,5 @@ jobs: - name: Run tutorials run: | export TQ_NUM_THREADS=2 + export RAY_DEDUP_LOGS=0 for file in tutorial/*.py; do python3 "$file"; done \ No newline at end of file diff --git a/README.md b/README.md index 2c7d4f0..8131b94 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ TransferQueue offers **fine-grained, sub-sample-level** data management and **lo

πŸ”„ Updates

- - **Jan 28, 2026**: We experimentally introduce `StreamingDataloader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/05_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_streaming_dataloader.py) for details. + - **Feb 8, 2026**: πŸ”₯ The initialization and usage is greatly simplified by high-level APIs [PR#26](https://github.com/Ascend/TransferQueue/pull/26), [PR#28](https://github.com/Ascend/TransferQueue/pull/28). You can now use a Redis-style API to take advantage of most of the advanced features provided by TransferQueue! + - **Jan 28, 2026**: We experimentally introduce `StreamingDataLoader` interface for fully-streamed production-consumption pipeline. Refer to our [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for details. - **Dec 30, 2025**: **TransferQueue x verl** integration is tested with the DAPO algorithm at scale **(64 nodes, 1024 cards)**. It significantly optimizes host memory utilization and accelerates data transfers. Stay tuned for more details! - **Dec 20, 2025**: πŸ”₯ The official [tutorial](https://github.com/Ascend/TransferQueue/tree/main/tutorial) is released! Feel free to check it out. - **Nov 10, 2025**: We disentangle the data retrieval logic from TransferQueueController [PR#101](https://github.com/TransferQueue/TransferQueue/pull/101). Now you can implement your own `Sampler` to control how to consume the data. @@ -91,26 +92,55 @@ This data structure design is motivated by the computational characteristics of

-### User Interface: Asynchronous & Synchronous Client -To simplify the usage of TransferQueue, we have encapsulated this process into `TransferQueueClient`. The client provides both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue into their framework. +### User Interface: High-Level & Low-Level APIs -We also experimentally provide a `StreamingDataLoader` interface as a standard PyTorch DataLoader. Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks. -This interface simplifies TransferQueue's integration, ensuring seamless compatibility with existing training workflows. Please refer to our [Roadmap](https://github.com/Ascend/TransferQueue/issues/1) and [tutorials/05_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_streaming_dataloader.py) for more details. +| Level | Tier | Style | Fine-Grained Access | Streaming | Sampler | Multiple-Backends | +|---|---|---|---|------------------|---|---| +| High | **KV Interface** (this PR) | Put/Get/List/Clear | βœ“ | β—‹ | βœ— | βœ“ | +| High | **StreamingDataLoader** (#23) | PyTorch DataLoader | βœ“ | βœ“ | βœ“ | βœ“ | +| Low | **TransferQueueClient** | Metadata-based | βœ“ | βœ“ | βœ“ | βœ“ | -

πŸ”₯ Showcases

-### General Usage +#### Key-Value based API + +To simplify the usage of TransferQueue, we have provided a Redis-style high-level API that can enjoy most of the advanced features provided by TransferQueue ([#PR28](https://github.com/Ascend/TransferQueue/pull/28)). + +**Methods** + +- **(async_)kv_put**: Insert/Update a multi-column sample by key, with optional metadata tag +- **(async_)kv_batch_put**: Put multiple key-value pairs efficiently in batch +- **(async_)kv_batch_get**: Retrieve samples (by keys), supporting column selection (by fields) +- **(async_)kv_list**: List keys and tags (metadata) in a partition +- **(async_)kv_clear**: Remove key-value pairs from storage + +**Key Features** + +- **Redis-style Semantics**: Familiar KV interface (Put/Get/List) for zero learning curve +- **Fine-grained Access**: Update or retrieve specific fields (columns) within a key (row) without full op. +- **Partition Isolation**: Logical separation of storage namespaces +- **Metadata Tags**: Lightweight metadata for status tracking +- **Pluggable Backends**: Supports multiple backends + +#### StreamingDataLoader API -The primary interaction points are `AsyncTransferQueueClient` and `TransferQueueClient`, serving as the communication interface with the TransferQueue system. +Designed as a drop-in replacement for the standard PyTorch `DataLoader`, this API allows each rank to automatically consume data without single-controller intervention. -Core interfaces: +In this scenario, `TransferQueueController` serves as a side-controller for data dispatching, with user-defined `Sampler` class to organize dataflow. +It encapsulates the complex scheduling and data transfer logic required for various parallelism strategies, seamlessly integrating TransferQueue into existing training workflows and simplifying the development of disaggregated frameworks. -- `(async_)get_meta(data_fields: list[str], batch_size:int, partition_id: str, mode: str, task_name:str, sampling_config: Optional[dict[str, Any]]) -> BatchMeta` -- `(async_)get_data(metadata: BatchMeta) -> TensorDict` -- `(async_)put(data: TensorDict, metadata: Optional[BatchMeta], partition_id: Optional[str])` -- `(async_)clear_partition(partition_id: str)` and `(async_)clear_samples(metadata: BatchMeta)` +See [Roadmap](https://github.com/Ascend/TransferQueue/issues/1) and [tutorials/06_streaming_dataloader.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/06_streaming_dataloader.py) for more details. -**Refer to our [tutorial](https://github.com/Ascend/TransferQueue/tree/main/tutorial) for detailed examples.** +#### Low-Level Native API + +The native interface of TransferQueue are implemented in `TransferQueueClient`. It offers maximum flexibility through native, atomic operations. + +Developers can leverage `TransferQueueClient` directly to implement advanced features that require fine-grained control and fully streamed data scheduling, as illustrated in the following tutorials: +- [tutorial/03_metadata_concepts.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/03_metadata_concepts.py) +- [tutorial/04_understanding_controller.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/04_understanding_controller.py) +- [tutorial/05_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_custom_sampler.py) + + +

πŸ”₯ Showcases

### Collocated Example @@ -131,7 +161,7 @@ You may refer to the [recipe](https://github.com/Ascend/TransferQueue/tree/dev/r ### Disaggregated Example -We have implemented a series of PRs ([#4](https://github.com/Ascend/TransferQueue/pull/4), [#7](https://github.com/Ascend/TransferQueue/pull/7), [#9](https://github.com/Ascend/TransferQueue/pull/9), [#16](https://github.com/Ascend/TransferQueue/pull/16)) to establish a **standardized, fully-streamed distributed** workflow via TransferQueue. +We have experimentally implemented a **standardized, fully-streamed distributed** workflow via TransferQueue. By leveraging the `RankAwareSampler` and `StreamingDataLoader` interfaces, we achieve a **streamlined micro-batch-level producer-consumer pipeline**. This design eliminates the need to manually determine data dispatching logic across varying parallelism strategiesβ€”a typical complexity in the single-controller paradigmβ€”thereby greatly simplifying framework design. @@ -186,7 +216,7 @@ pip install TransferQueue

-> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorageUnit` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community! +> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorage` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community! For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/tml8ke0zkgn6roey?singleDoc#). @@ -250,7 +280,7 @@ batch_meta = client.get_meta( ) ``` -**Refer to [tutorial/04_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/04_custom_sampler.py) for more details.** +**Refer to [tutorial/05_custom_sampler.py](https://github.com/Ascend/TransferQueue/blob/main/tutorial/05_custom_sampler.py) for more details.** ### How to integrate a new storage backend @@ -299,21 +329,6 @@ pip install pre-commit pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always ``` -

πŸ›£οΈ Roadmap

- -- [x] Support data rewrite for partial rollout & agentic post-training -- [x] Provide a general storage abstraction layer `TransferQueueStorageManager` to manage distributed storage units, which simplifies `Client` design and makes it possible to introduce different storage backends ([PR#66](https://github.com/TransferQueue/TransferQueue/pull/66), [issue#72](https://github.com/TransferQueue/TransferQueue/issues/72)) -- [x] Implement `AsyncSimpleStorageManager` as the default storage backend based on the `TransferQueueStorageManager` abstraction -- [x] Provide a `KVStorageManager` to cover all the KV based storage backends ([PR#96](https://github.com/TransferQueue/TransferQueue/pull/96)) -- [x] Support topic-based data partitioning to maintain train/val/test data simultaneously ([PR#98](https://github.com/TransferQueue/TransferQueue/pull/98)) -- [x] Release the first stable version through PyPI -- [ ] Support disaggregated framework (each rank retrieves its own data without going through a centralized node) -- [ ] Provide a `StreamingDataLoader` interface for disaggregated framework -- [ ] Support load-balancing and dynamic batching -- [x] Support high-performance storage backends for RDMA transmission (e.g., [Mooncake Store](https://github.com/kvcache-ai/Mooncake), [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html)...) -- [x] High-performance serialization and deserialization -- [ ] More documentation, examples and tutorials -

πŸ“‘ Citation

Please kindly cite our paper if you find this repo is useful: diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py new file mode 100644 index 0000000..71e9833 --- /dev/null +++ b/tests/e2e/test_kv_interface_e2e.py @@ -0,0 +1,588 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for KV interface in transfer_queue.interface. + +This test module validates the KV interface functionality by: +1. Using external interfaces (kv_put, kv_batch_put, kv_batch_get, kv_list, kv_clear) for read/write +2. Verifying correctness by calling TransferQueueController's internal methods directly +""" + +import os +import sys +from pathlib import Path + +import pytest +import ray +import torch +from tensordict import TensorDict + +# Add parent directory to path +parent_dir = Path(__file__).resolve().parent.parent.parent +sys.path.append(str(parent_dir)) + +import transfer_queue as tq # noqa: E402 + +# Configure Ray for tests +os.environ["RAY_DEDUP_LOGS"] = "0" + + +@pytest.fixture(scope="module") +def ray_init(): + """Initialize Ray for the test module.""" + if not ray.is_initialized(): + ray.init(namespace="TestKVInterfaceE2E") + yield + if ray.is_initialized(): + ray.shutdown() + + +@pytest.fixture(scope="module") +def tq_system(ray_init): + """Initialize TransferQueue system for the test module.""" + tq.init() + yield + tq.close() + + +@pytest.fixture +def controller(tq_system): + """Get the TransferQueueController actor for direct verification.""" + controller = ray.get_actor("TransferQueueController") + yield controller + + +@pytest.fixture(autouse=True) +def cleanup_partition(controller): + """Cleanup partition after each test.""" + yield + try: + partition_ids = ray.get(controller.list_partitions.remote()) + for partition_id in partition_ids: + ray.get(controller.clear_partition.remote(partition_id)) + except Exception: + pass + + +def get_controller_partition(controller, partition_id: str): + """Get partition snapshot from controller for verification.""" + return ray.get(controller.get_partition_snapshot.remote(partition_id)) + + +def assert_tensor_equal(tensor_a, tensor_b, msg=""): + """Assert two tensors are equal.""" + assert torch.equal(tensor_a, tensor_b), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}" + + +def assert_tensor_close(tensor_a, tensor_b, rtol=1e-5, atol=1e-8, msg=""): + """Assert two tensors are close.""" + assert torch.allclose(tensor_a, tensor_b, rtol=rtol, atol=atol), f"{msg} Tensors are not close" + + +class TestKVPutE2E: + """End-to-end tests for kv_put functionality.""" + + def test_kv_put_with_dict_fields(self, controller): + """Test kv_put with dict fields (auto-converted to TensorDict).""" + partition_id = "test_partition" + key = "sample_0" + + # Put with dict fields - will be auto-unsqueezed + tq.kv_put( + key=key, partition_id=partition_id, fields={"data": torch.tensor([1, 2, 3, 4])}, tag={"type": "dict_test"} + ) + + # Verify - retrieved data will have batch dimension + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed + assert_tensor_equal(retrieved["data"], expected) + + def test_kv_put_with_tensordict_fields(self, controller): + """Test kv_put with tensordict fields.""" + partition_id = "test_partition" + key = "sample_1" + + tensordict_data = TensorDict( + { + "input_ids": torch.tensor([[1, 2, 3, 4]]), + }, + batch_size=1, + ) + # Put with dict fields - will be auto-unsqueezed + tq.kv_put(key=key, partition_id=partition_id, fields=tensordict_data, tag={"type": "tensordict_test"}) + + # Verify - retrieved data will have batch dimension + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed + assert_tensor_equal(retrieved["input_ids"], expected) + + def test_kv_put_single_sample_with_fields_and_tag(self, controller): + """Test putting a single sample with fields and tag.""" + partition_id = "test_partition" + key = "sample_2" + # Use 1D tensors - kv_put with dict will auto-unsqueeze to add batch dimension + input_ids = torch.tensor([1, 2, 3]) + attention_mask = torch.ones(3) + tag = {"global_steps": 0, "status": "running"} + + # Put data using interface + tq.kv_put( + key=key, + partition_id=partition_id, + fields={"input_ids": input_ids, "attention_mask": attention_mask}, + tag=tag, + ) + + # Verify via controller internal state + partition = get_controller_partition(controller, partition_id) + assert partition is not None, "Partition should exist" + + # Check key->global_index mapping + assert key in partition.keys_mapping, f"Key {key} should be in keys_mapping" + global_idx = partition.keys_mapping[key] + assert global_idx in partition.global_indexes, f"Global index {global_idx} should be in global_indexes" + + # Check custom_meta (tag) + assert global_idx in partition.custom_meta, f"Custom meta should exist for global index {global_idx}" + assert partition.custom_meta[global_idx]["global_steps"] == 0 + assert partition.custom_meta[global_idx]["status"] == "running" + + # Check production status - fields should be marked as produced + assert "input_ids" in partition.field_name_mapping, "input_ids field should be registered" + assert "attention_mask" in partition.field_name_mapping, "attention_mask field should be registered" + input_ids_col_idx = partition.field_name_mapping["input_ids"] + assert partition.production_status[global_idx, input_ids_col_idx] == 1, "input_ids should be marked as produced" + + # Retrieve and verify data via kv_batch_get - tensors will have batch dimension + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + assert "input_ids" in retrieved.keys() + assert "attention_mask" in retrieved.keys() + # After unsqueeze, tensors become 2D [batch_size=1, original_size] + expected_input_ids = input_ids.unsqueeze(0) + expected_attention_mask = attention_mask.unsqueeze(0) + assert_tensor_equal(retrieved["input_ids"], expected_input_ids) + assert_tensor_equal(retrieved["attention_mask"], expected_attention_mask) + + def test_kv_put_update_tag_only(self, controller): + """Test updating only tag without providing fields.""" + partition_id = "test_partition" + key = "sample_3" + + # First put with fields - use TensorDict as another example + single_data = TensorDict({"value": torch.tensor([[10]])}, batch_size=1) + tq.kv_put(key=key, partition_id=partition_id, fields=single_data, tag={"version": 1}) + + # Update only tag + new_tag = {"version": 2, "status": "updated"} + tq.kv_put(key=key, partition_id=partition_id, fields=None, tag=new_tag) + + # Verify via controller + partition = get_controller_partition(controller, partition_id) + global_idx = partition.keys_mapping[key] + assert partition.custom_meta[global_idx]["version"] == 2 + assert partition.custom_meta[global_idx]["status"] == "updated" + + # Data should still be accessible + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + assert_tensor_equal(retrieved["value"], torch.tensor([[10]])) + + def test_kv_put_partial_update(self, controller): + """Test adding new fields to existing sample.""" + partition_id = "test_partition" + key = "sample_4" + + # First put initial data + initial_data = TensorDict( + { + "input_ids": torch.tensor([[1, 2, 3, 4]]), + }, + batch_size=1, + ) + tq.kv_put(key=key, partition_id=partition_id, fields=initial_data, tag={"v": 1}) + + # Add new fields to subset of keys + new_fields = TensorDict( + { + "response": torch.tensor([[5, 6]]), + }, + batch_size=1, + ) + tq.kv_put(key=key, partition_id=partition_id, fields=new_fields, tag={"v": 2}) + + # Verify via controller - only keys[1] should have response field + partition = get_controller_partition(controller, partition_id) + global_idx = partition.keys_mapping[key] + + # Check that fields were added + assert "response" in partition.field_name_mapping + response_col_idx = partition.field_name_mapping["response"] + + # key should have response marked as produced + assert partition.production_status[global_idx, response_col_idx] == 1, "Key should have response" + + +class TestKVBatchPutE2E: + """End-to-end tests for kv_batch_put functionality.""" + + def test_kv_batch_put_multiple_samples(self, controller): + """Test batch putting multiple samples.""" + partition_id = "test_partition" + keys = ["batch_0", "batch_1", "batch_2", "batch_3"] + batch_input_ids = torch.tensor( + [ + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + [13, 14, 15], + ] + ) + batch_attention_mask = torch.ones_like(batch_input_ids) + + fields = TensorDict( + { + "input_ids": batch_input_ids, + "attention_mask": batch_attention_mask, + }, + batch_size=4, + ) + + tags = [{"idx": i, "batch": True} for i in range(4)] + + # Batch put using interface + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags) + + # Verify via controller + partition = get_controller_partition(controller, partition_id) + assert partition is not None + + # All keys should be registered + for key in keys: + assert key in partition.keys_mapping, f"Key {key} should be in keys_mapping" + + # Verify tags + for i, key in enumerate(keys): + global_idx = partition.keys_mapping[key] + assert partition.custom_meta[global_idx]["idx"] == i + assert partition.custom_meta[global_idx]["batch"] is True + + # Verify all data via kv_batch_get + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) + assert_tensor_equal(retrieved["input_ids"], batch_input_ids) + assert_tensor_equal(retrieved["attention_mask"], batch_attention_mask) + + def test_kv_batch_put_partial_update(self, controller): + """Test adding new fields to existing samples.""" + partition_id = "test_partition" + keys = ["partial_0", "partial_1"] + + # First put initial data + initial_data = TensorDict( + { + "input_ids": torch.tensor([[1, 2], [3, 4]]), + }, + batch_size=2, + ) + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=initial_data, tags=[{"v": 1}, {"v": 1}]) + + # Add new fields to subset of keys + new_fields = TensorDict( + { + "response": torch.tensor([[5, 6]]), # Only for 1 sample + }, + batch_size=1, + ) + tq.kv_batch_put(keys=[keys[1]], partition_id=partition_id, fields=new_fields, tags=[{"v": 2}]) + + # Verify via controller - only keys[1] should have response field + partition = get_controller_partition(controller, partition_id) + global_idx_1 = partition.keys_mapping[keys[1]] + + # Check that fields were added + assert "response" in partition.field_name_mapping + response_col_idx = partition.field_name_mapping["response"] + + # keys[0] should NOT have response marked as produced + global_idx_0 = partition.keys_mapping[keys[0]] + assert partition.production_status[global_idx_0, response_col_idx] == 0, "Keys[0] should not have response" + + # keys[1] should have response marked as produced + assert partition.production_status[global_idx_1, response_col_idx] == 1, "Keys[1] should have response" + + +class TestKVGetE2E: + """End-to-end tests for kv_batch_get functionality.""" + + def test_kv_batch_get_single_key(self, controller): + """Test getting data for a single key.""" + partition_id = "test_partition" + key = "get_single" + # Use TensorDict to avoid auto-unsqueeze issue with dict input + expected_data = torch.tensor([[100, 200, 300]]) + fields = TensorDict({"data": expected_data}, batch_size=1) + + tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) + + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + assert_tensor_equal(retrieved["data"], expected_data) + + def test_kv_batch_get_multiple_keys(self, controller): + """Test getting data for multiple keys.""" + partition_id = "test_partition" + keys = ["get_multi_0", "get_multi_1", "get_multi_2"] + expected_data = torch.tensor([[1, 2], [3, 4], [5, 6]]) + + fields = TensorDict({"data": expected_data}, batch_size=3) + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) + + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) + assert_tensor_equal(retrieved["data"], expected_data) + + def test_kv_batch_get_partial_keys(self, controller): + """Test getting data for partial keys.""" + partition_id = "test_partition" + keys = ["get_multi_3", "get_multi_4", "get_multi_5"] + partial_keys = ["get_multi_3", "get_multi_5"] + input_data = torch.tensor([[1, 2], [3, 4], [5, 6]]) + expected_data = torch.tensor([[1, 2], [5, 6]]) + + fields = TensorDict({"data": input_data}, batch_size=3) + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) + + retrieved = tq.kv_batch_get(keys=partial_keys, partition_id=partition_id) + assert_tensor_equal(retrieved["data"], expected_data) + + def test_kv_batch_get_partial_fields(self, controller): + """Test getting only partial fields.""" + partition_id = "test_partition" + key = "get_fields" + # Use TensorDict to avoid auto-unsqueeze issue + input_ids = torch.tensor([[1, 2, 3]]) + attention_mask = torch.ones(1, 3) + response = torch.tensor([[10, 20]]) + + fields = TensorDict( + {"input_ids": input_ids, "attention_mask": attention_mask, "response": response}, batch_size=1 + ) + + # Put all fields + tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) + + # Get only input_ids + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, fields="input_ids") + assert "input_ids" in retrieved.keys() + assert "attention_mask" not in retrieved.keys() + assert "response" not in retrieved.keys() + assert_tensor_equal(retrieved["input_ids"], input_ids) + + # Get multiple specific fields + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, fields=["input_ids", "response"]) + assert "input_ids" in retrieved.keys() + assert "response" in retrieved.keys() + assert "attention_mask" not in retrieved.keys() + assert_tensor_equal(retrieved["input_ids"], input_ids) + assert_tensor_equal(retrieved["response"], response) + + def test_kv_batch_get_nonexistent_key(self, controller): + """Test that getting data for non-existent key returns empty result.""" + partition_id = "test_partition" + + # Try to get data for a key that doesn't exist - should return empty or raise error + try: + retrieved = tq.kv_batch_get(keys="nonexistent_key", partition_id=partition_id) + # If it returns, it should be empty + assert retrieved.batch_size[0] == 0 + except RuntimeError as e: + # Or it might raise an error about keys not found + assert "not found" in str(e).lower() or "empty" in str(e).lower() + + +class TestKVListE2E: + """End-to-end tests for kv_list functionality.""" + + def test_kv_list_single_partition(self, controller): + """Test listing all keys and tags in single partition.""" + partition_id = "test_partition" + keys = ["list_0", "list_1", "list_2"] + + for i, key in enumerate(keys): + tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag={"id": i}) + + # List all keys + partition_info = tq.kv_list(partition_id=partition_id) + + assert len(partition_info.keys()) == 1 + assert "test_partition" in partition_info.keys() + assert len(partition_info["test_partition"]) == 3 + for key in keys: + assert key in partition_info["test_partition"] + + # Verify tags match + for i, (key, tag) in enumerate(partition_info["test_partition"].items()): + assert tag["id"] == i + + def test_kv_list_all_partitions(self, controller): + """Test listing keys and tags in all partitions.""" + partition_id = ["test_partition0", "test_partition1", "test_partition2"] + + keys_partition0 = ["list_0", "list_1", "list_2"] + keys_partition1 = ["list_0", "list_1", "list_2"] # deliberately set same keys + keys_partition2 = ["list_3", "list_4", "list_5", "list_6"] + + fields_partition0 = TensorDict({"data": torch.tensor([[0], [1], [2]])}, batch_size=3) + fields_partition1 = TensorDict({"data": torch.tensor([[3], [4], [5]])}, batch_size=3) + fields_partition2 = TensorDict({"data": torch.tensor([[6], [7], [8], [9]])}, batch_size=4) + + tags_partition0 = [{"id": i} for i in range(3)] + tags_partition1 = [{"id": i + 3} for i in range(3)] + tags_partition2 = [{"id": i + 6} for i in range(4)] + + # Put to TQ + tq.kv_batch_put( + keys=keys_partition0, partition_id=partition_id[0], fields=fields_partition0, tags=tags_partition0 + ) + tq.kv_batch_put( + keys=keys_partition1, partition_id=partition_id[1], fields=fields_partition1, tags=tags_partition1 + ) + tq.kv_batch_put( + keys=keys_partition2, partition_id=partition_id[2], fields=fields_partition2, tags=tags_partition2 + ) + + # List all keys + partition_info = tq.kv_list() + + # Verify all partitions are exist + assert len(partition_info.keys()) == 3 + assert "test_partition0" in partition_info.keys() + assert "test_partition1" in partition_info.keys() + assert "test_partition2" in partition_info.keys() + + assert len(partition_info["test_partition0"]) == 3 + for key in keys_partition0: + assert key in partition_info["test_partition0"] + + assert len(partition_info["test_partition1"]) == 3 + for key in keys_partition1: + assert key in partition_info["test_partition1"] + + assert len(partition_info["test_partition2"]) == 4 + for key in keys_partition2: + assert key in partition_info["test_partition2"] + + # Verify tags match + for i, (key, tag) in enumerate(partition_info["test_partition0"].items()): + assert tag["id"] == i + for i, (key, tag) in enumerate(partition_info["test_partition1"].items()): + assert tag["id"] == i + 3 + for i, (key, tag) in enumerate(partition_info["test_partition2"].items()): + assert tag["id"] == i + 6 + + def test_kv_list_empty_partition(self): + """Test listing empty partition.""" + partition_id = "test_partition_empty" + + partition_info = tq.kv_list(partition_id=partition_id) + + assert len(partition_info) == 0 + + +class TestKVClearE2E: + """End-to-end tests for kv_clear functionality.""" + + def test_kv_clear_single_key(self, controller): + """Test clearing a single key.""" + partition_id = "test_partition" + key = "clear_single" + other_key = "clear_other" + + tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[1]])}, tag={"id": "single"}) + tq.kv_put(key=other_key, partition_id=partition_id, fields={"data": torch.tensor([[2]])}, tag={"id": "other"}) + + # Clear single key + tq.kv_clear(keys=key, partition_id=partition_id) + + # Verify via kv_list + partition_info = tq.kv_list(partition_id=partition_id) + assert key not in partition_info[partition_id] + assert other_key in partition_info[partition_id] + + # Verify via controller - key should be removed + partition = get_controller_partition(controller, partition_id) + assert key not in partition.keys_mapping + assert other_key in partition.keys_mapping + + def test_kv_clear_multiple_keys(self, controller): + """Test clearing multiple keys.""" + partition_id = "test_partition" + keys = ["clear_multi_0", "clear_multi_1", "clear_multi_2", "clear_multi_3"] + + for i, key in enumerate(keys): + tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag=None) + + # Clear first 2 keys + tq.kv_clear(keys=keys[:2], partition_id=partition_id) + + # Verify + partition_info = tq.kv_list(partition_id=partition_id) + assert len(partition_info[partition_id]) == 2 + assert keys[0] not in partition_info[partition_id] + assert keys[1] not in partition_info[partition_id] + assert keys[2] in partition_info[partition_id] + assert keys[3] in partition_info[partition_id] + + +class TestKVE2ECornerCases: + """End-to-end tests for corner cases.""" + + def test_field_expansion_across_samples(self, controller): + """Test that new fields can be added across samples.""" + partition_id = "test_partition" + keys = ["expand_0", "expand_1"] + + # Put initial fields + tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_a": torch.tensor([[1]])}, tag=None) + + # Add new field to first key + tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_b": torch.tensor([[2]])}, tag=None) + + # Add different field to second key + tq.kv_put( + key=keys[1], + partition_id=partition_id, + fields={"field_a": torch.tensor([[3]]), "field_c": torch.tensor([[4]])}, + tag=None, + ) + + # Verify field expansion in controller + partition = get_controller_partition(controller, partition_id) + + # All fields should be registered, but only samples with the actual fields are labeled as READY_FOR_CONSUME + assert "field_a" in partition.field_name_mapping + assert "field_b" in partition.field_name_mapping + assert "field_c" in partition.field_name_mapping + + # We can only fetch "field_a" because not all requested keys has other fields + data = tq.kv_batch_get(keys=keys, partition_id=partition_id) + assert "field_a" in data + assert "field_b" not in data + assert "field_c" not in data + + +def run_tests(): + """Run all e2e tests manually for debugging.""" + pytest.main([__file__, "-v", "-s"]) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_client.py b/tests/test_client.py index 42cd63b..5d308d8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -34,7 +34,7 @@ FieldMeta, SampleMeta, ) -from transfer_queue.utils.enum_utils import TransferQueueRole # noqa: E402 +from transfer_queue.utils.enum_utils import ProductionStatus, TransferQueueRole # noqa: E402 from transfer_queue.utils.zmq_utils import ( # noqa: E402 ZMQMessage, ZMQRequestType, @@ -144,6 +144,12 @@ def _handle_requests(self): "message": "Consumption reset successfully", } response_type = ZMQRequestType.RESET_CONSUMPTION_RESPONSE + elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS: + response_body = self._mock_kv_retrieve_keys(request_msg.body) + response_type = ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE + elif request_msg.request_type == ZMQRequestType.KV_LIST: + response_body = self._mock_kv_list(request_msg.body) + response_type = ZMQRequestType.KV_LIST_RESPONSE else: response_body = {"error": f"Unknown request type: {request_msg.request_type}"} response_type = ZMQRequestType.CLEAR_META_RESPONSE @@ -174,7 +180,7 @@ def _mock_batch_meta(self, request_body): name=field_name, dtype=None, shape=None, - production_status=0, + production_status=ProductionStatus.NOT_PRODUCED, ) fields.append(field_meta) sample = SampleMeta( @@ -187,6 +193,81 @@ def _mock_batch_meta(self, request_body): return {"metadata": metadata} + def _mock_kv_retrieve_keys(self, request_body): + """Mock KV retrieve keys response.""" + keys = request_body.get("keys", []) + create = request_body.get("create", False) + partition_id = request_body.get("partition_id", "") + + # Initialize key tracking if not exists + if not hasattr(self, "_kv_partition_keys"): + self._kv_partition_keys = {} + + # Generate global indexes for the keys + start_index = self._get_next_kv_index(partition_id) + global_indexes = list(range(start_index, start_index + len(keys))) + + # Create metadata for each key + samples = [] + for i, key in enumerate(keys): + field_meta = FieldMeta( + name="data", + dtype=torch.float32, + shape=torch.Size([1, 10]), + production_status=ProductionStatus.READY_FOR_CONSUME, + ) + sample = SampleMeta( + partition_id=partition_id, + global_index=global_indexes[i], + fields={"data": field_meta}, + ) + samples.append(sample) + + metadata = BatchMeta(samples=samples) + + # Store keys for this partition (only when create=True) + if create: + if partition_id not in self._kv_partition_keys: + self._kv_partition_keys[partition_id] = [] + self._kv_partition_keys[partition_id].extend(keys) + + # Update the next index for this partition + if global_indexes: + self._update_kv_index(partition_id, global_indexes[-1] + 1) + + return {"metadata": metadata} + + def _mock_kv_list(self, request_body): + """Mock KV list response.""" + partition_id = request_body.get("partition_id", None) + + # Initialize key tracking if not exists + if not hasattr(self, "_kv_partition_keys"): + self._kv_partition_keys = {} + + # Return cached keys for this partition + keys = self._kv_partition_keys.get(partition_id, []) + + return {"partition_info": {partition_id: {k: {} for k in keys}}, "message": "success"} + + def _get_next_kv_index(self, partition_id): + """Get next available index for KV keys in partition.""" + if not hasattr(self, "_kv_index_map"): + self._kv_index_map = {} + if partition_id not in self._kv_index_map: + self._kv_index_map[partition_id] = 0 + # Also initialize key tracking + if not hasattr(self, "_kv_partition_keys"): + self._kv_partition_keys = {} + self._kv_partition_keys[partition_id] = [] + return self._kv_index_map[partition_id] + + def _update_kv_index(self, partition_id, next_index): + """Update next available index for KV keys.""" + if not hasattr(self, "_kv_index_map"): + self._kv_index_map = {} + self._kv_index_map[partition_id] = next_index + def stop(self): self.running = False time.sleep(0.2) # Give thread time to stop @@ -850,10 +931,10 @@ def test_set_custom_meta_sync(self, client_setup): metadata = client.get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0") # Set custom_meta on the metadata metadata.update_custom_meta( - { - 0: {"input_ids": {"token_count": 100}}, - 1: {"input_ids": {"token_count": 120}}, - } + [ + {"input_ids": {"token_count": 100}}, + {"input_ids": {"token_count": 120}}, + ] ) # Call set_custom_meta with metadata (BatchMeta) @@ -869,12 +950,165 @@ async def test_set_custom_meta_async(self, client_setup): metadata = await client.async_get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0") # Set custom_meta on the metadata metadata.update_custom_meta( - { - 0: {"input_ids": {"token_count": 100}}, - 1: {"input_ids": {"token_count": 120}}, - } + [ + {"input_ids": {"token_count": 100}}, + {"input_ids": {"token_count": 120}}, + ] ) # Call async_set_custom_meta with metadata (BatchMeta) await client.async_set_custom_meta(metadata) print("βœ“ async_set_custom_meta async method works") + + +# ===================================================== +# KV Interface Tests +# ===================================================== + + +class TestClientKVInterface: + """Tests for client KV interface methods.""" + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_single(self, client_setup): + """Test async_kv_retrieve_keys with single key.""" + client, _, _ = client_setup + + # Test async_kv_retrieve_keys with single key + metadata = await client.async_kv_retrieve_keys( + keys="test_key_1", + partition_id="test_partition", + create=True, + ) + + # Verify metadata structure + assert metadata is not None + assert hasattr(metadata, "global_indexes") + assert hasattr(metadata, "size") + assert metadata.size == 1 + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_multiple(self, client_setup): + """Test async_kv_retrieve_keys with multiple keys.""" + client, _, _ = client_setup + + # Test async_kv_retrieve_keys with multiple keys + keys = ["key_a", "key_b", "key_c"] + metadata = await client.async_kv_retrieve_keys( + keys=keys, + partition_id="test_partition", + create=True, + ) + + # Verify metadata structure + assert metadata is not None + assert hasattr(metadata, "global_indexes") + assert hasattr(metadata, "size") + assert metadata.size == 3 + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_create_false(self, client_setup): + """Test async_kv_retrieve_keys with create=False (retrieve existing keys).""" + client, _, _ = client_setup + + # create some keys + await client.async_kv_retrieve_keys( + keys="existing_key", + partition_id="existing_partition", + create=True, + ) + + # Then retrieve them with create=False + metadata = await client.async_kv_retrieve_keys( + keys="existing_key", + partition_id="existing_partition", + create=False, + ) + + # Verify metadata structure + assert metadata is not None + assert metadata.size == 1 + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_invalid_keys_type(self, client_setup): + """Test async_kv_retrieve_keys raises error with invalid keys type.""" + client, _, _ = client_setup + + # Test with invalid keys type (not string or list) + with pytest.raises(TypeError): + await client.async_kv_retrieve_keys( + keys=123, # Invalid type + partition_id="test_partition", + create=True, + ) + + @pytest.mark.asyncio + async def test_async_kv_list_with_keys(self, client_setup): + """Test async_kv_list returns keys after they are registered.""" + client, mock_controller, _ = client_setup + + # First register some keys + await client.async_kv_retrieve_keys( + keys=["key_1", "key_2"], + partition_id="kv_partition", + create=True, + ) + + # Then list them + partition_info = await client.async_kv_list(partition_id="kv_partition") + + # Verify keys are returned + assert len(partition_info["kv_partition"]) >= 2 + assert "key_1" in partition_info["kv_partition"] + assert "key_2" in partition_info["kv_partition"] + assert list(partition_info["kv_partition"].values()) == [{}, {}] + + @pytest.mark.asyncio + async def test_async_kv_list_multiple_partitions(self, client_setup): + """Test async_kv_list with multiple partitions.""" + client, _, _ = client_setup + + # Create keys in different partitions + await client.async_kv_retrieve_keys( + keys="partition_a_key", + partition_id="partition_a", + create=True, + ) + await client.async_kv_retrieve_keys( + keys="partition_b_key", + partition_id="partition_b", + create=True, + ) + + # List keys for each partition + partition_a = await client.async_kv_list(partition_id="partition_a") + partition_b = await client.async_kv_list(partition_id="partition_b") + + # Verify keys are isolated per partition + assert "partition_a" in partition_a + assert "partition_b" in partition_b + assert "partition_a" not in partition_b + assert "partition_b" not in partition_a + assert "partition_a_key" in partition_a["partition_a"] + assert "partition_b_key" not in partition_a["partition_a"] + assert "partition_b_key" in partition_b["partition_b"] + assert "partition_a_key" not in partition_b["partition_b"] + assert list(partition_a["partition_a"].values()) == [{}] + assert list(partition_b["partition_b"].values()) == [{}] + + def test_kv_retrieve_keys_type_validation(self, client_setup): + """Test synchronous kv_retrieve_keys type validation.""" + import asyncio + + client, _, _ = client_setup + + # Test with non-string element in list + async def test_invalid_list(): + with pytest.raises(TypeError): + await client.async_kv_retrieve_keys( + keys=["valid_key", 123], # Invalid: 123 is not a string + partition_id="test_partition", + create=True, + ) + + asyncio.run(test_invalid_list()) diff --git a/tests/test_controller.py b/tests/test_controller.py index 09f5778..3528d1a 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -28,7 +28,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -from transfer_queue import TransferQueueController # noqa: E402 +from transfer_queue.controller import TransferQueueController # noqa: E402 from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 @@ -760,3 +760,205 @@ def test_controller_with_custom_meta(self, ray_setup): # Clean up ray.get(tq_controller.clear_partition.remote(partition_id)) + + +class TestTransferQueueControllerKvInterface: + """End-to-end tests for TransferQueueController KV interface functionality. + + Tests for kv_retrieve_keys method that supports key-value interface operations + across the controller and partition layers. + """ + + def test_controller_kv_retrieve_keys_create_mode(self, ray_setup): + """Test kv_retrieve_keys with create=True creates new keys in partition.""" + tq_controller = TransferQueueController.remote() + partition_id = "kv_test_partition" + + # Retrieve keys with create=True - should create partition and keys + keys = ["key_a", "key_b", "key_c"] + metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + + # Verify partition was created + partitions = ray.get(tq_controller.list_partitions.remote()) + assert partition_id in partitions + + # Verify metadata contains correct number of global_indexes + assert len(metadata.global_indexes) == len(keys) + + # Verify partition has keys_mapping + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + assert "key_a" in partition.keys_mapping + assert "key_b" in partition.keys_mapping + assert "key_c" in partition.keys_mapping + assert metadata.global_indexes[0] == partition.keys_mapping["key_a"] + assert metadata.global_indexes[1] == partition.keys_mapping["key_b"] + assert metadata.global_indexes[2] == partition.keys_mapping["key_c"] + assert partition.revert_keys_mapping[metadata.global_indexes[0]] == "key_a" + assert partition.revert_keys_mapping[metadata.global_indexes[1]] == "key_b" + assert partition.revert_keys_mapping[metadata.global_indexes[2]] == "key_c" + + print("βœ“ kv_retrieve_keys with create=True creates keys correctly") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_existing_keys(self, ray_setup): + """Test kv_retrieve_keys retrieves existing keys correctly.""" + tq_controller = TransferQueueController.remote() + partition_id = "kv_existing_test" + + # First, create some keys + keys = ["existing_key_1", "existing_key_2"] + ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + + # Retrieve the same keys again (should return existing) + retrieved_metadata = ray.get( + tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False) + ) + + # Verify the same global_indexes are returned + assert len(retrieved_metadata.global_indexes) == len(keys) + + print("βœ“ kv_retrieve_keys retrieves existing keys correctly") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_non_existent_without_create(self, ray_setup): + """Test kv_retrieve_keys raises error for non-existent keys without create.""" + tq_controller = TransferQueueController.remote() + partition_id = "kv_nonexistent_test" + + # Create partition first + ray.get(tq_controller.kv_retrieve_keys.remote(keys=["initial_key"], partition_id=partition_id, create=True)) + + # Try to retrieve non-existent key without create + batch_meta = ray.get( + tq_controller.kv_retrieve_keys.remote(keys=["nonexistent_key"], partition_id=partition_id, create=False) + ) + assert batch_meta.size == 0 + + print("βœ“ kv_retrieve_keys return an empty BatchMeta for non-existent keys without create") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_empty_partition_without_create(self, ray_setup): + """Test kv_retrieve_keys raises error for non-existent partition without create.""" + tq_controller = TransferQueueController.remote() + partition_id = "nonexistent_partition" + + batch_meta = ray.get( + tq_controller.kv_retrieve_keys.remote(keys=["key_1"], partition_id=partition_id, create=False) + ) + assert batch_meta.size == 0 + + print("βœ“ kv_retrieve_keys return an empty BatchMeta for non-existent partition_id without create") + + def test_controller_kv_retrieve_keys_with_production_status(self, ray_setup): + """Test kv_retrieve_keys works with production status update.""" + tq_controller = TransferQueueController.remote() + partition_id = "kv_production_test" + + # Create keys + keys = ["sample_1", "sample_2", "sample_3"] + metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + global_indexes = metadata.global_indexes + + # Update production status + dtypes = {idx: {"data": "torch.float32"} for idx in global_indexes} + shapes = {idx: {"data": (64,)} for idx in global_indexes} + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=global_indexes, + field_names=["data"], + dtypes=dtypes, + shapes=shapes, + ) + ) + assert success + + # Retrieve keys again (should include production info) + retrieved_metadata = ray.get( + tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False) + ) + + # Verify production status is available + assert len(retrieved_metadata.samples) == len(keys) + for sample in retrieved_metadata.samples: + assert "data" in sample.fields + assert sample.fields["data"].dtype == "torch.float32" + assert sample.fields["data"].shape == (64,) + + print("βœ“ kv_retrieve_keys works with production status") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_with_custom_meta(self, ray_setup): + """Test kv_retrieve_keys preserves custom_meta through retrieve.""" + tq_controller = TransferQueueController.remote() + partition_id = "kv_custom_meta_test" + + # Create keys + keys = ["key_1", "key_2"] + metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + + # Set custom_meta + custom_meta = { + partition_id: { + metadata.global_indexes[0]: {"score": 0.9, "tag": "A"}, + metadata.global_indexes[1]: {"score": 0.8, "tag": "B"}, + } + } + ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=custom_meta)) + + # Retrieve keys and verify custom_meta + retrieved_metadata = ray.get( + tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False) + ) + + # Verify custom_meta is preserved + all_custom_meta = retrieved_metadata.get_all_custom_meta() + assert len(all_custom_meta) == 2 + assert all_custom_meta[0]["score"] == 0.9 + assert all_custom_meta[1]["tag"] == "B" + + print("βœ“ kv_retrieve_keys preserves custom_meta") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_interface_multiple_partitions(self, ray_setup): + """Test KV interface works correctly across multiple partitions.""" + tq_controller = TransferQueueController.remote() + + # Create keys in partition 1 + partition_1 = "partition_kv_1" + keys_1 = ["p1_key_a", "p1_key_b"] + ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys_1, partition_id=partition_1, create=True)) + + # Create keys in partition 2 + partition_2 = "partition_kv_2" + keys_2 = ["p2_key_x", "p2_key_y", "p2_key_z"] + ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys_2, partition_id=partition_2, create=True)) + + # Verify partitions are isolated + partition_1_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_1)) + partition_2_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_2)) + + assert "p1_key_a" in partition_1_snapshot.keys_mapping + assert "p1_key_b" in partition_1_snapshot.keys_mapping + assert "p2_key_x" in partition_2_snapshot.keys_mapping + assert "p2_key_z" in partition_2_snapshot.keys_mapping + + # Verify cross-partition access is isolated + assert "p2_key_x" not in partition_1_snapshot.keys_mapping + assert "p1_key_a" not in partition_2_snapshot.keys_mapping + + print("βœ“ KV interface maintains partition isolation") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_1)) + ray.get(tq_controller.clear_partition.remote(partition_2)) diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 31478ba..6eba7a0 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -941,3 +941,62 @@ def test_custom_meta_cleared_with_data(self): result = partition.get_custom_meta([0, 1]) assert 0 not in result assert 1 in result # Sample 1 should still have custom_meta + + +class TestDataPartitionStatusKvInterface: + """Unit tests for DataPartitionStatus KV interface functionality. + + Tests for the keys_mapping and kv_retrieve_keys methods that support + key-value interface operations within a partition. + """ + + def test_kv_retrieve_keys_with_existing_keys(self): + """Test kv_retrieve_keys returns correct global_indexes for existing keys.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + # Simulate keys being registered (as would happen during kv_put) + partition.keys_mapping = {"key_a": 0, "key_b": 1, "key_c": 2} + + # Retrieve keys + global_indexes = partition.kv_retrieve_keys(["key_a", "key_b", "key_c"]) + + assert global_indexes == [0, 1, 2] + + def test_kv_retrieve_keys_with_nonexistent_keys(self): + """Test kv_retrieve_keys returns None for keys that don't exist.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + # Simulate some keys being registered + partition.keys_mapping = {"existing_key": 5} + + # Retrieve mixed existing and non-existing keys + global_indexes = partition.kv_retrieve_keys(["existing_key", "nonexistent_key"]) + + assert global_indexes == [5, None] + + def test_kv_retrieve_keys_empty_list(self): + """Test kv_retrieve_keys handles empty key list.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + global_indexes = partition.kv_retrieve_keys([]) + + assert global_indexes == [] + + def test_kv_retrieve_keys_partial_match(self): + """Test kv_retrieve_keys with partial key matches.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + partition.keys_mapping = {"key_1": 10, "key_2": 20, "key_3": 30} + + # Request only some of the keys + global_indexes = partition.kv_retrieve_keys(["key_1", "key_3"]) + + assert global_indexes == [10, 30] diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index 083e16d..ec5b29c 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -27,6 +27,7 @@ from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402 +from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 def get_meta(data, global_indexes=None): @@ -41,7 +42,7 @@ def get_meta(data, global_indexes=None): name=field_name, dtype=tensor.dtype if isinstance(tensor, torch.Tensor) else None, shape=tensor.shape if isinstance(tensor, torch.Tensor) else None, - production_status=1, + production_status=ProductionStatus.READY_FOR_CONSUME, ) fields_dict[field_name] = field_meta sample = SampleMeta( @@ -163,8 +164,8 @@ def test_merge_tensors_to_tensordict(mock_create, test_data): assert complex_tensordict[key] == complex_data[key] -def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data): - """Test _get_shape_type_custom_backend_meta_list returns correct shapes and dtypes without custom_meta.""" +def test_get_shape_type_custom_backend_meta_list_without_custom_backend_meta(test_data): + """Test _get_shape_type_custom_backend_meta_list returns correct shapes and dtypes without custom_backend_meta.""" shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list( test_data["metadata"] ) @@ -185,7 +186,7 @@ def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data): torch.Size([2]), # text[2] ] expected_dtypes = [torch.int64] * (len(test_data["field_names"]) * len(test_data["global_indexes"])) - # No custom_meta provided, so all should be None + # No custom_backend_meta provided, so all should be None expected_custom_backend_meta = [None] * (len(test_data["field_names"]) * len(test_data["global_indexes"])) assert shapes == expected_shapes @@ -193,9 +194,9 @@ def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data): assert custom_backend_meta_list == expected_custom_backend_meta -def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data): - """Test _get_shape_type_custom_meta_list returns correct custom_meta when provided.""" - # Add custom_meta to metadata +def test_get_shape_type_custom_backend_meta_list_with_custom_backend_meta(test_data): + """Test _get_shape_type_custom_backend_meta_list returns correct custom_backend_meta when provided.""" + # Add custom_backend_meta to metadata custom_backend_meta = { 8: {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}}, 9: {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}}, @@ -206,8 +207,8 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data): shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) - # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index - expected_custom_meta = [ + # Check custom_backend_meta - order is label, mask, text (sorted alphabetically) by global_index + expected_custom_backend_meta = [ {"key2": "value2"}, # label, global_index=8 {"key5": "value5"}, # label, global_index=9 {"key8": "value8"}, # label, global_index=10 @@ -218,15 +219,15 @@ def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data): {"key4": "value4"}, # text, global_index=9 {"key7": "value7"}, # text, global_index=10 ] - assert custom_backend_meta_list == expected_custom_meta + assert custom_backend_meta_list == expected_custom_backend_meta -def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_data): - """Test _get_shape_type_custom_backend_meta_list handles partial custom_meta correctly.""" - # Add custom_meta only for some global_indexes and fields +def test_get_shape_type_custom_backend_meta_list_with_partial_custom_backend_meta(test_data): + """Test _get_shape_type_custom_backend_meta_list handles partial custom_backend_meta correctly.""" + # Add custom_backend_meta only for some global_indexes and fields custom_backend_meta = { 8: {"text": {"key1": "value1"}}, # Only text field - # global_index 9 has no custom_meta + # global_index 9 has no custom_backend_meta 10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only } metadata = test_data["metadata"] @@ -234,17 +235,17 @@ def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_d shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) - # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index + # Check custom_backend_meta - order is label, mask, text (sorted alphabetically) by global_index expected_custom_backend_meta = [ - None, # label, global_index=8 (not in custom_meta) - None, # label, global_index=9 (not in custom_meta) + None, # label, global_index=8 (not in custom_backend_meta) + None, # label, global_index=9 (not in custom_backend_meta) {"key2": "value2"}, # label, global_index=10 - None, # mask, global_index=8 (not in custom_meta) - None, # mask, global_index=9 (not in custom_meta) + None, # mask, global_index=8 (not in custom_backend_meta) + None, # mask, global_index=9 (not in custom_backend_meta) {"key3": "value3"}, # mask, global_index=10 {"key1": "value1"}, # text, global_index=8 - None, # text, global_index=9 (not in custom_meta) - None, # text, global_index=10 (not in custom_meta for text) + None, # text, global_index=9 (not in custom_backend_meta) + None, # text, global_index=10 (not in custom_backend_meta for text) ] assert custom_backend_meta_list == expected_custom_backend_meta @@ -279,13 +280,13 @@ def test_data_for_put_data(): @patch.object(KVStorageManager, "_connect_to_controller", lambda self: None) @patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock) -def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_for_put_data): - """Test that put_data correctly processes custom_meta returned by storage client.""" +def test_put_data_with_custom_backend_meta_from_storage_client(mock_notify, test_data_for_put_data): + """Test that put_data correctly processes custom_backend_meta returned by storage client.""" # Create a mock storage client mock_storage_client = MagicMock() - # Simulate storage client returning custom_meta (one per key) + # Simulate storage client returning custom_backend_meta (one per key) # Keys order: label[0,1,2], text[0,1,2] (sorted by field name) - mock_custom_meta = [ + mock_custom_backend_meta = [ {"storage_key": "0@label"}, {"storage_key": "1@label"}, {"storage_key": "2@label"}, @@ -293,7 +294,7 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo {"storage_key": "1@text"}, {"storage_key": "2@text"}, ] - mock_storage_client.put.return_value = mock_custom_meta + mock_storage_client.put.return_value = mock_custom_backend_meta # Create manager with mocked dependencies config = {"client_name": "MockClient"} @@ -314,34 +315,35 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo assert keys == expected_keys assert len(values) == 6 - # Verify notify_data_update was called with correct custom_meta structure + # Verify notify_data_update was called with correct custom_backend_meta structure mock_notify.assert_called_once() notify_call_args = mock_notify.call_args - per_field_custom_meta = notify_call_args[0][5] # 6th positional argument + per_field_custom_backend_meta = notify_call_args[0][5] # 6th positional argument - # Verify custom_meta is structured correctly: {global_index: {field: meta}} - assert 0 in per_field_custom_meta - assert 1 in per_field_custom_meta - assert 2 in per_field_custom_meta + # Verify custom_backend_meta is structured correctly: {global_index: {field: meta}} + assert 0 in per_field_custom_backend_meta + assert 1 in per_field_custom_backend_meta + assert 2 in per_field_custom_backend_meta - assert per_field_custom_meta[0]["label"] == {"storage_key": "0@label"} - assert per_field_custom_meta[0]["text"] == {"storage_key": "0@text"} - assert per_field_custom_meta[1]["label"] == {"storage_key": "1@label"} - assert per_field_custom_meta[1]["text"] == {"storage_key": "1@text"} - assert per_field_custom_meta[2]["label"] == {"storage_key": "2@label"} - assert per_field_custom_meta[2]["text"] == {"storage_key": "2@text"} + assert per_field_custom_backend_meta[0]["label"] == {"storage_key": "0@label"} + assert per_field_custom_backend_meta[0]["text"] == {"storage_key": "0@text"} + assert per_field_custom_backend_meta[1]["label"] == {"storage_key": "1@label"} + assert per_field_custom_backend_meta[1]["text"] == {"storage_key": "1@text"} + assert per_field_custom_backend_meta[2]["label"] == {"storage_key": "2@label"} + assert per_field_custom_backend_meta[2]["text"] == {"storage_key": "2@text"} - # Verify metadata was updated with custom_meta - all_custom_meta = test_data_for_put_data["metadata"].get_all_custom_meta() - assert all_custom_meta[0]["label"] == {"storage_key": "0@label"} - assert all_custom_meta[2]["text"] == {"storage_key": "2@text"} + # Verify metadata was updated with custom_backend_meta + all_custom_backend_meta = test_data_for_put_data["metadata"]._custom_backend_meta + assert len(all_custom_backend_meta) == 3 + assert all_custom_backend_meta[0]["label"] == {"storage_key": "0@label"} + assert all_custom_backend_meta[2]["text"] == {"storage_key": "2@text"} @patch.object(KVStorageManager, "_connect_to_controller", lambda self: None) @patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock) -def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data): - """Test that put_data works correctly when storage client returns no custom_meta.""" - # Create a mock storage client that returns None for custom_meta +def test_put_data_without_custom_backend_meta(mock_notify, test_data_for_put_data): + """Test that put_data works correctly when storage client returns no custom_backend_meta.""" + # Create a mock storage client that returns None for custom_backend_meta mock_storage_client = MagicMock() mock_storage_client.put.return_value = None @@ -353,19 +355,19 @@ def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data): # Run put_data asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"])) - # Verify notify_data_update was called with empty dict for custom_meta + # Verify notify_data_update was called with empty dict for custom_backend_meta mock_notify.assert_called_once() notify_call_args = mock_notify.call_args - per_field_custom_meta = notify_call_args[0][5] # 6th positional argument - assert per_field_custom_meta == {} + per_field_custom_backend_meta = notify_call_args[0][5] # 6th positional argument + assert per_field_custom_backend_meta == {} @patch.object(KVStorageManager, "_connect_to_controller", lambda self: None) -def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_data): - """Test that put_data raises ValueError when custom_meta length doesn't match keys.""" - # Create a mock storage client that returns mismatched custom_meta length +def test_put_data_custom_backend_meta_length_mismatch_raises_error(test_data_for_put_data): + """Test that put_data raises ValueError when custom_backend_meta length doesn't match keys.""" + # Create a mock storage client that returns mismatched custom_backend_meta length mock_storage_client = MagicMock() - # Return only 3 custom_meta entries when 6 are expected + # Return only 3 custom_backend_meta entries when 6 are expected mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}] # Create manager with mocked dependencies diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 2bbf40c..2a129b5 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -27,7 +27,7 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 +from transfer_queue.metadata import BatchMeta, FieldMeta, KVBatchMeta, SampleMeta # noqa: E402 from transfer_queue.utils.enum_utils import ProductionStatus # noqa: E402 @@ -788,64 +788,6 @@ def test_batch_meta_select_samples_with_extra_info(self): # ===================================================== # Custom Meta Tests # ===================================================== - - def test_batch_meta_set_custom_meta_basic(self): - """Test set_custom_meta sets metadata for a sample by global_index.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - "field_b": FieldMeta(name="field_b", dtype=torch.int64, shape=(3,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Set custom_meta for sample 0 - batch.set_custom_meta(0, {"sample_score": 0.9, "quality": "high"}) - - result = batch.get_all_custom_meta() - assert 0 in result - assert result[0]["sample_score"] == 0.9 - assert result[0]["quality"] == "high" - # Sample 1 should not have custom_meta - assert 1 not in result - - def test_batch_meta_set_custom_meta_overwrites(self): - """Test set_custom_meta overwrites existing metadata.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Set initial custom_meta - batch.set_custom_meta(0, {"sample_score": 0.9, "quality": "high"}) - - # Overwrite with new custom_meta - batch.set_custom_meta(0, {"sample_score": 0.1, "quality": "low"}) - - result = batch.get_all_custom_meta() - assert result[0]["sample_score"] == 0.1 - assert result[0]["quality"] == "low" - - def test_batch_meta_set_custom_meta_invalid_global_index(self): - """Test set_custom_meta raises error for invalid global_index.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Try to set with non-existent global index - with pytest.raises(ValueError) as exc_info: - batch.set_custom_meta(999, {"sample_score": 0.9}) - assert "not found in global_indexes" in str(exc_info.value) - def test_batch_meta_update_custom_meta(self): """Test update_custom_meta adds metadata for different global indices.""" fields = { @@ -859,10 +801,7 @@ def test_batch_meta_update_custom_meta(self): batch = BatchMeta(samples=samples) # Initial custom_meta for sample 0 - batch.update_custom_meta({0: {"sample_score": 0.9}}) - - # Update with metadata for sample 1 - batch.update_custom_meta({1: {"sample_score": 0.1}}) + batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}]) result = batch.get_all_custom_meta() assert result[0]["sample_score"] == 0.9 @@ -879,10 +818,10 @@ def test_batch_meta_update_custom_meta_overwrites(self): batch = BatchMeta(samples=samples) # Initial custom_meta - batch.update_custom_meta({0: {"sample_score": 0.9, "quality": "high"}}) + batch.update_custom_meta([{"sample_score": 0.9, "quality": "high"}]) # Update with new value for same field - dict.update replaces - batch.update_custom_meta({0: {"sample_score": 0.1, "quality": "low"}}) + batch.update_custom_meta([{"sample_score": 0.1, "quality": "low"}]) result = batch.get_all_custom_meta() assert result[0]["sample_score"] == 0.1 @@ -899,7 +838,7 @@ def test_batch_meta_update_custom_meta_with_none(self): batch = BatchMeta(samples=samples) # Set initial value - batch.update_custom_meta({0: {"sample_score": 0.9}}) + batch.update_custom_meta([{"sample_score": 0.9}]) # Update with None should not change anything batch.update_custom_meta(None) @@ -907,40 +846,6 @@ def test_batch_meta_update_custom_meta_with_none(self): result = batch.get_all_custom_meta() assert result[0]["sample_score"] == 0.9 - def test_batch_meta_update_custom_meta_with_empty_dict(self): - """Test update_custom_meta with empty dict does nothing.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Set initial value - batch.update_custom_meta({0: {"sample_score": 0.9}}) - - # Update with empty dict should not change anything - batch.update_custom_meta({}) - - result = batch.get_all_custom_meta() - assert result[0]["sample_score"] == 0.9 - - def test_batch_meta_update_custom_meta_invalid_global_index(self): - """Test update_custom_meta raises error for invalid global_index.""" - fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), - } - samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - ] - batch = BatchMeta(samples=samples) - - # Try to update with non-existent global index - with pytest.raises(ValueError) as exc_info: - batch.update_custom_meta({999: {"sample_score": 0.9}}) - assert "non-exist global_indexes" in str(exc_info.value) - def test_batch_meta_clear_custom_meta(self): """Test clear_custom_meta removes all custom metadata.""" fields = { @@ -953,14 +858,13 @@ def test_batch_meta_clear_custom_meta(self): batch = BatchMeta(samples=samples) # Set custom_meta - batch.set_custom_meta(0, {"sample_score": 0.9}) - batch.set_custom_meta(1, {"sample_score": 0.1}) + batch.update_custom_meta([{"sample_score": 0.9}, {"sample_score": 0.1}]) # Clear all batch.clear_custom_meta() result = batch.get_all_custom_meta() - assert result == {} + assert result == [{}, {}] def test_batch_meta_get_all_custom_meta_returns_deep_copy(self): """Test get_all_custom_meta returns a deep copy.""" @@ -972,7 +876,7 @@ def test_batch_meta_get_all_custom_meta_returns_deep_copy(self): ] batch = BatchMeta(samples=samples) - custom_meta = {0: {"sample_score": 0.9, "nested": {"value": 1}}} + custom_meta = [{"sample_score": 0.9, "nested": {"value": 1}}] batch.update_custom_meta(custom_meta) # Get all custom_meta @@ -997,7 +901,7 @@ def test_batch_meta_get_all_custom_meta_empty(self): batch = BatchMeta(samples=samples) result = batch.get_all_custom_meta() - assert result == {} + assert result == [{}] def test_batch_meta_custom_meta_with_nested_data(self): """Test custom_meta supports nested dictionary data.""" @@ -1013,7 +917,7 @@ def test_batch_meta_custom_meta_with_nested_data(self): "model_info": {"name": "llama", "version": "7b", "config": {"hidden_size": 4096, "num_layers": 32}}, "tags": ["training", "inference"], } - batch.set_custom_meta(0, nested_meta) + batch.update_custom_meta([nested_meta]) result = batch.get_all_custom_meta() assert result[0]["model_info"]["name"] == "llama" @@ -1141,3 +1045,290 @@ def test_batch_meta_concat_validation_error(self): with pytest.raises(ValueError) as exc_info: BatchMeta.concat([batch1, batch2], validate=True) assert "Field names do not match" in str(exc_info.value) + + +class TestKVBatchMeta: + """KVBatchMeta Tests""" + + def test_kv_batch_meta_basic_init(self): + """Example: Basic KVBatchMeta initialization.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{"sample_id": 0}, {"sample_id": 1}, {"sample_id": 2}], + partition_id="partition_0", + fields=["field1", "field2"], + ) + + assert kv_meta.size == 3 + assert len(kv_meta) == 3 + assert kv_meta.keys == ["key1", "key2", "key3"] + assert kv_meta.partition_id == "partition_0" + assert kv_meta.fields == ["field1", "field2"] + + def test_kv_batch_meta_empty_init(self): + """Example: Empty KVBatchMeta initialization.""" + kv_meta = KVBatchMeta() + + assert kv_meta.size == 0 + assert len(kv_meta) == 0 + assert kv_meta.keys == [] + assert kv_meta.tags == [] + assert kv_meta.partition_id is None + assert kv_meta.fields is None + + def test_kv_batch_meta_init_validation_keys_tags_mismatch(self): + """Example: Init validation catches keys and tags length mismatch.""" + with pytest.raises(ValueError) as exc_info: + KVBatchMeta( + keys=["key1", "key2"], + tags=[{"sample_id": 0}], # Only one tag + ) + assert "keys and tags must have same length" in str(exc_info.value) + + def test_kv_batch_meta_init_validation_duplicate_keys(self): + """Example: Init validation catches duplicate keys.""" + with pytest.raises(ValueError) as exc_info: + KVBatchMeta( + keys=["key1", "key1"], + tags=[{"sample_id": 0}, {"sample_id": 1}], + partition_id="partition_0", + ) + assert "Got duplicated keys" in str(exc_info.value) + + def test_kv_batch_meta_init_validation_duplicate_fields(self): + """Example: Init validation catches duplicate fields.""" + with pytest.raises(ValueError) as exc_info: + KVBatchMeta( + keys=["key1"], + tags=[{"sample_id": 0}], + partition_id="partition_0", + fields=["field1", "field1"], + ) + assert "Got duplicated fields" in str(exc_info.value) + + def test_kv_batch_meta_select_keys(self): + """Example: Select specific keys from KVBatchMeta.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{"idx": 0}, {"idx": 1}, {"idx": 2}], + partition_id="partition_0", + fields=["field1", "field2"], + extra_info={"test": "value"}, + ) + + selected = kv_meta.select_keys(["key1", "key3"]) + + assert selected.keys == ["key1", "key3"] + assert selected.tags == [{"idx": 0}, {"idx": 2}] + assert selected.partition_id == "partition_0" + assert selected.fields == ["field1", "field2"] + assert selected.extra_info == {"test": "value"} + + def test_kv_batch_meta_select_keys_validation_duplicate(self): + """Example: Select keys validation catches duplicate keys in input.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{}, {}, {}], + ) + + with pytest.raises(ValueError) as exc_info: + kv_meta.select_keys(["key1", "key1"]) + assert "Contain duplicate keys" in str(exc_info.value) + + def test_kv_batch_meta_select_keys_validation_nonexistent(self): + """Example: Select keys validation catches non-existent keys.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{}, {}, {}], + ) + + with pytest.raises(RuntimeError) as exc_info: + kv_meta.select_keys(["key1", "nonexistent"]) + assert "not found in current batch" in str(exc_info.value) + + def test_kv_batch_meta_reorder(self): + """Example: Reorder samples in KVBatchMeta.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{"idx": 0}, {"idx": 1}, {"idx": 2}], + ) + + kv_meta.reorder([2, 0, 1]) + + assert kv_meta.keys == ["key3", "key1", "key2"] + assert kv_meta.tags == [{"idx": 2}, {"idx": 0}, {"idx": 1}] + + def test_kv_batch_meta_reorder_validation_size_mismatch(self): + """Example: Reorder validation catches size mismatch.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{}, {}, {}], + ) + + with pytest.raises(ValueError) as exc_info: + kv_meta.reorder([0, 1]) # Only 2 indexes for 3 samples + assert "does not match" in str(exc_info.value) + + def test_kv_batch_meta_reorder_validation_duplicate_indexes(self): + """Example: Reorder validation catches duplicate indexes.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2", "key3"], + tags=[{}, {}, {}], + ) + + with pytest.raises(ValueError) as exc_info: + kv_meta.reorder([0, 0, 1]) # Duplicate index 0 + assert "Contain duplicate indexes" in str(exc_info.value) + + def test_kv_batch_meta_chunk(self): + """Example: Split KVBatchMeta into multiple chunks.""" + kv_meta = KVBatchMeta( + keys=[f"key{i}" for i in range(10)], + tags=[{"idx": i} for i in range(10)], + partition_id="partition_0", + fields=["field1"], + extra_info={"test": "value"}, + ) + + chunks = kv_meta.chunk(3) + + assert len(chunks) == 3 + assert len(chunks[0]) == 4 # First chunk gets extra element + assert len(chunks[1]) == 3 + assert len(chunks[2]) == 3 + + # Verify partition_id and fields are preserved + assert chunks[0].partition_id == "partition_0" + assert chunks[0].fields == ["field1"] + assert chunks[0].extra_info == {"test": "value"} + + # Verify keys and tags are correctly chunked + assert chunks[0].keys == ["key0", "key1", "key2", "key3"] + assert chunks[0].tags == [{"idx": 0}, {"idx": 1}, {"idx": 2}, {"idx": 3}] + assert chunks[1].keys == ["key4", "key5", "key6"] + assert chunks[1].tags == [{"idx": 4}, {"idx": 5}, {"idx": 6}] + + def test_kv_batch_meta_chunk_with_more_chunks_than_samples(self): + """Example: Chunking when chunks > samples produces empty chunks.""" + kv_meta = KVBatchMeta( + keys=["key1", "key2"], + tags=[{"idx": 0}, {"idx": 1}], + ) + + chunks = kv_meta.chunk(5) + + assert len(chunks) == 5 + assert len(chunks[0]) == 1 + assert len(chunks[1]) == 1 + assert len(chunks[2]) == 0 + assert len(chunks[3]) == 0 + assert len(chunks[4]) == 0 + + def test_kv_batch_meta_concat(self): + """Example: Concatenate multiple KVBatchMeta chunks.""" + kv_meta1 = KVBatchMeta( + keys=["key0", "key1"], + tags=[{"idx": 0}, {"idx": 1}], + partition_id="partition_0", + fields=["field1"], + extra_info={"test": "value1"}, + ) + + kv_meta2 = KVBatchMeta( + keys=["key2", "key3"], + tags=[{"idx": 2}, {"idx": 3}], + partition_id="partition_0", + fields=["field1"], + extra_info={"test": "value2"}, + ) + + result = KVBatchMeta.concat([kv_meta1, kv_meta2]) + + assert result.size == 4 + assert result.keys == ["key0", "key1", "key2", "key3"] + assert result.tags == [{"idx": 0}, {"idx": 1}, {"idx": 2}, {"idx": 3}] + assert result.partition_id == "partition_0" + assert result.fields == ["field1"] + + def test_kv_batch_meta_concat_with_empty_chunks(self): + """Example: Concat handles empty KVBatchMeta chunks gracefully.""" + kv_meta1 = KVBatchMeta() + kv_meta2 = KVBatchMeta(keys=["key0"], tags=[{"idx": 0}]) + kv_meta3 = KVBatchMeta() + + result = KVBatchMeta.concat([kv_meta1, kv_meta2, kv_meta3]) + + assert result.size == 1 + assert result.keys == ["key0"] + assert result.tags == [{"idx": 0}] + + def test_kv_batch_meta_concat_validation_field_mismatch(self): + """Example: Concat validation catches field name mismatches.""" + kv_meta1 = KVBatchMeta( + keys=["key0"], + tags=[{}], + fields=["field1"], + ) + kv_meta2 = KVBatchMeta( + keys=["key1"], + tags=[{}], + fields=["field2"], # Different field + ) + + with pytest.raises(ValueError) as exc_info: + KVBatchMeta.concat([kv_meta1, kv_meta2]) + assert "Field names do not match" in str(exc_info.value) + + def test_kv_batch_meta_concat_validation_partition_mismatch(self): + """Example: Concat validation catches partition_id mismatches.""" + kv_meta1 = KVBatchMeta( + keys=["key0"], + tags=[{}], + partition_id="partition_0", + ) + kv_meta2 = KVBatchMeta( + keys=["key1"], + tags=[{}], + partition_id="partition_1", # Different partition + ) + + with pytest.raises(ValueError) as exc_info: + KVBatchMeta.concat([kv_meta1, kv_meta2]) + assert "Partition do not match" in str(exc_info.value) + + def test_kv_batch_meta_concat_empty_list(self): + """Example: Concat with empty list returns empty KVBatchMeta.""" + result = KVBatchMeta.concat([]) + + assert result.size == 0 + assert result.keys == [] + assert result.tags == [] + + def test_kv_batch_meta_deepcopy_tags(self): + """Example: Tags are deep copied to prevent mutation.""" + original_tags = [{"data": [1, 2, 3]}] + kv_meta = KVBatchMeta( + keys=["key1"], + tags=original_tags, + ) + + # Modify the tag in the KVBatchMeta + kv_meta.tags[0]["data"].append(4) + + # Original should not be modified + assert original_tags[0]["data"] == [1, 2, 3] + + def test_kv_batch_meta_deepcopy_extra_info(self): + """Example: Extra info is deep copied to prevent mutation.""" + original_extra = {"nested": {"value": 1}} + kv_meta = KVBatchMeta( + keys=["key1"], + tags=[{}], + extra_info=original_extra, + ) + + # Modify extra_info + kv_meta.extra_info["nested"]["value"] = 999 + + # Original should not be modified + assert original_extra["nested"]["value"] == 1 diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index 043b506..ed43e41 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -27,7 +27,7 @@ parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import SimpleStorageUnit # noqa: E402 +from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402 from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType # noqa: E402 diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 592ef0d..4732e84 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -16,63 +16,64 @@ import os from .client import TransferQueueClient -from .controller import TransferQueueController from .dataloader import StreamingDataLoader, StreamingDataset from .interface import ( - async_clear_partition, - async_clear_samples, - async_get_data, - async_get_meta, - async_put, - async_set_custom_meta, - clear_partition, - clear_samples, + async_kv_batch_get, + async_kv_batch_put, + async_kv_clear, + async_kv_list, + async_kv_put, close, - get_data, - get_meta, + get_client, init, - put, - set_custom_meta, + kv_batch_get, + kv_batch_put, + kv_clear, + kv_list, + kv_put, ) -from .metadata import BatchMeta +from .metadata import BatchMeta, KVBatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler from .sampler.rank_aware_sampler import RankAwareSampler from .sampler.sequential_sampler import SequentialSampler -from .storage import SimpleStorageUnit -from .utils.common import get_placement_group -from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info -__all__ = [ - "init", - "get_meta", - "get_data", - "put", - "set_custom_meta", - "clear_samples", - "clear_partition", - "async_get_meta", - "async_get_data", - "async_put", - "async_set_custom_meta", - "async_clear_samples", - "async_clear_partition", - "close", -] + [ - "TransferQueueClient", - "StreamingDataset", - "StreamingDataLoader", - "BatchMeta", - "TransferQueueController", - "SimpleStorageUnit", - "ZMQServerInfo", - "process_zmq_server_info", - "get_placement_group", - "BaseSampler", - "GRPOGroupNSampler", - "SequentialSampler", - "RankAwareSampler", -] +__all__ = ( + [ + # High-Level KV Interface + "init", + "close", + "kv_put", + "kv_batch_put", + "kv_batch_get", + "kv_list", + "kv_clear", + "async_kv_put", + "async_kv_batch_put", + "async_kv_batch_get", + "async_kv_list", + "async_kv_clear", + "KVBatchMeta", + ] + + [ + # High-Level StreamingDataLoader Interface + "StreamingDataset", + "StreamingDataLoader", + ] + + [ + # Low-Level Native Interface + "get_client", + "BatchMeta", + "TransferQueueClient", + ] + + [ + # Sampler + "BaseSampler", + "GRPOGroupNSampler", + "SequentialSampler", + "RankAwareSampler", + ] +) version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index cd9c6e8..b06f90c 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -153,6 +153,7 @@ async def wrapper(self, *args, **kwargs): return decorator + # ==================== Basic API ==================== @dynamic_socket(socket_name="request_handle_socket") async def async_get_meta( self, @@ -215,7 +216,7 @@ async def async_get_meta( """ assert socket is not None request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_META, + request_type=ZMQRequestType.GET_META, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -259,8 +260,8 @@ async def async_set_custom_meta( Args: metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() or - BatchMeta.set_custom_meta() before calling this method. + The custom_meta should be set using BatchMeta.update_custom_meta() + before calling this method. socket: ZMQ async socket for message transmission (injected by decorator) Raises: @@ -269,7 +270,7 @@ async def async_set_custom_meta( Example: >>> # Create batch with custom metadata >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) - >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) >>> asyncio.run(client.async_set_custom_meta(batch_meta)) """ assert socket is not None @@ -291,10 +292,13 @@ async def async_set_custom_meta( partition_custom_meta: dict[str, dict[int, dict]] = {pid: {} for pid in set(metadata.partition_ids)} for meta in metadata_chunks: - partition_custom_meta[meta.partition_ids[0]].update(meta.get_all_custom_meta()) + custom_meta = meta.get_all_custom_meta() + partition_custom_meta[meta.partition_ids[0]].update( + {meta.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta))} + ) request_msg = ZMQMessage.create( - request_type=ZMQRequestType.SET_CUSTOM_META, + request_type=ZMQRequestType.SET_CUSTOM_META, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -532,7 +536,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): """ request_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_META, + request_type=ZMQRequestType.CLEAR_META, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={"global_indexes": metadata.global_indexes, "partition_ids": metadata.partition_ids}, @@ -560,7 +564,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta RuntimeError: If controller returns error response """ request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_PARTITION_META, + request_type=ZMQRequestType.GET_PARTITION_META, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={"partition_id": partition_id}, @@ -605,6 +609,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None): if response_msg.request_type != ZMQRequestType.CLEAR_PARTITION_RESPONSE: raise RuntimeError(f"Failed to clear partition {partition_id} in controller.") + # ==================== Status Query API ==================== @dynamic_socket(socket_name="request_handle_socket") async def async_get_consumption_status( self, @@ -638,7 +643,7 @@ async def async_get_consumption_status( assert socket is not None request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_CONSUMPTION, + request_type=ZMQRequestType.GET_CONSUMPTION, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -675,7 +680,7 @@ async def async_get_production_status( partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, ) -> tuple[Optional[Tensor], Optional[Tensor]]: - """Get production status for current partition for specific fields. + """Get production status for specific data fields and partition. Args: data_fields: Data fields to check production status for @@ -700,7 +705,7 @@ async def async_get_production_status( """ assert socket is not None request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_PRODUCTION, + request_type=ZMQRequestType.GET_PRODUCTION, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -765,6 +770,41 @@ async def async_check_consumption_status( return False return torch.all(consumption_status == 1).item() + async def async_check_production_status( + self, + data_fields: list[str], + partition_id: str, + ) -> bool: + """Check if the all specific fields of samples for current partition are ready + (produced) for consumption. + + Args: + data_fields: Data fields to check production status for + partition_id: Partition id to check production status for + + Returns: + bool: True if all samples have been produced and ready, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Check if all samples are ready for consumption + >>> is_ready = asyncio.run(client.async_check_production_status( + ... data_fields=["input_ids", "attention_mask"], + ... partition_id="train_0" + ... )) + >>> print(f"All samples ready: {is_ready}") + """ + _, production_status = await self.async_get_production_status( + data_fields=data_fields, + partition_id=partition_id, + ) + + if production_status is None: + return False + return torch.all(production_status == 1).item() + @dynamic_socket(socket_name="request_handle_socket") async def async_reset_consumption( self, @@ -772,17 +812,22 @@ async def async_reset_consumption( task_name: Optional[str] = None, socket: Optional[zmq.asyncio.Socket] = None, ) -> bool: - """Reset consumption status for a partition, allowing data to be re-consumed. - This is useful for debugging scenarios where the same rollout data needs to be - trained multiple times without regenerating the data. + """Asynchronously reset consumption status for a partition. + + This allows the same data to be re-consumed, useful for debugging scenarios + where the same rollout data needs to be trained multiple times. + Args: partition_id: Partition id to reset consumption status for task_name: Name of the task to reset. If None, resets all tasks. socket: ZMQ async socket for message transmission (injected by decorator) + Returns: bool: True if reset was successful, False otherwise + Raises: RuntimeError: If communication fails or controller returns error response + Example: >>> # Reset consumption for train task to re-train on same data >>> success = asyncio.run(client.async_reset_consumption( @@ -796,7 +841,7 @@ async def async_reset_consumption( if task_name is not None: body["task_name"] = task_name request_msg = ZMQMessage.create( - request_type=ZMQRequestType.RESET_CONSUMPTION, + request_type=ZMQRequestType.RESET_CONSUMPTION, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body=body, @@ -822,41 +867,6 @@ async def async_reset_consumption( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e - async def async_check_production_status( - self, - data_fields: list[str], - partition_id: str, - ) -> bool: - """Check if the all specific fields of samples for current partition are ready - (produced) for consumption. - - Args: - data_fields: Data fields to check production status for - partition_id: Partition id to check production status for - - Returns: - bool: True if all samples have been produced and ready, False otherwise - - Raises: - RuntimeError: If communication fails or controller returns error response - - Example: - >>> # Check if all samples are ready for consumption - >>> is_ready = asyncio.run(client.async_check_production_status( - ... data_fields=["input_ids", "attention_mask"], - ... partition_id="train_0" - ... )) - >>> print(f"All samples ready: {is_ready}") - """ - _, production_status = await self.async_get_production_status( - data_fields=data_fields, - partition_id=partition_id, - ) - - if production_status is None: - return False - return torch.all(production_status == 1).item() - @dynamic_socket(socket_name="request_handle_socket") async def async_get_partition_list( self, @@ -869,9 +879,13 @@ async def async_get_partition_list( Returns: list[str]: List of partition ids managed by the controller + + Example: + >>> partition_ids = asyncio.run(client.get_partition_list()) + >>> print(f"Available partitions: {partition_ids}") """ request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_LIST_PARTITIONS, + request_type=ZMQRequestType.GET_LIST_PARTITIONS, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={}, @@ -898,6 +912,132 @@ async def async_get_partition_list( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e + # ==================== KV Interface API ==================== + @dynamic_socket(socket_name="request_handle_socket") + async def async_kv_retrieve_keys( + self, + keys: list[str] | str, + partition_id: str, + create: bool = False, + socket: Optional[zmq.asyncio.Socket] = None, + ) -> BatchMeta: + """Asynchronously retrieve BatchMeta from the controller using user-specified keys. + + Args: + keys: List of keys to retrieve from the controller + partition_id: The ID of the logical partition to search for keys. + create: Whether to register new keys if not found. + socket: ZMQ socket (injected by decorator) + + Returns: + metadata: BatchMeta of the corresponding keys + + Raises: + TypeError: If `keys` is not a list of string or a string + """ + + if isinstance(keys, str): + keys = [keys] + elif isinstance(keys, list): + if len(keys) < 1: + raise ValueError("Received an empty list as keys.") + # validate all the elements are str + if not all(isinstance(k, str) for k in keys): + raise TypeError("Not all elements in `keys` are strings.") + else: + raise TypeError("Only string or list of strings are allowed as `keys`.") + + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.KV_RETRIEVE_KEYS, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={ + "keys": keys, + "partition_id": partition_id, + "create": create, + }, + ) + + try: + assert socket is not None + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart() + response_msg = ZMQMessage.deserialize(response_serialized) + logger.debug( + f"[{self.client_id}]: Client get kv_retrieve_keys response: {response_msg} " + f"from controller {self._controller.id}" + ) + + if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE: + metadata = response_msg.body.get("metadata", BatchMeta.empty()) + metadata = BatchMeta.from_dict(metadata) if isinstance(metadata, dict) else metadata + return metadata + else: + raise RuntimeError( + f"[{self.client_id}]: Failed to retrieve keys from controller {self._controller.id}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e + + @dynamic_socket(socket_name="request_handle_socket") + async def async_kv_list( + self, + partition_id: Optional[str] = None, + socket: Optional[zmq.asyncio.Socket] = None, + ) -> dict[str, dict[str, Any]]: + """Asynchronously retrieve keys and custom_meta from the controller for one or all partitions. + + Args: + partition_id: The specific partition_id to query. + If None (default), returns keys from all partitions. + socket: ZMQ socket (injected by decorator) + + Returns: + A nested dictionary mapping partition IDs to their keys and metadata. + + Structure: + { + "partition_id": { + "key_name": { + "tag1": , + ... (other metadata) + }, + ..., + }, + ... + } + """ + + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.KV_LIST, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={ + "partition_id": partition_id, + }, + ) + + try: + assert socket is not None + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart() + response_msg = ZMQMessage.deserialize(response_serialized) + logger.debug( + f"[{self.client_id}]: Client get kv_list response: {response_msg} from controller {self._controller.id}" + ) + + if response_msg.request_type == ZMQRequestType.KV_LIST_RESPONSE: + partition_info = response_msg.body.get("partition_info", {}) + return partition_info + else: + raise RuntimeError( + f"[{self.client_id}]: Failed to list keys from controller {self._controller.id}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in kv_list: {str(e)}") from e + def close(self) -> None: """Close the client and cleanup resources including storage manager.""" try: @@ -972,67 +1112,10 @@ def wrapper(*args, **kwargs): self._get_partition_list = _make_sync(self.async_get_partition_list) self._set_custom_meta = _make_sync(self.async_set_custom_meta) self._reset_consumption = _make_sync(self.async_reset_consumption) + self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys) + self._kv_list = _make_sync(self.async_kv_list) - def put( - self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None - ) -> BatchMeta: - """Synchronously write data to storage units based on metadata. - - If metadata is not provided, it will be created automatically using insert mode - with the provided data fields and partition_id. - - During put, the custom_meta in metadata will update the corresponding custom_meta in - TransferQueue Controller. - - Note: - When using multiple workers for distributed execution, there may be data - ordering inconsistencies between workers during put operations. - - Args: - data: Data to write as TensorDict - metadata: Records the metadata of a batch of data samples, containing index and - storage unit information. If None, metadata will be auto-generated. - partition_id: Target data partition id (required if metadata is not provided) - - Returns: - BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved - metadata; will be updated in a future version to reflect the post-put state) - - Raises: - ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided - RuntimeError: If storage operation fails - - Example: - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_partition_id = "train_0" - >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = client.get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=batch_size, - ... partition_id=current_partition_id, - ... mode="fetch", - ... task_name="generate_sequences", - ... ) - >>> batch = client.get_data(batch_meta) - >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> client.put(data=output, metadata=batch_meta) - >>> - >>> # Example 2: Initial data insertion without pre-existing metadata - >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! - >>> # Please make sure the corresponding partition_id is empty before calling the async_put() - >>> # without metadata. - >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with - >>> # interleave the initial data if n_sample > 1 before calling the async_put(). - >>> original_prompts = torch.randn(batch_size, seq_len) - >>> n_samples = 4 - >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) - >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) - >>> # This will create metadata in "insert" mode internally. - >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) - """ - return self._put(data=data, metadata=metadata, partition_id=partition_id) - + # ==================== Basic API ==================== def get_meta( self, data_fields: list[str], @@ -1101,6 +1184,90 @@ def get_meta( sampling_config=sampling_config, ) + def set_custom_meta(self, metadata: BatchMeta) -> None: + """Synchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() + before calling this method. + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Create batch with custom metadata + >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=2, ...) + >>> batch_meta.update_custom_meta([{"score": 0.9}, {"score": 0.8}]) + >>> client.set_custom_meta(batch_meta) + """ + + return self._set_custom_meta(metadata=metadata) + + def put( + self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None + ) -> BatchMeta: + """Synchronously write data to storage units based on metadata. + + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. + + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. + + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. + + Args: + data: Data to write as TensorDict + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. + partition_id: Target data partition id (required if metadata is not provided) + + Returns: + BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved + metadata; will be updated in a future version to reflect the post-put state) + + Raises: + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails + + Example: + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = client.get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = client.get_data(batch_meta) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> client.put(data=output, metadata=batch_meta) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) + """ + return self._put(data=data, metadata=metadata, partition_id=partition_id) + def get_data(self, metadata: BatchMeta) -> TensorDict: """Synchronously fetch data from storage units and organize into TensorDict. @@ -1112,7 +1279,7 @@ def get_data(self, metadata: BatchMeta) -> TensorDict: - Requested data fields (e.g., "prompts", "attention_mask") Example: - >>> batch_meta = client.get_data( + >>> batch_meta = client.get_meta( ... data_fields=["prompts", "attention_mask"], ... batch_size=4, ... partition_id="train_0", @@ -1147,65 +1314,85 @@ def clear_samples(self, metadata: BatchMeta): """ return self._clear_samples(metadata=metadata) - def check_consumption_status(self, task_name: str, partition_id: str) -> bool: - """Synchronously check if all samples for a partition have been consumed by a specific task. + # ==================== Status Query API ==================== + def get_consumption_status( + self, + task_name: str, + partition_id: str, + ) -> tuple[Optional[Tensor], Optional[Tensor]]: + """Synchronously get consumption status for a specific task and partition. Args: task_name: Name of the task to check consumption for partition_id: Partition id to check consumption status for Returns: - bool: True if all samples have been consumed by the task, False otherwise + Tuple of: + - Partition global index tensor + - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. Raises: RuntimeError: If communication fails or controller returns error response Example: - >>> # Check if all samples have been consumed - >>> is_consumed = client.check_consumption_status( + >>> global_index, consumption_status = client.get_consumption_status( ... task_name="generate_sequences", ... partition_id="train_0" ... ) - >>> print(f"All samples consumed: {is_consumed}") + >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}") """ - return self._check_consumption_status(task_name=task_name, partition_id=partition_id) + return self._get_consumption_status(task_name, partition_id) - def get_consumption_status( + def get_production_status( self, - task_name: str, + data_fields: list[str], partition_id: str, ) -> tuple[Optional[Tensor], Optional[Tensor]]: - """Synchronously get consumption status for a specific task and partition. + """Synchronously get production status for specific data fields and partition. Args: - task_name: Name of the task to check consumption for - partition_id: Partition id to check consumption status for + data_fields: Data fields to check production status for + partition_id: Partition id to check production status for Returns: Tuple of: - Partition global index tensor - - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. + - Production status tensor for the specified fields. 1 for ready, 0 for not ready. + + Raises: + RuntimeError: If communication fails or controller returns error response Example: - >>> global_index, consumption_status = client.get_consumption_status( - ... task_name="generate_sequences", + >>> global_index, production_status = client.get_production_status( + ... data_fields=["input_ids", "attention_mask"], ... partition_id="train_0" ... ) - >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}") + >>> print(f"Global index: {global_index}, Production status: {production_status}") """ - return self._get_consumption_status(task_name, partition_id) + return self._get_production_status(data_fields=data_fields, partition_id=partition_id) + + def check_consumption_status(self, task_name: str, partition_id: str) -> bool: + """Synchronously check if all samples for a partition have been consumed by a specific task. - def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool: - """Synchronously reset consumption status for a partition. - This allows the same data to be re-consumed, useful for debugging scenarios - where the same rollout data needs to be trained multiple times. Args: - partition_id: Partition id to reset consumption status for - task_name: Name of the task to reset. If None, resets all tasks. + task_name: Name of the task to check consumption for + partition_id: Partition id to check consumption status for + Returns: - bool: True if reset was successful, False otherwise + bool: True if all samples have been consumed by the task, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Check if all samples have been consumed + >>> is_consumed = client.check_consumption_status( + ... task_name="generate_sequences", + ... partition_id="train_0" + ... ) + >>> print(f"All samples consumed: {is_consumed}") """ - return self._reset_consumption(partition_id, task_name) + return self._check_consumption_status(task_name=task_name, partition_id=partition_id) def check_production_status(self, data_fields: list[str], partition_id: str) -> bool: """Synchronously check if all samples for a partition are ready (produced) for consumption. @@ -1214,6 +1401,9 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> data_fields: Data fields to check production status for partition_id: Partition id to check production status for + Returns: + bool: True if all samples have been produced and ready, False otherwise + Raises: RuntimeError: If communication fails or controller returns error response @@ -1227,30 +1417,31 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> """ return self._check_production_status(data_fields=data_fields, partition_id=partition_id) - def get_production_status( - self, - data_fields: list[str], - partition_id: str, - ) -> tuple[Optional[Tensor], Optional[Tensor]]: - """Synchronously get production status for a specific data fields and partition. + def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool: + """Synchronously reset consumption status for a partition. + + This allows the same data to be re-consumed, useful for debugging scenarios + where the same rollout data needs to be trained multiple times. Args: - data_fields: Data fields to check production status for - partition_id: Partition id to check production status for + partition_id: Partition id to reset consumption status for + task_name: Name of the task to reset. If None, resets all tasks. Returns: - Tuple of: - - Partition global index tensor - - Production status tensor for the specified fields. 1 for ready, 0 for not ready. + bool: True if reset was successful, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response Example: - >>> global_index, production_status = client.get_production_status( - ... data_fields=["input_ids", "attention_mask"], - ... partition_id="train_0" + >>> # Reset consumption for train task to re-train on same data + >>> success = client.reset_consumption( + ... partition_id="train_0", + ... task_name="train" ... ) - >>> print(f"Global index: {global_index}, Production status: {production_status}") + >>> print(f"Reset successful: {success}") """ - return self._get_production_status(data_fields=data_fields, partition_id=partition_id) + return self._reset_consumption(partition_id, task_name) def get_partition_list( self, @@ -1259,32 +1450,64 @@ def get_partition_list( Returns: list[str]: List of partition ids managed by the controller + + Example: + >>> partition_ids = client.get_partition_list() + >>> print(f"Available partitions: {partition_ids}") """ return self._get_partition_list() - def set_custom_meta(self, metadata: BatchMeta) -> None: - """Synchronously send custom metadata to the controller. - - This method sends per-sample custom metadata (custom_meta) to the controller. - The custom_meta is stored in the controller and can be retrieved along with - the BatchMeta in subsequent get_meta calls. + # ==================== KV Interface API ==================== + def kv_retrieve_keys( + self, + keys: list[str] | str, + partition_id: str, + create: bool = False, + ) -> BatchMeta: + """Synchronously retrieve BatchMeta from the controller using user-specified keys. Args: - metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() or - BatchMeta.set_custom_meta() before calling this method. + keys: List of keys to retrieve from the controller + partition_id: The ID of the logical partition to search for keys. + create: Whether to register new keys if not found. + + Returns: + metadata: BatchMeta of the corresponding keys Raises: - RuntimeError: If communication fails or controller returns error response + TypeError: If `keys` is not a list of string or a string + """ - Example: - >>> # Create batch with custom metadata - >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) - >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) - >>> client.set_custom_meta(batch_meta) + return self._kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create) + + def kv_list( + self, + partition_id: Optional[str] = None, + ) -> dict[str, dict[str, Any]]: + """Synchronously retrieve keys and custom_meta from the controller for one or all partitions. + + Args: + partition_id: The specific partition_id to query. + If None (default), returns keys from all partitions. + socket: ZMQ socket (injected by decorator) + + Returns: + A nested dictionary mapping partition IDs to their keys and metadata. + + Structure: + { + "partition_id": { + "key_name": { + "tag1": , + ... (other metadata) + }, + ..., + }, + ... + } """ - return self._set_custom_meta(metadata=metadata) + return self._kv_list(partition_id=partition_id) def close(self) -> None: """Close the client and cleanup resources including event loop and thread.""" diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 763f3f4..3c7c3f9 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -214,7 +214,7 @@ class DataPartitionStatus: # Each tensor tracks which samples have been consumed by that task consumption_status: dict[str, Tensor] = field(default_factory=dict) - # Sample metadata + # Global indexes global_indexes: set[int] = field( default_factory=set ) # set of global indexes that have been added to this partition @@ -233,6 +233,10 @@ class DataPartitionStatus: # User-defined metadata that may not apply to field level custom_meta: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {} + # User-defined Keys + keys_mapping: dict[str, int] = field(default_factory=dict) # key -> global_idx + revert_keys_mapping: dict[int, str] = field(default_factory=dict) # global_idx -> key + # Threading lock for concurrency control; only for preventing mask operation error when expanding production_status. # No need to strictly lock for every read/write operation since freshness is not critical. data_status_lock: Lock = field(default_factory=Lock) @@ -536,15 +540,19 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8) # Get consumption status for requested task - consumption_status = self.consumption_status[task_name] - partition_global_index = torch.tensor( sorted(self.global_indexes | self.pre_allocated_global_indexes), dtype=torch.long ) if mask: - consumption_status = consumption_status[partition_global_index] - + if partition_global_index.numel() == 0: + empty_status = self.consumption_status[task_name].new_zeros(0) + return partition_global_index, empty_status + with self.data_status_lock: + self.ensure_samples_capacity(max(partition_global_index) + 1) + consumption_status = self.consumption_status[task_name][partition_global_index] + else: + consumption_status = self.consumption_status[task_name] return partition_global_index, consumption_status def reset_consumption(self, task_name: Optional[str] = None): @@ -826,12 +834,21 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr self.field_custom_backend_meta.pop(idx, None) self.custom_meta.pop(idx, None) + if idx in self.revert_keys_mapping: + self.keys_mapping.pop(self.revert_keys_mapping[idx], None) + self.revert_keys_mapping.pop(idx, None) + except Exception as e: logger.error( f"Error clearing data for partition {self.partition_id}: {e}. " f"Attempted to clear global_indexes: {indexes_to_release}" ) + def kv_retrieve_keys(self, keys: list[str]) -> list[int | None]: + """Translate the user-specified keys to global_indexes""" + global_indexes = [self.keys_mapping.get(k, None) for k in keys] + return global_indexes + @ray.remote(num_cpus=1) class TransferQueueController: @@ -971,17 +988,19 @@ def list_partitions(self) -> list[str]: # ==================== Partition Index Management API ==================== - def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]: + def get_partition_index_range(self, partition_id: str) -> list[int]: """ Get all indexes for a specific partition. Args: - partition: Partition identifier + partition_id: Partition identifier Returns: List of indexes allocated to the partition """ - return self.index_manager.get_indexes_for_partition(partition) + # Note: This includes the pre-allocated global_indexes for the partition. + # i.e., partition.global_indexes + partition.pre_allocated_global_indexes + return self.index_manager.get_indexes_for_partition(partition_id) # ==================== Data Production API ==================== @@ -1160,9 +1179,10 @@ def get_metadata( ) batch_global_indexes.extend(new_global_indexes) + # register global_indexes in partition + partition.global_indexes.update(batch_global_indexes) + else: - # TODO: separate this "clear" related logic into a separated mode - # clear metadata call passes empty data_fields batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) @@ -1314,7 +1334,7 @@ def generate_batch_meta( and partition.production_status is not None and partition.production_status[global_index, field_index] == 1 ): - production_status = ProductionStatus.NOT_PRODUCED + production_status = ProductionStatus.READY_FOR_CONSUME dtype = partition.get_field_dtype(global_index, field_name) shape = partition.get_field_shape(global_index, field_name) else: @@ -1340,7 +1360,7 @@ def generate_batch_meta( custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields) batch_meta = BatchMeta(samples=samples) - batch_meta.update_custom_meta(custom_meta) + batch_meta.update_custom_meta([custom_meta.get(idx, {}) for idx in batch_meta.global_indexes]) batch_meta._custom_backend_meta.update(custom_backend_meta) return batch_meta @@ -1357,7 +1377,8 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True): partition = self._get_partition(partition_id) if not partition: - raise ValueError(f"Partition {partition_id} not found") + logger.warning(f"Try to clear an non-existent partition {partition_id}!") + return global_indexes_range = list(self.index_manager.get_indexes_for_partition(partition_id)) partition.clear_data(global_indexes_range, clear_consumption) @@ -1374,13 +1395,13 @@ def reset_consumption(self, partition_id: str, task_name: Optional[str] = None): Args: partition_id: ID of the partition to reset consumption for task_name: Name of the task to reset. If None, resets all tasks. - Raises: - ValueError: If partition not found + """ logger.debug(f"[{self.controller_id}]: Resetting consumption for partition {partition_id}, task={task_name}") partition = self._get_partition(partition_id) if not partition: - raise ValueError(f"Partition {partition_id} not found") + logger.warning(f"Try to reset consumption of an non-existent partition {partition_id}!") + return partition.reset_consumption(task_name) def clear_meta( @@ -1432,6 +1453,82 @@ def clear_meta( # Release the specific indexes from index manager self.index_manager.release_indexes(partition_id, global_indexes_to_clear) + def kv_retrieve_keys( + self, + keys: list[str], + partition_id: str, + create: bool = False, + ) -> BatchMeta: + """ + Retrieve BatchMeta from the controller using a list of keys. + + Args: + keys: List of keys to retrieve from the controller + partition_id: Partition id to retrieve from the controller + create: Whether to register new keys if not found. + + Returns: + metadata: BatchMeta of the requested keys + """ + + logger.debug(f"[{self.controller_id}]: Retrieve keys {keys} in partition {partition_id}") + + partition = self._get_partition(partition_id) + + if partition is None: + if not create: + logger.warning(f"Partition {partition_id} were not found in controller!") + return BatchMeta.empty() + else: + self.create_partition(partition_id) + partition = self._get_partition(partition_id) + + assert partition is not None + global_indexes = partition.kv_retrieve_keys(keys) + + none_indexes = [idx for idx, value in enumerate(global_indexes) if value is None] + if len(none_indexes) > 0: + if not create: + logger.warning(f"Keys {[keys[i] for i in none_indexes]} were not found in partition {partition_id}!") + return BatchMeta.empty() + else: + # create non-exist keys + batch_global_indexes = partition.activate_pre_allocated_indexes(len(none_indexes)) + + if len(batch_global_indexes) < len(none_indexes): + new_global_indexes = self.index_manager.allocate_indexes( + partition_id, count=(len(none_indexes) - len(batch_global_indexes)) + ) + batch_global_indexes.extend(new_global_indexes) + + # register global_indexes in partition + partition.global_indexes.update(batch_global_indexes) + + # register key-global_indexes mapping in partition + for i in range(len(none_indexes)): + global_indexes[none_indexes[i]] = batch_global_indexes[i] + partition.keys_mapping[keys[none_indexes[i]]] = batch_global_indexes[i] + partition.revert_keys_mapping[batch_global_indexes[i]] = keys[none_indexes[i]] + + with partition.data_status_lock: + partition.ensure_samples_capacity(max(batch_global_indexes) + 1) + + verified_global_indexes = [idx for idx in global_indexes if idx is not None] + assert len(verified_global_indexes) == len(keys) + + # must fetch fields that the requested samples all have + col_mask = partition.production_status[verified_global_indexes, :].sum(dim=0).reshape(-1) == len( + verified_global_indexes + ) + data_fields = [] + for fname, col_idx in partition.field_name_mapping.items(): + if col_mask[col_idx]: + data_fields.append(fname) + + metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch") + + return metadata + def _init_zmq_socket(self): """Initialize ZMQ sockets for communication.""" self.zmq_context = zmq.Context() @@ -1727,6 +1824,51 @@ def _process_request(self): body={"partition_ids": partition_ids}, ) + elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS: + with perf_monitor.measure(op_type="KV_RETRIEVE_KEYS"): + params = request_msg.body + keys = params["keys"] + partition_id = params["partition_id"] + create = params["create"] + + metadata = self.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"metadata": metadata}, + ) + + elif request_msg.request_type == ZMQRequestType.KV_LIST: + with perf_monitor.measure(op_type="KV_LIST"): + params = request_msg.body + partition_id = params["partition_id"] + if partition_id is None: + partition_id = list(self.partitions.keys()) + else: + partition_id = [partition_id] + + message = "success" + partition_info = {} + for pid in partition_id: + partition = self._get_partition(pid) + if partition: + keys = list(partition.keys_mapping.keys()) + single_partition_info = { + k: partition.custom_meta.get(partition.keys_mapping[k], {}) for k in keys + } + partition_info[pid] = single_partition_info + else: + # this only happens when params["partition_id"] is not None + message = f"partition {pid} does not exist" + + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.KV_LIST_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"partition_info": partition_info, "message": message}, + ) + self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) def _update_data_status(self): diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 37ebdb2..8e4f8c3 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -21,12 +21,13 @@ from typing import Any, Optional import ray +import torch from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict +from tensordict.tensorclass import NonTensorStack from transfer_queue.client import TransferQueueClient from transfer_queue.controller import TransferQueueController -from transfer_queue.metadata import BatchMeta from transfer_queue.sampler import * # noqa: F401 from transfer_queue.sampler import BaseSampler from transfer_queue.storage.simple_backend import SimpleStorageUnit @@ -105,6 +106,7 @@ def _init_from_existing() -> None: time.sleep(1) +# ==================== Initialization API ==================== def init(conf: Optional[DictConfig] = None) -> None: """Initialize the TransferQueue system. @@ -194,456 +196,564 @@ def init(conf: Optional[DictConfig] = None) -> None: _maybe_create_transferqueue_client(final_conf) -def get_meta( - data_fields: list[str], - batch_size: int, - partition_id: str, - mode: str = "fetch", - task_name: Optional[str] = None, - sampling_config: Optional[dict[str, Any]] = None, -) -> BatchMeta: - """Synchronously fetch data metadata from the controller via ZMQ. +def close(): + """Close the TransferQueue system. - Args: - data_fields: List of data field names to retrieve metadata for - batch_size: Number of samples to request in the batch - partition_id: Current data partition id - mode: Data fetch mode. Options: - - 'fetch': Get ready data only - - 'force_fetch': Get data regardless of readiness (may return unready samples) - - 'insert': Internal usage - should not be used by users - task_name: Optional task name associated with the request - sampling_config: Optional sampling configuration for custom samplers. + This function cleans up the TransferQueue system, including: + - Closing the client and its associated resources + - Cleaning up distributed storage (only for the process that initialized it) + - Killing the controller actor + Note: + This function should be called when the TransferQueue system is no longer needed. + """ + global _TRANSFER_QUEUE_CLIENT + global _TRANSFER_QUEUE_STORAGE + if _TRANSFER_QUEUE_CLIENT: + _TRANSFER_QUEUE_CLIENT.close() + _TRANSFER_QUEUE_CLIENT = None - Returns: - BatchMeta: Metadata object containing data structure, sample information, and readiness status + try: + if _TRANSFER_QUEUE_STORAGE: + # only the process that do first-time init can clean the distributed storage + for storage in _TRANSFER_QUEUE_STORAGE.values(): + ray.kill(storage) + _TRANSFER_QUEUE_STORAGE = None + except Exception: + pass + + try: + controller = ray.get_actor("TransferQueueController") + ray.kill(controller) + except Exception: + pass + + +# ==================== High-Level KV Interface API ==================== +def kv_put( + key: str, + partition_id: str, + fields: Optional[TensorDict | dict[str, Any]] = None, + tag: Optional[dict[str, Any]] = None, +) -> None: + """Put a single key-value pair to TransferQueue. + + This is a convenience method for putting data using a user-specified key + instead of BatchMeta. Internally, the key is translated to a BatchMeta + and the data is stored using the regular put mechanism. + + Args: + key: User-specified key for the data sample (in row) + partition_id: Logical partition to store the data in + fields: Data fields to store. Can be a TensorDict or a dict of tensors. + Each key in `fields` will be treated as a column for the data sample. + If dict is provided, tensors will be unsqueezed to add batch dimension. + tag: Optional metadata tag to associate with the key Raises: - RuntimeError: If communication fails or controller returns error response + ValueError: If neither fields nor tag is provided + ValueError: If nested tensors are provided (use kv_batch_put instead) + RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` Example: >>> import transfer_queue as tq + >>> import torch >>> tq.init() - >>> - >>> # Example 1: Basic fetch metadata - >>> batch_meta = tq.get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences" - ... ) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) - >>> batch_meta = tq.get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=8, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", - ... sampling_config={"n_samples_per_prompt": 4} + >>> # Put with both fields and tag + >>> tq.kv_put( + ... key="sample_1", + ... partition_id="train", + ... fields={"input_ids": torch.tensor([1, 2, 3])}, + ... tag={"score": 0.95} ... ) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, - >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) - >>> batch_meta = tq.get_meta( - ... partition_id="train_0", # optional - ... mode="force_fetch", - ... ) - >>> print(batch_meta.is_ready) # May be False if some samples not ready """ + if fields is None and tag is None: + raise ValueError("Please provide at least one parameter of `fields` or `tag`.") tq_client = _maybe_create_transferqueue_client() - return tq_client.get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) + # 1. translate user-specified key to BatchMeta + batch_meta = tq_client.kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True) + + if batch_meta.size != 1: + raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") + + # 2. register the user-specified tag to BatchMeta + if tag is not None: + batch_meta.update_custom_meta([tag]) + + # 3. put data + if fields is not None: + if isinstance(fields, dict): + # TODO: consider whether to support this... + batch = {} + for field_name, value in fields.items(): + if isinstance(value, torch.Tensor): + if value.is_nested: + raise ValueError("Please use (async)kv_batch_put for batch operation") + batch[field_name] = value.unsqueeze(0) + else: + batch[field_name] = NonTensorStack(value) + fields = TensorDict(batch, batch_size=[1]) + elif not isinstance(fields, TensorDict): + raise ValueError("field can only be dict or TensorDict") + + # custom_meta (tag) will be put to controller through the internal put process + tq_client.put(fields, batch_meta) + else: + # directly update custom_meta (tag) to controller + tq_client.set_custom_meta(batch_meta) -async def async_get_meta( - data_fields: list[str], - batch_size: int, - partition_id: str, - mode: str = "fetch", - task_name: Optional[str] = None, - sampling_config: Optional[dict[str, Any]] = None, -) -> BatchMeta: - """Asynchronously fetch data metadata from the controller via ZMQ. - Args: - data_fields: List of data field names to retrieve metadata for - batch_size: Number of samples to request in the batch - partition_id: Current data partition id - mode: Data fetch mode. Options: - - 'fetch': Get ready data only - - 'force_fetch': Get data regardless of readiness (may return unready samples) - - 'insert': Internal usage - should not be used by users - task_name: Optional task name associated with the request - sampling_config: Optional sampling configuration for custom samplers. - socket: ZMQ async socket for message transmission (injected by decorator) +def kv_batch_put( + keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None +) -> None: + """Put multiple key-value pairs to TransferQueue in batch. - Returns: - BatchMeta: Metadata object containing data structure, sample information, and readiness status + This method stores multiple key-value pairs in a single operation, which is more + efficient than calling kv_put multiple times. + + Args: + keys: List of user-specified keys for the data + partition_id: Logical partition to store the data in + fields: TensorDict containing data for all keys. Must have batch_size == len(keys) + tags: List of metadata tags, one for each key Raises: - RuntimeError: If communication fails or controller returns error response + ValueError: If neither `fields` nor `tags` is provided + ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict + RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` Example: >>> import transfer_queue as tq + >>> from tensordict import TensorDict >>> tq.init() - >>> - >>> # Example 1: Basic fetch metadata - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences" - ... )) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=8, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", - ... )) - >>> print(batch_meta.is_ready) # True if all samples ready - >>> - >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, - >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... partition_id="train_0", # optional - ... mode="force_fetch", - ... )) - >>> print(batch_meta.is_ready) # May be False if some samples not ready + >>> keys = ["sample_1", "sample_2", "sample_3"] + >>> fields = TensorDict({ + ... "input_ids": torch.randn(3, 10), + ... "attention_mask": torch.ones(3, 10), + ... }, batch_size=3) + >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}] + >>> tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) """ + if fields is None and tags is None: + raise ValueError("Please provide at least one parameter of fields or tag.") + + if fields is not None and fields.batch_size[0] != len(keys): + raise ValueError( + f"`keys` with length {len(keys)} does not match the `fields` TensorDict with " + f"batch_size {fields.batch_size[0]}" + ) + tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_get_meta(data_fields, batch_size, partition_id, mode, task_name, sampling_config) + + # 1. translate user-specified key to BatchMeta + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True) + + if batch_meta.size != len(keys): + raise RuntimeError( + f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!" + ) + + # 2. register the user-specified tags to BatchMeta + if tags: + if len(tags) != len(keys): + raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}") + batch_meta.update_custom_meta(tags) + + # 3. put data + if fields is not None: + tq_client.put(fields, batch_meta) + else: + # directly update custom_meta (tags) to controller + tq_client.set_custom_meta(batch_meta) -def get_data(metadata: BatchMeta) -> TensorDict: - """Synchronously fetch data from storage units and organize into TensorDict. +def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict: + """Get data from TransferQueue using user-specified keys. + + This is a convenience method for retrieving data using keys instead of indexes. Args: - metadata: Batch metadata containing data location information and global indexes + keys: Single key or list of keys to retrieve + partition_id: Partition containing the keys + fields: Optional field(s) to retrieve. If None, retrieves all fields Returns: - TensorDict containing: - - Requested data fields (e.g., "prompts", "attention_mask") + TensorDict with the requested data + + Raises: + RuntimeError: If keys or partition are not found + RuntimeError: If empty fields exist in any key (sample) Example: >>> import transfer_queue as tq >>> tq.init() - >>> - >>> batch_meta = tq.get_data( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", + >>> # Get single key with all fields + >>> data = tq.kv_batch_get(keys="sample_1", partition_id="train") + >>> # Get multiple keys with specific fields + >>> data = tq.kv_batch_get( + ... keys=["sample_1", "sample_2"], + ... partition_id="train", + ... fields="input_ids" ... ) - >>> batch = tq.get_data(batch_meta) - >>> print(batch) - >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes """ tq_client = _maybe_create_transferqueue_client() - return tq_client.get_data(metadata) + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + + if batch_meta.size == 0: + raise RuntimeError("keys or partition were not found!") -async def async_get_data(metadata: BatchMeta) -> TensorDict: - """Asynchronously fetch data from storage units and organize into TensorDict. + if fields is not None: + if isinstance(fields, str): + fields = [fields] + batch_meta = batch_meta.select_fields(fields) + + if not batch_meta.is_ready: + raise RuntimeError("Some fields are not ready in all the requested keys!") + + data = tq_client.get_data(batch_meta) + return data + + +def kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: + """List all keys and their metadata in one or all partitions. Args: - metadata: Batch metadata containing data location information and global indexes + partition_id: The specific partition_id to query. + If None (default), returns keys from all partitions. Returns: - TensorDict containing: - - Requested data fields (e.g., "prompts", "attention_mask") + A nested dictionary mapping partition IDs to their keys and metadata. + + Structure: + { + "partition_id": { + "key_name": { + "tag1": , + ... (other metadata) + }, + ..., + }, + ... + } Example: >>> import transfer_queue as tq >>> tq.init() - >>> - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=4, - ... partition_id="train_0", - ... mode="fetch", - ... task_name="generate_sequences", - ... )) - >>> batch = asyncio.run(tq.async_get_data(batch_meta)) - >>> print(batch) - >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes + >>> # Case 1: Retrieve a specific partition + >>> partitions = tq.kv_list(partition_id="train") + >>> print(f"Keys: {list(partitions['train'].keys())}") + >>> print(f"Tags: {list(partitions['train'].values())}") + >>> # Case 2: Retrieve all partitions + >>> all_partitions = tq.kv_list() + >>> for pid, keys in all_partitions.items(): + >>> print(f"Partition: {pid}, Key count: {len(keys)}") """ tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_get_data(metadata) + partition_info = tq_client.kv_list(partition_id) -def put(data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None) -> BatchMeta: - """Synchronously write data to storage units based on metadata. + return partition_info - If metadata is not provided, it will be created automatically using insert mode - with the provided data fields and partition_id. - During put, the custom_meta in metadata will update the corresponding custom_meta in - TransferQueue Controller. +def kv_clear(keys: list[str] | str, partition_id: str) -> None: + """Clear key-value pairs from TransferQueue. - Note: - When using multiple workers for distributed execution, there may be data - ordering inconsistencies between workers during put operations. + This removes the specified keys and their associated data from both + the controller and storage units. Args: - data: Data to write as TensorDict - metadata: Records the metadata of a batch of data samples, containing index and - storage unit information. If None, metadata will be auto-generated. - partition_id: Target data partition id (required if metadata is not provided) - - Returns: - BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved - metadata; will be updated in a future version to reflect the post-put state) - - Raises: - ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided - RuntimeError: If storage operation fails + keys: Single key or list of keys to clear + partition_id: Partition containing the keys Example: >>> import transfer_queue as tq >>> tq.init() - >>> - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_partition_id = "train_0" - >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = tq.get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=batch_size, - ... partition_id=current_partition_id, - ... mode="fetch", - ... task_name="generate_sequences", - ... ) - >>> batch = tq.get_data(batch_meta) - >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> tq.put(data=output, metadata=batch_meta) - >>> - >>> # Example 2: Initial data insertion without pre-existing metadata - >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! - >>> # Please make sure the corresponding partition_id is empty before calling the async_put() - >>> # without metadata. - >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with - >>> # interleave the initial data if n_sample > 1 before calling the async_put(). - >>> original_prompts = torch.randn(batch_size, seq_len) - >>> n_samples = 4 - >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) - >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) - >>> # This will create metadata in "insert" mode internally. - >>> metadata = tq.put(data=prompts_repeated_batch, partition_id=current_partition_id) + >>> # Clear single key + >>> tq.kv_clear(keys="sample_1", partition_id="train") + >>> # Clear multiple keys + >>> tq.kv_clear(keys=["sample_1", "sample_2"], partition_id="train") """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.put(data, metadata, partition_id) + if isinstance(keys, str): + keys = [keys] -async def async_put( - data: TensorDict, - metadata: Optional[BatchMeta] = None, - partition_id: Optional[str] = None, -) -> BatchMeta: - """Asynchronously write data to storage units based on metadata. + tq_client = _maybe_create_transferqueue_client() + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) - If metadata is not provided, it will be created automatically using insert mode - with the provided data fields and partition_id. + if batch_meta.size > 0: + tq_client.clear_samples(batch_meta) - During put, the custom_meta in metadata will update the corresponding custom_meta in - TransferQueue Controller. - Note: - When using multiple workers for distributed execution, there may be data - ordering inconsistencies between workers during put operations. +# ==================== KV Interface API ==================== +async def async_kv_put( + key: str, + partition_id: str, + fields: Optional[TensorDict | dict[str, Any]] = None, + tag: Optional[dict[str, Any]] = None, +) -> None: + """Asynchronously put a single key-value pair to TransferQueue. - Args: - data: Data to write as TensorDict - metadata: Records the metadata of a batch of data samples, containing index and - storage unit information. If None, metadata will be auto-generated. - partition_id: Target data partition id (required if metadata is not provided) + This is a convenience method for putting data using a user-specified key + instead of BatchMeta. Internally, the key is translated to a BatchMeta + and the data is stored using the regular put mechanism. - Returns: - BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved - metadata; will be updated in a future version to reflect the post-put state) + Args: + key: User-specified key for the data sample (in row) + partition_id: Logical partition to store the data in + fields: Data fields to store. Can be a TensorDict or a dict of tensors. + Each key in `fields` will be treated as a column for the data sample. + If dict is provided, tensors will be unsqueezed to add batch dimension. + tag: Optional metadata tag to associate with the key Raises: - ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided - RuntimeError: If storage operation fails + ValueError: If neither fields nor tag is provided + ValueError: If nested tensors are provided (use kv_batch_put instead) + RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` Example: >>> import transfer_queue as tq + >>> import torch >>> tq.init() - >>> - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_partition_id = "train_0" - >>> # Example 1: Normal usage with existing metadata - >>> batch_meta = asyncio.run(tq.async_get_meta( - ... data_fields=["prompts", "attention_mask"], - ... batch_size=batch_size, - ... partition_id=current_partition_id, - ... mode="fetch", - ... task_name="generate_sequences", + >>> # Put with both fields and tag + >>> await tq.async_kv_put( + ... key="sample_1", + ... partition_id="train", + ... fields={"input_ids": torch.tensor([1, 2, 3])}, + ... tag={"score": 0.95} ... )) - >>> batch = asyncio.run(tq.async_get_data(batch_meta)) - >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> asyncio.run(tq.async_put(data=output, metadata=batch_meta)) - >>> - >>> # Example 2: Initial data insertion without pre-existing metadata - >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! - >>> # Please make sure the corresponding partition_id is empty before calling the async_put() - >>> # without metadata. - >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with - >>> # interleave the initial data if n_sample > 1 before calling the async_put(). - >>> original_prompts = torch.randn(batch_size, seq_len) - >>> n_samples = 4 - >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) - >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) - >>> # This will create metadata in "insert" mode internally. - >>> metadata = asyncio.run(tq.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) """ + + if fields is None and tag is None: + raise ValueError("Please provide at least one parameter of fields or tag.") + tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_put(data, metadata, partition_id) + # 1. translate user-specified key to BatchMeta + batch_meta = await tq_client.async_kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True) + + if batch_meta.size != 1: + raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") + + # 2. register the user-specified tag to BatchMeta + if tag: + batch_meta.update_custom_meta([tag]) + + # 3. put data + if fields is not None: + if isinstance(fields, dict): + # TODO: consider whether to support this... + batch = {} + for field_name, value in fields.items(): + if isinstance(value, torch.Tensor): + if value.is_nested: + raise ValueError("Please use (async)kv_batch_put for batch operation") + batch[field_name] = value.unsqueeze(0) + else: + batch[field_name] = NonTensorStack(value) + fields = TensorDict(batch, batch_size=[1]) + elif not isinstance(fields, TensorDict): + raise ValueError("field can only be dict or TensorDict") + + # custom_meta (tag) will be put to controller through the put process + await tq_client.async_put(fields, batch_meta) + else: + # directly update custom_meta (tag) to controller + await tq_client.async_set_custom_meta(batch_meta) -def set_custom_meta(metadata: BatchMeta) -> None: - """Synchronously send custom metadata to the controller. - This method sends per-sample custom metadata (custom_meta) to the controller. - The custom_meta is stored in the controller and can be retrieved along with - the BatchMeta in subsequent get_meta calls. +async def async_kv_batch_put( + keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None +) -> None: + """Asynchronously put multiple key-value pairs to TransferQueue in batch. + + This method stores multiple key-value pairs in a single operation, which is more + efficient than calling kv_put multiple times. Args: - metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() or - BatchMeta.set_custom_meta() before calling this method. + keys: List of user-specified keys for the data + partition_id: Logical partition to store the data in + fields: TensorDict containing data for all keys. Must have batch_size == len(keys) + tags: List of metadata tags, one for each key Raises: - RuntimeError: If communication fails or controller returns error response + ValueError: If neither `fields` nor `tags` is provided + ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict + RuntimeError: If retrieved BatchMeta size doesn't match length of `keys` Example: >>> import transfer_queue as tq >>> tq.init() - >>> - >>> # Create batch with custom metadata - >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...) - >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) - >>> tq.set_custom_meta(batch_meta) + >>> keys = ["sample_1", "sample_2", "sample_3"] + >>> fields = TensorDict({ + ... "input_ids": torch.randn(3, 10), + ... "attention_mask": torch.ones(3, 10), + ... }, batch_size=3) + >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}] + >>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) """ + + if fields is None and tags is None: + raise ValueError("Please provide at least one parameter of `fields` or `tags`.") + + if fields is not None and fields.batch_size[0] != len(keys): + raise ValueError( + f"`keys` with length {len(keys)} does not match the `fields` TensorDict with " + f"batch_size {fields.batch_size[0]}" + ) + tq_client = _maybe_create_transferqueue_client() - return tq_client.set_custom_meta(metadata) + # 1. translate user-specified key to BatchMeta + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True) -async def async_set_custom_meta( - metadata: BatchMeta, -) -> None: - """ - Asynchronously send custom metadata to the controller. + if batch_meta.size != len(keys): + raise RuntimeError( + f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!" + ) + + # 2. register the user-specified tags to BatchMeta + if tags is not None: + if len(tags) != len(keys): + raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}") + batch_meta.update_custom_meta(tags) + + # 3. put data + if fields is not None: + await tq_client.async_put(fields, batch_meta) + else: + # directly update custom_meta (tags) to controller + await tq_client.async_set_custom_meta(batch_meta) + + +async def async_kv_batch_get( + keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None +) -> TensorDict: + """Asynchronously get data from TransferQueue using user-specified keys. - This method sends per-sample custom metadata (custom_meta) to the controller. - The custom_meta is stored in the controller and can be retrieved along with - the BatchMeta in subsequent get_meta calls. + This is a convenience method for retrieving data using keys instead of indexes. Args: - metadata: BatchMeta containing the samples and their custom metadata to store. - The custom_meta should be set using BatchMeta.update_custom_meta() or - BatchMeta.set_custom_meta() before calling this method. - socket: ZMQ async socket for message transmission (injected by decorator) + keys: Single key or list of keys to retrieve + partition_id: Partition containing the keys + fields: Optional field(s) to retrieve. If None, retrieves all fields + + Returns: + TensorDict with the requested data Raises: - RuntimeError: If communication fails or controller returns error response + RuntimeError: If keys or partition are not found + RuntimeError: If empty fields exist in any key (sample) Example: >>> import transfer_queue as tq >>> tq.init() - >>> - >>> # Create batch with custom metadata - >>> batch_meta = tq.get_meta(data_fields=["input_ids"], batch_size=4, ...) - >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) - >>> asyncio.run(tq.async_set_custom_meta(batch_meta)) + >>> # Get single key with all fields + >>> data = await tq.async_kv_batch_get(keys="sample_1", partition_id="train") + >>> # Get multiple keys with specific fields + >>> data = await tq.async_kv_batch_get( + ... keys=["sample_1", "sample_2"], + ... partition_id="train", + ... fields="input_ids" + ... ) """ tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_set_custom_meta(metadata) + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) -def clear_samples(metadata: BatchMeta): - """Synchronously clear specific samples from all storage units and the controller. + if batch_meta.size == 0: + raise RuntimeError("keys or partition were not found!") - Args: - metadata: The BatchMeta of the corresponding data to be cleared + if fields is not None: + if isinstance(fields, str): + fields = [fields] + batch_meta = batch_meta.select_fields(fields) - Raises: - RuntimeError: If clear operation fails - """ - tq_client = _maybe_create_transferqueue_client() - return tq_client.clear_samples(metadata) + if not batch_meta.is_ready: + raise RuntimeError("Some fields are not ready in all the requested keys!") + data = await tq_client.async_get_data(batch_meta) + return data -async def async_clear_samples(metadata: BatchMeta): - """Asynchronously clear specific samples from all storage units and the controller. - Args: - metadata: The BatchMeta of the corresponding data to be cleared +async def async_kv_list(partition_id: Optional[str] = None) -> dict[str, dict[str, Any]]: + """Asynchronously list all keys and their metadata in one or all partitions. - Raises: - RuntimeError: If clear operation fails - """ - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_clear_samples(metadata) + Args: + partition_id: The specific partition_id to query. + If None (default), returns keys from all partitions. + Returns: + A nested dictionary mapping partition IDs to their keys and metadata. -def clear_partition(partition_id: str): - """Synchronously clear the whole partition from all storage units and the controller. + Structure: + { + "partition_id": { + "key_name": { + "tag1": , + ... (other metadata) + }, + ..., + }, + ... + } - Args: - partition_id: The partition id to clear data for - Raises: - RuntimeError: If clear operation fails + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # Case 1: Retrieve a specific partition + >>> partitions = await tq.async_kv_list(partition_id="train") + >>> print(f"Keys: {list(partitions['train'].keys())}") + >>> print(f"Tags: {list(partitions['train'].values())}") + >>> # Case 2: Retrieve all partitions + >>> all_partitions = await tq.async_kv_list() + >>> for pid, keys in all_partitions.items(): + >>> print(f"Partition: {pid}, Key count: {len(keys)}") """ tq_client = _maybe_create_transferqueue_client() - return tq_client.clear_partition(partition_id) + + partition_info = await tq_client.async_kv_list(partition_id) + + return partition_info -async def async_clear_partition(partition_id: str): - """Asynchronously clear the whole partition from all storage units and the controller. +async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: + """Asynchronously clear key-value pairs from TransferQueue. + + This removes the specified keys and their associated data from both + the controller and storage units. Args: - partition_id: The partition id to clear data for + keys: Single key or list of keys to clear + partition_id: Partition containing the keys - Raises: - RuntimeError: If clear operation fails + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # Clear single key + >>> await tq.async_kv_clear(keys="sample_1", partition_id="train") + >>> # Clear multiple keys + >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train") """ - tq_client = _maybe_create_transferqueue_client() - return await tq_client.async_clear_partition(partition_id) + if isinstance(keys, str): + keys = [keys] -def close(): - """Close the TransferQueue system.""" - global _TRANSFER_QUEUE_CLIENT - global _TRANSFER_QUEUE_STORAGE - if _TRANSFER_QUEUE_CLIENT: - _TRANSFER_QUEUE_CLIENT.close() - _TRANSFER_QUEUE_CLIENT = None + tq_client = _maybe_create_transferqueue_client() + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) - try: - if _TRANSFER_QUEUE_STORAGE: - # only the process that do first-time init can clean the distributed storage - for storage in _TRANSFER_QUEUE_STORAGE.values(): - ray.kill(storage) - _TRANSFER_QUEUE_STORAGE = None - except Exception: - pass + if batch_meta.size > 0: + await tq_client.async_clear_samples(batch_meta) - try: - controller = ray.get_actor("TransferQueueController") - ray.kill(controller) - except Exception: - pass + +# ==================== Low-Level Native API ==================== +# For low-level API support, please refer to transfer_queue/client.py for details. +def get_client(): + """Get a TransferQueueClient for using low-level API""" + if _TRANSFER_QUEUE_CLIENT is None: + raise RuntimeError("Please initialize the TransferQueue first by calling `tq.init()`!") + return _TRANSFER_QUEUE_CLIENT diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 5e25404..056a146 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -303,60 +303,43 @@ def clear_extra_info(self) -> None: """ self.extra_info.clear() - def set_custom_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None: + def get_all_custom_meta(self) -> list[dict[str, Any]]: """ - Set custom_meta for a specific sample by global_index. - - Custom metadata is user-defined per-sample metadata that can be stored - and retrieved along with the BatchMeta. - - Args: - global_index: The global_index of the sample to set custom meta for - meta_dict: Dictionary containing custom metadata for the sample - - Raises: - ValueError: If the key is not in global_indexes - """ - - if global_index not in self.global_indexes: - raise ValueError(f"key {global_index} not found in global_indexes {self.global_indexes}.") - - self.custom_meta[global_index] = copy.deepcopy(meta_dict) - - def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: - """ - Get all custom_meta as a dictionary. + Get all custom_meta as a list of dictionary. Returns: - A deep copy of the custom_meta dictionary + A deep copy of the custom_meta list """ - return copy.deepcopy(self.custom_meta) + custom_meta = [self.custom_meta.get(i, {}) for i in self.global_indexes] + return copy.deepcopy(custom_meta) - def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): + def update_custom_meta(self, custom_meta: list[dict[str, Any]]): """ - Update custom_meta with a dictionary of new metadata. + Update custom_meta with a list of dictionary of custom metadata. This method updates the custom_meta dictionary with the provided metadata. Existing keys will be overwritten with new values. Args: - new_meta: Dictionary of new metadata + custom_meta: list of custom_meta dictionary Raises: - ValueError: If any key in new_meta is not in global_indexes + ValueError: If the length of custom_meta cannot match the batch size """ - if new_meta is None: + if custom_meta is None: return - non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes) - if non_exist_global_indexes: + if len(custom_meta) != self.size: raise ValueError( - f"Trying to update custom_meta with non-exist global_indexes! {non_exist_global_indexes} " - f"do not exist in this batch." + f"The length of custom_meta list {len(custom_meta)} must match the batch size: {self.size}" ) - self.custom_meta.update(new_meta) + custom_meta_dict: dict[int, dict[str, Any]] = { + self.global_indexes[i]: custom_meta[i] for i in range(len(custom_meta)) + } + + self.custom_meta.update(custom_meta_dict) def clear_custom_meta(self) -> None: """ @@ -846,3 +829,211 @@ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> ] return all_fields + + +# ==================== KV Interface Metadata ==================== +@dataclass +class KVBatchMeta: + """Records the metadata for KV interface.""" + + # keys of each sample + keys: list[str] = dataclasses.field(default_factory=list) + + # sample-level tags + tags: list[dict] = dataclasses.field(default_factory=list) + + # [optional] partition_id of this batch + partition_id: Optional[str] = None + + # [optional] fields of each sample + fields: Optional[list[str]] = None + + # [optional] external information for batch-level information + extra_info: Optional[dict[str, Any]] = None + + def __post_init__(self): + """Validate all the variables""" + if len(self.keys) != len(self.tags): + raise ValueError(f"keys and tags must have same length, but got {len(self.keys)} and {len(self.tags)}") + if len(self.keys) != len(set(self.keys)): + raise ValueError("Got duplicated keys.") + if self.fields is not None: + if len(self.fields) != len(set(self.fields)): + raise ValueError("Got duplicated fields.") + + # deepcopy to prevent unexpected behavior after chunk/concat + self.tags = copy.deepcopy(self.tags) + self.extra_info = copy.deepcopy(self.extra_info) + + object.__setattr__(self, "_size", len(self.keys)) + + @property + def size(self) -> int: + """Return the number of samples in this batch""" + return getattr(self, "_size", 0) + + def __len__(self) -> int: + """Return the number of samples in this batch.""" + return len(self.keys) + + def __str__(self): + return f"KVBatchMeta(size={self.size}, field_names={self.fields}, extra_info={self.extra_info})" + + def select_keys(self, keys_to_select: list[str]) -> "KVBatchMeta": + """ + Select specific keys from this batch. + This will construct a new KVBatchMeta instance containing only the specified keys. + + Args: + keys_to_select (list[str]): List of keys to retain. + + Returns: + KVBatchMeta: A new KVBatchMeta instance containing only the specified keys. + + Raises: + ValueError: If duplicate keys exist in input param `keys_to_select`. + RuntimeError: If `keys_to_select` contains keys that do not exist in this batch. + """ + + if len(set(keys_to_select)) != len(keys_to_select): + raise ValueError("Contain duplicate keys.") + + non_exist_keys = set(keys_to_select) - set(self.keys) + if len(non_exist_keys) > 0: + raise RuntimeError(f"Keys {non_exist_keys} not found in current batch.") + + _keys_to_idx = {key: idx for idx, key in enumerate(self.keys)} + + loc_idx = [_keys_to_idx[k] for k in keys_to_select] + tags = [self.tags[i] for i in loc_idx] + + return KVBatchMeta( + keys=keys_to_select, + tags=tags, + partition_id=self.partition_id, + fields=self.fields, + extra_info=self.extra_info, + ) + + def reorder(self, indexes: list[int]): + """ + Reorder the samples in this batch according to the specified indexes. + + The operation is performed in-place. + + Args: + indexes : list[int] + A list of integers specifying the new order of SampleMeta. + + Raises: + ValueError: If the size of input `indexes` does not match with the batch size. + ValueError: If duplicate indexes exist in input param `indexes`. + """ + if len(indexes) != self.size: + raise ValueError( + f"Attempted to reorder with indexes length {len(indexes)} that does not match " + f"the batch size {self.size}." + ) + + if len(set(indexes)) != len(indexes): + raise ValueError("Contain duplicate indexes.") + + self.keys = [self.keys[i] for i in indexes] + self.tags = [self.tags[i] for i in indexes] + + def chunk(self, chunks: int) -> list["KVBatchMeta"]: + """ + Split this batch into smaller chunks. + + Args: + chunks: number of chunks + + Return: + List of smaller KVBatchMeta chunks + """ + + chunk_list = [] + if self.size < chunks: + logger.warning( + f"Chunk size {chunks} > number of samples in this batch {self.size}, this will return some " + f"empty KVBatchMeta chunks." + ) + + # Calculate the base size and remainder of each chunk + base_size = self.size // chunks + remainder = self.size % chunks + + start = 0 + for i in range(chunks): + # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size) + current_chunk_size = base_size + 1 if i < remainder else base_size + end = start + current_chunk_size + chunk_keys = self.keys[start:end] + chunk_tags = self.tags[start:end] + + chunk = KVBatchMeta( + keys=chunk_keys, + tags=chunk_tags, + partition_id=self.partition_id, + fields=self.fields, + extra_info=self.extra_info, + ) + chunk_list.append(chunk) + start = end + + return chunk_list + + @classmethod + def concat(cls, data: list["KVBatchMeta"]) -> "KVBatchMeta": + """ + Concatenate multiple KVBatchMeta chunks into one large batch. + + Args: + data: List of KVBatchMeta chunks to concatenate + + Returns: + Concatenated KVBatchMeta + + Raises: + ValueError: If validation fails (e.g., field names do not match) + """ + if not data: + logger.warning("Try to concat empty KVBatchMeta chunks. Returning empty KVBatchMeta.") + return KVBatchMeta() + + # skip empty chunks + data = [chunk for chunk in data if chunk and chunk.size > 0] + + if len(data) == 0: + logger.warning("No valid KVBatchMeta chunks to concatenate. Returning empty KVBatchMeta.") + return KVBatchMeta() + + base_fields = data[0].fields + if base_fields is not None: + base_fields_set = set(base_fields) + else: + base_fields_set = set() + + base_partition_id = data[0].partition_id + + all_keys = [] + all_tags = [] + all_extra_info = {} + for chunk in data: + if chunk.fields is not None and set(chunk.fields) != base_fields_set: + raise ValueError("Field names do not match for concatenation.") + if chunk.partition_id != base_partition_id: + raise ValueError("Partition do not match for concatenation.") + + all_keys.extend(chunk.keys) + all_tags.extend(chunk.tags) + if chunk.extra_info is not None: + all_extra_info.update(chunk.extra_info) + + return KVBatchMeta( + keys=all_keys, + tags=all_tags, + partition_id=base_partition_id, + fields=base_fields, + extra_info=all_extra_info if all_extra_info else None, + ) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 352b3ad..c0c5058 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -177,7 +177,7 @@ def _send_handshake_requests(self) -> None: """Send handshake request to controller.""" assert self.controller_handshake_socket is not None, "controller_handshake_socket is not properly initialized" request_msg = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE, + request_type=ZMQRequestType.HANDSHAKE, # type: ignore[arg-type] sender_id=self.storage_manager_id, body={ "storage_manager_id": self.storage_manager_id, @@ -225,7 +225,7 @@ async def notify_data_update( poller.register(self.data_status_update_socket, zmq.POLLIN) request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type] sender_id=self.storage_manager_id, body={ "partition_id": partition_id, @@ -245,7 +245,7 @@ async def notify_data_update( ) except Exception as e: request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type] sender_id=self.storage_manager_id, body={ "message": f"Failed to notify data status update information from " @@ -556,7 +556,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: loop = asyncio.get_event_loop() # put to storage backends - custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) + custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) per_field_dtypes: dict[int, dict[str, Any]] = {} per_field_shapes: dict[int, dict[str, Any]] = {} @@ -577,24 +577,28 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: getattr(data_item, "shape", None) if isinstance(data_item, Tensor) else None ) - # Prepare per-field custom_meta if available - per_field_custom_meta: dict[int, dict[str, Any]] = {} - if custom_meta: - if len(custom_meta) != len(keys): - raise ValueError(f"Length of custom_meta ({len(custom_meta)}) does not match expected ({len(keys)})") + # Prepare per-field custom_backend_meta if available + per_field_custom_backend_meta: dict[int, dict[str, Any]] = {} + if custom_backend_meta: + if len(custom_backend_meta) != len(keys): + raise ValueError( + f"Length of custom_backend_meta ({len(custom_backend_meta)}) does not match expected ({len(keys)})" + ) # custom meta is a flat list aligned with keys/values # Use itertools.product to eliminate nested loops for global_idx in metadata.global_indexes: - per_field_custom_meta[global_idx] = {} + per_field_custom_backend_meta[global_idx] = {} - # TODO(tianyi): the order of custom meta is coupled with keys/values + # FIXME(tianyi): the order of custom backend meta is coupled with keys/values + # FIXME: if put_data is called to partially update/add new fields, the current + # implementation will cause custom_backend_meta losses or mismatch! for (field_name, global_idx), meta_value in zip( itertools.product(sorted(metadata.field_names), metadata.global_indexes), - custom_meta, + custom_backend_meta, strict=True, ): - per_field_custom_meta[global_idx][field_name] = meta_value - metadata.update_custom_meta(per_field_custom_meta) + per_field_custom_backend_meta[global_idx][field_name] = meta_value + metadata._custom_backend_meta.update(per_field_custom_backend_meta) # Get current data partition id # Note: Currently we only support putting to & getting data from a single data partition simultaneously, @@ -607,7 +611,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: metadata.global_indexes, per_field_dtypes, per_field_shapes, - per_field_custom_meta, + per_field_custom_backend_meta, ) async def get_data(self, metadata: BatchMeta) -> TensorDict: diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index bf711c8..eaaf65e 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -101,6 +101,12 @@ class ZMQRequestType(ExplicitEnum): NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK" NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR" + # KV_INTERFACE + KV_RETRIEVE_KEYS = "KV_RETRIEVE_KEYS" + KV_RETRIEVE_KEYS_RESPONSE = "KV_RETRIEVE_KEYS_RESPONSE" + KV_LIST = "KV_LIST" + KV_LIST_RESPONSE = "KV_LIST_RESPONSE" + class ZMQServerInfo: """ diff --git a/tutorial/01_core_components.py b/tutorial/01_core_components.py index 59159a6..9a66b8e 100644 --- a/tutorial/01_core_components.py +++ b/tutorial/01_core_components.py @@ -63,6 +63,7 @@ def demonstrate_data_workflow(): # Step 1: Put data print("[Step 1] Putting data into TransferQueue...") + tq_client = tq.get_client() input_ids = torch.tensor( [ @@ -84,12 +85,12 @@ def demonstrate_data_workflow(): print(f" Created {data_batch.batch_size[0]} samples") partition_id = "tutorial_partition_0" - tq.put(data=data_batch, partition_id=partition_id) + tq_client.put(data=data_batch, partition_id=partition_id) print(f" βœ“ Data written to partition: {partition_id}") # Step 2: Get metadata print("[Step 2] Requesting data metadata...") - batch_meta = tq.get_meta( + batch_meta = tq_client.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=data_batch.batch_size[0], partition_id=partition_id, @@ -100,7 +101,7 @@ def demonstrate_data_workflow(): # Step 3: Get actual data print("[Step 3] Retrieving actual data...") - retrieved_data = tq.get_data(batch_meta) + retrieved_data = tq_client.get_data(batch_meta) print(" βœ“ Data retrieved successfully") print(f" Keys: {list(retrieved_data.keys())}") @@ -112,7 +113,7 @@ def demonstrate_data_workflow(): # Step 5: Clear print("[Step 5] Clearing partition... (you may also use clear_samples() to clear specific samples)") - tq.clear_partition(partition_id=partition_id) + tq_client.clear_partition(partition_id=partition_id) print(" βœ“ Partition cleared") diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py new file mode 100644 index 0000000..376ebfb --- /dev/null +++ b/tutorial/02_kv_interface.py @@ -0,0 +1,259 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import textwrap +import warnings +from pathlib import Path + +warnings.filterwarnings( + action="ignore", + message=r"The PyTorch API of nested tensors is in prototype stage*", + category=UserWarning, + module=r"torch\.nested", +) + +warnings.filterwarnings( + action="ignore", + message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible " + r"devices env var if num_gpus=0 or num_gpus=None.*", + category=FutureWarning, + module=r"ray\._private\.worker", +) + + +import ray # noqa: E402 +import torch # noqa: E402 +from tensordict import TensorDict # noqa: E402 + +# Add the parent directory to the path +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + +import transfer_queue as tq # noqa: E402 + +# Configure Ray +os.environ["RAY_DEDUP_LOGS"] = "0" +os.environ["RAY_DEBUG"] = "1" + +if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + +def demonstrate_kv_api(): + """ + Demonstrate the Key-Value (KV) semantic API: + kv_put & kv_batch_put -> kv_list -> kv_batch_get -> kv_clear + """ + print("=" * 80) + print("Key-Value Semantic API Demo: kv_put/kv_batch_put β†’ kv_list β†’ kv_batch_get β†’ kv_clear") + print("=" * 80) + + # Step 1: Put a single key-value pair with kv_put + print("[Step 1] Putting a single sample with kv_put...") + + # Define the data content (The "Value") + input_ids = torch.tensor([[1, 2, 3]]) + attention_mask = torch.ones(input_ids.size()) + + single_sample = TensorDict( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + }, + batch_size=input_ids.size(0), + ) + + partition_id = "Train" + # Use a meaningful string key instead of an auto-increment integer + key = "0_0" # User-defined key: "{uid}_{session_id}" + tag = {"global_steps": 0, "status": "running", "model_version": 0} + + print(f" Inserting Key: {key}") + print(f" Fields (Columns): {list(single_sample.keys())}") + print(f" Tag (Metadata): {tag}") + + tq.kv_put(key=key, partition_id=partition_id, fields=single_sample, tag=tag) + print(" βœ“ kv_put success.") + + # Step 2: Put multiple key-value pairs with kv_batch_put + print("\n[Step 2] Putting batch data with kv_batch_put...") + + batch_input_ids = torch.tensor( + [ + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + [13, 14, 15], + ] + ) + batch_attention_mask = torch.ones_like(batch_input_ids) + + data_batch = TensorDict( + { + "input_ids": batch_input_ids, + "attention_mask": batch_attention_mask, + }, + batch_size=batch_input_ids.size(0), + ) + + keys = ["1_0", "1_1", "1_2", "2_0"] # 4 keys for 4 samples + tags = [{"global_steps": 1, "status": "running", "model_version": 1} for _ in range(len(keys))] + + print(f" Inserting batch of {len(keys)} samples.") + print(f" Fields (Columns): {list(data_batch.keys())}") + print(f" Tag (Metadata): {tags}") + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=data_batch, tags=tags) + print(" βœ“ kv_batch_put success.") + + # Step 3: Append additional fields to existing samples + print("\n[Step 3] Appending new fields (Columns) to existing samples...") + + batch_response = torch.tensor( + [ + [4, 5, 6], + [7, 8, 9], + ] + ) + response_batch = TensorDict( + { + "response": batch_response, + }, + batch_size=batch_response.size(0), + ) + + # We only update subset of keys + append_keys = ["1_1", "2_0"] # Appending to existing samples + append_tags = [{"global_steps": 1, "status": "finish", "model_version": 1} for _ in range(len(append_keys))] + print(f" Target Keys: {append_keys}") + print(" New Field to Add is: 'response'") + print(f" The updated tags are: {append_tags}") + tq.kv_batch_put(keys=append_keys, partition_id=partition_id, fields=response_batch, tags=append_tags) + print(" βœ“ Update success: Samples '1_1' and '2_0' now contain {input_ids, attention_mask, response}.") + + # Step 4: Only update tags through kv_put + print("\n[Step 4] Update existing tags without providing value...") + key_for_update_tags = "0_0" + tag_update = {"global_steps": 0, "status": "finish", "model_version": 0} + print(f" Target Key: {key_for_update_tags}") + print(f" The updated tag is: {tag_update}") + tq.kv_put(key=key_for_update_tags, partition_id=partition_id, fields=None, tag=tag_update) + print(f" βœ“ Update success: Samples '0_0' now has tag as {tag_update}.") + + # Step 5: List all keys and tags in a partition + print("\n[Step 5] Listing all keys and tags in partition...") + + partition_info = tq.kv_list() + print(f" Found {len(partition_info.keys())} partitions: '{list(partition_info.keys())}'") + for pid, keys_and_tags in partition_info.items(): + for k, t in keys_and_tags.items(): + print(f"Partition: {pid}, - key='{k}' | tag={t}") + + # Step 6: Retrieve specific fields using kv_batch_get + print("\n[Step 6] Retrieving specific fields (Column) with kv_batch_get...") + print(" Fetching only 'input_ids' to save bandwidth (ignoring 'attention_mask' and 'response').") + + all_keys = list(partition_info[partition_id].keys()) + retrieved_input_ids = tq.kv_batch_get(keys=all_keys, partition_id=partition_id, fields="input_ids") + print(f" βœ“ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.") + + # # Step 7: Retrieve all fields using kv_batch_get + print("\n[Step 7] Retrieving all fields with kv_batch_get...") + retrieved_all = tq.kv_batch_get(keys=all_keys, partition_id=partition_id) + print(f" Retrieved all fields for {all_keys}:") + print(f" Fields: {list(retrieved_all.keys())}") + print( + f" Note: We cannot retrieve fields {list(response_batch.keys())}, since they only available in {append_keys}" + ) + + # Step 8: Clear specific keys + print("\n[Step 8] Clearing keys from partition...") + keys_to_clear = all_keys[:2] # Delete the first 2 keys + tq.kv_clear(keys=keys_to_clear, partition_id=partition_id) + print(f" βœ“ Cleared keys: {keys_to_clear}") + + partition_info_after_clear = tq.kv_list(partition_id=partition_id) + print(f" Remaining keys in partition: {list(partition_info_after_clear[partition_id].keys())}") + + +def main(): + print("=" * 80) + print( + textwrap.dedent( + """ + TransferQueue Tutorial 2: Key-Value (KV) Semantic API + + This tutorial demonstrates the KV semantic API, which provides a simple + interface for data storage and retrieval using user-defined string keys. + + Key Methods: + 1. (async_)kv_put - Insert/Update a multi-column sample by key, with optional metadata tag + 2. (async_)kv_batch_put - Put multiple key-value pairs efficiently in batch + 3. (async_)kv_batch_get - Retrieve samples (by keys), supporting column selection (by fields) + 4. (async_)kv_list - List keys and tags (metadata) in a partition + 5. (async_)kv_clear - Remove key-value pairs from storage + + Key Features: + βœ“ Redis-style Semantics - Familiar KV interface (Put/Get/List) for zero learning curve + βœ“ Fine-grained Access - Update or retrieve specific fields (columns) within a key (row) without full op. + βœ“ Partition Isolation - Logical separation of storage namespaces + βœ“ Metadata Tags - Lightweight metadata for status tracking + βœ“ Pluggable Backends - Supports multiple backends + + Use Cases: + - Focusing on fine-grained data access where extreme streaming performance is non-essential + - Integration with external ReplayBuffer/single-controller that manage sample dispatching + + Limitations (vs low-level native APIs): + - No built-in production/consumption tracking: Users must manually check status via tags externally. + - No built-in Sampler support: Must implement data dispatch by ReplayBuffer or single-controller externally. + - Not fully streaming: Consumers must wait for single-controller to dispatch `keys`. + """ + ) + ) + print("=" * 80) + + try: + print("Setting up TransferQueue...") + tq.init() + + print("\nDemonstrating the KV semantic API...") + demonstrate_kv_api() + + print("\n" + "=" * 80) + print("Tutorial Complete!") + print("=" * 80) + print("\nKey Takeaways:") + print(" 1. KV API simplifies data access with Redis-style semantics") + print(" 2. Use 'fields' parameter to get/put specific fields only") + print(" 3. Tags enable custom metadata for production status, scores, etc.") + print(" 4. Use kv_list to inspect partition contents") + + # Cleanup + tq.close() + ray.shutdown() + print("\nCleanup complete") + + except Exception as e: + print(f"Error during tutorial: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tutorial/02_metadata_concepts.py b/tutorial/03_metadata_concepts.py similarity index 93% rename from tutorial/02_metadata_concepts.py rename to tutorial/03_metadata_concepts.py index d864db3..2c81941 100644 --- a/tutorial/02_metadata_concepts.py +++ b/tutorial/03_metadata_concepts.py @@ -205,10 +205,15 @@ def demonstrate_batch_meta(): print(f"βœ“ Extra info: {batch.get_all_extra_info()}") print("[Example 3] Adding sample-level information through custom_meta...") - batch.set_custom_meta( - global_index=0, meta_dict={"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"} + batch.update_custom_meta( + [ + {"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"}, + {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"}, + {"uid": "prompt@2", "session_id": "session@0", "model_version": "epoch@0"}, + {"uid": "prompt@3", "session_id": "session@0", "model_version": "epoch@0"}, + {"uid": "prompt@4", "session_id": "session@0", "model_version": "epoch@0"}, + ] ) - batch.update_custom_meta({1: {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"}}) print(f"βœ“ Custom meta: {batch.get_all_custom_meta()}") # Example 4: Chunk a batch @@ -300,11 +305,13 @@ def demonstrate_real_workflow(): # Initialize Ray if not ray.is_initialized(): - ray.init() + ray.init(namespace="TransferQueueTutorial") # Initialize TransferQueue tq.init() + tq_client = tq.get_client() + print("[Step 1] Putting data into TransferQueue...") input_ids = torch.randint(0, 1000, (8, 512)) attention_mask = torch.ones(8, 512) @@ -318,23 +325,23 @@ def demonstrate_real_workflow(): ) partition_id = "demo_partition" - batch_meta = tq.put(data=data_batch, partition_id=partition_id) + batch_meta = tq_client.put(data=data_batch, partition_id=partition_id) print(f"βœ“ Put {data_batch.batch_size[0]} samples into partition '{partition_id}', got BatchMeta back {batch_meta}.") print("[Step 2] [Optional] Setting sample-level custom_meta...") - custom_meta = { - global_index: {"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0} - for global_index in batch_meta.global_indexes - } + custom_meta = [ + {"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0} + for _ in range(batch_meta.size) + ] batch_meta.update_custom_meta(custom_meta) print(f"βœ“ Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}") - tq.set_custom_meta(batch_meta) + tq_client.set_custom_meta(batch_meta) print("βœ“ Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.") print("[Step 3] Try to get metadata from TransferQueue from other places...") - batch_meta = tq.get_meta( + batch_meta = tq_client.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=8, partition_id=partition_id, @@ -353,7 +360,7 @@ def demonstrate_real_workflow(): print("βœ“ Selected 'input_ids' field only:") print(f" New field names: {selected_meta.field_names}") print(f" Samples still have same global indexes: {selected_meta.global_indexes}") - retrieved_data = tq.get_data(selected_meta) + retrieved_data = tq_client.get_data(selected_meta) print(f" Retrieved data keys: {list(retrieved_data.keys())}") print("[Step 5] Select specific samples from the retrieved BatchMeta...") @@ -361,7 +368,7 @@ def demonstrate_real_workflow(): print("βœ“ Selected samples at indices [0, 2, 4, 6]:") print(f" New global indexes: {partial_meta.global_indexes}") print(f" Number of samples: {len(partial_meta)}") - retrieved_data = tq.get_data(partial_meta) + retrieved_data = tq_client.get_data(partial_meta) print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}") print("[Step 6] Demonstrate chunk operation...") @@ -369,11 +376,11 @@ def demonstrate_real_workflow(): print(f"βœ“ Chunked into {len(chunks)} parts:") for i, chunk in enumerate(chunks): print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}") - chunk_data = tq.get_data(chunk) + chunk_data = tq_client.get_data(chunk) print(f" Chunk {i}: Retrieved chunk data: {chunk_data}") # Cleanup - tq.clear_partition(partition_id=partition_id) + tq_client.clear_partition(partition_id=partition_id) tq.close() ray.shutdown() print("βœ“ Partition cleared and resources cleaned up") @@ -385,7 +392,7 @@ def main(): print( textwrap.dedent( """ - TransferQueue Tutorial 2: Metadata System + TransferQueue Tutorial 3: Metadata System This script introduces the metadata system in TransferQueue, which tracks the structure and state of data: diff --git a/tutorial/03_understanding_controller.py b/tutorial/04_understanding_controller.py similarity index 86% rename from tutorial/03_understanding_controller.py rename to tutorial/04_understanding_controller.py index 4ca426d..746de14 100644 --- a/tutorial/03_understanding_controller.py +++ b/tutorial/04_understanding_controller.py @@ -62,6 +62,8 @@ def demonstrate_partition_isolation(): tq.init() + tq_client = tq.get_client() + # Partition 1: Training data print("\n[Partition 1] Putting training data...") train_data = TensorDict( @@ -71,7 +73,7 @@ def demonstrate_partition_isolation(): }, batch_size=2, ) - tq.put(data=train_data, partition_id="train") + tq_client.put(data=train_data, partition_id="train") print(" βœ“ Training data added to 'train' partition") # Partition 2: Validation data @@ -83,23 +85,25 @@ def demonstrate_partition_isolation(): }, batch_size=2, ) - tq.put(data=val_data, partition_id="val") + tq_client.put(data=val_data, partition_id="val") print(" βœ“ Validation data added to 'val' partition") # Get from train partition print("\n[Retrieving from 'train' partition]") - train_meta = tq.get_meta( + train_meta = tq_client.get_meta( data_fields=["input_ids", "labels"], batch_size=2, partition_id="train", task_name="train_task" ) - retrieved_train_data = tq.get_data(train_meta) + retrieved_train_data = tq_client.get_data(train_meta) print(f" βœ“ Got BatchMeta={train_meta} from partition 'train'") print(f" βœ“ Retrieved Data: input_ids={retrieved_train_data['input_ids']}, labels={retrieved_train_data['labels']}") # Get from val partition print("\n[Retrieving from 'val' partition]") - val_meta = tq.get_meta(data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task") - retrieved_val_data = tq.get_data(val_meta) + val_meta = tq_client.get_meta( + data_fields=["input_ids", "labels"], batch_size=2, partition_id="val", task_name="val_task" + ) + retrieved_val_data = tq_client.get_data(val_meta) print(f" βœ“ Got BatchMeta={val_meta} from partition 'val'") print(f" βœ“ Retrieved Data: input_ids={retrieved_val_data['input_ids']}, labels={retrieved_val_data['labels']}") @@ -107,8 +111,8 @@ def demonstrate_partition_isolation(): print(" βœ“ Data isolation: 'train' and 'val' partitions are completely independent") # Cleanup - tq.clear_partition(partition_id="train") - tq.clear_partition(partition_id="val") + tq_client.clear_partition(partition_id="train") + tq_client.clear_partition(partition_id="val") tq.close() ray.shutdown() @@ -126,6 +130,8 @@ def demonstrate_dynamic_expansion(): tq.init() + tq_client = tq.get_client() + # Add first batch with 2 samples, 2 fields print("\n[Step 1] Adding initial data (2 samples, 2 fields)...") data1 = TensorDict( @@ -135,7 +141,7 @@ def demonstrate_dynamic_expansion(): }, batch_size=2, ) - meta1 = tq.put(data=data1, partition_id="dynamic") + meta1 = tq_client.put(data=data1, partition_id="dynamic") print(" βœ“ Added 2 samples") print(f" βœ“ Got BatchMeta: {meta1} samples") @@ -148,9 +154,9 @@ def demonstrate_dynamic_expansion(): }, batch_size=3, ) - meta2 = tq.put(data=data2, partition_id="dynamic") + meta2 = tq_client.put(data=data2, partition_id="dynamic") - all_meta = tq.get_meta( + all_meta = tq_client.get_meta( data_fields=["field1", "field2"], batch_size=5, partition_id="dynamic", task_name="dynamic_task" ) print(" βœ“ Added 3 more samples (total: 5)") @@ -165,7 +171,7 @@ def demonstrate_dynamic_expansion(): }, batch_size=2, ) - meta3 = tq.put(data=data3, metadata=meta1) + meta3 = tq_client.put(data=data3, metadata=meta1) print(" βœ“ Added 2 samples with new field 'field3'") print(f" βœ“ Got BatchMeta: {meta3} for newly put data with new field") @@ -174,7 +180,7 @@ def demonstrate_dynamic_expansion(): print(" βœ“ Columns auto-expand: Can add new fields anytime") # Cleanup - tq.clear_partition(partition_id="dynamic") + tq_client.clear_partition(partition_id="dynamic") tq.close() ray.shutdown() @@ -190,6 +196,8 @@ def demonstrate_default_consumption_sample_strategy(): tq.init() + tq_client = tq.get_client() + # Add 6 samples print("\n[Setup] Adding 6 samples...") all_data = TensorDict( @@ -198,22 +206,22 @@ def demonstrate_default_consumption_sample_strategy(): }, batch_size=6, ) - tq.put(data=all_data, partition_id="sampling") + tq_client.put(data=all_data, partition_id="sampling") print(" βœ“ 6 samples added") # First get - should get samples 0,1,2 print("\n[Task A, Get 1] Requesting 3 samples...") - meta1 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") + meta1 = tq_client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") print(f" βœ“ Got samples: {meta1.global_indexes}") # Second get - should get samples 3,4,5 (no duplicates!) print("\n[Task A, Get 2] Requesting 3 more samples...") - meta2 = tq.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") + meta2 = tq_client.get_meta(data_fields=["data"], batch_size=3, partition_id="sampling", task_name="A") print(f" βœ“ Got samples: {meta2.global_indexes}") # Third get - should get samples 0,1 print("\n[Task B, Get 1] Requesting 2 samples...") - meta3 = tq.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B") + meta3 = tq_client.get_meta(data_fields=["data"], batch_size=2, partition_id="sampling", task_name="B") print(f" βœ“ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -224,7 +232,7 @@ def demonstrate_default_consumption_sample_strategy(): print(" βœ“ Third get (Task B): samples 0,1") # Cleanup - tq.clear_partition(partition_id="sampling") + tq_client.clear_partition(partition_id="sampling") tq.close() ray.shutdown() @@ -235,7 +243,7 @@ def main(): print( textwrap.dedent( """ - TransferQueue Tutorial 3: Understanding TransferQueueController + TransferQueue Tutorial 4: Understanding TransferQueueController This script demonstrates TransferQueueController's key features: diff --git a/tutorial/04_custom_sampler.py b/tutorial/05_custom_sampler.py similarity index 90% rename from tutorial/04_custom_sampler.py rename to tutorial/05_custom_sampler.py index c35b4e0..e30487b 100644 --- a/tutorial/04_custom_sampler.py +++ b/tutorial/05_custom_sampler.py @@ -186,6 +186,8 @@ def demonstrate_random_sampler_with_replacement(): sampler = RandomSamplerWithReplacement() setup_transfer_queue_with_sampler(sampler) + tq_client = tq.get_client() + # Add 5 samples print("\n[Step 1] Adding 5 samples...") data = TensorDict( @@ -194,22 +196,22 @@ def demonstrate_random_sampler_with_replacement(): }, batch_size=5, ) - tq.put(data=data, partition_id="test") + tq_client.put(data=data, partition_id="test") print(" βœ“ 5 samples added") # Get batch 1 (should get 2 random samples) print("\n[Step 2] Get batch 1 (2 samples)...") - meta1 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta1 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" βœ“ Got samples: {meta1.global_indexes}") # Get batch 2 (should get 1 random sample with replacement - may have duplicate with previous batch!) print("\n[Step 3] Get batch 2 (1 sample)...") - meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") + meta2 = tq_client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") print(f" βœ“ Got samples: {meta2.global_indexes}") # Get batch 3 (should get 2 random samples with replacement - may have duplicate with previous batches!) print("\n[Step 4] Get batch 3 (2 samples)...") - meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta3 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" βœ“ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -219,7 +221,7 @@ def demonstrate_random_sampler_with_replacement(): print(f" βœ“ All sampled: {all_sampled}") # Cleanup - tq.clear_partition(partition_id="test") + tq_client.clear_partition(partition_id="test") tq.close() ray.shutdown() @@ -234,6 +236,8 @@ def demonstrate_random_sampler_without_replacement(): sampler = RandomSamplerWithoutReplacement() setup_transfer_queue_with_sampler(sampler) + tq_client = tq.get_client() + # Add 6 samples print("\n[Step 1] Adding 6 samples...") data = TensorDict( @@ -242,22 +246,22 @@ def demonstrate_random_sampler_without_replacement(): }, batch_size=6, ) - tq.put(data=data, partition_id="test") + tq_client.put(data=data, partition_id="test") print(" βœ“ 6 samples added") # Get batch 1 (should get 3 random samples without replacement) print("\n[Step 2] Get batch 1 (3 samples)...") - meta1 = tq.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task") + meta1 = tq_client.get_meta(data_fields=["input"], batch_size=3, partition_id="test", task_name="demo_task") print(f" βœ“ Got samples: {meta1.global_indexes}") # Get batch 2 (should randomly get 1 sample that are different from previous batch) print("\n[Step 3] Get batch 2 (1 samples)...") - meta2 = tq.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") + meta2 = tq_client.get_meta(data_fields=["input"], batch_size=1, partition_id="test", task_name="demo_task") print(f" βœ“ Got samples: {meta2.global_indexes}") # Get batch 3 (should randomly get 2 samples that are different from previous batch) print("\n[Step 4] Get batch 3 (2 samples)...") - meta3 = tq.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") + meta3 = tq_client.get_meta(data_fields=["input"], batch_size=2, partition_id="test", task_name="demo_task") print(f" βœ“ Got samples: {meta3.global_indexes}") print("\n[Verification]") @@ -267,7 +271,7 @@ def demonstrate_random_sampler_without_replacement(): print(f" βœ“ Batch 3: {meta3.global_indexes} (none left)") # Cleanup - tq.clear_partition(partition_id="test") + tq_client.clear_partition(partition_id="test") tq.close() ray.shutdown() @@ -282,6 +286,8 @@ def demonstrate_priority_sampler(): sampler = PrioritySampler() setup_transfer_queue_with_sampler(sampler) + tq_client = tq.get_client() + # Add 8 samples print("\n[Step 1] Adding 8 samples...") data = TensorDict( @@ -290,7 +296,7 @@ def demonstrate_priority_sampler(): }, batch_size=8, ) - tq.put(data=data, partition_id="test") + tq_client.put(data=data, partition_id="test") print(" βœ“ 8 samples added") time.sleep(1) @@ -303,7 +309,7 @@ def demonstrate_priority_sampler(): print(f"Priority scores: {priority_scores}") # Get batch using priority sampling - meta1 = tq.get_meta( + meta1 = tq_client.get_meta( data_fields=["input"], batch_size=1, partition_id="test", @@ -315,7 +321,7 @@ def demonstrate_priority_sampler(): # Get another batch print("\n[Step 3] Get another batch (2 samples)...") - meta2 = tq.get_meta( + meta2 = tq_client.get_meta( data_fields=["input"], batch_size=2, partition_id="test", @@ -331,7 +337,7 @@ def demonstrate_priority_sampler(): print(f" βœ“ Batch 2 high-priority indices: {[i for i in meta2.global_indexes if priority_scores[i] >= 0.1]}") # Cleanup - tq.clear_partition(partition_id="test") + tq_client.clear_partition(partition_id="test") tq.close() ray.shutdown() @@ -341,7 +347,7 @@ def main(): print( textwrap.dedent( """ - TransferQueue Tutorial 4: Custom Sampler Development + TransferQueue Tutorial 5: Custom Sampler Development This script demonstrates how to develop custom samplers for TransferQueue. Samplers control HOW data is consumed from the queue. diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/06_streaming_dataloader.py similarity index 98% rename from tutorial/05_streaming_dataloader.py rename to tutorial/06_streaming_dataloader.py index 4d92aa2..c60b4e3 100644 --- a/tutorial/05_streaming_dataloader.py +++ b/tutorial/06_streaming_dataloader.py @@ -123,6 +123,8 @@ def generate_worker(rank_id: int, num_samples: int = 20): # Need to call tq.init() in each process tq.init() + tq_client = tq.get_client() + # Generate and put samples into the queue for i in range(num_samples): # Create unique sequence ID for this sample @@ -137,7 +139,7 @@ def generate_worker(rank_id: int, num_samples: int = 20): print(f"[Generate Worker@{rank_id}]: Putting sample {seq_id} into TransferQueue") # Put data into the specified partition - tq.put(data, partition_id="train") + tq_client.put(data, partition_id="train") print(f"[Generate Worker@{rank_id}]: Complete putting samples into TransferQueue") @@ -289,7 +291,7 @@ def main(): print( textwrap.dedent( """ - TransferQueue Tutorial 5: StreamingDataLoader for Distributed Training + TransferQueue Tutorial 6: StreamingDataLoader for Distributed Training This tutorial demonstrates the StreamingDataLoader interface for distributed training scenarios. It showcases how to use StreamingDataset and StreamingDataLoader