diff --git a/examples/inference/gpt/gpt_dynamic_inference.py b/examples/inference/gpt/gpt_dynamic_inference.py index 23234be93ef..f02aae9c221 100644 --- a/examples/inference/gpt/gpt_dynamic_inference.py +++ b/examples/inference/gpt/gpt_dynamic_inference.py @@ -264,8 +264,7 @@ def _process_step_result(result): break # Resume engine (NOOP if not suspended). - if engine.is_suspended: - engine.resume() + engine.resume() return { "step_times": step_times, diff --git a/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py b/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py index a7fb8fabcc0..81a96da3611 100644 --- a/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py +++ b/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py @@ -14,6 +14,7 @@ from examples.inference.gpt.utils import Request, build_dynamic_engine_setup_prefix, build_requests from megatron.core.inference.engines import DynamicInferenceEngine +from megatron.core.inference.engines.dynamic_engine import EngineState from megatron.core.inference.inference_client import InferenceClient from megatron.core.inference.inference_request import DynamicInferenceRequestRecord from megatron.core.inference.sampling_params import SamplingParams @@ -29,6 +30,22 @@ logging.basicConfig(level=logging.INFO, force=True) +async def suspend_resume_cycle(client, engine, args, futures): + """Wait for all in-flight requests, then suspend/train/resume.""" + await asyncio.gather(*futures) + + client.pause_engines() + await engine.wait_until(EngineState.PAUSED) + client.suspend_engines() + await engine.wait_until(EngineState.SUSPENDED) + if args.suspend_timeout > 0: + await asyncio.sleep(args.suspend_timeout) + client.resume_engines() + await engine.wait_until(EngineState.RESUMED) + client.unpause_engines() + await engine.wait_until(EngineState.RUNNING) + + async def main( engine: DynamicInferenceEngine, requests: List[Request], @@ -54,33 +71,22 @@ async def main( coordinator_schedule_output_path=args.coordinator_schedule_output_path, ) - # Test suspend/resume intervals. - if dist.get_rank() == 0 and args.suspend_resume_interval is not None: - # Since the client doesn't directly call engine.async_step here, we test - # the suspend-resume system ~4 times. - suspend_resume_interval = max(1, len(requests) // 4) - suspend_idxs = set( - range(suspend_resume_interval, len(requests) + 1, suspend_resume_interval) - ) - resume_idxs = set( - min(len(requests), i + suspend_resume_interval // 2) for i in suspend_idxs - ) - else: - suspend_idxs = set() - resume_idxs = set() + # All ranks agree on the number of suspend/resume cycles from args. + num_suspend_resume_cycles = len(requests) // args.suspend_resume_interval if args.suspend_resume_interval else 0 # Create client and run example. if dist.get_rank() == 0: client = InferenceClient(dp_addr) # submits requests to the inference coordinator - await client.start() + client.start() base_arrival_time = time.time_ns() / 10**9 for request in requests: request.time_arrival = request.time_offset + base_arrival_time futures = [] num_requests_total = len(requests) num_requests_added = 0 - # logging.info("Waiting for 20 seconds before starting to add requests. This is to mimic an RL style setup..") - # time.sleep(20) + next_suspend_at = args.suspend_resume_interval or 0 + cycles_done = 0 + while True: current_time = time.time_ns() / 10**9 if args.incoming_requests_per_step is None: @@ -96,11 +102,10 @@ async def main( futures.append(client.add_request(request.prompt_text, request.sampling_params)) num_requests_added += 1 - # Test suspend/resume. - if num_requests_added in suspend_idxs: - client.suspend_engines() - if num_requests_added in resume_idxs: - client.resume_engines() + if num_requests_added >= next_suspend_at and cycles_done < num_suspend_resume_cycles: + await suspend_resume_cycle(client, engine, args, futures) + cycles_done += 1 + next_suspend_at += args.suspend_resume_interval else: # Add deterministic number of requests (generally used for debugging). @@ -114,11 +119,10 @@ async def main( futures.append(client.add_request(request.prompt_text, request.sampling_params)) num_requests_added += 1 - # Test suspend/resume. - if num_requests_added in suspend_idxs: - client.suspend_engines() - if num_requests_added in resume_idxs: - client.resume_engines() + if num_requests_added >= next_suspend_at and cycles_done < num_suspend_resume_cycles: + await suspend_resume_cycle(client, engine, args, futures) + cycles_done += 1 + next_suspend_at += args.suspend_resume_interval if num_requests_added == num_requests_total: break @@ -127,6 +131,13 @@ async def main( # While we wait for the requests to complete, the engine runs in the background. results: List[DynamicInferenceRequestRecord] = await asyncio.gather(*futures) + else: + # Non-rank-0: match the suspend/resume cycles that rank 0 drives. + for _ in range(num_suspend_resume_cycles): + await engine.wait_until(EngineState.PAUSED) + await engine.wait_until(EngineState.SUSPENDED) + await engine.wait_until(EngineState.RESUMED) + await engine.wait_until(EngineState.RUNNING) if dist.get_rank() == 0: # Write results to JSON. Primarily used for functional testing. @@ -173,14 +184,19 @@ async def main( ) ) - # kill the engines and suspend the client - # Right now, we can only call stop when all requests are done. - # Todo: Make this explicit in the Client class.... - await client.stop_engines() - client.stop() + # Pause before stopping: STOP requires PAUSED or SUSPENDED state. + client.pause_engines() + + await engine.wait_until(EngineState.PAUSED) + + if dist.get_rank() == 0: + client.stop_engines() - # once the stop signal eventually makes its way to each GPU, the engines will stop. - await asyncio.gather(engine.engine_loop_task) + await engine.wait_until(EngineState.STOPPED) + + if dist.get_rank() == 0: + client.shutdown_coordinator() + client.stop() logging.info(f"Rank: {dist.get_rank()} stopped their engine instance successfully.") @@ -210,9 +226,7 @@ async def main( model = get_model_for_inference() - requests = ( - build_requests(args, tokenizer, sampling_params) if dist.get_rank() == 0 else None - ) + requests = build_requests(args, tokenizer, sampling_params) engine = get_dynamic_inference_engine(model=model) diff --git a/megatron/core/inference/data_parallel_inference_coordinator.py b/megatron/core/inference/data_parallel_inference_coordinator.py index dc4ddc17ae9..60ca06819e7 100644 --- a/megatron/core/inference/data_parallel_inference_coordinator.py +++ b/megatron/core/inference/data_parallel_inference_coordinator.py @@ -7,7 +7,7 @@ import signal import socket from collections import deque -from itertools import cycle +from enum import Enum, auto from multiprocessing import Event from multiprocessing.connection import Connection @@ -71,6 +71,14 @@ class DataParallelInferenceCoordinator: next_request_id (int): A counter for generating unique server-side request IDs. """ + class CoordinatorState(Enum): + """State machine for the coordinator.""" + + RUNNING = auto() + PAUSED = auto() + SUSPENDED = auto() + STOPPING = auto() + def __init__( self, pipe_connection: Connection, @@ -122,6 +130,8 @@ def __init__( local_ip = socket.gethostname() self.router_socket = self.context.socket(zmq.ROUTER) + # Raise error if the other side of the connection has dropped. + self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1) is_bound = False if inference_coordinator_port is not None: try: @@ -161,15 +171,14 @@ def __init__( self.identities_of_data_parallel_ranks = deque( sorted(self.identities_of_data_parallel_ranks) ) - self.data_parallel_rank_iterator = cycle(self.identities_of_data_parallel_ranks) - self.data_parallel_pause_acks = set() - self.data_parallel_stop_acks = set() + self._round_robin_idx = 0 self.request_id_to_client_id = {} self.request_id_to_client_request_id = {} self.next_request_id = 0 self.tokenizer = tokenizer + self.state = self.CoordinatorState.RUNNING # Prefix caching state for routing. self.block_size_tokens = block_size_tokens @@ -195,7 +204,36 @@ def get_next_data_parallel_rank(self): Returns: bytes: The ZMQ identity of the next data parallel rank to receive a request. """ - return next(self.data_parallel_rank_iterator) + identities = self.identities_of_data_parallel_ranks + if not identities: + raise RuntimeError("No engines connected") + idx = self._round_robin_idx % len(identities) + self._round_robin_idx = idx + 1 + return identities[idx] + + def _remove_engine(self, identity): + """Remove a disconnected engine from the routing pool.""" + self.identities_of_data_parallel_ranks.remove(identity) + logging.warning( + "Coordinator: removed engine %s (now %d engines)", + identity, + len(self.identities_of_data_parallel_ranks), + ) + + def _send_to_engine(self, identity, payload): + """Send payload to an engine, removing it from the pool if unreachable. + + Returns: + True if the send succeeded, False if the engine was unreachable and removed. + """ + try: + self.router_socket.send_multipart([identity, payload]) + return True + except zmq.error.ZMQError as e: + if e.errno == zmq.EHOSTUNREACH: + self._remove_engine(identity) + return False + raise def compute_request_hashes(self, prompt): """Compute block hashes for a prompt on CPU. @@ -274,6 +312,13 @@ def start(self): known_clients = set() while True: sender_identity, serialized_payload = self.router_socket.recv_multipart() + + # Allow for re-registration if connecting to a running coordinator. + if serialized_payload == b"": + if sender_identity not in self.identities_of_data_parallel_ranks: + self.identities_of_data_parallel_ranks.append(sender_identity) + continue + deserialized_payload = msgpack.unpackb(serialized_payload, raw=False) header = Headers(deserialized_payload[0]) @@ -319,13 +364,30 @@ def start(self): else: raise Exception("specialize for <%s> prompt." % type(prompt).__name__) + payload = msgpack.packb( + [Headers.SUBMIT_REQUEST.value, request_id, prompt, sampling_params], + use_bin_type=True, + ) + request_hashes = self.compute_request_hashes(prompt) if ( self.prefix_caching_coordinator_policy == PrefixCachingCoordinatorPolicy.FIRST_PREFIX_BLOCK ): request_hashes = request_hashes[:1] - next_data_parallel_rank_identity = self.get_best_data_parallel_rank(request_hashes) + + # Account for the fact that some engines may have died. + for _ in range(len(self.identities_of_data_parallel_ranks)): + next_identity = self.get_best_data_parallel_rank(request_hashes) + if self._send_to_engine(next_identity, payload): + break + else: + # If all engines have died, we are in an abnormal state, and must exit cleanly. + logging.error("Coordinator: no reachable engines for request %d", request_id) + del self.request_id_to_client_id[request_id] + del self.request_id_to_client_request_id[request_id] + return + if request_hashes: self._update_rank_hashes(next_data_parallel_rank_identity, request_hashes) if self.schedule_records is not None: @@ -338,76 +400,61 @@ def start(self): "num_hashes": len(request_hashes), } ) - self.router_socket.send_multipart( - [ - next_data_parallel_rank_identity, - msgpack.packb( - [Headers.SUBMIT_REQUEST.value, request_id, prompt, sampling_params], - use_bin_type=True, - ), - ] - ) - elif header in [ + + elif header in ( Headers.PAUSE, Headers.UNPAUSE, Headers.SUSPEND, Headers.RESUME, Headers.INCREMENT_STALENESS, Headers.STOP, - ]: - # control signals for the engine - # broadcast to all data parallel ranks + ): + # Start by checking the current state against the control signal. if sender_identity not in known_clients: + logging.warning("Coordinator: ignoring signal from unknown client.") continue - for data_parallel_rank_id in self.identities_of_data_parallel_ranks: - self.router_socket.send_multipart( - [data_parallel_rank_id, msgpack.packb([header.value], use_bin_type=True)] - ) - if header == Headers.UNPAUSE: - self.data_parallel_pause_acks = set() - elif header == Headers.PAUSE_ACK: - # control signal ack from the engine - assert sender_identity in self.identities_of_data_parallel_ranks - assert sender_identity not in self.data_parallel_pause_acks - self.data_parallel_pause_acks.add(sender_identity) - # route to all clients only once we have gotten an ack from all data parallel ranks - if len(self.data_parallel_pause_acks) == self.data_parallel_size: - for client_id in known_clients: - self.router_socket.send_multipart( - [ - client_id, - msgpack.packb([header.value, sender_identity], use_bin_type=True), - ] - ) - for data_parallel_rank_id in self.identities_of_data_parallel_ranks: - self.router_socket.send_multipart( - [ - data_parallel_rank_id, - msgpack.packb([Headers.PAUSE_ACK.value], use_bin_type=True), - ] - ) - elif header == Headers.STOP_ACK: - # control signal ack from the engine - assert sender_identity in self.identities_of_data_parallel_ranks - assert sender_identity not in self.data_parallel_stop_acks - self.data_parallel_stop_acks.add(sender_identity) - # route to all clients only once we have gotten an ack from all data parallel ranks - if len(self.data_parallel_stop_acks) == self.data_parallel_size: - for client_id in known_clients: - self.router_socket.send_multipart( - [ - client_id, - msgpack.packb([header.value, sender_identity], use_bin_type=True), - ] - ) - for data_parallel_rank_id in self.identities_of_data_parallel_ranks: - self.router_socket.send_multipart( - [ - data_parallel_rank_id, - msgpack.packb([Headers.STOP_ACK.value], use_bin_type=True), - ] - ) - break # Exit the main loop after STOP_ACKs have been processed. + + if header == Headers.PAUSE: + idem_states = (self.CoordinatorState.PAUSED, self.CoordinatorState.SUSPENDED) + if self.state == self.CoordinatorState.RUNNING: + self.state = self.CoordinatorState.PAUSED + elif self.state in idem_states: + # Already paused/suspended, ignore redundant PAUSE. + continue + else: + logging.warning("Coordinator: ignoring PAUSE in state %s", self.state) + continue + elif header == Headers.UNPAUSE: + if self.state != self.CoordinatorState.PAUSED: + logging.warning("Coordinator: ignoring UNPAUSE in state %s", self.state) + continue + self.state = self.CoordinatorState.RUNNING + elif header == Headers.SUSPEND: + if self.state != self.CoordinatorState.PAUSED: + logging.warning("Coordinator: ignoring SUSPEND in state %s", self.state) + continue + self.state = self.CoordinatorState.SUSPENDED + elif header == Headers.RESUME: + if self.state != self.CoordinatorState.SUSPENDED: + logging.warning("Coordinator: ignoring RESUME in state %s", self.state) + continue + self.state = self.CoordinatorState.PAUSED + elif header == Headers.STOP: + good_states = (self.CoordinatorState.PAUSED, self.CoordinatorState.SUSPENDED) + if self.state not in good_states: + logging.warning("Coordinator: ignoring STOP in state %s", self.state) + continue + self.state = self.CoordinatorState.STOPPING + + # Broadcast the control signal if we're in a good state. + broadcast_payload = msgpack.packb([header.value], use_bin_type=True) + for data_parallel_rank_id in list(self.identities_of_data_parallel_ranks): + self._send_to_engine(data_parallel_rank_id, broadcast_payload) + + # STOP affects engines; reset coordinator to RUNNING to allow future engines. + if header == Headers.STOP: + self.state = self.CoordinatorState.RUNNING + elif header == Headers.ENGINE_REPLY: # This is the output of a single engine step on some data parallel rank. assert sender_identity in self.identities_of_data_parallel_ranks @@ -431,6 +478,16 @@ def start(self): ] ) + elif header == Headers.SHUTDOWN: + if sender_identity not in known_clients: + logging.warning("Coordinator: ignoring signal from unknown client.") + continue + break + + elif header == Headers.DISCONNECT: + if sender_identity in self.identities_of_data_parallel_ranks: + self._remove_engine(sender_identity) + else: raise UnknownHeaderError(header) diff --git a/megatron/core/inference/engines/async_zmq_communicator.py b/megatron/core/inference/engines/async_zmq_communicator.py index 7076bb283bd..124f2d46932 100644 --- a/megatron/core/inference/engines/async_zmq_communicator.py +++ b/megatron/core/inference/engines/async_zmq_communicator.py @@ -65,46 +65,54 @@ def __init__(self, zmq_context: zmq.Context, process_group: dist.ProcessGroup): self.bcast_sock.connect(bcast_socket_addr) self.bcast_sock.setsockopt_string(zmq.SUBSCRIBE, "") - async def all_reduce_max(self, local_val: int) -> int: - """ - Asyncio friendly all reduce max operation. Gathers on rank 0, computes max, - and broadcasts the result. + async def all_reduce_max(self, *local_vals: int) -> int | tuple[int, ...]: + """Element-wise all-reduce max of one or more integers. + + Packs all values into a single message so the communication cost + is independent of the number of values. + + Returns a single int when called with one argument, otherwise a tuple. """ + n = len(local_vals) + if n == 0: + raise ValueError("all_reduce_max requires at least one value") + if self.world_size <= 1: - return local_val + return local_vals[0] if n == 1 else local_vals - payload = struct.pack('!i', local_val) + fmt = f'!{n}i' + payload = struct.pack(fmt, *local_vals) if self.is_leader: - # Rank 0: Gather -> Max -> Broadcast - values = [local_val] + rows = [local_vals] - # Non-blocking gather from N-1 peers - while len(values) < self.world_size: + while len(rows) < self.world_size: try: msg = self.gather_sock.recv(flags=zmq.NOBLOCK) - values.append(struct.unpack('!i', msg)[0]) + rows.append(struct.unpack(fmt, msg)) except zmq.Again: - await asyncio.sleep(0.001) # Yield to event loop + await asyncio.sleep(0.001) - max_val = max(values) - self.bcast_sock.send(struct.pack('!i', max_val)) - return max_val + maxes = tuple(max(row[i] for row in rows) for i in range(n)) + self.bcast_sock.send(struct.pack(fmt, *maxes)) + return maxes[0] if n == 1 else maxes else: - # Worker: Send -> Wait for Broadcast self.gather_sock.send(payload) while True: try: msg = self.bcast_sock.recv(flags=zmq.NOBLOCK) - return struct.unpack('!i', msg)[0] + result = struct.unpack(fmt, msg) + return result[0] if n == 1 else result except zmq.Again: - await asyncio.sleep(0.001) # Yield to event loop + await asyncio.sleep(0.001) def close(self): """ Close the ZMQ sockets. """ - self.gather_sock.close() - self.bcast_sock.close() + # linger=0: discard unsent messages immediately on close rather than blocking until sent. + # The ZMQ default is to not allow `close` until all messages have been successfully sent. + self.gather_sock.close(linger=0) + self.bcast_sock.close(linger=0) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 7dec7a14bea..88e0f31b7b6 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -12,6 +12,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime +from enum import Enum, auto from itertools import repeat from typing import Dict, List, Optional, Tuple, Union @@ -110,6 +111,21 @@ ] +class EngineState(Enum): + """State machine for the inference engine.""" + + RUNNING = auto() # Processing requests + PAUSING = auto() # PAUSE received; waiting for EP consensus + world barrier + PAUSED = auto() # Globally confirmed idle + UNPAUSING = auto() # UNPAUSE received; waiting for world barrier + SUSPENDING = auto() # SUSPEND received; offloading GPU; waiting for world barrier + SUSPENDED = auto() # GPU offloaded, all ranks confirmed + RESUMING = auto() # RESUME received; onloading GPU; waiting for world barrier + RESUMED = auto() # GPU onloaded, all ranks confirmed; cleared on next SUSPEND + STOPPING = auto() # STOP received; futures cancelled; waiting for world barrier + STOPPED = auto() # All ranks confirmed; teardown complete + + class EngineSuspendedError(Exception): """Engine is currently suspended and not performing steps.""" @@ -153,6 +169,15 @@ class DynamicInferenceEngine(AbstractEngine): batching and a dynamic block-level KV cache (similar to paged attention). """ + # Map stable states to their corresponding asyncio events. + _STATE_EVENTS = ( + EngineState.RUNNING, + EngineState.PAUSED, + EngineState.SUSPENDED, + EngineState.RESUMED, + EngineState.STOPPED, + ) + @deprecate_args( *DEPRECATED_ARGS, message="Argument `{name}` has been deprecated. Only pass `controller` and `context`", @@ -251,13 +276,11 @@ def reset(self) -> None: # Runtime state. self._loop = get_asyncio_loop(getattr(self, "_loop", None)) self._cond = asyncio.Condition() - self.running = asyncio.Event() - self.paused = asyncio.Event() - self.stopped = asyncio.Event() - self.received_pause: bool = False - self.received_stop: bool = False - self.suspend_signal = False - self.is_suspended = False + self._state_events = {k: asyncio.Event() for k in self._STATE_EVENTS} + self.state = EngineState.RUNNING + self._state_events[EngineState.RUNNING].set() + self._pending_signals = deque() + self.resume_request_ids = None # Prefix caching coordination state. @@ -266,6 +289,18 @@ def reset(self) -> None: # Coordinator state. self.use_coordinator = False + async def wait_until(self, state: EngineState): + """Wait until the engine reaches the given state. + + Only stable states (RUNNING, PAUSED, SUSPENDED, RESUMED, + STOPPED) are supported. Transient states (PAUSING, SUSPENDING, + RESUMING, STOPPING) are not directly waitable. + """ + event = self._state_events.get(state) + if event is None: + raise ValueError(f"Cannot wait for transient state {state}") + await event.wait() + def create_cuda_graphs(self, reset_context: bool = True): """Create cuda graphs. @@ -421,7 +456,7 @@ async def start_listening_to_data_parallel_coordinator( "pip install msgpack" ) - self.zmq_context = zmq.Context().instance() + self.zmq_context = zmq.Context.instance() self.zmq_sockets = [] # keep track of all sockets created by this engine # Get world info. @@ -547,6 +582,11 @@ async def start_listening_to_data_parallel_coordinator( self.zmq_context, process_group=self.pg_collection.ep ) + # initialize zmq-based world communicator for consensus barriers + total_world_size = torch.distributed.get_world_size() + if total_world_size > 1: + self.world_zmq_communicator = AsyncZMQCommunicator(self.zmq_context, process_group=None) + if launch_inference_coordinator and self.is_dp_coordinator: await await_process_call( coordinator_ready_event.wait, self.inference_coordinator_process @@ -626,11 +666,9 @@ def suspend_resume_ctx(key: str, *, unified_memory_level: int) -> None: def suspend(self): """Suspend engine by deallocating context's GPU state.""" - # Skip if already suspended, which can happen when using the inference - # coordinator. - if self.is_suspended: + # Skip if already suspended or in the process of suspending. + if self.state in (EngineState.SUSPENDED, EngineState.SUSPENDING): return - self.is_suspended = True # Deallocate context tensors. with self.__class__.suspend_resume_ctx( @@ -660,14 +698,16 @@ def suspend(self): for request_id in recompute_active_ids: self.requests[request_id].record.checkpoint() + # If we are not using the inference coordinator, we need to manually handle state. + if not self.use_coordinator: + self.state = EngineState.SUSPENDED + def resume(self): """Resume engine by reallocating context's GPU state.""" - # Skip if not suspended, which can happen when using the inference - # coordinator. - if not self.is_suspended: + # Skip if not suspended or in the process of suspending. + if self.state not in (EngineState.SUSPENDED, EngineState.SUSPENDING): return - self.is_suspended = False # Resume. with self.__class__.suspend_resume_ctx( @@ -709,8 +749,13 @@ def resume(self): ) ) - # Notify event loop. - self._loop.call_soon_threadsafe(asyncio.create_task, self._notify_cond_for_new_request()) + # If we are not using the inference coordinator, we need to manually handle state. + if not self.use_coordinator: + self.state = EngineState.RUNNING + # Notify the condition variable that run_engine() waits on. + self._loop.call_soon_threadsafe( + asyncio.create_task, self._notify_cond_for_new_request() + ) @trace_async_exceptions async def _notify_cond_for_new_request(self): @@ -1340,7 +1385,7 @@ async def async_forward(self) -> Tuple[Dict, Dict, float]: """ # If suspended, no stepping. - if self.is_suspended: + if self.state in (EngineState.SUSPENDED, EngineState.SUSPENDING): raise EngineSuspendedError(self.context.step_count) # schedule requests @@ -1713,77 +1758,105 @@ def schedule_requests(self) -> int: all_messages = [] range_pop() + + # First pass: add all requests and detect staleness increments. + # Control signals are queued for the second pass. + has_staleness_increment = False for message in all_messages: data = msgpack.unpackb(message, raw=False) header = Headers(data[0]) - - if self.received_stop: - assert ( - header == Headers.STOP_ACK - ), "Engine is shutting down. No other messages allowed except STOP_ACK." - if header == Headers.SUBMIT_REQUEST: request_id, prompt, sampling_params = data[1:] sampling_params = SamplingParams.deserialize(sampling_params) range_push("add_request") self.add_request(request_id, prompt, sampling_params) range_pop() - elif header == Headers.PAUSE: - # Pause thyself. - self.received_pause = True - self.running.clear() - # Send PAUSE_ACK back to coordinator. - if self.is_mp_coordinator: - payload = msgpack.packb([Headers.PAUSE_ACK.value], use_bin_type=True) - self.socket_for_receiving_requests.send(payload) - elif header == Headers.STOP: - # Stop thyself. - self.received_stop = True - self.running.clear() - # Send STOP_ACK back to coordinator. - if self.is_mp_coordinator: - payload = msgpack.packb([Headers.STOP_ACK.value], use_bin_type=True) - self.socket_for_receiving_requests.send(payload) - elif header == Headers.PAUSE_ACK: - self.paused.set() - self.received_pause = False - elif header == Headers.STOP_ACK: - self.stopped.set() - self.received_stop = False + elif header == Headers.INCREMENT_STALENESS: + has_staleness_increment = True + else: + # Control signal: queue for second pass. + self._pending_signals.append(message) + + if has_staleness_increment: + waiting = set(self.waiting_request_ids) + for request_id, entry in self.requests.items(): + entry.record.increment_staleness(policy_only=request_id in waiting) + + # Second pass: apply at most one control signal (the engine loop + # processes one state transition per iteration). + if self._pending_signals: + message = self._pending_signals.popleft() + data = msgpack.unpackb(message, raw=False) + header = Headers(data[0]) + + if header == Headers.PAUSE: + if self.state == EngineState.RUNNING: + self.state = EngineState.PAUSING + self._state_events[EngineState.RUNNING].clear() + # Any other state can safely ignore PAUSE. + elif header == Headers.UNPAUSE: - self.paused.clear() - self.running.set() + assert self.state == EngineState.PAUSED, f"Received UNPAUSE in state {self.state}" + self.state = EngineState.UNPAUSING + elif header == Headers.SUSPEND: - self.suspend_signal = True + assert self.state == EngineState.PAUSED, f"Received SUSPEND in state {self.state}" + self._state_events[EngineState.RESUMED].clear() + self.suspend() + self.state = EngineState.SUSPENDING + elif header == Headers.RESUME: - self.suspend_signal = False - elif header == Headers.INCREMENT_STALENESS: - waiting = set(self.waiting_request_ids) - for request_id, entry in self.requests.items(): - entry.record.increment_staleness(policy_only=request_id in waiting) + assert self.state == EngineState.SUSPENDED, f"Received RESUME in state {self.state}" + self._state_events[EngineState.SUSPENDED].clear() + self.resume() + self.state = EngineState.RESUMING + elif header == Headers.STOP: - self.received_stop = True + assert self.state in ( + EngineState.PAUSED, + EngineState.SUSPENDED, + ), f"Received STOP in state {self.state}" + if self.state == EngineState.SUSPENDED: + self._state_events[EngineState.SUSPENDED].clear() + self.state = EngineState.STOPPING + else: raise UnknownHeaderError(header) return len(all_messages) - def stop(self): - """ - Stops the inference engine by terminating the inference coordinator process - if it exists, and destroys the model parallel state. - This method ensures that any running inference coordinator subprocess - is properly terminated, and cleans up resources associated with - model parallelism. + async def shutdown(self): + """Shut down the engine and clean up ZMQ resources. + + Called from the engine loop's finally block after the loop exits. """ + self.state = EngineState.STOPPED - if hasattr(self, "inference_coordinator_process"): - self.inference_coordinator_process.join() - for socket in self.zmq_sockets: - socket.close() + # Cleanup the request futures. + for entry in self.requests.values(): + if not entry.future.done(): + entry.future.cancel() + + # ZMQ cleanup; designed to be idempotent. + sock = getattr(self, 'socket_for_receiving_requests', None) + if sock is not None and not sock.closed: + try: + sock.send(msgpack.packb([Headers.DISCONNECT.value], use_bin_type=True)) + except Exception: + pass + for socket in getattr(self, 'zmq_sockets', []): + socket.close(linger=0) + if hasattr(self, 'zmq_sockets'): + self.zmq_sockets.clear() if hasattr(self, "expert_parallel_zmq_communicator"): self.expert_parallel_zmq_communicator.close() - self.zmq_context.term() + if hasattr(self, "world_zmq_communicator"): + self.world_zmq_communicator.close() + if not self.zmq_context.closed: + self.zmq_context.term() + + # Set the stopped state at the very end. + self._state_events[EngineState.STOPPED].set() @trace_async_exceptions async def run_engine(self, *, loop: Optional[asyncio.AbstractEventLoop] = None): @@ -1796,7 +1869,7 @@ async def run_engine(self, *, loop: Optional[asyncio.AbstractEventLoop] = None): async with self._cond: await self._cond.wait_for( lambda: ( - not self.is_suspended + self.state not in (EngineState.SUSPENDED, EngineState.SUSPENDING) and ( self.context.get_active_request_count() > 0 or self.waiting_request_ids @@ -1807,118 +1880,144 @@ async def run_engine(self, *, loop: Optional[asyncio.AbstractEventLoop] = None): except asyncio.CancelledError: pass - async def _ep_group_has_work(self, local_work: int) -> bool: - """Determines if there are some pending requests in the expert parallel group this - rank is a part of. + async def _ep_establish_consensus( + self, local_work: int, signal_consensus: bool + ) -> tuple[int, bool]: + """EP all-reduce to share work counts and pause consensus. + + All-reduces two integers at once: + - local_work: actual pending request count (always >= 0). + - consensus flag: -1 if this rank wants to pause, 0 otherwise. + + Using max for both: + - max(work) > 0 means at least one EP peer has real work. + - max(consensus) == -1 means ALL peers signaled -1 (all PAUSING). + Any RUNNING peer contributes 0, pulling the max to 0. + Args: - local_work (int): The local work count for this rank. This is a sum of active - and waiting requests. + local_work: Pending request count for this rank. + signal_consensus: True if this rank is ready to pause. Returns: - bool: True if there is some work in the EP group, False otherwise. + (global_work, all_pausing): max work across EP, and whether + all peers signaled consensus. """ - range_push("_ep_group_has_work") - - is_stopped = self.stopped.is_set() or self.received_stop - is_paused = self.paused.is_set() or self.received_pause - is_suspended = self.suspend_signal - if is_stopped or is_paused or is_suspended: - # Signals can be received asynchronously on EP ranks. - # We do not want a rank to pause/stop/suspend prematurely if one of it's peers - # is yet to receive the signal. - # So this is an *attempt* to process the signal. This rank has received the signal - # and passes 0 to the all-reduce. If any other rank in the EP group has not received the signal yet, - # it will pass a non-zero value to the all-reduce, and hence the global work will be non-zero, - # and we will defer processing the signal. - # When all ranks receive the signal, global work will be zero, and we can process the signal safely. - local_work = 0 + range_push("_ep_establish_consensus") + + consensus_val = -1 if signal_consensus else 0 + + # Signals can be received asynchronously on EP ranks. + # We do not want a rank to pause prematurely if its peers have yet to receive the signal. + # So this is an *attempt* to process the signal. This rank has received the signal + # and passes -1 to the all-reduce. If any other rank in the EP group has not received + # the signal yet, it will pass a zero value to the all-reduce, hence the global consensus + # will be zero and we will defer processing the signal. + # When all ranks receive the signal, global consensus will be -1 and we can process. if self.ep_world_size > 1: - # Perform all-reduce to get max global work across EP group. # Note that it is important to use a non-blocking asyncio-friendly all-reduce here. # The user may have other tasks running in the event loop that need to be serviced. # Do not using a torch.distributed blocking all-reduce here using nccl/gloo. # We have tried that and it blocks the event loop in megatron-rl. - max_global_work = await self.expert_parallel_zmq_communicator.all_reduce_max(local_work) + global_work, global_consensus = ( + await self.expert_parallel_zmq_communicator.all_reduce_max( + local_work, consensus_val + ) + ) else: - max_global_work = local_work + global_work, global_consensus = local_work, consensus_val range_pop() - return max_global_work > 0 + return global_work, global_consensus == -1 + + async def _world_barrier(self): + """World-wide ZMQ all-reduce barrier for global rank consensus. + + Used for all state transitions that require global synchronization: + PAUSING → PAUSED, UNPAUSING → RUNNING, SUSPENDING → SUSPENDED, + RESUMING → PAUSED, and STOPPING → STOPPED. + + No-op when world_size == 1 (communicator is not created). + """ + range_push("world_barrier") + if hasattr(self, 'world_zmq_communicator'): + await self.world_zmq_communicator.all_reduce_max(1) + range_pop() @trace_async_exceptions async def run_engine_with_coordinator( self, *, loop: Optional[asyncio.AbstractEventLoop] = None ): - """Continually steps the engine asynchronously.""" + """Continually steps the engine asynchronously. + + State-dependent behavior: + - RUNNING: EP all-reduce to check for work, then step or idle. + - PAUSING: EP all-reduce to reach consensus, then world barrier. + - PAUSED / SUSPENDED: Idle-sleep, wait for signals via schedule_requests(). + - UNPAUSING / SUSPENDING / RESUMING / STOPPING: World barrier, then transition. + - STOPPED: Teardown and exit. + """ self._loop = get_asyncio_loop(loop) self.use_coordinator = True + try: while True: self.schedule_requests() - # for the cases below (no active requests, or undergoing a state-change) - # do not use asyncio.sleep(0) - # as tp-rank=0 will flood the num_messages publisher - # with "0" repeatedly. This causes some packets to drop. - # Instead be nice, and sleep - # for a short time. - # The minimum sleep time needed is ~100us i.e. the time - # needed to send one message on an IPC socket. However - # just to be safe, we use 20ms here. - - local_pending_requests = self.context.get_active_request_count() + len( - self.waiting_request_ids - ) - # 1. Check for work availability (Consensus Step) - ep_group_has_work = await self._ep_group_has_work(local_pending_requests) - - # 2. Dummy Work Logic (Keep group alive if peers have work) - if ep_group_has_work and local_pending_requests == 0: - # run dummy forward pass if EP group as a whole has work, - # but this rank does not have any work. - self.step_start_event.record() - self.controller.dummy_forward() - self.step_end_event.record() - self.step_end_event.synchronize() - self.context.step_count += 1 - continue - - # 3. No work in EP group - # We handle control signals (PAUSE/STOP/SUSPEND) only when - # the entire EP group has received the signal. It is important to - # not process these signals immediately upon receipt, because - # other ranks in the EP group may not have received them yet. - # If we exit prematurely, other ranks will deadlock at the all-to-all. - # We use self._ep_group_has_work() to build consensus across the EP group - # as to when it is safe to process these signals. The function returns False - # when all ranks have received the signal. - if not ep_group_has_work: - # Priority A: STOP - if self.stopped.is_set(): - if self.rank == 0: - logging.info("Stopping engine.") - self.stop() - break + if self.state in (EngineState.RUNNING, EngineState.PAUSING): + local_pending = self.context.get_active_request_count() + len( + self.waiting_request_ids + ) + global_work, all_pausing = await self._ep_establish_consensus( + local_pending, signal_consensus=(self.state == EngineState.PAUSING) + ) - # Priority B: SUSPEND - if self.suspend_signal: - self.suspend() + if all_pausing: + # All EP peers are PAUSING: pause immediately. + await self._world_barrier() + self.state = EngineState.PAUSED + self._state_events[EngineState.PAUSED].set() + elif global_work > 0: + # At least one EP peer has work: all must participate. + if local_pending > 0: + await self.async_step() + else: + # Dummy forward to participate in the EP collective. + self.step_start_event.record() + self.controller.dummy_forward() + self.step_end_event.record() + self.step_end_event.synchronize() + self.context.step_count += 1 else: - self.resume() - - # Priority C: PAUSE or no work - nothing needs to be done - # To avoid flooding the TP publisher socket with packets, - # we sleep for 20 ms here. - # todo [Siddharth]: Can this hardcoded sleep be avoided - # with asyncio zmq sockets? - await asyncio.sleep(0.02) # Yield to event loop - continue + # No work, but not all pausing: idle. + await asyncio.sleep(0.02) - try: - await self.async_step() - except EngineSuspendedError: + elif self.state == EngineState.PAUSED: await asyncio.sleep(0.02) - continue - except asyncio.CancelledError: - pass + elif self.state == EngineState.UNPAUSING: + await self._world_barrier() + self.state = EngineState.RUNNING + self._state_events[EngineState.PAUSED].clear() + self._state_events[EngineState.RUNNING].set() + + elif self.state == EngineState.SUSPENDING: + await self._world_barrier() + self.state = EngineState.SUSPENDED + self._state_events[EngineState.SUSPENDED].set() + + elif self.state == EngineState.SUSPENDED: + await asyncio.sleep(0.02) + + elif self.state == EngineState.RESUMING: + await self._world_barrier() + self.state = EngineState.PAUSED + self._state_events[EngineState.RESUMED].set() + + elif self.state == EngineState.STOPPING: + await self._world_barrier() + if self.rank == 0: + logging.info("Stopping engine.") + break + + finally: + await self.shutdown() diff --git a/megatron/core/inference/headers.py b/megatron/core/inference/headers.py index 2551bc54f53..03c290ea925 100644 --- a/megatron/core/inference/headers.py +++ b/megatron/core/inference/headers.py @@ -13,17 +13,17 @@ class Headers(Enum): SUBMIT_REQUEST = auto() ENGINE_REPLY = auto() PAUSE = auto() - PAUSE_ACK = auto() UNPAUSE = auto() SUSPEND = auto() RESUME = auto() INCREMENT_STALENESS = auto() STOP = auto() - STOP_ACK = auto() + DISCONNECT = auto() + SHUTDOWN = auto() class UnknownHeaderError(Exception): """A signal with an unrecognized header was received by the coordinator.""" - def __init_(self, header): + def __init__(self, header): super().__init__(f"specialize for {header}.") diff --git a/megatron/core/inference/inference_client.py b/megatron/core/inference/inference_client.py index 5a4ed24bebb..33c53fa447f 100644 --- a/megatron/core/inference/inference_client.py +++ b/megatron/core/inference/inference_client.py @@ -3,7 +3,7 @@ import asyncio import logging import time -from typing import Awaitable, List, Optional, Union +from typing import List, Optional, Union from megatron.core.inference.inference_request import DynamicInferenceRequestRecord from megatron.core.inference.sampling_params import SamplingParams @@ -25,8 +25,6 @@ except: HAVE_MSGPACK = False -from .headers import Headers - class InferenceClient: """ @@ -72,10 +70,6 @@ def __init__(self, inference_coordinator_address: str): socket.connect(inference_coordinator_address) self._loop = None - self.running = asyncio.Event() - self.paused = asyncio.Event() - self.stopped = asyncio.Event() - self.socket = socket self.completion_futures = {} self.request_submission_times = {} @@ -101,8 +95,6 @@ def add_request( asyncio.Future: A future that will be resolved with a `DynamicInferenceRequestRecord` object containing the completed result. """ - if not self.running.is_set(): - raise RuntimeError("InferenceClient is not currently running.") request_id = self.next_request_id self.next_request_id += 1 payload = [Headers.SUBMIT_REQUEST.value, request_id, prompt, sampling_params.serialize()] @@ -143,10 +135,6 @@ async def _recv_task(self): completion_future.get_loop().call_soon_threadsafe( completion_future.set_result, completed_request ) - elif header == Headers.PAUSE_ACK: - self.paused.set() - elif header == Headers.STOP_ACK: - self.stopped.set() except zmq.Again: await asyncio.sleep(0.005) continue @@ -165,19 +153,16 @@ def _connect_with_inference_coordinator(self): reply = msgpack.unpackb(self.socket.recv(), raw=False)[0] assert Headers(reply) == Headers.CONNECT_ACK - async def start(self, loop: Optional[asyncio.AbstractEventLoop] = None): + def start(self, loop: Optional[asyncio.AbstractEventLoop] = None): """ Connects to the coordinator and starts the background listener task. - This method must be awaited before submitting any requests. It handles + This must be called before submitting any requests. It handles the initial handshake and spawns the `listen_for_completed_requests` coroutine. """ logging.info("Client: Connecting to InferenceCoordinator...") self._loop = get_asyncio_loop(loop) - self.running.set() - self.paused.clear() - self.stopped.clear() self._connect_with_inference_coordinator() self.listener_task = self._loop.create_task(self._recv_task()) @@ -192,58 +177,51 @@ def _send_signal_to_engines(self, signal): payload_serialized = msgpack.packb(payload, use_bin_type=True) self.socket.send(payload_serialized) - def pause_engines(self) -> Awaitable: - """Sends a signal to pause all inference engines. - - The signal first propagates thru the coordinator to all engines. - All engines acknowledge this signal and clear their `running` flags. - The coordinator awaits all acknowledgements before forwarding the ACK - back to the client, as well as to the engines. - The engines set their `paused` flags upon seeing the ACK. + def pause_engines(self): + """Sends PAUSE to all engines via coordinator. - Returns: - Awaitable: An awaitable that resolves when all engines have paused. + The coordinator broadcasts PAUSE. Each engine reaches EP consensus, + then synchronizes via a world-wide barrier before transitioning to + PAUSED. Callers should await engine.paused for confirmation. """ self._send_signal_to_engines(Headers.PAUSE) - return self.paused.wait() def unpause_engines(self) -> None: - """Sends a signal to unpause all inference engines.""" - self.paused.clear() - self.running.set() + """Sends UNPAUSE to all engines. No synchronization needed.""" self._send_signal_to_engines(Headers.UNPAUSE) def increment_staleness(self): """Sends a signal to increment staleness on all in-flight requests.""" - assert self.paused.is_set(), "Can only increment staleness while engines are paused." self._send_signal_to_engines(Headers.INCREMENT_STALENESS) def suspend_engines(self): - """Sends a signal to pause all inference engines.""" - self._send_signal_to_engines(Headers.PAUSE) + """Sends SUSPEND to all engines via coordinator. Requires PAUSED. + + Callers should await engine.suspended for confirmation. + """ self._send_signal_to_engines(Headers.SUSPEND) def resume_engines(self): - """Sends a signal to unpause all inference engines.""" - self.paused.clear() - self._send_signal_to_engines(Headers.RESUME) - self._send_signal_to_engines(Headers.UNPAUSE) + """Sends RESUME to all engines via coordinator. Requires SUSPENDED. - def stop_engines(self) -> Awaitable: - """Sends a signal to gracefully stop all inference engines. + Callers should await engine.paused (or engine.running after UNPAUSE) for confirmation. + """ + self._send_signal_to_engines(Headers.RESUME) - The signal first propagates thru the coordinator to all engines. - All engines acknowledge this signal and clear their `running` flags. - The coordinator awaits all acknowledgements before forwarding the ACK - back to the client, as well as to the engines. - The engines set their `stopped` flags upon seeing the ACK. + def stop_engines(self): + """Sends STOP to all engines via coordinator. Requires PAUSED or SUSPENDED. - Returns: - Awaitable: An awaitable that resolves when all engines have stopped. + Callers should await engine.stopped for confirmation. + Does not affect the coordinator. """ self._send_signal_to_engines(Headers.STOP) - self.running.clear() - return self.stopped.wait() + + def shutdown_coordinator(self): + """Tells the coordinator process to exit its main loop. + + Does not affect the engines. + """ + self._send_signal_to_engines(Headers.SHUTDOWN) def stop(self): """ @@ -253,6 +231,12 @@ def stop(self): and terminates the ZMQ context. It should be called when the client is no longer needed to ensure a graceful shutdown. """ - self.listener_task.cancel() - self.socket.close() + if hasattr(self, 'listener_task') and not self.listener_task.done(): + self.listener_task.cancel() + # Wake up any listeners. + for future in self.completion_futures.values(): + if not future.done(): + future.cancel() + self.completion_futures.clear() + self.socket.close(linger=0) self.context.term() diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/flask_server.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/flask_server.py index 73b9684ad48..5366ba977e0 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/flask_server.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/flask_server.py @@ -100,7 +100,7 @@ async def run_flask_server( """Initializes and runs the async Flask server starting an InferenceClient with the provided coordinator address.""" inference_client = InferenceClient(coordinator_addr) - await inference_client.start() + inference_client.start() logger.info(f"Rank {rank}: InferenceClient connected.") try: await run_flask_server_on_client(inference_client, tokenizer, flask_port, parsers, verbose) diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 8b06087e124..90063219f8a 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -224,9 +224,15 @@ def add_inference_args(parser: ArgumentParser) -> ArgumentParser: type=int, default=None, help="Suspend and resume the dynamic engine every " - "`suspend_resume_interval` steps. This is used to tet the suspend/resume " + "`suspend_resume_interval` requests. This is used to test the suspend/resume " "system.", ) + group.add_argument( + "--suspend-timeout", + type=float, + default=0.0, + help="Seconds to sleep while the engine is suspended (simulates a training step).", + ) group.add_argument( "--inference-repeat-n", type=int, @@ -258,7 +264,6 @@ def add_inference_args(parser: ArgumentParser) -> ArgumentParser: default=None, help="Path to write coordinator request scheduling decisions as JSON", ) - return parser diff --git a/megatron/rl/inference/megatron.py b/megatron/rl/inference/megatron.py index b1a7dcd8b83..3a7d17b57a0 100644 --- a/megatron/rl/inference/megatron.py +++ b/megatron/rl/inference/megatron.py @@ -7,7 +7,7 @@ from pydantic import PrivateAttr from megatron.core.inference.config import KVCacheManagementMode -from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine +from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine, EngineState from megatron.core.inference.inference_client import InferenceClient from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.utils import log_single_rank @@ -100,7 +100,7 @@ async def launch(cls, model: GPTModel, **kwargs): from megatron.core.inference.text_generation_server.dynamic_text_gen_server.flask_server import run_flask_server_on_client loop = asyncio.get_event_loop() client = InferenceClient(inference_coordinator_address=dp_addr) - await client.start() + client.start() server_task = loop.create_task(run_flask_server_on_client( client=client, tokenizer=inference_engine.controller.tokenizer, @@ -124,21 +124,38 @@ async def launch(cls, model: GPTModel, **kwargs): async def kill(self): if dist.get_rank() == 0: - await self._client.stop_engines() - await self._inference_engine.stopped.wait() + self._client.pause_engines() + await self._inference_engine.wait_until(EngineState.PAUSED) + + if dist.get_rank() == 0: + self._client.stop_engines() + await self._inference_engine.wait_until(EngineState.STOPPED) + + if dist.get_rank() == 0: + self._client.shutdown_coordinator() + self._client.stop() def increment_staleness(self): if dist.get_rank() == 0: self._client.increment_staleness() async def suspend(self): + if dist.get_rank() == 0: + self._client.pause_engines() + await self._inference_engine.wait_until(EngineState.PAUSED) + if dist.get_rank() == 0: self._client.suspend_engines() - await self._inference_engine.paused.wait() - self._inference_engine.suspend() + await self._inference_engine.wait_until(EngineState.SUSPENDED) async def resume(self): + if self._inference_engine._state_events[EngineState.RUNNING].is_set(): + return + if dist.get_rank() == 0: self._client.resume_engines() - await self._inference_engine.running.wait() - self._inference_engine.resume() + await self._inference_engine.wait_until(EngineState.RESUMED) + + if dist.get_rank() == 0: + self._client.unpause_engines() + await self._inference_engine.wait_until(EngineState.RUNNING) diff --git a/megatron/rl/server/api.py b/megatron/rl/server/api.py index 528e6e880dc..6635b3d48dc 100644 --- a/megatron/rl/server/api.py +++ b/megatron/rl/server/api.py @@ -18,7 +18,7 @@ async def launch(cls) -> Self: async def suspend(self): pass - def resume(self): + async def resume(self): pass async def kill(self): diff --git a/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py b/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py index 81e2fc925b5..0138308b991 100644 --- a/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py +++ b/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py @@ -56,6 +56,9 @@ def test_inference_pipeline( model_config_content = f3.read() metrics = yaml.safe_load(model_config_content)["METRICS"] + if not metrics: + print("No metrics defined in model_config.yaml, skipping validation.") + return output_groundtruth = json.loads(golden_values_content) diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp8_dp1_583m_logitsmatch_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp8_dp1_583m_logitsmatch_zmq/model_config.yaml index bd34c11fc24..345fc250694 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp8_dp1_583m_logitsmatch_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp8_dp1_583m_logitsmatch_zmq/model_config.yaml @@ -42,9 +42,6 @@ MODEL_ARGS: --top_k: 1 --return-log-probs: true --num-tokens-to-generate: 30 - --inference-dynamic-batching-max-requests-override: 8 # hardcode decode padding tokens to 7 for reproducibility - --inference-dynamic-batching-buffer-guaranteed-fraction: 0 - --inference-dynamic-batching-buffer-overflow-factor: 0.2 --inference-dynamic-batching-buffer-size-gb: 20 --dist-ckpt-strictness: log_unexpected --inference-ckpt-non-strict: true # To handle the extra_state errors diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/golden_values_dev_dgx_h100.json index 55d6955055a..3af1d61504a 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/golden_values_dev_dgx_h100.json @@ -1,158 +1,158 @@ { "0": { "input_prompt": "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies.", - "generated_text": " And then you get to the end of the movie, and you realize that this is not New York at all. This is New York at the end", + "generated_text": " And that this is the place where you can be yourself, and be yourself in the most beautiful way. And that this is the place where you can", "generated_tokens": [ 3060, - 2430, - 1636, - 2012, - 1317, - 1278, - 2362, - 1307, + 1455, + 1593, + 1395, 1278, - 16070, + 3535, + 2478, + 1636, + 1710, + 1402, + 14019, 1044, 1321, - 1636, - 23067, + 1402, + 14019, + 1294, + 1278, + 2725, + 15568, + 3039, + 1046, + 3060, 1455, 1593, 1395, - 1605, - 3140, - 5152, - 1513, - 1747, - 1046, - 2409, - 1395, - 3140, - 5152, - 1513, 1278, - 2362 + 3535, + 2478, + 1636, + 1710 ], - "latency": 44.73653959017247, + "latency": 56.18305364251137, "logprobs": [ - -9.358970642089844, - -2.7523813247680664, - -4.628502368927002, - -1.4058877229690552, - -0.6050865054130554, - -1.7354254722595215, - -2.4828507900238037, - -2.0520384311676025, - -2.4089853763580322, - -6.2649126052856445, - -1.5644135475158691, - -3.4096615314483643, - -4.358163833618164, - -3.866471767425537, - -2.0575876235961914, - -1.904883623123169, - -3.7622976303100586, - -6.835415363311768, - -0.2829523980617523, - -0.9827429056167603, - -6.655940055847168, - -7.188957214355469, - -12.757233619689941, - -2.1933951377868652, - -3.808887481689453, - -0.515199601650238, - -4.323916912078857, - -0.067625492811203, - -0.09976530075073242, - -3.228640556335449, - -10.129311561584473, - -1.1787357330322266, - -5.97692346572876, - -5.036575794219971, - -3.8267176151275635, - -2.6010468006134033, - -3.366438865661621, - -5.553505897521973, - -1.6046268939971924, - -5.442874908447266, - -12.218503952026367, - -12.597894668579102, - -0.0976092740893364, - -2.530579090118408, - -1.4139617681503296, - -2.8606526851654053, - -1.1690009832382202, - -0.0066696410067379475, - -3.361189365386963, - -13.191482543945312, - -4.413737773895264, - -2.639688491821289, - -6.0114641189575195, - -0.7672993540763855, - -0.047326065599918365, - -1.550362467765808, - -1.137772798538208, - -5.627618789672852, - -0.40103790163993835, - -4.908735275268555, - -0.5704602599143982, - -0.6625558733940125, - -2.364135503768921, - -13.609526634216309, - -0.08865148574113846, - -3.5251970291137695, - -1.3791766166687012, - -6.395696640014648, - -0.588782787322998, - -3.566770076751709, - -0.8742034435272217, - -1.5827170610427856, - -5.3912353515625, - -17.150842666625977, - -6.6234588623046875, - -0.885993242263794, - -4.162992477416992, - -1.1942744255065918, - -2.281689405441284, - -1.7708709239959717, - -0.22030864655971527, - -9.292593955993652, - -0.1258234828710556, - -7.346449851989746, - -2.5470826625823975, - -4.115433692932129, - -3.5646262168884277, - -1.9410749673843384, - -2.3247878551483154, - -1.523364543914795, - -2.360647678375244, - -1.708706021308899, - -1.131014108657837, - -2.944424867630005, - -0.5273782014846802, - -0.44912564754486084, - -1.753378987312317, - -0.8341047167778015, - -0.4124295711517334, - -0.9006240367889404, - -1.4890273809432983, - -0.4379286766052246, - -1.6497018337249756, - -0.5444425344467163, - -1.2305881977081299, - -1.164027214050293, - -0.002498721005395055, - -1.165798544883728, - -0.007112303748726845, - -0.718407154083252, - -0.7442683577537537, - -0.04299728572368622, - -0.8688321113586426, - -0.021008115261793137, - -2.033963680267334, - -1.2936673164367676, - -0.78721684217453 + -9.358942031860352, + -2.7132151126861572, + -4.606732368469238, + -1.4793059825897217, + -0.604263186454773, + -1.7374769449234009, + -2.485668897628784, + -2.1064839363098145, + -2.4603278636932373, + -6.253784656524658, + -1.4727367162704468, + -3.4053215980529785, + -4.36705207824707, + -3.8439993858337402, + -2.0021021366119385, + -1.8833506107330322, + -3.7835519313812256, + -6.891242980957031, + -0.28234225511550903, + -0.911859393119812, + -6.631955146789551, + -7.208620071411133, + -12.827497482299805, + -2.126032590866089, + -3.8147177696228027, + -0.5067541599273682, + -4.314828872680664, + -0.06301839649677277, + -0.10691610723733902, + -3.262773036956787, + -10.134418487548828, + -1.1751978397369385, + -6.014812469482422, + -5.020193576812744, + -3.8787002563476562, + -2.6112544536590576, + -3.366523027420044, + -5.561098098754883, + -1.622261643409729, + -5.453547477722168, + -12.218475341796875, + -12.583305358886719, + -0.09843693673610687, + -2.528902769088745, + -1.4074000120162964, + -2.8589088916778564, + -1.201108455657959, + -0.006660522893071175, + -3.3809216022491455, + -13.233884811401367, + -4.536578178405762, + -2.6009042263031006, + -6.00333309173584, + -0.7021315693855286, + -0.05108878016471863, + -1.5450650453567505, + -1.135668396949768, + -5.629673957824707, + -0.3985691964626312, + -4.960829257965088, + -0.5763677358627319, + -0.6644068360328674, + -2.4823904037475586, + -13.577603340148926, + -0.09277129173278809, + -3.5255157947540283, + -1.419959306716919, + -6.460269927978516, + -0.5874902009963989, + -3.569597005844116, + -0.8671818971633911, + -1.6106094121932983, + -5.3532209396362305, + -17.28781509399414, + -6.644278526306152, + -0.9023992419242859, + -4.112849235534668, + -1.1764109134674072, + -2.2248306274414062, + -1.7836185693740845, + -0.22097186744213104, + -9.294025421142578, + -0.13404981791973114, + -7.356222629547119, + -2.4790220260620117, + -4.13104248046875, + -3.6036787033081055, + -1.9332302808761597, + -2.328838348388672, + -2.2153103351593018, + -0.30630964040756226, + -1.4429736137390137, + -2.2041285037994385, + -1.1290805339813232, + -1.6353514194488525, + -1.4010272026062012, + -1.969598650932312, + -1.5041292905807495, + -0.8407909870147705, + -1.2370984554290771, + -1.9493464231491089, + -2.3773415088653564, + -1.624193787574768, + -0.9511104822158813, + -1.8184833526611328, + -2.394134521484375, + -1.2718405723571777, + -1.5075287818908691, + -0.5832712054252625, + -0.6682117581367493, + -0.5109573602676392, + -0.022901568561792374, + -0.483473539352417, + -0.2419981062412262, + -0.0857735350728035, + -0.13640964031219482, + -0.06942988932132721 ] } } diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/model_config.yaml index 13d409c5968..3b55b09e82e 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/model_config.yaml @@ -23,6 +23,7 @@ MODEL_ARGS: --distributed-backend: nccl --log-interval: 1 --transformer-impl: inference_optimized + --inference-fuse-tp-communication: true --sequence-parallel: true --tensor-model-parallel-size: 2 --pipeline-model-parallel-size: 2 @@ -42,9 +43,6 @@ MODEL_ARGS: --top_k: 1 --return-log-probs: true --num-tokens-to-generate: 30 - --inference-dynamic-batching-max-requests-override: 8 # hardcode decode padding tokens to 7 for reproducibility - --inference-dynamic-batching-buffer-guaranteed-fraction: 0 - --inference-dynamic-batching-buffer-overflow-factor: 0.2 --inference-dynamic-batching-buffer-size-gb: 20 --dist-ckpt-strictness: log_unexpected --inference-ckpt-non-strict: true # To handle the extra_state errors diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml index 8d5779a5099..88a3e40a193 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml @@ -51,9 +51,6 @@ MODEL_ARGS: --incoming-requests-per-step: 32 --use-flashinfer-fused-rope: true --inference-logging-step-interval: 1 - --cuda-graph-impl: local - --inference-dynamic-batching-max-requests: 128 - --inference-dynamic-batching-num-cuda-graphs: 2 METRICS: - "generated_tokens" - "logprobs" diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/model_config.yaml index f154a91c778..80e2a37c250 100644 --- a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/model_config.yaml @@ -79,13 +79,10 @@ MODEL_ARGS: --inference-dynamic-batching-buffer-size-gb: 20 --cuda-graph-impl: local --moe-pad-experts-for-cuda-graph-inference: true - --inference-dynamic-batching-buffer-size-gb: 20 --inference-dynamic-batching-num-cuda-graphs: -1 --inference-dynamic-batching-max-requests: 16 --inference-logging-step-interval: 1 - --sequence-parallel: true --moe-enable-routing-replay: true - METRICS: - "generated_tokens" - "logprobs" diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/model_config.yaml index d62d10db7c1..479cb7a4751 100644 --- a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/model_config.yaml @@ -76,9 +76,11 @@ MODEL_ARGS: --prompts: "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies." --incoming-requests-per-sec: -1 # all requests arrive up front. --inference-repeat-n: 8 - --inference-logging-step-interval: 1 - --sequence-parallel: true - + --inference-dynamic-batching-buffer-size-gb: 20 + --inference-dynamic-batching-max-requests: 16 + --inference-logging-step-interval: 1 + --moe-enable-routing-replay: true METRICS: - "generated_tokens" - "logprobs" + - "routing_indices" diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..66c9e3e4121 --- /dev/null +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/golden_values_dev_dgx_h100.json @@ -0,0 +1,158 @@ +{ + "0": { + "input_prompt": "Time travel to 2008, and go to a bar or a club or one of the myriad disco-basements on the Lower East Side that does not quite know which of those it is. Dance awkwardly in a room full of other glittered-up nerds, and wait for something to happen, buoyed on the feeling that this is the big swollen heart of life, that this is New York like the movies.", + "generated_text": " Wait for the moment when the music stops, and the lights come up, and the DJ says, \"I'm going to play a song for you", + "generated_tokens": [ + 32844, + 1394, + 1278, + 4735, + 2200, + 1278, + 7146, + 30774, + 1044, + 1321, + 1278, + 26466, + 3930, + 2015, + 1044, + 1321, + 1278, + 30245, + 8223, + 1044, + 1429, + 1073, + 4525, + 4670, + 1317, + 3354, + 1261, + 6947, + 1394, + 1636 + ], + "latency": 28.185462809633464, + "logprobs": [ + -10.737512588500977, + -3.724862575531006, + -2.833397388458252, + -1.2464861869812012, + -0.2549239993095398, + -1.7607988119125366, + -2.419379711151123, + -1.9533929824829102, + -2.1014301776885986, + -6.169030666351318, + -0.8734959363937378, + -2.4733574390411377, + -3.4822516441345215, + -4.180896759033203, + -1.9767613410949707, + -1.8347630500793457, + -2.2581257820129395, + -7.180149078369141, + -0.0453881211578846, + -1.9841610193252563, + -5.015386581420898, + -8.827117919921875, + -9.885746002197266, + -0.8498678207397461, + -4.770059585571289, + -0.855280339717865, + -2.2494924068450928, + -0.017164958640933037, + -0.03715415671467781, + -3.4830124378204346, + -8.635110855102539, + -1.2520610094070435, + -6.62324857711792, + -3.639960765838623, + -3.664339542388916, + -4.182392597198486, + -2.1796066761016846, + -1.0725229978561401, + -0.26311880350112915, + -0.8036076426506042, + -4.6958818435668945, + -9.042495727539062, + -0.013647346757352352, + -3.1747794151306152, + -1.322129487991333, + -3.949110746383667, + -0.7829495072364807, + -0.002083513652905822, + -2.970266580581665, + -10.56244945526123, + -3.2369167804718018, + -1.1530492305755615, + -4.917466163635254, + -0.21241025626659393, + -0.06490474194288254, + -1.372581124305725, + -2.224682092666626, + -4.3847503662109375, + -0.36867555975914, + -4.035493850708008, + -0.39869019389152527, + -0.14373983442783356, + -2.716118812561035, + -10.687016487121582, + -0.04773370549082756, + -3.398231267929077, + -0.8646175265312195, + -4.74052619934082, + -0.23649944365024567, + -2.6610701084136963, + -0.8428961634635925, + -1.614527940750122, + -5.793307781219482, + -16.929147720336914, + -2.6586406230926514, + -0.1385982781648636, + -7.435610771179199, + -1.0483647584915161, + -2.1261863708496094, + -1.5261307954788208, + -0.27082547545433044, + -5.859070777893066, + -0.00648513063788414, + -7.732051849365234, + -2.712515354156494, + -2.9137418270111084, + -3.041210651397705, + -2.3559694290161133, + -0.3973437249660492, + -1.4338903427124023, + -2.2967660427093506, + -0.6096595525741577, + -1.3119444847106934, + -1.93257474899292, + -1.726539134979248, + -0.8397530317306519, + -0.5014236569404602, + -1.2989763021469116, + -1.5857150554656982, + -1.096572995185852, + -0.4009067416191101, + -0.43302634358406067, + -0.041601795703172684, + -1.285712718963623, + -2.214778184890747, + -2.6971933841705322, + -0.8101387619972229, + -0.43101266026496887, + -2.808060884475708, + -1.5226430892944336, + -1.6209226846694946, + -0.048716772347688675, + -1.3497682809829712, + -1.343377947807312, + -1.2755295038223267, + -1.2342015504837036, + -0.5394397377967834 + ] + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/model_config.yaml new file mode 100644 index 00000000000..1f302455440 --- /dev/null +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/model_config.yaml @@ -0,0 +1,89 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: :4096:8 +TEST_TYPE: frozen-start +MODE: inference +MODEL_ARGS: + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --log-memory-to-tensorboard: true + --timing-log-level: 0 + --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints + --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json + --tokenizer-type: TikTokenizer + --tiktoken-pattern: v2 + --distributed-backend: nccl + --log-interval: 1 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 4 + --pipeline-model-parallel-size: 1 + --expert-model-parallel-size: 8 + --expert-tensor-parallel-size: 1 + --sequence-parallel: true + --use-mcore-models: true + --moe-token-dispatcher-type: alltoall + --moe-grouped-gemm: true + --num-experts: 64 + --moe-router-topk: 6 + --moe-z-loss-coeff: 0 + --moe-router-load-balancing-type: seq_aux_loss + --moe-aux-loss-coeff: 1e-3 + --moe-router-score-function: sigmoid + --untie-embeddings-and-output-weights: true + --disable-bias-linear: true + --init-method-std: 0.014 + --position-embedding-type: rope + --rotary-base: 1000000 + --rotary-percent: 1.0 + --num-layers: 27 + --hidden-size: 2048 + --moe-ffn-hidden-size: 1408 + --moe-shared-expert-intermediate-size: 2816 + --ffn-hidden-size: 10944 + --num-attention-heads: 16 + --kv-channels: 128 + --normalization: RMSNorm + --swiglu: true + --attention-dropout: 0.0 + --hidden-dropout: 0.0 + --seq-length: 4096 + --max-position-embeddings: 4096 + --micro-batch-size: 1 + --ckpt-format: torch_dist + --ckpt-fully-parallel-save: true + --ckpt-fully-parallel-load: true + --ckpt-assume-constant-structure: true + --dist-ckpt-strictness: log_unexpected + --bf16: true + --attention-backend: flash + --no-create-attention-mask-in-dataloader: true + --num-workers: 8 + --use-checkpoint-args: true + --no-use-tokenizer-model-from-checkpoint-args: true + --no-load-optim: true + --deterministic-mode: true # moe will use different ops for determinism for inference + --save-interval: 2000 + --temperature: 1.0 + --top_k: 1 + --return-log-probs: true + --num-tokens-to-generate: 30 + --max-tokens-to-oom: 3600000 + --inference-max-seq-length: 4096 + --output-path: ${INFERENCE_OUTPUT_PATH} + --incoming-requests-per-sec: -1 # all requests arrive up front. + --inference-repeat-n: 8 + --inference-dynamic-batching-buffer-size-gb: 20 + --inference-dynamic-batching-max-requests: 16 + --inference-logging-step-interval: 1 + --moe-enable-routing-replay: true + --prompt-file: "./tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/prompts.json" + --suspend-resume-interval: 16 + --suspend-timeout: 4.0 + --rl-kv-cache-management-mode: recompute + --inference-dynamic-batching-unified-memory-level: 0 + --no-rl-persist-cuda-graphs: true + +METRICS: diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/prompts.json b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/prompts.json new file mode 100644 index 00000000000..7dcda2a3fa6 --- /dev/null +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/prompts.json @@ -0,0 +1,159 @@ +{"text": "The cat sat on the mat and"} +{"text": "The sun rose in the east,"} +{"text": "She opened the door and saw"} +{"text": "He picked up the book,"} +{"text": "Rain poured down all day,"} +{"text": "The old house stood"} +{"text": "Birds sang in the trees,"} +{"text": "A car drove by,"} +{"text": "The dog barked loudly,"} +{"text": "The water was very"} +{"text": "The capital of France is Paris. It is"} +{"text": "Water boils at 100 degrees Celsius. This is"} +{"text": "Humans have five senses: sight, sound, smell, taste, and"} +{"text": "The Earth revolves around the Sun, taking approximately"} +{"text": "A triangle has three sides. A square has"} +{"text": "Computers process information using"} +{"text": "The Amazon River is the"} +{"text": "Photosynthesis is the process by which plants"} +{"text": "Gravity is a force that"} +{"text": "DNA stands for"} +{"text": "Once upon a time, there was"} +{"text": "All's well that ends"} +{"text": "Every cloud has a silver"} +{"text": "A stitch in time saves"} +{"text": "Bite the bullet and"} +{"text": "Cost an arm and a"} +{"text": "Hit the road,"} +{"text": "Let the cat out of the"} +{"text": "Speak of the devil and"} +{"text": "The early bird catches the"} +{"text": "Monday, Tuesday, Wednesday,"} +{"text": "1, 2, 3, 4,"} +{"text": "Apple, Banana, Orange,"} +{"text": "Red, Green, Blue,"} +{"text": "North, South, East,"} +{"text": "First, Second, Third,"} +{"text": "A, B, C, D,"} +{"text": "January, February, March,"} +{"text": "Up, Down, Left,"} +{"text": "Small, Medium, Large,"} +{"text": "In a small village, nestled among the hills, lived"} +{"text": "The detective peered through the misty window. Outside,"} +{"text": "A sudden gust of wind swept through the forest,"} +{"text": "She woke up with a start, realizing that"} +{"text": "The ancient map, tattered and worn, showed"} +{"text": "Far, far away, in a land of dragons and magic,"} +{"text": "He hesitated at the crossroads, unsure which"} +{"text": "The ship sailed smoothly across the calm sea,"} +{"text": "A mysterious package arrived on her doorstep. Inside,"} +{"text": "The old clock in the hall chimed midnight."} +{"text": "The sky was a brilliant shade of"} +{"text": "Her hair was long and flowing, like"} +{"text": "The smell of freshly baked bread filled"} +{"text": "The old, gnarled tree stood on the hill, its branches"} +{"text": "The city lights twinkled in the distance,"} +{"text": "His voice was deep and resonant,"} +{"text": "The surface of the lake was smooth as"} +{"text": "The vibrant colors of the sunset painted"} +{"text": "The intricate pattern on the rug showed"} +{"text": "The air was crisp and cool,"} +{"text": "What do you call a group of"} +{"text": "How long does it take to"} +{"text": "Where can one find the most"} +{"text": "Why is the sky"} +{"text": "When did the first"} +{"text": "Who was the last person to"} +{"text": "Which way should we go to"} +{"text": "What would happen if"} +{"text": "How many stars are"} +{"text": "Where does the river"} +{"text": "`def main():`"} +{"text": "``"} +{"text": "{ \"name\": \"John\", \"age\":"} +{"text": "`import os`"} +{"text": "`SELECT * FROM users WHERE`"} +{"text": "\"Hello?\" she whispered into the darkness. \"Is"} +{"text": "He said, \"I don't know what to do.\" She replied, \"Well,"} +{"text": "\"Can you pass the salt?\" asked Tom. Sarah reached for"} +{"text": "\"Good morning,\" the shopkeeper greeted. \"How may I"} +{"text": "\"I'm so tired,\" she sighed. \"Me too,\" he agreed, \"I just want to"} +{"text": "Her laughter was music to"} +{"text": "The city was a sleeping giant,"} +{"text": "Time flew like an arrow,"} +{"text": "His words were daggers,"} +{"text": "The fog was a thick blanket,"} +{"text": "The waves crashed like thunder,"} +{"text": "Hope was a fragile bird,"} +{"text": "The moon was a silver coin,"} +{"text": "The wind whispered secrets,"} +{"text": "Life is a journey, not"} +{"text": "The quick brown fox jumps over"} +{"text": "Once bitten, twice"} +{"text": "As clear as day,"} +{"text": "The more, the merrier."} +{"text": "A penny for your"} +{"text": "You can't have your cake and eat"} +{"text": "The early bird gets the"} +{"text": "Between a rock and a hard"} +{"text": "Don't put all your eggs in one"} +{"text": "When in Rome, do as the"} +{"text": "Spring, Summer, Autumn,"} +{"text": "Do, Re, Mi, Fa,"} +{"text": "North America, South America, Europe,"} +{"text": "Circle, Square, Triangle,"} +{"text": "Tiny, Small, Medium,"} +{"text": "The train arrived late, causing"} +{"text": "She carefully placed the delicate vase on"} +{"text": "He decided to take a long walk to clear"} +{"text": "The old woman sat by the fire, knitting"} +{"text": "The news spread quickly throughout the"} +{"text": "The aroma of coffee filled the air, inviting"} +{"text": "Beneath the surface of the calm water,"} +{"text": "He adjusted his glasses and began to read"} +{"text": "The sound of distant thunder rumbled,"} +{"text": "The forgotten garden was overgrown with"} +{"text": "The doctor examined the patient and prescribed"} +{"text": "The chef prepared a delicious meal, using"} +{"text": "The engineer designed a new bridge, considering"} +{"text": "The artist painted a vibrant landscape, capturing"} +{"text": "The scientist conducted an experiment, observing"} +{"text": "A group of crows is called a"} +{"text": "A parliament of owls is a"} +{"text": "The opposite of hot is"} +{"text": "The square root of 9 is"} +{"text": "The longest river in Africa is the"} +{"text": "The study found that regular exercise can"} +{"text": "The theory of relativity was proposed by"} +{"text": "The process of melting ice is called"} +{"text": "Photosynthesis requires sunlight, water, and"} +{"text": "The main ingredient in bread is"} +{"text": "What is the primary function of the"} +{"text": "How do you define the term"} +{"text": "Where does the saying 'strike while the iron is hot' come from?"} +{"text": "Why do leaves change color in the autumn?"} +{"text": "When was the first personal computer"} +{"text": "Who invented the light bulb?"} +{"text": "Which planet is known as the Red Planet?"} +{"text": "What happens if you mix acid with a base?"} +{"text": "How many days are in a leap year?"} +{"text": "Where do migratory birds go in the winter?"} +{"text": "`function calculateArea(radius) {`"} +{"text": "`# Python comments start with`"} +{"text": "`CREATE TABLE employees (`"} +{"text": "``"} +{"text": "\"Excuse me,\" he coughed, \"but I think you've dropped something.\""} +{"text": "She retorted, \"I'm not sure what you mean by that.\""} +{"text": "\"It's freezing in here!\" he exclaimed. \"Can someone please\""} +{"text": "The customer complained, \"This isn't what I ordered.\" The waiter replied,"} +{"text": "\"Tell me about your day,\" she encouraged. He began,"} +{"text": "His memory was a patchwork quilt of"} +{"text": "The problem was a Gordian knot, impossible to"} +{"text": "The news hit him like a ton of"} +{"text": "Her words were a balm to his"} +{"text": "The traffic moved at a snail's pace on"} +{"text": "The silence in the room was deafening, a heavy"} +{"text": "Opportunity knocked loudly on his"} +{"text": "The city lights were a scatter of diamonds against"} +{"text": "The mountain stood sentinel over the"} +{"text": "Reading a good book is like taking a trip without"} diff --git a/tests/test_utils/recipes/h100/gpt-dynamic-inference-with-coordinator.yaml b/tests/test_utils/recipes/h100/gpt-dynamic-inference-with-coordinator.yaml index 19d523eea8d..0349896345b 100644 --- a/tests/test_utils/recipes/h100/gpt-dynamic-inference-with-coordinator.yaml +++ b/tests/test_utils/recipes/h100/gpt-dynamic-inference-with-coordinator.yaml @@ -58,24 +58,25 @@ products: - test_case: [gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq] products: - environment: [dev] - scope: [flaky] + scope: [mr] platforms: [dgx_h100] - test_case: [gpt_dynamic_inference_tp1_pp8_dp1_583m_logitsmatch_zmq] products: - environment: [dev] - scope: [flaky] + scope: [mr] platforms: [dgx_h100] - test_case: [gpt_dynamic_inference_tp1_pp1_dp8_583m_logitsmatch_zmq] products: - environment: [dev] scope: [mr] platforms: [dgx_h100] + - test_case: [gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq] + products: + - environment: [dev] + scope: [mr, mr-github] + platforms: [dgx_h100] # - test_case: [gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq] # products: # - environment: [dev] # scope: [mr] # platforms: [dgx_h100] - - test_case: [gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq] - products: - - environment: [dev] - scope: [flaky] diff --git a/tests/test_utils/recipes/h100/moe-dynamic-inference-with-coordinator.yaml b/tests/test_utils/recipes/h100/moe-dynamic-inference-with-coordinator.yaml index 4ce11808bdc..fb872dee2d4 100644 --- a/tests/test_utils/recipes/h100/moe-dynamic-inference-with-coordinator.yaml +++ b/tests/test_utils/recipes/h100/moe-dynamic-inference-with-coordinator.yaml @@ -58,11 +58,16 @@ products: - test_case: [gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq] products: - environment: [dev] - scope: [flaky] + scope: [mr, mr-github] platforms: [dgx_h100] - test_case: [gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq] products: - environment: [dev] - scope: [mr-github] + scope: [mr, mr-github] + platforms: [dgx_h100] + - test_case: [gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume] + products: + - environment: [dev] + scope: [mr, mr-github] platforms: [dgx_h100] diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index ced9746248c..6a07d7a35ae 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -28,6 +28,7 @@ TokenOverflowError, ) from megatron.core.inference.engines import DynamicInferenceEngine +from megatron.core.inference.engines.dynamic_engine import EngineState from megatron.core.inference.inference_request import DynamicInferenceRequest, Status from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( GPTInferenceWrapper, @@ -1774,7 +1775,7 @@ def test_suspend_resume_cycle(self, kv_cache_management_mode, static_kv_memory_p engine = env.engine context = engine.context - assert not engine.is_suspended + assert engine.state != EngineState.SUSPENDED assert context.is_tensor_state_allocated deallocates = kv_cache_management_mode != "persist" @@ -1797,7 +1798,7 @@ def test_suspend_resume_cycle(self, kv_cache_management_mode, static_kv_memory_p # Suspend. engine.suspend() - assert engine.is_suspended + assert engine.state == EngineState.SUSPENDED assert not context.is_tensor_state_allocated gc.collect() @@ -1827,7 +1828,7 @@ def test_suspend_resume_cycle(self, kv_cache_management_mode, static_kv_memory_p # Resume. engine.resume() - assert not engine.is_suspended + assert engine.state != EngineState.SUSPENDED assert context.is_tensor_state_allocated if deallocates and not uses_tms: diff --git a/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py b/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py index 57326291a73..5bdd16ab94f 100644 --- a/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py +++ b/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py @@ -2,16 +2,25 @@ import asyncio import itertools +import multiprocessing +import os import time +import unittest.mock from collections import deque from typing import Dict, Optional import msgpack import pytest import torch -from tqdm import tqdm -from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine, RequestEntry +from megatron.core.inference.data_parallel_inference_coordinator import ( + DataParallelInferenceCoordinator, +) +from megatron.core.inference.engines.dynamic_engine import ( + DynamicInferenceEngine, + EngineState, + RequestEntry, +) from megatron.core.inference.headers import Headers from megatron.core.inference.inference_client import InferenceClient from megatron.core.inference.inference_request import ( @@ -34,7 +43,6 @@ NUM_REQUESTS = 10 NUM_TOKENS = 2 DEFAULT_PORT = 46581 -ZMQ_FLAKY_SHUTDOWN = True class DummyTokenizer: @@ -64,6 +72,7 @@ class DummyContext: def __init__(self): self.active_cnt = 0 + self.step_count = 0 def get_active_request_count(self) -> int: return self.active_cnt @@ -86,20 +95,41 @@ def __init__(self): """We cannot call super().__init__() because it requires complex setup.""" self.waiting_request_ids = deque() self.requests: Dict[int, RequestEntry] = {} - self.suspend_signal = False - self.is_suspended = False self._loop = get_asyncio_loop() self.context = DummyContext() self.controller = DummyController() - self.running = asyncio.Event() - self.paused = asyncio.Event() - self.stopped = asyncio.Event() self.pending_microbatch = deque() - self.received_pause: bool = False - self.received_stop: bool = False self.pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.rank = torch.distributed.get_rank() + # State machine (mirrors dynamic_engine.py reset()). + self.state = EngineState.RUNNING + self._state_events = {k: asyncio.Event() for k in self._STATE_EVENTS} + self._state_events[EngineState.RUNNING].set() + self._pending_signals = deque() + self.resume_request_ids = None + self.use_coordinator = False + + self.ep_world_size = 1 + + self.step_start_event = unittest.mock.MagicMock() + self.step_end_event = unittest.mock.MagicMock() + + async def run_engine_with_coordinator(self, *, loop=None): + """Override to bypass @trace_async_exceptions for testability. + + In production, @trace_async_exceptions converts AssertionError to sys.exit(1) -> SystemExit. + In Python 3.12+, asyncio re-raises SystemExit from tasks in the main thread. + For tests, we let AssertionErrors propagate directly so pytest.raises can catch them. + """ + return await DynamicInferenceEngine.run_engine_with_coordinator.__wrapped__(self, loop=loop) + + def suspend(self): + pass + + def resume(self): + pass + def add_request( self, request_id: int, prompt: str, sampling_params: Optional[SamplingParams] = None ) -> asyncio.Future[DynamicInferenceRequestRecord]: @@ -122,6 +152,8 @@ def add_request( async def async_step(self, *, verbose: Optional[bool] = False) -> Dict: """Dummy async_step.""" + await asyncio.sleep(0) + # Finish "active" requests. finished_request_records = [] to_remove = [] @@ -163,6 +195,42 @@ async def async_step(self, *, verbose: Optional[bool] = False) -> Dict: } +async def cleanup_engine(engine, client=None, timeout=30.0): + """Disconnect an engine between tests. The coordinator stays alive.""" + task = getattr(engine, 'engine_loop_task', None) + if task is not None and not task.done(): + if client is not None: + client.pause_engines() + try: + await asyncio.wait_for(engine.wait_until(EngineState.PAUSED), timeout=timeout) + except (asyncio.TimeoutError, Exception): + pass + + sub = getattr(engine, 'model_parallel_num_msgs_subscriber_socket', None) + if sub is not None: + sub.setsockopt(zmq.RCVTIMEO, 1000) + + # Close ZMQ communicator sockets to unblock any stuck ranks. + for attr in ('expert_parallel_zmq_communicator', 'world_zmq_communicator'): + comm = getattr(engine, attr, None) + if comm is not None: + comm.close() + + task.cancel() + try: + await asyncio.wait_for(asyncio.shield(task), timeout=5.0) + except (asyncio.TimeoutError, asyncio.CancelledError, Exception): + pass + + if client is not None: + # Walk the coordinator back to RUNNING regardless of its current state + # so the next test starts cleanly. Each call is a no-op when the + # coordinator is already in the target state (just logs a warning). + client.resume_engines() # SUSPENDED → PAUSED (no-op otherwise) + client.unpause_engines() # PAUSED → RUNNING (no-op otherwise) + client.stop() + + @pytest.fixture def initialize_model_parallel(request, monkeypatch): """Fixture to initialize and destroy model parallel. @@ -179,12 +247,65 @@ def initialize_model_parallel(request, monkeypatch): pipeline_model_parallel_size=pp, expert_model_parallel_size=ep, ) - dp = world_size // (tp * pp * ep) + dp = world_size // (tp * pp) yield world_size, dp, tp, pp, ep Utils.destroy_model_parallel() -@pytest.mark.skipif(ZMQ_FLAKY_SHUTDOWN, reason="ZMQ shutdown is flaky") +@pytest.fixture(scope="class") +def coordinator(): + """Launch a single coordinator process for the entire test class. + + Only rank 0 spawns the coordinator process. Non-rank-0 processes use a + placeholder address; the real address is broadcast inside each test's call + to start_listening_to_data_parallel_coordinator (which broadcasts dp_addr + from dp_src within the DP process group). + + The coordinator is spawned with data_parallel_size=0 so it doesn't block + waiting for engines; engines register dynamically via the empty-payload + re-registration path. + """ + rank = int(os.environ.get("RANK", "0")) + + if rank == 0: + spawn_context = multiprocessing.get_context('spawn') + pipe_parent, pipe_child = spawn_context.Pipe() + ready_event = spawn_context.Event() + proc = spawn_context.Process( + target=DataParallelInferenceCoordinator.entrypoint, + args=(pipe_child, ready_event, 0, DummyTokenizer(), DEFAULT_PORT, False), + ) + proc.start() + + # Wait for the coordinator to bind its socket and send the address. + while not pipe_parent.poll(timeout=0.1): + assert proc.is_alive(), "Coordinator process died during init" + dp_addr = pipe_parent.recv() + pipe_parent.close() + ready_event.wait(timeout=10.0) + else: + proc = None + # Placeholder: the engine setup broadcasts rank 0's actual address. + dp_addr = f"tcp://localhost:{DEFAULT_PORT}" + + yield dp_addr + + # Only rank 0 tears down the coordinator process. + if rank == 0 and proc is not None and proc.is_alive(): + ctx = zmq.Context() + sock = ctx.socket(zmq.DEALER) + sock.connect(dp_addr) + sock.send(msgpack.packb([Headers.CONNECT.value], use_bin_type=True)) + sock.recv() # CONNECT_ACK + sock.send(msgpack.packb([Headers.SHUTDOWN.value], use_bin_type=True)) + sock.close(linger=1000) + ctx.term() + proc.join(timeout=10.0) + if proc.is_alive(): + proc.terminate() + proc.join(timeout=5.0) + + class TestCoordinator: """Test class for Data Parallel Inference Coordinator.""" @@ -195,26 +316,43 @@ def build_requests(self, num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS): for _ in range(num_requests) ] - async def run_coordinator_test( - self, - *, - launch_coordinator=True, - stop_engines=True, - num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - ): - """Run a coordinator test. Model parallel must already be initialized.""" + @pytest.mark.internal + @pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test") + @pytest.mark.asyncio + @pytest.mark.parametrize( + "initialize_model_parallel", + [ + pytest.param((tp, pp, ep), id=f"tp{tp}-pp{pp}-ep{ep}") + for tp, pp, ep in itertools.product([1, 2], [1, 2], [1, 2]) + if tp * pp * ep <= Utils.world_size + ], + indirect=["initialize_model_parallel"], + ) + async def test_parallel_configs(self, initialize_model_parallel, coordinator): + """Test coordinator with various TP, PP, and EP configurations.""" + dp_addr = coordinator + port = int(dp_addr.rsplit(":", 1)[-1]) + requests = self.build_requests() engine = DummyEngine() - requests = self.build_requests(num_requests, num_tokens) + rank = torch.distributed.get_rank() - dp_addr = await engine.start_listening_to_data_parallel_coordinator( - inference_coordinator_port=DEFAULT_PORT, launch_inference_coordinator=launch_coordinator + await engine.start_listening_to_data_parallel_coordinator( + inference_coordinator_port=port, launch_inference_coordinator=False ) + # Ensure all engines are registered before submitting requests. + await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), timeout=30.0 + ) + + client = None try: - if torch.distributed.get_rank() == 0: + if rank == 0: + # Yield so engine loop can run before we block the event loop + # with the client's synchronous connect handshake. + await asyncio.sleep(0) client = InferenceClient(dp_addr) - await client.start() + client.start() futures = [ client.add_request(prompt=prompt, sampling_params=params) @@ -224,266 +362,247 @@ async def run_coordinator_test( for record in results: assert record[-1].status == Status.COMPLETED - finally: - if torch.distributed.get_rank() == 0: - if stop_engines: - await asyncio.wait_for(client.stop_engines(), timeout=10.0) - client.stop() - if stop_engines: - try: - await asyncio.wait_for(engine.engine_loop_task, timeout=30.0) - except asyncio.TimeoutError: - engine.engine_loop_task.cancel() - return dp_addr + await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), + timeout=30.0, + ) + finally: + await cleanup_engine(engine, client) @pytest.mark.internal @pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test") @pytest.mark.asyncio @pytest.mark.parametrize( "initialize_model_parallel", - [ - pytest.param((tp, pp, ep), id=f"tp{tp}-pp{pp}-ep{ep}") - for tp, pp, ep in itertools.product([1, 2], [1, 2], [1, 2]) - if tp * pp * ep <= Utils.world_size - ], + [pytest.param((2, 2, 2), id="tp2-pp2-ep2")], indirect=["initialize_model_parallel"], ) - async def test_parallel_configs(self, initialize_model_parallel): - """Test coordinator with various TP, PP, and EP configurations.""" - await self.run_coordinator_test() + async def test_control_logic_lifecycle(self, initialize_model_parallel, coordinator): + """Comprehensive lifecycle test for the engine state machine.""" + # States where paused stays set: once set during PAUSE, it's only cleared by UNPAUSE. + PAUSED_FAMILY = { + EngineState.PAUSED, + EngineState.UNPAUSING, + EngineState.SUSPENDING, + EngineState.SUSPENDED, + EngineState.RESUMING, + EngineState.STOPPING, + EngineState.STOPPED, + } - @pytest.mark.internal - @pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test") - @pytest.mark.asyncio - async def test_coordinator_lifecycle(self, initialize_model_parallel): - """Test coordinator connection and port conflict behavior.""" - engine1 = DummyEngine() - engine2 = None - engine3 = None - third_addr = None - - # Launch first coordinator - binds to DEFAULT_PORT - first_addr = await engine1.start_listening_to_data_parallel_coordinator( - inference_coordinator_port=DEFAULT_PORT, launch_inference_coordinator=True - ) + def assert_state(eng, expected): + """Assert engine state and all four event flags are consistent.""" + assert eng.state == expected, f"Expected state {expected}, got {eng.state}" + assert eng._state_events[EngineState.RUNNING].is_set() == ( + expected == EngineState.RUNNING + ), f"RUNNING.is_set()={eng._state_events[EngineState.RUNNING].is_set()} for state={expected}" + assert eng._state_events[EngineState.PAUSED].is_set() == ( + expected in PAUSED_FAMILY + ), f"PAUSED.is_set()={eng._state_events[EngineState.PAUSED].is_set()} for state={expected}" + assert eng._state_events[EngineState.SUSPENDED].is_set() == ( + expected == EngineState.SUSPENDED + ), f"SUSPENDED.is_set()={eng._state_events[EngineState.SUSPENDED].is_set()} for state={expected}" + assert eng._state_events[EngineState.STOPPED].is_set() == ( + expected == EngineState.STOPPED + ), f"STOPPED.is_set()={eng._state_events[EngineState.STOPPED].is_set()} for state={expected}" + + dp_addr = coordinator + port = int(dp_addr.rsplit(":", 1)[-1]) + requests = self.build_requests(num_requests=16) + engine = DummyEngine() + client = None + doomed_futures = [] + rank = torch.distributed.get_rank() try: - # Cancel engine1 loop without sending stop to coordinator - # This keeps coordinator process alive and holding the port - engine1.engine_loop_task.cancel() - try: - await engine1.engine_loop_task - except asyncio.CancelledError: - pass - - # Connect engine2 to existing coordinator (don't launch new one) - engine2 = DummyEngine() - second_addr = await engine2.start_listening_to_data_parallel_coordinator( - inference_coordinator_port=DEFAULT_PORT, launch_inference_coordinator=False + await engine.start_listening_to_data_parallel_coordinator( + inference_coordinator_port=port, launch_inference_coordinator=False ) - # Should connect to same port, but will not always in CI due to port conflicts. - first_port = int(first_addr.rsplit(":", 1)[-1]) - second_port = int(second_addr.rsplit(":", 1)[-1]) - # assert second_port == first_port - - # Cancel engine2 - engine2.engine_loop_task.cancel() - try: - await engine2.engine_loop_task - except asyncio.CancelledError: - pass - - # Launch new coordinator - should get different port since first is holding it - engine3 = DummyEngine() - third_addr = await engine3.start_listening_to_data_parallel_coordinator( - inference_coordinator_port=DEFAULT_PORT, launch_inference_coordinator=True + # Synchronize all ranks so every engine has registered. + await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), + timeout=30.0, ) - # Verify we got a different port due to conflict - third_port = int(third_addr.rsplit(":", 1)[-1]) - assert ( - third_port != first_port - ), f"Expected different port due to conflict, but got same: {third_port}" + if rank == 0: + client = InferenceClient(dp_addr) + client.start() - finally: - # Clean up engine3's coordinator - if engine3 is not None and third_addr is not None: - client3 = InferenceClient(third_addr) - await client3.start() - await asyncio.wait_for(client3.stop_engines(), timeout=10.0) - client3.stop() - try: - await asyncio.wait_for(engine3.engine_loop_task, timeout=30.0) - except asyncio.TimeoutError: - engine3.engine_loop_task.cancel() - - # Rebuild engine and reconnect to engine1's coordinator - first_port = int(first_addr.rsplit(":", 1)[-1]) - engine1 = DummyEngine() - await engine1.start_listening_to_data_parallel_coordinator( - inference_coordinator_port=first_port, launch_inference_coordinator=False - ) - client1 = InferenceClient(first_addr) - await client1.start() - await asyncio.wait_for(client1.stop_engines(), timeout=10.0) - client1.stop() - try: - await asyncio.wait_for(engine1.engine_loop_task, timeout=30.0) - except asyncio.TimeoutError: - engine1.engine_loop_task.cancel() + await asyncio.wait_for(engine.wait_until(EngineState.RUNNING), timeout=5.0) + assert_state(engine, EngineState.RUNNING) - @pytest.mark.internal - @pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test") - @pytest.mark.asyncio - async def test_pause(self, initialize_model_parallel): - """Test pause and resume functionality.""" - engine = DummyEngine() - requests = self.build_requests(num_requests=32) + # Try to submit signals out of FSM order. + # The coordinator's state machine filters these out. + client.suspend_engines() + await asyncio.sleep(0.1) + assert_state(engine, EngineState.RUNNING) + client.resume_engines() + await asyncio.sleep(0.1) + assert_state(engine, EngineState.RUNNING) + client.stop_engines() + await asyncio.sleep(0.1) + assert_state(engine, EngineState.RUNNING) - dp_addr = await engine.start_listening_to_data_parallel_coordinator( - inference_coordinator_port=DEFAULT_PORT, launch_inference_coordinator=True - ) + # Submit and complete requests while running. + futures = [client.add_request(prompt=p, sampling_params=s) for p, s in requests[:2]] + results = await asyncio.wait_for(asyncio.gather(*futures), timeout=5.0) + for record in results: + assert record[-1].status == Status.COMPLETED - success = True - try: - if torch.distributed.get_rank() == 0: - client = InferenceClient(dp_addr) - await client.start() + # Submit requests while RUNNING, then PAUSE before they drain. + # These must survive the PAUSE (not be drained during PAUSING). + pre_pause_futures = [ + client.add_request(prompt=p, sampling_params=s) for p, s in requests[2:3] + ] + client.pause_engines() + await asyncio.wait_for(engine.wait_until(EngineState.PAUSED), timeout=5.0) + assert_state(engine, EngineState.PAUSED) - # Submit requests and pause after completion. - futures = [client.add_request(prompt=p, sampling_params=s) for p, s in requests[:2]] + # Pre-pause requests must NOT have been drained. + done, pending = await asyncio.wait(pre_pause_futures, timeout=0.5) + assert len(pending) > 0, "Pre-pause requests should not drain during PAUSING" + + # Try pausing again and see if it breaks. + client.pause_engines() await asyncio.sleep(0.1) - awaitables = futures + [client.pause_engines()] - try: - await asyncio.wait_for(asyncio.gather(*awaitables), timeout=0.5) - except asyncio.TimeoutError: - pytest.fail("Pause operation timed out.") - - # Ensure that requests can be added while paused. - prompt, params = requests[2] - future = client.add_request(prompt=prompt, sampling_params=params) - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(future, timeout=0.1) - - # Resume and verify new requests complete. + assert_state(engine, EngineState.PAUSED) + + # Requests submitted while PAUSED should queue, not complete. + paused_futures = [ + client.add_request(prompt=p, sampling_params=s) for p, s in requests[3:5] + ] + # Use asyncio.wait (not wait_for) so futures aren't cancelled. + done, pending = await asyncio.wait(paused_futures, timeout=0.5) + assert len(done) == 0, "No requests should complete while paused" + assert len(pending) == 2 + + # UNPAUSE and verify all in-flight requests (pre-pause + paused) complete. client.unpause_engines() - # TODO: The system should not be incorrectly raising a cancelled error here. - with pytest.raises(asyncio.CancelledError): - await future + await asyncio.wait_for(engine.wait_until(EngineState.RUNNING), timeout=5.0) + all_queued = pre_pause_futures + paused_futures + results = await asyncio.wait_for(asyncio.gather(*all_queued), timeout=10.0) + for record in results: + assert record[-1].status == Status.COMPLETED + assert_state(engine, EngineState.RUNNING) + # Engine processes new requests normally after unpause. futures = [ - client.add_request(prompt=p, sampling_params=s) for p, s in requests[3:4] + client.add_request(prompt=p, sampling_params=s) for p, s in requests[5:7] ] + results = await asyncio.wait_for(asyncio.gather(*futures), timeout=5.0) + for record in results: + assert record[-1].status == Status.COMPLETED + + # Suspend. + client.pause_engines() + await asyncio.wait_for(engine.wait_until(EngineState.PAUSED), timeout=5.0) + assert_state(engine, EngineState.PAUSED) + + client.suspend_engines() + await asyncio.wait_for(engine.wait_until(EngineState.SUSPENDED), timeout=5.0) + assert_state(engine, EngineState.SUSPENDED) + + # Try pausing again and see if it breaks. + client.pause_engines() + await asyncio.sleep(0.1) + assert_state(engine, EngineState.SUSPENDED) + + # Try suspending again and see if it breaks. + client.pause_engines() await asyncio.sleep(0.1) - try: - await asyncio.wait_for(asyncio.gather(*futures), timeout=0.5) - except asyncio.TimeoutError: - pytest.fail("Resumed requests did not complete in time.") - except: - success = False + assert_state(engine, EngineState.SUSPENDED) + + # Resume. + client.resume_engines() + await asyncio.wait_for(engine.wait_until(EngineState.RESUMED), timeout=5.0) + assert_state(engine, EngineState.PAUSED) + assert not engine._state_events[EngineState.SUSPENDED].is_set() + + # Engine processes requests after suspend/resume cycle. + client.unpause_engines() + await asyncio.wait_for(engine.wait_until(EngineState.RUNNING), timeout=5.0) + + futures = [ + client.add_request(prompt=p, sampling_params=s) for p, s in requests[7:10] + ] + results = await asyncio.wait_for(asyncio.gather(*futures), timeout=5.0) + for record in results: + assert record[-1].status == Status.COMPLETED + + # Submit requests that will be cancelled on STOP. + client.pause_engines() + await asyncio.wait_for(engine.wait_until(EngineState.PAUSED), timeout=5.0) + assert_state(engine, EngineState.PAUSED) + + doomed_futures = [ + client.add_request(prompt=p, sampling_params=s) for p, s in requests[10:13] + ] + + # Synchronize all ranks before STOP. + await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), + timeout=30.0, + ) + + if rank == 0: + # Verify doomed futures are still pending. + for f in doomed_futures: + assert not f.done(), "Client futures should still be pending" + client.stop_engines() + + await asyncio.wait_for(engine.wait_until(EngineState.STOPPED), timeout=60.0) + assert_state(engine, EngineState.STOPPED) + finally: - try: - if torch.distributed.get_rank() == 0: - await asyncio.wait_for(client.stop_engines(), timeout=5.0) - client.stop() - await asyncio.wait_for(engine.engine_loop_task, timeout=30.0) - except asyncio.TimeoutError: - engine.engine_loop_task.cancel() - assert success, "Pause/resume test failed." + await cleanup_engine(engine, client) + + # cleanup_engine called client.stop() which cancels pending futures. + if torch.distributed.get_rank() == 0: + for f in doomed_futures: + assert f.cancelled(), "Client futures should be cancelled after client.stop()" @pytest.mark.internal @pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test") @pytest.mark.asyncio - async def test_throughput(self, initialize_model_parallel): - """Throughput test with no TP or PP.""" + async def test_throughput(self, initialize_model_parallel, coordinator): + """Throughput benchmark: measures ZMQ packet rate.""" + _, dp, _, _, _ = initialize_model_parallel num_requests = 10**4 num_iterations = 10 + dp_addr = coordinator + port = int(dp_addr.rsplit(":", 1)[-1]) engine = DummyEngine() requests = self.build_requests(num_requests=num_requests) - start_time = time.time() - dp_addr = await engine.start_listening_to_data_parallel_coordinator( - inference_coordinator_port=DEFAULT_PORT, launch_inference_coordinator=True + await engine.start_listening_to_data_parallel_coordinator( + inference_coordinator_port=port, launch_inference_coordinator=False ) + # Ensure all engines are registered before submitting requests. + await asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier) + + client = None try: if torch.distributed.get_rank() == 0: client = InferenceClient(dp_addr) - await client.start() - init_time = time.time() + client.start() + start_time = time.time() for _ in range(num_iterations): futures = [] - for prompt, sampling_params in tqdm(requests, "add_requests"): + for prompt, sampling_params in requests: fut = client.add_request(prompt=prompt, sampling_params=sampling_params) futures.append(fut) - await asyncio.wait_for(asyncio.gather(*futures), timeout=10.0) - done_time = time.time() + await asyncio.wait_for(asyncio.gather(*futures), timeout=30.0) + elapsed_ms = (time.time() - start_time) * 1e3 + total = num_requests * num_iterations // dp + print( + f"ZMQ throughput: {total / elapsed_ms:.2f} requests/ms " + f"({total} reqs in {elapsed_ms:.0f} ms)" + ) + await asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier) finally: - if torch.distributed.get_rank() == 0: - await asyncio.wait_for(client.stop_engines(), timeout=10.0) - client.stop() - try: - await asyncio.wait_for(engine.engine_loop_task, timeout=30.0) - except asyncio.TimeoutError: - engine.engine_loop_task.cancel() - - stop_time = time.time() - - flags = torch.tensor([1, 1, 1], dtype=torch.int, device=torch.cuda.current_device()) - - init_duration = golden_init_duration = None - run_duration = golden_run_duration = None - stop_duration = golden_stop_duration = None - - if torch.distributed.get_rank() == 0: - init_duration = (init_time - start_time) * 10**3 - golden_init_duration = 6974.43 # ms - run_duration = (done_time - init_time) * 10**3 - golden_run_duration = 4392.63 # ms - stop_duration = (stop_time - done_time) * 10**3 - golden_stop_duration = 931.49 # ms - - def clamp_to_golden_value(value, golden_value, delta=0.1): - return value > golden_value * (1 - delta) and value < golden_value * (1 + delta) - - if not clamp_to_golden_value(init_duration, golden_init_duration, delta=0.5): - flags[0] = 0 - if not clamp_to_golden_value(run_duration, golden_run_duration, delta=0.2): - flags[1] = 0 - if not clamp_to_golden_value(stop_duration, golden_stop_duration, delta=1.0): - flags[2] = 0 - - # Synchronize results - torch.distributed.broadcast(flags, src=0) - - if torch.distributed.get_rank() == 0: - print(f"Initialization time: {init_duration:.2f} ms") - print(f"Run time: {run_duration:.2f} ms") - print(f"Stop time: {stop_duration:.2f} ms") - - assert flags[0].item() == 1, ( - f"WARNING: Init duration {init_duration:.2f}s deviates from " - f"golden value {golden_init_duration:.2f}s" - ) - assert flags[1].item() == 1, ( - f"WARNING: Run duration {run_duration:.2f}s deviates from " - f"golden value {golden_run_duration:.2f}s" - ) - assert flags[2].item() == 1, ( - f"WARNING: Stop duration {stop_duration:.2f}s deviates from " - f"golden value {golden_stop_duration:.2f}s" - ) - - print( - f"ZMQ throughput is approximately " - f"{num_requests * num_iterations / run_duration:.2f} " - f"requests/ms" - ) - else: - assert flags[0].item() == 1 - assert flags[1].item() == 1 - assert flags[2].item() == 1 + await cleanup_engine(engine, client) diff --git a/tests/unit_tests/inference/test_dynamic_prefix_caching_coordinator.py b/tests/unit_tests/inference/test_dynamic_prefix_caching_coordinator.py index f2c9f97bb5b..e8d2bc728ef 100644 --- a/tests/unit_tests/inference/test_dynamic_prefix_caching_coordinator.py +++ b/tests/unit_tests/inference/test_dynamic_prefix_caching_coordinator.py @@ -216,6 +216,7 @@ def make_coordinator_direct( ) coordinator.hash_to_rank_info = {} + coordinator._round_robin_idx = 0 coordinator._assignment_counter = 0 return coordinator @@ -442,7 +443,7 @@ async def run_coordinator_test( try: if torch.distributed.get_rank() == 0: client = InferenceClient(dp_addr) - await client.start() + client.start() futures = [ client.add_request(prompt=prompt, sampling_params=params)