Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions scripts/put_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@
parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))

from transfer_queue import ( # noqa: E402
AsyncTransferQueueClient,
SimpleStorageUnit,
TransferQueueController,
process_zmq_server_info,
)
from transfer_queue import TransferQueueClient # noqa: E402
from transfer_queue.controller import TransferQueueController # noqa: E402
from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402
from transfer_queue.utils.common import get_placement_group # noqa: E402
from transfer_queue.utils.zmq_utils import process_zmq_server_info # noqa: E402

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -309,7 +307,7 @@ def initialize_system(self, config_dict):
self.tq_config = OmegaConf.merge(tq_internal_conf, self.tq_config)

# Client Init
self.data_system_client = AsyncTransferQueueClient(
self.data_system_client = TransferQueueClient(
client_id="Trainer", controller_info=self.data_system_controller_info
)
self.data_system_client.initialize_storage_manager(
Expand Down
21 changes: 13 additions & 8 deletions tests/e2e/test_e2e_lifecycle_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,15 +420,20 @@ def test_cross_shard_complex_update(e2e_client):
"Region 30-39 tensor_f32 should match original Put B"
)

# 9. Verify new fields exist in update region
extended_fields = base_fields + ["new_extra_tensor", "new_extra_non_tensor"]
update_region_meta = poll_for_meta(
client, partition_id, extended_fields, 20, "update_region_task", mode="force_fetch"
# 9. Verify new fields exist in update region (indices 10-29 only have new fields).
# Build extended_meta from full_meta (which has valid _custom_backend_meta/_su_id)
# by selecting the subset of samples whose global_indexes match meta_update.
# Using meta_update directly would fail because it was derived from alloc_meta
# before put(), so its _custom_backend_meta lacks _su_id.
update_gis = set(meta_update.global_indexes)
update_positions_in_full = [i for i, gi in enumerate(full_meta.global_indexes) if gi in update_gis]
update_meta_with_backend = full_meta.select_samples(update_positions_in_full)
extended_meta = update_meta_with_backend.with_data_fields(
base_fields + ["new_extra_tensor", "new_extra_non_tensor"]
)
if update_region_meta is not None and update_region_meta.size > 0:
update_region_data = client.get_data(update_region_meta)
assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist"
assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist"
update_region_data = client.get_data(extended_meta)
assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist"
assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist"
finally:
client.clear_partition(partition_id)

Expand Down
Loading