Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions scripts/performance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,11 @@
parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))


from transfer_queue import ( # noqa: E402
SimpleStorageUnit,
TransferQueueClient,
TransferQueueController,
process_zmq_server_info,
)
from transfer_queue.client import TransferQueueClient # noqa: E402
from transfer_queue.controller import TransferQueueController # noqa: E402
from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402
from transfer_queue.utils.common import get_placement_group # noqa: E402
from transfer_queue.utils.zmq_utils import process_zmq_server_info # noqa: E402

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ def kv_retrieve_keys(
)
data_fields = []
for fname, col_idx in partition.field_name_mapping.items():
if col_mask[col_idx]:
if col_idx < len(col_mask) and col_mask[col_idx]:
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change adds a bounds check to prevent an IndexError when col_idx is out of range for col_mask. While this is a good defensive fix, it appears to be unrelated to the multi-threading changes in this PR.

Consider:

  1. Moving this fix to a separate PR for easier tracking and review.
  2. Adding a comment explaining under what conditions col_idx might exceed len(col_mask), as this could indicate a data consistency issue elsewhere in the code.
  3. Adding a warning log when this condition is detected to help identify the root cause.

Copilot uses AI. Check for mistakes.
data_fields.append(fname)

metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch")
Expand Down
4 changes: 3 additions & 1 deletion transfer_queue/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig:
placement_group_bundle_index=storage_unit_rank,
name=f"TransferQueueStorageUnit#{storage_unit_rank}",
lifetime="detached",
).remote(storage_unit_size=math.ceil(total_storage_size / num_data_storage_units))
).remote(
storage_unit_size=math.ceil(total_storage_size / num_data_storage_units),
)
_TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node
logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.")

Expand Down
157 changes: 134 additions & 23 deletions transfer_queue/storage/simple_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import dataclasses
import logging
import os
import time
import weakref
from dataclasses import dataclass
from operator import itemgetter
from threading import Thread
from typing import Any
from threading import Event, Thread
from typing import Any, Optional
from uuid import uuid4

import ray
Expand Down Expand Up @@ -173,16 +175,41 @@ def __init__(self, storage_unit_size: int):

self.storage_data = StorageUnitData(self.storage_unit_size)

# Internal communication address for proxy and workers
self._inproc_addr = f"inproc://simple_storage_workers_{self.storage_unit_id}"

# Shutdown event for graceful termination
self._shutdown_event = Event()

# Placeholder for zmq_context, proxy_thread and worker_threads
self.zmq_context: Optional[zmq.Context] = None
self.put_get_socket: Optional[zmq.Socket] = None
self.proxy_thread: Optional[Thread] = None
self.worker_thread: Optional[Thread] = None

self._init_zmq_socket()
self._start_process_put_get()

# Register finalizer for graceful cleanup when garbage collected
self._finalizer = weakref.finalize(
self,
self._shutdown_resources,
self._shutdown_event,
self.worker_thread,
self.proxy_thread,
self.zmq_context,
self.put_get_socket,
)
Comment on lines 194 to 202
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The finalizer is registered before the threads and zmq_context are fully initialized. At this point (line 206), self.worker_threads is an empty list, self.proxy_thread is None, and self.zmq_context is None. These values are captured by the finalizer at registration time, not at cleanup time. When garbage collection occurs, the finalizer will attempt to shut down the wrong (empty/None) references instead of the actual running threads and context. The finalizer should be registered after _init_zmq_socket() and _start_process_put_get() complete, or it should pass self and access attributes dynamically.

Copilot uses AI. Check for mistakes.

def _init_zmq_socket(self) -> None:
"""
Initialize ZMQ socket connections between storage unit and controller/clients:
- put_get_socket:
Handle put/get requests from clients.
- put_get_socket (ROUTER): Handle put/get requests from clients.
- worker_socket (DEALER): Backend socket for worker communication.
"""
self.zmq_context = zmq.Context()

# Frontend: ROUTER for receiving client requests
self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER)
self._node_ip = get_node_ip_address()

Expand All @@ -195,6 +222,10 @@ def _init_zmq_socket(self) -> None:
logger.warning(f"[{self.storage_unit_id}]: Try to bind ZMQ sockets failed, retrying...")
continue

# Backend: DEALER for worker communication (connected via zmq.proxy)
self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER)
self.worker_socket.bind(self._inproc_addr)
Comment on lines +225 to +227
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The worker socket binding happens in _init_zmq_socket (line 235), but worker threads connect to this address in _worker_routine (line 283). There's a potential race condition where worker threads might try to connect before the backend socket is fully bound and ready.

While there's a retry mechanism in the bind operation (lines 225-231), adding a small delay or verification after the bind at line 235 would ensure the socket is ready before starting worker threads. Alternatively, consider binding the backend socket before starting any threads, or add connection retry logic in the worker threads.

Copilot uses AI. Check for mistakes.

self.zmq_server_info = ZMQServerInfo(
role=TransferQueueRole.STORAGE,
id=str(self.storage_unit_id),
Expand All @@ -203,33 +234,78 @@ def _init_zmq_socket(self) -> None:
)

