From 318cc0558f87701a6c15312bc04e23f0f4ed762a Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 27 Feb 2026 09:25:10 +0800 Subject: [PATCH 1/5] use async zmq context Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/base.py | 69 +++++++++++++++---------- 1 file changed, 42 insertions(+), 27 deletions(-) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 2e9ce30..083cd46 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -28,6 +28,7 @@ import ray import torch import zmq +import zmq.asyncio from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict from torch import Tensor @@ -65,10 +66,11 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): self.config = config self.controller_info = controller_info - self.data_status_update_socket: Optional[zmq.Socket[bytes]] = None - self.controller_handshake_socket: Optional[zmq.Socket[bytes]] = None + self.data_status_update_socket: Optional[zmq.asyncio.Socket] = None + # Handshake socket is sync (used only during initialization) + self.controller_handshake_socket: Optional[zmq.Socket] = None - self.zmq_context: Optional[zmq.Context[Any]] = None + self.zmq_context: Optional[zmq.asyncio.Context] = None self._connect_to_controller() def _connect_to_controller(self) -> None: @@ -77,15 +79,28 @@ def _connect_to_controller(self) -> None: raise ValueError(f"controller_info should be ZMQServerInfo, but got {type(self.controller_info)}") try: - # create zmq context - self.zmq_context = zmq.Context() + # Create a synchronous context for handshake (blocking operation) + sync_zmq_context = zmq.Context() - # create zmq sockets for handshake and data status update + # create zmq socket for handshake (sync, for initial connection) self.controller_handshake_socket = create_zmq_socket( - self.zmq_context, + sync_zmq_context, zmq.DEALER, identity=f"{self.storage_manager_id}-controller_handshake_socket-{uuid4().hex[:8]}".encode(), ) + + # do handshake with controller using sync socket + self._do_handshake_with_controller() + + # Close the sync handshake socket and context after handshake + if self.controller_handshake_socket and not self.controller_handshake_socket.closed: + self.controller_handshake_socket.close(linger=0) + sync_zmq_context.term() + + # Create async context for data update notifications (non-blocking) + self.zmq_context = zmq.asyncio.Context() + + # create zmq async socket for data status update self.data_status_update_socket = create_zmq_socket( self.zmq_context, zmq.DEALER, @@ -94,9 +109,6 @@ def _connect_to_controller(self) -> None: assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" self.data_status_update_socket.connect(self.controller_info.to_addr("data_status_update_socket")) - # do handshake with controller - self._do_handshake_with_controller() - except Exception as e: logger.error(f"Failed to connect to controller: {e}") raise @@ -210,20 +222,14 @@ async def notify_data_update( shapes: Per-field shapes for each field, in {global_index: {field: shape}} format. custom_backend_meta: Per-field custom_meta for each sample, in {global_index: {field: custom_meta}} format. """ - # Create zmq poller for notifying data update information - if not self.controller_info: logger.warning(f"No controller connected for storage manager {self.storage_manager_id}") return - # Create zmq poller for notifying data update information - poller = zmq.Poller() # Note: data_status_update_socket is already connected during initialization assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" try: - poller.register(self.data_status_update_socket, zmq.POLLIN) - request_msg = ZMQMessage.create( request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type] sender_id=self.storage_manager_id, @@ -237,7 +243,7 @@ async def notify_data_update( }, ).serialize() - self.data_status_update_socket.send_multipart(request_msg) + await self.data_status_update_socket.send_multipart(request_msg) logger.debug( f"[{self.storage_manager_id}]: Send data status update request " f"from storage manager id #{self.storage_manager_id} " @@ -254,20 +260,21 @@ async def notify_data_update( }, ).serialize() - self.data_status_update_socket.send_multipart(request_msg) + await self.data_status_update_socket.send_multipart(request_msg) # Make sure controller successfully receives data status update information. + # Use asyncio.wait_for with timeout instead of blocking poller response_received: bool = False - start_time = time.time() + timeout = TQ_DATA_UPDATE_RESPONSE_TIMEOUT - while ( - not response_received # Only one controller to get response from - and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT - ): - socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000)) - - if self.data_status_update_socket in socks: - response_msg = ZMQMessage.deserialize(self.data_status_update_socket.recv_multipart()) + while not response_received and timeout > 0: + try: + # Use asyncio.wait_for with polling interval for responsiveness + poll_interval = min(TQ_STORAGE_POLLER_TIMEOUT, timeout) + messages = await asyncio.wait_for( + self.data_status_update_socket.recv_multipart(), timeout=poll_interval + ) + response_msg = ZMQMessage.deserialize(messages) if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: response_received = True @@ -276,6 +283,14 @@ async def notify_data_update( f"from controller id #{response_msg.sender_id} " f"to storage manager id #{self.storage_manager_id} successfully." ) + except asyncio.TimeoutError: + # Timeout waiting for response, check if we should continue + timeout -= poll_interval + if timeout <= 0: + break + except Exception as e: + logger.warning(f"[{self.storage_manager_id}]: Error receiving data status update response: {e}") + break if not response_received: logger.error( From be0a58ec26ad39202fed5c742c650b82d70652bb Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 27 Feb 2026 10:42:23 +0800 Subject: [PATCH 2/5] use dynamic socket Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/base.py | 122 ++++++++++-------------- 1 file changed, 52 insertions(+), 70 deletions(-) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 083cd46..36cc873 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -66,7 +66,6 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): self.config = config self.controller_info = controller_info - self.data_status_update_socket: Optional[zmq.asyncio.Socket] = None # Handshake socket is sync (used only during initialization) self.controller_handshake_socket: Optional[zmq.Socket] = None @@ -97,18 +96,6 @@ def _connect_to_controller(self) -> None: self.controller_handshake_socket.close(linger=0) sync_zmq_context.term() - # Create async context for data update notifications (non-blocking) - self.zmq_context = zmq.asyncio.Context() - - # create zmq async socket for data status update - self.data_status_update_socket = create_zmq_socket( - self.zmq_context, - zmq.DEALER, - identity=f"{self.storage_manager_id}-data_status_update_socket-{uuid4().hex[:8]}".encode(), - ) - assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" - self.data_status_update_socket.connect(self.controller_info.to_addr("data_status_update_socket")) - except Exception as e: logger.error(f"Failed to connect to controller: {e}") raise @@ -226,12 +213,17 @@ async def notify_data_update( logger.warning(f"No controller connected for storage manager {self.storage_manager_id}") return - # Note: data_status_update_socket is already connected during initialization - assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" + # create dynamic socket + # TODO: use unified dynamic socket register + context = zmq.asyncio.Context() + identity = f"{self.storage_manager_id}-data_update-{uuid4().hex[:8]}".encode() + sock = create_zmq_socket(context, zmq.DEALER, identity=identity) try: + sock.connect(self.controller_info.to_addr("data_status_update_socket")) + request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type] + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, sender_id=self.storage_manager_id, body={ "partition_id": partition_id, @@ -243,60 +235,55 @@ async def notify_data_update( }, ).serialize() - await self.data_status_update_socket.send_multipart(request_msg) + await sock.send_multipart(request_msg) logger.debug( f"[{self.storage_manager_id}]: Send data status update request " f"from storage manager id #{self.storage_manager_id} " f"to controller id #{self.controller_info.id} successfully." ) - except Exception as e: - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type] - sender_id=self.storage_manager_id, - body={ - "message": f"Failed to notify data status update information from " - f"storage manager id #{self.storage_manager_id}, " - f"detail error message: {str(e)}" - }, - ).serialize() - await self.data_status_update_socket.send_multipart(request_msg) + response_received = False + timeout = TQ_DATA_UPDATE_RESPONSE_TIMEOUT - # Make sure controller successfully receives data status update information. - # Use asyncio.wait_for with timeout instead of blocking poller - response_received: bool = False - timeout = TQ_DATA_UPDATE_RESPONSE_TIMEOUT + while not response_received and timeout > 0: + try: + poll_interval = min(TQ_STORAGE_POLLER_TIMEOUT, timeout) + messages = await asyncio.wait_for(sock.recv_multipart(), timeout=poll_interval) + response_msg = ZMQMessage.deserialize(messages) - while not response_received and timeout > 0: - try: - # Use asyncio.wait_for with polling interval for responsiveness - poll_interval = min(TQ_STORAGE_POLLER_TIMEOUT, timeout) - messages = await asyncio.wait_for( - self.data_status_update_socket.recv_multipart(), timeout=poll_interval - ) - response_msg = ZMQMessage.deserialize(messages) - - if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: - response_received = True - logger.debug( - f"[{self.storage_manager_id}]: Get data status update ACK response " - f"from controller id #{response_msg.sender_id} " - f"to storage manager id #{self.storage_manager_id} successfully." - ) - except asyncio.TimeoutError: - # Timeout waiting for response, check if we should continue - timeout -= poll_interval - if timeout <= 0: + if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: + response_received = True + logger.debug( + f"[{self.storage_manager_id}]: Get data status update ACK response " + f"from controller id #{response_msg.sender_id} successfully." + ) + except asyncio.TimeoutError: + timeout -= poll_interval + except Exception as e: + logger.warning(f"[{self.storage_manager_id}]: Error receiving response: {e}") break - except Exception as e: - logger.warning(f"[{self.storage_manager_id}]: Error receiving data status update response: {e}") - break - if not response_received: - logger.error( - f"[{self.storage_manager_id}]: Storage manager id #{self.storage_manager_id} " - f"did not receive data status update ACK response from controller." - ) + if not response_received: + logger.error(f"[{self.storage_manager_id}]: Did not receive data status update ACK.") + + except Exception as e: + # 发送错误通知 + try: + error_msg = ZMQMessage.create( + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, + sender_id=self.storage_manager_id, + body={"message": f"Failed to notify: {str(e)}"}, + ).serialize() + await sock.send_multipart(error_msg) + except Exception: + pass + finally: + try: + if not sock.closed: + sock.close(linger=0) + except Exception: + pass + context.term() @abstractmethod async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: @@ -334,18 +321,13 @@ async def clear_data(self, metadata: BatchMeta) -> None: def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" - for sock in (self.controller_handshake_socket, self.data_status_update_socket): + # Close handshake socket if it exists + if self.controller_handshake_socket: try: - if sock and not sock.closed: - sock.close(linger=0) + if not self.controller_handshake_socket.closed: + self.controller_handshake_socket.close(linger=0) except Exception as e: - logger.error(f"[{self.storage_manager_id}]: Error closing socket {sock}: {str(e)}") - - try: - if self.zmq_context: - self.zmq_context.term() - except Exception as e: - logger.error(f"[{self.storage_manager_id}]: Error terminating zmq_context: {str(e)}") + logger.error(f"[{self.storage_manager_id}]: Error closing controller_handshake_socket: {str(e)}") def __del__(self): """Destructor to ensure resources are cleaned up.""" From 5aafaf93ca65dccabae4da53b10a8de9b4dbc77c Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 27 Feb 2026 11:09:44 +0800 Subject: [PATCH 3/5] fix Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/base.py | 14 +++++++++----- .../storage/managers/simple_backend_manager.py | 16 ++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 36cc873..61330bc 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -91,11 +91,15 @@ def _connect_to_controller(self) -> None: # do handshake with controller using sync socket self._do_handshake_with_controller() - # Close the sync handshake socket and context after handshake + # close the sync handshake socket and context after handshake if self.controller_handshake_socket and not self.controller_handshake_socket.closed: self.controller_handshake_socket.close(linger=0) + self.controller_handshake_socket = None sync_zmq_context.term() + # create async context for data status update + self.zmq_context = zmq.asyncio.Context() + except Exception as e: logger.error(f"Failed to connect to controller: {e}") raise @@ -214,10 +218,8 @@ async def notify_data_update( return # create dynamic socket - # TODO: use unified dynamic socket register - context = zmq.asyncio.Context() identity = f"{self.storage_manager_id}-data_update-{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity) + sock = create_zmq_socket(self.zmq_context, zmq.DEALER, identity=identity) try: sock.connect(self.controller_info.to_addr("data_status_update_socket")) @@ -283,7 +285,6 @@ async def notify_data_update( sock.close(linger=0) except Exception: pass - context.term() @abstractmethod async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: @@ -329,6 +330,9 @@ def close(self) -> None: except Exception as e: logger.error(f"[{self.storage_manager_id}]: Error closing controller_handshake_socket: {str(e)}") + if self.zmq_context: + self.zmq_context.term() + def __del__(self): """Destructor to ensure resources are cleaned up.""" try: diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index b658ba4..8066d6a 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -45,8 +45,7 @@ handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) logger.addHandler(handler) -TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT", 200)) # seconds -TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT", 200)) # seconds +TQ_SIMPLE_STORAGE_MANAGER_COMM_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT", 200)) # seconds TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) @@ -119,11 +118,12 @@ def _build_storage_mapping_functions(self): # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. @staticmethod - def dynamic_storage_manager_socket(socket_name: str): + def dynamic_storage_manager_socket(socket_name: str, timeout: int): """Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close). Args: socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port"). + timeout (float): Timeout in seconds for ZMQ connection (in seconds). Decorated Function Rules: 1. Must be an async class method (needs `self`). @@ -157,8 +157,8 @@ async def wrapper(self, *args, **kwargs): try: sock.connect(address) # Timeouts to avoid indefinite await on recv/send - sock.setsockopt(zmq.RCVTIMEO, TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT * 1000) - sock.setsockopt(zmq.SNDTIMEO, TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT * 1000) + sock.setsockopt(zmq.RCVTIMEO, timeout * 1000) + sock.setsockopt(zmq.SNDTIMEO, timeout * 1000) logger.debug( f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} " f"with identity {identity.decode()}" @@ -249,7 +249,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: partition_id, list(results.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes ) - @dynamic_storage_manager_socket(socket_name="put_get_socket") + @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_MANAGER_COMM_TIMEOUT) async def _put_to_single_storage_unit( self, local_indexes: list[int], @@ -348,7 +348,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: return TensorDict(tensor_data, batch_size=len(metadata)) - @dynamic_storage_manager_socket(socket_name="put_get_socket") + @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_MANAGER_COMM_TIMEOUT) async def _get_from_single_storage_unit( self, storage_meta_group: StorageMetaGroup, target_storage_unit: str, socket: zmq.Socket = None ): @@ -407,7 +407,7 @@ async def clear_data(self, metadata: BatchMeta) -> None: if isinstance(result, Exception): logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}") - @dynamic_storage_manager_socket(socket_name="put_get_socket") + @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_MANAGER_COMM_TIMEOUT) async def _clear_single_storage_unit(self, local_indexes, target_storage_unit=None, socket=None): try: request_msg = ZMQMessage.create( From f9dbb9e7c82f47ddec48f0142954e5f39a3d0ef0 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 27 Feb 2026 11:13:04 +0800 Subject: [PATCH 4/5] fix Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 61330bc..cad91f9 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -269,7 +269,7 @@ async def notify_data_update( logger.error(f"[{self.storage_manager_id}]: Did not receive data status update ACK.") except Exception as e: - # 发送错误通知 + logger.error(f"[{self.storage_manager_id}]: Error during notify_data_update: {e}") try: error_msg = ZMQMessage.create( request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, @@ -282,7 +282,7 @@ async def notify_data_update( finally: try: if not sock.closed: - sock.close(linger=0) + sock.close(linger=-1) except Exception: pass From 72ce3d3aba624c7ab465c9255f9d0df6f87b8dac Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 27 Feb 2026 14:04:50 +0800 Subject: [PATCH 5/5] update Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/simple_backend_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 8066d6a..745c04b 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -45,7 +45,7 @@ handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) logger.addHandler(handler) -TQ_SIMPLE_STORAGE_MANAGER_COMM_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT", 200)) # seconds +TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) @@ -249,7 +249,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: partition_id, list(results.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes ) - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_MANAGER_COMM_TIMEOUT) + @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) async def _put_to_single_storage_unit( self, local_indexes: list[int], @@ -348,7 +348,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: return TensorDict(tensor_data, batch_size=len(metadata)) - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_MANAGER_COMM_TIMEOUT) + @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) async def _get_from_single_storage_unit( self, storage_meta_group: StorageMetaGroup, target_storage_unit: str, socket: zmq.Socket = None ): @@ -407,7 +407,7 @@ async def clear_data(self, metadata: BatchMeta) -> None: if isinstance(result, Exception): logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}") - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_MANAGER_COMM_TIMEOUT) + @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) async def _clear_single_storage_unit(self, local_indexes, target_storage_unit=None, socket=None): try: request_msg = ZMQMessage.create(