Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
ff01c9a
Fix the coordinator signaling logic
tdene Feb 19, 2026
b3e59f0
Fix coordinator shutdown
tdene Feb 20, 2026
f8c8025
Fix coordinator unit tests
tdene Feb 19, 2026
fc9d093
Address reviewer comments
tdene Feb 24, 2026
8d6eeee
Address reviewer comments
tdene Feb 24, 2026
9d0ff35
Address reviewer comments
tdene Feb 24, 2026
bb961a3
Cleanup
tdene Feb 24, 2026
a87ea0b
Address reviewer comments
tdene Feb 24, 2026
d0c19a9
Cleanup
tdene Feb 24, 2026
3fb373c
Fix cleanup
tdene Feb 24, 2026
291b820
Cleanup
tdene Feb 24, 2026
683312f
Cleanup
tdene Feb 24, 2026
526ccb2
Cleanup
tdene Feb 24, 2026
178680c
Correct shutdown
tdene Feb 24, 2026
8c426d7
Correct pause, no ACK
tdene Feb 24, 2026
2e2c8d5
lint
tdene Feb 25, 2026
e8d55b9
Fix cleanup
tdene Feb 25, 2026
68e3c6f
Fix cleanup
tdene Feb 25, 2026
5a3e1bb
Clean up test shutdown
tdene Feb 25, 2026
611de48
Cleanup
tdene Feb 26, 2026
b3605b8
Fix typo
tdene Feb 26, 2026
45b7436
tmp
tdene Feb 26, 2026
8081aff
Final cleanup
tdene Feb 26, 2026
a7f2dab
Cleanup tests
tdene Feb 26, 2026
aac10bb
Fix golden values
tdene Feb 26, 2026
690ab38
lint
tdene Feb 26, 2026
cee4c28
Fix CI unit test
tdene Feb 26, 2026
97d3343
Address reviewer comments
tdene Feb 26, 2026
136cfdc
Fix CI
tdene Feb 27, 2026
dbc7117
Address reviewer comments
tdene Feb 27, 2026
9ce666b
Merge remote-tracking branch 'gh/main' into tde/robust_coordinator_si…
tdene Feb 27, 2026
6d3afc0
Remove immediate shutdown logic
tdene Feb 27, 2026
45ffc2b
Fix CI
tdene Feb 28, 2026
3eb054f
Address reviewer comments
tdene Feb 28, 2026
f8639d0
lint
tdene Feb 28, 2026
9a17764
Fix transposed if statement
tdene Feb 28, 2026
bdfe52c
Fix spurious CI hang
tdene Mar 1, 2026
13cf3ac
Add UNPAUSING state
tdene Mar 2, 2026
6e93217
Merge remote-tracking branch 'gh/main' into tde/robust_coordinator_si…
tdene Mar 2, 2026
ee4f148
Add functional test
tdene Mar 2, 2026
e91e0e2
lint
tdene Mar 3, 2026
d025a9c
Merge remote-tracking branch 'gh/main' into tde/robust_coordinator_si…
tdene Mar 3, 2026
c3a315e
Fix lack of barrier after resume
tdene Mar 4, 2026
78af3ca
Merge remote-tracking branch 'gh/main' into tde/robust_coordinator_si…
tdene Mar 4, 2026
80c3f20
Fix merge
tdene Mar 4, 2026
e3ea13f
add suspend/resume to moe functional test
sidsingh-nvidia Mar 4, 2026
baa8b0c
Address reviewer comments
tdene Mar 4, 2026
e9d1070
Fix more merge conflicts
tdene Mar 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/inference/gpt/gpt_dynamic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ def _process_step_result(result):
break

# Resume engine (NOOP if not suspended).
if engine.is_suspended:
engine.resume()
engine.resume()