def _start_process_put_get(self) -> None:
"""Create a daemon thread and start put/get process."""
self.process_put_get_thread = Thread(
target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.storage_unit_id}", daemon=True
"""Start worker threads and ZMQ proxy for handling requests."""

# Start worker thread
self.worker_thread = Thread(
target=self._worker_routine,
name=f"StorageUnitWorkerThread-{self.storage_unit_id}",
daemon=True,
)
self.worker_thread.start()

time.sleep(0.5) # make sure worker thread is ready before zmq.proxy forwarding messages

# Start proxy thread (ROUTER <-> DEALER)
self.proxy_thread = Thread(
target=self._proxy_routine,
name=f"StorageUnitProxyThread-{self.storage_unit_id}",
daemon=True,
)
self.process_put_get_thread.start()
self.proxy_thread.start()

def _proxy_routine(self) -> None:
"""ZMQ proxy for message forwarding between frontend ROUTER and backend DEALER."""
logger.info(f"[{self.storage_unit_id}]: start ZMQ proxy...")
try:
zmq.proxy(self.put_get_socket, self.worker_socket)
except zmq.ContextTerminated:
logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy stopped gracefully (Context Terminated)")
except Exception as e:
if self._shutdown_event.is_set():
logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy shutting down...")
else:
logger.error(f"[{self.storage_unit_id}]: ZMQ Proxy unexpected error: {e}")

def _worker_routine(self) -> None:
"""Worker thread for processing requests."""
# Each worker must have its own socket
worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER)
worker_socket.connect(self._inproc_addr)

def _process_put_get(self) -> None:
"""Process put_get_socket request."""
poller = zmq.Poller()
poller.register(self.put_get_socket, zmq.POLLIN)
poller.register(worker_socket, zmq.POLLIN)

logger.info(f"[{self.storage_unit_id}]: start processing put/get requests...")
logger.info(f"[{self.storage_unit_id}]: worker thread started...")
perf_monitor = IntervalPerfMonitor(caller_name=f"{self.storage_unit_id}")

while not self._shutdown_event.is_set():
try:
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
except zmq.error.ContextTerminated:
# ZMQ context was terminated, exit gracefully
logger.info(f"[{self.storage_unit_id}]: worker stopped gracefully (Context Terminated)")
break
except Exception as e:
logger.warning(f"[{self.storage_unit_id}]: worker poll error: {e}")
continue

perf_monitor = IntervalPerfMonitor(caller_name=self.storage_unit_id)
if self._shutdown_event.is_set():
break

while True:
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
if worker_socket in socks:
# Messages received from proxy: [identity, serialized_msg_frame1, ...]
messages = worker_socket.recv_multipart()
identity = messages[0]
serialized_msg = messages[1:]

if self.put_get_socket in socks:
messages = self.put_get_socket.recv_multipart()
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)
operation = request_msg.request_type

try:
logger.debug(f"[{self.storage_unit_id}]: receive operation: {operation}, message: {request_msg}")
logger.debug(f"[{self.storage_unit_id}]: worker received operation: {operation}")

# Process request
if operation == ZMQRequestType.PUT_DATA:
with perf_monitor.measure(op_type="PUT_DATA"):
response_msg = self._handle_put(request_msg)
Expand All @@ -253,12 +329,17 @@ def _process_put_get(self) -> None:
request_type=ZMQRequestType.PUT_GET_ERROR,
sender_id=self.storage_unit_id,
body={
"message": f"Storage unit id #{self.storage_unit_id} occur error in processing "
f"put/get/clear request, detail error message: {str(e)}."
"message": f"{self.storage_unit_id}, worker encountered error "
f"during operation {operation}: {str(e)}."
},
)

self.put_get_socket.send_multipart([identity, *response_msg.serialize()], copy=False)
# Send response back with identity for routing
worker_socket.send_multipart([identity] + response_msg.serialize(), copy=False)

logger.info(f"[{self.storage_unit_id}]: worker stopped.")
poller.unregister(worker_socket)
worker_socket.close(linger=0)

def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage:
"""
Expand Down Expand Up @@ -365,6 +446,36 @@ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage:
)
return response_msg

@staticmethod
def _shutdown_resources(
shutdown_event: Event,
worker_thread: Optional[Thread],
proxy_thread: Optional[Thread],
zmq_context: Optional[zmq.Context],
put_get_socket: Optional[zmq.Socket],
) -> None:
"""Clean up resources on garbage collection."""
logger.info("Shutting down SimpleStorageUnit resources...")

# Signal all threads to stop
shutdown_event.set()

# Terminate put_get_socket
if put_get_socket:
put_get_socket.close(linger=0)

# Terminate ZMQ context to unblock proxy and workers
if zmq_context:
zmq_context.term()

# Wait for threads to finish (with timeout)
if worker_thread and worker_thread.is_alive():
worker_thread.join(timeout=5)
if proxy_thread and proxy_thread.is_alive():
proxy_thread.join(timeout=5)

logger.info("SimpleStorageUnit resources shutdown complete.")

def get_zmq_server_info(self) -> ZMQServerInfo:
"""Get the ZMQ server information for this storage unit.

Expand Down