From d43b0141550863fe976cbf0f96c55685583deee5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9C=8B=E6=88=9172=E9=81=8D?= Date: Wed, 4 Feb 2026 11:56:51 +0800 Subject: [PATCH 1/9] test: Add E2E consistency tests with comprehensive data type coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Added tests/test_e2e_consistency.py covering core types, multi-round puts, and slicing. - Updated terminology to Standard/Complex groups. - Verified cross-batch put scenarios. - Fixed linter errors. Signed-off-by: 看我72遍 --- tests/test_e2e_consistency.py | 530 ++++++++++++++++++++++++++++++++++ 1 file changed, 530 insertions(+) create mode 100644 tests/test_e2e_consistency.py diff --git a/tests/test_e2e_consistency.py b/tests/test_e2e_consistency.py new file mode 100644 index 0000000..88b38ad --- /dev/null +++ b/tests/test_e2e_consistency.py @@ -0,0 +1,530 @@ +import sys +import time +from pathlib import Path + +import numpy as np +import pytest +import ray +import torch +from tensordict import NonTensorStack, TensorDict + +from transfer_queue import ( + SimpleStorageUnit, + TransferQueueClient, + TransferQueueController, +) + +# Setup paths +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + + +@pytest.fixture(scope="module") +def ray_cluster(): + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + yield + if ray.is_initialized(): + ray.shutdown() + + +@pytest.fixture(scope="module") +def tq_setup(ray_cluster): + # 1. Start Controller + controller_actor = TransferQueueController.options( + name="test_controller", + get_if_exists=True, + ).remote() + controller_info = ray.get(controller_actor.get_zmq_server_info.remote()) + + # 2. Start Storage Unit + storage_actor = SimpleStorageUnit.options( + name="test_storage", + get_if_exists=True, + ).remote(storage_unit_size=10000) + storage_info = ray.get(storage_actor.get_zmq_server_info.remote()) + + # 3. Setup Client + + client_id = "test_e2e_client" + client = TransferQueueClient( + client_id=client_id, + controller_info=controller_info, + ) + + # Initialize Storage Manager (AsyncSimpleStorageManager) + # We need to manually configure it to know about our specific storage unit + config = { + "controller_info": controller_info, + "storage_unit_infos": {storage_info.id: storage_info}, + "storage_backend_config": {"storage_unit_size": 10000}, + } + + client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + + yield client, controller_actor, storage_actor + + +def assert_data_equal(original, retrieved, msg=""): + """Recursively check data equality for various types.""" + if isinstance(original, torch.Tensor): + assert isinstance(retrieved, torch.Tensor), f"{msg} Type mismatch: expected Tensor, got {type(retrieved)}" + # Check for nested tensor + if original.is_nested: + assert retrieved.is_nested, f"{msg} Expected nested tensor" + assert len(original) == len(retrieved), f"{msg} Nested tensor length mismatch" + for i in range(len(original)): + # Recurse for nested elements + assert_data_equal(original[i], retrieved[i], msg=f"{msg} Nested index {i} mismatch") + else: + # Handle potential NaN/Inf + # equal_nan=True is generally safe for equality checks in tests + torch.testing.assert_close(original, retrieved, msg=f"{msg} Tensor mismatch", equal_nan=True) + + elif isinstance(original, list | tuple): + # If it's a list, retrieved might be a NonTensorStack or list + if isinstance(retrieved, NonTensorStack | list | tuple): + assert len(original) == len(retrieved), f"{msg} Length mismatch" + for i, (o, r) in enumerate(zip(original, retrieved, strict=False)): + assert_data_equal(o, r, msg=f"{msg} List index {i}") + else: + pytest.fail(f"{msg} Type mismatch: expected List/Tuple, got {type(retrieved)}") + + elif isinstance(original, np.ndarray): + np.testing.assert_array_equal(original, retrieved, err_msg=f"{msg} Numpy array mismatch") + + elif isinstance(original, TensorDict | dict): + assert isinstance(retrieved, TensorDict | dict), f"{msg} Type mismatch: expected Dict, got {type(retrieved)}" + for k in original.keys(): + assert k in retrieved, f"{msg} Missing key {k}" + assert_data_equal(original[k], retrieved[k], msg=f"{msg} Key {k}") + + else: + # Primitive types + assert original == retrieved, f"{msg} Value mismatch: {original} != {retrieved}" + + +@pytest.mark.timeout(60) +def test_consistency_core_types(tq_setup): + """ + Test Case 1: Core Data Types Coverage + - Tensor, NestedTensor, Non-Tensor (stackable/non-stackable) + """ + client, _, _ = tq_setup + partition_id = "test_core_types" + + # Define test data + batch_size = 5 + + # 1. Normal Tensor + tensor_data = torch.randn(batch_size, 10) + + # 2. Nested Tensor (Ragged) + nested_list = [torch.randn(i + 2) for i in range(batch_size)] + nested_tensor = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged) + + # 3. Stackable Non-Tensor (List of ints) + list_int = [i * 10 for i in range(batch_size)] + + # 4. Non-Stackable / Scalar-like mixed (Strings) + list_str = [f"sample_{i}" for i in range(batch_size)] + + # 5. List of numpy arrays + list_numpy = [np.array([i, i + 1]) for i in range(batch_size)] + + # 6. Numpy Object (Strings/Mixed) + # TransferQueue should handle this as NonTensor or specific serialization + np_obj = np.array([f"obj_{i}" for i in range(batch_size)], dtype=object) + + # 7. Special Values (Inf/NaN) & Bool + special_tensor = torch.zeros(batch_size, 3) + special_tensor[:, 0] = float("inf") + special_tensor[:, 1] = float("nan") + bool_tensor = torch.rand(batch_size, 5) > 0.5 + + # 8. Non-contiguous Tensor + large_t = torch.randn(batch_size, 20) + non_contiguous = large_t[:, ::2] # Stride 2 + + data = TensorDict( + { + "tensor_field": tensor_data, + "nested_field": nested_tensor, + "list_int_field": list_int, + "list_str_field": list_str, + "list_numpy_field": list_numpy, + "np_object_field": np_obj, + "special_field": special_tensor, + "bool_field": bool_tensor, + "non_orig_field": non_contiguous, + }, + batch_size=batch_size, + ) + + # Put Data + client.put(partition_id=partition_id, data=data) + + # Get Data + + # Poll for metadata until ready + max_retries = 10 + retrieved_data = None + + fields = [ + "tensor_field", + "nested_field", + "list_int_field", + "list_str_field", + "np_object_field", + "special_field", + "bool_field", + "non_orig_field", + ] + + meta = None + for _ in range(max_retries): + try: + meta = client.get_meta( + partition_id=partition_id, + data_fields=fields, + batch_size=batch_size, + mode="fetch", + task_name="test_worker", + ) + break + except Exception: + time.sleep(0.5) + + assert meta is not None, "Failed to retrieve metadata" + + retrieved_data = client.get_data(meta) + + # Verification + assert_data_equal(data["tensor_field"], retrieved_data["tensor_field"], "Tensor Field") + assert_data_equal(data["nested_field"], retrieved_data["nested_field"], "Nested Field") + + # For Non-Tensor, TransferQueue might return them as NonTensorStack or list + assert_data_equal(data["list_int_field"], retrieved_data["list_int_field"], "List Int Field") + assert_data_equal(data["list_str_field"], retrieved_data["list_str_field"], "List Str Field") + + # Verify complex types + assert_data_equal(data["np_object_field"], retrieved_data["np_object_field"], "Numpy Object Field") + + # Special Floats - NaN needs special check in assert checking/allclose + assert_data_equal(data["special_field"], retrieved_data["special_field"], "Special Float Field") + + assert_data_equal(data["bool_field"], retrieved_data["bool_field"], "Bool Field") + assert_data_equal(data["non_orig_field"], retrieved_data["non_orig_field"], "Non-contiguous Field") + + +@pytest.mark.timeout(120) +def test_consistency_multi_round_put_get(tq_setup): + """ + Test Case 2: Multi-round Put & Field Merge + Simulate fragmented writing and field stitching. + """ + client, _, _ = tq_setup + partition_id = "test_multi_round" + + # Define Indices + idx_round1 = list(range(0, 20)) + idx_round2 = list(range(20, 41)) # 21 items + + # ... (gen_data functions same) ... + # Data Generators with Descriptive Names + # Group 1: Standard & Scalar Types (Tensor, List[str], List[int], Special, Bool) + def gen_group_standard(indices): + n = len(indices) + # 1. Normal Tensor + tensor_data = torch.randn(n, 5) + indices[0] + # 2. List of Strings + list_str = [f"str_{i}" for i in indices] + # 3. List of Ints + list_int = [i * 10 for i in indices] + # 4. Special Floats + special_tensor = torch.zeros(n, 3) + special_tensor[:, 0] = float("inf") + special_tensor[:, 1] = float("nan") + # 5. Bool + bool_tensor = torch.rand(n, 5) > 0.5 + + return TensorDict( + { + "f_tensor": tensor_data, + "f_list_str": list_str, + "f_list_int": list_int, + "f_special": special_tensor, + "f_bool": bool_tensor, + }, + batch_size=n, + ) + + # Group 2: Complex & Nested Types (NestedTensor, List[numpy], NumpyObj, Non-Contig) + def gen_group_complex(indices): + n = len(indices) + # 6. Nested Tensor + nested_list = [torch.full((i % 5 + 1,), float(i)) for i in indices] + nested_tensor = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged) + # 7. List of Numpy Arrays + list_numpy = [np.array([i, i * 2]) for i in indices] + # 8. Numpy Object + np_obj = np.array([f"obj_{i}" for i in indices], dtype=object) + # 9. Non-contiguous Tensor + large_t = torch.randn(n, 20) + non_contiguous = large_t[:, ::2] + + return TensorDict( + {"f_nested": nested_tensor, "f_list_numpy": list_numpy, "f_np_obj": np_obj, "f_non_contig": non_contiguous}, + batch_size=n, + ) + + # Helper to support updates on specific indices + def update_samples(global_indices, data): + import asyncio + + from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta + + samples = [] + field_names = list(data.keys()) + for i, idx in enumerate(global_indices): + fields_dict = {name: FieldMeta(name=name, dtype=None, shape=None) for name in field_names} + samples.append(SampleMeta(partition_id=partition_id, global_index=idx, fields=fields_dict)) + + meta = BatchMeta(samples=samples) + + future = asyncio.run_coroutine_threadsafe(client.storage_manager.put_data(data, meta), client._loop) + try: + future.result(timeout=10) + except Exception: + raise + + # Step 1: Pre-allocate Indices in TWO separate batches to test cross-batch retrieval + # Batch 1 (0-19) | Batch 2 (20-40) + all_fields = [ + "f_tensor", + "f_list_str", + "f_list_int", + "f_special", + "f_bool", + "f_nested", + "f_list_numpy", + "f_np_obj", + "f_non_contig", + ] + + # Allocation 1: Indices 0-19 + meta_alloc_1 = client.get_meta( + partition_id=partition_id, data_fields=all_fields, batch_size=20, mode="insert", task_name="allocator_1" + ) + idx_round1 = meta_alloc_1.global_indexes + assert len(idx_round1) == 20 + + # Allocation 2: Indices 20-40 + meta_alloc_2 = client.get_meta( + partition_id=partition_id, data_fields=all_fields, batch_size=21, mode="insert", task_name="allocator_2" + ) + idx_round2 = meta_alloc_2.global_indexes + assert len(idx_round2) == 21 + + # Full list for reference + all_indices = idx_round1 + idx_round2 + + # --- Write Operations with Cross-Batch Logic --- + + # Op 1: Put Standard Group for Batch 1 (0-19) + data_1_std = gen_group_standard(idx_round1) + update_samples(idx_round1, data_1_std) + + # Op 2: Put Standard Group for Batch 2 (20-40) + data_2_std = gen_group_standard(idx_round2) + update_samples(idx_round2, data_2_std) + + # Op 3: Cross-batch Put for Complex Group (indices 15-25) + idx_cross = all_indices[15:26] # 15 to 25 inclusive + data_cross_complex = gen_group_complex(idx_cross) + update_samples(idx_cross, data_cross_complex) + + # Op 4: Fill remaining Complex Group (0-14) + idx_remaining_1 = all_indices[0:15] + data_rem_1_complex = gen_group_complex(idx_remaining_1) + update_samples(idx_remaining_1, data_rem_1_complex) + + # Op 5: Fill remaining Complex Group (26-40) + idx_remaining_2 = all_indices[26:] + data_rem_2_complex = gen_group_complex(idx_remaining_2) + update_samples(idx_remaining_2, data_rem_2_complex) + + # Verifications + + # 1. Get 0-10 + meta_1 = client.get_meta( + partition_id=partition_id, data_fields=all_fields, batch_size=10, mode="fetch", task_name="verifier_1" + ) + res_1 = client.get_data(meta_1) + assert len(res_1["f_tensor"]) == 10 + + # Verify presence of all fields + for field in all_fields: + assert field in res_1, f"Missing field {field}" + + # Verify NestedTensor property + assert res_1["f_nested"].is_nested, "f_nested should be a NestedTensor" + + # 2. Get 10-30 + meta_2 = client.get_meta( + partition_id=partition_id, data_fields=all_fields, batch_size=20, mode="fetch", task_name="verifier_1" + ) + res_2 = client.get_data(meta_2) + assert len(res_2["f_tensor"]) == 20 + + # 3. Get 30-40 + meta_3 = client.get_meta( + partition_id=partition_id, data_fields=all_fields, batch_size=11, mode="fetch", task_name="verifier_1" + ) + res_3 = client.get_data(meta_3) + assert len(res_3["f_tensor"]) == 11 + + +@pytest.mark.timeout(60) +def test_consistency_slicing_and_subset(tq_setup): + """ + Test Case 3: Slicing and Field Subsetting + """ + client, _, _ = tq_setup + partition_id = "test_slicing" + + # Pre-allocate Partition + + all_fields = [ + "F_tensor", + "F_nested", + "F_list_int", + "F_list_str", + "F_list_numpy", + "F_np_obj", + "F_special", + "F_bool", + "F_non_contig", + ] + meta_alloc = client.get_meta( + partition_id=partition_id, data_fields=all_fields, batch_size=20, mode="insert", task_name="allocator" + ) + # USE THE ALLOCATED INDICES + indices = meta_alloc.global_indexes + assert len(indices) == 20 + + # Create Data using these indices + n = len(indices) + + # 1. Tensor + tensor_data = torch.randn(n, 5) + indices[0] + # 2. Nested + nested_list = [torch.tensor([i, i * 2, i * 3]) if i % 2 == 0 else torch.tensor([i]) for i in indices] + nested_tensor = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged) + # 3. List int + list_int = [i * 10 for i in indices] + # 4. List str + list_str = [f"slice_{i}" for i in indices] + # 5. List numpy + list_numpy = [np.array([i]) for i in indices] + # 6. Numpy Object + np_obj = np.array([f"obj_{i}" for i in indices], dtype=object) + # 7. Special + special_tensor = torch.zeros(n, 3) + special_tensor[:, 0] = float("inf") + # 8. Bool + bool_tensor = torch.rand(n, 1) > 0.5 + # 9. Non-contig + large_t = torch.randn(n, 10) + non_contig = large_t[:, ::2] + + data = TensorDict( + { + "F_tensor": tensor_data, + "F_nested": nested_tensor, + "F_list_int": list_int, + "F_list_str": list_str, + "F_list_numpy": list_numpy, + "F_np_obj": np_obj, + "F_special": special_tensor, + "F_bool": bool_tensor, + "F_non_contig": non_contig, + }, + batch_size=20, + ) + + # Helper for manual put + def update_samples(global_indices, data): + import asyncio + + from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta + + samples = [] + field_names = list(data.keys()) + for i, idx in enumerate(global_indices): + # Populate fields just in case + fields_dict = {name: FieldMeta(name=name, dtype=None, shape=None) for name in field_names} + samples.append(SampleMeta(partition_id=partition_id, global_index=idx, fields=fields_dict)) + + meta = BatchMeta(samples=samples) + future = asyncio.run_coroutine_threadsafe(client.storage_manager.put_data(data, meta), client._loop) + future.result() + + update_samples(indices, data) + + # 1. Field Subset: Get only F_nested (Ragged) + meta_subset = client.get_meta( + partition_id=partition_id, + data_fields=["F_nested"], + batch_size=20, # Ignored in force_fetch + mode="force_fetch", + task_name="inspector", + ) + res_subset = client.get_data(meta_subset) + + assert "F_nested" in res_subset + assert "F_tensor" not in res_subset + assert len(res_subset["F_nested"]) == 20 + + fetched_indices = meta_subset.global_indexes + fetched_values = res_subset["F_nested"] + + # Verify F_nested is NestedTensor + assert fetched_values.is_nested + + for idx, val in zip(fetched_indices, fetched_values, strict=False): + if idx % 2 == 0: + expected = torch.tensor([idx, idx * 2, idx * 3]) + else: + expected = torch.tensor([idx]) + + torch.testing.assert_close(val, expected.to(val.dtype)) + + # 2. Get F_np_obj (Numpy Objects/Mixed) & F_special (Inf/Nan) + meta_mixed = client.get_meta( + partition_id=partition_id, + data_fields=["F_np_obj", "F_special"], + batch_size=20, + mode="force_fetch", + task_name="inspector_2", + ) + res_mixed = client.get_data(meta_mixed) + assert "F_np_obj" in res_mixed + assert "F_special" in res_mixed + + # Verify F_np_obj + fetched_indices_mixed = meta_mixed.global_indexes + fetched_obj = res_mixed["F_np_obj"] + + assert len(fetched_obj) == 20 + for idx, val in zip(fetched_indices_mixed, fetched_obj, strict=False): + expected_val = f"obj_{idx}" + assert val == expected_val, f"Mismatch for index {idx}: {val} != {expected_val}" + + # Verify F_special (checking logic roughly, we know it has Inf) + fetched_special = res_mixed["F_special"] + assert torch.isinf(fetched_special).any(), "Expected Inf values in special tensor" + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) From adf6b56912457ba88c9c2c21dfc8bd76315cf1dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9C=8B=E6=88=9172=E9=81=8D?= Date: Wed, 4 Feb 2026 14:40:55 +0800 Subject: [PATCH 2/9] test: move sys.path setup before package import in e2e test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 看我72遍 --- tests/test_e2e_consistency.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_e2e_consistency.py b/tests/test_e2e_consistency.py index 88b38ad..c6a4beb 100644 --- a/tests/test_e2e_consistency.py +++ b/tests/test_e2e_consistency.py @@ -8,16 +8,16 @@ import torch from tensordict import NonTensorStack, TensorDict -from transfer_queue import ( +# Setup paths +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + +from transfer_queue import ( # noqa: E402 SimpleStorageUnit, TransferQueueClient, TransferQueueController, ) -# Setup paths -parent_dir = Path(__file__).resolve().parent.parent -sys.path.append(str(parent_dir)) - @pytest.fixture(scope="module") def ray_cluster(): From fc5948edf30da49d648a1abc16866ed98430a3c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9C=8B=E6=88=9172=E9=81=8D?= Date: Wed, 4 Feb 2026 16:34:04 +0800 Subject: [PATCH 3/9] Refactor e2e consistency tests: cleanup fixture and deduplicate validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 看我72遍 --- tests/e2e/test_e2e_consistency.py | 463 ++++++++++++++++++++++++++ tests/test_e2e_consistency.py | 530 ------------------------------ 2 files changed, 463 insertions(+), 530 deletions(-) create mode 100644 tests/e2e/test_e2e_consistency.py delete mode 100644 tests/test_e2e_consistency.py diff --git a/tests/e2e/test_e2e_consistency.py b/tests/e2e/test_e2e_consistency.py new file mode 100644 index 0000000..50c9bfd --- /dev/null +++ b/tests/e2e/test_e2e_consistency.py @@ -0,0 +1,463 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import collections.abc +import hashlib +import sys +import time +from pathlib import Path +from typing import Any + +import numpy as np +import pytest +import ray +import torch +from tensordict import TensorDict + +# Setup paths +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + +from transfer_queue import ( # noqa: E402 + SimpleStorageUnit, + TransferQueueClient, + TransferQueueController, +) +from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 + + +@pytest.fixture(scope="module") +def ray_cluster(): + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + yield + if ray.is_initialized(): + ray.shutdown() + + +@pytest.fixture(scope="module") +def e2e_client(ray_cluster): + controller_actor = TransferQueueController.options( + name="test_controller", + get_if_exists=True, + ).remote() + controller_info = ray.get(controller_actor.get_zmq_server_info.remote()) + + # Start two storage units to test sharding/scatter + storage_actor_1 = SimpleStorageUnit.options( + name="test_storage_1", + get_if_exists=True, + ).remote(storage_unit_size=10000) + storage_info_1 = ray.get(storage_actor_1.get_zmq_server_info.remote()) + + storage_actor_2 = SimpleStorageUnit.options( + name="test_storage_2", + get_if_exists=True, + ).remote(storage_unit_size=10000) + storage_info_2 = ray.get(storage_actor_2.get_zmq_server_info.remote()) + + client_id = "test_e2e_client" + client = TransferQueueClient( + client_id=client_id, + controller_info=controller_info, + ) + + # Initialize Storage Manager (AsyncSimpleStorageManager) and configure it + config = { + "controller_info": controller_info, + "storage_unit_infos": { + storage_info_1.id: storage_info_1, + storage_info_2.id: storage_info_2, + }, + "storage_backend_config": {"storage_unit_size": 10000}, + } + + client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + + yield client + + +def generate_consistency_data(indices: list[int]) -> TensorDict: + """ + Generates a TensorDict with deterministic data based on the provided indices. + Includes all field types used in consistency tests. + """ + n = len(indices) + + # Standard Tensor: shape (n, 5), values based on index + tensor_data = torch.stack([torch.arange(i, i + 5, dtype=torch.float32) for i in indices]) + + # Nested Tensor (Jagged): + nested_list = [] + for i in indices: + if i % 2 == 0: + nested_list.append(torch.tensor([i, i * 2, i * 3], dtype=torch.float32)) + else: + nested_list.append(torch.tensor([i], dtype=torch.float32)) + nested_tensor = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged) + + # Strided Nested Tensor (or fallback) + try: + tensors = [torch.full((3, 4), float(i)) for i in indices] + strided_nested = torch.nested.nested_tensor(tensors, layout=torch.strided) + except Exception: + # Fallback for environments without strided nested support + tensors = [torch.full((3, 4), float(i)) for i in indices] + strided_nested = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) + + # List of Ints + int_values = [i * 10 for i in indices] + + # List of Strings + str_values = [f"str_{i}" for i in indices] + + # List of Numpy Arrays + numpy_list = [np.array([i, i + 1], dtype=np.int64) for i in indices] + + # Numpy Object Array (Strings) + numpy_objects = np.array([f"obj_{i}" for i in indices], dtype=object) + + # Special Values (Inf/NaN) + special_tensor = torch.zeros(n, 3) + special_tensor[:, 0] = float("inf") + special_tensor[:, 1] = float("nan") + special_tensor[:, 2] = torch.tensor(indices, dtype=torch.float32) + + # Bool Tensor + bool_tensor = torch.tensor([[i % 2 == 0] for i in indices], dtype=torch.bool) + + # Non-contiguous Tensor + # orig shape (n, 20), slice ::2 -> (n, 10). + large_tensor = torch.stack([torch.full((20,), float(i)) for i in indices]) + non_contiguous_tensor = large_tensor[:, ::2] + + return TensorDict( + { + "tensor_field": tensor_data, + "nested_field": nested_tensor, + "strided_nested_field": strided_nested, + "list_int_field": int_values, + "list_str_field": str_values, + "list_numpy_field": numpy_list, + "np_object_field": numpy_objects, + "special_field": special_tensor, + "bool_field": bool_tensor, + "non_orig_field": non_contiguous_tensor, + }, + batch_size=n, + ) + + +def compute_data_hash(data: Any, algorithm="sha256") -> str: + """ + Computes a structure-agnostic hash of the data. + - Flattens nested structures (lists, tuples, tensors). + - Ignores container types. + - Normalizes Tensors (detach, cpu). + - Consistent for NestedTensor vs List[Tensor]. + - Consistent for List[int] vs Tensor(int). + """ + hasher = hashlib.new(algorithm) + _update_hash(hasher, data) + return hasher.hexdigest() + + +def _update_hash(hasher, data): + if isinstance(data, dict | TensorDict) or (hasattr(data, "keys") and hasattr(data, "__getitem__")): + # Check if it behaves like a mapping + try: + keys = sorted(list(data.keys())) + for k in keys: + _hash_scalar(hasher, str(k)) + _update_hash(hasher, data[k]) + return + except TypeError: + pass + + if isinstance(data, torch.Tensor): + if data.is_nested: + for t in data.unbind(): + _update_hash(hasher, t) + return + + data = data.detach().cpu() + if data.ndim == 0: + _hash_scalar(hasher, data.item()) + return + + if isinstance(data, np.ndarray): + if data.ndim == 0: + _hash_scalar(hasher, data.item()) + elif data.dtype == object: + for item in data: + _update_hash(hasher, item) + else: + for item in data: + _update_hash(hasher, item) + return + + if isinstance(data, str | bytes): + _hash_scalar(hasher, data) + return + + # Handles list, tuple, and any custom Sequence + if isinstance(data, collections.abc.Sequence): + for item in data: + _update_hash(hasher, item) + return + + _hash_scalar(hasher, data) + + +def _hash_scalar(hasher, data): + s = str(data) + if s == "nan": + s = "NAN" + hasher.update(s.encode("utf-8")) + hasher.update(b"|") + + +def put_data_with_indices(client, partition_id, indices, data): + """ + Helper to put data with specific global indices. + """ + samples = [] + field_names = list(data.keys()) + + # Pre-compute field metas + for idx in indices: + fields_dict = {name: FieldMeta(name=name, dtype=None, shape=None) for name in field_names} + samples.append(SampleMeta(partition_id=partition_id, global_index=idx, fields=fields_dict)) + + meta = BatchMeta(samples=samples) + + future = asyncio.run_coroutine_threadsafe(client.storage_manager.put_data(data, meta), client._loop) + # Wait for result + future.result(timeout=30) + + +def get_fields_subset(indices, field_names): + """Helper to slice TensorDict by fields.""" + full_data = generate_consistency_data(indices) + # Select only specific fields + # Note: TensorDict.select returns a new TensorDict with shared storage + return full_data.select(*field_names) + + +def verify_data_consistency( + client, partition_id: str, task_name: str, data_fields: list[str], batch_size: int, mode: str = "fetch" +) -> TensorDict: + """Helper to retrieve data and verify it matches deterministic generation logic.""" + # Poll for metadata until ready + max_retries = 10 + meta = None + for _ in range(max_retries): + try: + meta = client.get_meta( + partition_id=partition_id, + data_fields=data_fields, + batch_size=batch_size, + mode=mode, + task_name=task_name, + ) + break + except Exception: + time.sleep(0.5) + + assert meta is not None, f"Failed to retrieve metadata for {task_name}" + + retrieved_data = client.get_data(meta) + + # Generate partial expected data based on retrieved indices + full_expected = generate_consistency_data(meta.global_indexes) + expected_data = full_expected.select(*data_fields) + + # Verify Hash + retrieved_hash = compute_data_hash(retrieved_data) + expected_hash = compute_data_hash(expected_data) + assert retrieved_hash == expected_hash, f"Hash mismatch for {task_name}" + + return retrieved_data + + +@pytest.mark.timeout(60) +def test_consistency_core_types(e2e_client): + """ + Test Case 1: Core Data Types Coverage + - Tensor, NestedTensor, Non-Tensor (stackable/non-stackable) + """ + client = e2e_client + + # Use distinct partition to avoid conflict if tests run shared (though scope is module) + partition_id = "test_core_types" + + batch_size = 5 + + # Get fields list from dummy data + dummy_data = generate_consistency_data([0]) + fields = list(dummy_data.keys()) + + # 1. Allocate Partition & Indices + allocation_meta = client.get_meta( + partition_id=partition_id, data_fields=fields, batch_size=batch_size, mode="insert", task_name="allocator" + ) + indices = allocation_meta.global_indexes + assert len(indices) == batch_size + + # 2. Generate Data using allocated indices + data = generate_consistency_data(indices) + + put_data_with_indices(client, partition_id, indices, data) + + # 3. Get Data and Verify + verify_data_consistency(client, partition_id, "test_worker", list(data.keys()), batch_size) + + +@pytest.mark.timeout(120) +def test_consistency_multi_round_put_get(e2e_client): + """ + Test Case 2: Multi-round Put & Field Merge + Simulate fragmented writing and field stitching. + """ + client = e2e_client + partition_id = "test_multi_round" + + # Define Indices + idx_round1 = list(range(0, 20)) + idx_round2 = list(range(20, 41)) # 21 items + + # Step 1: Allocations + # Define Field Groups + # Group 1: Standard + group_std_fields = ["tensor_field", "list_str_field", "list_int_field", "special_field", "bool_field"] + # Group 2: Complex + group_complex_fields = [ + "nested_field", + "strided_nested_field", + "list_numpy_field", + "np_object_field", + "non_orig_field", + ] + + all_fields = group_std_fields + group_complex_fields + + # Allocation 1: Indices 0-19 + allocation_meta_1 = client.get_meta( + partition_id=partition_id, data_fields=all_fields, batch_size=20, mode="insert", task_name="allocator_1" + ) + assert len(allocation_meta_1.global_indexes) == 20 + + # Allocation 2: Indices 20-40 + allocation_meta_2 = client.get_meta( + partition_id=partition_id, data_fields=all_fields, batch_size=21, mode="insert", task_name="allocator_2" + ) + assert len(allocation_meta_2.global_indexes) == 21 + + # Full list for reference + all_indices = idx_round1 + idx_round2 + + # --- Write Operations with Cross-Batch Logic --- + + # Op 1: Put Standard Group for Batch 1 (0-19) + data_1_std = get_fields_subset(idx_round1, group_std_fields) + put_data_with_indices(client, partition_id, idx_round1, data_1_std) + + # Op 2: Put Standard Group for Batch 2 (20-40) + data_2_std = get_fields_subset(idx_round2, group_std_fields) + put_data_with_indices(client, partition_id, idx_round2, data_2_std) + + # Op 3: Cross-batch Put for Complex Group (indices 15-25) + idx_cross = all_indices[15:26] # 15 to 25 inclusive + data_cross_complex = get_fields_subset(idx_cross, group_complex_fields) + put_data_with_indices(client, partition_id, idx_cross, data_cross_complex) + + # Op 4: Fill remaining Complex Group (0-14) + idx_remaining_1 = all_indices[0:15] + data_rem_1_complex = get_fields_subset(idx_remaining_1, group_complex_fields) + put_data_with_indices(client, partition_id, idx_remaining_1, data_rem_1_complex) + + # Op 5: Fill remaining Complex Group (26-40) + idx_remaining_2 = all_indices[26:] + data_rem_2_complex = get_fields_subset(idx_remaining_2, group_complex_fields) + put_data_with_indices(client, partition_id, idx_remaining_2, data_rem_2_complex) + + # Verifications + + # 1. Get 0-10 + verify_data_consistency(client, partition_id, "verifier_1", all_fields, 10) + + # 2. Get 10-30 (Cross boundary) + verify_data_consistency(client, partition_id, "verifier_2", all_fields, 20) + + # 3. Get 30-40 + verify_data_consistency(client, partition_id, "verifier_3", all_fields, 11) + + +@pytest.mark.timeout(60) +def test_consistency_slicing_and_subset(e2e_client): + """ + Test Case 3: Slicing and Field Subsetting + """ + client = e2e_client + partition_id = "test_slicing" + + # Pre-allocate Partition + all_fields = [ + "tensor_field", + "nested_field", + "strided_nested_field", + "list_int_field", + "list_str_field", + "list_numpy_field", + "np_object_field", + "special_field", + "bool_field", + "non_orig_field", + ] + allocation_meta = client.get_meta( + partition_id=partition_id, data_fields=all_fields, batch_size=20, mode="insert", task_name="allocator" + ) + indices = allocation_meta.global_indexes + assert len(indices) == 20 + + # Create Data using these indices + data = generate_consistency_data(indices) + + # Put Data + put_data_with_indices(client, partition_id, indices, data) + + # 1. Field Subset: Get only nested_field (Ragged) + result_subset = verify_data_consistency(client, partition_id, "inspector", ["nested_field"], 20, mode="force_fetch") + + assert "nested_field" in result_subset + assert "tensor_field" not in result_subset + assert len(result_subset["nested_field"]) == 20 + + # 2. Get np_object_field & special_field + result_mixed = verify_data_consistency( + client, partition_id, "inspector_2", ["np_object_field", "special_field"], 20, mode="force_fetch" + ) + + assert "np_object_field" in result_mixed + assert "special_field" in result_mixed + + # 3. Get strided_nested_field explicitly + verify_data_consistency(client, partition_id, "inspector_3", ["strided_nested_field"], 20, mode="force_fetch") + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/tests/test_e2e_consistency.py b/tests/test_e2e_consistency.py deleted file mode 100644 index c6a4beb..0000000 --- a/tests/test_e2e_consistency.py +++ /dev/null @@ -1,530 +0,0 @@ -import sys -import time -from pathlib import Path - -import numpy as np -import pytest -import ray -import torch -from tensordict import NonTensorStack, TensorDict - -# Setup paths -parent_dir = Path(__file__).resolve().parent.parent -sys.path.append(str(parent_dir)) - -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, -) - - -@pytest.fixture(scope="module") -def ray_cluster(): - if not ray.is_initialized(): - ray.init(ignore_reinit_error=True) - yield - if ray.is_initialized(): - ray.shutdown() - - -@pytest.fixture(scope="module") -def tq_setup(ray_cluster): - # 1. Start Controller - controller_actor = TransferQueueController.options( - name="test_controller", - get_if_exists=True, - ).remote() - controller_info = ray.get(controller_actor.get_zmq_server_info.remote()) - - # 2. Start Storage Unit - storage_actor = SimpleStorageUnit.options( - name="test_storage", - get_if_exists=True, - ).remote(storage_unit_size=10000) - storage_info = ray.get(storage_actor.get_zmq_server_info.remote()) - - # 3. Setup Client - - client_id = "test_e2e_client" - client = TransferQueueClient( - client_id=client_id, - controller_info=controller_info, - ) - - # Initialize Storage Manager (AsyncSimpleStorageManager) - # We need to manually configure it to know about our specific storage unit - config = { - "controller_info": controller_info, - "storage_unit_infos": {storage_info.id: storage_info}, - "storage_backend_config": {"storage_unit_size": 10000}, - } - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - - yield client, controller_actor, storage_actor - - -def assert_data_equal(original, retrieved, msg=""): - """Recursively check data equality for various types.""" - if isinstance(original, torch.Tensor): - assert isinstance(retrieved, torch.Tensor), f"{msg} Type mismatch: expected Tensor, got {type(retrieved)}" - # Check for nested tensor - if original.is_nested: - assert retrieved.is_nested, f"{msg} Expected nested tensor" - assert len(original) == len(retrieved), f"{msg} Nested tensor length mismatch" - for i in range(len(original)): - # Recurse for nested elements - assert_data_equal(original[i], retrieved[i], msg=f"{msg} Nested index {i} mismatch") - else: - # Handle potential NaN/Inf - # equal_nan=True is generally safe for equality checks in tests - torch.testing.assert_close(original, retrieved, msg=f"{msg} Tensor mismatch", equal_nan=True) - - elif isinstance(original, list | tuple): - # If it's a list, retrieved might be a NonTensorStack or list - if isinstance(retrieved, NonTensorStack | list | tuple): - assert len(original) == len(retrieved), f"{msg} Length mismatch" - for i, (o, r) in enumerate(zip(original, retrieved, strict=False)): - assert_data_equal(o, r, msg=f"{msg} List index {i}") - else: - pytest.fail(f"{msg} Type mismatch: expected List/Tuple, got {type(retrieved)}") - - elif isinstance(original, np.ndarray): - np.testing.assert_array_equal(original, retrieved, err_msg=f"{msg} Numpy array mismatch") - - elif isinstance(original, TensorDict | dict): - assert isinstance(retrieved, TensorDict | dict), f"{msg} Type mismatch: expected Dict, got {type(retrieved)}" - for k in original.keys(): - assert k in retrieved, f"{msg} Missing key {k}" - assert_data_equal(original[k], retrieved[k], msg=f"{msg} Key {k}") - - else: - # Primitive types - assert original == retrieved, f"{msg} Value mismatch: {original} != {retrieved}" - - -@pytest.mark.timeout(60) -def test_consistency_core_types(tq_setup): - """ - Test Case 1: Core Data Types Coverage - - Tensor, NestedTensor, Non-Tensor (stackable/non-stackable) - """ - client, _, _ = tq_setup - partition_id = "test_core_types" - - # Define test data - batch_size = 5 - - # 1. Normal Tensor - tensor_data = torch.randn(batch_size, 10) - - # 2. Nested Tensor (Ragged) - nested_list = [torch.randn(i + 2) for i in range(batch_size)] - nested_tensor = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged) - - # 3. Stackable Non-Tensor (List of ints) - list_int = [i * 10 for i in range(batch_size)] - - # 4. Non-Stackable / Scalar-like mixed (Strings) - list_str = [f"sample_{i}" for i in range(batch_size)] - - # 5. List of numpy arrays - list_numpy = [np.array([i, i + 1]) for i in range(batch_size)] - - # 6. Numpy Object (Strings/Mixed) - # TransferQueue should handle this as NonTensor or specific serialization - np_obj = np.array([f"obj_{i}" for i in range(batch_size)], dtype=object) - - # 7. Special Values (Inf/NaN) & Bool - special_tensor = torch.zeros(batch_size, 3) - special_tensor[:, 0] = float("inf") - special_tensor[:, 1] = float("nan") - bool_tensor = torch.rand(batch_size, 5) > 0.5 - - # 8. Non-contiguous Tensor - large_t = torch.randn(batch_size, 20) - non_contiguous = large_t[:, ::2] # Stride 2 - - data = TensorDict( - { - "tensor_field": tensor_data, - "nested_field": nested_tensor, - "list_int_field": list_int, - "list_str_field": list_str, - "list_numpy_field": list_numpy, - "np_object_field": np_obj, - "special_field": special_tensor, - "bool_field": bool_tensor, - "non_orig_field": non_contiguous, - }, - batch_size=batch_size, - ) - - # Put Data - client.put(partition_id=partition_id, data=data) - - # Get Data - - # Poll for metadata until ready - max_retries = 10 - retrieved_data = None - - fields = [ - "tensor_field", - "nested_field", - "list_int_field", - "list_str_field", - "np_object_field", - "special_field", - "bool_field", - "non_orig_field", - ] - - meta = None - for _ in range(max_retries): - try: - meta = client.get_meta( - partition_id=partition_id, - data_fields=fields, - batch_size=batch_size, - mode="fetch", - task_name="test_worker", - ) - break - except Exception: - time.sleep(0.5) - - assert meta is not None, "Failed to retrieve metadata" - - retrieved_data = client.get_data(meta) - - # Verification - assert_data_equal(data["tensor_field"], retrieved_data["tensor_field"], "Tensor Field") - assert_data_equal(data["nested_field"], retrieved_data["nested_field"], "Nested Field") - - # For Non-Tensor, TransferQueue might return them as NonTensorStack or list - assert_data_equal(data["list_int_field"], retrieved_data["list_int_field"], "List Int Field") - assert_data_equal(data["list_str_field"], retrieved_data["list_str_field"], "List Str Field") - - # Verify complex types - assert_data_equal(data["np_object_field"], retrieved_data["np_object_field"], "Numpy Object Field") - - # Special Floats - NaN needs special check in assert checking/allclose - assert_data_equal(data["special_field"], retrieved_data["special_field"], "Special Float Field") - - assert_data_equal(data["bool_field"], retrieved_data["bool_field"], "Bool Field") - assert_data_equal(data["non_orig_field"], retrieved_data["non_orig_field"], "Non-contiguous Field") - - -@pytest.mark.timeout(120) -def test_consistency_multi_round_put_get(tq_setup): - """ - Test Case 2: Multi-round Put & Field Merge - Simulate fragmented writing and field stitching. - """ - client, _, _ = tq_setup - partition_id = "test_multi_round" - - # Define Indices - idx_round1 = list(range(0, 20)) - idx_round2 = list(range(20, 41)) # 21 items - - # ... (gen_data functions same) ... - # Data Generators with Descriptive Names - # Group 1: Standard & Scalar Types (Tensor, List[str], List[int], Special, Bool) - def gen_group_standard(indices): - n = len(indices) - # 1. Normal Tensor - tensor_data = torch.randn(n, 5) + indices[0] - # 2. List of Strings - list_str = [f"str_{i}" for i in indices] - # 3. List of Ints - list_int = [i * 10 for i in indices] - # 4. Special Floats - special_tensor = torch.zeros(n, 3) - special_tensor[:, 0] = float("inf") - special_tensor[:, 1] = float("nan") - # 5. Bool - bool_tensor = torch.rand(n, 5) > 0.5 - - return TensorDict( - { - "f_tensor": tensor_data, - "f_list_str": list_str, - "f_list_int": list_int, - "f_special": special_tensor, - "f_bool": bool_tensor, - }, - batch_size=n, - ) - - # Group 2: Complex & Nested Types (NestedTensor, List[numpy], NumpyObj, Non-Contig) - def gen_group_complex(indices): - n = len(indices) - # 6. Nested Tensor - nested_list = [torch.full((i % 5 + 1,), float(i)) for i in indices] - nested_tensor = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged) - # 7. List of Numpy Arrays - list_numpy = [np.array([i, i * 2]) for i in indices] - # 8. Numpy Object - np_obj = np.array([f"obj_{i}" for i in indices], dtype=object) - # 9. Non-contiguous Tensor - large_t = torch.randn(n, 20) - non_contiguous = large_t[:, ::2] - - return TensorDict( - {"f_nested": nested_tensor, "f_list_numpy": list_numpy, "f_np_obj": np_obj, "f_non_contig": non_contiguous}, - batch_size=n, - ) - - # Helper to support updates on specific indices - def update_samples(global_indices, data): - import asyncio - - from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta - - samples = [] - field_names = list(data.keys()) - for i, idx in enumerate(global_indices): - fields_dict = {name: FieldMeta(name=name, dtype=None, shape=None) for name in field_names} - samples.append(SampleMeta(partition_id=partition_id, global_index=idx, fields=fields_dict)) - - meta = BatchMeta(samples=samples) - - future = asyncio.run_coroutine_threadsafe(client.storage_manager.put_data(data, meta), client._loop) - try: - future.result(timeout=10) - except Exception: - raise - - # Step 1: Pre-allocate Indices in TWO separate batches to test cross-batch retrieval - # Batch 1 (0-19) | Batch 2 (20-40) - all_fields = [ - "f_tensor", - "f_list_str", - "f_list_int", - "f_special", - "f_bool", - "f_nested", - "f_list_numpy", - "f_np_obj", - "f_non_contig", - ] - - # Allocation 1: Indices 0-19 - meta_alloc_1 = client.get_meta( - partition_id=partition_id, data_fields=all_fields, batch_size=20, mode="insert", task_name="allocator_1" - ) - idx_round1 = meta_alloc_1.global_indexes - assert len(idx_round1) == 20 - - # Allocation 2: Indices 20-40 - meta_alloc_2 = client.get_meta( - partition_id=partition_id, data_fields=all_fields, batch_size=21, mode="insert", task_name="allocator_2" - ) - idx_round2 = meta_alloc_2.global_indexes - assert len(idx_round2) == 21 - - # Full list for reference - all_indices = idx_round1 + idx_round2 - - # --- Write Operations with Cross-Batch Logic --- - - # Op 1: Put Standard Group for Batch 1 (0-19) - data_1_std = gen_group_standard(idx_round1) - update_samples(idx_round1, data_1_std) - - # Op 2: Put Standard Group for Batch 2 (20-40) - data_2_std = gen_group_standard(idx_round2) - update_samples(idx_round2, data_2_std) - - # Op 3: Cross-batch Put for Complex Group (indices 15-25) - idx_cross = all_indices[15:26] # 15 to 25 inclusive - data_cross_complex = gen_group_complex(idx_cross) - update_samples(idx_cross, data_cross_complex) - - # Op 4: Fill remaining Complex Group (0-14) - idx_remaining_1 = all_indices[0:15] - data_rem_1_complex = gen_group_complex(idx_remaining_1) - update_samples(idx_remaining_1, data_rem_1_complex) - - # Op 5: Fill remaining Complex Group (26-40) - idx_remaining_2 = all_indices[26:] - data_rem_2_complex = gen_group_complex(idx_remaining_2) - update_samples(idx_remaining_2, data_rem_2_complex) - - # Verifications - - # 1. Get 0-10 - meta_1 = client.get_meta( - partition_id=partition_id, data_fields=all_fields, batch_size=10, mode="fetch", task_name="verifier_1" - ) - res_1 = client.get_data(meta_1) - assert len(res_1["f_tensor"]) == 10 - - # Verify presence of all fields - for field in all_fields: - assert field in res_1, f"Missing field {field}" - - # Verify NestedTensor property - assert res_1["f_nested"].is_nested, "f_nested should be a NestedTensor" - - # 2. Get 10-30 - meta_2 = client.get_meta( - partition_id=partition_id, data_fields=all_fields, batch_size=20, mode="fetch", task_name="verifier_1" - ) - res_2 = client.get_data(meta_2) - assert len(res_2["f_tensor"]) == 20 - - # 3. Get 30-40 - meta_3 = client.get_meta( - partition_id=partition_id, data_fields=all_fields, batch_size=11, mode="fetch", task_name="verifier_1" - ) - res_3 = client.get_data(meta_3) - assert len(res_3["f_tensor"]) == 11 - - -@pytest.mark.timeout(60) -def test_consistency_slicing_and_subset(tq_setup): - """ - Test Case 3: Slicing and Field Subsetting - """ - client, _, _ = tq_setup - partition_id = "test_slicing" - - # Pre-allocate Partition - - all_fields = [ - "F_tensor", - "F_nested", - "F_list_int", - "F_list_str", - "F_list_numpy", - "F_np_obj", - "F_special", - "F_bool", - "F_non_contig", - ] - meta_alloc = client.get_meta( - partition_id=partition_id, data_fields=all_fields, batch_size=20, mode="insert", task_name="allocator" - ) - # USE THE ALLOCATED INDICES - indices = meta_alloc.global_indexes - assert len(indices) == 20 - - # Create Data using these indices - n = len(indices) - - # 1. Tensor - tensor_data = torch.randn(n, 5) + indices[0] - # 2. Nested - nested_list = [torch.tensor([i, i * 2, i * 3]) if i % 2 == 0 else torch.tensor([i]) for i in indices] - nested_tensor = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged) - # 3. List int - list_int = [i * 10 for i in indices] - # 4. List str - list_str = [f"slice_{i}" for i in indices] - # 5. List numpy - list_numpy = [np.array([i]) for i in indices] - # 6. Numpy Object - np_obj = np.array([f"obj_{i}" for i in indices], dtype=object) - # 7. Special - special_tensor = torch.zeros(n, 3) - special_tensor[:, 0] = float("inf") - # 8. Bool - bool_tensor = torch.rand(n, 1) > 0.5 - # 9. Non-contig - large_t = torch.randn(n, 10) - non_contig = large_t[:, ::2] - - data = TensorDict( - { - "F_tensor": tensor_data, - "F_nested": nested_tensor, - "F_list_int": list_int, - "F_list_str": list_str, - "F_list_numpy": list_numpy, - "F_np_obj": np_obj, - "F_special": special_tensor, - "F_bool": bool_tensor, - "F_non_contig": non_contig, - }, - batch_size=20, - ) - - # Helper for manual put - def update_samples(global_indices, data): - import asyncio - - from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta - - samples = [] - field_names = list(data.keys()) - for i, idx in enumerate(global_indices): - # Populate fields just in case - fields_dict = {name: FieldMeta(name=name, dtype=None, shape=None) for name in field_names} - samples.append(SampleMeta(partition_id=partition_id, global_index=idx, fields=fields_dict)) - - meta = BatchMeta(samples=samples) - future = asyncio.run_coroutine_threadsafe(client.storage_manager.put_data(data, meta), client._loop) - future.result() - - update_samples(indices, data) - - # 1. Field Subset: Get only F_nested (Ragged) - meta_subset = client.get_meta( - partition_id=partition_id, - data_fields=["F_nested"], - batch_size=20, # Ignored in force_fetch - mode="force_fetch", - task_name="inspector", - ) - res_subset = client.get_data(meta_subset) - - assert "F_nested" in res_subset - assert "F_tensor" not in res_subset - assert len(res_subset["F_nested"]) == 20 - - fetched_indices = meta_subset.global_indexes - fetched_values = res_subset["F_nested"] - - # Verify F_nested is NestedTensor - assert fetched_values.is_nested - - for idx, val in zip(fetched_indices, fetched_values, strict=False): - if idx % 2 == 0: - expected = torch.tensor([idx, idx * 2, idx * 3]) - else: - expected = torch.tensor([idx]) - - torch.testing.assert_close(val, expected.to(val.dtype)) - - # 2. Get F_np_obj (Numpy Objects/Mixed) & F_special (Inf/Nan) - meta_mixed = client.get_meta( - partition_id=partition_id, - data_fields=["F_np_obj", "F_special"], - batch_size=20, - mode="force_fetch", - task_name="inspector_2", - ) - res_mixed = client.get_data(meta_mixed) - assert "F_np_obj" in res_mixed - assert "F_special" in res_mixed - - # Verify F_np_obj - fetched_indices_mixed = meta_mixed.global_indexes - fetched_obj = res_mixed["F_np_obj"] - - assert len(fetched_obj) == 20 - for idx, val in zip(fetched_indices_mixed, fetched_obj, strict=False): - expected_val = f"obj_{idx}" - assert val == expected_val, f"Mismatch for index {idx}: {val} != {expected_val}" - - # Verify F_special (checking logic roughly, we know it has Inf) - fetched_special = res_mixed["F_special"] - assert torch.isinf(fetched_special).any(), "Expected Inf values in special tensor" - - -if __name__ == "__main__": - sys.exit(pytest.main(["-v", __file__])) From 7fa6c11d4b28ab11ac90720830a52c93ef23e4df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9C=8B=E6=88=9172=E9=81=8D?= Date: Wed, 4 Feb 2026 16:36:01 +0800 Subject: [PATCH 4/9] move e2e tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 看我72遍 --- tests/{ => e2e}/test_yuanrong_storage_client_e2e.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{ => e2e}/test_yuanrong_storage_client_e2e.py (100%) diff --git a/tests/test_yuanrong_storage_client_e2e.py b/tests/e2e/test_yuanrong_storage_client_e2e.py similarity index 100% rename from tests/test_yuanrong_storage_client_e2e.py rename to tests/e2e/test_yuanrong_storage_client_e2e.py From 7a663a18cc64f51f92b10050deee610ceedc9e8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9C=8B=E6=88=9172=E9=81=8D?= Date: Wed, 4 Feb 2026 16:54:50 +0800 Subject: [PATCH 5/9] fix polling mode in text MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 看我72遍 --- tests/e2e/test_e2e_consistency.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/e2e/test_e2e_consistency.py b/tests/e2e/test_e2e_consistency.py index 50c9bfd..351049f 100644 --- a/tests/e2e/test_e2e_consistency.py +++ b/tests/e2e/test_e2e_consistency.py @@ -54,7 +54,7 @@ def e2e_client(ray_cluster): controller_actor = TransferQueueController.options( name="test_controller", get_if_exists=True, - ).remote() + ).remote(polling_mode=True) controller_info = ray.get(controller_actor.get_zmq_server_info.remote()) # Start two storage units to test sharding/scatter @@ -266,19 +266,18 @@ def verify_data_consistency( max_retries = 10 meta = None for _ in range(max_retries): - try: - meta = client.get_meta( - partition_id=partition_id, - data_fields=data_fields, - batch_size=batch_size, - mode=mode, - task_name=task_name, - ) + meta = client.get_meta( + partition_id=partition_id, + data_fields=data_fields, + batch_size=batch_size, + mode=mode, + task_name=task_name, + ) + if meta is not None and meta.size > 0: break - except Exception: - time.sleep(0.5) + time.sleep(0.5) - assert meta is not None, f"Failed to retrieve metadata for {task_name}" + assert meta is not None and meta.size > 0, f"Failed to retrieve metadata for {task_name}" retrieved_data = client.get_data(meta) From 30a5757db7feb58c30547aafd0f4616a2c0b5cda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9C=8B=E6=88=9172=E9=81=8D?= Date: Thu, 5 Feb 2026 17:30:05 +0800 Subject: [PATCH 6/9] refactor: improve e2e lifecycle tests with better error handling and cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 看我72遍 --- consistency_validation_plan.md | 98 +++ tests/e2e/test_e2e_consistency.py | 462 ------------- tests/e2e/test_e2e_lifecycle_consistency.py | 632 ++++++++++++++++++ .../test_yuanrong_storage_client_e2e.py | 0 4 files changed, 730 insertions(+), 462 deletions(-) create mode 100644 consistency_validation_plan.md delete mode 100644 tests/e2e/test_e2e_consistency.py create mode 100644 tests/e2e/test_e2e_lifecycle_consistency.py rename tests/{e2e => }/test_yuanrong_storage_client_e2e.py (100%) diff --git a/consistency_validation_plan.md b/consistency_validation_plan.md new file mode 100644 index 0000000..38f7366 --- /dev/null +++ b/consistency_validation_plan.md @@ -0,0 +1,98 @@ +# TransferQueue 端到端数据一致性校验实施计划 (Consistency Validation Master Plan) + +## 1. 背景与目标 (Background & Objective) +本项目旨在开发一个端到端的自动化测试脚本 (`scripts/test_e2e_lifecycle_consistency.py`),用于验证 `TransferQueue` 在复杂场景下的数据一致性和生命周期管理能力。 +核心目标是确保在数据经过 **分片存储 (Sharding)**、**跨节点传输**、**分批写入 (Multi-round Put)**、**动态更新 (Update/Overwrite)** 等操作后,数据的完整性和状态机(Production/Consumption Status)的正确性。 + +## 2. 核心原则 (Core Principles) +1. **Public API Only**: 测试代码仅允许使用 `TransferQueueClient` 的公共接口(如 `put`, `get_meta`, `get_data`, `check_production_status` 等),**严禁**调用 `StorageUnit` 或 `Controller` 的内部私有方法。 +2. **Complex Data Types**: 所有数据传输场景必须覆盖全量复杂数据类型(见下文)。 +3. **Environment**: 必须模拟真实分布式环境,启动 **2个及以上 Storage Units** 以强制触发 Manager 的自动分片逻辑。 + +--- + +## 3. 详细实施方案 (Implementation Details) + +### 3.1 测试环境配置 +参考 `scripts/performance_test.py` 的初始化逻辑: +- **Ray Cluster**: 使用 `pytest` fixture 启动。 +- **Components**: + - **Controller**: 1个 (`polling_mode=True`)。 + - **Storage Units**: **2个** (Capacity=10000),确保数据会分布在不同 Unit 上。 + - **Client**: 初始化一般 Client,配置 `AsyncSimpleStorageManager`。 + +### 3.2 通用复杂数据生成器 (Universal Data Generator) +实现 `generate_complex_data(indices, fields_subset=None)`,生成 `TensorDict`,必须包含: + +| 类型 (Type) | 字段 (Field) | 特征 (Characteristics) | +| :--- | :--- | :--- | +| **Standard Tensor** | `tensor_f32`, `tensor_i64` | Float32/Int64, 标准形状 | +| **Nested Tensor** | `nested_jagged` | `layout=torch.jagged`, 变长样本 | +| | `nested_strided` | `layout=torch.strided` (若支持) | +| **Lists** | `list_int`, `list_str` | Python 原生列表 | +| **NumPy** | `np_array`, `np_obj` | 标准 Array 及 Object Array (混合类型) | +| **Special Values** | `special_val` | 包含 **NaN** 和 **Inf** (验证传输稳定性) | +| **Non-Tensor** | **`non_tensor_stack`** | 使用 `tensordict.tensorclass.NonTensorData` 封装 | + +### 3.3 验证场景 (Verification Scenarios) + +#### 场景一:核心读写一致性 (Core Consistency) +- **操作**: Put 写入上述全量复杂数据 -> Get 读取。 +- **验证**: + - 输入输出的 Hash 值完全一致 (使用结构无关 Hash)。 + - `NaN` 保持为 `NaN`,`Inf` 保持为 `Inf`。 + - `NonTensorData` 解包后内容无损。 + +#### 场景二:跨分片操作与复杂更新 (Cross-Partition & Complex Update) +- **配置**: 依赖 Manager 自动将不同 Indices 分片到 2 个 Storage Units。 +- **步骤**: + 1. **Put A**: Indices `0-19` (含全量复杂字段)。 + 2. **Put B**: Indices `20-39` (含全量复杂字段)。 + 3. **Update (Cross-Shard)**: Indices `10-29` (跨越分片边界)。 + - **Modify**: 修改 `nested_jagged` (变长), `non_tensor_stack` 等字段的值。 + - **Add**: 新增字段 `new_extra_tensor` 和 `new_extra_non_tensor`。 + 4. **Get Full**: Indices `0-39`。 +- **验证**: + - `0-9`: 保持 Put A 旧值。 + - `10-29` (Update区): 旧字段更新成功,**新字段存在且正确**。 + - `30-39`: 保持 Put B 旧值。 + +#### 场景三:生命周期状态管理 (Status Lifecycle) +- **重点**: 验证 **分字段多轮 Put** 对 `Production Status` 的影响。 +- **步骤**: + 1. **Round 1 Put**: Indices `0-9`, 仅写入 `Set_A` 字段。 + - Check Production(`Set_A`): **True**。 + - Check Production(`Set_B`): **False**。 + - Check Production(`Set_A` + `Set_B`): **False**。 + 2. **Round 2 Put**: Indices `0-9`, 补全 `Set_B` 字段。 + - Check Production(`Set_A` + `Set_B`): **True**。 + 3. **Consumption**: + - Check Consumption: **False**。 + - Get Data (`Set_A` + `Set_B`). + - Check Consumption: **True**。 + +#### 场景四:自定义元数据持久化 (Custom Metadata) +- **操作**: `put` 数据 -> `set_custom_meta` (上传 Sample-level dict) -> `get_meta`。 +- **验证**: 读取到的 `custom_meta` 与上传内容完全一致。 + +#### 场景五:重置与清理 (Reset & Clear) +- **Reset Consumption**: + - 消费后调用 `reset_consumption`。 + - 验证状态变回 `Not Consumed`。 + - 验证数据可再次 `get_meta` 获取。 +- **Clear Partition**: + - 调用 `clear_partition`。 + - 验证数据物理删除 (`get_meta` 返回空或 `check_production` 为 False)。 + +--- + +## 4. 执行指南 (Execution Guide) +1. **脚本位置**: `scripts/test_e2e_lifecycle_consistency.py` +2. **运行命令**: + ```bash + ./venv/bin/python -m pytest scripts/test_e2e_lifecycle_consistency.py -v + ``` +3. **依赖**: 确保 `pytest`, `pytest-asyncio` 已安装。 + +## 5. 注意事项 (Notes) +- 保持 `client.py` 的接口纯净性,如果发现 Client 功能不足以支持测试(如由 Sync/Async 接口缺失导致),应先在 Client 层补充对应公共接口,而非在测试脚本中 Hack 内部实现。 diff --git a/tests/e2e/test_e2e_consistency.py b/tests/e2e/test_e2e_consistency.py deleted file mode 100644 index 351049f..0000000 --- a/tests/e2e/test_e2e_consistency.py +++ /dev/null @@ -1,462 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2025 The TransferQueue Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import collections.abc -import hashlib -import sys -import time -from pathlib import Path -from typing import Any - -import numpy as np -import pytest -import ray -import torch -from tensordict import TensorDict - -# Setup paths -parent_dir = Path(__file__).resolve().parent.parent -sys.path.append(str(parent_dir)) - -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, -) -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 - - -@pytest.fixture(scope="module") -def ray_cluster(): - if not ray.is_initialized(): - ray.init(ignore_reinit_error=True) - yield - if ray.is_initialized(): - ray.shutdown() - - -@pytest.fixture(scope="module") -def e2e_client(ray_cluster): - controller_actor = TransferQueueController.options( - name="test_controller", - get_if_exists=True, - ).remote(polling_mode=True) - controller_info = ray.get(controller_actor.get_zmq_server_info.remote()) - - # Start two storage units to test sharding/scatter - storage_actor_1 = SimpleStorageUnit.options( - name="test_storage_1", - get_if_exists=True, - ).remote(storage_unit_size=10000) - storage_info_1 = ray.get(storage_actor_1.get_zmq_server_info.remote()) - - storage_actor_2 = SimpleStorageUnit.options( - name="test_storage_2", - get_if_exists=True, - ).remote(storage_unit_size=10000) - storage_info_2 = ray.get(storage_actor_2.get_zmq_server_info.remote()) - - client_id = "test_e2e_client" - client = TransferQueueClient( - client_id=client_id, - controller_info=controller_info, - ) - - # Initialize Storage Manager (AsyncSimpleStorageManager) and configure it - config = { - "controller_info": controller_info, - "storage_unit_infos": { - storage_info_1.id: storage_info_1, - storage_info_2.id: storage_info_2, - }, - "storage_backend_config": {"storage_unit_size": 10000}, - } - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - - yield client - - -def generate_consistency_data(indices: list[int]) -> TensorDict: - """ - Generates a TensorDict with deterministic data based on the provided indices. - Includes all field types used in consistency tests. - """ - n = len(indices) - - # Standard Tensor: shape (n, 5), values based on index - tensor_data = torch.stack([torch.arange(i, i + 5, dtype=torch.float32) for i in indices]) - - # Nested Tensor (Jagged): - nested_list = [] - for i in indices: - if i % 2 == 0: - nested_list.append(torch.tensor([i, i * 2, i * 3], dtype=torch.float32)) - else: - nested_list.append(torch.tensor([i], dtype=torch.float32)) - nested_tensor = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged) - - # Strided Nested Tensor (or fallback) - try: - tensors = [torch.full((3, 4), float(i)) for i in indices] - strided_nested = torch.nested.nested_tensor(tensors, layout=torch.strided) - except Exception: - # Fallback for environments without strided nested support - tensors = [torch.full((3, 4), float(i)) for i in indices] - strided_nested = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) - - # List of Ints - int_values = [i * 10 for i in indices] - - # List of Strings - str_values = [f"str_{i}" for i in indices] - - # List of Numpy Arrays - numpy_list = [np.array([i, i + 1], dtype=np.int64) for i in indices] - - # Numpy Object Array (Strings) - numpy_objects = np.array([f"obj_{i}" for i in indices], dtype=object) - - # Special Values (Inf/NaN) - special_tensor = torch.zeros(n, 3) - special_tensor[:, 0] = float("inf") - special_tensor[:, 1] = float("nan") - special_tensor[:, 2] = torch.tensor(indices, dtype=torch.float32) - - # Bool Tensor - bool_tensor = torch.tensor([[i % 2 == 0] for i in indices], dtype=torch.bool) - - # Non-contiguous Tensor - # orig shape (n, 20), slice ::2 -> (n, 10). - large_tensor = torch.stack([torch.full((20,), float(i)) for i in indices]) - non_contiguous_tensor = large_tensor[:, ::2] - - return TensorDict( - { - "tensor_field": tensor_data, - "nested_field": nested_tensor, - "strided_nested_field": strided_nested, - "list_int_field": int_values, - "list_str_field": str_values, - "list_numpy_field": numpy_list, - "np_object_field": numpy_objects, - "special_field": special_tensor, - "bool_field": bool_tensor, - "non_orig_field": non_contiguous_tensor, - }, - batch_size=n, - ) - - -def compute_data_hash(data: Any, algorithm="sha256") -> str: - """ - Computes a structure-agnostic hash of the data. - - Flattens nested structures (lists, tuples, tensors). - - Ignores container types. - - Normalizes Tensors (detach, cpu). - - Consistent for NestedTensor vs List[Tensor]. - - Consistent for List[int] vs Tensor(int). - """ - hasher = hashlib.new(algorithm) - _update_hash(hasher, data) - return hasher.hexdigest() - - -def _update_hash(hasher, data): - if isinstance(data, dict | TensorDict) or (hasattr(data, "keys") and hasattr(data, "__getitem__")): - # Check if it behaves like a mapping - try: - keys = sorted(list(data.keys())) - for k in keys: - _hash_scalar(hasher, str(k)) - _update_hash(hasher, data[k]) - return - except TypeError: - pass - - if isinstance(data, torch.Tensor): - if data.is_nested: - for t in data.unbind(): - _update_hash(hasher, t) - return - - data = data.detach().cpu() - if data.ndim == 0: - _hash_scalar(hasher, data.item()) - return - - if isinstance(data, np.ndarray): - if data.ndim == 0: - _hash_scalar(hasher, data.item()) - elif data.dtype == object: - for item in data: - _update_hash(hasher, item) - else: - for item in data: - _update_hash(hasher, item) - return - - if isinstance(data, str | bytes): - _hash_scalar(hasher, data) - return - - # Handles list, tuple, and any custom Sequence - if isinstance(data, collections.abc.Sequence): - for item in data: - _update_hash(hasher, item) - return - - _hash_scalar(hasher, data) - - -def _hash_scalar(hasher, data): - s = str(data) - if s == "nan": - s = "NAN" - hasher.update(s.encode("utf-8")) - hasher.update(b"|") - - -def put_data_with_indices(client, partition_id, indices, data): - """ - Helper to put data with specific global indices. - """ - samples = [] - field_names = list(data.keys()) - - # Pre-compute field metas - for idx in indices: - fields_dict = {name: FieldMeta(name=name, dtype=None, shape=None) for name in field_names} - samples.append(SampleMeta(partition_id=partition_id, global_index=idx, fields=fields_dict)) - - meta = BatchMeta(samples=samples) - - future = asyncio.run_coroutine_threadsafe(client.storage_manager.put_data(data, meta), client._loop) - # Wait for result - future.result(timeout=30) - - -def get_fields_subset(indices, field_names): - """Helper to slice TensorDict by fields.""" - full_data = generate_consistency_data(indices) - # Select only specific fields - # Note: TensorDict.select returns a new TensorDict with shared storage - return full_data.select(*field_names) - - -def verify_data_consistency( - client, partition_id: str, task_name: str, data_fields: list[str], batch_size: int, mode: str = "fetch" -) -> TensorDict: - """Helper to retrieve data and verify it matches deterministic generation logic.""" - # Poll for metadata until ready - max_retries = 10 - meta = None - for _ in range(max_retries): - meta = client.get_meta( - partition_id=partition_id, - data_fields=data_fields, - batch_size=batch_size, - mode=mode, - task_name=task_name, - ) - if meta is not None and meta.size > 0: - break - time.sleep(0.5) - - assert meta is not None and meta.size > 0, f"Failed to retrieve metadata for {task_name}" - - retrieved_data = client.get_data(meta) - - # Generate partial expected data based on retrieved indices - full_expected = generate_consistency_data(meta.global_indexes) - expected_data = full_expected.select(*data_fields) - - # Verify Hash - retrieved_hash = compute_data_hash(retrieved_data) - expected_hash = compute_data_hash(expected_data) - assert retrieved_hash == expected_hash, f"Hash mismatch for {task_name}" - - return retrieved_data - - -@pytest.mark.timeout(60) -def test_consistency_core_types(e2e_client): - """ - Test Case 1: Core Data Types Coverage - - Tensor, NestedTensor, Non-Tensor (stackable/non-stackable) - """ - client = e2e_client - - # Use distinct partition to avoid conflict if tests run shared (though scope is module) - partition_id = "test_core_types" - - batch_size = 5 - - # Get fields list from dummy data - dummy_data = generate_consistency_data([0]) - fields = list(dummy_data.keys()) - - # 1. Allocate Partition & Indices - allocation_meta = client.get_meta( - partition_id=partition_id, data_fields=fields, batch_size=batch_size, mode="insert", task_name="allocator" - ) - indices = allocation_meta.global_indexes - assert len(indices) == batch_size - - # 2. Generate Data using allocated indices - data = generate_consistency_data(indices) - - put_data_with_indices(client, partition_id, indices, data) - - # 3. Get Data and Verify - verify_data_consistency(client, partition_id, "test_worker", list(data.keys()), batch_size) - - -@pytest.mark.timeout(120) -def test_consistency_multi_round_put_get(e2e_client): - """ - Test Case 2: Multi-round Put & Field Merge - Simulate fragmented writing and field stitching. - """ - client = e2e_client - partition_id = "test_multi_round" - - # Define Indices - idx_round1 = list(range(0, 20)) - idx_round2 = list(range(20, 41)) # 21 items - - # Step 1: Allocations - # Define Field Groups - # Group 1: Standard - group_std_fields = ["tensor_field", "list_str_field", "list_int_field", "special_field", "bool_field"] - # Group 2: Complex - group_complex_fields = [ - "nested_field", - "strided_nested_field", - "list_numpy_field", - "np_object_field", - "non_orig_field", - ] - - all_fields = group_std_fields + group_complex_fields - - # Allocation 1: Indices 0-19 - allocation_meta_1 = client.get_meta( - partition_id=partition_id, data_fields=all_fields, batch_size=20, mode="insert", task_name="allocator_1" - ) - assert len(allocation_meta_1.global_indexes) == 20 - - # Allocation 2: Indices 20-40 - allocation_meta_2 = client.get_meta( - partition_id=partition_id, data_fields=all_fields, batch_size=21, mode="insert", task_name="allocator_2" - ) - assert len(allocation_meta_2.global_indexes) == 21 - - # Full list for reference - all_indices = idx_round1 + idx_round2 - - # --- Write Operations with Cross-Batch Logic --- - - # Op 1: Put Standard Group for Batch 1 (0-19) - data_1_std = get_fields_subset(idx_round1, group_std_fields) - put_data_with_indices(client, partition_id, idx_round1, data_1_std) - - # Op 2: Put Standard Group for Batch 2 (20-40) - data_2_std = get_fields_subset(idx_round2, group_std_fields) - put_data_with_indices(client, partition_id, idx_round2, data_2_std) - - # Op 3: Cross-batch Put for Complex Group (indices 15-25) - idx_cross = all_indices[15:26] # 15 to 25 inclusive - data_cross_complex = get_fields_subset(idx_cross, group_complex_fields) - put_data_with_indices(client, partition_id, idx_cross, data_cross_complex) - - # Op 4: Fill remaining Complex Group (0-14) - idx_remaining_1 = all_indices[0:15] - data_rem_1_complex = get_fields_subset(idx_remaining_1, group_complex_fields) - put_data_with_indices(client, partition_id, idx_remaining_1, data_rem_1_complex) - - # Op 5: Fill remaining Complex Group (26-40) - idx_remaining_2 = all_indices[26:] - data_rem_2_complex = get_fields_subset(idx_remaining_2, group_complex_fields) - put_data_with_indices(client, partition_id, idx_remaining_2, data_rem_2_complex) - - # Verifications - - # 1. Get 0-10 - verify_data_consistency(client, partition_id, "verifier_1", all_fields, 10) - - # 2. Get 10-30 (Cross boundary) - verify_data_consistency(client, partition_id, "verifier_2", all_fields, 20) - - # 3. Get 30-40 - verify_data_consistency(client, partition_id, "verifier_3", all_fields, 11) - - -@pytest.mark.timeout(60) -def test_consistency_slicing_and_subset(e2e_client): - """ - Test Case 3: Slicing and Field Subsetting - """ - client = e2e_client - partition_id = "test_slicing" - - # Pre-allocate Partition - all_fields = [ - "tensor_field", - "nested_field", - "strided_nested_field", - "list_int_field", - "list_str_field", - "list_numpy_field", - "np_object_field", - "special_field", - "bool_field", - "non_orig_field", - ] - allocation_meta = client.get_meta( - partition_id=partition_id, data_fields=all_fields, batch_size=20, mode="insert", task_name="allocator" - ) - indices = allocation_meta.global_indexes - assert len(indices) == 20 - - # Create Data using these indices - data = generate_consistency_data(indices) - - # Put Data - put_data_with_indices(client, partition_id, indices, data) - - # 1. Field Subset: Get only nested_field (Ragged) - result_subset = verify_data_consistency(client, partition_id, "inspector", ["nested_field"], 20, mode="force_fetch") - - assert "nested_field" in result_subset - assert "tensor_field" not in result_subset - assert len(result_subset["nested_field"]) == 20 - - # 2. Get np_object_field & special_field - result_mixed = verify_data_consistency( - client, partition_id, "inspector_2", ["np_object_field", "special_field"], 20, mode="force_fetch" - ) - - assert "np_object_field" in result_mixed - assert "special_field" in result_mixed - - # 3. Get strided_nested_field explicitly - verify_data_consistency(client, partition_id, "inspector_3", ["strided_nested_field"], 20, mode="force_fetch") - - -if __name__ == "__main__": - sys.exit(pytest.main(["-v", __file__])) diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py new file mode 100644 index 0000000..203267b --- /dev/null +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -0,0 +1,632 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +E2E Lifecycle Consistency Tests for TransferQueue. + +Implements all 5 scenarios from consistency_validation_plan.md: +- Scenario One: Core Read/Write Consistency +- Scenario Two: Cross-Partition & Complex Update +- Scenario Three: Production Status Lifecycle +- Scenario Four: Custom Metadata Persistence +- Scenario Five: Reset & Clear +""" + +import sys +import time +from pathlib import Path + +import numpy as np +import pytest +import ray +import torch +from tensordict import TensorDict +from tensordict.tensorclass import NonTensorData + +# Setup paths +parent_dir = Path(__file__).resolve().parent.parent.parent +sys.path.append(str(parent_dir)) + +from transfer_queue import ( # noqa: E402 + SimpleStorageUnit, + TransferQueueClient, + TransferQueueController, +) + +# Module-level default fields to avoid repeated generation +DEFAULT_FIELDS = [ + "tensor_f32", + "tensor_i64", + "nested_jagged", + "nested_strided", + "list_int", + "list_str", + "np_array", + "np_obj", + "special_val", + "non_tensor_stack", +] + + +@pytest.fixture(scope="module") +def ray_cluster(): + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + yield + if ray.is_initialized(): + ray.shutdown() + + +@pytest.fixture(scope="module") +def e2e_client(ray_cluster): + """Create a client with 2 storage units for lifecycle testing.""" + controller_actor = TransferQueueController.options( + name="lifecycle_controller", + get_if_exists=True, + ).remote(polling_mode=True) + controller_info = ray.get(controller_actor.get_zmq_server_info.remote()) + + # Two storage units to ensure sharding + storage_actor_1 = SimpleStorageUnit.options( + name="lifecycle_storage_1", + get_if_exists=True, + ).remote(storage_unit_size=10000) + storage_info_1 = ray.get(storage_actor_1.get_zmq_server_info.remote()) + + storage_actor_2 = SimpleStorageUnit.options( + name="lifecycle_storage_2", + get_if_exists=True, + ).remote(storage_unit_size=10000) + storage_info_2 = ray.get(storage_actor_2.get_zmq_server_info.remote()) + + client = TransferQueueClient( + client_id="lifecycle_test_client", + controller_info=controller_info, + ) + + config = { + "controller_info": controller_info, + "storage_unit_infos": { + storage_info_1.id: storage_info_1, + storage_info_2.id: storage_info_2, + }, + "storage_backend_config": {"storage_unit_size": 10000}, + } + + client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) + + yield client + + +def generate_complex_data(indices: list[int]) -> TensorDict: + """ + Generates complex TensorDict with all required field types. + Based on consistency_validation_plan.md Section 3.2. + """ + n = len(indices) + + # Standard Tensor (Float32) + tensor_f32 = torch.stack([torch.arange(i, i + 5, dtype=torch.float32) for i in indices]) + + # Standard Tensor (Int64) + tensor_i64 = torch.stack([torch.arange(i, i + 5, dtype=torch.int64) for i in indices]) + + # Nested Tensor (Jagged) + nested_list = [] + for i in indices: + length = 2 + (i % 4) # Variable length: 2-5 + nested_list.append(torch.arange(i, i + length, dtype=torch.float32)) + nested_jagged = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged) + + # Nested Tensor (Strided) - fallback to jagged if not supported + try: + strided_tensors = [torch.full((3, 4), float(i)) for i in indices] + nested_strided = torch.nested.nested_tensor(strided_tensors, layout=torch.strided) + except (TypeError, RuntimeError): + strided_tensors = [torch.full((3, 4), float(i)) for i in indices] + nested_strided = torch.nested.as_nested_tensor(strided_tensors, layout=torch.jagged) + + # Python Lists + list_int = [i * 10 for i in indices] + list_str = [f"sample_{i}" for i in indices] + + # NumPy Arrays + np_array = np.array([np.arange(i, i + 3) for i in indices], dtype=np.float64) + np_obj = np.array([f"obj_{i}" for i in indices], dtype=object) + + # Special Values (NaN and Inf) + special_val = torch.zeros(n, 3) + special_val[:, 0] = float("inf") + special_val[:, 1] = float("nan") + special_val[:, 2] = torch.tensor(indices, dtype=torch.float32) + + # NonTensorData + non_tensor_data = [{"idx": i, "text": f"non_tensor_{i}"} for i in indices] + non_tensor_stack = NonTensorData(data=non_tensor_data, batch_size=(n,), device=None) + + return TensorDict( + { + "tensor_f32": tensor_f32, + "tensor_i64": tensor_i64, + "nested_jagged": nested_jagged, + "nested_strided": nested_strided, + "list_int": list_int, + "list_str": list_str, + "np_array": np_array, + "np_obj": np_obj, + "special_val": special_val, + "non_tensor_stack": non_tensor_stack, + }, + batch_size=n, + ) + + +def poll_for_meta(client, partition_id, data_fields, batch_size, task_name, mode="fetch", max_retries=10): + """Poll until metadata is ready or max retries reached.""" + for _ in range(max_retries): + meta = client.get_meta( + partition_id=partition_id, + data_fields=data_fields, + batch_size=batch_size, + mode=mode, + task_name=task_name, + ) + if meta is not None and meta.size > 0: + return meta + time.sleep(0.3) + return None + + +# ============================================================================= +# Helper Functions for Data Verification +# ============================================================================= +def verify_special_values(retrieved: torch.Tensor, expected: torch.Tensor) -> bool: + """Verify special values (NaN, Inf) are preserved.""" + # Check Inf column + if not torch.all(torch.isinf(retrieved[:, 0]) & (retrieved[:, 0] > 0)): + return False + # Check NaN column + if not torch.all(torch.isnan(retrieved[:, 1])): + return False + # Check regular values column + if not torch.allclose(retrieved[:, 2], expected[:, 2]): + return False + return True + + +def verify_nested_tensor_equal(retrieved, expected) -> bool: + """Verify nested tensors element by element.""" + if len(retrieved.unbind()) != len(expected.unbind()): + return False + for r, e in zip(retrieved.unbind(), expected.unbind(), strict=False): + if not torch.allclose(r, e): + return False + return True + + +def verify_non_tensor_data(retrieved, expected) -> bool: + """Verify NonTensorData content.""" + if hasattr(retrieved, "data"): + retrieved = retrieved.data + if hasattr(expected, "data"): + expected = expected.data + return retrieved == expected + + +def verify_list_equal(retrieved, expected) -> bool: + """Verify list content, handling possible Tensor conversion.""" + # Convert Tensor to list if needed + if isinstance(retrieved, torch.Tensor): + retrieved = retrieved.tolist() + if isinstance(expected, torch.Tensor): + expected = expected.tolist() + return retrieved == expected + + +# ============================================================================= +# Scenario One: Core Read/Write Consistency +# ============================================================================= +def test_core_consistency(e2e_client): + """ + Test Case: Core Read/Write Consistency (Scenario 1) + + Validates: + 1. Put full complex data -> Get retrieves identical data + 2. NaN remains NaN, Inf remains Inf + 3. NonTensorData unpacks without loss + 4. All field types are correctly round-tripped + """ + client = e2e_client + partition_id = "test_core_consistency" + batch_size = 20 + task_name = "core_consistency_task" + + # 1. Put full complex data + indices = list(range(batch_size)) + original_data = generate_complex_data(indices) + fields = DEFAULT_FIELDS + + meta = client.put(data=original_data, partition_id=partition_id) + assert meta.size == batch_size, f"Expected batch_size {batch_size}, got {meta.size}" + try: + # 2. Get data + retrieved_meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="fetch") + assert retrieved_meta is not None and retrieved_meta.size == batch_size, "Failed to retrieve metadata" + retrieved_data = client.get_data(retrieved_meta) + + # 3. Verify Standard Tensors + assert torch.allclose(retrieved_data["tensor_f32"], original_data["tensor_f32"]), "tensor_f32 mismatch" + assert torch.equal(retrieved_data["tensor_i64"], original_data["tensor_i64"]), "tensor_i64 mismatch" + + # 4. Verify Nested Tensors (Jagged) + assert verify_nested_tensor_equal(retrieved_data["nested_jagged"], original_data["nested_jagged"]), ( + "Jagged nested tensor mismatch" + ) + + # 5. Verify Nested Tensors (Strided) + assert verify_nested_tensor_equal(retrieved_data["nested_strided"], original_data["nested_strided"]), ( + "Strided nested tensor mismatch" + ) + + # 6. Verify Python Lists + assert verify_list_equal(retrieved_data["list_int"], original_data["list_int"]), "list_int mismatch" + assert verify_list_equal(retrieved_data["list_str"], original_data["list_str"]), "list_str mismatch" + + # 7. Verify NumPy Arrays + assert np.allclose(retrieved_data["np_array"], original_data["np_array"]), "np_array mismatch" + assert np.array_equal(retrieved_data["np_obj"], original_data["np_obj"]), "np_obj mismatch" + + # 8. Verify Special Values (NaN and Inf) + assert verify_special_values(retrieved_data["special_val"], original_data["special_val"]), ( + "Special values (NaN/Inf) not preserved" + ) + + # 9. Verify NonTensorData + assert verify_non_tensor_data(retrieved_data["non_tensor_stack"], original_data["non_tensor_stack"]), ( + "NonTensorData content mismatch" + ) + finally: + client.clear_partition(partition_id) + + +# ============================================================================= +# Scenario Two: Cross-Partition & Complex Update +# ============================================================================= +def test_cross_partition_complex_update(e2e_client): + """ + Test Case: Cross-Partition & Complex Update (Scenario 2) + + Validates: + 1. Put A (indices 0-19) with full complex fields + 2. Put B (indices 20-39) with full complex fields + 3. Update indices 10-29 (cross-shard): modify existing fields + add new fields + 4. Get Full (0-39) and verify: + - 0-9: original Put A values + - 10-29: updated values with new fields + - 30-39: original Put B values + """ + client = e2e_client + partition_id = "test_cross_partition_update" + task_name = "cross_partition_task" + + # Define index ranges + idx_a = list(range(0, 20)) # Put A + idx_b = list(range(20, 40)) # Put B + idx_update = list(range(10, 30)) # Update (cross-shard) + base_fields = DEFAULT_FIELDS + + # 1. Allocate full partition + alloc_meta = client.get_meta( + partition_id=partition_id, + data_fields=base_fields, + batch_size=40, + mode="insert", + task_name="allocator", + ) + assert len(alloc_meta.global_indexes) == 40, "Failed to allocate 40 samples" + + try: + # 2. Put A: indices 0-19 + data_a = generate_complex_data(idx_a) + meta_a = alloc_meta.select_samples(list(range(0, 20))) + client.put(data=data_a, metadata=meta_a) + + # 3. Put B: indices 20-39 + data_b = generate_complex_data(idx_b) + meta_b = alloc_meta.select_samples(list(range(20, 40))) + client.put(data=data_b, metadata=meta_b) + + # 4. Update indices 10-29 with modified values and new fields + modified_indices = [i + 1000 for i in idx_update] # Offset to make values distinguishable + data_update = generate_complex_data(modified_indices) + + # Add new fields + new_extra_tensor = torch.stack([torch.ones(3) * i for i in idx_update]) # Shape: (20, 3) + new_extra_non_tensor = NonTensorData( + data=[{"new_field": f"new_{i}"} for i in idx_update], + batch_size=(len(idx_update),), + device=None, + ) + data_update["new_extra_tensor"] = new_extra_tensor + data_update["new_extra_non_tensor"] = new_extra_non_tensor + + # Put update data + meta_update = alloc_meta.select_samples(list(range(10, 30))) + client.put(data=data_update, metadata=meta_update) + + # 5. Get Full: indices 0-39, only base fields first + full_meta = poll_for_meta(client, partition_id, base_fields, 40, task_name, mode="force_fetch") + assert full_meta is not None and full_meta.size == 40, "Failed to retrieve full metadata" + full_data = client.get_data(full_meta) + + # 6. Verify region 0-9: original Put A values + original_data_0_9 = generate_complex_data(list(range(0, 10))) + assert torch.allclose(full_data["tensor_f32"][:10], original_data_0_9["tensor_f32"]), ( + "Region 0-9 tensor_f32 should match original Put A" + ) + + # 7. Verify region 10-29: updated values (using offset indices 1010-1029) + updated_expected = generate_complex_data([i + 1000 for i in range(10, 30)]) + assert torch.allclose(full_data["tensor_f32"][10:30], updated_expected["tensor_f32"]), ( + "Region 10-29 tensor_f32 should match updated values" + ) + + # 8. Verify region 30-39: original Put B values + original_data_30_39 = generate_complex_data(list(range(30, 40))) + assert torch.allclose(full_data["tensor_f32"][30:40], original_data_30_39["tensor_f32"]), ( + "Region 30-39 tensor_f32 should match original Put B" + ) + + # 9. Verify new fields exist in update region + extended_fields = base_fields + ["new_extra_tensor", "new_extra_non_tensor"] + update_region_meta = poll_for_meta( + client, partition_id, extended_fields, 20, "update_region_task", mode="force_fetch" + ) + if update_region_meta is not None and update_region_meta.size > 0: + update_region_data = client.get_data(update_region_meta) + assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist" + assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist" + finally: + client.clear_partition(partition_id) + + +# ============================================================================= +# Scenario Three: Production Status Lifecycle +# ============================================================================= +def test_production_status_lifecycle(e2e_client): + """ + Test Case: Production Status Lifecycle (Scenario 3) + + Validates multi-round partial field put and production status transitions. + + Steps: + 1. Round 1 Put: Indices 0-9, only Set_A fields -> Check production(Set_A)=True, production(Set_B)=False + 2. Round 2 Put: Indices 0-9, complete Set_B fields -> Check production(Set_A+Set_B)=True + 3. Verify consumption status transitions + """ + client = e2e_client + partition_id = "test_production_lifecycle" + batch_size = 10 + task_name = "production_lifecycle_task" + + # Define field sets + set_a_fields = ["tensor_f32", "tensor_i64", "list_int", "list_str"] + set_b_fields = ["nested_jagged", "np_array", "special_val"] + all_fields = set_a_fields + set_b_fields + + # 1. Allocate partition with all fields + alloc_meta = client.get_meta( + partition_id=partition_id, + data_fields=all_fields, + batch_size=batch_size, + mode="insert", + task_name="allocator", + ) + indices = alloc_meta.global_indexes + assert len(indices) == batch_size, f"Expected {batch_size} indices, got {len(indices)}" + + try: + # 2. Round 1: Put only Set_A fields + full_data = generate_complex_data(indices) + set_a_data = full_data.select(*set_a_fields) + client.put(data=set_a_data, metadata=alloc_meta) + + # 3. Check Production Status - Set_A should be ready, Set_B should not + set_a_ready = client.check_production_status(data_fields=set_a_fields, partition_id=partition_id) + assert set_a_ready, "Set_A fields should be ready after Round 1" + + set_b_ready = client.check_production_status(data_fields=set_b_fields, partition_id=partition_id) + assert not set_b_ready, "Set_B fields should NOT be ready after Round 1" + + all_ready_before = client.check_production_status(data_fields=all_fields, partition_id=partition_id) + assert not all_ready_before, "All fields should NOT be ready before Round 2" + + # 4. Round 2: Put Set_B fields + set_b_data = full_data.select(*set_b_fields) + client.put(data=set_b_data, metadata=alloc_meta) + + # 5. Check Production Status - All should be ready + all_ready_after = client.check_production_status(data_fields=all_fields, partition_id=partition_id) + assert all_ready_after, "All fields should be ready after Round 2" + + # 6. Consumption Status Check - should be False initially + is_consumed = client.check_consumption_status(task_name=task_name, partition_id=partition_id) + assert not is_consumed, "Data should not be consumed initially" + + # 7. Consume Data + meta = poll_for_meta(client, partition_id, all_fields, batch_size, task_name, mode="fetch") + assert meta is not None, "Failed to poll metadata" + client.get_data(meta) + + # 8. Post-Consumption Check - should be True + is_consumed_after = client.check_consumption_status(task_name=task_name, partition_id=partition_id) + assert is_consumed_after, "Data should be consumed after get_data" + finally: + client.clear_partition(partition_id) + + +# ============================================================================= +# Scenario Four: Custom Metadata Persistence +# ============================================================================= +def test_custom_metadata_persistence(e2e_client): + """ + + Test Case: Custom Metadata Persistence (Scenario 4) + + Validates: + 1. put data -> set_custom_meta -> get_meta retrieves correct custom_meta + 2. Custom metadata is per-sample and survives roundtrip + """ + client = e2e_client + partition_id = "test_custom_meta" + batch_size = 8 + task_name = "custom_meta_task" + fields = DEFAULT_FIELDS + + # 1. Allocate and Put Data + meta = client.put( + data=generate_complex_data(list(range(batch_size))), + partition_id=partition_id, + ) + assert meta.size == batch_size, f"Expected batch_size {batch_size}, got {meta.size}" + + try: + # 2. Create Custom Metadata for each sample + custom_metadata = {} + for i in range(batch_size): + custom_metadata[meta.global_indexes[i]] = { + "score": float(i) / 10.0, + "label": f"label_{i}", + "tags": [f"tag_{i}_a", f"tag_{i}_b"], + } + meta.update_custom_meta(custom_metadata) + + # 3. Upload Custom Metadata + client.set_custom_meta(meta) + + # 4. Retrieve Metadata and Verify Custom Meta + retrieved_meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch") + assert retrieved_meta is not None, "Failed to retrieve metadata" + + # Verify custom metadata content + retrieved_custom = retrieved_meta.get_all_custom_meta() + for global_idx, expected_meta in custom_metadata.items(): + assert global_idx in retrieved_custom, f"Missing custom_meta for index {global_idx}" + actual = retrieved_custom[global_idx] + assert actual["score"] == expected_meta["score"], f"Score mismatch at index {global_idx}" + assert actual["label"] == expected_meta["label"], f"Label mismatch at index {global_idx}" + assert actual["tags"] == expected_meta["tags"], f"Tags mismatch at index {global_idx}" + finally: + client.clear_partition(partition_id) + + +# ============================================================================= +# Scenario Five: Reset & Clear +# ============================================================================= +def test_reset_consumption(e2e_client): + """ + Test Case: Reset Consumption Status (Scenario 5a) + + Validates: + 1. After consuming data, consumption status is True + 2. After reset_consumption, status reverts to False + 3. Data can be re-consumed after reset + """ + client = e2e_client + partition_id = "test_reset_consumption" + batch_size = 10 + task_name = "reset_test_task" + fields = DEFAULT_FIELDS + + # 1. Put Data + data = generate_complex_data(list(range(batch_size))) + client.put(data=data, partition_id=partition_id) + + try: + # 2. Initial Consumption Status Check - should be False (not consumed) + is_consumed_initial = client.check_consumption_status(task_name=task_name, partition_id=partition_id) + assert not is_consumed_initial, "Data should not be consumed initially" + + # 3. Consume Data (get_meta + get_data) + meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="fetch") + assert meta is not None and meta.size == batch_size, "Failed to poll metadata" + retrieved_data = client.get_data(meta) + assert retrieved_data.batch_size[0] == batch_size, "Retrieved data batch_size mismatch" + + # 4. Post-Consumption Status Check - should be True + is_consumed_after = client.check_consumption_status(task_name=task_name, partition_id=partition_id) + assert is_consumed_after, "Data should be consumed after get_data" + + # 5. Reset Consumption + success = client.reset_consumption(partition_id=partition_id, task_name=task_name) + assert success, "reset_consumption should return True" + + # 6. Post-Reset Status Check - should be False again + is_consumed_reset = client.check_consumption_status(task_name=task_name, partition_id=partition_id) + assert not is_consumed_reset, "Consumption status should be False after reset" + + # 7. Verify data can be re-consumed + meta_again = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="fetch") + assert meta_again is not None and meta_again.size == batch_size, "Should be able to fetch metadata again" + finally: + client.clear_partition(partition_id) + + +def test_clear_partition(e2e_client): + """ + Test Case: Clear Partition (Scenario 5b) + + Validates: + 1. Put data -> data is accessible + 2. clear_partition -> data is physically deleted + 3. After clear, check_production_status returns False + 4. After clear, partition is removed from partition list + """ + client = e2e_client + partition_id = "test_clear_partition" + batch_size = 15 + task_name = "clear_test_task" + fields = DEFAULT_FIELDS + + # 1. Put Data + data = generate_complex_data(list(range(batch_size))) + client.put(data=data, partition_id=partition_id) + + # 2. Verify Data Exists - production status should be True + is_ready = client.check_production_status(data_fields=fields, partition_id=partition_id) + assert is_ready, "Data should be ready after put" + + # 3. Get Data to confirm accessibility + meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch") + assert meta is not None and meta.size == batch_size, "Failed to poll metadata" + + # 4. Verify partition exists before clear + partition_list_before = client.get_partition_list() + assert partition_id in partition_list_before, "Partition should exist before clear" + + # 5. Clear Partition + client.clear_partition(partition_id) + + # 6. Verify partition is removed from list + partition_list_after = client.get_partition_list() + assert partition_id not in partition_list_after, "Partition should be removed after clear" + + # 7. Verify Production Status returns False for cleared partition + is_ready_after_clear = client.check_production_status(data_fields=fields, partition_id=partition_id) + assert not is_ready_after_clear, "Production status should be False after clear" + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/tests/e2e/test_yuanrong_storage_client_e2e.py b/tests/test_yuanrong_storage_client_e2e.py similarity index 100% rename from tests/e2e/test_yuanrong_storage_client_e2e.py rename to tests/test_yuanrong_storage_client_e2e.py From b999ae428376ecca86e0c951ad9de8abc26fa547 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9C=8B=E6=88=9172=E9=81=8D?= Date: Thu, 5 Feb 2026 19:25:26 +0800 Subject: [PATCH 7/9] remove file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 看我72遍 --- consistency_validation_plan.md | 98 ---------------------------------- 1 file changed, 98 deletions(-) delete mode 100644 consistency_validation_plan.md diff --git a/consistency_validation_plan.md b/consistency_validation_plan.md deleted file mode 100644 index 38f7366..0000000 --- a/consistency_validation_plan.md +++ /dev/null @@ -1,98 +0,0 @@ -# TransferQueue 端到端数据一致性校验实施计划 (Consistency Validation Master Plan) - -## 1. 背景与目标 (Background & Objective) -本项目旨在开发一个端到端的自动化测试脚本 (`scripts/test_e2e_lifecycle_consistency.py`),用于验证 `TransferQueue` 在复杂场景下的数据一致性和生命周期管理能力。 -核心目标是确保在数据经过 **分片存储 (Sharding)**、**跨节点传输**、**分批写入 (Multi-round Put)**、**动态更新 (Update/Overwrite)** 等操作后,数据的完整性和状态机(Production/Consumption Status)的正确性。 - -## 2. 核心原则 (Core Principles) -1. **Public API Only**: 测试代码仅允许使用 `TransferQueueClient` 的公共接口(如 `put`, `get_meta`, `get_data`, `check_production_status` 等),**严禁**调用 `StorageUnit` 或 `Controller` 的内部私有方法。 -2. **Complex Data Types**: 所有数据传输场景必须覆盖全量复杂数据类型(见下文)。 -3. **Environment**: 必须模拟真实分布式环境,启动 **2个及以上 Storage Units** 以强制触发 Manager 的自动分片逻辑。 - ---- - -## 3. 详细实施方案 (Implementation Details) - -### 3.1 测试环境配置 -参考 `scripts/performance_test.py` 的初始化逻辑: -- **Ray Cluster**: 使用 `pytest` fixture 启动。 -- **Components**: - - **Controller**: 1个 (`polling_mode=True`)。 - - **Storage Units**: **2个** (Capacity=10000),确保数据会分布在不同 Unit 上。 - - **Client**: 初始化一般 Client,配置 `AsyncSimpleStorageManager`。 - -### 3.2 通用复杂数据生成器 (Universal Data Generator) -实现 `generate_complex_data(indices, fields_subset=None)`,生成 `TensorDict`,必须包含: - -| 类型 (Type) | 字段 (Field) | 特征 (Characteristics) | -| :--- | :--- | :--- | -| **Standard Tensor** | `tensor_f32`, `tensor_i64` | Float32/Int64, 标准形状 | -| **Nested Tensor** | `nested_jagged` | `layout=torch.jagged`, 变长样本 | -| | `nested_strided` | `layout=torch.strided` (若支持) | -| **Lists** | `list_int`, `list_str` | Python 原生列表 | -| **NumPy** | `np_array`, `np_obj` | 标准 Array 及 Object Array (混合类型) | -| **Special Values** | `special_val` | 包含 **NaN** 和 **Inf** (验证传输稳定性) | -| **Non-Tensor** | **`non_tensor_stack`** | 使用 `tensordict.tensorclass.NonTensorData` 封装 | - -### 3.3 验证场景 (Verification Scenarios) - -#### 场景一:核心读写一致性 (Core Consistency) -- **操作**: Put 写入上述全量复杂数据 -> Get 读取。 -- **验证**: - - 输入输出的 Hash 值完全一致 (使用结构无关 Hash)。 - - `NaN` 保持为 `NaN`,`Inf` 保持为 `Inf`。 - - `NonTensorData` 解包后内容无损。 - -#### 场景二:跨分片操作与复杂更新 (Cross-Partition & Complex Update) -- **配置**: 依赖 Manager 自动将不同 Indices 分片到 2 个 Storage Units。 -- **步骤**: - 1. **Put A**: Indices `0-19` (含全量复杂字段)。 - 2. **Put B**: Indices `20-39` (含全量复杂字段)。 - 3. **Update (Cross-Shard)**: Indices `10-29` (跨越分片边界)。 - - **Modify**: 修改 `nested_jagged` (变长), `non_tensor_stack` 等字段的值。 - - **Add**: 新增字段 `new_extra_tensor` 和 `new_extra_non_tensor`。 - 4. **Get Full**: Indices `0-39`。 -- **验证**: - - `0-9`: 保持 Put A 旧值。 - - `10-29` (Update区): 旧字段更新成功,**新字段存在且正确**。 - - `30-39`: 保持 Put B 旧值。 - -#### 场景三:生命周期状态管理 (Status Lifecycle) -- **重点**: 验证 **分字段多轮 Put** 对 `Production Status` 的影响。 -- **步骤**: - 1. **Round 1 Put**: Indices `0-9`, 仅写入 `Set_A` 字段。 - - Check Production(`Set_A`): **True**。 - - Check Production(`Set_B`): **False**。 - - Check Production(`Set_A` + `Set_B`): **False**。 - 2. **Round 2 Put**: Indices `0-9`, 补全 `Set_B` 字段。 - - Check Production(`Set_A` + `Set_B`): **True**。 - 3. **Consumption**: - - Check Consumption: **False**。 - - Get Data (`Set_A` + `Set_B`). - - Check Consumption: **True**。 - -#### 场景四:自定义元数据持久化 (Custom Metadata) -- **操作**: `put` 数据 -> `set_custom_meta` (上传 Sample-level dict) -> `get_meta`。 -- **验证**: 读取到的 `custom_meta` 与上传内容完全一致。 - -#### 场景五:重置与清理 (Reset & Clear) -- **Reset Consumption**: - - 消费后调用 `reset_consumption`。 - - 验证状态变回 `Not Consumed`。 - - 验证数据可再次 `get_meta` 获取。 -- **Clear Partition**: - - 调用 `clear_partition`。 - - 验证数据物理删除 (`get_meta` 返回空或 `check_production` 为 False)。 - ---- - -## 4. 执行指南 (Execution Guide) -1. **脚本位置**: `scripts/test_e2e_lifecycle_consistency.py` -2. **运行命令**: - ```bash - ./venv/bin/python -m pytest scripts/test_e2e_lifecycle_consistency.py -v - ``` -3. **依赖**: 确保 `pytest`, `pytest-asyncio` 已安装。 - -## 5. 注意事项 (Notes) -- 保持 `client.py` 的接口纯净性,如果发现 Client 功能不足以支持测试(如由 Sync/Async 接口缺失导致),应先在 Client 层补充对应公共接口,而非在测试脚本中 Hack 内部实现。 From c447faf6e1e4437f98f4a26a7aa4a66eb9a98229 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9C=8B=E6=88=9172=E9=81=8D?= Date: Mon, 16 Feb 2026 17:05:27 +0800 Subject: [PATCH 8/9] refactor(e2e): validate field ordering against DEFAULT_FIELDS and cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add runtime assertion in generate_complex_data to ensure field_values keys exactly match DEFAULT_FIELDS, preventing silent field mismatches - Build TensorDict via dict comprehension keyed by DEFAULT_FIELDS order - Extract inline reorder logic into reusable _reorder_tensordict helper - Remove redundant section separator comments (=== banners) - Add missing assertions for tensor_bf16 and list_obj in core consistency - Rename test_cross_partition_complex_update -> test_cross_shard_complex_update - Improve verify_list_equal docstring with TensorDict conversion note Signed-off-by: 看我72遍 --- tests/e2e/test_e2e_lifecycle_consistency.py | 359 ++++++++++---------- 1 file changed, 175 insertions(+), 184 deletions(-) diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index 203267b..c82a43d 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -13,16 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -E2E Lifecycle Consistency Tests for TransferQueue. - -Implements all 5 scenarios from consistency_validation_plan.md: -- Scenario One: Core Read/Write Consistency -- Scenario Two: Cross-Partition & Complex Update -- Scenario Three: Production Status Lifecycle -- Scenario Four: Custom Metadata Persistence -- Scenario Five: Reset & Clear -""" +"""E2E lifecycle consistency tests for TransferQueue.""" import sys import time @@ -35,24 +26,20 @@ from tensordict import TensorDict from tensordict.tensorclass import NonTensorData -# Setup paths +# Setup paths (transfer_queue is not pip-installed) parent_dir = Path(__file__).resolve().parent.parent.parent sys.path.append(str(parent_dir)) -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, -) - # Module-level default fields to avoid repeated generation DEFAULT_FIELDS = [ "tensor_f32", "tensor_i64", + "tensor_bf16", "nested_jagged", "nested_strided", "list_int", "list_str", + "list_obj", "np_array", "np_obj", "special_val", @@ -71,50 +58,31 @@ def ray_cluster(): @pytest.fixture(scope="module") def e2e_client(ray_cluster): - """Create a client with 2 storage units for lifecycle testing.""" - controller_actor = TransferQueueController.options( - name="lifecycle_controller", - get_if_exists=True, - ).remote(polling_mode=True) - controller_info = ray.get(controller_actor.get_zmq_server_info.remote()) - - # Two storage units to ensure sharding - storage_actor_1 = SimpleStorageUnit.options( - name="lifecycle_storage_1", - get_if_exists=True, - ).remote(storage_unit_size=10000) - storage_info_1 = ray.get(storage_actor_1.get_zmq_server_info.remote()) - - storage_actor_2 = SimpleStorageUnit.options( - name="lifecycle_storage_2", - get_if_exists=True, - ).remote(storage_unit_size=10000) - storage_info_2 = ray.get(storage_actor_2.get_zmq_server_info.remote()) - - client = TransferQueueClient( - client_id="lifecycle_test_client", - controller_info=controller_info, - ) + """Create a client using transfer_queue.init() for lifecycle testing.""" + from omegaconf import OmegaConf + + import transfer_queue config = { - "controller_info": controller_info, - "storage_unit_infos": { - storage_info_1.id: storage_info_1, - storage_info_2.id: storage_info_2, + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "SimpleStorage", + "SimpleStorage": { + "total_storage_size": 200, + "num_data_storage_units": 2, + }, }, - "storage_backend_config": {"storage_unit_size": 10000}, } - - client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - + transfer_queue.init(OmegaConf.create(config)) + client = transfer_queue.get_client() yield client + transfer_queue.close() def generate_complex_data(indices: list[int]) -> TensorDict: - """ - Generates complex TensorDict with all required field types. - Based on consistency_validation_plan.md Section 3.2. - """ + """Generate complex TensorDict with all supported field types.""" n = len(indices) # Standard Tensor (Float32) @@ -128,15 +96,15 @@ def generate_complex_data(indices: list[int]) -> TensorDict: for i in indices: length = 2 + (i % 4) # Variable length: 2-5 nested_list.append(torch.arange(i, i + length, dtype=torch.float32)) + # Inject special values into jagged tensor components + nested_list[0][0] = float("inf") + if len(nested_list) > 1: + nested_list[1][0] = float("nan") nested_jagged = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged) - # Nested Tensor (Strided) - fallback to jagged if not supported - try: - strided_tensors = [torch.full((3, 4), float(i)) for i in indices] - nested_strided = torch.nested.nested_tensor(strided_tensors, layout=torch.strided) - except (TypeError, RuntimeError): - strided_tensors = [torch.full((3, 4), float(i)) for i in indices] - nested_strided = torch.nested.as_nested_tensor(strided_tensors, layout=torch.jagged) + # Nested Tensor (Strided) + strided_tensors = [torch.full((3, 4), float(i)) for i in indices] + nested_strided = torch.nested.nested_tensor(strided_tensors, layout=torch.strided) # Python Lists list_int = [i * 10 for i in indices] @@ -156,19 +124,36 @@ def generate_complex_data(indices: list[int]) -> TensorDict: non_tensor_data = [{"idx": i, "text": f"non_tensor_{i}"} for i in indices] non_tensor_stack = NonTensorData(data=non_tensor_data, batch_size=(n,), device=None) + # BFloat16 Tensor + tensor_bf16 = torch.stack([torch.arange(i, i + 5, dtype=torch.bfloat16) for i in indices]) + + # List of objects (dicts) + list_obj = [{"key": f"value_{i}", "num": i} for i in indices] + + field_values = { + "tensor_f32": tensor_f32, + "tensor_i64": tensor_i64, + "tensor_bf16": tensor_bf16, + "nested_jagged": nested_jagged, + "nested_strided": nested_strided, + "list_int": list_int, + "list_str": list_str, + "list_obj": list_obj, + "np_array": np_array, + "np_obj": np_obj, + "special_val": special_val, + "non_tensor_stack": non_tensor_stack, + } + + # Validate: field_values must exactly match DEFAULT_FIELDS + assert set(field_values.keys()) == set(DEFAULT_FIELDS), ( + f"generate_complex_data fields mismatch with DEFAULT_FIELDS: " + f"extra={set(field_values.keys()) - set(DEFAULT_FIELDS)}, " + f"missing={set(DEFAULT_FIELDS) - set(field_values.keys())}" + ) + return TensorDict( - { - "tensor_f32": tensor_f32, - "tensor_i64": tensor_i64, - "nested_jagged": nested_jagged, - "nested_strided": nested_strided, - "list_int": list_int, - "list_str": list_str, - "np_array": np_array, - "np_obj": np_obj, - "special_val": special_val, - "non_tensor_stack": non_tensor_stack, - }, + {field: field_values[field] for field in DEFAULT_FIELDS}, batch_size=n, ) @@ -189,9 +174,7 @@ def poll_for_meta(client, partition_id, data_fields, batch_size, task_name, mode return None -# ============================================================================= # Helper Functions for Data Verification -# ============================================================================= def verify_special_values(retrieved: torch.Tensor, expected: torch.Tensor) -> bool: """Verify special values (NaN, Inf) are preserved.""" # Check Inf column @@ -207,27 +190,47 @@ def verify_special_values(retrieved: torch.Tensor, expected: torch.Tensor) -> bo def verify_nested_tensor_equal(retrieved, expected) -> bool: - """Verify nested tensors element by element.""" - if len(retrieved.unbind()) != len(expected.unbind()): + """Verify nested tensors element by element, handling NaN/Inf.""" + r_list = retrieved.unbind() + e_list = expected.unbind() + if len(r_list) != len(e_list): return False - for r, e in zip(retrieved.unbind(), expected.unbind(), strict=False): - if not torch.allclose(r, e): + for r, e in zip(r_list, e_list, strict=True): + # Handle NaN: positions must match + r_nan = torch.isnan(r) + e_nan = torch.isnan(e) + if not torch.equal(r_nan, e_nan): + return False + # Compare non-NaN values (allclose handles Inf correctly) + mask = ~r_nan + if mask.any() and not torch.allclose(r[mask], e[mask]): return False return True def verify_non_tensor_data(retrieved, expected) -> bool: - """Verify NonTensorData content.""" - if hasattr(retrieved, "data"): + """Verify NonTensorData content element by element.""" + if hasattr(retrieved, "tolist"): + retrieved = retrieved.tolist() + elif hasattr(retrieved, "data"): retrieved = retrieved.data - if hasattr(expected, "data"): + if hasattr(expected, "tolist"): + expected = expected.tolist() + elif hasattr(expected, "data"): expected = expected.data + if isinstance(retrieved, list) and isinstance(expected, list): + if len(retrieved) != len(expected): + return False + return all(r == e for r, e in zip(retrieved, expected, strict=True)) return retrieved == expected def verify_list_equal(retrieved, expected) -> bool: - """Verify list content, handling possible Tensor conversion.""" - # Convert Tensor to list if needed + """Verify list content. + + Note: TensorDict automatically converts Python lists to Tensors during storage, + so we convert back to native Python types before comparison. + """ if isinstance(retrieved, torch.Tensor): retrieved = retrieved.tolist() if isinstance(expected, torch.Tensor): @@ -235,19 +238,31 @@ def verify_list_equal(retrieved, expected) -> bool: return retrieved == expected -# ============================================================================= -# Scenario One: Core Read/Write Consistency -# ============================================================================= -def test_core_consistency(e2e_client): - """ - Test Case: Core Read/Write Consistency (Scenario 1) +def _reorder_tensordict(td: TensorDict, order: list[int]) -> TensorDict: + """Reorder a TensorDict by the given index order. - Validates: - 1. Put full complex data -> Get retrieves identical data - 2. NaN remains NaN, Inf remains Inf - 3. NonTensorData unpacks without loss - 4. All field types are correctly round-tripped + Handles regular tensors, nested/jagged tensors, lists, and other indexable types. """ + reordered = {} + for key in td.keys(): + field = td[key] + if hasattr(field, "unbind"): + items = field.unbind(0) + reordered_items = [items[i] for i in order] + try: + reordered[key] = torch.stack(reordered_items) + except RuntimeError: + reordered[key] = torch.nested.as_nested_tensor(reordered_items, layout=field.layout) + elif isinstance(field, list): + reordered[key] = [field[i] for i in order] + else: + reordered[key] = field[torch.tensor(order)] + return TensorDict(reordered, batch_size=td.batch_size) + + +# Scenario One: Core Read/Write Consistency +def test_core_consistency(e2e_client): + """Put full complex data then get — verify all field types are correctly round-tripped.""" client = e2e_client partition_id = "test_core_consistency" batch_size = 20 @@ -269,6 +284,7 @@ def test_core_consistency(e2e_client): # 3. Verify Standard Tensors assert torch.allclose(retrieved_data["tensor_f32"], original_data["tensor_f32"]), "tensor_f32 mismatch" assert torch.equal(retrieved_data["tensor_i64"], original_data["tensor_i64"]), "tensor_i64 mismatch" + assert torch.equal(retrieved_data["tensor_bf16"], original_data["tensor_bf16"]), "tensor_bf16 mismatch" # 4. Verify Nested Tensors (Jagged) assert verify_nested_tensor_equal(retrieved_data["nested_jagged"], original_data["nested_jagged"]), ( @@ -283,6 +299,7 @@ def test_core_consistency(e2e_client): # 6. Verify Python Lists assert verify_list_equal(retrieved_data["list_int"], original_data["list_int"]), "list_int mismatch" assert verify_list_equal(retrieved_data["list_str"], original_data["list_str"]), "list_str mismatch" + assert verify_list_equal(retrieved_data["list_obj"], original_data["list_obj"]), "list_obj mismatch" # 7. Verify NumPy Arrays assert np.allclose(retrieved_data["np_array"], original_data["np_array"]), "np_array mismatch" @@ -301,25 +318,12 @@ def test_core_consistency(e2e_client): client.clear_partition(partition_id) -# ============================================================================= -# Scenario Two: Cross-Partition & Complex Update -# ============================================================================= -def test_cross_partition_complex_update(e2e_client): - """ - Test Case: Cross-Partition & Complex Update (Scenario 2) - - Validates: - 1. Put A (indices 0-19) with full complex fields - 2. Put B (indices 20-39) with full complex fields - 3. Update indices 10-29 (cross-shard): modify existing fields + add new fields - 4. Get Full (0-39) and verify: - - 0-9: original Put A values - - 10-29: updated values with new fields - - 30-39: original Put B values - """ +# Scenario Two: Cross-Shard Update +def test_cross_shard_complex_update(e2e_client): + """Cross-shard update: put A + put B, update overlapping region, verify all regions.""" client = e2e_client - partition_id = "test_cross_partition_update" - task_name = "cross_partition_task" + partition_id = "test_cross_shard_update" + task_name = "cross_shard_task" # Define index ranges idx_a = list(range(0, 20)) # Put A @@ -371,6 +375,11 @@ def test_cross_partition_complex_update(e2e_client): assert full_meta is not None and full_meta.size == 40, "Failed to retrieve full metadata" full_data = client.get_data(full_meta) + # Reorder by global_indexes for deterministic positional assertions + sorted_order = sorted(range(full_meta.size), key=lambda i: full_meta.global_indexes[i]) + if sorted_order != list(range(full_meta.size)): + full_data = _reorder_tensordict(full_data, sorted_order) + # 6. Verify region 0-9: original Put A values original_data_0_9 = generate_complex_data(list(range(0, 10))) assert torch.allclose(full_data["tensor_f32"][:10], original_data_0_9["tensor_f32"]), ( @@ -402,20 +411,9 @@ def test_cross_partition_complex_update(e2e_client): client.clear_partition(partition_id) -# ============================================================================= # Scenario Three: Production Status Lifecycle -# ============================================================================= def test_production_status_lifecycle(e2e_client): - """ - Test Case: Production Status Lifecycle (Scenario 3) - - Validates multi-round partial field put and production status transitions. - - Steps: - 1. Round 1 Put: Indices 0-9, only Set_A fields -> Check production(Set_A)=True, production(Set_B)=False - 2. Round 2 Put: Indices 0-9, complete Set_B fields -> Check production(Set_A+Set_B)=True - 3. Verify consumption status transitions - """ + """Multi-round partial put: verify production & consumption status transitions.""" client = e2e_client partition_id = "test_production_lifecycle" batch_size = 10 @@ -465,30 +463,26 @@ def test_production_status_lifecycle(e2e_client): is_consumed = client.check_consumption_status(task_name=task_name, partition_id=partition_id) assert not is_consumed, "Data should not be consumed initially" - # 7. Consume Data + # 7. Consume Data (consumption is marked during get_meta(fetch)) meta = poll_for_meta(client, partition_id, all_fields, batch_size, task_name, mode="fetch") assert meta is not None, "Failed to poll metadata" + + # Consumption is marked during get_meta(fetch), verify before get_data + is_consumed_mid = client.check_consumption_status(task_name=task_name, partition_id=partition_id) + assert is_consumed_mid, "Data should already be consumed after get_meta in fetch mode" + client.get_data(meta) # 8. Post-Consumption Check - should be True is_consumed_after = client.check_consumption_status(task_name=task_name, partition_id=partition_id) - assert is_consumed_after, "Data should be consumed after get_data" + assert is_consumed_after, "Data should be consumed after get_meta in fetch mode" finally: client.clear_partition(partition_id) -# ============================================================================= # Scenario Four: Custom Metadata Persistence -# ============================================================================= def test_custom_metadata_persistence(e2e_client): - """ - - Test Case: Custom Metadata Persistence (Scenario 4) - - Validates: - 1. put data -> set_custom_meta -> get_meta retrieves correct custom_meta - 2. Custom metadata is per-sample and survives roundtrip - """ + """Set per-sample custom metadata, retrieve it, and verify persistence.""" client = e2e_client partition_id = "test_custom_meta" batch_size = 8 @@ -504,14 +498,15 @@ def test_custom_metadata_persistence(e2e_client): try: # 2. Create Custom Metadata for each sample - custom_metadata = {} - for i in range(batch_size): - custom_metadata[meta.global_indexes[i]] = { + custom_metadata_list = [ + { "score": float(i) / 10.0, "label": f"label_{i}", "tags": [f"tag_{i}_a", f"tag_{i}_b"], } - meta.update_custom_meta(custom_metadata) + for i in range(batch_size) + ] + meta.update_custom_meta(custom_metadata_list) # 3. Upload Custom Metadata client.set_custom_meta(meta) @@ -522,28 +517,20 @@ def test_custom_metadata_persistence(e2e_client): # Verify custom metadata content retrieved_custom = retrieved_meta.get_all_custom_meta() - for global_idx, expected_meta in custom_metadata.items(): - assert global_idx in retrieved_custom, f"Missing custom_meta for index {global_idx}" - actual = retrieved_custom[global_idx] - assert actual["score"] == expected_meta["score"], f"Score mismatch at index {global_idx}" - assert actual["label"] == expected_meta["label"], f"Label mismatch at index {global_idx}" - assert actual["tags"] == expected_meta["tags"], f"Tags mismatch at index {global_idx}" + assert len(retrieved_custom) == batch_size, ( + f"Expected {batch_size} custom_meta entries, got {len(retrieved_custom)}" + ) + for i, (actual, expected) in enumerate(zip(retrieved_custom, custom_metadata_list, strict=True)): + assert actual["score"] == expected["score"], f"Score mismatch at sample {i}" + assert actual["label"] == expected["label"], f"Label mismatch at sample {i}" + assert actual["tags"] == expected["tags"], f"Tags mismatch at sample {i}" finally: client.clear_partition(partition_id) -# ============================================================================= # Scenario Five: Reset & Clear -# ============================================================================= def test_reset_consumption(e2e_client): - """ - Test Case: Reset Consumption Status (Scenario 5a) - - Validates: - 1. After consuming data, consumption status is True - 2. After reset_consumption, status reverts to False - 3. Data can be re-consumed after reset - """ + """Consume data, reset consumption status, verify re-consumability.""" client = e2e_client partition_id = "test_reset_consumption" batch_size = 10 @@ -559,15 +546,20 @@ def test_reset_consumption(e2e_client): is_consumed_initial = client.check_consumption_status(task_name=task_name, partition_id=partition_id) assert not is_consumed_initial, "Data should not be consumed initially" - # 3. Consume Data (get_meta + get_data) + # 3. Consume Data (consumption is marked during get_meta(fetch)) meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="fetch") assert meta is not None and meta.size == batch_size, "Failed to poll metadata" + + # Consumption is marked during get_meta(fetch), verify before get_data + is_consumed_mid = client.check_consumption_status(task_name=task_name, partition_id=partition_id) + assert is_consumed_mid, "Data should already be consumed after get_meta in fetch mode" + retrieved_data = client.get_data(meta) assert retrieved_data.batch_size[0] == batch_size, "Retrieved data batch_size mismatch" # 4. Post-Consumption Status Check - should be True is_consumed_after = client.check_consumption_status(task_name=task_name, partition_id=partition_id) - assert is_consumed_after, "Data should be consumed after get_data" + assert is_consumed_after, "Data should be consumed after get_meta in fetch mode" # 5. Reset Consumption success = client.reset_consumption(partition_id=partition_id, task_name=task_name) @@ -585,15 +577,7 @@ def test_reset_consumption(e2e_client): def test_clear_partition(e2e_client): - """ - Test Case: Clear Partition (Scenario 5b) - - Validates: - 1. Put data -> data is accessible - 2. clear_partition -> data is physically deleted - 3. After clear, check_production_status returns False - 4. After clear, partition is removed from partition list - """ + """Clear partition: verify data removal and production status reset.""" client = e2e_client partition_id = "test_clear_partition" batch_size = 15 @@ -604,28 +588,35 @@ def test_clear_partition(e2e_client): data = generate_complex_data(list(range(batch_size))) client.put(data=data, partition_id=partition_id) - # 2. Verify Data Exists - production status should be True - is_ready = client.check_production_status(data_fields=fields, partition_id=partition_id) - assert is_ready, "Data should be ready after put" + try: + # 2. Verify Data Exists - production status should be True + is_ready = client.check_production_status(data_fields=fields, partition_id=partition_id) + assert is_ready, "Data should be ready after put" - # 3. Get Data to confirm accessibility - meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch") - assert meta is not None and meta.size == batch_size, "Failed to poll metadata" + # 3. Get Data to confirm accessibility + meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch") + assert meta is not None and meta.size == batch_size, "Failed to poll metadata" - # 4. Verify partition exists before clear - partition_list_before = client.get_partition_list() - assert partition_id in partition_list_before, "Partition should exist before clear" + # 4. Verify partition exists before clear + partition_list_before = client.get_partition_list() + assert partition_id in partition_list_before, "Partition should exist before clear" - # 5. Clear Partition - client.clear_partition(partition_id) + # 5. Clear Partition + client.clear_partition(partition_id) - # 6. Verify partition is removed from list - partition_list_after = client.get_partition_list() - assert partition_id not in partition_list_after, "Partition should be removed after clear" + # 6. Verify partition is removed from list + partition_list_after = client.get_partition_list() + assert partition_id not in partition_list_after, "Partition should be removed after clear" - # 7. Verify Production Status returns False for cleared partition - is_ready_after_clear = client.check_production_status(data_fields=fields, partition_id=partition_id) - assert not is_ready_after_clear, "Production status should be False after clear" + # 7. Verify Production Status returns False for cleared partition + is_ready_after_clear = client.check_production_status(data_fields=fields, partition_id=partition_id) + assert not is_ready_after_clear, "Production status should be False after clear" + finally: + # Ensure cleanup even if assertions fail + try: + client.clear_partition(partition_id) + except Exception: + pass if __name__ == "__main__": From 0ac69a201217a5e7618ff58c993b22ced0e6f8be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9C=8B=E6=88=9172=E9=81=8D?= Date: Wed, 25 Feb 2026 14:36:49 +0800 Subject: [PATCH 9/9] fix some review issue MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 看我72遍 --- tests/e2e/test_e2e_lifecycle_consistency.py | 38 ++++++++++++++++----- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index c82a43d..4e7025f 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -228,12 +228,18 @@ def verify_non_tensor_data(retrieved, expected) -> bool: def verify_list_equal(retrieved, expected) -> bool: """Verify list content. - Note: TensorDict automatically converts Python lists to Tensors during storage, - so we convert back to native Python types before comparison. + Note: TensorDict may materialize Python lists as Tensors or NonTensorStack during + storage/retrieval, so we normalize both sides to native Python types before comparison. """ - if isinstance(retrieved, torch.Tensor): + from tensordict.tensorclass import NonTensorStack # local import to avoid circular deps + + if isinstance(retrieved, NonTensorStack): + retrieved = retrieved.tolist() + elif isinstance(retrieved, torch.Tensor): retrieved = retrieved.tolist() - if isinstance(expected, torch.Tensor): + if isinstance(expected, NonTensorStack): + expected = expected.tolist() + elif isinstance(expected, torch.Tensor): expected = expected.tolist() return retrieved == expected @@ -241,17 +247,26 @@ def verify_list_equal(retrieved, expected) -> bool: def _reorder_tensordict(td: TensorDict, order: list[int]) -> TensorDict: """Reorder a TensorDict by the given index order. - Handles regular tensors, nested/jagged tensors, lists, and other indexable types. + Handles regular tensors, nested/jagged tensors, NonTensorStack, lists, and other + indexable types. """ + from tensordict.tensorclass import NonTensorStack # local import to avoid circular deps + reordered = {} for key in td.keys(): field = td[key] - if hasattr(field, "unbind"): + if isinstance(field, NonTensorStack): + # NonTensorStack: reorder by converting to list and re-wrapping + items = field.tolist() + reordered_items = [items[i] for i in order] + reordered[key] = NonTensorStack(*reordered_items, batch_size=[len(order)]) + elif hasattr(field, "unbind"): items = field.unbind(0) reordered_items = [items[i] for i in order] try: reordered[key] = torch.stack(reordered_items) - except RuntimeError: + except (RuntimeError, TypeError): + # RuntimeError: shape mismatch (jagged); TypeError: non-Tensor items reordered[key] = torch.nested.as_nested_tensor(reordered_items, layout=field.layout) elif isinstance(field, list): reordered[key] = [field[i] for i in order] @@ -303,7 +318,14 @@ def test_core_consistency(e2e_client): # 7. Verify NumPy Arrays assert np.allclose(retrieved_data["np_array"], original_data["np_array"]), "np_array mismatch" - assert np.array_equal(retrieved_data["np_obj"], original_data["np_obj"]), "np_obj mismatch" + # np_obj may be returned as NonTensorStack; normalize to list before comparing + retrieved_np_obj = retrieved_data["np_obj"] + if hasattr(retrieved_np_obj, "tolist"): + retrieved_np_obj = retrieved_np_obj.tolist() + expected_np_obj = original_data["np_obj"] + if hasattr(expected_np_obj, "tolist") and not isinstance(expected_np_obj, np.ndarray): + expected_np_obj = expected_np_obj.tolist() + assert list(retrieved_np_obj) == list(expected_np_obj), "np_obj mismatch" # 8. Verify Special Values (NaN and Inf) assert verify_special_values(retrieved_data["special_val"], original_data["special_val"]), (