diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 2e9ce30..cad91f9 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -28,6 +28,7 @@ import ray import torch import zmq +import zmq.asyncio from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict from torch import Tensor @@ -65,10 +66,10 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): self.config = config self.controller_info = controller_info - self.data_status_update_socket: Optional[zmq.Socket[bytes]] = None - self.controller_handshake_socket: Optional[zmq.Socket[bytes]] = None + # Handshake socket is sync (used only during initialization) + self.controller_handshake_socket: Optional[zmq.Socket] = None - self.zmq_context: Optional[zmq.Context[Any]] = None + self.zmq_context: Optional[zmq.asyncio.Context] = None self._connect_to_controller() def _connect_to_controller(self) -> None: @@ -77,26 +78,28 @@ def _connect_to_controller(self) -> None: raise ValueError(f"controller_info should be ZMQServerInfo, but got {type(self.controller_info)}") try: - # create zmq context - self.zmq_context = zmq.Context() + # Create a synchronous context for handshake (blocking operation) + sync_zmq_context = zmq.Context() - # create zmq sockets for handshake and data status update + # create zmq socket for handshake (sync, for initial connection) self.controller_handshake_socket = create_zmq_socket( - self.zmq_context, + sync_zmq_context, zmq.DEALER, identity=f"{self.storage_manager_id}-controller_handshake_socket-{uuid4().hex[:8]}".encode(), ) - self.data_status_update_socket = create_zmq_socket( - self.zmq_context, - zmq.DEALER, - identity=f"{self.storage_manager_id}-data_status_update_socket-{uuid4().hex[:8]}".encode(), - ) - assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" - self.data_status_update_socket.connect(self.controller_info.to_addr("data_status_update_socket")) - # do handshake with controller + # do handshake with controller using sync socket self._do_handshake_with_controller() + # close the sync handshake socket and context after handshake + if self.controller_handshake_socket and not self.controller_handshake_socket.closed: + self.controller_handshake_socket.close(linger=0) + self.controller_handshake_socket = None + sync_zmq_context.term() + + # create async context for data status update + self.zmq_context = zmq.asyncio.Context() + except Exception as e: logger.error(f"Failed to connect to controller: {e}") raise @@ -210,22 +213,19 @@ async def notify_data_update( shapes: Per-field shapes for each field, in {global_index: {field: shape}} format. custom_backend_meta: Per-field custom_meta for each sample, in {global_index: {field: custom_meta}} format. """ - # Create zmq poller for notifying data update information - if not self.controller_info: logger.warning(f"No controller connected for storage manager {self.storage_manager_id}") return - # Create zmq poller for notifying data update information - poller = zmq.Poller() - # Note: data_status_update_socket is already connected during initialization - assert self.data_status_update_socket is not None, "data_status_update_socket is not properly initialized" + # create dynamic socket + identity = f"{self.storage_manager_id}-data_update-{uuid4().hex[:8]}".encode() + sock = create_zmq_socket(self.zmq_context, zmq.DEALER, identity=identity) try: - poller.register(self.data_status_update_socket, zmq.POLLIN) + sock.connect(self.controller_info.to_addr("data_status_update_socket")) request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type] + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, sender_id=self.storage_manager_id, body={ "partition_id": partition_id, @@ -237,51 +237,54 @@ async def notify_data_update( }, ).serialize() - self.data_status_update_socket.send_multipart(request_msg) + await sock.send_multipart(request_msg) logger.debug( f"[{self.storage_manager_id}]: Send data status update request " f"from storage manager id #{self.storage_manager_id} " f"to controller id #{self.controller_info.id} successfully." ) - except Exception as e: - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type] - sender_id=self.storage_manager_id, - body={ - "message": f"Failed to notify data status update information from " - f"storage manager id #{self.storage_manager_id}, " - f"detail error message: {str(e)}" - }, - ).serialize() - - self.data_status_update_socket.send_multipart(request_msg) - # Make sure controller successfully receives data status update information. - response_received: bool = False - start_time = time.time() + response_received = False + timeout = TQ_DATA_UPDATE_RESPONSE_TIMEOUT - while ( - not response_received # Only one controller to get response from - and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT - ): - socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000)) + while not response_received and timeout > 0: + try: + poll_interval = min(TQ_STORAGE_POLLER_TIMEOUT, timeout) + messages = await asyncio.wait_for(sock.recv_multipart(), timeout=poll_interval) + response_msg = ZMQMessage.deserialize(messages) - if self.data_status_update_socket in socks: - response_msg = ZMQMessage.deserialize(self.data_status_update_socket.recv_multipart()) + if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: + response_received = True + logger.debug( + f"[{self.storage_manager_id}]: Get data status update ACK response " + f"from controller id #{response_msg.sender_id} successfully." + ) + except asyncio.TimeoutError: + timeout -= poll_interval + except Exception as e: + logger.warning(f"[{self.storage_manager_id}]: Error receiving response: {e}") + break - if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: - response_received = True - logger.debug( - f"[{self.storage_manager_id}]: Get data status update ACK response " - f"from controller id #{response_msg.sender_id} " - f"to storage manager id #{self.storage_manager_id} successfully." - ) + if not response_received: + logger.error(f"[{self.storage_manager_id}]: Did not receive data status update ACK.") - if not response_received: - logger.error( - f"[{self.storage_manager_id}]: Storage manager id #{self.storage_manager_id} " - f"did not receive data status update ACK response from controller." - ) + except Exception as e: + logger.error(f"[{self.storage_manager_id}]: Error during notify_data_update: {e}") + try: + error_msg = ZMQMessage.create( + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, + sender_id=self.storage_manager_id, + body={"message": f"Failed to notify: {str(e)}"}, + ).serialize() + await sock.send_multipart(error_msg) + except Exception: + pass + finally: + try: + if not sock.closed: + sock.close(linger=-1) + except Exception: + pass @abstractmethod async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: @@ -319,18 +322,16 @@ async def clear_data(self, metadata: BatchMeta) -> None: def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" - for sock in (self.controller_handshake_socket, self.data_status_update_socket): + # Close handshake socket if it exists + if self.controller_handshake_socket: try: - if sock and not sock.closed: - sock.close(linger=0) + if not self.controller_handshake_socket.closed: + self.controller_handshake_socket.close(linger=0) except Exception as e: - logger.error(f"[{self.storage_manager_id}]: Error closing socket {sock}: {str(e)}") + logger.error(f"[{self.storage_manager_id}]: Error closing controller_handshake_socket: {str(e)}") - try: - if self.zmq_context: - self.zmq_context.term() - except Exception as e: - logger.error(f"[{self.storage_manager_id}]: Error terminating zmq_context: {str(e)}") + if self.zmq_context: + self.zmq_context.term() def __del__(self): """Destructor to ensure resources are cleaned up.""" diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index b658ba4..745c04b 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -45,8 +45,7 @@ handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) logger.addHandler(handler) -TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT", 200)) # seconds -TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT", 200)) # seconds +TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False) @@ -119,11 +118,12 @@ def _build_storage_mapping_functions(self): # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. @staticmethod - def dynamic_storage_manager_socket(socket_name: str): + def dynamic_storage_manager_socket(socket_name: str, timeout: int): """Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close). Args: socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port"). + timeout (float): Timeout in seconds for ZMQ connection (in seconds). Decorated Function Rules: 1. Must be an async class method (needs `self`). @@ -157,8 +157,8 @@ async def wrapper(self, *args, **kwargs): try: sock.connect(address) # Timeouts to avoid indefinite await on recv/send - sock.setsockopt(zmq.RCVTIMEO, TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT * 1000) - sock.setsockopt(zmq.SNDTIMEO, TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT * 1000) + sock.setsockopt(zmq.RCVTIMEO, timeout * 1000) + sock.setsockopt(zmq.SNDTIMEO, timeout * 1000) logger.debug( f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} " f"with identity {identity.decode()}" @@ -249,7 +249,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: partition_id, list(results.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes ) - @dynamic_storage_manager_socket(socket_name="put_get_socket") + @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) async def _put_to_single_storage_unit( self, local_indexes: list[int], @@ -348,7 +348,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: return TensorDict(tensor_data, batch_size=len(metadata)) - @dynamic_storage_manager_socket(socket_name="put_get_socket") + @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) async def _get_from_single_storage_unit( self, storage_meta_group: StorageMetaGroup, target_storage_unit: str, socket: zmq.Socket = None ): @@ -407,7 +407,7 @@ async def clear_data(self, metadata: BatchMeta) -> None: if isinstance(result, Exception): logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}") - @dynamic_storage_manager_socket(socket_name="put_get_socket") + @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) async def _clear_single_storage_unit(self, local_indexes, target_storage_unit=None, socket=None): try: request_msg = ZMQMessage.create(