-
Notifications
You must be signed in to change notification settings - Fork 12
[perf] Add zmq.proxy to accelerate request processing for SimpleStorageUnit
#37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4a483b7
8bcfc33
c12f590
7585917
51445ed
9ad0c9e
f6a16f4
5c20b4e
f3d28a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,10 +16,12 @@ | |
| import dataclasses | ||
| import logging | ||
| import os | ||
| import time | ||
| import weakref | ||
| from dataclasses import dataclass | ||
| from operator import itemgetter | ||
| from threading import Thread | ||
| from typing import Any | ||
| from threading import Event, Thread | ||
| from typing import Any, Optional | ||
| from uuid import uuid4 | ||
|
|
||
| import ray | ||
|
|
@@ -173,16 +175,41 @@ def __init__(self, storage_unit_size: int): | |
|
|
||
| self.storage_data = StorageUnitData(self.storage_unit_size) | ||
|
|
||
| # Internal communication address for proxy and workers | ||
| self._inproc_addr = f"inproc://simple_storage_workers_{self.storage_unit_id}" | ||
|
|
||
| # Shutdown event for graceful termination | ||
| self._shutdown_event = Event() | ||
|
|
||
| # Placeholder for zmq_context, proxy_thread and worker_threads | ||
| self.zmq_context: Optional[zmq.Context] = None | ||
| self.put_get_socket: Optional[zmq.Socket] = None | ||
| self.proxy_thread: Optional[Thread] = None | ||
| self.worker_thread: Optional[Thread] = None | ||
|
|
||
| self._init_zmq_socket() | ||
| self._start_process_put_get() | ||
|
|
||
| # Register finalizer for graceful cleanup when garbage collected | ||
| self._finalizer = weakref.finalize( | ||
| self, | ||
| self._shutdown_resources, | ||
| self._shutdown_event, | ||
| self.worker_thread, | ||
| self.proxy_thread, | ||
| self.zmq_context, | ||
| self.put_get_socket, | ||
| ) | ||
|
Comment on lines
194
to
202
|
||
|
|
||
| def _init_zmq_socket(self) -> None: | ||
| """ | ||
| Initialize ZMQ socket connections between storage unit and controller/clients: | ||
| - put_get_socket: | ||
| Handle put/get requests from clients. | ||
| - put_get_socket (ROUTER): Handle put/get requests from clients. | ||
| - worker_socket (DEALER): Backend socket for worker communication. | ||
| """ | ||
| self.zmq_context = zmq.Context() | ||
|
|
||
| # Frontend: ROUTER for receiving client requests | ||
| self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER) | ||
| self._node_ip = get_node_ip_address() | ||
|
|
||
|
|
@@ -195,6 +222,10 @@ def _init_zmq_socket(self) -> None: | |
| logger.warning(f"[{self.storage_unit_id}]: Try to bind ZMQ sockets failed, retrying...") | ||
| continue | ||
|
|
||
| # Backend: DEALER for worker communication (connected via zmq.proxy) | ||
| self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER) | ||
| self.worker_socket.bind(self._inproc_addr) | ||
|
Comment on lines
+225
to
+227
|
||
|
|
||
| self.zmq_server_info = ZMQServerInfo( | ||
| role=TransferQueueRole.STORAGE, | ||
| id=str(self.storage_unit_id), | ||
|
|
@@ -203,33 +234,78 @@ def _init_zmq_socket(self) -> None: | |
| ) | ||
|
|
||
| def _start_process_put_get(self) -> None: | ||
| """Create a daemon thread and start put/get process.""" | ||
| self.process_put_get_thread = Thread( | ||
| target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.storage_unit_id}", daemon=True | ||
| """Start worker threads and ZMQ proxy for handling requests.""" | ||
|
|
||
| # Start worker thread | ||
| self.worker_thread = Thread( | ||
| target=self._worker_routine, | ||
| name=f"StorageUnitWorkerThread-{self.storage_unit_id}", | ||
| daemon=True, | ||
| ) | ||
| self.worker_thread.start() | ||
|
|
||
| time.sleep(0.5) # make sure worker thread is ready before zmq.proxy forwarding messages | ||
|
|
||
| # Start proxy thread (ROUTER <-> DEALER) | ||
| self.proxy_thread = Thread( | ||
| target=self._proxy_routine, | ||
| name=f"StorageUnitProxyThread-{self.storage_unit_id}", | ||
| daemon=True, | ||
| ) | ||
| self.process_put_get_thread.start() | ||
| self.proxy_thread.start() | ||
0oshowero0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _proxy_routine(self) -> None: | ||
| """ZMQ proxy for message forwarding between frontend ROUTER and backend DEALER.""" | ||
| logger.info(f"[{self.storage_unit_id}]: start ZMQ proxy...") | ||
| try: | ||
| zmq.proxy(self.put_get_socket, self.worker_socket) | ||
| except zmq.ContextTerminated: | ||
| logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy stopped gracefully (Context Terminated)") | ||
| except Exception as e: | ||
| if self._shutdown_event.is_set(): | ||
| logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy shutting down...") | ||
| else: | ||
| logger.error(f"[{self.storage_unit_id}]: ZMQ Proxy unexpected error: {e}") | ||
|
|
||
| def _worker_routine(self) -> None: | ||
| """Worker thread for processing requests.""" | ||
| # Each worker must have its own socket | ||
| worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER) | ||
| worker_socket.connect(self._inproc_addr) | ||
|
|
||
| def _process_put_get(self) -> None: | ||
| """Process put_get_socket request.""" | ||
| poller = zmq.Poller() | ||
| poller.register(self.put_get_socket, zmq.POLLIN) | ||
| poller.register(worker_socket, zmq.POLLIN) | ||
|
|
||
| logger.info(f"[{self.storage_unit_id}]: start processing put/get requests...") | ||
| logger.info(f"[{self.storage_unit_id}]: worker thread started...") | ||
| perf_monitor = IntervalPerfMonitor(caller_name=f"{self.storage_unit_id}") | ||
|
|
||
| while not self._shutdown_event.is_set(): | ||
| try: | ||
| socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000)) | ||
| except zmq.error.ContextTerminated: | ||
| # ZMQ context was terminated, exit gracefully | ||
| logger.info(f"[{self.storage_unit_id}]: worker stopped gracefully (Context Terminated)") | ||
| break | ||
| except Exception as e: | ||
| logger.warning(f"[{self.storage_unit_id}]: worker poll error: {e}") | ||
| continue | ||
|
|
||
| perf_monitor = IntervalPerfMonitor(caller_name=self.storage_unit_id) | ||
| if self._shutdown_event.is_set(): | ||
| break | ||
|
|
||
| while True: | ||
| socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000)) | ||
| if worker_socket in socks: | ||
| # Messages received from proxy: [identity, serialized_msg_frame1, ...] | ||
| messages = worker_socket.recv_multipart() | ||
| identity = messages[0] | ||
| serialized_msg = messages[1:] | ||
0oshowero0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if self.put_get_socket in socks: | ||
| messages = self.put_get_socket.recv_multipart() | ||
| identity = messages.pop(0) | ||
| serialized_msg = messages | ||
| request_msg = ZMQMessage.deserialize(serialized_msg) | ||
| operation = request_msg.request_type | ||
|
|
||
| try: | ||
| logger.debug(f"[{self.storage_unit_id}]: receive operation: {operation}, message: {request_msg}") | ||
| logger.debug(f"[{self.storage_unit_id}]: worker received operation: {operation}") | ||
|
|
||
| # Process request | ||
| if operation == ZMQRequestType.PUT_DATA: | ||
| with perf_monitor.measure(op_type="PUT_DATA"): | ||
| response_msg = self._handle_put(request_msg) | ||
|
|
@@ -253,12 +329,17 @@ def _process_put_get(self) -> None: | |
| request_type=ZMQRequestType.PUT_GET_ERROR, | ||
| sender_id=self.storage_unit_id, | ||
| body={ | ||
| "message": f"Storage unit id #{self.storage_unit_id} occur error in processing " | ||
| f"put/get/clear request, detail error message: {str(e)}." | ||
| "message": f"{self.storage_unit_id}, worker encountered error " | ||
| f"during operation {operation}: {str(e)}." | ||
| }, | ||
| ) | ||
|
|
||
| self.put_get_socket.send_multipart([identity, *response_msg.serialize()], copy=False) | ||
| # Send response back with identity for routing | ||
| worker_socket.send_multipart([identity] + response_msg.serialize(), copy=False) | ||
|
|
||
| logger.info(f"[{self.storage_unit_id}]: worker stopped.") | ||
| poller.unregister(worker_socket) | ||
| worker_socket.close(linger=0) | ||
|
|
||
| def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: | ||
| """ | ||
|
|
@@ -365,6 +446,36 @@ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: | |
| ) | ||
| return response_msg | ||
|
|
||
| @staticmethod | ||
| def _shutdown_resources( | ||
| shutdown_event: Event, | ||
| worker_thread: Optional[Thread], | ||
| proxy_thread: Optional[Thread], | ||
| zmq_context: Optional[zmq.Context], | ||
| put_get_socket: Optional[zmq.Socket], | ||
| ) -> None: | ||
| """Clean up resources on garbage collection.""" | ||
| logger.info("Shutting down SimpleStorageUnit resources...") | ||
|
|
||
| # Signal all threads to stop | ||
| shutdown_event.set() | ||
|
|
||
| # Terminate put_get_socket | ||
| if put_get_socket: | ||
| put_get_socket.close(linger=0) | ||
|
|
||
| # Terminate ZMQ context to unblock proxy and workers | ||
| if zmq_context: | ||
| zmq_context.term() | ||
|
|
||
| # Wait for threads to finish (with timeout) | ||
| if worker_thread and worker_thread.is_alive(): | ||
| worker_thread.join(timeout=5) | ||
| if proxy_thread and proxy_thread.is_alive(): | ||
| proxy_thread.join(timeout=5) | ||
0oshowero0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| logger.info("SimpleStorageUnit resources shutdown complete.") | ||
|
|
||
| def get_zmq_server_info(self) -> ZMQServerInfo: | ||
| """Get the ZMQ server information for this storage unit. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change adds a bounds check to prevent an IndexError when
col_idxis out of range forcol_mask. While this is a good defensive fix, it appears to be unrelated to the multi-threading changes in this PR.Consider:
col_idxmight exceedlen(col_mask), as this could indicate a data consistency issue elsewhere in the code.