Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
0b64ce8
skeleton of inference moe layer done
sidsingh-nvidia Jan 14, 2026
da29281
restore
sidsingh-nvidia Jan 23, 2026
75cfef8
Merge branch 'main' into siddharth/inference-optimized-moe-layer
sidsingh-nvidia Jan 23, 2026
6e01116
match argument signature with training
sidsingh-nvidia Jan 23, 2026
153265b
support gpt models like qwen
sidsingh-nvidia Jan 26, 2026
7915cff
make torch grouped gemm work
sidsingh-nvidia Jan 28, 2026
8dd410d
add config restraints for single GPU only and make dtoh and sync a nu…
sidsingh-nvidia Jan 28, 2026
b8f5fe5
remove requirement for router fusion
sidsingh-nvidia Jan 29, 2026
5063fb2
confirm that this works with nccl all to alls
sidsingh-nvidia Jan 30, 2026
297f926
disable drop and pad for inference optimized, and propogate cuda grap…
sidsingh-nvidia Feb 2, 2026
629dc1f
confirm that all-gather dispatch runs within cuda graphs
sidsingh-nvidia Feb 2, 2026
21b9140
working
sidsingh-nvidia Feb 2, 2026
a786eda
replace permute/unpermute kernels with triton
sidsingh-nvidia Feb 2, 2026
f6ee32c
minor optimizations
sidsingh-nvidia Feb 2, 2026
3f7f39d
one round of optimizations
sidsingh-nvidia Feb 2, 2026
10da287
reduce kernel calls
sidsingh-nvidia Feb 2, 2026
1983688
symmetric memory AG for hidden states
sidsingh-nvidia Feb 2, 2026
02f315a
nvls all gathers for all three tensors. nvls rs on hidden state
sidsingh-nvidia Feb 2, 2026
0fac929
full model cg optimizations and bump up max blocks for blackwell
sidsingh-nvidia Feb 2, 2026
df930a2
Merge branch 'main' into inf-opt-all-gather-dispatcher
sidsingh-nvidia Feb 2, 2026
3606123
Merge remote-tracking branch 'origin/main' into inf-opt-all-gather-di…
sidsingh-nvidia Feb 4, 2026
371043c
fix full model CG for mamba
sidsingh-nvidia Feb 4, 2026
01cd40f
remove requirement for moe permute fusion
sidsingh-nvidia Feb 4, 2026
30d8cf3
failed attempt at optimizing router and permute
sidsingh-nvidia Feb 4, 2026
6cb8a8a
tseted with qwen
sidsingh-nvidia Feb 4, 2026
b85e8fe
add cutlass kernel
sidsingh-nvidia Feb 5, 2026
98a4d9f
optimize dummy forwards
sidsingh-nvidia Feb 5, 2026
acbc841
bugfix in inference router
sidsingh-nvidia Feb 5, 2026
986e2a1
latest
sidsingh-nvidia Feb 9, 2026
bb8890d
return usage characteristics from text gen server
sidsingh-nvidia Feb 13, 2026
9d062c7
add vllm cg utils
sidsingh-nvidia Feb 16, 2026
7a834ae
print engine time in ms instead of seconds
sidsingh-nvidia Feb 16, 2026
8281d86
sleep 0
sidsingh-nvidia Feb 16, 2026
d4f00ca
add fused 3 tensor all gather
sidsingh-nvidia Feb 16, 2026
6251bb9
restore some delay in zmq asyncio
sidsingh-nvidia Feb 17, 2026
72674ee
faster dummy ep cg codepath
sidsingh-nvidia Feb 20, 2026
0b55c2d
format
sidsingh-nvidia Feb 21, 2026
cf40976
Merge branch 'main' into siddharth/optimize-dummy-ep-fwd-pass
sidsingh-nvidia Feb 23, 2026
b944043
Merge branch 'main' into siddharth/optimize-dummy-ep-fwd-pass
sidsingh-nvidia Feb 23, 2026
126d6c1
refactor + make safer
sidsingh-nvidia Feb 23, 2026
7a7d78d
relegate to strict matching for qwen
sidsingh-nvidia Feb 24, 2026
bf6f678
add unit test for ep syncs, and fix a bug in non strict matching
sidsingh-nvidia Feb 24, 2026
ceb83c2
Merge branch 'main' into siddharth/optimize-dummy-ep-fwd-pass
sidsingh-nvidia Feb 24, 2026
3f2de16
linting
sidsingh-nvidia Feb 24, 2026
05e872b
attempt to delete unnecessary modifications
sidsingh-nvidia Feb 24, 2026
71dffc9
Merge branch 'siddharth/optimize-dummy-ep-fwd-pass' into inf-opt-all-…
sidsingh-nvidia Feb 24, 2026
4d654c2
minor
sidsingh-nvidia Feb 24, 2026
afd2ad8
remove code from the dummy ep PR
sidsingh-nvidia Feb 25, 2026
bf8f546
restore utils.py
sidsingh-nvidia Feb 25, 2026
c4091bd
remove torch grouped gemm kernels: we will add them in another PR
sidsingh-nvidia Feb 25, 2026
7c2b2ff
remove mamba metadata changes
sidsingh-nvidia Feb 25, 2026
2631c1a
simplify hybrid spec call
sidsingh-nvidia Feb 25, 2026
27c0f7c
restore dynamic context
sidsingh-nvidia Feb 25, 2026
b055cb6
slight clean up of router
sidsingh-nvidia Feb 25, 2026
3f24597
router cleanup
sidsingh-nvidia Feb 25, 2026
1bbaf82
more router cleanup
sidsingh-nvidia Feb 25, 2026
5607f6f
absorb inference layer into parent moe layer and more cleanup
sidsingh-nvidia Feb 25, 2026
834656b
more cleanup
sidsingh-nvidia Feb 25, 2026
14d4540
Merge branch 'main' into inf-opt-all-gather-dispatcher
sidsingh-nvidia Feb 25, 2026
8163d18
fallback to NCCL as the triton collectives do not work for non 128-bi…
sidsingh-nvidia Feb 25, 2026
b89600b
remove changes related to symm mem comms
sidsingh-nvidia Feb 25, 2026
ba396b1
refactor
sidsingh-nvidia Feb 25, 2026
462ed8a
more cleanup and add warnings if flashinfer-jit and cubin are not ins…
sidsingh-nvidia Feb 25, 2026
e3311a0
more refactor
sidsingh-nvidia Feb 25, 2026
afb807b
make qwen3 work without CGs
sidsingh-nvidia Feb 26, 2026
51c383d
remove comment
sidsingh-nvidia Feb 26, 2026
c761a0d
refactor
sidsingh-nvidia Feb 26, 2026
d9f1712
Revert "fallback to NCCL as the triton collectives do not work for no…
sidsingh-nvidia Feb 26, 2026
3deb50c
Revert "remove changes related to symm mem comms"
sidsingh-nvidia Feb 26, 2026
7aff116
bring back NVLS collectives
sidsingh-nvidia Feb 26, 2026
2f8cf3e
add kill-switch for nvls
sidsingh-nvidia Feb 26, 2026
64bc241
cleanup nvls
sidsingh-nvidia Feb 26, 2026
81054b9
minor changes
sidsingh-nvidia Feb 26, 2026
6f5bf12
only do set is inference cg iteration from the engine
sidsingh-nvidia Feb 26, 2026
e533f43
more cleanup
sidsingh-nvidia Feb 26, 2026
7f4bc32
resolve flashinfer activations, small bugfix, and no use flashinfer f…
sidsingh-nvidia Feb 27, 2026
76d7c83
kill switch for torch grouped gemm
sidsingh-nvidia Feb 27, 2026
5e77a87
change name of dispatcher
sidsingh-nvidia Feb 27, 2026
d35b58f
remove bias act function duplication
sidsingh-nvidia Feb 27, 2026
d22a5e1
change the name of cuda graph mixed prefill count to cuda graph mixed…
sidsingh-nvidia Feb 27, 2026
1670f7e
cleanup asserts, disable fused tp kernel for moe
sidsingh-nvidia Feb 27, 2026
cc3e18f
format
sidsingh-nvidia Feb 27, 2026
902dc69
fix linting issues
sidsingh-nvidia Feb 27, 2026
2547534
linting
sidsingh-nvidia Feb 27, 2026
165d6d4
refactor
sidsingh-nvidia Feb 27, 2026
5f517a7
Merge branch 'main' into inf-opt-all-gather-dispatcher
sidsingh-nvidia Feb 27, 2026
14c8a39
fix unit test failures
sidsingh-nvidia Feb 27, 2026
2fc86b2
unit test for inference top-k router
sidsingh-nvidia Feb 28, 2026
04672e4
minor changes
sidsingh-nvidia Feb 28, 2026
0f9f7a8
add warmup to router unit test
sidsingh-nvidia Feb 28, 2026
84a1134
format
sidsingh-nvidia Feb 28, 2026
db0f784
Merge branch 'main' into inf-opt-all-gather-dispatcher
sidsingh-nvidia Feb 28, 2026
f132d59
add error message to assert
sidsingh-nvidia Mar 2, 2026
264cbb2
address feedback
sidsingh-nvidia Mar 2, 2026
df9de35
use decorator for torch compile
sidsingh-nvidia Mar 2, 2026
ca75b2b
bugfix
sidsingh-nvidia Mar 2, 2026
a61aea5
lint
sidsingh-nvidia Mar 2, 2026
4fd23ce
format and guard properly
sidsingh-nvidia Mar 2, 2026
fa25b1b
Merge branch 'main' into inf-opt-all-gather-dispatcher
sidsingh-nvidia Mar 2, 2026
5ae4424
Merge branch 'main' into inf-opt-all-gather-dispatcher
sidsingh-nvidia Mar 2, 2026
01ad0b2
Merge branch 'main' into inf-opt-all-gather-dispatcher
sidsingh-nvidia Mar 2, 2026
b1530a6
fix comments
sidsingh-nvidia Mar 2, 2026
fa61633
working histogram kernel
sidsingh-nvidia Mar 2, 2026
86b438f
add exhaustive unit tests
sidsingh-nvidia Mar 2, 2026
b7821ac
work with qwen3 on hopper
sidsingh-nvidia Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion gpt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_
)
)
elif args.num_experts:
assert not (config.transformer_impl == "inference_optimized")
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(
config,
Expand Down
26 changes: 13 additions & 13 deletions megatron/core/inference/batch_dimensions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def adjust_batch_dims_for_expert_parallelism(
strict: bool,
decode_only_cuda_graphs: bool,
explicit_chunked_prefill: bool,
cuda_graph_mixed_prefill_count: int,
smallest_non_decode_cuda_graph_size: int,
ep_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Optional["InferenceBatchDimensions"]:
"""Adjusted cuda graph batch dimensions for expert parallelism.
Expand Down Expand Up @@ -176,9 +176,9 @@ def adjust_batch_dims_for_expert_parallelism(
is_any_ep_rank_in_non_decode = sync_tensor[1].item() == 1

# We force eager mode for scenarios where some ranks will run with CUDA graphs
# while others will not. Without this check, the all-to-all communication in the
# while others will not. Without this check, communication in the
# expert routing layer would pad up to the maximum capacity only for the ranks that
# are using CUDA graphs in this step, leading to a NCCL hang.
# are using CUDA graphs in this step, leading to a hang.
# This can happen in the following cases:
# 1. If we only allow decode CUDA graphs but some ranks are running non-decode batches
# 2. Some ranks are running explicit chunked prefill requests
Expand All @@ -203,7 +203,7 @@ def adjust_batch_dims_for_expert_parallelism(
# graph while prefill ranks match a coarser mixed graph, which would
# produce inconsistent token counts across EP ranks.
if is_any_ep_rank_in_non_decode and not strict:
adjusted_token_count = max(adjusted_token_count, cuda_graph_mixed_prefill_count)
adjusted_token_count = max(adjusted_token_count, smallest_non_decode_cuda_graph_size)

adjusted_batch_dim = InferenceBatchDimensions(
token_count=adjusted_token_count,
Expand Down Expand Up @@ -303,7 +303,7 @@ def generate_cuda_graph_batch_dimensions_list(
tp_size: int,
num_cuda_graphs: Optional[int],
cuda_graph_max_tokens: int,
cuda_graph_mixed_prefill_count: Optional[int],
cuda_graph_mixed_prefill_request_count: Optional[int],
max_requests: int,
max_tokens: int,
max_sequence_length: int,
Expand Down Expand Up @@ -339,7 +339,7 @@ def generate_cuda_graph_batch_dimensions_list(
tp_size: Tensor parallel size
num_cuda_graphs: Number of CUDA graphs to generate
cuda_graph_max_tokens: Maximum tokens for CUDA graphs
cuda_graph_mixed_prefill_count: Number of mixed prefill requests for CUDA graphs
cuda_graph_mixed_prefill_request_count: Number of mixed prefill requests for CUDA graphs
max_requests: Maximum number of requests
max_tokens: Maximum total tokens
max_sequence_length: Maximum sequence length
Expand Down Expand Up @@ -409,8 +409,8 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int
if num_cuda_graphs is None:
cuda_graph_batch_dimensions_list = []
elif (
not cuda_graph_mixed_prefill_count
or cuda_graph_mixed_prefill_count <= 0
not cuda_graph_mixed_prefill_request_count
or cuda_graph_mixed_prefill_request_count <= 0
or not use_cuda_graphs_for_non_decode_steps
): # decode only
# Use decode-specific token counts for decode-only graphs
Expand All @@ -426,14 +426,14 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int
for size in cuda_graph_prefill_token_counts:
add_if_valid(
token_count=size,
prefill_req_count=min(cuda_graph_mixed_prefill_count, max_requests),
prefill_req_count=min(cuda_graph_mixed_prefill_request_count, max_requests),
decode_req_count=min(size, max_requests)
- min(cuda_graph_mixed_prefill_count, max_requests),
- min(cuda_graph_mixed_prefill_request_count, max_requests),
)
# We need to ensure the prefill requests are shorter than the max sequence length,
# considering the one decode token is used for prefill request construction
prefill_only_minimal_num = max(
cuda_graph_mixed_prefill_count,
cuda_graph_mixed_prefill_request_count,
math.ceil(size / max(1, max_sequence_length - 1)),
)
if prefill_only_minimal_num < max_requests:
Expand Down Expand Up @@ -474,7 +474,7 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int
def match_graph_config(
real_batch_dim: InferenceBatchDimensions,
cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions],
cuda_graph_mixed_prefill_count: int,
smallest_non_decode_cuda_graph_size: int,
strict: bool = False,
decode_only_cuda_graphs: bool = False,
explicit_chunked_prefill: bool = False,
Expand Down Expand Up @@ -509,7 +509,7 @@ def match_graph_config(
decode_only_cuda_graphs=decode_only_cuda_graphs,
explicit_chunked_prefill=explicit_chunked_prefill,
ep_group=ep_group,
cuda_graph_mixed_prefill_count=cuda_graph_mixed_prefill_count,
smallest_non_decode_cuda_graph_size=smallest_non_decode_cuda_graph_size,
)

if adjusted_batch_dim is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

from .collectives import multimem_all_gather, multimem_reduce_scatter
from .collectives import multimem_all_gather, multimem_all_gather_fused, multimem_reduce_scatter
from .fused_collectives import fused_multimem_rs_add_norm_ag
from .utils import are_tensors_nvls_eligible, is_device_nvls_capable
Loading
Loading