return {
"step_times": step_times,
Expand Down
88 changes: 51 additions & 37 deletions examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from examples.inference.gpt.utils import Request, build_dynamic_engine_setup_prefix, build_requests
from megatron.core.inference.engines import DynamicInferenceEngine
from megatron.core.inference.engines.dynamic_engine import EngineState
from megatron.core.inference.inference_client import InferenceClient
from megatron.core.inference.inference_request import DynamicInferenceRequestRecord
from megatron.core.inference.sampling_params import SamplingParams
Expand All @@ -29,6 +30,22 @@
logging.basicConfig(level=logging.INFO, force=True)


async def suspend_resume_cycle(client, engine, args, futures):
"""Wait for all in-flight requests, then suspend/train/resume."""
await asyncio.gather(*futures)

client.pause_engines()
await engine.wait_until(EngineState.PAUSED)
client.suspend_engines()
await engine.wait_until(EngineState.SUSPENDED)
if args.suspend_timeout > 0:
await asyncio.sleep(args.suspend_timeout)
client.resume_engines()
await engine.wait_until(EngineState.RESUMED)
client.unpause_engines()
await engine.wait_until(EngineState.RUNNING)


async def main(
engine: DynamicInferenceEngine,
requests: List[Request],
Expand All @@ -54,33 +71,22 @@ async def main(
coordinator_schedule_output_path=args.coordinator_schedule_output_path,
)

# Test suspend/resume intervals.
if dist.get_rank() == 0 and args.suspend_resume_interval is not None:
# Since the client doesn't directly call engine.async_step here, we test
# the suspend-resume system ~4 times.
suspend_resume_interval = max(1, len(requests) // 4)
suspend_idxs = set(
range(suspend_resume_interval, len(requests) + 1, suspend_resume_interval)
)
resume_idxs = set(
min(len(requests), i + suspend_resume_interval // 2) for i in suspend_idxs
)
else:
suspend_idxs = set()
resume_idxs = set()
# All ranks agree on the number of suspend/resume cycles from args.
num_suspend_resume_cycles = len(requests) // args.suspend_resume_interval if args.suspend_resume_interval else 0

# Create client and run example.
if dist.get_rank() == 0:
client = InferenceClient(dp_addr) # submits requests to the inference coordinator
await client.start()
client.start()
base_arrival_time = time.time_ns() / 10**9
for request in requests:
request.time_arrival = request.time_offset + base_arrival_time
futures = []
num_requests_total = len(requests)
num_requests_added = 0
# logging.info("Waiting for 20 seconds before starting to add requests. This is to mimic an RL style setup..")
# time.sleep(20)
next_suspend_at = args.suspend_resume_interval or 0
cycles_done = 0

while True:
current_time = time.time_ns() / 10**9
if args.incoming_requests_per_step is None:
Expand All @@ -96,11 +102,10 @@ async def main(
futures.append(client.add_request(request.prompt_text, request.sampling_params))
num_requests_added += 1

# Test suspend/resume.
if num_requests_added in suspend_idxs:
client.suspend_engines()
if num_requests_added in resume_idxs:
client.resume_engines()
if num_requests_added >= next_suspend_at and cycles_done < num_suspend_resume_cycles:
await suspend_resume_cycle(client, engine, args, futures)
cycles_done += 1
next_suspend_at += args.suspend_resume_interval

else:
# Add deterministic number of requests (generally used for debugging).
Expand All @@ -114,11 +119,10 @@ async def main(
futures.append(client.add_request(request.prompt_text, request.sampling_params))
num_requests_added += 1

# Test suspend/resume.
if num_requests_added in suspend_idxs:
client.suspend_engines()
if num_requests_added in resume_idxs:
client.resume_engines()
if num_requests_added >= next_suspend_at and cycles_done < num_suspend_resume_cycles:
await suspend_resume_cycle(client, engine, args, futures)
cycles_done += 1
next_suspend_at += args.suspend_resume_interval

if num_requests_added == num_requests_total:
break
Expand All @@ -127,6 +131,13 @@ async def main(

# While we wait for the requests to complete, the engine runs in the background.
results: List[DynamicInferenceRequestRecord] = await asyncio.gather(*futures)
else:
# Non-rank-0: match the suspend/resume cycles that rank 0 drives.
for _ in range(num_suspend_resume_cycles):
await engine.wait_until(EngineState.PAUSED)
await engine.wait_until(EngineState.SUSPENDED)
await engine.wait_until(EngineState.RESUMED)
await engine.wait_until(EngineState.RUNNING)

if dist.get_rank() == 0:
# Write results to JSON. Primarily used for functional testing.
Expand Down Expand Up @@ -173,14 +184,19 @@ async def main(
)
)

# kill the engines and suspend the client
# Right now, we can only call stop when all requests are done.
# Todo: Make this explicit in the Client class....
await client.stop_engines()
client.stop()
# Pause before stopping: STOP requires PAUSED or SUSPENDED state.
client.pause_engines()

await engine.wait_until(EngineState.PAUSED)

if dist.get_rank() == 0:
client.stop_engines()

# once the stop signal eventually makes its way to each GPU, the engines will stop.
await asyncio.gather(engine.engine_loop_task)
await engine.wait_until(EngineState.STOPPED)

if dist.get_rank() == 0:
client.shutdown_coordinator()
client.stop()
logging.info(f"Rank: {dist.get_rank()} stopped their engine instance successfully.")


Expand Down Expand Up @@ -210,9 +226,7 @@ async def main(

model = get_model_for_inference()

requests = (
build_requests(args, tokenizer, sampling_params) if dist.get_rank() == 0 else None
)
requests = build_requests(args, tokenizer, sampling_params)

engine = get_dynamic_inference_engine(model=model)

Expand Down
Loading