diff --git a/scripts/performance_test.py b/scripts/performance_test.py index 396ddb2..14d06a4 100644 --- a/scripts/performance_test.py +++ b/scripts/performance_test.py @@ -30,14 +30,11 @@ parent_dir = Path(__file__).resolve().parent.parent.parent sys.path.append(str(parent_dir)) - -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) +from transfer_queue.client import TransferQueueClient # noqa: E402 +from transfer_queue.controller import TransferQueueController # noqa: E402 +from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402 from transfer_queue.utils.common import get_placement_group # noqa: E402 +from transfer_queue.utils.zmq_utils import process_zmq_server_info # noqa: E402 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index f7cd239..fdc0840 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1523,7 +1523,7 @@ def kv_retrieve_keys( ) data_fields = [] for fname, col_idx in partition.field_name_mapping.items(): - if col_mask[col_idx]: + if col_idx < len(col_mask) and col_mask[col_idx]: data_fields.append(fname) metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch") diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index d415a2f..23d0bc9 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -82,7 +82,9 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: placement_group_bundle_index=storage_unit_rank, name=f"TransferQueueStorageUnit#{storage_unit_rank}", lifetime="detached", - ).remote(storage_unit_size=math.ceil(total_storage_size / num_data_storage_units)) + ).remote( + storage_unit_size=math.ceil(total_storage_size / num_data_storage_units), + ) _TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index 37dcca4..ed12d54 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -16,10 +16,12 @@ import dataclasses import logging import os +import time +import weakref from dataclasses import dataclass from operator import itemgetter -from threading import Thread -from typing import Any +from threading import Event, Thread +from typing import Any, Optional from uuid import uuid4 import ray @@ -173,16 +175,41 @@ def __init__(self, storage_unit_size: int): self.storage_data = StorageUnitData(self.storage_unit_size) + # Internal communication address for proxy and workers + self._inproc_addr = f"inproc://simple_storage_workers_{self.storage_unit_id}" + + # Shutdown event for graceful termination + self._shutdown_event = Event() + + # Placeholder for zmq_context, proxy_thread and worker_threads + self.zmq_context: Optional[zmq.Context] = None + self.put_get_socket: Optional[zmq.Socket] = None + self.proxy_thread: Optional[Thread] = None + self.worker_thread: Optional[Thread] = None + self._init_zmq_socket() self._start_process_put_get() + # Register finalizer for graceful cleanup when garbage collected + self._finalizer = weakref.finalize( + self, + self._shutdown_resources, + self._shutdown_event, + self.worker_thread, + self.proxy_thread, + self.zmq_context, + self.put_get_socket, + ) + def _init_zmq_socket(self) -> None: """ Initialize ZMQ socket connections between storage unit and controller/clients: - - put_get_socket: - Handle put/get requests from clients. + - put_get_socket (ROUTER): Handle put/get requests from clients. + - worker_socket (DEALER): Backend socket for worker communication. """ self.zmq_context = zmq.Context() + + # Frontend: ROUTER for receiving client requests self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER) self._node_ip = get_node_ip_address() @@ -195,6 +222,10 @@ def _init_zmq_socket(self) -> None: logger.warning(f"[{self.storage_unit_id}]: Try to bind ZMQ sockets failed, retrying...") continue + # Backend: DEALER for worker communication (connected via zmq.proxy) + self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER) + self.worker_socket.bind(self._inproc_addr) + self.zmq_server_info = ZMQServerInfo( role=TransferQueueRole.STORAGE, id=str(self.storage_unit_id), @@ -203,33 +234,78 @@ def _init_zmq_socket(self) -> None: ) def _start_process_put_get(self) -> None: - """Create a daemon thread and start put/get process.""" - self.process_put_get_thread = Thread( - target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.storage_unit_id}", daemon=True + """Start worker threads and ZMQ proxy for handling requests.""" + + # Start worker thread + self.worker_thread = Thread( + target=self._worker_routine, + name=f"StorageUnitWorkerThread-{self.storage_unit_id}", + daemon=True, + ) + self.worker_thread.start() + + time.sleep(0.5) # make sure worker thread is ready before zmq.proxy forwarding messages + + # Start proxy thread (ROUTER <-> DEALER) + self.proxy_thread = Thread( + target=self._proxy_routine, + name=f"StorageUnitProxyThread-{self.storage_unit_id}", + daemon=True, ) - self.process_put_get_thread.start() + self.proxy_thread.start() + + def _proxy_routine(self) -> None: + """ZMQ proxy for message forwarding between frontend ROUTER and backend DEALER.""" + logger.info(f"[{self.storage_unit_id}]: start ZMQ proxy...") + try: + zmq.proxy(self.put_get_socket, self.worker_socket) + except zmq.ContextTerminated: + logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy stopped gracefully (Context Terminated)") + except Exception as e: + if self._shutdown_event.is_set(): + logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy shutting down...") + else: + logger.error(f"[{self.storage_unit_id}]: ZMQ Proxy unexpected error: {e}") + + def _worker_routine(self) -> None: + """Worker thread for processing requests.""" + # Each worker must have its own socket + worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER) + worker_socket.connect(self._inproc_addr) - def _process_put_get(self) -> None: - """Process put_get_socket request.""" poller = zmq.Poller() - poller.register(self.put_get_socket, zmq.POLLIN) + poller.register(worker_socket, zmq.POLLIN) - logger.info(f"[{self.storage_unit_id}]: start processing put/get requests...") + logger.info(f"[{self.storage_unit_id}]: worker thread started...") + perf_monitor = IntervalPerfMonitor(caller_name=f"{self.storage_unit_id}") + + while not self._shutdown_event.is_set(): + try: + socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000)) + except zmq.error.ContextTerminated: + # ZMQ context was terminated, exit gracefully + logger.info(f"[{self.storage_unit_id}]: worker stopped gracefully (Context Terminated)") + break + except Exception as e: + logger.warning(f"[{self.storage_unit_id}]: worker poll error: {e}") + continue - perf_monitor = IntervalPerfMonitor(caller_name=self.storage_unit_id) + if self._shutdown_event.is_set(): + break - while True: - socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000)) + if worker_socket in socks: + # Messages received from proxy: [identity, serialized_msg_frame1, ...] + messages = worker_socket.recv_multipart() + identity = messages[0] + serialized_msg = messages[1:] - if self.put_get_socket in socks: - messages = self.put_get_socket.recv_multipart() - identity = messages.pop(0) - serialized_msg = messages request_msg = ZMQMessage.deserialize(serialized_msg) operation = request_msg.request_type + try: - logger.debug(f"[{self.storage_unit_id}]: receive operation: {operation}, message: {request_msg}") + logger.debug(f"[{self.storage_unit_id}]: worker received operation: {operation}") + # Process request if operation == ZMQRequestType.PUT_DATA: with perf_monitor.measure(op_type="PUT_DATA"): response_msg = self._handle_put(request_msg) @@ -253,12 +329,17 @@ def _process_put_get(self) -> None: request_type=ZMQRequestType.PUT_GET_ERROR, sender_id=self.storage_unit_id, body={ - "message": f"Storage unit id #{self.storage_unit_id} occur error in processing " - f"put/get/clear request, detail error message: {str(e)}." + "message": f"{self.storage_unit_id}, worker encountered error " + f"during operation {operation}: {str(e)}." }, ) - self.put_get_socket.send_multipart([identity, *response_msg.serialize()], copy=False) + # Send response back with identity for routing + worker_socket.send_multipart([identity] + response_msg.serialize(), copy=False) + + logger.info(f"[{self.storage_unit_id}]: worker stopped.") + poller.unregister(worker_socket) + worker_socket.close(linger=0) def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: """ @@ -365,6 +446,36 @@ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: ) return response_msg + @staticmethod + def _shutdown_resources( + shutdown_event: Event, + worker_thread: Optional[Thread], + proxy_thread: Optional[Thread], + zmq_context: Optional[zmq.Context], + put_get_socket: Optional[zmq.Socket], + ) -> None: + """Clean up resources on garbage collection.""" + logger.info("Shutting down SimpleStorageUnit resources...") + + # Signal all threads to stop + shutdown_event.set() + + # Terminate put_get_socket + if put_get_socket: + put_get_socket.close(linger=0) + + # Terminate ZMQ context to unblock proxy and workers + if zmq_context: + zmq_context.term() + + # Wait for threads to finish (with timeout) + if worker_thread and worker_thread.is_alive(): + worker_thread.join(timeout=5) + if proxy_thread and proxy_thread.is_alive(): + proxy_thread.join(timeout=5) + + logger.info("SimpleStorageUnit resources shutdown complete.") + def get_zmq_server_info(self) -> ZMQServerInfo: """Get the ZMQ server information for this storage unit.